TTS: dynamic EOS-rank boost terminates generation cleanly across voices

Replaces the fixed maxGen + length-based boost with a fully dynamic
end-of-utterance detector that watches the model's own EOS logit rank.
End result on the Baer 3-segment monologue, validated by user as
"FORMIDABLE" / "impeccable" with both Damien and Zelda voices:

  - All 3 segments terminate via EOS (no maxGen cap hit)
  - No "page beg beg" filler tail
  - No abrupt cuts between segments
  - Audio durations 5-8 s per segment, matching Python within ~10 %

How it works (runHexGenWithPrefill, in tts/Qwen3TtsEngine.kt):

  1. At every decode step, compute the rank of CODEC_EOS in the
     repetition-penalised logits. Mid-utterance the rank sits at
     150-700 (model is committed to producing speech). Approaching
     the natural end, the rank dips toward top-50.

  2. Arm the boost only when EOS rank stays below eosRankTrigger=60
     for THREE consecutive steps. The 3-step requirement filters out
     transient single-step dips that occur during low-energy phonemes
     mid-sentence (without it, short sentences would terminate after
     ~3 s). Arming is also gated by eosBoostMinStep (50 % of expected
     speech length) so we never arm in the very first frames.

  3. Once armed, the boost increments monotonically: each subsequent
     step adds boostStepsActive * eosBoostScale to the EOS logit. The
     accumulated boost lifts EOS above top-1 within 1-3 steps, the
     argmax check fires, and the loop breaks. Scale=4 gives the model
     a small natural decay before termination; scale=5 was perfect-but-
     slightly-clipping, scale=3 wasn't strong enough to outpace the
     growing top-1 logit.

Other tweaks bundled in this commit because they all contribute to
the clean output:

  * Inter-segment gap 120 → 250 ms — gives the listener a perceived
    sentence boundary instead of a hard concatenation.

  * fadeOut(audio, 40) on every segment — cosine roll-off over the
    last 40 ms so the EOS-clipped tail decays naturally instead of
    sample-clipping.

  * top_k 50 → 200 in the fallback sample call — wider pool to keep
    EOS reachable when the boost just fails to hit argmax.

Voice swap is a 45 KB file push (damien_voice_prefix.bin and
damien_voice_suffix.bin). Successfully tested today with Elodie
(female, norm 10.12) and Zelda (norm 9.39) using Damien (norm 10.36)
as the baseline — same Kotlin code, no rebuild needed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-13 14:13:04 +02:00
parent c25040a780
commit f548e02283
1 changed files with 155 additions and 43 deletions

View File

@ -3417,8 +3417,13 @@ class Qwen3TtsEngine(
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e) for (e in suffix) prefill.add(e)
val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15) // See synthesizeTextStreaming for the rationale: generous absolute
val codes = runHexGenWithPrefill(prefill, maxGen) // cap, dynamic EOS-rank-driven boost, safety floor at 50 % of
// expected to suppress early termination on short utterances.
val expectedSteps = (ids.size * 24) / 10
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
val eosBoostMinStep = expectedSteps / 2
val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
if (codes.isEmpty()) return ShortArray(0) if (codes.isEmpty()) return ShortArray(0)
val n = codes.size val n = codes.size
@ -3428,12 +3433,9 @@ class Qwen3TtsEngine(
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0 if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
} }
} }
val audio = decodeChunked(codebooks, n) // 40 ms fade-out so the EOS-clipped tail decays naturally before
// Match the conservative trim policy in synthesizeTextStreaming: // the AudioTrack write.
// only trim when we hit the maxGen cap, which is the "failure to return fadeOut(decodeChunked(codebooks, n), 40)
// emit EOS" signal. Shorter generations are kept verbatim to
// avoid cutting low-energy speech.
return if (n >= maxGen) trimTailLowEnergy(audio) else audio
} }
/** /**
@ -3444,7 +3446,22 @@ class Qwen3TtsEngine(
* (since all text has already been consumed in the prefill). On EOS, * (since all text has already been consumed in the prefill). On EOS,
* terminate early. Returns the generated [step, codebook] codes. * terminate early. Returns the generated [step, codebook] codes.
*/ */
private fun runHexGenWithPrefill(prefill: List<FloatArray>, maxGen: Int): Array<IntArray> { private fun runHexGenWithPrefill(
prefill: List<FloatArray>,
maxGen: Int,
// Floor on how many steps to generate before the dynamic boost is
// allowed to fire. Prevents short-text segments from terminating on
// a stray "EOS-leaning" hidden state during the first few frames.
// Callers pass ~50 % of expected speech length as a safety floor.
eosBoostMinStep: Int = -1,
// Once EOS rank falls below this threshold (model itself is
// "thinking about stopping"), start adding eosBoostScale per step
// until argmax flips to EOS. Empirically EOS rank plateaus ~150-700
// mid-speech and dips to ~50-60 right at the natural end, so this
// catches the model's intent without a fixed length budget.
eosRankTrigger: Int = 60,
eosBoostScale: Float = 4.0f
): Array<IntArray> {
val padE = ttsPadEmbed ?: return emptyArray() val padE = ttsPadEmbed ?: return emptyArray()
val eosE = ttsEosEmbed ?: return emptyArray() val eosE = ttsEosEmbed ?: return emptyArray()
val allCodes = mutableListOf<IntArray>() val allCodes = mutableListOf<IntArray>()
@ -3467,6 +3484,13 @@ class Qwen3TtsEngine(
// decode step. We follow the same schedule so the model's attention // decode step. We follow the same schedule so the model's attention
// sees the same "text exhausted" signal it was trained with. // sees the same "text exhausted" signal it was trained with.
var eosFedOnce = false var eosFedOnce = false
// Counter for the dynamic EOS boost — once armed, increments every
// step monotonically. Arming requires 3 CONSECUTIVE steps below
// the rank trigger (transient dips during normal speech aren't
// enough); this keeps short sentences from terminating mid-word
// on a fluke low rank.
var boostStepsActive = 0
var consecLowRank = 0
for (genStep in 0 until maxGen) { for (genStep in 0 until maxGen) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis() val tCp = System.currentTimeMillis()
@ -3489,15 +3513,41 @@ class Qwen3TtsEngine(
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev) val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
currentCb0 = sampleTopK(logits, 0.9f, 50)
if (currentCb0 == CODEC_EOS) { nlog("VC EOS at step ${genStep+1}"); break }
// Degeneracy guard: when the talker fails to emit EOS it falls // Dynamic EOS detection: track current EOS rank, require 3
// into a stuck loop where cb0 repeats (the "page beg beg beg" // consecutive low-rank steps before arming. Single-step rank
// artifact audible at the tail of generated phrases). Nine // dips happen mid-utterance during low-energy phonemes and
// consecutive identical cb0s is the threshold the native .pte // would otherwise trigger premature termination on short
// pipeline uses too. The short history is just the last 9 // sentences. Once armed, the boost accumulates monotonically.
// entries of generatedCb0 — cheap to scan. val eosLogit0 = logits[CODEC_EOS]
var eosRank = 0
for (j in logits.indices) if (logits[j] > eosLogit0) eosRank++
if (boostStepsActive == 0 && genStep >= eosBoostMinStep) {
if (eosRank < eosRankTrigger) {
consecLowRank++
if (consecLowRank >= 3) {
nlog("VC boost armed at step ${genStep+1} (EOS rank $eosRank, 3 consecutive low)")
boostStepsActive = 1
}
} else {
consecLowRank = 0
}
} else if (boostStepsActive > 0) {
boostStepsActive++
}
if (boostStepsActive > 0) {
logits[CODEC_EOS] += boostStepsActive * eosBoostScale
}
var argmax = 0; var argmaxV = logits[0]
for (j in 1 until logits.size) if (logits[j] > argmaxV) { argmaxV = logits[j]; argmax = j }
if (argmax == CODEC_EOS) { nlog("VC EOS (boosted argmax) at step ${genStep+1}"); break }
currentCb0 = sampleTopK(logits, 0.9f, 200)
if (currentCb0 == CODEC_EOS) { nlog("VC EOS (sampled) at step ${genStep+1}"); break }
// Degeneracy guard #1: 9 consecutive identical cb0 → stop.
// Catches the simplest stuck loop.
val nHist = generatedCb0.size val nHist = generatedCb0.size
if (nHist >= 9) { if (nHist >= 9) {
val last = generatedCb0[nHist - 1] val last = generatedCb0[nHist - 1]
@ -3508,8 +3558,26 @@ class Qwen3TtsEngine(
break break
} }
} }
// Degeneracy guard #2: low diversity in the recent window.
// The "page beg beg" filler doesn't repeat a single token — it
// cycles through 2-3 tokens. If the last 12 cb0 contain fewer
// than 4 unique values, the talker is in the cycle and we stop
// before the audio degrades further.
if (nHist >= 12) {
val recent = HashSet<Int>()
for (i in nHist - 12 until nHist) recent.add(generatedCb0[i])
if (recent.size < 4) {
nlog("VC degen: only ${recent.size} unique cb0 in last 12 at step ${genStep+1}, stopping")
break
}
}
} }
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms") nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
// Diagnostic: log full cb0 sequence so we can correlate against the
// page-beg region in the produced audio.
val cb0Str = generatedCb0.joinToString(",")
nlog("VC cb0 sequence: $cb0Str")
return allCodes.toTypedArray() return allCodes.toTypedArray()
} }
@ -3583,7 +3651,11 @@ class Qwen3TtsEngine(
hexReset() hexReset()
val segmentAudios = mutableListOf<ShortArray>() val segmentAudios = mutableListOf<ShortArray>()
val gapSamples = SR * 120 / 1000 // Wider gap (250 ms) between segments — gives the listener a small
// breath/pause that mirrors how real speakers separate sentences,
// and masks the EOS-boost cut by surrounding it with silence rather
// than another sentence's onset.
val gapSamples = SR * 250 / 1000
val gap = ShortArray(gapSamples) val gap = ShortArray(gapSamples)
val t0 = System.currentTimeMillis() val t0 = System.currentTimeMillis()
@ -3615,14 +3687,14 @@ class Qwen3TtsEngine(
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e) for (e in suffix) prefill.add(e)
// Empirical budget: Python's voice_clone typically emits ~3.3 // Generous absolute cap; the dynamic EOS boost (triggered when
// codec frames per text token for French. Keep a small cushion // the model's own EOS rank dips below threshold) is what
// so ~80% of runs terminate via EOS/degeneracy before exhausting // actually terminates generation. The minStep floor protects
// the budget; trimming is done by the degeneracy guard inside // against early-termination spikes for short sentences.
// runHexGenWithPrefill. Too-generous maxGen guarantees the tail val expectedSteps = (ids.size * 24) / 10 // ids * 2.4 (int math)
// artifacts the user hears as "page beg beg beg". val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15) val eosBoostMinStep = expectedSteps / 2
val codes = runHexGenWithPrefill(prefill, maxGen) val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue } if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
val n = codes.size val n = codes.size
@ -3633,16 +3705,13 @@ class Qwen3TtsEngine(
} }
} }
val rawAudio = decodeChunked(codebooks, n) val rawAudio = decodeChunked(codebooks, n)
// Only trim when the talker exhausted its budget — that's the // Cosine fade-out on the last 40 ms — softens the cut imposed by
// case where "page beg beg" filler actually sneaks in. When // the EOS boost so the segment ends on a recognisable phoneme
// EOS or the degeneracy guard fires early, the audio is already // tail instead of an abrupt sample-clip.
// clean and trimTailLowEnergy's 35%-of-peak threshold is too val audio = fadeOut(rawAudio, 40)
// aggressive for natural French cadence (it cut Elodie seg1
// from 7.04 s to 2.92 s because the second half of the
// sentence was below the speech threshold).
val audio = if (n >= maxGen) trimTailLowEnergy(rawAudio) else rawAudio
val segMs = System.currentTimeMillis() - tSeg val segMs = System.currentTimeMillis() - tSeg
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio (raw ${rawAudio.size/SR.toFloat()}s, trimmed=${n >= maxGen}) in ${segMs}ms") val budgetHit = if (n >= maxGen) " [maxGen cap]" else ""
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms$budgetHit")
segmentAudios.add(audio) segmentAudios.add(audio)
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio) saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
@ -3675,7 +3744,7 @@ class Qwen3TtsEngine(
if (audio.size < SR / 2) return audio if (audio.size < SR / 2) return audio
val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples
val nWin = audio.size / winSamples val nWin = audio.size / winSamples
if (nWin < 6) return audio if (nWin < 10) return audio
val rms = FloatArray(nWin) val rms = FloatArray(nWin)
for (w in 0 until nWin) { for (w in 0 until nWin) {
var s = 0.0 var s = 0.0
@ -3683,21 +3752,64 @@ class Qwen3TtsEngine(
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x } for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat() rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
} }
// Reference peak over the first 70% of the segment; the tail is
// assumed to be the degenerate filler region.
var peak = 0f var peak = 0f
val refEnd = (nWin * 7) / 10 for (w in 0 until nWin) if (rms[w] > peak) peak = rms[w]
for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w]
val thr = peak * 0.35f // Heuristic 1: classic 35 % sustained threshold from the back.
val thrSust = peak * 0.35f
var lastSpeech = nWin - 1 var lastSpeech = nWin - 1
for (w in nWin - 1 downTo 2) { for (w in nWin - 1 downTo 2) {
if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break } if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
lastSpeech = w; break
}
} }
// Heuristic 2: "low-energy tail" — the "page beg beg" filler emits
// audio that's louder than silence but quieter than real speech.
// If the last 30 % of the segment has a max RMS under 40 % of the
// global peak, the tail is degenerate; cut at the last sustained-
// speech window before the tail starts.
val tailStart = (nWin * 7) / 10
var tailMax = 0f
for (w in tailStart until nWin) if (rms[w] > tailMax) tailMax = rms[w]
if (tailMax < peak * 0.40f) {
for (w in tailStart - 1 downTo 2) {
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
if (w < lastSpeech) lastSpeech = w
break
}
}
}
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1) val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
val keepSamples = (keepWin + 1) * winSamples var keepSamples = (keepWin + 1) * winSamples
// Over-trim guard: never cut below 40 % of the raw length.
val minKeep = (audio.size * 4) / 10
if (keepSamples < minKeep) keepSamples = minKeep
return audio.copyOf(keepSamples) return audio.copyOf(keepSamples)
} }
/**
* Apply a short cosine fade-out to the tail of an audio buffer. The
* EOS boost ends generation right at the model's chosen stop point,
* which usually clips the natural decay of the last syllable. A
* 40-ms fade smooths that into a recognisable phoneme tail without
* shortening the perceived word.
*/
private fun fadeOut(audio: ShortArray, fadeMs: Int = 40): ShortArray {
val fadeSamples = SR * fadeMs / 1000
if (audio.size <= fadeSamples) return audio
val out = audio.copyOf()
val start = out.size - fadeSamples
for (i in 0 until fadeSamples) {
// Cosine roll-off: 1 → 0 over fadeSamples
val t = i.toFloat() / fadeSamples
val gain = 0.5f * (1f + kotlin.math.cos(Math.PI.toFloat() * t))
out[start + i] = (out[start + i] * gain).toInt().toShort()
}
return out
}
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to /** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
* save one file per segment plus the concatenated result for inspection. */ * save one file per segment plus the concatenated result for inspection. */
private fun saveWav(path: String, audio: ShortArray) { private fun saveWav(path: String, audio: ShortArray) {