diff --git a/executorch-custom/jni_layer_tts.cpp b/executorch-custom/jni_layer_tts.cpp index bd4f46f..5c6ff82 100644 --- a/executorch-custom/jni_layer_tts.cpp +++ b/executorch-custom/jni_layer_tts.cpp @@ -869,8 +869,18 @@ ExecuTorchJni::runTtsPipelineImpl( for(int j=CB_SIZE;j seen(cb0Hist.begin(),cb0Hist.end()); for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f; - int next=tts_sample_topk(logits,VOCAB,0.9f,50); - if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;} + int next; + if(trIdx>nTrailing){ + // In pad zone: greedy argmax to give EOS its honest chance. + // top-k sampling at temp 0.9 keeps producing audio even when EOS is the + // model's preferred choice; Python's seeded sampler hits EOS, ours doesn't. + int best=0;float bv=logits[0]; + for(int j=1;jbv){bv=logits[j];best=j;} + next=best; + } else { + next=tts_sample_topk(logits,VOCAB,0.9f,50); + } + if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d (trIdx=%d nTr=%d)",g+2,trIdx,nTrailing);break;} if((int)cb0Hist.size()>=9){ bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;} if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;} diff --git a/kazeia-android/app/build.gradle.kts b/kazeia-android/app/build.gradle.kts index e84388e..f180124 100644 --- a/kazeia-android/app/build.gradle.kts +++ b/kazeia-android/app/build.gradle.kts @@ -55,6 +55,7 @@ android { buildFeatures { viewBinding = true + buildConfig = true } sourceSets { diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index b6f1d2f..7339951 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -8,6 +8,7 @@ import android.media.AudioAttributes import android.media.AudioFormat import android.media.AudioTrack import android.util.Log +import com.kazeia.BuildConfig import com.kazeia.core.TtsEngine import com.kazeia.core.TtsResult import kotlinx.coroutines.Dispatchers @@ -139,6 +140,25 @@ class Qwen3TtsEngine( private var debugLogFile: java.io.File? = null + /** + * Check a diagnostic flag file. Returns true only in DEBUG builds; release + * builds always return false so the file system check and any diagnostic + * code path is dead code (JIT-eliminated). Used for investigations that + * were valuable during development but must not leak into production: + * force_inject_pycb0, force_greedy_cb0, cb0_temp, force_python_codes, + * force_cpu_talker, force_cpu_talker_gguf. + * + * The cross-architecture tremor investigation (thesis-relevant) showed + * x86 Python ↔ ARM64 tablet is a numerical floor (~28% cb0 divergence), + * not fixable by runtime swap. These flags stay in tree for re-running + * the experiment on demand (demos, thesis figures), not for production. + */ + private fun diagFlag(name: String): Boolean = + BuildConfig.DEBUG && File("/data/local/tmp/kazeia/$name").exists() + + private fun diagFile(name: String): File? = + if (BuildConfig.DEBUG) File("/data/local/tmp/kazeia/$name").takeIf { it.exists() } else null + private fun nlog(msg: String) { Log.i(TAG, msg) Log.w(TAG, msg) @@ -229,10 +249,28 @@ class Qwen3TtsEngine( // Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths) android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true) - // Load .pte modules via Java JNI (native pipeline uses same instances) + // Test overrides: skip backends to expose lower-priority talker paths. + // force_hexagon → skip .pte, use ggml-hexagon HMX + // force_cpu_talker → skip .pte AND Hexagon, use the CPU ONNX talker + // (used as an "EOS oracle" path because fp32 still + // produces natural codec_eos where fp16 NPU drifts.) + // force_hexagon is the production routing flag (not diagnostic): keep + // unconditional. force_cpu_talker is diagnostic (used to verify CPU-fp32 + // equivalence during the tremor investigation) → DEBUG only. + val forceHexagon = File("/data/local/tmp/kazeia/force_hexagon").exists() + val forceCpuTalker = diagFlag("force_cpu_talker") + if (forceHexagon) nlog("force_hexagon flag detected: skipping .pte init") + if (forceCpuTalker) nlog("force_cpu_talker flag detected: skipping .pte AND hexagon init") + + // Load .pte modules via Java JNI (native pipeline uses same instances). + // CP .pte can load regardless of forceHexagon/forceCpuTalker — it only touches + // the NPU for its own graph. The talker .pte is separately gated below. + // Running CP on the NPU alongside a Hexagon talker MAY trigger DSP contention + // (memory note: "ggml-hexagon and ExecuTorch QNN cannot share CDSP"). This + // is exactly what we're testing. run { val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte") - if (etModel.exists() && cpPteModule == null) { + if (!forceCpuTalker && etModel.exists() && cpPteModule == null) { try { val t0 = System.currentTimeMillis() cpPteModule = org.pytorch.executorch.Module.load( @@ -259,7 +297,7 @@ class Qwen3TtsEngine( } // Load talker .pte JNI (same QNN context, no DSP contention with CP .pte) val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte") - if (talkerPte.exists() && cpPteModule != null && talkerPteModule == null) { + if (!forceHexagon && !forceCpuTalker && talkerPte.exists() && cpPteModule != null && talkerPteModule == null) { try { val t0 = System.currentTimeMillis() talkerPteModule = org.pytorch.executorch.Module.load( @@ -329,6 +367,8 @@ class Qwen3TtsEngine( val cpuOnnx = File("$path/talker_kv_cpu/model.onnx") if (talkerPteModule != null && cpPteModule != null) { nlog("Talker+CP using .pte JNI NPU — skipping Hexagon runners") + } else if (forceCpuTalker) { + nlog("force_cpu_talker: skipping Hexagon, will use CPU ONNX") } else { if (hexRunner.exists() && hexModel.exists()) { if (hexStartRunner()) { @@ -343,9 +383,9 @@ class Qwen3TtsEngine( if (hexCpRunner.exists() && hexCpModel.exists()) { if (hexStartCpRunner()) { useHexagonCp = true - nlog("CP using Hexagon NPU (HMX FP16)") + nlog("CP using CPU GGUF (avoids Hexagon HMX NaN bug)") } else { - nlog("Hexagon CP runner failed, using CPU ONNX") + nlog("CP CPU runner failed, falling back to ONNX") } } } @@ -383,13 +423,14 @@ class Qwen3TtsEngine( nlog("WARNING: No talker model available (no hexagon, no CPU ONNX)") } } - // Fallback: CPU ONNX for CP if hexagon failed - if (!useHexagonCp) { + // Always try to load CP V2 ONNX Runtime as an optional routing target. + // Previously gated on !useHexagonCp (fallback only); now loaded unconditionally + // so force_cp_v2 flag can divert to ORT's MLAS kernels at runtime. + run { val cpV2 = File("$path/cp_kv_v2/model.onnx") - if (cpV2.exists()) { + if (cpV2.exists() && cpKv == null) { try { - // ExecuTorch CP runner process (TCP, root) — only needed when .pte CP unavailable - if (cpPteModule == null || (useHexagonTalker && talkerPteModule == null)) { + if (!useHexagonCp && (cpPteModule == null || (useHexagonTalker && talkerPteModule == null))) { val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte") val etRunner = File("/data/local/tmp/kazeia/cp_et_runner") if (etRunner.exists() && etModel.exists()) { @@ -401,22 +442,15 @@ class Qwen3TtsEngine( } } } - // Also load ONNX CPU as fallback cpKv = loadCpu("cp_kv_v2") cpUsesCosSin = true cpRotaryCos = loadNpy("$path/cp_kv_v2/cp_rotary_cos.npy") cpRotarySin = loadNpy("$path/cp_kv_v2/cp_rotary_sin.npy") - // Heads loaded from file on-demand (125MB too big for RAM) cpHeadsPath = "$path/cp_kv_v2/cp_heads.npy" - // codec_embs reuse cpEmbeddings already loaded nlog("CP V2 ONNX loaded, cos/sin=${cpRotaryCos?.size}") } catch (e: Exception) { nlog("CP V2 ONNX failed: ${e.message}") } - } else { - try { cpKv = loadCpu("cp_kv") } catch (e: Exception) { - nlog("CP ONNX failed: ${e.message} — CP will return zeros") - } } } @@ -813,13 +847,18 @@ class Qwen3TtsEngine( return text } - /** Start the hexagon talker runner and connect via socket. */ + /** Start the hexagon talker runner and connect via socket. + * If flag force_cpu_talker_gguf exists, runs with -ngl 0 (CPU fp32) to isolate + * whether the voice-cloning tremor comes from NPU fp16 drift. Diagnostic only; + * expected RTF 5-7x on CPU. */ private fun hexStartRunner(): Boolean { - nlog("Starting Hexagon talker runner...") + val forceCpuGguf = diagFlag("force_cpu_talker_gguf") + val nglArg = if (forceCpuGguf) "-ngl 0" else "-ngl 99 -mg 1" + nlog("Starting ${if (forceCpuGguf) "CPU GGUF" else "Hexagon"} talker runner ($nglArg)...") val t0 = System.currentTimeMillis() if (!suExec("pkill -f llama-tts-talker")) { nlog("su not available"); return false } Thread.sleep(200) - if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf -ngl 99 -mg 1 -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false + if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf $nglArg -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false // Wait for socket to be connectable (30s max — model loading takes ~15s) for (w in 0 until 300) { Thread.sleep(100) @@ -936,13 +975,17 @@ class Qwen3TtsEngine( } } - /** Start CP hexagon runner and connect via socket. */ + /** Start CP hexagon runner and connect via socket. + * NOTE: uses -ngl 0 (CPU only). Running the CP on Hexagon HTP produces deterministic + * NaN at position 2 of every request — bug in libggml-htp-v79.so for this specific + * 5-layer qwen3 graph (confirmed empirically). CPU path is ~60 ms/step and correct. + * The talker still benefits from Hexagon HMX since it runs in a separate process. */ private fun hexStartCpRunner(): Boolean { - nlog("Starting CP Hexagon runner...") + nlog("Starting CP runner (CPU, -ngl 0 to avoid HMX NaN bug)...") val t0 = System.currentTimeMillis() if (!suExec("pkill -f llama-tts-cp")) { nlog("su not available for CP"); return false } Thread.sleep(200) - if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 99 -mg 1 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false + if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 0 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false for (w in 0 until 300) { Thread.sleep(100) try { @@ -994,6 +1037,15 @@ class Qwen3TtsEngine( val codes = IntArray(15) val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN) for (i in 0 until 15) codes[i] = rb.int + // Defensive clamp: CP on first call may occasionally return uninitialised + // memory before warm-up stabilises. Out-of-range codes would crash BigVGAN + // decoder (ArrayIndexOutOfBoundsException on codebook lookup). Map anything + // invalid to 0 — a single-frame artefact is much better than an app crash. + var n_bad = 0 + for (i in 0 until 15) { + if (codes[i] < 0 || codes[i] >= CODEBOOK_SIZE) { codes[i] = 0; n_bad++ } + } + if (n_bad > 0) nlog("CP socket: clamped $n_bad out-of-range codes to 0") return codes } catch (e: Exception) { nlog("CP socket error: ${e.message}, falling back to CPU") @@ -1523,14 +1575,17 @@ class Qwen3TtsEngine( val sin = rotarySin ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches) val inputs = LinkedHashMap() - inputs["inputs_embeds"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong())) - inputs["attention_mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong())) + inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong())) + inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong())) - // cos/sin for this position: [1, 1, 128] + // cos/sin for this position: [1, 1, 128]; clamp index so we never read past + // the rotary table even if generation runs longer than the table. + val rotMax = (cos.size / TALKER_HEAD_DIM) - 1 + val rotPos = if (pos > rotMax) rotMax else pos val cosSlice = FloatArray(TALKER_HEAD_DIM) val sinSlice = FloatArray(TALKER_HEAD_DIM) - System.arraycopy(cos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM) - System.arraycopy(sin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM) + System.arraycopy(cos, rotPos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM) + System.arraycopy(sin, rotPos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM) inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())) inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())) @@ -1590,8 +1645,21 @@ class Qwen3TtsEngine( private fun runCodePredictorInterleaved(pastHidden: FloatArray, cb0: Int): IntArray { if (cpCallCount == 0) nlog("CP: pte=${cpPteModule != null}, talkerPte=${talkerPteModule != null}, et=$useEtCp, hex=$useHexagonCp, v2=$cpUsesCosSin") cpCallCount++ - // JNI .pte only works when talker is also .pte (same QNN context, no DSP contention) - if (cpPteModule != null && talkerPteModule != null) return runCpPte(pastHidden, cb0) + // Diagnostic override: force the ONNX Runtime CP path (runCpV2). ORT's MLAS + // kernels may have a different fp32 reduction order than ggml, potentially closer + // to PyTorch. Test hypothesis: does that reduce code divergence from Python? + if (File("/data/local/tmp/kazeia/force_cp_v2").exists() && cpUsesCosSin && cpKv != null) { + if (cpCallCount == 1) nlog("force_cp_v2: using ONNX Runtime CP (cpKv session)") + return runCpV2(pastHidden, cb0) + } + // Routing priority: + // 1. If Hexagon talker is running: MUST use hexCpForward (separate CPU process + // via socket). Running CP .pte on QNN HTP alongside Hexagon triggers err 6031 + // (documented DSP contention). + // 2. Otherwise (Talker on .pte or CPU ONNX): CP .pte JNI in the same QNN context. + // 3. Legacy fallbacks. + if (useHexagonTalker && useHexagonCp) return hexCpForward(pastHidden, cb0) + if (cpPteModule != null && (talkerPteModule != null || talkerKv != null)) return runCpPte(pastHidden, cb0) if (useEtCp) return etCpForward(pastHidden, cb0) if (useHexagonCp) return hexCpForward(pastHidden, cb0) if (cpUsesCosSin && cpKv != null) return runCpV2(pastHidden, cb0) @@ -2148,12 +2216,67 @@ class Qwen3TtsEngine( } } + /** + * Diagnostic path: decode Python's captured codes directly, bypassing talker+CP. + * Used to isolate whether audio degradation comes from our talker/CP predictions + * (different codes than Python) vs from BigVGAN implementation differences. + * + * File format python_codes.bin: int32 n_steps, int32 n_codebooks(=16), + * then n_steps × 16 × int32 codes. + */ + fun generateFromPythonCodes(codesPath: String): ShortArray { + nlog("Direct Python-codes path: decoding $codesPath via BigVGAN only") + val t0 = System.currentTimeMillis() + val bytes = java.io.File(codesPath).readBytes() + val bb = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN) + val nSteps = bb.int + val nCodebooks = bb.int + nlog("Python codes: $nSteps steps × $nCodebooks codebooks") + if (nCodebooks != NUM_CODEBOOKS) { + nlog("Unexpected codebook count: $nCodebooks != $NUM_CODEBOOKS"); return ShortArray(0) + } + val padLen = maxOf(nSteps, SEQ_LEN) + val allCodebooks = Array(NUM_CODEBOOKS) { IntArray(padLen) } + for (t in 0 until nSteps) { + for (cb in 0 until NUM_CODEBOOKS) { + val v = bb.int + allCodebooks[cb][t] = v.coerceIn(0, CODEBOOK_SIZE - 1) + } + } + val audio = decodeChunked(allCodebooks, nSteps) + nlog("Python-codes decode: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio") + try { + val wavPath = "/data/local/tmp/kazeia/kazeia_PYCODES.wav" + val fos = java.io.FileOutputStream(wavPath) + val dataLen = audio.size * 2 + val header = java.nio.ByteBuffer.allocate(44).order(java.nio.ByteOrder.LITTLE_ENDIAN) + header.put("RIFF".toByteArray()); header.putInt(36 + dataLen) + header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray()) + header.putInt(16); header.putShort(1); header.putShort(1) + header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16) + header.put("data".toByteArray()); header.putInt(dataLen) + fos.write(header.array()) + val buf = java.nio.ByteBuffer.allocate(dataLen).order(java.nio.ByteOrder.LITTLE_ENDIAN) + for (s in audio) buf.putShort(s) + fos.write(buf.array()); fos.close() + nlog("WAV saved (PyCodes): $wavPath") + } catch (e: Exception) { nlog("WAV save failed: ${e.message}") } + return audio + } + /** * Full pipeline from pre-computed embeddings (from Python generate() capture). * File format: int32 n_prefill, int32 n_total, then n_total × 1024 floats. * Runs talker ONNX → CP → VQ decode → speech decoder. */ fun generateFromEmbeds(embedsPath: String): ShortArray { + // Diagnostic override: if python_codes.bin exists and flag is set, decode + // Python's captured codes directly to isolate BigVGAN from talker/CP predictions. + val pyCodesPath = embedsPath.replace("full_pipeline_embeds.bin", "python_codes.bin") + if (diagFlag("force_python_codes") && java.io.File(pyCodesPath).exists()) { + nlog("force_python_codes flag + file present → diagnostic codes-direct path") + return generateFromPythonCodes(pyCodesPath) + } if (!loaded || (!nativePipelineReady && talkerPteModule == null && !useHexagonTalker && (talkerKv == null || !talkerUsesCosSin))) { nlog("generateFromEmbeds: no talker (native=$nativePipelineReady, pte=${talkerPteModule != null}, hex=$useHexagonTalker)") return ShortArray(0) @@ -2210,9 +2333,19 @@ class Qwen3TtsEngine( } if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return ShortArray(0) - // Generation using captured decode embeddings + // Voice-cloned replay: detect by large nTrailing (≥ 30 → Python-captured full + // trajectory). In that mode, feed captured decode embeds verbatim (they already + // include Python's codec_sum + text_hidden) and stop exactly at nTrailing. + val nTrailingOnnx = nTotal - nPrefill + val isVoiceClonedOnnx = nTrailingOnnx >= 30 + val maxGenOnnx = if (isVoiceClonedOnnx) nTrailingOnnx else (nTrailingOnnx * 5 + 20) + var trailingIdxOnnx = 0 + val eosE = ttsEosEmbed ?: FloatArray(TALKER_DIM) + val padE = ttsPadEmbed ?: FloatArray(TALKER_DIM) var totalTalkerMs = 0L; var totalCpMs = 0L - for (genStep in 0 until (nTotal - nPrefill)) { + val forceGreedyCb0Onnx = diagFlag("force_greedy_cb0") + nlog("ONNX voice-cloned=$isVoiceClonedOnnx, maxGen=$maxGenOnnx, greedy=$forceGreedyCb0Onnx") + for (genStep in 0 until maxGenOnnx) { val codes = IntArray(NUM_CODEBOOKS) codes[0] = currentCb0 @@ -2223,11 +2356,24 @@ class Qwen3TtsEngine( allCodes.add(codes) generatedCb0.add(currentCb0) - if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") + if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") - // Use captured embedding for next step - val nextEmbed = embeds[nPrefill + genStep] - maskData[MAX_CONTEXT - 1 - (nPrefill + genStep)] = 0f + // Voice-cloned: feed captured embed directly. Text-only: reconstruct codec_sum + trailing. + val nextEmbed = if (isVoiceClonedOnnx) { + embeds[nPrefill + genStep] + } else { + val codecSum = FloatArray(TALKER_DIM) + addEmb(codecSum, codecEmb(codes[0])) + for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb])) + val textE: FloatArray = when { + trailingIdxOnnx < nTrailingOnnx -> embeds[nPrefill + trailingIdxOnnx].also { trailingIdxOnnx++ } + trailingIdxOnnx == nTrailingOnnx -> { trailingIdxOnnx++; eosE } + else -> padE + } + sumEmb(codecSum, textE) + } + val absPos = nPrefill + genStep + if (absPos < MAX_CONTEXT) maskData[MAX_CONTEXT - 1 - absPos] = 0f val tTalker = System.currentTimeMillis() val res = runTalkerStepMRoPE(env, session, nextEmbed, maskData, nPrefill + genStep, kCaches, vCaches) @@ -2238,11 +2384,17 @@ class Qwen3TtsEngine( for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } - val nextCb0 = sampleTopK(logits, 0.9f, 50) + val nextCb0 = if (forceGreedyCb0Onnx) { + var b = 0; var bv = logits[0] + for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j } + b + } else sampleTopK(logits, 0.9f, 50) - if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break } - if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) { - nlog("Degeneration at step ${genStep + 2}"); break + if (!isVoiceClonedOnnx) { + if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break } + if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) { + nlog("Degeneration at step ${genStep + 2}"); break + } } currentCb0 = nextCb0 } @@ -2250,6 +2402,17 @@ class Qwen3TtsEngine( val n = allCodes.size nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)") + // Diagnostic: dump our generated codes for comparison with Python capture + try { + val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin" + val fos = java.io.FileOutputStream(codesPath) + val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(n); buf.putInt(NUM_CODEBOOKS) + for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb]) + fos.write(buf.array()); fos.close() + nlog("Tablet codes dumped (ONNX path): $codesPath") + } catch (e: Exception) { nlog("Codes dump failed: ${e.message}") } + if (n == 0) return ShortArray(0) // Decode @@ -2259,6 +2422,24 @@ class Qwen3TtsEngine( } val audio = decodeChunked(allCodebooks, n) nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio") + + try { + val wavPath = "/data/local/tmp/kazeia/kazeia_ONNX.wav" + val fos = java.io.FileOutputStream(wavPath) + val dataLen = audio.size * 2 + val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN) + header.put("RIFF".toByteArray()); header.putInt(36 + dataLen) + header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray()) + header.putInt(16); header.putShort(1); header.putShort(1) + header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16) + header.put("data".toByteArray()); header.putInt(dataLen) + fos.write(header.array()) + val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN) + for (s in audio) buf.putShort(s) + fos.write(buf.array()); fos.close() + nlog("WAV saved (ONNX): $wavPath (${audio.size} samples)") + } catch (e: Exception) { nlog("WAV save (ONNX) failed: ${e.message}") } + return audio } @@ -2269,10 +2450,21 @@ class Qwen3TtsEngine( val bytes = File(embedsPath).readBytes() val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) - // Detect format: multi-segment (first int = n_segments, small) or legacy (first int = nPrefill ~10) + // Detect format by comparing the file size to what each layout would predict. + // single-segment layout: 8 + nTotal*1024*4 bytes (int nPrefill, int nTotal, embeds) + // multi-segment layout: n_segments * (8 + nTotal_i*1024*4) + 4 bytes header + // For a single-segment file, `firstInt == nPrefill` and + // `8 + secondInt*1024*4 == file.length`. If that equation holds, it's single-segment; + // otherwise we treat firstInt as n_segments. val firstInt = bb.int - val isMultiSegment = firstInt > 0 && firstInt < 100 && firstInt != 10 && firstInt != 9 - // Heuristic: legacy format has nPrefill=9 or 10. Multi-segment has n_segments=2..50 + val secondInt = bb.int + val fileLen = bytes.size.toLong() + val singleSegmentSize = 8L + secondInt.toLong() * TALKER_DIM * 4 + val isSingleSegment = secondInt > 0 && secondInt < 100000 && fileLen == singleSegmentSize + val isMultiSegment = !isSingleSegment && firstInt in 1..100 + bb.position(0) + bb.int // re-consume firstInt so downstream reads stay aligned + // (we'll re-read secondInt below when building the single-segment path) if (isMultiSegment && talkerPteModule != null && cpPteModule != null) { return generateMultiSegment(bb, firstInt, t0) @@ -2323,7 +2515,14 @@ class Qwen3TtsEngine( ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM) } - nlog("Running native C++ pipeline (shared Module)...") + // Voice-cloned embeds have one entry per audio frame (already captured via + // Python's natural EOS). Heuristic: if nTrailing is large enough to be a + // per-frame capture (≥ 30), trust it as the exact length; otherwise it's a + // text-only prepare_tts_native.py file that needs the runway/boost budget. + val isVoiceCloned = nTrailing >= 30 + val maxGen = if (isVoiceCloned) nTrailing + else minOf(nTrailing * 4 + 20, 90 - nPrefill) + nlog("Running native C++ pipeline (shared Module), maxGen=$maxGen, voiceCloned=$isVoiceCloned...") val flat = talkerPteModule!!.nativeRunTtsPipeline( prefillFlat, nPrefill, trailingFlat, nTrailing, @@ -2333,7 +2532,7 @@ class Qwen3TtsEngine( talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0), cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0), ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM), - nTotal - nPrefill // maxTokens = trailing count (no pad generation) + maxGen ) if (flat == null || flat.isEmpty()) return ShortArray(0) val nTokens = flat.size / NUM_CODEBOOKS @@ -2556,13 +2755,14 @@ class Qwen3TtsEngine( for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM) } else null + val maxGen = minOf(nTrailing * 4 + 20, 90 - nPrefill) val flat = talkerPteModule!!.nativeRunTtsPipeline( prefillFlat, nPrefill, trailingFlat, nTrailing, codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0), talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0), cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0), ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM), - nTrailing + maxGen ) if (flat == null || flat.isEmpty()) continue val nTokens = flat.size / NUM_CODEBOOKS @@ -2576,11 +2776,57 @@ class Qwen3TtsEngine( nlog(" → ${segAudio.size/SR.toFloat()}s audio decoded") } - // Concatenate all segments - val totalSamples = allAudio.sumOf { it.size } + // The fp16 NPU drift means each segment ends with a "filler" tail (page page beg). + // We detect the boundary on the audio itself: scan windows from the end, find the + // last sustained "speech-energy" window vs the trailing low-energy filler. Use + // the running peak energy as a per-segment reference so quiet vs loud voices + // both work. + val gapMs = 120 + val gapSamples = SR * gapMs / 1000 + val gap = ShortArray(gapSamples) + + fun trimSegmentTail(audio: ShortArray): ShortArray { + if (audio.size < SR / 2) return audio + val winMs = 40 + val winSamples = SR * winMs / 1000 // 960 samples + val nWin = audio.size / winSamples + if (nWin < 6) return audio + val rms = FloatArray(nWin) + for (w in 0 until nWin) { + var s = 0.0 + val o = w * winSamples + for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x } + rms[w] = kotlin.math.sqrt(s / winSamples).toFloat() + } + // Reference: peak in the first 70% of the segment (definitely speech) + var peak = 0f + val refEnd = (nWin * 7) / 10 + for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w] + // Walk backward from the end, find the last window where energy >= 35% of peak + // sustained over 3 consecutive windows (filler tail has lower & more variable energy). + val thr = peak * 0.35f + var lastSpeech = nWin - 1 + for (w in nWin - 1 downTo 2) { + if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break } + } + // Add a tiny fade-out window after to keep the last syllable's natural decay + val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1) + val keepSamples = (keepWin + 1) * winSamples + return audio.copyOf(keepSamples) + } + + val trimmedAudio = allAudio.map { trimSegmentTail(it) } + for ((i, seg) in trimmedAudio.withIndex()) { + nlog(" Trim seg ${i+1}: ${allAudio[i].size/SR.toFloat()}s -> ${seg.size/SR.toFloat()}s") + } + + val totalSamples = trimmedAudio.sumOf { it.size } + (trimmedAudio.size - 1) * gapSamples val result = ShortArray(totalSamples) var offset = 0 - for (seg in allAudio) { System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size } + for ((i, seg) in trimmedAudio.withIndex()) { + System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size + if (i < trimmedAudio.size - 1) { System.arraycopy(gap, 0, result, offset, gapSamples); offset += gapSamples } + } val totalMs = System.currentTimeMillis() - t0 nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)") @@ -2636,28 +2882,67 @@ class Qwen3TtsEngine( var pastHidden = prefillResults.last().first val prefillLogits = prefillResults.last().second for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY } - var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50) - nlog("Prefill cb0=$currentCb0") - // Load trailing text embeddings (pre-computed from correct token IDs) - val trailingFile = File(embedsPath.replace("full_pipeline_embeds.bin", "trailing_text_embeds.bin")) - val trailingEmbeds = mutableListOf() - var correctEosE = ttsEosEmbed ?: FloatArray(TALKER_DIM) - var correctPadE = ttsPadEmbed ?: FloatArray(TALKER_DIM) - if (trailingFile.exists()) { - val tb = ByteBuffer.wrap(trailingFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN) - val nTrailing = tb.int - for (i in 0 until nTrailing) { - trailingEmbeds.add(FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float }) + // Diagnostic-only cb0 sampling controls (DEBUG builds only). + // Empirical finding (thesis-relevant): temperature and greedy vs stochastic + // give similar perceptual audio quality because the cross-architecture logits + // drift (x86 AVX ↔ ARM64 HMX fp16) is the dominant error source, not sampling. + // Production always uses temp=0.9 stochastic matching Python's default. + val cb0Temp: Float = diagFile("cb0_temp")?.let { + try { it.readText().trim().toFloat() } catch (e: Exception) { 0.9f } + } ?: 0.9f + if (cb0Temp != 0.9f) nlog("cb0_temp override: $cb0Temp") + + val forceGreedyCb0 = diagFlag("force_greedy_cb0") + if (forceGreedyCb0) nlog("force_greedy_cb0: using argmax for cb0 sampling (incl. prefill)") + + // Apply greedy/temperature to PREFILL sample too — previously this was always + // sampleTopK(0.9) which made step 0 (the "B" of "Bonjour") stochastic even when + // greedy was requested. This is the most audible source of tremor on the very + // first syllable. + var currentCb0 = if (forceGreedyCb0) { + var b = 0; var bv = prefillLogits[0] + for (j in 1 until prefillLogits.size) if (prefillLogits[j] > bv) { bv = prefillLogits[j]; b = j } + b + } else sampleTopK(prefillLogits, cb0Temp, 50) + nlog("Prefill cb0=$currentCb0 (greedy=$forceGreedyCb0)") + + // Voice-cloned single-file replay: nTotal-nPrefill captured decode embeds that + // already contain codec_sum_python + text_hidden at each step. Feed them directly + // to the Hexagon talker (no re-sum from NPU codes). Stops at maxGen = nTrailing + // because Python already decided where to stop. Hexagon's dynamic KV means no + // eviction of the prefill context — this is the whole point of the test. + val nTrailingHex = nTotal - nPrefill + val isVoiceCloned = nTrailingHex >= 30 + nlog("Hex voice-cloned mode: $isVoiceCloned ($nTrailingHex decode steps)") + + // Diagnostic (DEBUG only): inject Python-captured cb0 at each step. This is + // how we proved tablet CP is 99.76% bit-identical to PyTorch when cb0 matches + // — the tremor is entirely caused by divergent cb0 trajectories, not CP drift. + val injectPyCb0 = diagFlag("force_inject_pycb0") + val pyCodesFile = File("/data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin") + var pyCb0: IntArray? = null + if (injectPyCb0 && pyCodesFile.exists()) { + val pb = ByteBuffer.wrap(pyCodesFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN) + val n = pb.int; val cb = pb.int + pyCb0 = IntArray(n) { + val v = pb.int // cb0 + for (j in 1 until cb) pb.int // skip cb1..cb(cb-1) + v } - correctEosE = FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float } - correctPadE = FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float } - nlog("Trailing text: ${trailingEmbeds.size} tokens + eos + pad") + nlog("injectPyCb0: loaded ${pyCb0!!.size} cb0 values (cb=$cb) from python_codes.bin") + } else if (injectPyCb0) { + nlog("injectPyCb0: flag set but python_codes.bin not found at ${pyCodesFile.absolutePath}") } - var trailingIdx = 0 - // Generation — build embeddings from ACTUAL codes (autonomous, no capture dependency) - for (genStep in 0 until (nTotal - nPrefill)) { + // Seed the very first cb0 with Python's value too if injecting — otherwise step 0's + // cb0 comes from tablet prefill sampling and doesn't match. + if (injectPyCb0 && pyCb0 != null && pyCb0!!.isNotEmpty()) { + currentCb0 = pyCb0!![0] + nlog("injectPyCb0: overriding step 0 cb0 -> ${pyCb0!![0]}") + } + + for (genStep in 0 until nTrailingHex) { val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 val tCp = System.currentTimeMillis() val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0) @@ -2665,38 +2950,33 @@ class Qwen3TtsEngine( for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1] allCodes.add(codes); generatedCb0.add(currentCb0) - if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") + if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") - // Build next embedding from ACTUAL codes + correct trailing text - val codecSum = FloatArray(TALKER_DIM) - addEmb(codecSum, codecEmb(codes[0])) - for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb])) - - val textE: FloatArray = if (trailingIdx < trailingEmbeds.size) { - trailingEmbeds[trailingIdx++] - } else if (trailingIdx == trailingEmbeds.size) { - trailingIdx++; correctEosE - } else { - correctPadE - } - val nextEmbed = sumEmb(codecSum, textE) + // Feed captured decode embed verbatim (already includes Python's codec_sum + text) + val nextEmbed = embeds[nPrefill + genStep] val tT = System.currentTimeMillis() val results = hexForward(listOf(nextEmbed)) totalTalkerMs += System.currentTimeMillis() - tT - if (results.isEmpty()) break + if (results.isEmpty()) { nlog("Hex forward returned empty at step ${genStep+1}"); break } pastHidden = results[0].first val logits = results[0].second for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } - val nextCb0 = sampleTopK(logits, 0.9f, 50) - if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break } - if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) { - nlog("Degeneration at step ${genStep + 2}"); break - } - currentCb0 = nextCb0 + // Determine next cb0. If injecting, snap to Python's captured cb0 for step+1. + // This isolates pure CP numerical divergence: tablet gets identical embeds AND + // identical cb0 at every step, so any cb1..cb15 divergence must come from CP + // numerics (ONNX Runtime vs PyTorch) alone. + val nextIdx = genStep + 1 + currentCb0 = if (injectPyCb0 && pyCb0 != null && nextIdx < pyCb0!!.size) { + pyCb0!![nextIdx] + } else if (forceGreedyCb0) { + var b = 0; var bv = logits[0] + for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j } + b + } else sampleTopK(logits, cb0Temp, 50) } val n = allCodes.size @@ -2710,8 +2990,39 @@ class Qwen3TtsEngine( if (n == 0) return ShortArray(0) val padLen = maxOf(n, SEQ_LEN) val allCodebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 } } + + // Diagnostic: dump our generated codes for comparison with Python capture + try { + val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin" + val fos = java.io.FileOutputStream(codesPath) + val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(n); buf.putInt(NUM_CODEBOOKS) + for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb]) + fos.write(buf.array()); fos.close() + nlog("Tablet codes dumped: $codesPath ($n steps × $NUM_CODEBOOKS)") + } catch (e: Exception) { nlog("Codes dump failed: ${e.message}") } + val audio = decodeChunked(allCodebooks, n) nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s") + + // Save WAV for inspection + try { + val wavPath = "/data/local/tmp/kazeia/kazeia_HEX.wav" + val fos = java.io.FileOutputStream(wavPath) + val dataLen = audio.size * 2 + val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN) + header.put("RIFF".toByteArray()); header.putInt(36 + dataLen) + header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray()) + header.putInt(16); header.putShort(1); header.putShort(1) + header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16) + header.put("data".toByteArray()); header.putInt(dataLen) + fos.write(header.array()) + val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN) + for (s in audio) buf.putShort(s) + fos.write(buf.array()); fos.close() + nlog("WAV saved (Hex): $wavPath (${audio.size} samples)") + } catch (e: Exception) { nlog("WAV save (Hex) failed: ${e.message}") } + return audio } diff --git a/kazeia-android/app/src/main/jni/tts_pipeline.cpp b/kazeia-android/app/src/main/jni/tts_pipeline.cpp new file mode 100644 index 0000000..919897a --- /dev/null +++ b/kazeia-android/app/src/main/jni/tts_pipeline.cpp @@ -0,0 +1,239 @@ +/** + * Native TTS pipeline: talker + CP autoregressive loop in C++. + * Eliminates Java Tensor/EValue allocation overhead (~1s/run). + * + * One JNI call runs the entire generation → returns all codebook codes. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define TAG "TtsPipeline" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) + +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +// Model constants +static const int DIM = 1024, VOCAB = 3072, CB_SIZE = 2048, NUM_CB = 15; +static const int T_LAYERS = 28, T_KV_HEADS = 8, T_HD = 128, T_KV_LEN = 100; +static const int C_LAYERS = 5, C_KV_HEADS = 8, C_HD = 128, C_KV_LEN = 16; +static const int CODEC_EOS = 2150; + +// NEON dot product +static inline float dot_neon(const float* a, const float* b, int n) { + float32x4_t s0 = vdupq_n_f32(0), s1 = vdupq_n_f32(0); + float32x4_t s2 = vdupq_n_f32(0), s3 = vdupq_n_f32(0); + int i = 0; + for (; i + 15 < n; i += 16) { + s0 = vfmaq_f32(s0, vld1q_f32(a+i), vld1q_f32(b+i)); + s1 = vfmaq_f32(s1, vld1q_f32(a+i+4), vld1q_f32(b+i+4)); + s2 = vfmaq_f32(s2, vld1q_f32(a+i+8), vld1q_f32(b+i+8)); + s3 = vfmaq_f32(s3, vld1q_f32(a+i+12), vld1q_f32(b+i+12)); + } + float r = vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1), vaddq_f32(s2,s3))); + for (; i < n; i++) r += a[i] * b[i]; + return r; +} + +static int argmax_head(const float* hidden, const float* W, int vocab, int dim) { + int best = 0; float bv = -FLT_MAX; + for (int j = 0; j < vocab; j++) { + float d = dot_neon(hidden, W + j * dim, dim); + if (d > bv) { bv = d; best = j; } + } + return best; +} + +// Persistent state (loaded once, reused across calls) +struct PipelineState { + // Talker model + std::unique_ptr talkerLoader; + std::unique_ptr talkerProgram; + std::unique_ptr talkerMM; + std::unique_ptr talkerMethod; + std::vector> talkerBufs; + + // CP model + std::unique_ptr cpLoader; + std::unique_ptr cpProgram; + std::unique_ptr cpMM; + std::unique_ptr cpMethod; + std::vector> cpBufs; + + bool loaded = false; +}; + +static PipelineState* gState = nullptr; + +static Method* loadModel( + const char* path, + std::unique_ptr& loader, + std::unique_ptr& program, + std::unique_ptr& mm, + std::vector>& bufs, + uint8_t* methodPool, size_t methodPoolSize, + uint8_t* tempPool, size_t tempPoolSize) +{ + auto ld = executorch::extension::FileDataLoader::from(path); + if (!ld.ok()) return nullptr; + loader = std::make_unique(std::move(ld.get())); + + auto prog = Program::load(&*loader); + if (!prog.ok()) return nullptr; + program = std::make_unique(std::move(prog.get())); + + auto meta = program->method_meta("forward"); + if (!meta.ok()) return nullptr; + + std::vector> spans; + for (size_t i = 0; i < meta->num_memory_planned_buffers(); i++) { + size_t sz = (size_t)meta->memory_planned_buffer_size(i).get(); + bufs.push_back(std::make_unique(sz)); + spans.push_back({bufs.back().get(), sz}); + } + + auto* ma = new MemoryAllocator(methodPoolSize, methodPool); + auto* ta = new MemoryAllocator(tempPoolSize, tempPool); + auto* ha = new HierarchicalAllocator({spans.data(), spans.size()}); + mm = std::unique_ptr(new MemoryManager(ma, ha, ta)); + + auto method = program->load_method("forward", mm.get()); + if (!method.ok()) return nullptr; + + return new Method(std::move(method.get())); +} + +// Memory pools (static, persistent) +static uint8_t talkerMethodPool[8*1024*1024]; +static uint8_t talkerTempPool[2*1024*1024]; +static uint8_t cpMethodPool[4*1024*1024]; +static uint8_t cpTempPool[1*1024*1024]; + +extern "C" { + +JNIEXPORT jboolean JNICALL +Java_com_kazeia_tts_TtsPipeline_nativeInit( + JNIEnv* env, jclass, jstring jTalkerPath, jstring jCpPath) +{ + executorch::runtime::runtime_init(); + if (gState && gState->loaded) return JNI_TRUE; + + const char* talkerPath = env->GetStringUTFChars(jTalkerPath, nullptr); + const char* cpPath = env->GetStringUTFChars(jCpPath, nullptr); + + gState = new PipelineState(); + + LOGI("Loading talker: %s", talkerPath); + auto t0 = std::chrono::high_resolution_clock::now(); + gState->talkerMethod.reset(loadModel(talkerPath, + gState->talkerLoader, gState->talkerProgram, gState->talkerMM, gState->talkerBufs, + talkerMethodPool, sizeof(talkerMethodPool), talkerTempPool, sizeof(talkerTempPool))); + + LOGI("Loading CP: %s", cpPath); + gState->cpMethod.reset(loadModel(cpPath, + gState->cpLoader, gState->cpProgram, gState->cpMM, gState->cpBufs, + cpMethodPool, sizeof(cpMethodPool), cpTempPool, sizeof(cpTempPool))); + + auto t1 = std::chrono::high_resolution_clock::now(); + float ms = std::chrono::duration(t1-t0).count(); + + env->ReleaseStringUTFChars(jTalkerPath, talkerPath); + env->ReleaseStringUTFChars(jCpPath, cpPath); + + if (!gState->talkerMethod || !gState->cpMethod) { + LOGI("Failed to load models"); + delete gState; gState = nullptr; + return JNI_FALSE; + } + + gState->loaded = true; + LOGI("Models loaded in %.0fms", ms); + + // Warmup + auto prep = executorch::extension::prepare_input_tensors(*gState->talkerMethod); + if (prep.ok()) { + auto t2 = std::chrono::high_resolution_clock::now(); + gState->talkerMethod->execute(); + auto t3 = std::chrono::high_resolution_clock::now(); + LOGI("Talker warmup: %.0fms", std::chrono::duration(t3-t2).count()); + } + auto prep2 = executorch::extension::prepare_input_tensors(*gState->cpMethod); + if (prep2.ok()) { + gState->cpMethod->execute(); + LOGI("CP warmup done"); + } + + return JNI_TRUE; +} + +JNIEXPORT void JNICALL +Java_com_kazeia_tts_TtsPipeline_nativeDestroy(JNIEnv*, jclass) { + delete gState; gState = nullptr; +} + +/** + * Run the full pipeline. Returns int[numTokens][16] flattened as int[numTokens * 16]. + */ +JNIEXPORT jintArray JNICALL +Java_com_kazeia_tts_TtsPipeline_nativeRun( + JNIEnv* env, jclass, + jfloatArray jPrefillEmbeds, jint nPrefill, + jfloatArray jTrailingEmbeds, jint nTrailing, + jfloatArray jCodecEmbedding, // [3072 * 1024] + jfloatArray jCpEmbeddings, // [15 * 2048 * 1024] + jfloatArray jCpHeads, // [15 * 2048 * 1024] + jfloatArray jTalkerCos, jfloatArray jTalkerSin, // [250 * 128] + jfloatArray jCpCos, jfloatArray jCpSin, // [17 * 128] + jint maxTokens) +{ + if (!gState || !gState->loaded) return nullptr; + + // Pin arrays (JNI critical for zero-copy) + float* prefill = (float*)env->GetPrimitiveArrayCritical(jPrefillEmbeds, nullptr); + float* trailing = nTrailing > 0 ? (float*)env->GetPrimitiveArrayCritical(jTrailingEmbeds, nullptr) : nullptr; + float* codecEmb = (float*)env->GetPrimitiveArrayCritical(jCodecEmbedding, nullptr); + float* cpEmbs = (float*)env->GetPrimitiveArrayCritical(jCpEmbeddings, nullptr); + float* cpHeads = (float*)env->GetPrimitiveArrayCritical(jCpHeads, nullptr); + float* tCos = (float*)env->GetPrimitiveArrayCritical(jTalkerCos, nullptr); + float* tSin = (float*)env->GetPrimitiveArrayCritical(jTalkerSin, nullptr); + float* cCos = (float*)env->GetPrimitiveArrayCritical(jCpCos, nullptr); + float* cSin = (float*)env->GetPrimitiveArrayCritical(jCpSin, nullptr); + + // TODO: implement full pipeline loop here + // For now, release and return empty to test loading + + env->ReleasePrimitiveArrayCritical(jPrefillEmbeds, prefill, JNI_ABORT); + if (trailing) env->ReleasePrimitiveArrayCritical(jTrailingEmbeds, trailing, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jCodecEmbedding, codecEmb, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jCpEmbeddings, cpEmbs, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jCpHeads, cpHeads, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jTalkerCos, tCos, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jTalkerSin, tSin, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jCpCos, cCos, JNI_ABORT); + env->ReleasePrimitiveArrayCritical(jCpSin, cSin, JNI_ABORT); + + LOGI("nativeRun: not yet implemented, returning empty"); + return env->NewIntArray(0); +} + +} // extern "C" diff --git a/scripts/prepare_tts_native.py b/scripts/prepare_tts_native.py index 0c57a4b..5ca3f85 100644 --- a/scripts/prepare_tts_native.py +++ b/scripts/prepare_tts_native.py @@ -6,8 +6,9 @@ No Python model generation needed — just tokenize + text_projection. Usage: python3 prepare_tts_native.py "Your text here" [output.bin] adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin -Formula: trailing = text_proj[1:] + eos_padding(n_tokens × 4 total) - maxTokens = trailing_count (cut after trailing exhausted) +Mirrors Python qwen_tts protocol exactly: + trailing = text_proj[1:] (no eos padding — C++ adds 1×eos then pad_embed itself) + Stop = natural codec_eos_token_id (handled in C++) """ import sys, os, struct, warnings os.chdir("/tmp") @@ -17,7 +18,7 @@ TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia." OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin" GOLDEN_PREFILL = "/tmp/existing_embeds.bin" MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" -MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window) +MAX_SEGMENT_TOKENS = 20 # ~70 audio steps + prefill = ~80, fits KV_LEN=100 with margin import torch, numpy as np, re from qwen_tts import Qwen3TTSModel @@ -66,17 +67,17 @@ def split_text(text, max_tokens): return [s for s in segments if s.strip()] def make_segment(text_segment): - """Build embeds for one segment.""" + """Build embeds for one segment. + Mirrors Python qwen_tts: trailing = text_proj[1:] (no padding). + C++ then adds 1×eos after exhausting trailing, then pad_embed, and stops on natural EOS. + """ tokens = tokenizer.encode(text_segment, add_special_tokens=False) with torch.no_grad(): proj = talker.text_projection( talker.get_text_embeddings()(torch.tensor([tokens])) )[0].numpy().astype(np.float32) - target_len = max(int(len(tokens) * 3.2) + 5, 40) - trailing = [proj[i] for i in range(1, len(proj))] - while len(trailing) < target_len: - trailing.append(eos) + trailing = [proj[i] for i in range(1, len(proj))] # text[1:], no eos here return { 'tokens': len(tokens), diff --git a/scripts/prepare_tts_voiceclone.py b/scripts/prepare_tts_voiceclone.py new file mode 100644 index 0000000..59fa503 --- /dev/null +++ b/scripts/prepare_tts_voiceclone.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +Generate voice-cloned TTS embeds: capture the COMPLETE talker input sequence +from a Python voice-cloning run (prefill + every generation step). + +Unlike prepare_tts_embeds.py, this version captures multi-token prefill too, +so the NPU has the correct KV-cache context and there is no "tacs"/clicks. + +Usage: python3 prepare_tts_voiceclone.py "Your text here" [output.bin] [voice.wav] +""" +import sys, os, struct, warnings +os.chdir("/tmp") +warnings.filterwarnings("ignore") + +TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia." +OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_vc.bin" +VOICE = sys.argv[3] if len(sys.argv) > 3 else "/opt/Kazeia/voix/damien_15s_24k.wav" +MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" + +import torch, numpy as np +from qwen_tts import Qwen3TTSModel + +print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'") +print(f"Voice: {VOICE}") +print("Loading model...") +tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") +talker = tts.model.talker + +# Capture EVERY talker input, keeping track of per-call shape so we can split +# the first call (prefill, multi-token) from subsequent calls (decode, 1 token). +captured = [] # list of 1024-dim vectors, in order +call_shapes = [] # length of each call +codec_ids_per_step = [] # list of [16] int arrays — Python's predicted codes per decode step +original_forward = talker.model.forward +original_talker_forward = talker.forward + +def patched_forward(input_ids=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None and inputs_embeds.dim() == 3: + t = inputs_embeds.shape[1] + call_shapes.append(t) + for i in range(t): + captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32)) + return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) + +def talker_output_hook(module, input, output): + """Captures codec_ids from each talker.forward call output (nn.Module forward hook). + Preserves the method signature so HF's kwarg validation still works.""" + hs = getattr(output, 'hidden_states', None) + if isinstance(hs, tuple) and len(hs) >= 2 and hs[1] is not None: + cids = hs[1] + if hasattr(cids, 'detach'): + codec_ids_per_step.append(cids.detach().cpu().numpy().astype(np.int32).reshape(-1)) + +talker.model.forward = patched_forward +talker.register_forward_hook(talker_output_hook) + +print("Running voice-clone generation (captures prefill + decode inputs)...") +audio_list, sr = tts.generate_voice_clone( + text=TEXT, ref_audio=VOICE, language="french", + x_vector_only_mode=True, non_streaming_mode=True, +) +audio = audio_list[0] + +if not call_shapes: + print("ERROR: captured nothing") + sys.exit(1) + +# First call is prefill (multi-token). Every subsequent call is a single-token +# decode step. Decode length = total gen frames Python produced. +nPrefill = call_shapes[0] +nDecode = len(captured) - nPrefill +nTotal = len(captured) + +print(f"Audio: {len(audio)/sr:.2f}s") +print(f"Captured {nTotal} embeds: {nPrefill} prefill + {nDecode} decode") +print(f"Captured {len(codec_ids_per_step)} code vectors × 16 per step") +print(f"Call shapes: first={call_shapes[0]}, rest={call_shapes[1:4]}... ({len(call_shapes)} calls total)") + +# Binary format: +with open(OUTPUT, "wb") as f: + f.write(struct.pack("