diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/NeonOps.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/NeonOps.kt index c938da4..cfc66c9 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/NeonOps.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/NeonOps.kt @@ -9,4 +9,7 @@ object NeonOps { /** Batch argmax for all heads at once (avoids JNI overhead per head). */ external fun headArgmaxBatch(hidden: FloatArray, allHeads: FloatArray, numHeads: Int, vocab: Int, dim: Int): IntArray + + /** Argmax with offset into a large weight buffer (avoids array copy). */ + external fun headArgmaxOffset(hidden: FloatArray, allHeads: FloatArray, offset: Int, vocab: Int, dim: Int): Int } diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index b378674..1e17c20 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -127,6 +127,7 @@ class Qwen3TtsEngine( private var cpRotarySin: FloatArray? = null private var cpHeadsPath: String? = null // path dir for head_0..14.npy private var cpHeadsCache: Array? = null // lazy-loaded heads cache (8MB each) + private var cpAllHeads: FloatArray? = null // all 15 heads concatenated [15*2048*1024] for batch NEON private var cpPteModule: org.pytorch.executorch.Module? = null // ExecuTorch CP on NPU (JNI) private var talkerPteModule: org.pytorch.executorch.Module? = null // ExecuTorch talker on NPU (JNI) private var talkerPteRotaryCos: FloatArray? = null @@ -207,7 +208,7 @@ class Qwen3TtsEngine( return session } - // Speech decoder: try V2 ONNX CPU (correct weights), fall back to QNN HTP + // Speech decoder V2: CPU ONNX (GPU tested: no gain, +300ms overhead) val v2Path = "$path/v2_pre_conv" if (File("$v2Path/model.onnx").exists()) { nlog("Loading V2 speech decoder (CPU ONNX)...") @@ -1627,16 +1628,23 @@ class Qwen3TtsEngine( // .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ... val hiddenOut = outputs[0].toTensor().dataAsFloatArray - // Head argmax using NEON SIMD (5× faster than Java) + // Head argmax using NEON SIMD (individual 8MB heads, pre-loaded) if (step >= 1 && step - 1 < 15) { - if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15) - val cache = cpHeadsCache!! - val cbIdx = step - 1 - if (cache[cbIdx] == null) { - val hp = cpHeadsPath ?: return codes - cache[cbIdx] = loadNpy(hp.replace("cp_heads.npy", "head_${cbIdx}.npy")) + if (cpHeadsCache == null) { + cpHeadsCache = arrayOfNulls(15) + val hp = cpHeadsPath + if (hp != null) { + for (h in 0 until 15) { + cpHeadsCache!![h] = loadNpy(hp.replace("cp_heads.npy", "head_${h}.npy")) + } + nlog("CP heads pre-loaded: 15 × ${cpHeadsCache!![0]?.size} floats") + } + } + val cbIdx = step - 1 + val headData = cpHeadsCache?.get(cbIdx) + if (headData != null) { + codes[cbIdx] = NeonOps.headArgmax(hiddenOut, headData, CODEBOOK_SIZE, TALKER_DIM) } - codes[cbIdx] = NeonOps.headArgmax(hiddenOut, cache[cbIdx]!!, CODEBOOK_SIZE, TALKER_DIM) } // Update KV caches (output is [1,8,16,128] — fixed size, already shifted) diff --git a/kazeia-android/app/src/main/jni/neon_ops.cpp b/kazeia-android/app/src/main/jni/neon_ops.cpp index b2d5c70..55c543a 100644 --- a/kazeia-android/app/src/main/jni/neon_ops.cpp +++ b/kazeia-android/app/src/main/jni/neon_ops.cpp @@ -112,4 +112,28 @@ Java_com_kazeia_tts_NeonOps_headArgmaxBatch( return jResult; } +/** + * Argmax with offset into weight buffer (avoids Java array copy). + */ +JNIEXPORT jint JNICALL +Java_com_kazeia_tts_NeonOps_headArgmaxOffset( + JNIEnv* env, jclass, + jfloatArray jHidden, jfloatArray jAllHeads, jint offset, jint vocab, jint dim) +{ + jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr); + jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr); + const float* W = allHeads + offset; + + int best = 0; + float bestVal = -FLT_MAX; + for (int j = 0; j < vocab; j++) { + float dot = dot_neon(hidden, W + j * dim, dim); + if (dot > bestVal) { bestVal = dot; best = j; } + } + + env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT); + env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT); + return best; +} + } // extern "C"