TTS: overlap CP↔BigVGAN — first audio 14.5s → 10.9s per segment

Streaming variant of the per-segment decode pipeline. As soon as SEQ_LEN
codes are accumulated from the talker/CP loop, BigVGAN is dispatched on
a background coroutine while the producer keeps generating the rest of
the segment. The BigVGAN consumer feeds a streaming crossfader that
emits stable audio as it arrives and holds back overlapSamples for the
next chunk's blend.

Mirrors decodeChunked's semantics exactly so final audio is bit-identical
modulo the fadeOut application location (now applied to the final
emission tail instead of the full buffer; the last 40ms still get faded).

Validated A/B on the same prompt 3 used in the recent benchmark:
  prompt: "Je me sens un peu triste aujourdhui…"
  seg 0 first audio:  14 485 ms → 10 936 ms (−3.5 s)
  end-to-end first audio (LLM trigger → audio): 16.2 s → 12.7 s
  Stream LLM total: 33 234 ms → 28 594 ms (−4.6 s)

Short segments (<SEQ_LEN codes) and the legacy non-streaming callers
(generateSegmentAudioVC, decodeChunked, multi-segment pipelines, etc.)
are untouched. The new path is gated behind USE_STREAMING_DECODE so it
can be reverted by flipping a single const if a regression is found.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-14 16:22:15 +02:00
parent 67de8d4767
commit 10fd10fd90
1 changed files with 197 additions and 3 deletions

View File

