Root cause found + on-device embed capture + KV=100 restored
Root cause: embeds must come from SAME NPU model instance. Python fp32 embeds cause divergence on NPU fp16 after ~20 steps. Solution: Java pipeline captures embeds on-device during generation. Captured embeds work perfectly with C++ pipeline (validated "bon"). - Added capture mode: touch /data/local/tmp/kazeia/capture_mode - Embeds saved to captured_embeds.bin (same format as pipeline input) - KV_LEN restored to 100 (KV=64 lost role tokens → quality loss) - C++ uses pre-computed embeds as-is (no double codec_sum) Production path: Java pipeline RTF 1.8 for new texts (good quality) Replay path: C++ pipeline RTF 1.26 with captured embeds Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3dcf73aa38
commit
09d36f2025
|
|
@ -1067,10 +1067,13 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
||||||
|
|
||||||
|
// Capture mode: save all talker inputs for reuse with C++ pipeline
|
||||||
|
val capturedEmbeds = mutableListOf<FloatArray>()
|
||||||
|
|
||||||
// ===== PREFILL =====
|
// ===== PREFILL =====
|
||||||
val tPrefill = System.currentTimeMillis()
|
val tPrefill = System.currentTimeMillis()
|
||||||
for (step in prefill.indices) {
|
for (step in prefill.indices) {
|
||||||
// Unmask position
|
capturedEmbeds.add(prefill[step].clone()) // capture prefill input
|
||||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||||
|
|
||||||
|
|
@ -1139,6 +1142,8 @@ class Qwen3TtsEngine(
|
||||||
else -> sumEmb(codecSum, padE)
|
else -> sumEmb(codecSum, padE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
capturedEmbeds.add(nextEmbed.clone()) // capture decode input
|
||||||
|
|
||||||
// 3. Talker step
|
// 3. Talker step
|
||||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||||
|
|
@ -1192,6 +1197,26 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
val n = allCodes.size
|
val n = allCodes.size
|
||||||
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||||||
|
|
||||||
|
// Save captured embeds for C++ pipeline reuse
|
||||||
|
if (capturedEmbeds.isNotEmpty()) {
|
||||||
|
try {
|
||||||
|
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
|
||||||
|
val nPrefill = prefill.size
|
||||||
|
val fos = java.io.FileOutputStream(capPath)
|
||||||
|
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||||
|
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
|
||||||
|
fos.write(hdr.array())
|
||||||
|
for (emb in capturedEmbeds) {
|
||||||
|
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||||
|
for (v in emb) buf.putFloat(v)
|
||||||
|
fos.write(buf.array())
|
||||||
|
}
|
||||||
|
fos.close()
|
||||||
|
nlog("Captured ${capturedEmbeds.size} embeds → $capPath")
|
||||||
|
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
|
||||||
|
}
|
||||||
|
|
||||||
return allCodes.toTypedArray()
|
return allCodes.toTypedArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2299,8 +2324,10 @@ class Qwen3TtsEngine(
|
||||||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||||||
|
|
||||||
val allCodes: Array<IntArray>
|
val allCodes: Array<IntArray>
|
||||||
|
// Check if capture mode requested (force Java path to capture embeds)
|
||||||
|
val captureMode = File("/data/local/tmp/kazeia/capture_mode").exists()
|
||||||
// Native C++ pipeline using SAME Java Module instances (no quality loss)
|
// Native C++ pipeline using SAME Java Module instances (no quality loss)
|
||||||
if (talkerPteModule != null && cpPteModule != null) {
|
if (!captureMode && talkerPteModule != null && cpPteModule != null) {
|
||||||
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
|
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
|
||||||
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
||||||
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
||||||
|
|
@ -2449,9 +2476,13 @@ class Qwen3TtsEngine(
|
||||||
return Pair(hidden, logits)
|
return Pair(hidden, logits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture embeds for C++ reuse
|
||||||
|
val capturedEmbeds = mutableListOf<FloatArray>()
|
||||||
|
|
||||||
// ===== PREFILL =====
|
// ===== PREFILL =====
|
||||||
val tPrefill = System.currentTimeMillis()
|
val tPrefill = System.currentTimeMillis()
|
||||||
for (step in prefillEmbeds.indices) {
|
for (step in prefillEmbeds.indices) {
|
||||||
|
capturedEmbeds.add(prefillEmbeds[step].clone())
|
||||||
val (h, logits) = talkerStep(prefillEmbeds[step])
|
val (h, logits) = talkerStep(prefillEmbeds[step])
|
||||||
pastHidden = h
|
pastHidden = h
|
||||||
if (step == prefillEmbeds.size - 1) {
|
if (step == prefillEmbeds.size - 1) {
|
||||||
|
|
@ -2480,12 +2511,12 @@ class Qwen3TtsEngine(
|
||||||
if (trailingIdx < trailingEmbeds.size) {
|
if (trailingIdx < trailingEmbeds.size) {
|
||||||
nextEmbed = trailingEmbeds[trailingIdx]; trailingIdx++
|
nextEmbed = trailingEmbeds[trailingIdx]; trailingIdx++
|
||||||
} else {
|
} else {
|
||||||
// Build from codec embeddings + trailing text (pad)
|
|
||||||
val codecSum = FloatArray(TALKER_DIM)
|
val codecSum = FloatArray(TALKER_DIM)
|
||||||
addEmb(codecSum, codecEmb(codes[0]))
|
addEmb(codecSum, codecEmb(codes[0]))
|
||||||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||||||
nextEmbed = sumEmb(codecSum, padE)
|
nextEmbed = sumEmb(codecSum, padE)
|
||||||
}
|
}
|
||||||
|
capturedEmbeds.add(nextEmbed.clone())
|
||||||
|
|
||||||
val tTalker0 = System.currentTimeMillis()
|
val tTalker0 = System.currentTimeMillis()
|
||||||
val (h, logits) = talkerStep(nextEmbed)
|
val (h, logits) = talkerStep(nextEmbed)
|
||||||
|
|
@ -2506,7 +2537,26 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
val n = allCodes.size
|
val n = allCodes.size
|
||||||
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||||||
nlog("CB0 Java: [${generatedCb0.joinToString(",")}]")
|
|
||||||
|
// Save captured embeds
|
||||||
|
if (capturedEmbeds.isNotEmpty()) {
|
||||||
|
try {
|
||||||
|
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
|
||||||
|
val nPrefill = prefillEmbeds.size
|
||||||
|
val fos = java.io.FileOutputStream(capPath)
|
||||||
|
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||||
|
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
|
||||||
|
fos.write(hdr.array())
|
||||||
|
for (emb in capturedEmbeds) {
|
||||||
|
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||||
|
for (v in emb) buf.putFloat(v)
|
||||||
|
fos.write(buf.array())
|
||||||
|
}
|
||||||
|
fos.close()
|
||||||
|
nlog("Captured ${capturedEmbeds.size} embeds → $capPath (${nPrefill} prefill + ${capturedEmbeds.size - nPrefill} decode)")
|
||||||
|
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
|
||||||
|
}
|
||||||
|
|
||||||
return allCodes.toTypedArray()
|
return allCodes.toTypedArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue