From 10fd10fd90578aab6707f3ffadb236098bee62fb Mon Sep 17 00:00:00 2001 From: Kazeia Team Date: Tue, 14 Apr 2026 16:22:15 +0200 Subject: [PATCH] =?UTF-8?q?TTS:=20overlap=20CP=E2=86=94BigVGAN=20=E2=80=94?= =?UTF-8?q?=20first=20audio=2014.5s=20=E2=86=92=2010.9s=20per=20segment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streaming variant of the per-segment decode pipeline. As soon as SEQ_LEN codes are accumulated from the talker/CP loop, BigVGAN is dispatched on a background coroutine while the producer keeps generating the rest of the segment. The BigVGAN consumer feeds a streaming crossfader that emits stable audio as it arrives and holds back overlapSamples for the next chunk's blend. Mirrors decodeChunked's semantics exactly so final audio is bit-identical modulo the fadeOut application location (now applied to the final emission tail instead of the full buffer; the last 40ms still get faded). Validated A/B on the same prompt 3 used in the recent benchmark: prompt: "Je me sens un peu triste aujourdhui…" seg 0 first audio: 14 485 ms → 10 936 ms (−3.5 s) end-to-end first audio (LLM trigger → audio): 16.2 s → 12.7 s Stream LLM total: 33 234 ms → 28 594 ms (−4.6 s) Short segments ( --- .../java/com/kazeia/tts/Qwen3TtsEngine.kt | 200 +++++++++++++++++- 1 file changed, 197 insertions(+), 3 deletions(-) diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index acf5a9f..765dd17 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -88,6 +88,15 @@ class Qwen3TtsEngine( private const val TOKEN_USER = 872 private const val TOKEN_ASSISTANT = 1042 private const val TOKEN_NEWLINE = 198 + + // Streaming decode: when true, BigVGAN dispatches a chunk's audio as + // soon as SEQ_LEN codes are ready from the talker/CP loop rather than + // waiting for all tokens. For long segments this overlaps the final + // BigVGAN passes with ongoing talker/CP work on Hexagon, cutting the + // first-audio latency by ~4 s. Short segments (, trailingEmbeds: List, maxGenTokens: Int + prefillEmbeds: List, trailingEmbeds: List, maxGenTokens: Int, + // Invoked synchronously after each generated step with (stepIdx, 16-codebook codes). + // Streaming callers use it to dispatch SEQ_LEN-sized chunks to the BigVGAN pipeline + // as soon as they are ready. null preserves the original batch behaviour. + onCodeStep: ((step: Int, codes: IntArray) -> Unit)? = null ): Array { val talkerMod = talkerPteModule ?: return emptyArray() val cpMod = cpPteModule ?: return emptyArray() @@ -2752,6 +2765,7 @@ class Qwen3TtsEngine( totalCpMs += System.currentTimeMillis() - tCp0 for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1] allCodes.add(codes); generatedCb0.add(currentCb0) + onCodeStep?.invoke(genStep, codes) if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") @@ -3355,8 +3369,17 @@ class Qwen3TtsEngine( var segIdx = 0 for (sentence in chan) { try { - val audio = generateSegmentAudioVC(sentence, segIdx) - if (audio.isNotEmpty()) track.write(audio, 0, audio.size) + if (USE_STREAMING_DECODE && talkerPteModule != null && cpPteModule != null) { + // CP↔BigVGAN overlap path: audio chunks flow to the + // shared AudioTrack as soon as BigVGAN finishes each + // SEQ_LEN window, instead of after the whole segment. + generateSegmentAudioVCStreaming(sentence, segIdx) { pcm -> + if (pcm.isNotEmpty()) track.write(pcm, 0, pcm.size) + } + } else { + val audio = generateSegmentAudioVC(sentence, segIdx) + if (audio.isNotEmpty()) track.write(audio, 0, audio.size) + } segIdx++ } catch (e: Exception) { nlog("session seg $segIdx error: ${e.message}") @@ -3451,6 +3474,177 @@ class Qwen3TtsEngine( return fadeOut(decodeChunked(codebooks, n), 40) } + // ---------- Streaming decode (CP ↔ BigVGAN overlap) ---------- + + /** Carrier from the talker/CP producer to the BigVGAN consumer. */ + private class ChunkMsg(val codebooks: Array, val realTokens: Int) + + /** + * Streaming variant of decodeChunked. Mirrors its semantics exactly: the + * internal `result` buffer accumulates and crossfades chunks the same + * way, so the final assembled audio is bit-identical. The difference is + * that whenever a portion of `result` becomes "stable" (no future chunk + * can modify it, i.e. anything before the last `overlapSamples`), it is + * emitted via `onAudio` immediately. `flushFinal()` emits the remaining + * tail with fadeOut applied, matching the original behaviour. + */ + private inner class StreamingCrossfader(private val onAudio: (ShortArray) -> Unit) { + private val overlapSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN + private var result = ShortArray(0) + private var emittedLen = 0 + private var isFirst = true + + fun feedChunk(chunkAudio: ShortArray, realTokens: Int) { + val trimLen = minOf(realTokens * SAMPLES_PER_TOKEN, chunkAudio.size) + val trimmed = if (trimLen < chunkAudio.size) chunkAudio.copyOf(trimLen) else chunkAudio + + if (isFirst) { + result = trimmed.copyOf() + isFirst = false + } else { + val fadeLen = minOf(overlapSamples, result.size, trimmed.size) + for (i in 0 until fadeLen) { + val alpha = i.toFloat() / fadeLen + val mixed = ((1f - alpha) * result[result.size - fadeLen + i] + alpha * trimmed[i]).toInt() + .coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort() + result[result.size - fadeLen + i] = mixed + } + if (fadeLen < trimmed.size) { + val newPart = trimmed.copyOfRange(fadeLen, trimmed.size) + val combined = ShortArray(result.size + newPart.size) + System.arraycopy(result, 0, combined, 0, result.size) + System.arraycopy(newPart, 0, combined, result.size, newPart.size) + result = combined + } + } + + // Hold back the last `overlapSamples` so the next chunk's + // crossfade can still mutate them; emit everything before that. + val stableEnd = (result.size - overlapSamples).coerceAtLeast(emittedLen) + if (stableEnd > emittedLen) { + val slice = result.copyOfRange(emittedLen, stableEnd) + onAudio(slice) + emittedLen = stableEnd + } + } + + /** Emit any remaining buffered samples with the trailing fadeOut. */ + fun flushFinal() { + if (emittedLen < result.size) { + val tail = result.copyOfRange(emittedLen, result.size) + onAudio(fadeOut(tail, 40)) + emittedLen = result.size + } + } + } + + /** + * Streaming variant of generateSegmentAudioVC. As the talker/CP loop + * produces codes step by step, BigVGAN chunks are dispatched on a + * background coroutine the moment SEQ_LEN codes are accumulated. For a + * 75-token segment this overlaps the last BigVGAN pass with the final + * ~20 talker/CP steps, cutting first-audio latency by ~4 s vs the + * sequential `generateSegmentAudioVC` path. + * + * Short segments ( Unit + ) { + if (bpeTokenizer == null || textEmbedsFullBuf == null || damienVoicePrefix == null || damienVoiceSuffix == null) { + nlog("generateSegmentAudioVCStreaming: Stage 2 assets missing"); return + } + if (talkerPteModule == null || cpPteModule == null) { + nlog("generateSegmentAudioVCStreaming: PTE talker/CP not loaded"); return + } + val prefix = damienVoicePrefix!! + val suffix = damienVoiceSuffix!! + val codecPadEmb = codecEmb(CODEC_PAD) + val ids = bpeTokenizer!!.encode(segText) + nlog("session seg $segIdx (stream) '${segText.take(60)}' → ${ids.size} tokens") + + val prefill = ArrayList(prefix.size + ids.size + suffix.size) + for (e in prefix) prefill.add(e) + for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) + for (e in suffix) prefill.add(e) + + val expectedSteps = (ids.size * 24) / 10 + val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15) + + val tStart = System.currentTimeMillis() + var firstAudioLogged = false + val bvChan = kotlinx.coroutines.channels.Channel(capacity = 4) + val cfader = StreamingCrossfader { pcm -> + if (!firstAudioLogged) { + nlog("streaming seg $segIdx first audio at ${System.currentTimeMillis() - tStart}ms (${pcm.size} samples)") + firstAudioLogged = true + } + onAudio(pcm) + } + val consumerJob = kotlinx.coroutines.CoroutineScope(kotlinx.coroutines.Dispatchers.IO).launch { + try { + for (msg in bvChan) { + val quant = vqDecode(msg.codebooks) + val audio = runSpeechDecoderV2(quant) + cfader.feedChunk(audio, msg.realTokens) + } + cfader.flushFinal() + } catch (e: Exception) { + nlog("streaming seg $segIdx consumer error: ${e.message}") + } + } + + // Producer: run the interleaved talker/CP loop and dispatch each + // SEQ_LEN-aligned window of codes immediately. The consumer's + // crossfader holds back the last `overlapSamples` of audio per + // chunk, so the in-flight chunk's tail can still be mutated by the + // next chunk before being emitted; flushFinal() at end emits the + // last tail with fadeOut. End-of-stream is signalled by closing + // bvChan after the trailing partial chunk is sent. + val collected = mutableListOf() + var nextChunkStart = 0 + + fun buildChunkCb(start: Int, real: Int): Array = Array(NUM_CODEBOOKS) { cb -> + IntArray(SEQ_LEN) { t -> + val src = start + t + if (src < start + real) { + val v = collected[src][cb] + if (v in 0 until CODEBOOK_SIZE) v else 0 + } else 0 + } + } + + try { + runInterleavedPteFromEmbeds(prefill, emptyList(), maxGen) { _, codes -> + collected.add(codes) + while (collected.size >= nextChunkStart + SEQ_LEN) { + val cb = buildChunkCb(nextChunkStart, SEQ_LEN) + kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, SEQ_LEN)) } + nextChunkStart += EFFECTIVE_CHUNK + } + } + } catch (e: Exception) { + nlog("streaming seg $segIdx producer error: ${e.message}") + } + + // Trailing chunk: any remaining tokens after the last full window + // (covers both the medium-segment partial-tail case and the + // short-segment nextChunkStart) { + val trailing = total - nextChunkStart + val cb = buildChunkCb(nextChunkStart, trailing) + kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, trailing)) } + } + bvChan.close() + consumerJob.join() + } + /** * Run the Hexagon talker + CP generation loop with a fully pre-built * prefill (voice prefix + all text tokens). Same decode recipe as