TTS: dynamic EOS-rank boost terminates generation cleanly across voices
Replaces the fixed maxGen + length-based boost with a fully dynamic
end-of-utterance detector that watches the model's own EOS logit rank.
End result on the Baer 3-segment monologue, validated by user as
"FORMIDABLE" / "impeccable" with both Damien and Zelda voices:
- All 3 segments terminate via EOS (no maxGen cap hit)
- No "page beg beg" filler tail
- No abrupt cuts between segments
- Audio durations 5-8 s per segment, matching Python within ~10 %
How it works (runHexGenWithPrefill, in tts/Qwen3TtsEngine.kt):
1. At every decode step, compute the rank of CODEC_EOS in the
repetition-penalised logits. Mid-utterance the rank sits at
150-700 (model is committed to producing speech). Approaching
the natural end, the rank dips toward top-50.
2. Arm the boost only when EOS rank stays below eosRankTrigger=60
for THREE consecutive steps. The 3-step requirement filters out
transient single-step dips that occur during low-energy phonemes
mid-sentence (without it, short sentences would terminate after
~3 s). Arming is also gated by eosBoostMinStep (50 % of expected
speech length) so we never arm in the very first frames.
3. Once armed, the boost increments monotonically: each subsequent
step adds boostStepsActive * eosBoostScale to the EOS logit. The
accumulated boost lifts EOS above top-1 within 1-3 steps, the
argmax check fires, and the loop breaks. Scale=4 gives the model
a small natural decay before termination; scale=5 was perfect-but-
slightly-clipping, scale=3 wasn't strong enough to outpace the
growing top-1 logit.
Other tweaks bundled in this commit because they all contribute to
the clean output:
* Inter-segment gap 120 → 250 ms — gives the listener a perceived
sentence boundary instead of a hard concatenation.
* fadeOut(audio, 40) on every segment — cosine roll-off over the
last 40 ms so the EOS-clipped tail decays naturally instead of
sample-clipping.
* top_k 50 → 200 in the fallback sample call — wider pool to keep
EOS reachable when the boost just fails to hit argmax.
Voice swap is a 45 KB file push (damien_voice_prefix.bin and
damien_voice_suffix.bin). Successfully tested today with Elodie
(female, norm 10.12) and Zelda (norm 9.39) using Damien (norm 10.36)
as the baseline — same Kotlin code, no rebuild needed.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c25040a780
commit
f548e02283
|
|
@ -3417,8 +3417,13 @@ class Qwen3TtsEngine(
|
||||||
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
||||||
for (e in suffix) prefill.add(e)
|
for (e in suffix) prefill.add(e)
|
||||||
|
|
||||||
val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15)
|
// See synthesizeTextStreaming for the rationale: generous absolute
|
||||||
val codes = runHexGenWithPrefill(prefill, maxGen)
|
// cap, dynamic EOS-rank-driven boost, safety floor at 50 % of
|
||||||
|
// expected to suppress early termination on short utterances.
|
||||||
|
val expectedSteps = (ids.size * 24) / 10
|
||||||
|
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
|
||||||
|
val eosBoostMinStep = expectedSteps / 2
|
||||||
|
val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
|
||||||
if (codes.isEmpty()) return ShortArray(0)
|
if (codes.isEmpty()) return ShortArray(0)
|
||||||
|
|
||||||
val n = codes.size
|
val n = codes.size
|
||||||
|
|
@ -3428,12 +3433,9 @@ class Qwen3TtsEngine(
|
||||||
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val audio = decodeChunked(codebooks, n)
|
// 40 ms fade-out so the EOS-clipped tail decays naturally before
|
||||||
// Match the conservative trim policy in synthesizeTextStreaming:
|
// the AudioTrack write.
|
||||||
// only trim when we hit the maxGen cap, which is the "failure to
|
return fadeOut(decodeChunked(codebooks, n), 40)
|
||||||
// emit EOS" signal. Shorter generations are kept verbatim to
|
|
||||||
// avoid cutting low-energy speech.
|
|
||||||
return if (n >= maxGen) trimTailLowEnergy(audio) else audio
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -3444,7 +3446,22 @@ class Qwen3TtsEngine(
|
||||||
* (since all text has already been consumed in the prefill). On EOS,
|
* (since all text has already been consumed in the prefill). On EOS,
|
||||||
* terminate early. Returns the generated [step, codebook] codes.
|
* terminate early. Returns the generated [step, codebook] codes.
|
||||||
*/
|
*/
|
||||||
private fun runHexGenWithPrefill(prefill: List<FloatArray>, maxGen: Int): Array<IntArray> {
|
private fun runHexGenWithPrefill(
|
||||||
|
prefill: List<FloatArray>,
|
||||||
|
maxGen: Int,
|
||||||
|
// Floor on how many steps to generate before the dynamic boost is
|
||||||
|
// allowed to fire. Prevents short-text segments from terminating on
|
||||||
|
// a stray "EOS-leaning" hidden state during the first few frames.
|
||||||
|
// Callers pass ~50 % of expected speech length as a safety floor.
|
||||||
|
eosBoostMinStep: Int = -1,
|
||||||
|
// Once EOS rank falls below this threshold (model itself is
|
||||||
|
// "thinking about stopping"), start adding eosBoostScale per step
|
||||||
|
// until argmax flips to EOS. Empirically EOS rank plateaus ~150-700
|
||||||
|
// mid-speech and dips to ~50-60 right at the natural end, so this
|
||||||
|
// catches the model's intent without a fixed length budget.
|
||||||
|
eosRankTrigger: Int = 60,
|
||||||
|
eosBoostScale: Float = 4.0f
|
||||||
|
): Array<IntArray> {
|
||||||
val padE = ttsPadEmbed ?: return emptyArray()
|
val padE = ttsPadEmbed ?: return emptyArray()
|
||||||
val eosE = ttsEosEmbed ?: return emptyArray()
|
val eosE = ttsEosEmbed ?: return emptyArray()
|
||||||
val allCodes = mutableListOf<IntArray>()
|
val allCodes = mutableListOf<IntArray>()
|
||||||
|
|
@ -3467,6 +3484,13 @@ class Qwen3TtsEngine(
|
||||||
// decode step. We follow the same schedule so the model's attention
|
// decode step. We follow the same schedule so the model's attention
|
||||||
// sees the same "text exhausted" signal it was trained with.
|
// sees the same "text exhausted" signal it was trained with.
|
||||||
var eosFedOnce = false
|
var eosFedOnce = false
|
||||||
|
// Counter for the dynamic EOS boost — once armed, increments every
|
||||||
|
// step monotonically. Arming requires 3 CONSECUTIVE steps below
|
||||||
|
// the rank trigger (transient dips during normal speech aren't
|
||||||
|
// enough); this keeps short sentences from terminating mid-word
|
||||||
|
// on a fluke low rank.
|
||||||
|
var boostStepsActive = 0
|
||||||
|
var consecLowRank = 0
|
||||||
for (genStep in 0 until maxGen) {
|
for (genStep in 0 until maxGen) {
|
||||||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||||||
val tCp = System.currentTimeMillis()
|
val tCp = System.currentTimeMillis()
|
||||||
|
|
@ -3489,15 +3513,41 @@ class Qwen3TtsEngine(
|
||||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
|
||||||
if (currentCb0 == CODEC_EOS) { nlog("VC EOS at step ${genStep+1}"); break }
|
|
||||||
|
|
||||||
// Degeneracy guard: when the talker fails to emit EOS it falls
|
// Dynamic EOS detection: track current EOS rank, require 3
|
||||||
// into a stuck loop where cb0 repeats (the "page beg beg beg"
|
// consecutive low-rank steps before arming. Single-step rank
|
||||||
// artifact audible at the tail of generated phrases). Nine
|
// dips happen mid-utterance during low-energy phonemes and
|
||||||
// consecutive identical cb0s is the threshold the native .pte
|
// would otherwise trigger premature termination on short
|
||||||
// pipeline uses too. The short history is just the last 9
|
// sentences. Once armed, the boost accumulates monotonically.
|
||||||
// entries of generatedCb0 — cheap to scan.
|
val eosLogit0 = logits[CODEC_EOS]
|
||||||
|
var eosRank = 0
|
||||||
|
for (j in logits.indices) if (logits[j] > eosLogit0) eosRank++
|
||||||
|
if (boostStepsActive == 0 && genStep >= eosBoostMinStep) {
|
||||||
|
if (eosRank < eosRankTrigger) {
|
||||||
|
consecLowRank++
|
||||||
|
if (consecLowRank >= 3) {
|
||||||
|
nlog("VC boost armed at step ${genStep+1} (EOS rank $eosRank, 3 consecutive low)")
|
||||||
|
boostStepsActive = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
consecLowRank = 0
|
||||||
|
}
|
||||||
|
} else if (boostStepsActive > 0) {
|
||||||
|
boostStepsActive++
|
||||||
|
}
|
||||||
|
if (boostStepsActive > 0) {
|
||||||
|
logits[CODEC_EOS] += boostStepsActive * eosBoostScale
|
||||||
|
}
|
||||||
|
|
||||||
|
var argmax = 0; var argmaxV = logits[0]
|
||||||
|
for (j in 1 until logits.size) if (logits[j] > argmaxV) { argmaxV = logits[j]; argmax = j }
|
||||||
|
if (argmax == CODEC_EOS) { nlog("VC EOS (boosted argmax) at step ${genStep+1}"); break }
|
||||||
|
|
||||||
|
currentCb0 = sampleTopK(logits, 0.9f, 200)
|
||||||
|
if (currentCb0 == CODEC_EOS) { nlog("VC EOS (sampled) at step ${genStep+1}"); break }
|
||||||
|
|
||||||
|
// Degeneracy guard #1: 9 consecutive identical cb0 → stop.
|
||||||
|
// Catches the simplest stuck loop.
|
||||||
val nHist = generatedCb0.size
|
val nHist = generatedCb0.size
|
||||||
if (nHist >= 9) {
|
if (nHist >= 9) {
|
||||||
val last = generatedCb0[nHist - 1]
|
val last = generatedCb0[nHist - 1]
|
||||||
|
|
@ -3508,8 +3558,26 @@ class Qwen3TtsEngine(
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Degeneracy guard #2: low diversity in the recent window.
|
||||||
|
// The "page beg beg" filler doesn't repeat a single token — it
|
||||||
|
// cycles through 2-3 tokens. If the last 12 cb0 contain fewer
|
||||||
|
// than 4 unique values, the talker is in the cycle and we stop
|
||||||
|
// before the audio degrades further.
|
||||||
|
if (nHist >= 12) {
|
||||||
|
val recent = HashSet<Int>()
|
||||||
|
for (i in nHist - 12 until nHist) recent.add(generatedCb0[i])
|
||||||
|
if (recent.size < 4) {
|
||||||
|
nlog("VC degen: only ${recent.size} unique cb0 in last 12 at step ${genStep+1}, stopping")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
|
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
|
||||||
|
// Diagnostic: log full cb0 sequence so we can correlate against the
|
||||||
|
// page-beg region in the produced audio.
|
||||||
|
val cb0Str = generatedCb0.joinToString(",")
|
||||||
|
nlog("VC cb0 sequence: $cb0Str")
|
||||||
return allCodes.toTypedArray()
|
return allCodes.toTypedArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3583,7 +3651,11 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
hexReset()
|
hexReset()
|
||||||
val segmentAudios = mutableListOf<ShortArray>()
|
val segmentAudios = mutableListOf<ShortArray>()
|
||||||
val gapSamples = SR * 120 / 1000
|
// Wider gap (250 ms) between segments — gives the listener a small
|
||||||
|
// breath/pause that mirrors how real speakers separate sentences,
|
||||||
|
// and masks the EOS-boost cut by surrounding it with silence rather
|
||||||
|
// than another sentence's onset.
|
||||||
|
val gapSamples = SR * 250 / 1000
|
||||||
val gap = ShortArray(gapSamples)
|
val gap = ShortArray(gapSamples)
|
||||||
val t0 = System.currentTimeMillis()
|
val t0 = System.currentTimeMillis()
|
||||||
|
|
||||||
|
|
@ -3615,14 +3687,14 @@ class Qwen3TtsEngine(
|
||||||
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
||||||
for (e in suffix) prefill.add(e)
|
for (e in suffix) prefill.add(e)
|
||||||
|
|
||||||
// Empirical budget: Python's voice_clone typically emits ~3.3
|
// Generous absolute cap; the dynamic EOS boost (triggered when
|
||||||
// codec frames per text token for French. Keep a small cushion
|
// the model's own EOS rank dips below threshold) is what
|
||||||
// so ~80% of runs terminate via EOS/degeneracy before exhausting
|
// actually terminates generation. The minStep floor protects
|
||||||
// the budget; trimming is done by the degeneracy guard inside
|
// against early-termination spikes for short sentences.
|
||||||
// runHexGenWithPrefill. Too-generous maxGen guarantees the tail
|
val expectedSteps = (ids.size * 24) / 10 // ids * 2.4 (int math)
|
||||||
// artifacts the user hears as "page beg beg beg".
|
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
|
||||||
val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15)
|
val eosBoostMinStep = expectedSteps / 2
|
||||||
val codes = runHexGenWithPrefill(prefill, maxGen)
|
val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
|
||||||
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
|
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
|
||||||
|
|
||||||
val n = codes.size
|
val n = codes.size
|
||||||
|
|
@ -3633,16 +3705,13 @@ class Qwen3TtsEngine(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val rawAudio = decodeChunked(codebooks, n)
|
val rawAudio = decodeChunked(codebooks, n)
|
||||||
// Only trim when the talker exhausted its budget — that's the
|
// Cosine fade-out on the last 40 ms — softens the cut imposed by
|
||||||
// case where "page beg beg" filler actually sneaks in. When
|
// the EOS boost so the segment ends on a recognisable phoneme
|
||||||
// EOS or the degeneracy guard fires early, the audio is already
|
// tail instead of an abrupt sample-clip.
|
||||||
// clean and trimTailLowEnergy's 35%-of-peak threshold is too
|
val audio = fadeOut(rawAudio, 40)
|
||||||
// aggressive for natural French cadence (it cut Elodie seg1
|
|
||||||
// from 7.04 s to 2.92 s because the second half of the
|
|
||||||
// sentence was below the speech threshold).
|
|
||||||
val audio = if (n >= maxGen) trimTailLowEnergy(rawAudio) else rawAudio
|
|
||||||
val segMs = System.currentTimeMillis() - tSeg
|
val segMs = System.currentTimeMillis() - tSeg
|
||||||
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio (raw ${rawAudio.size/SR.toFloat()}s, trimmed=${n >= maxGen}) in ${segMs}ms")
|
val budgetHit = if (n >= maxGen) " [maxGen cap]" else ""
|
||||||
|
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms$budgetHit")
|
||||||
|
|
||||||
segmentAudios.add(audio)
|
segmentAudios.add(audio)
|
||||||
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
|
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
|
||||||
|
|
@ -3675,7 +3744,7 @@ class Qwen3TtsEngine(
|
||||||
if (audio.size < SR / 2) return audio
|
if (audio.size < SR / 2) return audio
|
||||||
val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples
|
val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples
|
||||||
val nWin = audio.size / winSamples
|
val nWin = audio.size / winSamples
|
||||||
if (nWin < 6) return audio
|
if (nWin < 10) return audio
|
||||||
val rms = FloatArray(nWin)
|
val rms = FloatArray(nWin)
|
||||||
for (w in 0 until nWin) {
|
for (w in 0 until nWin) {
|
||||||
var s = 0.0
|
var s = 0.0
|
||||||
|
|
@ -3683,21 +3752,64 @@ class Qwen3TtsEngine(
|
||||||
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
|
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
|
||||||
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
|
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
|
||||||
}
|
}
|
||||||
// Reference peak over the first 70% of the segment; the tail is
|
|
||||||
// assumed to be the degenerate filler region.
|
|
||||||
var peak = 0f
|
var peak = 0f
|
||||||
val refEnd = (nWin * 7) / 10
|
for (w in 0 until nWin) if (rms[w] > peak) peak = rms[w]
|
||||||
for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w]
|
|
||||||
val thr = peak * 0.35f
|
// Heuristic 1: classic 35 % sustained threshold from the back.
|
||||||
|
val thrSust = peak * 0.35f
|
||||||
var lastSpeech = nWin - 1
|
var lastSpeech = nWin - 1
|
||||||
for (w in nWin - 1 downTo 2) {
|
for (w in nWin - 1 downTo 2) {
|
||||||
if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break }
|
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
|
||||||
|
lastSpeech = w; break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Heuristic 2: "low-energy tail" — the "page beg beg" filler emits
|
||||||
|
// audio that's louder than silence but quieter than real speech.
|
||||||
|
// If the last 30 % of the segment has a max RMS under 40 % of the
|
||||||
|
// global peak, the tail is degenerate; cut at the last sustained-
|
||||||
|
// speech window before the tail starts.
|
||||||
|
val tailStart = (nWin * 7) / 10
|
||||||
|
var tailMax = 0f
|
||||||
|
for (w in tailStart until nWin) if (rms[w] > tailMax) tailMax = rms[w]
|
||||||
|
if (tailMax < peak * 0.40f) {
|
||||||
|
for (w in tailStart - 1 downTo 2) {
|
||||||
|
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
|
||||||
|
if (w < lastSpeech) lastSpeech = w
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
|
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
|
||||||
val keepSamples = (keepWin + 1) * winSamples
|
var keepSamples = (keepWin + 1) * winSamples
|
||||||
|
// Over-trim guard: never cut below 40 % of the raw length.
|
||||||
|
val minKeep = (audio.size * 4) / 10
|
||||||
|
if (keepSamples < minKeep) keepSamples = minKeep
|
||||||
return audio.copyOf(keepSamples)
|
return audio.copyOf(keepSamples)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply a short cosine fade-out to the tail of an audio buffer. The
|
||||||
|
* EOS boost ends generation right at the model's chosen stop point,
|
||||||
|
* which usually clips the natural decay of the last syllable. A
|
||||||
|
* 40-ms fade smooths that into a recognisable phoneme tail without
|
||||||
|
* shortening the perceived word.
|
||||||
|
*/
|
||||||
|
private fun fadeOut(audio: ShortArray, fadeMs: Int = 40): ShortArray {
|
||||||
|
val fadeSamples = SR * fadeMs / 1000
|
||||||
|
if (audio.size <= fadeSamples) return audio
|
||||||
|
val out = audio.copyOf()
|
||||||
|
val start = out.size - fadeSamples
|
||||||
|
for (i in 0 until fadeSamples) {
|
||||||
|
// Cosine roll-off: 1 → 0 over fadeSamples
|
||||||
|
val t = i.toFloat() / fadeSamples
|
||||||
|
val gain = 0.5f * (1f + kotlin.math.cos(Math.PI.toFloat() * t))
|
||||||
|
out[start + i] = (out[start + i] * gain).toInt().toShort()
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
|
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
|
||||||
* save one file per segment plus the concatenated result for inspection. */
|
* save one file per segment plus the concatenated result for inspection. */
|
||||||
private fun saveWav(path: String, audio: ShortArray) {
|
private fun saveWav(path: String, audio: ShortArray) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue