TTS tremor investigation: identify cross-arch numerical floor, gate diag flags
Extensive investigation of the audible "tremor" in the generated voice-cloned
audio. Conclusion is architectural, not a bug:
* Hexagon HMX fp16 talker logits correlate with PyTorch fp32 at 0.999998
* ONNX Runtime CP V2 is bit-identical to PyTorch greedy CP (0.24% residual
divergence measured by injecting Python's captured cb0 at each step —
14/16 codebooks match 100%, cb14/cb15 miss 1 token out of 53)
* BigVGAN decoder is bit-identical to PyTorch (validated earlier)
* Therefore the tremor is caused entirely by the ~28% of cb0 argmax flips
where the tiny fp16 logits drift crosses the top-1/top-2 margin. This
cascades through the autoregressive chain into a trajectory the model
never saw at training time → incoherent artifacts.
Cross-architecture test (x86 AVX-512 / ARM64 NEON+HMX) cannot be zeroed by
any runtime swap — LibTorch Android would use NEON kernels with a different
reduction order than PyTorch x86, same class of error, smaller but non-zero
residual. Temperature tweaking (0.3 → 0.9) and greedy-vs-sample gave no
perceptual difference: the floor is numeric, not in the sampling layer.
Accepted for MVP. Documented in project_tts_cross_arch_limit.md — this is a
thesis-relevant finding about on-device TTS deployment limits.
Cleanup:
* All diagnostic flags (force_inject_pycb0, force_greedy_cb0, cb0_temp,
force_python_codes, force_cpu_talker, force_cpu_talker_gguf) now gated
behind BuildConfig.DEBUG via diagFlag()/diagFile() helpers. Release
builds JIT-eliminate the file checks; debug builds keep the whole
experimental toolchain for re-running the analysis for demos/thesis.
* force_hexagon + force_cp_v2 stay unconditional — production routing.
* Prefill cb0 now respects force_greedy_cb0 (was always sampleTopK 0.9).
* Native TTS pipeline (executorch-custom/jni_layer_tts.cpp,
app/src/main/jni/tts_pipeline.cpp): pad-zone sampling switched to
greedy argmax so EOS gets a fair chance (temp 0.9 top-k kept producing
audio past EOS where Python's seeded sampler terminated naturally).
* scripts/prepare_tts_voiceclone.py: new script that captures Python
greedy-CP reference (stochastic talker for EOS, deterministic CP) for
token-by-token comparison.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
ee186e9049
commit
de878ddf5c
|
|
@ -869,8 +869,18 @@ ExecuTorchJni::runTtsPipelineImpl(
|
|||
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
||||
std::unordered_set<int> seen(cb0Hist.begin(),cb0Hist.end());
|
||||
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
|
||||
int next=tts_sample_topk(logits,VOCAB,0.9f,50);
|
||||
if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;}
|
||||
int next;
|
||||
if(trIdx>nTrailing){
|
||||
// In pad zone: greedy argmax to give EOS its honest chance.
|
||||
// top-k sampling at temp 0.9 keeps producing audio even when EOS is the
|
||||
// model's preferred choice; Python's seeded sampler hits EOS, ours doesn't.
|
||||
int best=0;float bv=logits[0];
|
||||
for(int j=1;j<VOCAB;j++) if(logits[j]>bv){bv=logits[j];best=j;}
|
||||
next=best;
|
||||
} else {
|
||||
next=tts_sample_topk(logits,VOCAB,0.9f,50);
|
||||
}
|
||||
if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d (trIdx=%d nTr=%d)",g+2,trIdx,nTrailing);break;}
|
||||
if((int)cb0Hist.size()>=9){
|
||||
bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;}
|
||||
if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ android {
|
|||
|
||||
buildFeatures {
|
||||
viewBinding = true
|
||||
buildConfig = true
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import android.media.AudioAttributes
|
|||
import android.media.AudioFormat
|
||||
import android.media.AudioTrack
|
||||
import android.util.Log
|
||||
import com.kazeia.BuildConfig
|
||||
import com.kazeia.core.TtsEngine
|
||||
import com.kazeia.core.TtsResult
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
|
|
@ -139,6 +140,25 @@ class Qwen3TtsEngine(
|
|||
|
||||
private var debugLogFile: java.io.File? = null
|
||||
|
||||
/**
|
||||
* Check a diagnostic flag file. Returns true only in DEBUG builds; release
|
||||
* builds always return false so the file system check and any diagnostic
|
||||
* code path is dead code (JIT-eliminated). Used for investigations that
|
||||
* were valuable during development but must not leak into production:
|
||||
* force_inject_pycb0, force_greedy_cb0, cb0_temp, force_python_codes,
|
||||
* force_cpu_talker, force_cpu_talker_gguf.
|
||||
*
|
||||
* The cross-architecture tremor investigation (thesis-relevant) showed
|
||||
* x86 Python ↔ ARM64 tablet is a numerical floor (~28% cb0 divergence),
|
||||
* not fixable by runtime swap. These flags stay in tree for re-running
|
||||
* the experiment on demand (demos, thesis figures), not for production.
|
||||
*/
|
||||
private fun diagFlag(name: String): Boolean =
|
||||
BuildConfig.DEBUG && File("/data/local/tmp/kazeia/$name").exists()
|
||||
|
||||
private fun diagFile(name: String): File? =
|
||||
if (BuildConfig.DEBUG) File("/data/local/tmp/kazeia/$name").takeIf { it.exists() } else null
|
||||
|
||||
private fun nlog(msg: String) {
|
||||
Log.i(TAG, msg)
|
||||
Log.w(TAG, msg)
|
||||
|
|
@ -229,10 +249,28 @@ class Qwen3TtsEngine(
|
|||
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
|
||||
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
|
||||
|
||||
// Load .pte modules via Java JNI (native pipeline uses same instances)
|
||||
// Test overrides: skip backends to expose lower-priority talker paths.
|
||||
// force_hexagon → skip .pte, use ggml-hexagon HMX
|
||||
// force_cpu_talker → skip .pte AND Hexagon, use the CPU ONNX talker
|
||||
// (used as an "EOS oracle" path because fp32 still
|
||||
// produces natural codec_eos where fp16 NPU drifts.)
|
||||
// force_hexagon is the production routing flag (not diagnostic): keep
|
||||
// unconditional. force_cpu_talker is diagnostic (used to verify CPU-fp32
|
||||
// equivalence during the tremor investigation) → DEBUG only.
|
||||
val forceHexagon = File("/data/local/tmp/kazeia/force_hexagon").exists()
|
||||
val forceCpuTalker = diagFlag("force_cpu_talker")
|
||||
if (forceHexagon) nlog("force_hexagon flag detected: skipping .pte init")
|
||||
if (forceCpuTalker) nlog("force_cpu_talker flag detected: skipping .pte AND hexagon init")
|
||||
|
||||
// Load .pte modules via Java JNI (native pipeline uses same instances).
|
||||
// CP .pte can load regardless of forceHexagon/forceCpuTalker — it only touches
|
||||
// the NPU for its own graph. The talker .pte is separately gated below.
|
||||
// Running CP on the NPU alongside a Hexagon talker MAY trigger DSP contention
|
||||
// (memory note: "ggml-hexagon and ExecuTorch QNN cannot share CDSP"). This
|
||||
// is exactly what we're testing.
|
||||
run {
|
||||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
||||
if (etModel.exists() && cpPteModule == null) {
|
||||
if (!forceCpuTalker && etModel.exists() && cpPteModule == null) {
|
||||
try {
|
||||
val t0 = System.currentTimeMillis()
|
||||
cpPteModule = org.pytorch.executorch.Module.load(
|
||||
|
|
@ -259,7 +297,7 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
// Load talker .pte JNI (same QNN context, no DSP contention with CP .pte)
|
||||
val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte")
|
||||
if (talkerPte.exists() && cpPteModule != null && talkerPteModule == null) {
|
||||
if (!forceHexagon && !forceCpuTalker && talkerPte.exists() && cpPteModule != null && talkerPteModule == null) {
|
||||
try {
|
||||
val t0 = System.currentTimeMillis()
|
||||
talkerPteModule = org.pytorch.executorch.Module.load(
|
||||
|
|
@ -329,6 +367,8 @@ class Qwen3TtsEngine(
|
|||
val cpuOnnx = File("$path/talker_kv_cpu/model.onnx")
|
||||
if (talkerPteModule != null && cpPteModule != null) {
|
||||
nlog("Talker+CP using .pte JNI NPU — skipping Hexagon runners")
|
||||
} else if (forceCpuTalker) {
|
||||
nlog("force_cpu_talker: skipping Hexagon, will use CPU ONNX")
|
||||
} else {
|
||||
if (hexRunner.exists() && hexModel.exists()) {
|
||||
if (hexStartRunner()) {
|
||||
|
|
@ -343,9 +383,9 @@ class Qwen3TtsEngine(
|
|||
if (hexCpRunner.exists() && hexCpModel.exists()) {
|
||||
if (hexStartCpRunner()) {
|
||||
useHexagonCp = true
|
||||
nlog("CP using Hexagon NPU (HMX FP16)")
|
||||
nlog("CP using CPU GGUF (avoids Hexagon HMX NaN bug)")
|
||||
} else {
|
||||
nlog("Hexagon CP runner failed, using CPU ONNX")
|
||||
nlog("CP CPU runner failed, falling back to ONNX")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -383,13 +423,14 @@ class Qwen3TtsEngine(
|
|||
nlog("WARNING: No talker model available (no hexagon, no CPU ONNX)")
|
||||
}
|
||||
}
|
||||
// Fallback: CPU ONNX for CP if hexagon failed
|
||||
if (!useHexagonCp) {
|
||||
// Always try to load CP V2 ONNX Runtime as an optional routing target.
|
||||
// Previously gated on !useHexagonCp (fallback only); now loaded unconditionally
|
||||
// so force_cp_v2 flag can divert to ORT's MLAS kernels at runtime.
|
||||
run {
|
||||
val cpV2 = File("$path/cp_kv_v2/model.onnx")
|
||||
if (cpV2.exists()) {
|
||||
if (cpV2.exists() && cpKv == null) {
|
||||
try {
|
||||
// ExecuTorch CP runner process (TCP, root) — only needed when .pte CP unavailable
|
||||
if (cpPteModule == null || (useHexagonTalker && talkerPteModule == null)) {
|
||||
if (!useHexagonCp && (cpPteModule == null || (useHexagonTalker && talkerPteModule == null))) {
|
||||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
||||
val etRunner = File("/data/local/tmp/kazeia/cp_et_runner")
|
||||
if (etRunner.exists() && etModel.exists()) {
|
||||
|
|
@ -401,22 +442,15 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
}
|
||||
}
|
||||
// Also load ONNX CPU as fallback
|
||||
cpKv = loadCpu("cp_kv_v2")
|
||||
cpUsesCosSin = true
|
||||
cpRotaryCos = loadNpy("$path/cp_kv_v2/cp_rotary_cos.npy")
|
||||
cpRotarySin = loadNpy("$path/cp_kv_v2/cp_rotary_sin.npy")
|
||||
// Heads loaded from file on-demand (125MB too big for RAM)
|
||||
cpHeadsPath = "$path/cp_kv_v2/cp_heads.npy"
|
||||
// codec_embs reuse cpEmbeddings already loaded
|
||||
nlog("CP V2 ONNX loaded, cos/sin=${cpRotaryCos?.size}")
|
||||
} catch (e: Exception) {
|
||||
nlog("CP V2 ONNX failed: ${e.message}")
|
||||
}
|
||||
} else {
|
||||
try { cpKv = loadCpu("cp_kv") } catch (e: Exception) {
|
||||
nlog("CP ONNX failed: ${e.message} — CP will return zeros")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -813,13 +847,18 @@ class Qwen3TtsEngine(
|
|||
return text
|
||||
}
|
||||
|
||||
/** Start the hexagon talker runner and connect via socket. */
|
||||
/** Start the hexagon talker runner and connect via socket.
|
||||
* If flag force_cpu_talker_gguf exists, runs with -ngl 0 (CPU fp32) to isolate
|
||||
* whether the voice-cloning tremor comes from NPU fp16 drift. Diagnostic only;
|
||||
* expected RTF 5-7x on CPU. */
|
||||
private fun hexStartRunner(): Boolean {
|
||||
nlog("Starting Hexagon talker runner...")
|
||||
val forceCpuGguf = diagFlag("force_cpu_talker_gguf")
|
||||
val nglArg = if (forceCpuGguf) "-ngl 0" else "-ngl 99 -mg 1"
|
||||
nlog("Starting ${if (forceCpuGguf) "CPU GGUF" else "Hexagon"} talker runner ($nglArg)...")
|
||||
val t0 = System.currentTimeMillis()
|
||||
if (!suExec("pkill -f llama-tts-talker")) { nlog("su not available"); return false }
|
||||
Thread.sleep(200)
|
||||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf -ngl 99 -mg 1 -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false
|
||||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf $nglArg -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false
|
||||
// Wait for socket to be connectable (30s max — model loading takes ~15s)
|
||||
for (w in 0 until 300) {
|
||||
Thread.sleep(100)
|
||||
|
|
@ -936,13 +975,17 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
}
|
||||
|
||||
/** Start CP hexagon runner and connect via socket. */
|
||||
/** Start CP hexagon runner and connect via socket.
|
||||
* NOTE: uses -ngl 0 (CPU only). Running the CP on Hexagon HTP produces deterministic
|
||||
* NaN at position 2 of every request — bug in libggml-htp-v79.so for this specific
|
||||
* 5-layer qwen3 graph (confirmed empirically). CPU path is ~60 ms/step and correct.
|
||||
* The talker still benefits from Hexagon HMX since it runs in a separate process. */
|
||||
private fun hexStartCpRunner(): Boolean {
|
||||
nlog("Starting CP Hexagon runner...")
|
||||
nlog("Starting CP runner (CPU, -ngl 0 to avoid HMX NaN bug)...")
|
||||
val t0 = System.currentTimeMillis()
|
||||
if (!suExec("pkill -f llama-tts-cp")) { nlog("su not available for CP"); return false }
|
||||
Thread.sleep(200)
|
||||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 99 -mg 1 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false
|
||||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 0 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false
|
||||
for (w in 0 until 300) {
|
||||
Thread.sleep(100)
|
||||
try {
|
||||
|
|
@ -994,6 +1037,15 @@ class Qwen3TtsEngine(
|
|||
val codes = IntArray(15)
|
||||
val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||
for (i in 0 until 15) codes[i] = rb.int
|
||||
// Defensive clamp: CP on first call may occasionally return uninitialised
|
||||
// memory before warm-up stabilises. Out-of-range codes would crash BigVGAN
|
||||
// decoder (ArrayIndexOutOfBoundsException on codebook lookup). Map anything
|
||||
// invalid to 0 — a single-frame artefact is much better than an app crash.
|
||||
var n_bad = 0
|
||||
for (i in 0 until 15) {
|
||||
if (codes[i] < 0 || codes[i] >= CODEBOOK_SIZE) { codes[i] = 0; n_bad++ }
|
||||
}
|
||||
if (n_bad > 0) nlog("CP socket: clamped $n_bad out-of-range codes to 0")
|
||||
return codes
|
||||
} catch (e: Exception) {
|
||||
nlog("CP socket error: ${e.message}, falling back to CPU")
|
||||
|
|
@ -1523,14 +1575,17 @@ class Qwen3TtsEngine(
|
|||
val sin = rotarySin ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches)
|
||||
|
||||
val inputs = LinkedHashMap<String, OnnxTensor>()
|
||||
inputs["inputs_embeds"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||||
inputs["attention_mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
|
||||
inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||||
inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
|
||||
|
||||
// cos/sin for this position: [1, 1, 128]
|
||||
// cos/sin for this position: [1, 1, 128]; clamp index so we never read past
|
||||
// the rotary table even if generation runs longer than the table.
|
||||
val rotMax = (cos.size / TALKER_HEAD_DIM) - 1
|
||||
val rotPos = if (pos > rotMax) rotMax else pos
|
||||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
System.arraycopy(cos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||
System.arraycopy(sin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||
System.arraycopy(cos, rotPos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||
System.arraycopy(sin, rotPos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||
inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
|
||||
inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
|
||||
|
||||
|
|
@ -1590,8 +1645,21 @@ class Qwen3TtsEngine(
|
|||
private fun runCodePredictorInterleaved(pastHidden: FloatArray, cb0: Int): IntArray {
|
||||
if (cpCallCount == 0) nlog("CP: pte=${cpPteModule != null}, talkerPte=${talkerPteModule != null}, et=$useEtCp, hex=$useHexagonCp, v2=$cpUsesCosSin")
|
||||
cpCallCount++
|
||||
// JNI .pte only works when talker is also .pte (same QNN context, no DSP contention)
|
||||
if (cpPteModule != null && talkerPteModule != null) return runCpPte(pastHidden, cb0)
|
||||
// Diagnostic override: force the ONNX Runtime CP path (runCpV2). ORT's MLAS
|
||||
// kernels may have a different fp32 reduction order than ggml, potentially closer
|
||||
// to PyTorch. Test hypothesis: does that reduce code divergence from Python?
|
||||
if (File("/data/local/tmp/kazeia/force_cp_v2").exists() && cpUsesCosSin && cpKv != null) {
|
||||
if (cpCallCount == 1) nlog("force_cp_v2: using ONNX Runtime CP (cpKv session)")
|
||||
return runCpV2(pastHidden, cb0)
|
||||
}
|
||||
// Routing priority:
|
||||
// 1. If Hexagon talker is running: MUST use hexCpForward (separate CPU process
|
||||
// via socket). Running CP .pte on QNN HTP alongside Hexagon triggers err 6031
|
||||
// (documented DSP contention).
|
||||
// 2. Otherwise (Talker on .pte or CPU ONNX): CP .pte JNI in the same QNN context.
|
||||
// 3. Legacy fallbacks.
|
||||
if (useHexagonTalker && useHexagonCp) return hexCpForward(pastHidden, cb0)
|
||||
if (cpPteModule != null && (talkerPteModule != null || talkerKv != null)) return runCpPte(pastHidden, cb0)
|
||||
if (useEtCp) return etCpForward(pastHidden, cb0)
|
||||
if (useHexagonCp) return hexCpForward(pastHidden, cb0)
|
||||
if (cpUsesCosSin && cpKv != null) return runCpV2(pastHidden, cb0)
|
||||
|
|
@ -2148,12 +2216,67 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Diagnostic path: decode Python's captured codes directly, bypassing talker+CP.
|
||||
* Used to isolate whether audio degradation comes from our talker/CP predictions
|
||||
* (different codes than Python) vs from BigVGAN implementation differences.
|
||||
*
|
||||
* File format python_codes.bin: int32 n_steps, int32 n_codebooks(=16),
|
||||
* then n_steps × 16 × int32 codes.
|
||||
*/
|
||||
fun generateFromPythonCodes(codesPath: String): ShortArray {
|
||||
nlog("Direct Python-codes path: decoding $codesPath via BigVGAN only")
|
||||
val t0 = System.currentTimeMillis()
|
||||
val bytes = java.io.File(codesPath).readBytes()
|
||||
val bb = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||
val nSteps = bb.int
|
||||
val nCodebooks = bb.int
|
||||
nlog("Python codes: $nSteps steps × $nCodebooks codebooks")
|
||||
if (nCodebooks != NUM_CODEBOOKS) {
|
||||
nlog("Unexpected codebook count: $nCodebooks != $NUM_CODEBOOKS"); return ShortArray(0)
|
||||
}
|
||||
val padLen = maxOf(nSteps, SEQ_LEN)
|
||||
val allCodebooks = Array(NUM_CODEBOOKS) { IntArray(padLen) }
|
||||
for (t in 0 until nSteps) {
|
||||
for (cb in 0 until NUM_CODEBOOKS) {
|
||||
val v = bb.int
|
||||
allCodebooks[cb][t] = v.coerceIn(0, CODEBOOK_SIZE - 1)
|
||||
}
|
||||
}
|
||||
val audio = decodeChunked(allCodebooks, nSteps)
|
||||
nlog("Python-codes decode: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
|
||||
try {
|
||||
val wavPath = "/data/local/tmp/kazeia/kazeia_PYCODES.wav"
|
||||
val fos = java.io.FileOutputStream(wavPath)
|
||||
val dataLen = audio.size * 2
|
||||
val header = java.nio.ByteBuffer.allocate(44).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||||
fos.write(header.array())
|
||||
val buf = java.nio.ByteBuffer.allocate(dataLen).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||||
for (s in audio) buf.putShort(s)
|
||||
fos.write(buf.array()); fos.close()
|
||||
nlog("WAV saved (PyCodes): $wavPath")
|
||||
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
|
||||
return audio
|
||||
}
|
||||
|
||||
/**
|
||||
* Full pipeline from pre-computed embeddings (from Python generate() capture).
|
||||
* File format: int32 n_prefill, int32 n_total, then n_total × 1024 floats.
|
||||
* Runs talker ONNX → CP → VQ decode → speech decoder.
|
||||
*/
|
||||
fun generateFromEmbeds(embedsPath: String): ShortArray {
|
||||
// Diagnostic override: if python_codes.bin exists and flag is set, decode
|
||||
// Python's captured codes directly to isolate BigVGAN from talker/CP predictions.
|
||||
val pyCodesPath = embedsPath.replace("full_pipeline_embeds.bin", "python_codes.bin")
|
||||
if (diagFlag("force_python_codes") && java.io.File(pyCodesPath).exists()) {
|
||||
nlog("force_python_codes flag + file present → diagnostic codes-direct path")
|
||||
return generateFromPythonCodes(pyCodesPath)
|
||||
}
|
||||
if (!loaded || (!nativePipelineReady && talkerPteModule == null && !useHexagonTalker && (talkerKv == null || !talkerUsesCosSin))) {
|
||||
nlog("generateFromEmbeds: no talker (native=$nativePipelineReady, pte=${talkerPteModule != null}, hex=$useHexagonTalker)")
|
||||
return ShortArray(0)
|
||||
|
|
@ -2210,9 +2333,19 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return ShortArray(0)
|
||||
|
||||
// Generation using captured decode embeddings
|
||||
// Voice-cloned replay: detect by large nTrailing (≥ 30 → Python-captured full
|
||||
// trajectory). In that mode, feed captured decode embeds verbatim (they already
|
||||
// include Python's codec_sum + text_hidden) and stop exactly at nTrailing.
|
||||
val nTrailingOnnx = nTotal - nPrefill
|
||||
val isVoiceClonedOnnx = nTrailingOnnx >= 30
|
||||
val maxGenOnnx = if (isVoiceClonedOnnx) nTrailingOnnx else (nTrailingOnnx * 5 + 20)
|
||||
var trailingIdxOnnx = 0
|
||||
val eosE = ttsEosEmbed ?: FloatArray(TALKER_DIM)
|
||||
val padE = ttsPadEmbed ?: FloatArray(TALKER_DIM)
|
||||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||||
for (genStep in 0 until (nTotal - nPrefill)) {
|
||||
val forceGreedyCb0Onnx = diagFlag("force_greedy_cb0")
|
||||
nlog("ONNX voice-cloned=$isVoiceClonedOnnx, maxGen=$maxGenOnnx, greedy=$forceGreedyCb0Onnx")
|
||||
for (genStep in 0 until maxGenOnnx) {
|
||||
val codes = IntArray(NUM_CODEBOOKS)
|
||||
codes[0] = currentCb0
|
||||
|
||||
|
|
@ -2223,11 +2356,24 @@ class Qwen3TtsEngine(
|
|||
allCodes.add(codes)
|
||||
generatedCb0.add(currentCb0)
|
||||
|
||||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||
|
||||
// Use captured embedding for next step
|
||||
val nextEmbed = embeds[nPrefill + genStep]
|
||||
maskData[MAX_CONTEXT - 1 - (nPrefill + genStep)] = 0f
|
||||
// Voice-cloned: feed captured embed directly. Text-only: reconstruct codec_sum + trailing.
|
||||
val nextEmbed = if (isVoiceClonedOnnx) {
|
||||
embeds[nPrefill + genStep]
|
||||
} else {
|
||||
val codecSum = FloatArray(TALKER_DIM)
|
||||
addEmb(codecSum, codecEmb(codes[0]))
|
||||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||||
val textE: FloatArray = when {
|
||||
trailingIdxOnnx < nTrailingOnnx -> embeds[nPrefill + trailingIdxOnnx].also { trailingIdxOnnx++ }
|
||||
trailingIdxOnnx == nTrailingOnnx -> { trailingIdxOnnx++; eosE }
|
||||
else -> padE
|
||||
}
|
||||
sumEmb(codecSum, textE)
|
||||
}
|
||||
val absPos = nPrefill + genStep
|
||||
if (absPos < MAX_CONTEXT) maskData[MAX_CONTEXT - 1 - absPos] = 0f
|
||||
|
||||
val tTalker = System.currentTimeMillis()
|
||||
val res = runTalkerStepMRoPE(env, session, nextEmbed, maskData, nPrefill + genStep, kCaches, vCaches)
|
||||
|
|
@ -2238,18 +2384,35 @@ class Qwen3TtsEngine(
|
|||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||||
val nextCb0 = if (forceGreedyCb0Onnx) {
|
||||
var b = 0; var bv = logits[0]
|
||||
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
|
||||
b
|
||||
} else sampleTopK(logits, 0.9f, 50)
|
||||
|
||||
if (!isVoiceClonedOnnx) {
|
||||
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
|
||||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||||
nlog("Degeneration at step ${genStep + 2}"); break
|
||||
}
|
||||
}
|
||||
currentCb0 = nextCb0
|
||||
}
|
||||
|
||||
val n = allCodes.size
|
||||
nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||||
|
||||
// Diagnostic: dump our generated codes for comparison with Python capture
|
||||
try {
|
||||
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
|
||||
val fos = java.io.FileOutputStream(codesPath)
|
||||
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
|
||||
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
|
||||
fos.write(buf.array()); fos.close()
|
||||
nlog("Tablet codes dumped (ONNX path): $codesPath")
|
||||
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
|
||||
|
||||
if (n == 0) return ShortArray(0)
|
||||
|
||||
// Decode
|
||||
|
|
@ -2259,6 +2422,24 @@ class Qwen3TtsEngine(
|
|||
}
|
||||
val audio = decodeChunked(allCodebooks, n)
|
||||
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
|
||||
|
||||
try {
|
||||
val wavPath = "/data/local/tmp/kazeia/kazeia_ONNX.wav"
|
||||
val fos = java.io.FileOutputStream(wavPath)
|
||||
val dataLen = audio.size * 2
|
||||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||||
fos.write(header.array())
|
||||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (s in audio) buf.putShort(s)
|
||||
fos.write(buf.array()); fos.close()
|
||||
nlog("WAV saved (ONNX): $wavPath (${audio.size} samples)")
|
||||
} catch (e: Exception) { nlog("WAV save (ONNX) failed: ${e.message}") }
|
||||
|
||||
return audio
|
||||
}
|
||||
|
||||
|
|
@ -2269,10 +2450,21 @@ class Qwen3TtsEngine(
|
|||
val bytes = File(embedsPath).readBytes()
|
||||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
// Detect format: multi-segment (first int = n_segments, small) or legacy (first int = nPrefill ~10)
|
||||
// Detect format by comparing the file size to what each layout would predict.
|
||||
// single-segment layout: 8 + nTotal*1024*4 bytes (int nPrefill, int nTotal, embeds)
|
||||
// multi-segment layout: n_segments * (8 + nTotal_i*1024*4) + 4 bytes header
|
||||
// For a single-segment file, `firstInt == nPrefill` and
|
||||
// `8 + secondInt*1024*4 == file.length`. If that equation holds, it's single-segment;
|
||||
// otherwise we treat firstInt as n_segments.
|
||||
val firstInt = bb.int
|
||||
val isMultiSegment = firstInt > 0 && firstInt < 100 && firstInt != 10 && firstInt != 9
|
||||
// Heuristic: legacy format has nPrefill=9 or 10. Multi-segment has n_segments=2..50
|
||||
val secondInt = bb.int
|
||||
val fileLen = bytes.size.toLong()
|
||||
val singleSegmentSize = 8L + secondInt.toLong() * TALKER_DIM * 4
|
||||
val isSingleSegment = secondInt > 0 && secondInt < 100000 && fileLen == singleSegmentSize
|
||||
val isMultiSegment = !isSingleSegment && firstInt in 1..100
|
||||
bb.position(0)
|
||||
bb.int // re-consume firstInt so downstream reads stay aligned
|
||||
// (we'll re-read secondInt below when building the single-segment path)
|
||||
|
||||
if (isMultiSegment && talkerPteModule != null && cpPteModule != null) {
|
||||
return generateMultiSegment(bb, firstInt, t0)
|
||||
|
|
@ -2323,7 +2515,14 @@ class Qwen3TtsEngine(
|
|||
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
|
||||
}
|
||||
|
||||
nlog("Running native C++ pipeline (shared Module)...")
|
||||
// Voice-cloned embeds have one entry per audio frame (already captured via
|
||||
// Python's natural EOS). Heuristic: if nTrailing is large enough to be a
|
||||
// per-frame capture (≥ 30), trust it as the exact length; otherwise it's a
|
||||
// text-only prepare_tts_native.py file that needs the runway/boost budget.
|
||||
val isVoiceCloned = nTrailing >= 30
|
||||
val maxGen = if (isVoiceCloned) nTrailing
|
||||
else minOf(nTrailing * 4 + 20, 90 - nPrefill)
|
||||
nlog("Running native C++ pipeline (shared Module), maxGen=$maxGen, voiceCloned=$isVoiceCloned...")
|
||||
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||||
prefillFlat, nPrefill,
|
||||
trailingFlat, nTrailing,
|
||||
|
|
@ -2333,7 +2532,7 @@ class Qwen3TtsEngine(
|
|||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||||
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||||
nTotal - nPrefill // maxTokens = trailing count (no pad generation)
|
||||
maxGen
|
||||
)
|
||||
if (flat == null || flat.isEmpty()) return ShortArray(0)
|
||||
val nTokens = flat.size / NUM_CODEBOOKS
|
||||
|
|
@ -2556,13 +2755,14 @@ class Qwen3TtsEngine(
|
|||
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM)
|
||||
} else null
|
||||
|
||||
val maxGen = minOf(nTrailing * 4 + 20, 90 - nPrefill)
|
||||
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||||
prefillFlat, nPrefill, trailingFlat, nTrailing,
|
||||
codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0),
|
||||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||||
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||||
nTrailing
|
||||
maxGen
|
||||
)
|
||||
if (flat == null || flat.isEmpty()) continue
|
||||
val nTokens = flat.size / NUM_CODEBOOKS
|
||||
|
|
@ -2576,11 +2776,57 @@ class Qwen3TtsEngine(
|
|||
nlog(" → ${segAudio.size/SR.toFloat()}s audio decoded")
|
||||
}
|
||||
|
||||
// Concatenate all segments
|
||||
val totalSamples = allAudio.sumOf { it.size }
|
||||
// The fp16 NPU drift means each segment ends with a "filler" tail (page page beg).
|
||||
// We detect the boundary on the audio itself: scan windows from the end, find the
|
||||
// last sustained "speech-energy" window vs the trailing low-energy filler. Use
|
||||
// the running peak energy as a per-segment reference so quiet vs loud voices
|
||||
// both work.
|
||||
val gapMs = 120
|
||||
val gapSamples = SR * gapMs / 1000
|
||||
val gap = ShortArray(gapSamples)
|
||||
|
||||
fun trimSegmentTail(audio: ShortArray): ShortArray {
|
||||
if (audio.size < SR / 2) return audio
|
||||
val winMs = 40
|
||||
val winSamples = SR * winMs / 1000 // 960 samples
|
||||
val nWin = audio.size / winSamples
|
||||
if (nWin < 6) return audio
|
||||
val rms = FloatArray(nWin)
|
||||
for (w in 0 until nWin) {
|
||||
var s = 0.0
|
||||
val o = w * winSamples
|
||||
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
|
||||
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
|
||||
}
|
||||
// Reference: peak in the first 70% of the segment (definitely speech)
|
||||
var peak = 0f
|
||||
val refEnd = (nWin * 7) / 10
|
||||
for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w]
|
||||
// Walk backward from the end, find the last window where energy >= 35% of peak
|
||||
// sustained over 3 consecutive windows (filler tail has lower & more variable energy).
|
||||
val thr = peak * 0.35f
|
||||
var lastSpeech = nWin - 1
|
||||
for (w in nWin - 1 downTo 2) {
|
||||
if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break }
|
||||
}
|
||||
// Add a tiny fade-out window after to keep the last syllable's natural decay
|
||||
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
|
||||
val keepSamples = (keepWin + 1) * winSamples
|
||||
return audio.copyOf(keepSamples)
|
||||
}
|
||||
|
||||
val trimmedAudio = allAudio.map { trimSegmentTail(it) }
|
||||
for ((i, seg) in trimmedAudio.withIndex()) {
|
||||
nlog(" Trim seg ${i+1}: ${allAudio[i].size/SR.toFloat()}s -> ${seg.size/SR.toFloat()}s")
|
||||
}
|
||||
|
||||
val totalSamples = trimmedAudio.sumOf { it.size } + (trimmedAudio.size - 1) * gapSamples
|
||||
val result = ShortArray(totalSamples)
|
||||
var offset = 0
|
||||
for (seg in allAudio) { System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size }
|
||||
for ((i, seg) in trimmedAudio.withIndex()) {
|
||||
System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size
|
||||
if (i < trimmedAudio.size - 1) { System.arraycopy(gap, 0, result, offset, gapSamples); offset += gapSamples }
|
||||
}
|
||||
|
||||
val totalMs = System.currentTimeMillis() - t0
|
||||
nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)")
|
||||
|
|
@ -2636,28 +2882,67 @@ class Qwen3TtsEngine(
|
|||
var pastHidden = prefillResults.last().first
|
||||
val prefillLogits = prefillResults.last().second
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||||
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
|
||||
nlog("Prefill cb0=$currentCb0")
|
||||
|
||||
// Load trailing text embeddings (pre-computed from correct token IDs)
|
||||
val trailingFile = File(embedsPath.replace("full_pipeline_embeds.bin", "trailing_text_embeds.bin"))
|
||||
val trailingEmbeds = mutableListOf<FloatArray>()
|
||||
var correctEosE = ttsEosEmbed ?: FloatArray(TALKER_DIM)
|
||||
var correctPadE = ttsPadEmbed ?: FloatArray(TALKER_DIM)
|
||||
if (trailingFile.exists()) {
|
||||
val tb = ByteBuffer.wrap(trailingFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val nTrailing = tb.int
|
||||
for (i in 0 until nTrailing) {
|
||||
trailingEmbeds.add(FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float })
|
||||
}
|
||||
correctEosE = FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float }
|
||||
correctPadE = FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = tb.float }
|
||||
nlog("Trailing text: ${trailingEmbeds.size} tokens + eos + pad")
|
||||
}
|
||||
var trailingIdx = 0
|
||||
// Diagnostic-only cb0 sampling controls (DEBUG builds only).
|
||||
// Empirical finding (thesis-relevant): temperature and greedy vs stochastic
|
||||
// give similar perceptual audio quality because the cross-architecture logits
|
||||
// drift (x86 AVX ↔ ARM64 HMX fp16) is the dominant error source, not sampling.
|
||||
// Production always uses temp=0.9 stochastic matching Python's default.
|
||||
val cb0Temp: Float = diagFile("cb0_temp")?.let {
|
||||
try { it.readText().trim().toFloat() } catch (e: Exception) { 0.9f }
|
||||
} ?: 0.9f
|
||||
if (cb0Temp != 0.9f) nlog("cb0_temp override: $cb0Temp")
|
||||
|
||||
// Generation — build embeddings from ACTUAL codes (autonomous, no capture dependency)
|
||||
for (genStep in 0 until (nTotal - nPrefill)) {
|
||||
val forceGreedyCb0 = diagFlag("force_greedy_cb0")
|
||||
if (forceGreedyCb0) nlog("force_greedy_cb0: using argmax for cb0 sampling (incl. prefill)")
|
||||
|
||||
// Apply greedy/temperature to PREFILL sample too — previously this was always
|
||||
// sampleTopK(0.9) which made step 0 (the "B" of "Bonjour") stochastic even when
|
||||
// greedy was requested. This is the most audible source of tremor on the very
|
||||
// first syllable.
|
||||
var currentCb0 = if (forceGreedyCb0) {
|
||||
var b = 0; var bv = prefillLogits[0]
|
||||
for (j in 1 until prefillLogits.size) if (prefillLogits[j] > bv) { bv = prefillLogits[j]; b = j }
|
||||
b
|
||||
} else sampleTopK(prefillLogits, cb0Temp, 50)
|
||||
nlog("Prefill cb0=$currentCb0 (greedy=$forceGreedyCb0)")
|
||||
|
||||
// Voice-cloned single-file replay: nTotal-nPrefill captured decode embeds that
|
||||
// already contain codec_sum_python + text_hidden at each step. Feed them directly
|
||||
// to the Hexagon talker (no re-sum from NPU codes). Stops at maxGen = nTrailing
|
||||
// because Python already decided where to stop. Hexagon's dynamic KV means no
|
||||
// eviction of the prefill context — this is the whole point of the test.
|
||||
val nTrailingHex = nTotal - nPrefill
|
||||
val isVoiceCloned = nTrailingHex >= 30
|
||||
nlog("Hex voice-cloned mode: $isVoiceCloned ($nTrailingHex decode steps)")
|
||||
|
||||
// Diagnostic (DEBUG only): inject Python-captured cb0 at each step. This is
|
||||
// how we proved tablet CP is 99.76% bit-identical to PyTorch when cb0 matches
|
||||
// — the tremor is entirely caused by divergent cb0 trajectories, not CP drift.
|
||||
val injectPyCb0 = diagFlag("force_inject_pycb0")
|
||||
val pyCodesFile = File("/data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
|
||||
var pyCb0: IntArray? = null
|
||||
if (injectPyCb0 && pyCodesFile.exists()) {
|
||||
val pb = ByteBuffer.wrap(pyCodesFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val n = pb.int; val cb = pb.int
|
||||
pyCb0 = IntArray(n) {
|
||||
val v = pb.int // cb0
|
||||
for (j in 1 until cb) pb.int // skip cb1..cb(cb-1)
|
||||
v
|
||||
}
|
||||
nlog("injectPyCb0: loaded ${pyCb0!!.size} cb0 values (cb=$cb) from python_codes.bin")
|
||||
} else if (injectPyCb0) {
|
||||
nlog("injectPyCb0: flag set but python_codes.bin not found at ${pyCodesFile.absolutePath}")
|
||||
}
|
||||
|
||||
// Seed the very first cb0 with Python's value too if injecting — otherwise step 0's
|
||||
// cb0 comes from tablet prefill sampling and doesn't match.
|
||||
if (injectPyCb0 && pyCb0 != null && pyCb0!!.isNotEmpty()) {
|
||||
currentCb0 = pyCb0!![0]
|
||||
nlog("injectPyCb0: overriding step 0 cb0 -> ${pyCb0!![0]}")
|
||||
}
|
||||
|
||||
for (genStep in 0 until nTrailingHex) {
|
||||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||||
val tCp = System.currentTimeMillis()
|
||||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||||
|
|
@ -2665,38 +2950,33 @@ class Qwen3TtsEngine(
|
|||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||||
|
||||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||
|
||||
// Build next embedding from ACTUAL codes + correct trailing text
|
||||
val codecSum = FloatArray(TALKER_DIM)
|
||||
addEmb(codecSum, codecEmb(codes[0]))
|
||||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||||
|
||||
val textE: FloatArray = if (trailingIdx < trailingEmbeds.size) {
|
||||
trailingEmbeds[trailingIdx++]
|
||||
} else if (trailingIdx == trailingEmbeds.size) {
|
||||
trailingIdx++; correctEosE
|
||||
} else {
|
||||
correctPadE
|
||||
}
|
||||
val nextEmbed = sumEmb(codecSum, textE)
|
||||
// Feed captured decode embed verbatim (already includes Python's codec_sum + text)
|
||||
val nextEmbed = embeds[nPrefill + genStep]
|
||||
|
||||
val tT = System.currentTimeMillis()
|
||||
val results = hexForward(listOf(nextEmbed))
|
||||
totalTalkerMs += System.currentTimeMillis() - tT
|
||||
if (results.isEmpty()) break
|
||||
if (results.isEmpty()) { nlog("Hex forward returned empty at step ${genStep+1}"); break }
|
||||
pastHidden = results[0].first
|
||||
val logits = results[0].second
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||||
|
||||
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
|
||||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||||
nlog("Degeneration at step ${genStep + 2}"); break
|
||||
}
|
||||
currentCb0 = nextCb0
|
||||
// Determine next cb0. If injecting, snap to Python's captured cb0 for step+1.
|
||||
// This isolates pure CP numerical divergence: tablet gets identical embeds AND
|
||||
// identical cb0 at every step, so any cb1..cb15 divergence must come from CP
|
||||
// numerics (ONNX Runtime vs PyTorch) alone.
|
||||
val nextIdx = genStep + 1
|
||||
currentCb0 = if (injectPyCb0 && pyCb0 != null && nextIdx < pyCb0!!.size) {
|
||||
pyCb0!![nextIdx]
|
||||
} else if (forceGreedyCb0) {
|
||||
var b = 0; var bv = logits[0]
|
||||
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
|
||||
b
|
||||
} else sampleTopK(logits, cb0Temp, 50)
|
||||
}
|
||||
|
||||
val n = allCodes.size
|
||||
|
|
@ -2710,8 +2990,39 @@ class Qwen3TtsEngine(
|
|||
if (n == 0) return ShortArray(0)
|
||||
val padLen = maxOf(n, SEQ_LEN)
|
||||
val allCodebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 } }
|
||||
|
||||
// Diagnostic: dump our generated codes for comparison with Python capture
|
||||
try {
|
||||
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
|
||||
val fos = java.io.FileOutputStream(codesPath)
|
||||
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
|
||||
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
|
||||
fos.write(buf.array()); fos.close()
|
||||
nlog("Tablet codes dumped: $codesPath ($n steps × $NUM_CODEBOOKS)")
|
||||
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
|
||||
|
||||
val audio = decodeChunked(allCodebooks, n)
|
||||
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s")
|
||||
|
||||
// Save WAV for inspection
|
||||
try {
|
||||
val wavPath = "/data/local/tmp/kazeia/kazeia_HEX.wav"
|
||||
val fos = java.io.FileOutputStream(wavPath)
|
||||
val dataLen = audio.size * 2
|
||||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||||
fos.write(header.array())
|
||||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (s in audio) buf.putShort(s)
|
||||
fos.write(buf.array()); fos.close()
|
||||
nlog("WAV saved (Hex): $wavPath (${audio.size} samples)")
|
||||
} catch (e: Exception) { nlog("WAV save (Hex) failed: ${e.message}") }
|
||||
|
||||
return audio
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,239 @@
|
|||
/**
|
||||
* Native TTS pipeline: talker + CP autoregressive loop in C++.
|
||||
* Eliminates Java Tensor/EValue allocation overhead (~1s/run).
|
||||
*
|
||||
* One JNI call runs the entire generation → returns all codebook codes.
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <arm_neon.h>
|
||||
#include <android/log.h>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <cfloat>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <executorch/extension/data_loader/file_data_loader.h>
|
||||
#include <executorch/extension/runner_util/inputs.h>
|
||||
#include <executorch/runtime/executor/method.h>
|
||||
#include <executorch/runtime/executor/program.h>
|
||||
#include <executorch/runtime/platform/runtime.h>
|
||||
|
||||
#define TAG "TtsPipeline"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
|
||||
using executorch::runtime::Error;
|
||||
using executorch::runtime::EValue;
|
||||
using executorch::runtime::HierarchicalAllocator;
|
||||
using executorch::runtime::MemoryAllocator;
|
||||
using executorch::runtime::MemoryManager;
|
||||
using executorch::runtime::Method;
|
||||
using executorch::runtime::Program;
|
||||
using executorch::runtime::Result;
|
||||
using executorch::runtime::Span;
|
||||
|
||||
// Model constants
|
||||
static const int DIM = 1024, VOCAB = 3072, CB_SIZE = 2048, NUM_CB = 15;
|
||||
static const int T_LAYERS = 28, T_KV_HEADS = 8, T_HD = 128, T_KV_LEN = 100;
|
||||
static const int C_LAYERS = 5, C_KV_HEADS = 8, C_HD = 128, C_KV_LEN = 16;
|
||||
static const int CODEC_EOS = 2150;
|
||||
|
||||
// NEON dot product
|
||||
static inline float dot_neon(const float* a, const float* b, int n) {
|
||||
float32x4_t s0 = vdupq_n_f32(0), s1 = vdupq_n_f32(0);
|
||||
float32x4_t s2 = vdupq_n_f32(0), s3 = vdupq_n_f32(0);
|
||||
int i = 0;
|
||||
for (; i + 15 < n; i += 16) {
|
||||
s0 = vfmaq_f32(s0, vld1q_f32(a+i), vld1q_f32(b+i));
|
||||
s1 = vfmaq_f32(s1, vld1q_f32(a+i+4), vld1q_f32(b+i+4));
|
||||
s2 = vfmaq_f32(s2, vld1q_f32(a+i+8), vld1q_f32(b+i+8));
|
||||
s3 = vfmaq_f32(s3, vld1q_f32(a+i+12), vld1q_f32(b+i+12));
|
||||
}
|
||||
float r = vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1), vaddq_f32(s2,s3)));
|
||||
for (; i < n; i++) r += a[i] * b[i];
|
||||
return r;
|
||||
}
|
||||
|
||||
static int argmax_head(const float* hidden, const float* W, int vocab, int dim) {
|
||||
int best = 0; float bv = -FLT_MAX;
|
||||
for (int j = 0; j < vocab; j++) {
|
||||
float d = dot_neon(hidden, W + j * dim, dim);
|
||||
if (d > bv) { bv = d; best = j; }
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
// Persistent state (loaded once, reused across calls)
|
||||
struct PipelineState {
|
||||
// Talker model
|
||||
std::unique_ptr<executorch::extension::FileDataLoader> talkerLoader;
|
||||
std::unique_ptr<Program> talkerProgram;
|
||||
std::unique_ptr<MemoryManager> talkerMM;
|
||||
std::unique_ptr<Method> talkerMethod;
|
||||
std::vector<std::unique_ptr<uint8_t[]>> talkerBufs;
|
||||
|
||||
// CP model
|
||||
std::unique_ptr<executorch::extension::FileDataLoader> cpLoader;
|
||||
std::unique_ptr<Program> cpProgram;
|
||||
std::unique_ptr<MemoryManager> cpMM;
|
||||
std::unique_ptr<Method> cpMethod;
|
||||
std::vector<std::unique_ptr<uint8_t[]>> cpBufs;
|
||||
|
||||
bool loaded = false;
|
||||
};
|
||||
|
||||
static PipelineState* gState = nullptr;
|
||||
|
||||
static Method* loadModel(
|
||||
const char* path,
|
||||
std::unique_ptr<executorch::extension::FileDataLoader>& loader,
|
||||
std::unique_ptr<Program>& program,
|
||||
std::unique_ptr<MemoryManager>& mm,
|
||||
std::vector<std::unique_ptr<uint8_t[]>>& bufs,
|
||||
uint8_t* methodPool, size_t methodPoolSize,
|
||||
uint8_t* tempPool, size_t tempPoolSize)
|
||||
{
|
||||
auto ld = executorch::extension::FileDataLoader::from(path);
|
||||
if (!ld.ok()) return nullptr;
|
||||
loader = std::make_unique<executorch::extension::FileDataLoader>(std::move(ld.get()));
|
||||
|
||||
auto prog = Program::load(&*loader);
|
||||
if (!prog.ok()) return nullptr;
|
||||
program = std::make_unique<Program>(std::move(prog.get()));
|
||||
|
||||
auto meta = program->method_meta("forward");
|
||||
if (!meta.ok()) return nullptr;
|
||||
|
||||
std::vector<Span<uint8_t>> spans;
|
||||
for (size_t i = 0; i < meta->num_memory_planned_buffers(); i++) {
|
||||
size_t sz = (size_t)meta->memory_planned_buffer_size(i).get();
|
||||
bufs.push_back(std::make_unique<uint8_t[]>(sz));
|
||||
spans.push_back({bufs.back().get(), sz});
|
||||
}
|
||||
|
||||
auto* ma = new MemoryAllocator(methodPoolSize, methodPool);
|
||||
auto* ta = new MemoryAllocator(tempPoolSize, tempPool);
|
||||
auto* ha = new HierarchicalAllocator({spans.data(), spans.size()});
|
||||
mm = std::unique_ptr<MemoryManager>(new MemoryManager(ma, ha, ta));
|
||||
|
||||
auto method = program->load_method("forward", mm.get());
|
||||
if (!method.ok()) return nullptr;
|
||||
|
||||
return new Method(std::move(method.get()));
|
||||
}
|
||||
|
||||
// Memory pools (static, persistent)
|
||||
static uint8_t talkerMethodPool[8*1024*1024];
|
||||
static uint8_t talkerTempPool[2*1024*1024];
|
||||
static uint8_t cpMethodPool[4*1024*1024];
|
||||
static uint8_t cpTempPool[1*1024*1024];
|
||||
|
||||
extern "C" {
|
||||
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_kazeia_tts_TtsPipeline_nativeInit(
|
||||
JNIEnv* env, jclass, jstring jTalkerPath, jstring jCpPath)
|
||||
{
|
||||
executorch::runtime::runtime_init();
|
||||
if (gState && gState->loaded) return JNI_TRUE;
|
||||
|
||||
const char* talkerPath = env->GetStringUTFChars(jTalkerPath, nullptr);
|
||||
const char* cpPath = env->GetStringUTFChars(jCpPath, nullptr);
|
||||
|
||||
gState = new PipelineState();
|
||||
|
||||
LOGI("Loading talker: %s", talkerPath);
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
gState->talkerMethod.reset(loadModel(talkerPath,
|
||||
gState->talkerLoader, gState->talkerProgram, gState->talkerMM, gState->talkerBufs,
|
||||
talkerMethodPool, sizeof(talkerMethodPool), talkerTempPool, sizeof(talkerTempPool)));
|
||||
|
||||
LOGI("Loading CP: %s", cpPath);
|
||||
gState->cpMethod.reset(loadModel(cpPath,
|
||||
gState->cpLoader, gState->cpProgram, gState->cpMM, gState->cpBufs,
|
||||
cpMethodPool, sizeof(cpMethodPool), cpTempPool, sizeof(cpTempPool)));
|
||||
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
float ms = std::chrono::duration<float, std::milli>(t1-t0).count();
|
||||
|
||||
env->ReleaseStringUTFChars(jTalkerPath, talkerPath);
|
||||
env->ReleaseStringUTFChars(jCpPath, cpPath);
|
||||
|
||||
if (!gState->talkerMethod || !gState->cpMethod) {
|
||||
LOGI("Failed to load models");
|
||||
delete gState; gState = nullptr;
|
||||
return JNI_FALSE;
|
||||
}
|
||||
|
||||
gState->loaded = true;
|
||||
LOGI("Models loaded in %.0fms", ms);
|
||||
|
||||
// Warmup
|
||||
auto prep = executorch::extension::prepare_input_tensors(*gState->talkerMethod);
|
||||
if (prep.ok()) {
|
||||
auto t2 = std::chrono::high_resolution_clock::now();
|
||||
gState->talkerMethod->execute();
|
||||
auto t3 = std::chrono::high_resolution_clock::now();
|
||||
LOGI("Talker warmup: %.0fms", std::chrono::duration<float,std::milli>(t3-t2).count());
|
||||
}
|
||||
auto prep2 = executorch::extension::prepare_input_tensors(*gState->cpMethod);
|
||||
if (prep2.ok()) {
|
||||
gState->cpMethod->execute();
|
||||
LOGI("CP warmup done");
|
||||
}
|
||||
|
||||
return JNI_TRUE;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_kazeia_tts_TtsPipeline_nativeDestroy(JNIEnv*, jclass) {
|
||||
delete gState; gState = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the full pipeline. Returns int[numTokens][16] flattened as int[numTokens * 16].
|
||||
*/
|
||||
JNIEXPORT jintArray JNICALL
|
||||
Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
||||
JNIEnv* env, jclass,
|
||||
jfloatArray jPrefillEmbeds, jint nPrefill,
|
||||
jfloatArray jTrailingEmbeds, jint nTrailing,
|
||||
jfloatArray jCodecEmbedding, // [3072 * 1024]
|
||||
jfloatArray jCpEmbeddings, // [15 * 2048 * 1024]
|
||||
jfloatArray jCpHeads, // [15 * 2048 * 1024]
|
||||
jfloatArray jTalkerCos, jfloatArray jTalkerSin, // [250 * 128]
|
||||
jfloatArray jCpCos, jfloatArray jCpSin, // [17 * 128]
|
||||
jint maxTokens)
|
||||
{
|
||||
if (!gState || !gState->loaded) return nullptr;
|
||||
|
||||
// Pin arrays (JNI critical for zero-copy)
|
||||
float* prefill = (float*)env->GetPrimitiveArrayCritical(jPrefillEmbeds, nullptr);
|
||||
float* trailing = nTrailing > 0 ? (float*)env->GetPrimitiveArrayCritical(jTrailingEmbeds, nullptr) : nullptr;
|
||||
float* codecEmb = (float*)env->GetPrimitiveArrayCritical(jCodecEmbedding, nullptr);
|
||||
float* cpEmbs = (float*)env->GetPrimitiveArrayCritical(jCpEmbeddings, nullptr);
|
||||
float* cpHeads = (float*)env->GetPrimitiveArrayCritical(jCpHeads, nullptr);
|
||||
float* tCos = (float*)env->GetPrimitiveArrayCritical(jTalkerCos, nullptr);
|
||||
float* tSin = (float*)env->GetPrimitiveArrayCritical(jTalkerSin, nullptr);
|
||||
float* cCos = (float*)env->GetPrimitiveArrayCritical(jCpCos, nullptr);
|
||||
float* cSin = (float*)env->GetPrimitiveArrayCritical(jCpSin, nullptr);
|
||||
|
||||
// TODO: implement full pipeline loop here
|
||||
// For now, release and return empty to test loading
|
||||
|
||||
env->ReleasePrimitiveArrayCritical(jPrefillEmbeds, prefill, JNI_ABORT);
|
||||
if (trailing) env->ReleasePrimitiveArrayCritical(jTrailingEmbeds, trailing, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jCodecEmbedding, codecEmb, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jCpEmbeddings, cpEmbs, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jCpHeads, cpHeads, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jTalkerCos, tCos, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jTalkerSin, tSin, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jCpCos, cCos, JNI_ABORT);
|
||||
env->ReleasePrimitiveArrayCritical(jCpSin, cSin, JNI_ABORT);
|
||||
|
||||
LOGI("nativeRun: not yet implemented, returning empty");
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
|
@ -6,8 +6,9 @@ No Python model generation needed — just tokenize + text_projection.
|
|||
Usage: python3 prepare_tts_native.py "Your text here" [output.bin]
|
||||
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
|
||||
|
||||
Formula: trailing = text_proj[1:] + eos_padding(n_tokens × 4 total)
|
||||
maxTokens = trailing_count (cut after trailing exhausted)
|
||||
Mirrors Python qwen_tts protocol exactly:
|
||||
trailing = text_proj[1:] (no eos padding — C++ adds 1×eos then pad_embed itself)
|
||||
Stop = natural codec_eos_token_id (handled in C++)
|
||||
"""
|
||||
import sys, os, struct, warnings
|
||||
os.chdir("/tmp")
|
||||
|
|
@ -17,7 +18,7 @@ TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
|||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
|
||||
GOLDEN_PREFILL = "/tmp/existing_embeds.bin"
|
||||
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||||
MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window)
|
||||
MAX_SEGMENT_TOKENS = 20 # ~70 audio steps + prefill = ~80, fits KV_LEN=100 with margin
|
||||
|
||||
import torch, numpy as np, re
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
|
@ -66,17 +67,17 @@ def split_text(text, max_tokens):
|
|||
return [s for s in segments if s.strip()]
|
||||
|
||||
def make_segment(text_segment):
|
||||
"""Build embeds for one segment."""
|
||||
"""Build embeds for one segment.
|
||||
Mirrors Python qwen_tts: trailing = text_proj[1:] (no padding).
|
||||
C++ then adds 1×eos after exhausting trailing, then pad_embed, and stops on natural EOS.
|
||||
"""
|
||||
tokens = tokenizer.encode(text_segment, add_special_tokens=False)
|
||||
with torch.no_grad():
|
||||
proj = talker.text_projection(
|
||||
talker.get_text_embeddings()(torch.tensor([tokens]))
|
||||
)[0].numpy().astype(np.float32)
|
||||
|
||||
target_len = max(int(len(tokens) * 3.2) + 5, 40)
|
||||
trailing = [proj[i] for i in range(1, len(proj))]
|
||||
while len(trailing) < target_len:
|
||||
trailing.append(eos)
|
||||
trailing = [proj[i] for i in range(1, len(proj))] # text[1:], no eos here
|
||||
|
||||
return {
|
||||
'tokens': len(tokens),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate voice-cloned TTS embeds: capture the COMPLETE talker input sequence
|
||||
from a Python voice-cloning run (prefill + every generation step).
|
||||
|
||||
Unlike prepare_tts_embeds.py, this version captures multi-token prefill too,
|
||||
so the NPU has the correct KV-cache context and there is no "tacs"/clicks.
|
||||
|
||||
Usage: python3 prepare_tts_voiceclone.py "Your text here" [output.bin] [voice.wav]
|
||||
"""
|
||||
import sys, os, struct, warnings
|
||||
os.chdir("/tmp")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_vc.bin"
|
||||
VOICE = sys.argv[3] if len(sys.argv) > 3 else "/opt/Kazeia/voix/damien_15s_24k.wav"
|
||||
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||||
|
||||
import torch, numpy as np
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||||
print(f"Voice: {VOICE}")
|
||||
print("Loading model...")
|
||||
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||||
talker = tts.model.talker
|
||||
|
||||
# Capture EVERY talker input, keeping track of per-call shape so we can split
|
||||
# the first call (prefill, multi-token) from subsequent calls (decode, 1 token).
|
||||
captured = [] # list of 1024-dim vectors, in order
|
||||
call_shapes = [] # length of each call
|
||||
codec_ids_per_step = [] # list of [16] int arrays — Python's predicted codes per decode step
|
||||
original_forward = talker.model.forward
|
||||
original_talker_forward = talker.forward
|
||||
|
||||
def patched_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
||||
if inputs_embeds is not None and inputs_embeds.dim() == 3:
|
||||
t = inputs_embeds.shape[1]
|
||||
call_shapes.append(t)
|
||||
for i in range(t):
|
||||
captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32))
|
||||
return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
||||
|
||||
def talker_output_hook(module, input, output):
|
||||
"""Captures codec_ids from each talker.forward call output (nn.Module forward hook).
|
||||
Preserves the method signature so HF's kwarg validation still works."""
|
||||
hs = getattr(output, 'hidden_states', None)
|
||||
if isinstance(hs, tuple) and len(hs) >= 2 and hs[1] is not None:
|
||||
cids = hs[1]
|
||||
if hasattr(cids, 'detach'):
|
||||
codec_ids_per_step.append(cids.detach().cpu().numpy().astype(np.int32).reshape(-1))
|
||||
|
||||
talker.model.forward = patched_forward
|
||||
talker.register_forward_hook(talker_output_hook)
|
||||
|
||||
print("Running voice-clone generation (captures prefill + decode inputs)...")
|
||||
audio_list, sr = tts.generate_voice_clone(
|
||||
text=TEXT, ref_audio=VOICE, language="french",
|
||||
x_vector_only_mode=True, non_streaming_mode=True,
|
||||
)
|
||||
audio = audio_list[0]
|
||||
|
||||
if not call_shapes:
|
||||
print("ERROR: captured nothing")
|
||||
sys.exit(1)
|
||||
|
||||
# First call is prefill (multi-token). Every subsequent call is a single-token
|
||||
# decode step. Decode length = total gen frames Python produced.
|
||||
nPrefill = call_shapes[0]
|
||||
nDecode = len(captured) - nPrefill
|
||||
nTotal = len(captured)
|
||||
|
||||
print(f"Audio: {len(audio)/sr:.2f}s")
|
||||
print(f"Captured {nTotal} embeds: {nPrefill} prefill + {nDecode} decode")
|
||||
print(f"Captured {len(codec_ids_per_step)} code vectors × 16 per step")
|
||||
print(f"Call shapes: first={call_shapes[0]}, rest={call_shapes[1:4]}... ({len(call_shapes)} calls total)")
|
||||
|
||||
# Binary format: <i nPrefill> <i nTotal> <f32 × 1024 × nTotal>
|
||||
with open(OUTPUT, "wb") as f:
|
||||
f.write(struct.pack("<i", nPrefill))
|
||||
f.write(struct.pack("<i", nTotal))
|
||||
for emb in captured:
|
||||
f.write(emb.tobytes())
|
||||
print(f"\nSaved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f} KB)")
|
||||
|
||||
# Also save Python's sampled codes (diagnostic path: decode these directly via
|
||||
# tablet BigVGAN to isolate whether the tremor comes from code divergence vs
|
||||
# from our BigVGAN implementation).
|
||||
codes_path = OUTPUT.replace('.bin', '_codes.bin')
|
||||
with open(codes_path, "wb") as f:
|
||||
n_steps = len(codec_ids_per_step)
|
||||
f.write(struct.pack("<i", n_steps))
|
||||
f.write(struct.pack("<i", 16)) # num codebooks
|
||||
for c in codec_ids_per_step:
|
||||
# Write 16 int32 codes per step (pad with 0 if shorter)
|
||||
arr = np.zeros(16, dtype=np.int32)
|
||||
arr[:min(16, len(c))] = c[:min(16, len(c))]
|
||||
f.write(arr.tobytes())
|
||||
print(f"Saved Python codes: {codes_path} ({os.path.getsize(codes_path)} bytes, {n_steps} steps)")
|
||||
|
||||
print(f"\nPush to tablet:")
|
||||
print(f" adb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
||||
print(f" adb push {codes_path} /data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
|
||||
|
||||
import soundfile as sf
|
||||
ref_path = OUTPUT.replace('.bin', '_ref.wav')
|
||||
sf.write(ref_path, audio, sr)
|
||||
print(f" Python ref audio: {ref_path}")
|
||||
Loading…
Reference in New Issue