Pre-load CP heads + GPU decoder test (reverted) + headArgmaxOffset

- Pre-load all 15 CP heads at first CP call (eliminates lazy-load lag)
- Tested BigVGAN on GPU Adreno: no gain (+300ms vs CPU), kept on CPU
- Added headArgmaxOffset for future batch optimization
- Cancel previous pipeline on new run_pipeline intent

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 09:57:01 +02:00
parent 6e6c562d53
commit fb6045a635
3 changed files with 44 additions and 9 deletions

View File

@ -9,4 +9,7 @@ object NeonOps {
/** Batch argmax for all heads at once (avoids JNI overhead per head). */ /** Batch argmax for all heads at once (avoids JNI overhead per head). */
external fun headArgmaxBatch(hidden: FloatArray, allHeads: FloatArray, numHeads: Int, vocab: Int, dim: Int): IntArray external fun headArgmaxBatch(hidden: FloatArray, allHeads: FloatArray, numHeads: Int, vocab: Int, dim: Int): IntArray
/** Argmax with offset into a large weight buffer (avoids array copy). */
external fun headArgmaxOffset(hidden: FloatArray, allHeads: FloatArray, offset: Int, vocab: Int, dim: Int): Int
} }

View File

@ -127,6 +127,7 @@ class Qwen3TtsEngine(
private var cpRotarySin: FloatArray? = null private var cpRotarySin: FloatArray? = null
private var cpHeadsPath: String? = null // path dir for head_0..14.npy private var cpHeadsPath: String? = null // path dir for head_0..14.npy
private var cpHeadsCache: Array<FloatArray?>? = null // lazy-loaded heads cache (8MB each) private var cpHeadsCache: Array<FloatArray?>? = null // lazy-loaded heads cache (8MB each)
private var cpAllHeads: FloatArray? = null // all 15 heads concatenated [15*2048*1024] for batch NEON
private var cpPteModule: org.pytorch.executorch.Module? = null // ExecuTorch CP on NPU (JNI) private var cpPteModule: org.pytorch.executorch.Module? = null // ExecuTorch CP on NPU (JNI)
private var talkerPteModule: org.pytorch.executorch.Module? = null // ExecuTorch talker on NPU (JNI) private var talkerPteModule: org.pytorch.executorch.Module? = null // ExecuTorch talker on NPU (JNI)
private var talkerPteRotaryCos: FloatArray? = null private var talkerPteRotaryCos: FloatArray? = null
@ -207,7 +208,7 @@ class Qwen3TtsEngine(
return session return session
} }
// Speech decoder: try V2 ONNX CPU (correct weights), fall back to QNN HTP // Speech decoder V2: CPU ONNX (GPU tested: no gain, +300ms overhead)
val v2Path = "$path/v2_pre_conv" val v2Path = "$path/v2_pre_conv"
if (File("$v2Path/model.onnx").exists()) { if (File("$v2Path/model.onnx").exists()) {
nlog("Loading V2 speech decoder (CPU ONNX)...") nlog("Loading V2 speech decoder (CPU ONNX)...")
@ -1627,16 +1628,23 @@ class Qwen3TtsEngine(
// .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ... // .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ...
val hiddenOut = outputs[0].toTensor().dataAsFloatArray val hiddenOut = outputs[0].toTensor().dataAsFloatArray
// Head argmax using NEON SIMD (5× faster than Java) // Head argmax using NEON SIMD (individual 8MB heads, pre-loaded)
if (step >= 1 && step - 1 < 15) { if (step >= 1 && step - 1 < 15) {
if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15) if (cpHeadsCache == null) {
val cache = cpHeadsCache!! cpHeadsCache = arrayOfNulls(15)
val cbIdx = step - 1 val hp = cpHeadsPath
if (cache[cbIdx] == null) { if (hp != null) {
val hp = cpHeadsPath ?: return codes for (h in 0 until 15) {
cache[cbIdx] = loadNpy(hp.replace("cp_heads.npy", "head_${cbIdx}.npy")) cpHeadsCache!![h] = loadNpy(hp.replace("cp_heads.npy", "head_${h}.npy"))
}
nlog("CP heads pre-loaded: 15 × ${cpHeadsCache!![0]?.size} floats")
}
}
val cbIdx = step - 1
val headData = cpHeadsCache?.get(cbIdx)
if (headData != null) {
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, headData, CODEBOOK_SIZE, TALKER_DIM)
} }
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, cache[cbIdx]!!, CODEBOOK_SIZE, TALKER_DIM)
} }
// Update KV caches (output is [1,8,16,128] — fixed size, already shifted) // Update KV caches (output is [1,8,16,128] — fixed size, already shifted)

View File

@ -112,4 +112,28 @@ Java_com_kazeia_tts_NeonOps_headArgmaxBatch(
return jResult; return jResult;
} }
/**
* Argmax with offset into weight buffer (avoids Java array copy).
*/
JNIEXPORT jint JNICALL
Java_com_kazeia_tts_NeonOps_headArgmaxOffset(
JNIEnv* env, jclass,
jfloatArray jHidden, jfloatArray jAllHeads, jint offset, jint vocab, jint dim)
{
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr);
const float* W = allHeads + offset;
int best = 0;
float bestVal = -FLT_MAX;
for (int j = 0; j < vocab; j++) {
float dot = dot_neon(hidden, W + j * dim, dim);
if (dot > bestVal) { bestVal = dot; best = j; }
}
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT);
return best;
}
} // extern "C" } // extern "C"