From f548e02283371cc4d179ea8d6af37f2b68746002 Mon Sep 17 00:00:00 2001 From: Kazeia Team Date: Mon, 13 Apr 2026 14:13:04 +0200 Subject: [PATCH] TTS: dynamic EOS-rank boost terminates generation cleanly across voices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the fixed maxGen + length-based boost with a fully dynamic end-of-utterance detector that watches the model's own EOS logit rank. End result on the Baer 3-segment monologue, validated by user as "FORMIDABLE" / "impeccable" with both Damien and Zelda voices: - All 3 segments terminate via EOS (no maxGen cap hit) - No "page beg beg" filler tail - No abrupt cuts between segments - Audio durations 5-8 s per segment, matching Python within ~10 % How it works (runHexGenWithPrefill, in tts/Qwen3TtsEngine.kt): 1. At every decode step, compute the rank of CODEC_EOS in the repetition-penalised logits. Mid-utterance the rank sits at 150-700 (model is committed to producing speech). Approaching the natural end, the rank dips toward top-50. 2. Arm the boost only when EOS rank stays below eosRankTrigger=60 for THREE consecutive steps. The 3-step requirement filters out transient single-step dips that occur during low-energy phonemes mid-sentence (without it, short sentences would terminate after ~3 s). Arming is also gated by eosBoostMinStep (50 % of expected speech length) so we never arm in the very first frames. 3. Once armed, the boost increments monotonically: each subsequent step adds boostStepsActive * eosBoostScale to the EOS logit. The accumulated boost lifts EOS above top-1 within 1-3 steps, the argmax check fires, and the loop breaks. Scale=4 gives the model a small natural decay before termination; scale=5 was perfect-but- slightly-clipping, scale=3 wasn't strong enough to outpace the growing top-1 logit. Other tweaks bundled in this commit because they all contribute to the clean output: * Inter-segment gap 120 → 250 ms — gives the listener a perceived sentence boundary instead of a hard concatenation. * fadeOut(audio, 40) on every segment — cosine roll-off over the last 40 ms so the EOS-clipped tail decays naturally instead of sample-clipping. * top_k 50 → 200 in the fallback sample call — wider pool to keep EOS reachable when the boost just fails to hit argmax. Voice swap is a 45 KB file push (damien_voice_prefix.bin and damien_voice_suffix.bin). Successfully tested today with Elodie (female, norm 10.12) and Zelda (norm 9.39) using Damien (norm 10.36) as the baseline — same Kotlin code, no rebuild needed. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../java/com/kazeia/tts/Qwen3TtsEngine.kt | 198 ++++++++++++++---- 1 file changed, 155 insertions(+), 43 deletions(-) diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 5060f19..30c8833 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -3417,8 +3417,13 @@ class Qwen3TtsEngine( for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) for (e in suffix) prefill.add(e) - val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15) - val codes = runHexGenWithPrefill(prefill, maxGen) + // See synthesizeTextStreaming for the rationale: generous absolute + // cap, dynamic EOS-rank-driven boost, safety floor at 50 % of + // expected to suppress early termination on short utterances. + val expectedSteps = (ids.size * 24) / 10 + val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15) + val eosBoostMinStep = expectedSteps / 2 + val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep) if (codes.isEmpty()) return ShortArray(0) val n = codes.size @@ -3428,12 +3433,9 @@ class Qwen3TtsEngine( if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0 } } - val audio = decodeChunked(codebooks, n) - // Match the conservative trim policy in synthesizeTextStreaming: - // only trim when we hit the maxGen cap, which is the "failure to - // emit EOS" signal. Shorter generations are kept verbatim to - // avoid cutting low-energy speech. - return if (n >= maxGen) trimTailLowEnergy(audio) else audio + // 40 ms fade-out so the EOS-clipped tail decays naturally before + // the AudioTrack write. + return fadeOut(decodeChunked(codebooks, n), 40) } /** @@ -3444,7 +3446,22 @@ class Qwen3TtsEngine( * (since all text has already been consumed in the prefill). On EOS, * terminate early. Returns the generated [step, codebook] codes. */ - private fun runHexGenWithPrefill(prefill: List, maxGen: Int): Array { + private fun runHexGenWithPrefill( + prefill: List, + maxGen: Int, + // Floor on how many steps to generate before the dynamic boost is + // allowed to fire. Prevents short-text segments from terminating on + // a stray "EOS-leaning" hidden state during the first few frames. + // Callers pass ~50 % of expected speech length as a safety floor. + eosBoostMinStep: Int = -1, + // Once EOS rank falls below this threshold (model itself is + // "thinking about stopping"), start adding eosBoostScale per step + // until argmax flips to EOS. Empirically EOS rank plateaus ~150-700 + // mid-speech and dips to ~50-60 right at the natural end, so this + // catches the model's intent without a fixed length budget. + eosRankTrigger: Int = 60, + eosBoostScale: Float = 4.0f + ): Array { val padE = ttsPadEmbed ?: return emptyArray() val eosE = ttsEosEmbed ?: return emptyArray() val allCodes = mutableListOf() @@ -3467,6 +3484,13 @@ class Qwen3TtsEngine( // decode step. We follow the same schedule so the model's attention // sees the same "text exhausted" signal it was trained with. var eosFedOnce = false + // Counter for the dynamic EOS boost — once armed, increments every + // step monotonically. Arming requires 3 CONSECUTIVE steps below + // the rank trigger (transient dips during normal speech aren't + // enough); this keeps short sentences from terminating mid-word + // on a fluke low rank. + var boostStepsActive = 0 + var consecLowRank = 0 for (genStep in 0 until maxGen) { val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 val tCp = System.currentTimeMillis() @@ -3489,15 +3513,41 @@ class Qwen3TtsEngine( for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } - currentCb0 = sampleTopK(logits, 0.9f, 50) - if (currentCb0 == CODEC_EOS) { nlog("VC EOS at step ${genStep+1}"); break } - // Degeneracy guard: when the talker fails to emit EOS it falls - // into a stuck loop where cb0 repeats (the "page beg beg beg" - // artifact audible at the tail of generated phrases). Nine - // consecutive identical cb0s is the threshold the native .pte - // pipeline uses too. The short history is just the last 9 - // entries of generatedCb0 — cheap to scan. + // Dynamic EOS detection: track current EOS rank, require 3 + // consecutive low-rank steps before arming. Single-step rank + // dips happen mid-utterance during low-energy phonemes and + // would otherwise trigger premature termination on short + // sentences. Once armed, the boost accumulates monotonically. + val eosLogit0 = logits[CODEC_EOS] + var eosRank = 0 + for (j in logits.indices) if (logits[j] > eosLogit0) eosRank++ + if (boostStepsActive == 0 && genStep >= eosBoostMinStep) { + if (eosRank < eosRankTrigger) { + consecLowRank++ + if (consecLowRank >= 3) { + nlog("VC boost armed at step ${genStep+1} (EOS rank $eosRank, 3 consecutive low)") + boostStepsActive = 1 + } + } else { + consecLowRank = 0 + } + } else if (boostStepsActive > 0) { + boostStepsActive++ + } + if (boostStepsActive > 0) { + logits[CODEC_EOS] += boostStepsActive * eosBoostScale + } + + var argmax = 0; var argmaxV = logits[0] + for (j in 1 until logits.size) if (logits[j] > argmaxV) { argmaxV = logits[j]; argmax = j } + if (argmax == CODEC_EOS) { nlog("VC EOS (boosted argmax) at step ${genStep+1}"); break } + + currentCb0 = sampleTopK(logits, 0.9f, 200) + if (currentCb0 == CODEC_EOS) { nlog("VC EOS (sampled) at step ${genStep+1}"); break } + + // Degeneracy guard #1: 9 consecutive identical cb0 → stop. + // Catches the simplest stuck loop. val nHist = generatedCb0.size if (nHist >= 9) { val last = generatedCb0[nHist - 1] @@ -3508,8 +3558,26 @@ class Qwen3TtsEngine( break } } + + // Degeneracy guard #2: low diversity in the recent window. + // The "page beg beg" filler doesn't repeat a single token — it + // cycles through 2-3 tokens. If the last 12 cb0 contain fewer + // than 4 unique values, the talker is in the cycle and we stop + // before the audio degrades further. + if (nHist >= 12) { + val recent = HashSet() + for (i in nHist - 12 until nHist) recent.add(generatedCb0[i]) + if (recent.size < 4) { + nlog("VC degen: only ${recent.size} unique cb0 in last 12 at step ${genStep+1}, stopping") + break + } + } } nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms") + // Diagnostic: log full cb0 sequence so we can correlate against the + // page-beg region in the produced audio. + val cb0Str = generatedCb0.joinToString(",") + nlog("VC cb0 sequence: $cb0Str") return allCodes.toTypedArray() } @@ -3583,7 +3651,11 @@ class Qwen3TtsEngine( hexReset() val segmentAudios = mutableListOf() - val gapSamples = SR * 120 / 1000 + // Wider gap (250 ms) between segments — gives the listener a small + // breath/pause that mirrors how real speakers separate sentences, + // and masks the EOS-boost cut by surrounding it with silence rather + // than another sentence's onset. + val gapSamples = SR * 250 / 1000 val gap = ShortArray(gapSamples) val t0 = System.currentTimeMillis() @@ -3615,14 +3687,14 @@ class Qwen3TtsEngine( for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) for (e in suffix) prefill.add(e) - // Empirical budget: Python's voice_clone typically emits ~3.3 - // codec frames per text token for French. Keep a small cushion - // so ~80% of runs terminate via EOS/degeneracy before exhausting - // the budget; trimming is done by the degeneracy guard inside - // runHexGenWithPrefill. Too-generous maxGen guarantees the tail - // artifacts the user hears as "page beg beg beg". - val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15) - val codes = runHexGenWithPrefill(prefill, maxGen) + // Generous absolute cap; the dynamic EOS boost (triggered when + // the model's own EOS rank dips below threshold) is what + // actually terminates generation. The minStep floor protects + // against early-termination spikes for short sentences. + val expectedSteps = (ids.size * 24) / 10 // ids * 2.4 (int math) + val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15) + val eosBoostMinStep = expectedSteps / 2 + val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep) if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue } val n = codes.size @@ -3633,16 +3705,13 @@ class Qwen3TtsEngine( } } val rawAudio = decodeChunked(codebooks, n) - // Only trim when the talker exhausted its budget — that's the - // case where "page beg beg" filler actually sneaks in. When - // EOS or the degeneracy guard fires early, the audio is already - // clean and trimTailLowEnergy's 35%-of-peak threshold is too - // aggressive for natural French cadence (it cut Elodie seg1 - // from 7.04 s to 2.92 s because the second half of the - // sentence was below the speech threshold). - val audio = if (n >= maxGen) trimTailLowEnergy(rawAudio) else rawAudio + // Cosine fade-out on the last 40 ms — softens the cut imposed by + // the EOS boost so the segment ends on a recognisable phoneme + // tail instead of an abrupt sample-clip. + val audio = fadeOut(rawAudio, 40) val segMs = System.currentTimeMillis() - tSeg - nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio (raw ${rawAudio.size/SR.toFloat()}s, trimmed=${n >= maxGen}) in ${segMs}ms") + val budgetHit = if (n >= maxGen) " [maxGen cap]" else "" + nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms$budgetHit") segmentAudios.add(audio) saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio) @@ -3675,7 +3744,7 @@ class Qwen3TtsEngine( if (audio.size < SR / 2) return audio val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples val nWin = audio.size / winSamples - if (nWin < 6) return audio + if (nWin < 10) return audio val rms = FloatArray(nWin) for (w in 0 until nWin) { var s = 0.0 @@ -3683,21 +3752,64 @@ class Qwen3TtsEngine( for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x } rms[w] = kotlin.math.sqrt(s / winSamples).toFloat() } - // Reference peak over the first 70% of the segment; the tail is - // assumed to be the degenerate filler region. var peak = 0f - val refEnd = (nWin * 7) / 10 - for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w] - val thr = peak * 0.35f + for (w in 0 until nWin) if (rms[w] > peak) peak = rms[w] + + // Heuristic 1: classic 35 % sustained threshold from the back. + val thrSust = peak * 0.35f var lastSpeech = nWin - 1 for (w in nWin - 1 downTo 2) { - if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break } + if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) { + lastSpeech = w; break + } } + + // Heuristic 2: "low-energy tail" — the "page beg beg" filler emits + // audio that's louder than silence but quieter than real speech. + // If the last 30 % of the segment has a max RMS under 40 % of the + // global peak, the tail is degenerate; cut at the last sustained- + // speech window before the tail starts. + val tailStart = (nWin * 7) / 10 + var tailMax = 0f + for (w in tailStart until nWin) if (rms[w] > tailMax) tailMax = rms[w] + if (tailMax < peak * 0.40f) { + for (w in tailStart - 1 downTo 2) { + if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) { + if (w < lastSpeech) lastSpeech = w + break + } + } + } + val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1) - val keepSamples = (keepWin + 1) * winSamples + var keepSamples = (keepWin + 1) * winSamples + // Over-trim guard: never cut below 40 % of the raw length. + val minKeep = (audio.size * 4) / 10 + if (keepSamples < minKeep) keepSamples = minKeep return audio.copyOf(keepSamples) } + /** + * Apply a short cosine fade-out to the tail of an audio buffer. The + * EOS boost ends generation right at the model's chosen stop point, + * which usually clips the natural decay of the last syllable. A + * 40-ms fade smooths that into a recognisable phoneme tail without + * shortening the perceived word. + */ + private fun fadeOut(audio: ShortArray, fadeMs: Int = 40): ShortArray { + val fadeSamples = SR * fadeMs / 1000 + if (audio.size <= fadeSamples) return audio + val out = audio.copyOf() + val start = out.size - fadeSamples + for (i in 0 until fadeSamples) { + // Cosine roll-off: 1 → 0 over fadeSamples + val t = i.toFloat() / fadeSamples + val gain = 0.5f * (1f + kotlin.math.cos(Math.PI.toFloat() * t)) + out[start + i] = (out[start + i] * gain).toInt().toShort() + } + return out + } + /** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to * save one file per segment plus the concatenated result for inspection. */ private fun saveWav(path: String, audio: ShortArray) {