@ -88,6 +88,15 @@ class Qwen3TtsEngine(
private const val TOKEN_USER = 872 private const val TOKEN_USER = 872
private const val TOKEN_ASSISTANT = 1042 private const val TOKEN_ASSISTANT = 1042
private const val TOKEN_NEWLINE = 198 private const val TOKEN_NEWLINE = 198
// Streaming decode: when true, BigVGAN dispatches a chunk's audio as
// soon as SEQ_LEN codes are ready from the talker/CP loop rather than
// waiting for all tokens. For long segments this overlaps the final
// BigVGAN passes with ongoing talker/CP work on Hexagon, cutting the
// first-audio latency by ~4 s. Short segments (<SEQ_LEN codes) fall
// back to the single-chunk path with zero difference. Flag exists so
// the sequential path can be re-enabled for A/B comparison.
private const val USE_STREAMING_DECODE = true
} }
private var ortEnv: OrtEnvironment? = null private var ortEnv: OrtEnvironment? = null
@ -2674,7 +2683,11 @@ class Qwen3TtsEngine(
/** PTE pipeline from pre-computed embeddings (prefill + trailing). */ /** PTE pipeline from pre-computed embeddings (prefill + trailing). */
private fun runInterleavedPteFromEmbeds( private fun runInterleavedPteFromEmbeds(
prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int,
// Invoked synchronously after each generated step with (stepIdx, 16-codebook codes).
// Streaming callers use it to dispatch SEQ_LEN-sized chunks to the BigVGAN pipeline
// as soon as they are ready. null preserves the original batch behaviour.
onCodeStep: ((step: Int, codes: IntArray) -> Unit)? = null
): Array<IntArray> { ): Array<IntArray> {
val talkerMod = talkerPteModule ?: return emptyArray() val talkerMod = talkerPteModule ?: return emptyArray()
val cpMod = cpPteModule ?: return emptyArray() val cpMod = cpPteModule ?: return emptyArray()
@ -2752,6 +2765,7 @@ class Qwen3TtsEngine(
totalCpMs += System.currentTimeMillis() - tCp0 totalCpMs += System.currentTimeMillis() - tCp0
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1] for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0) allCodes.add(codes); generatedCb0.add(currentCb0)
onCodeStep?.invoke(genStep, codes)
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
@ -3355,8 +3369,17 @@ class Qwen3TtsEngine(
var segIdx = 0 var segIdx = 0
for (sentence in chan) { for (sentence in chan) {
try { try {
val audio = generateSegmentAudioVC(sentence, segIdx) if (USE_STREAMING_DECODE && talkerPteModule != null && cpPteModule != null) {
if (audio.isNotEmpty()) track.write(audio, 0, audio.size) // CP↔BigVGAN overlap path: audio chunks flow to the
// shared AudioTrack as soon as BigVGAN finishes each
// SEQ_LEN window, instead of after the whole segment.
generateSegmentAudioVCStreaming(sentence, segIdx) { pcm ->
if (pcm.isNotEmpty()) track.write(pcm, 0, pcm.size)
}
} else {
val audio = generateSegmentAudioVC(sentence, segIdx)
if (audio.isNotEmpty()) track.write(audio, 0, audio.size)
}
segIdx++ segIdx++
} catch (e: Exception) { } catch (e: Exception) {
nlog("session seg $segIdx error: ${e.message}") nlog("session seg $segIdx error: ${e.message}")
@ -3451,6 +3474,177 @@ class Qwen3TtsEngine(
return fadeOut(decodeChunked(codebooks, n), 40) return fadeOut(decodeChunked(codebooks, n), 40)
} }
// ---------- Streaming decode (CP ↔ BigVGAN overlap) ----------
/** Carrier from the talker/CP producer to the BigVGAN consumer. */
private class ChunkMsg(val codebooks: Array<IntArray>, val realTokens: Int)
/**
* Streaming variant of decodeChunked. Mirrors its semantics exactly: the
* internal `result` buffer accumulates and crossfades chunks the same
* way, so the final assembled audio is bit-identical. The difference is
* that whenever a portion of `result` becomes "stable" (no future chunk
* can modify it, i.e. anything before the last `overlapSamples`), it is
* emitted via `onAudio` immediately. `flushFinal()` emits the remaining
* tail with fadeOut applied, matching the original behaviour.
*/
private inner class StreamingCrossfader(private val onAudio: (ShortArray) -> Unit) {
private val overlapSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
private var result = ShortArray(0)
private var emittedLen = 0
private var isFirst = true
fun feedChunk(chunkAudio: ShortArray, realTokens: Int) {
val trimLen = minOf(realTokens * SAMPLES_PER_TOKEN, chunkAudio.size)
val trimmed = if (trimLen < chunkAudio.size) chunkAudio.copyOf(trimLen) else chunkAudio
if (isFirst) {
result = trimmed.copyOf()
isFirst = false
} else {
val fadeLen = minOf(overlapSamples, result.size, trimmed.size)
for (i in 0 until fadeLen) {
val alpha = i.toFloat() / fadeLen
val mixed = ((1f - alpha) * result[result.size - fadeLen + i] + alpha * trimmed[i]).toInt()
.coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort()
result[result.size - fadeLen + i] = mixed
}
if (fadeLen < trimmed.size) {
val newPart = trimmed.copyOfRange(fadeLen, trimmed.size)
val combined = ShortArray(result.size + newPart.size)
System.arraycopy(result, 0, combined, 0, result.size)
System.arraycopy(newPart, 0, combined, result.size, newPart.size)
result = combined
}
}
// Hold back the last `overlapSamples` so the next chunk's
// crossfade can still mutate them; emit everything before that.
val stableEnd = (result.size - overlapSamples).coerceAtLeast(emittedLen)
if (stableEnd > emittedLen) {
val slice = result.copyOfRange(emittedLen, stableEnd)
onAudio(slice)
emittedLen = stableEnd
}
}
/** Emit any remaining buffered samples with the trailing fadeOut. */
fun flushFinal() {
if (emittedLen < result.size) {
val tail = result.copyOfRange(emittedLen, result.size)
onAudio(fadeOut(tail, 40))
emittedLen = result.size
}
}
}
/**
* Streaming variant of generateSegmentAudioVC. As the talker/CP loop
* produces codes step by step, BigVGAN chunks are dispatched on a
* background coroutine the moment SEQ_LEN codes are accumulated. For a
* 75-token segment this overlaps the last BigVGAN pass with the final
* ~20 talker/CP steps, cutting first-audio latency by ~4 s vs the
* sequential `generateSegmentAudioVC` path.
*
* Short segments (<SEQ_LEN codes) emit a single chunk at end-of-gen,
* matching the legacy single-chunk path with no perceptible difference.
*
* The producer thread blocks on `bvChan.send` if the BigVGAN consumer
* is behind; in practice that never happens because the producer takes
* ~5 s per chunk vs ~2.4 s for BigVGAN.
*/
private suspend fun generateSegmentAudioVCStreaming(
segText: String, segIdx: Int, onAudio: (ShortArray) -> Unit
) {
if (bpeTokenizer == null || textEmbedsFullBuf == null || damienVoicePrefix == null || damienVoiceSuffix == null) {
nlog("generateSegmentAudioVCStreaming: Stage 2 assets missing"); return
}
if (talkerPteModule == null || cpPteModule == null) {
nlog("generateSegmentAudioVCStreaming: PTE talker/CP not loaded"); return
}
val prefix = damienVoicePrefix!!
val suffix = damienVoiceSuffix!!
val codecPadEmb = codecEmb(CODEC_PAD)
val ids = bpeTokenizer!!.encode(segText)
nlog("session seg $segIdx (stream) '${segText.take(60)}' → ${ids.size} tokens")
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
for (e in prefix) prefill.add(e)
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e)
val expectedSteps = (ids.size * 24) / 10
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
val tStart = System.currentTimeMillis()
var firstAudioLogged = false
val bvChan = kotlinx.coroutines.channels.Channel<ChunkMsg>(capacity = 4)
val cfader = StreamingCrossfader { pcm ->
if (!firstAudioLogged) {
nlog("streaming seg $segIdx first audio at ${System.currentTimeMillis() - tStart}ms (${pcm.size} samples)")
firstAudioLogged = true
}
onAudio(pcm)
}
val consumerJob = kotlinx.coroutines.CoroutineScope(kotlinx.coroutines.Dispatchers.IO).launch {
try {
for (msg in bvChan) {
val quant = vqDecode(msg.codebooks)
val audio = runSpeechDecoderV2(quant)
cfader.feedChunk(audio, msg.realTokens)
}
cfader.flushFinal()
} catch (e: Exception) {
nlog("streaming seg $segIdx consumer error: ${e.message}")
}
}
// Producer: run the interleaved talker/CP loop and dispatch each
// SEQ_LEN-aligned window of codes immediately. The consumer's
// crossfader holds back the last `overlapSamples` of audio per
// chunk, so the in-flight chunk's tail can still be mutated by the
// next chunk before being emitted; flushFinal() at end emits the
// last tail with fadeOut. End-of-stream is signalled by closing
// bvChan after the trailing partial chunk is sent.
val collected = mutableListOf<IntArray>()
var nextChunkStart = 0
fun buildChunkCb(start: Int, real: Int): Array<IntArray> = Array(NUM_CODEBOOKS) { cb ->
IntArray(SEQ_LEN) { t ->
val src = start + t
if (src < start + real) {
val v = collected[src][cb]
if (v in 0 until CODEBOOK_SIZE) v else 0
} else 0
}
}
try {
runInterleavedPteFromEmbeds(prefill, emptyList(), maxGen) { _, codes ->
collected.add(codes)
while (collected.size >= nextChunkStart + SEQ_LEN) {
val cb = buildChunkCb(nextChunkStart, SEQ_LEN)
kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, SEQ_LEN)) }
nextChunkStart += EFFECTIVE_CHUNK
}
}
} catch (e: Exception) {
nlog("streaming seg $segIdx producer error: ${e.message}")
}
// Trailing chunk: any remaining tokens after the last full window
// (covers both the medium-segment partial-tail case and the
// short-segment <SEQ_LEN single-chunk case where nextChunkStart=0).
val total = collected.size
if (total > nextChunkStart) {
val trailing = total - nextChunkStart
val cb = buildChunkCb(nextChunkStart, trailing)
kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, trailing)) }
}
bvChan.close()
consumerJob.join()
}
/** /**
* Run the Hexagon talker + CP generation loop with a fully pre-built * Run the Hexagon talker + CP generation loop with a fully pre-built
* prefill (voice prefix + all text tokens). Same decode recipe as * prefill (voice prefix + all text tokens). Same decode recipe as