Root cause found + on-device embed capture + KV=100 restored

Root cause: embeds must come from SAME NPU model instance.
Python fp32 embeds cause divergence on NPU fp16 after ~20 steps.

Solution: Java pipeline captures embeds on-device during generation.
Captured embeds work perfectly with C++ pipeline (validated "bon").

- Added capture mode: touch /data/local/tmp/kazeia/capture_mode
- Embeds saved to captured_embeds.bin (same format as pipeline input)
- KV_LEN restored to 100 (KV=64 lost role tokens → quality loss)
- C++ uses pre-computed embeds as-is (no double codec_sum)

Production path: Java pipeline RTF 1.8 for new texts (good quality)
Replay path: C++ pipeline RTF 1.26 with captured embeds

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 23:00:37 +02:00
parent 3dcf73aa38
commit 09d36f2025
1 changed files with 54 additions and 4 deletions

View File

@ -1067,10 +1067,13 @@ class Qwen3TtsEngine(
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
// Capture mode: save all talker inputs for reuse with C++ pipeline
val capturedEmbeds = mutableListOf<FloatArray>()
// ===== PREFILL =====
val tPrefill = System.currentTimeMillis()
for (step in prefill.indices) {
// Unmask position
capturedEmbeds.add(prefill[step].clone()) // capture prefill input
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
if (maskIdx >= 0) maskData[maskIdx] = 0f
@ -1139,6 +1142,8 @@ class Qwen3TtsEngine(
else -> sumEmb(codecSum, padE)
}
capturedEmbeds.add(nextEmbed.clone()) // capture decode input
// 3. Talker step
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
if (maskIdx >= 0) maskData[maskIdx] = 0f
@ -1192,6 +1197,26 @@ class Qwen3TtsEngine(
val n = allCodes.size
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
// Save captured embeds for C++ pipeline reuse
if (capturedEmbeds.isNotEmpty()) {
try {
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
val nPrefill = prefill.size
val fos = java.io.FileOutputStream(capPath)
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
fos.write(hdr.array())
for (emb in capturedEmbeds) {
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in emb) buf.putFloat(v)
fos.write(buf.array())
}
fos.close()
nlog("Captured ${capturedEmbeds.size} embeds → $capPath")
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
}
return allCodes.toTypedArray()
}
@ -2299,8 +2324,10 @@ class Qwen3TtsEngine(
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
val allCodes: Array<IntArray>
// Check if capture mode requested (force Java path to capture embeds)
val captureMode = File("/data/local/tmp/kazeia/capture_mode").exists()
// Native C++ pipeline using SAME Java Module instances (no quality loss)
if (talkerPteModule != null && cpPteModule != null) {
if (!captureMode && talkerPteModule != null && cpPteModule != null) {
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
@ -2449,9 +2476,13 @@ class Qwen3TtsEngine(
return Pair(hidden, logits)
}
// Capture embeds for C++ reuse
val capturedEmbeds = mutableListOf<FloatArray>()
// ===== PREFILL =====
val tPrefill = System.currentTimeMillis()
for (step in prefillEmbeds.indices) {
capturedEmbeds.add(prefillEmbeds[step].clone())
val (h, logits) = talkerStep(prefillEmbeds[step])
pastHidden = h
if (step == prefillEmbeds.size - 1) {
@ -2480,12 +2511,12 @@ class Qwen3TtsEngine(
if (trailingIdx < trailingEmbeds.size) {
nextEmbed = trailingEmbeds[trailingIdx]; trailingIdx++
} else {
// Build from codec embeddings + trailing text (pad)
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
nextEmbed = sumEmb(codecSum, padE)
}
capturedEmbeds.add(nextEmbed.clone())
val tTalker0 = System.currentTimeMillis()
val (h, logits) = talkerStep(nextEmbed)
@ -2506,7 +2537,26 @@ class Qwen3TtsEngine(
val n = allCodes.size
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
nlog("CB0 Java: [${generatedCb0.joinToString(",")}]")
// Save captured embeds
if (capturedEmbeds.isNotEmpty()) {
try {
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
val nPrefill = prefillEmbeds.size
val fos = java.io.FileOutputStream(capPath)
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
fos.write(hdr.array())
for (emb in capturedEmbeds) {
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in emb) buf.putFloat(v)
fos.write(buf.array())
}
fos.close()
nlog("Captured ${capturedEmbeds.size} embeds → $capPath (${nPrefill} prefill + ${capturedEmbeds.size - nPrefill} decode)")
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
}
return allCodes.toTypedArray()
}