TTS: overlap CP↔BigVGAN — first audio 14.5s → 10.9s per segment
Streaming variant of the per-segment decode pipeline. As soon as SEQ_LEN codes are accumulated from the talker/CP loop, BigVGAN is dispatched on a background coroutine while the producer keeps generating the rest of the segment. The BigVGAN consumer feeds a streaming crossfader that emits stable audio as it arrives and holds back overlapSamples for the next chunk's blend. Mirrors decodeChunked's semantics exactly so final audio is bit-identical modulo the fadeOut application location (now applied to the final emission tail instead of the full buffer; the last 40ms still get faded). Validated A/B on the same prompt 3 used in the recent benchmark: prompt: "Je me sens un peu triste aujourdhui…" seg 0 first audio: 14 485 ms → 10 936 ms (−3.5 s) end-to-end first audio (LLM trigger → audio): 16.2 s → 12.7 s Stream LLM total: 33 234 ms → 28 594 ms (−4.6 s) Short segments (<SEQ_LEN codes) and the legacy non-streaming callers (generateSegmentAudioVC, decodeChunked, multi-segment pipelines, etc.) are untouched. The new path is gated behind USE_STREAMING_DECODE so it can be reverted by flipping a single const if a regression is found. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
67de8d4767
commit
10fd10fd90
|
|
@ -88,6 +88,15 @@ class Qwen3TtsEngine(
|
||||||
private const val TOKEN_USER = 872
|
private const val TOKEN_USER = 872
|
||||||
private const val TOKEN_ASSISTANT = 1042
|
private const val TOKEN_ASSISTANT = 1042
|
||||||
private const val TOKEN_NEWLINE = 198
|
private const val TOKEN_NEWLINE = 198
|
||||||
|
|
||||||
|
// Streaming decode: when true, BigVGAN dispatches a chunk's audio as
|
||||||
|
// soon as SEQ_LEN codes are ready from the talker/CP loop rather than
|
||||||
|
// waiting for all tokens. For long segments this overlaps the final
|
||||||
|
// BigVGAN passes with ongoing talker/CP work on Hexagon, cutting the
|
||||||
|
// first-audio latency by ~4 s. Short segments (<SEQ_LEN codes) fall
|
||||||
|
// back to the single-chunk path with zero difference. Flag exists so
|
||||||
|
// the sequential path can be re-enabled for A/B comparison.
|
||||||
|
private const val USE_STREAMING_DECODE = true
|
||||||
}
|
}
|
||||||
|
|
||||||
private var ortEnv: OrtEnvironment? = null
|
private var ortEnv: OrtEnvironment? = null
|
||||||
|
|
@ -2674,7 +2683,11 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
/** PTE pipeline from pre-computed embeddings (prefill + trailing). */
|
/** PTE pipeline from pre-computed embeddings (prefill + trailing). */
|
||||||
private fun runInterleavedPteFromEmbeds(
|
private fun runInterleavedPteFromEmbeds(
|
||||||
prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int
|
prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int,
|
||||||
|
// Invoked synchronously after each generated step with (stepIdx, 16-codebook codes).
|
||||||
|
// Streaming callers use it to dispatch SEQ_LEN-sized chunks to the BigVGAN pipeline
|
||||||
|
// as soon as they are ready. null preserves the original batch behaviour.
|
||||||
|
onCodeStep: ((step: Int, codes: IntArray) -> Unit)? = null
|
||||||
): Array<IntArray> {
|
): Array<IntArray> {
|
||||||
val talkerMod = talkerPteModule ?: return emptyArray()
|
val talkerMod = talkerPteModule ?: return emptyArray()
|
||||||
val cpMod = cpPteModule ?: return emptyArray()
|
val cpMod = cpPteModule ?: return emptyArray()
|
||||||
|
|
@ -2752,6 +2765,7 @@ class Qwen3TtsEngine(
|
||||||
totalCpMs += System.currentTimeMillis() - tCp0
|
totalCpMs += System.currentTimeMillis() - tCp0
|
||||||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||||||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||||||
|
onCodeStep?.invoke(genStep, codes)
|
||||||
|
|
||||||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||||
|
|
||||||
|
|
@ -3355,8 +3369,17 @@ class Qwen3TtsEngine(
|
||||||
var segIdx = 0
|
var segIdx = 0
|
||||||
for (sentence in chan) {
|
for (sentence in chan) {
|
||||||
try {
|
try {
|
||||||
val audio = generateSegmentAudioVC(sentence, segIdx)
|
if (USE_STREAMING_DECODE && talkerPteModule != null && cpPteModule != null) {
|
||||||
if (audio.isNotEmpty()) track.write(audio, 0, audio.size)
|
// CP↔BigVGAN overlap path: audio chunks flow to the
|
||||||
|
// shared AudioTrack as soon as BigVGAN finishes each
|
||||||
|
// SEQ_LEN window, instead of after the whole segment.
|
||||||
|
generateSegmentAudioVCStreaming(sentence, segIdx) { pcm ->
|
||||||
|
if (pcm.isNotEmpty()) track.write(pcm, 0, pcm.size)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
val audio = generateSegmentAudioVC(sentence, segIdx)
|
||||||
|
if (audio.isNotEmpty()) track.write(audio, 0, audio.size)
|
||||||
|
}
|
||||||
segIdx++
|
segIdx++
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
nlog("session seg $segIdx error: ${e.message}")
|
nlog("session seg $segIdx error: ${e.message}")
|
||||||
|
|
@ -3451,6 +3474,177 @@ class Qwen3TtsEngine(
|
||||||
return fadeOut(decodeChunked(codebooks, n), 40)
|
return fadeOut(decodeChunked(codebooks, n), 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------- Streaming decode (CP ↔ BigVGAN overlap) ----------
|
||||||
|
|
||||||
|
/** Carrier from the talker/CP producer to the BigVGAN consumer. */
|
||||||
|
private class ChunkMsg(val codebooks: Array<IntArray>, val realTokens: Int)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Streaming variant of decodeChunked. Mirrors its semantics exactly: the
|
||||||
|
* internal `result` buffer accumulates and crossfades chunks the same
|
||||||
|
* way, so the final assembled audio is bit-identical. The difference is
|
||||||
|
* that whenever a portion of `result` becomes "stable" (no future chunk
|
||||||
|
* can modify it, i.e. anything before the last `overlapSamples`), it is
|
||||||
|
* emitted via `onAudio` immediately. `flushFinal()` emits the remaining
|
||||||
|
* tail with fadeOut applied, matching the original behaviour.
|
||||||
|
*/
|
||||||
|
private inner class StreamingCrossfader(private val onAudio: (ShortArray) -> Unit) {
|
||||||
|
private val overlapSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
|
||||||
|
private var result = ShortArray(0)
|
||||||
|
private var emittedLen = 0
|
||||||
|
private var isFirst = true
|
||||||
|
|
||||||
|
fun feedChunk(chunkAudio: ShortArray, realTokens: Int) {
|
||||||
|
val trimLen = minOf(realTokens * SAMPLES_PER_TOKEN, chunkAudio.size)
|
||||||
|
val trimmed = if (trimLen < chunkAudio.size) chunkAudio.copyOf(trimLen) else chunkAudio
|
||||||
|
|
||||||
|
if (isFirst) {
|
||||||
|
result = trimmed.copyOf()
|
||||||
|
isFirst = false
|
||||||
|
} else {
|
||||||
|
val fadeLen = minOf(overlapSamples, result.size, trimmed.size)
|
||||||
|
for (i in 0 until fadeLen) {
|
||||||
|
val alpha = i.toFloat() / fadeLen
|
||||||
|
val mixed = ((1f - alpha) * result[result.size - fadeLen + i] + alpha * trimmed[i]).toInt()
|
||||||
|
.coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort()
|
||||||
|
result[result.size - fadeLen + i] = mixed
|
||||||
|
}
|
||||||
|
if (fadeLen < trimmed.size) {
|
||||||
|
val newPart = trimmed.copyOfRange(fadeLen, trimmed.size)
|
||||||
|
val combined = ShortArray(result.size + newPart.size)
|
||||||
|
System.arraycopy(result, 0, combined, 0, result.size)
|
||||||
|
System.arraycopy(newPart, 0, combined, result.size, newPart.size)
|
||||||
|
result = combined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hold back the last `overlapSamples` so the next chunk's
|
||||||
|
// crossfade can still mutate them; emit everything before that.
|
||||||
|
val stableEnd = (result.size - overlapSamples).coerceAtLeast(emittedLen)
|
||||||
|
if (stableEnd > emittedLen) {
|
||||||
|
val slice = result.copyOfRange(emittedLen, stableEnd)
|
||||||
|
onAudio(slice)
|
||||||
|
emittedLen = stableEnd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Emit any remaining buffered samples with the trailing fadeOut. */
|
||||||
|
fun flushFinal() {
|
||||||
|
if (emittedLen < result.size) {
|
||||||
|
val tail = result.copyOfRange(emittedLen, result.size)
|
||||||
|
onAudio(fadeOut(tail, 40))
|
||||||
|
emittedLen = result.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Streaming variant of generateSegmentAudioVC. As the talker/CP loop
|
||||||
|
* produces codes step by step, BigVGAN chunks are dispatched on a
|
||||||
|
* background coroutine the moment SEQ_LEN codes are accumulated. For a
|
||||||
|
* 75-token segment this overlaps the last BigVGAN pass with the final
|
||||||
|
* ~20 talker/CP steps, cutting first-audio latency by ~4 s vs the
|
||||||
|
* sequential `generateSegmentAudioVC` path.
|
||||||
|
*
|
||||||
|
* Short segments (<SEQ_LEN codes) emit a single chunk at end-of-gen,
|
||||||
|
* matching the legacy single-chunk path with no perceptible difference.
|
||||||
|
*
|
||||||
|
* The producer thread blocks on `bvChan.send` if the BigVGAN consumer
|
||||||
|
* is behind; in practice that never happens because the producer takes
|
||||||
|
* ~5 s per chunk vs ~2.4 s for BigVGAN.
|
||||||
|
*/
|
||||||
|
private suspend fun generateSegmentAudioVCStreaming(
|
||||||
|
segText: String, segIdx: Int, onAudio: (ShortArray) -> Unit
|
||||||
|
) {
|
||||||
|
if (bpeTokenizer == null || textEmbedsFullBuf == null || damienVoicePrefix == null || damienVoiceSuffix == null) {
|
||||||
|
nlog("generateSegmentAudioVCStreaming: Stage 2 assets missing"); return
|
||||||
|
}
|
||||||
|
if (talkerPteModule == null || cpPteModule == null) {
|
||||||
|
nlog("generateSegmentAudioVCStreaming: PTE talker/CP not loaded"); return
|
||||||
|
}
|
||||||
|
val prefix = damienVoicePrefix!!
|
||||||
|
val suffix = damienVoiceSuffix!!
|
||||||
|
val codecPadEmb = codecEmb(CODEC_PAD)
|
||||||
|
val ids = bpeTokenizer!!.encode(segText)
|
||||||
|
nlog("session seg $segIdx (stream) '${segText.take(60)}' → ${ids.size} tokens")
|
||||||
|
|
||||||
|
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
|
||||||
|
for (e in prefix) prefill.add(e)
|
||||||
|
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
||||||
|
for (e in suffix) prefill.add(e)
|
||||||
|
|
||||||
|
val expectedSteps = (ids.size * 24) / 10
|
||||||
|
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
|
||||||
|
|
||||||
|
val tStart = System.currentTimeMillis()
|
||||||
|
var firstAudioLogged = false
|
||||||
|
val bvChan = kotlinx.coroutines.channels.Channel<ChunkMsg>(capacity = 4)
|
||||||
|
val cfader = StreamingCrossfader { pcm ->
|
||||||
|
if (!firstAudioLogged) {
|
||||||
|
nlog("streaming seg $segIdx first audio at ${System.currentTimeMillis() - tStart}ms (${pcm.size} samples)")
|
||||||
|
firstAudioLogged = true
|
||||||
|
}
|
||||||
|
onAudio(pcm)
|
||||||
|
}
|
||||||
|
val consumerJob = kotlinx.coroutines.CoroutineScope(kotlinx.coroutines.Dispatchers.IO).launch {
|
||||||
|
try {
|
||||||
|
for (msg in bvChan) {
|
||||||
|
val quant = vqDecode(msg.codebooks)
|
||||||
|
val audio = runSpeechDecoderV2(quant)
|
||||||
|
cfader.feedChunk(audio, msg.realTokens)
|
||||||
|
}
|
||||||
|
cfader.flushFinal()
|
||||||
|
} catch (e: Exception) {
|
||||||
|
nlog("streaming seg $segIdx consumer error: ${e.message}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Producer: run the interleaved talker/CP loop and dispatch each
|
||||||
|
// SEQ_LEN-aligned window of codes immediately. The consumer's
|
||||||
|
// crossfader holds back the last `overlapSamples` of audio per
|
||||||
|
// chunk, so the in-flight chunk's tail can still be mutated by the
|
||||||
|
// next chunk before being emitted; flushFinal() at end emits the
|
||||||
|
// last tail with fadeOut. End-of-stream is signalled by closing
|
||||||
|
// bvChan after the trailing partial chunk is sent.
|
||||||
|
val collected = mutableListOf<IntArray>()
|
||||||
|
var nextChunkStart = 0
|
||||||
|
|
||||||
|
fun buildChunkCb(start: Int, real: Int): Array<IntArray> = Array(NUM_CODEBOOKS) { cb ->
|
||||||
|
IntArray(SEQ_LEN) { t ->
|
||||||
|
val src = start + t
|
||||||
|
if (src < start + real) {
|
||||||
|
val v = collected[src][cb]
|
||||||
|
if (v in 0 until CODEBOOK_SIZE) v else 0
|
||||||
|
} else 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
runInterleavedPteFromEmbeds(prefill, emptyList(), maxGen) { _, codes ->
|
||||||
|
collected.add(codes)
|
||||||
|
while (collected.size >= nextChunkStart + SEQ_LEN) {
|
||||||
|
val cb = buildChunkCb(nextChunkStart, SEQ_LEN)
|
||||||
|
kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, SEQ_LEN)) }
|
||||||
|
nextChunkStart += EFFECTIVE_CHUNK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
nlog("streaming seg $segIdx producer error: ${e.message}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trailing chunk: any remaining tokens after the last full window
|
||||||
|
// (covers both the medium-segment partial-tail case and the
|
||||||
|
// short-segment <SEQ_LEN single-chunk case where nextChunkStart=0).
|
||||||
|
val total = collected.size
|
||||||
|
if (total > nextChunkStart) {
|
||||||
|
val trailing = total - nextChunkStart
|
||||||
|
val cb = buildChunkCb(nextChunkStart, trailing)
|
||||||
|
kotlinx.coroutines.runBlocking { bvChan.send(ChunkMsg(cb, trailing)) }
|
||||||
|
}
|
||||||
|
bvChan.close()
|
||||||
|
consumerJob.join()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the Hexagon talker + CP generation loop with a fully pre-built
|
* Run the Hexagon talker + CP generation loop with a fully pre-built
|
||||||
* prefill (voice prefix + all text tokens). Same decode recipe as
|
* prefill (voice prefix + all text tokens). Same decode recipe as
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue