TTS Stage 1 streaming: play each segment the moment it's decoded

Adds a streaming multi-segment pipeline on top of the Hexagon talker + ONNX
CP backend. First audio arrives at ~20s (vs ~65s for the full phrase
non-streamed) on the Baer 16.56s reference (3-segment split). Voice cloning
is preserved per segment because each segment now ships its own full prefill.

Changes:

  * Qwen3TtsEngine.generateFromEmbedsHexagonStreaming(path, onSegmentReady)
    reads single- or multi-segment embeds, runs prefill + generation + VQ
    decode + BigVGAN per segment, and fires the callback with each
    segment's ShortArray the moment it's ready. Saves per-segment WAVs
    (kazeia_stream_seg{N}.wav) plus the concatenated kazeia_stream_full.wav
    for offline inspection. Extracted the common generation loop into
    runHexSegmentFromEmbeds(prefill, trailing, idx) so single-segment and
    streaming paths share exactly the same code (no quality drift between
    modes). Added hexReset() between segments so segment 2's prefill logits
    don't contain segment 1's KV state.

  * vqDecode buffer overrun fix: when the talker samples CODEC_EOS as cb0
    it stores a vocab id > CODEBOOK_SIZE, which vqDecode then used as a
    codebook row index — reading past the 2048-row buffer. The short Baer
    probe never hit this; longer phrases do. Clamp any out-of-vocab code
    to 0 at allCodebooks build time.

  * KazeiaService: new stream_pipeline intent extra wires the callback
    to an AudioTrack MODE_STREAM instance, writing each segment's audio as
    soon as it comes back. Logs time-to-first-audio.

  * prepare_tts_segments.py: the previous version only captured 1-token
    decode calls and substituted a generic 9-embed "prefill_base" pulled
    from an unrelated single-segment file — dropping the per-segment
    xvector conditioning AND the text-encoded embeddings, so Hexagon
    produced garbled mixed speech for segments 2..N. Now captures the
    multi-token prefill call too (like prepare_tts_voiceclone.py) so each
    segment is self-contained.

Limitation (documented, not fixed in this commit): RTF ~4.4 > 1 on the
Snapdragon 8 Elite with current config means each segment takes longer to
generate than it takes to play, so audible gaps between segments remain.
Removing the gaps requires either (a) producer/consumer parallelism across
two coroutines (doesn't help if RTF stays > 1), or (b) faster CP (the
~180ms/step ONNX MLAS CP is the bottleneck; Hexagon HMX has a known NaN bug
and the .pte path contends with Hexagon talker on the DSP).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-13 08:43:30 +02:00
parent de878ddf5c
commit 5e416713ce
3 changed files with 252 additions and 31 deletions

View File

@ -123,6 +123,51 @@ class KazeiaService : Service() {
// Audio is played by the TTS engine internally // Audio is played by the TTS engine internally
} }
} }
intent?.getStringExtra("stream_pipeline")?.let { embedsPath ->
// Stage 1 streaming pipeline: generate segment-by-segment and play each
// segment the moment it's ready via an AudioTrack MODE_STREAM. First audio
// arrives ~5s after request (time to generate 1 segment) instead of ~15s
// for the whole phrase. Per-segment WAVs + concatenated full WAV are
// written to /data/local/tmp/kazeia/kazeia_stream_seg*.wav and _full.wav
// by the engine itself — the service only handles playback.
log("Stream pipeline from pre-computed embeds: $embedsPath")
serviceScope.launch {
try {
val qwenTts = tts as? com.kazeia.tts.Qwen3TtsEngine ?: return@launch
val sr = 24000
val track = android.media.AudioTrack.Builder()
.setAudioAttributes(android.media.AudioAttributes.Builder()
.setUsage(android.media.AudioAttributes.USAGE_MEDIA)
.setContentType(android.media.AudioAttributes.CONTENT_TYPE_SPEECH)
.build())
.setAudioFormat(android.media.AudioFormat.Builder()
.setEncoding(android.media.AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(sr)
.setChannelMask(android.media.AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(sr * 4) // 2s mono pcm16 buffer, plenty for seg handoff
.setTransferMode(android.media.AudioTrack.MODE_STREAM)
.build()
track.play()
val tStart = System.currentTimeMillis()
var firstAudioLogged = false
qwenTts.generateFromEmbedsHexagonStreaming(embedsPath) { segIdx, audio ->
if (!firstAudioLogged) {
log("First audio out at ${System.currentTimeMillis() - tStart}ms (seg ${segIdx+1})")
firstAudioLogged = true
}
track.write(audio, 0, audio.size)
}
// Let AudioTrack drain the written samples before releasing.
track.stop()
track.release()
log("Stream pipeline done at ${System.currentTimeMillis() - tStart}ms")
} catch (e: Exception) {
log("Stream pipeline error: ${e.message}")
e.printStackTrace()
}
}
}
intent?.getStringExtra("full_pipeline")?.let { embedsPath -> intent?.getStringExtra("full_pipeline")?.let { embedsPath ->
val savePath = intent.getStringExtra("save_wav") ?: "/data/local/tmp/kazeia/tts_output.wav" val savePath = intent.getStringExtra("save_wav") ?: "/data/local/tmp/kazeia/tts_output.wav"
log("Full pipeline from pre-computed embeds: $embedsPath") log("Full pipeline from pre-computed embeds: $embedsPath")

View File

@ -3026,6 +3026,182 @@ class Qwen3TtsEngine(
return audio return audio
} }
/**
* Run the Hexagon talker + CP generation loop on a single segment's embeds
* and return the decoded audio. Extracted from generateFromEmbedsHexagon so
* both single-segment playback and the streaming multi-segment path share
* exactly the same generation path (no quality drift between modes).
*
* Caller is responsible for hexReset() before first call of a request.
* Subsequent calls (segments 2..N in multi-segment mode) must hexReset()
* between segments so the talker KV-cache doesn't carry stale context.
*/
private fun runHexSegmentFromEmbeds(
prefillEmbeds: List<FloatArray>,
trailingEmbeds: List<FloatArray>,
segIdx: Int = 0
): ShortArray {
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var totalTalkerMs = 0L; var totalCpMs = 0L
// Prefill
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(prefillEmbeds)
nlog("Seg ${segIdx+1} prefill: ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
if (prefillResults.isEmpty()) return ShortArray(0)
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
val nTrailing = trailingEmbeds.size
for (genStep in 0 until nTrailing) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
val nextEmbed = trailingEmbeds[genStep]
val tT = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
totalTalkerMs += System.currentTimeMillis() - tT
if (results.isEmpty()) { nlog("Seg ${segIdx+1}: hex empty at step ${genStep+1}"); break }
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
currentCb0 = sampleTopK(logits, 0.9f, 50)
}
val n = allCodes.size
nlog("Seg ${segIdx+1} generated $n tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
if (n == 0) return ShortArray(0)
// cb0 can hit CODEC_EOS (> CODEBOOK_SIZE) on longer phrases — the original
// single-segment Hexagon path never exercises this because the short Baer
// probe stayed well inside the decode budget. Clamp any out-of-vocab code
// to 0 (silence) so vqDecode can't read past the codebook buffer.
val padLen = maxOf(n, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t ->
if (t < n) { val v = allCodes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
}
}
return decodeChunked(allCodebooks, n)
}
/**
* Streaming multi-segment variant of generateFromEmbeds. Reads a multi-segment
* embeds file, generates each segment via Hexagon talker + CP sequentially,
* and invokes `onSegmentReady(idx, audio)` the moment each segment's audio is
* decoded. The callback writes to an AudioTrack in the calling coroutine so
* playback begins as soon as segment 1 finishes (~5s for a 5s segment,
* instead of ~15s for the full phrase).
*
* Each segment's raw audio is saved to /data/local/tmp/kazeia/kazeia_stream_segN.wav
* and the final concatenated audio to /data/local/tmp/kazeia/kazeia_stream_full.wav
* so the caller can inspect individual segments for quality regressions.
*
* Single-segment files are supported as a degenerate case (nSegments=1) so
* the caller doesn't need to branch on format.
*/
fun generateFromEmbedsHexagonStreaming(
embedsPath: String,
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
): ShortArray {
if (!loaded || !useHexagonTalker) {
nlog("Streaming: Hexagon talker not ready")
return ShortArray(0)
}
val t0 = System.currentTimeMillis()
val bytes = File(embedsPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
// Format detection (same logic as generateFromEmbedsPte):
// single: <i32 nPrefill> <i32 nTotal> <f32 × nTotal × 1024>
// multi : <i32 nSegments> [<i32 nPrefill_i> <i32 nTotal_i> <f32 ...>] × nSegments
val firstInt = bb.int
val secondInt = bb.int
val fileLen = bytes.size.toLong()
val singleSize = 8L + secondInt.toLong() * TALKER_DIM * 4
val isSingle = secondInt > 0 && secondInt < 100000 && fileLen == singleSize
val nSegments = if (isSingle) 1 else firstInt
bb.position(if (isSingle) 0 else 4)
nlog("Streaming: $nSegments segment(s), ${bytes.size} bytes")
// Ensure a fresh runner connection for this request. Between requests
// the KV-cache carries stale state from the previous generation and
// prefill logits come out as garbage on segment 1.
hexReset()
val segmentAudios = mutableListOf<ShortArray>()
val gapSamples = SR * 120 / 1000
val gap = ShortArray(gapSamples)
for (seg in 0 until nSegments) {
// Between segments the talker KV-cache must be reset so segment 2's
// prefill logits don't contain segment 1's state. Skipping this
// produces garbled speech from segment 2 onwards.
if (seg > 0) hexReset()
val nPrefill = bb.int
val nTotal = bb.int
val prefill = List(nPrefill) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
val nTrailing = nTotal - nPrefill
val trailing = List(nTrailing) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Streaming seg ${seg+1}/$nSegments: $nPrefill prefill + $nTrailing decode")
val tSeg = System.currentTimeMillis()
val audio = runHexSegmentFromEmbeds(prefill, trailing, seg)
val segMs = System.currentTimeMillis() - tSeg
nlog("Streaming seg ${seg+1}/$nSegments: ${audio.size/SR.toFloat()}s audio in ${segMs}ms")
// Emit to caller immediately so playback starts now. WAV dump is
// synchronous here; for true zero-lag streaming the caller can
// write to AudioTrack on its own dispatcher from the callback.
segmentAudios.add(audio)
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${seg+1}.wav", audio)
onSegmentReady?.invoke(seg, audio)
}
// Concatenate with short gaps between segments for the full-file WAV.
// Playback path already inserted perceptual spacing via the callback order.
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
val concat = ShortArray(total)
var off = 0
for ((i, s) in segmentAudios.withIndex()) {
System.arraycopy(s, 0, concat, off, s.size); off += s.size
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
}
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
nlog("Streaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s ($nSegments seg)")
return concat
}
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
* save one file per segment plus the concatenated result for inspection. */
private fun saveWav(path: String, audio: ShortArray) {
try {
val fos = java.io.FileOutputStream(path)
val dataLen = audio.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
} catch (e: Exception) { nlog("WAV save failed ($path): ${e.message}") }
}
/** Test with pre-computed codec tokens from PC (for validation) */ /** Test with pre-computed codec tokens from PC (for validation) */
fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray { fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray {
if (!loaded) return ShortArray(0) if (!loaded) return ShortArray(0)

View File

@ -59,25 +59,25 @@ print("\nLoading model...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker talker = tts.model.talker
# Capture generation inputs via monkey-patch on inner model # Capture generation inputs. The previous version of this script only captured
captured_inputs = [] # 1-token decode calls and reassembled them with a generic 9-embed "prefill_base"
# pulled from another file — dropping the per-segment prefill that contains the
# xvector conditioning AND the text-encoded embeddings. With that generic prefix
# the talker had no idea which sentence to produce → Hexagon output was garbled.
# Fix: capture the MULTI-token prefill call too (first call has shape[1] > 1),
# exactly like prepare_tts_voiceclone.py does. Each segment becomes self-contained.
captured_embeds = [] # 1024-dim vectors in order
call_shapes = [] # length of each talker.model.forward call
original_model_forward = talker.model.forward original_model_forward = talker.model.forward
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs): def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.shape[1] == 1: if inputs_embeds is not None and inputs_embeds.dim() == 3:
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32)) t = inputs_embeds.shape[1]
call_shapes.append(t)
for j in range(t):
captured_embeds.append(inputs_embeds[0, j, :].detach().cpu().numpy().astype(np.float32))
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
talker.model.forward = patched_model_forward talker.model.forward = patched_model_forward
# Load prefill structure
EXISTING = "/tmp/existing_embeds.bin"
if not os.path.exists(EXISTING):
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
with open(EXISTING, "rb") as f:
nP = struct.unpack("<i", f.read(4))[0]
nT = struct.unpack("<i", f.read(4))[0]
old_embeds = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
prefill_base = old_embeds[:9] # role+ctrl+spk+bos
# Generate embeds for each segment # Generate embeds for each segment
all_segments_data = [] all_segments_data = []
total_audio_duration = 0 total_audio_duration = 0
@ -86,7 +86,8 @@ for i, seg_text in enumerate(segments):
print(f"\n--- Segment {i+1}/{len(segments)} ---") print(f"\n--- Segment {i+1}/{len(segments)} ---")
print(f" '{seg_text[:60]}...'") print(f" '{seg_text[:60]}...'")
captured_inputs.clear() captured_embeds.clear()
call_shapes.clear()
audio_list, sr = tts.generate_voice_clone( audio_list, sr = tts.generate_voice_clone(
text=seg_text, ref_audio=VOICE, language="french", text=seg_text, ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True, x_vector_only_mode=True, non_streaming_mode=True,
@ -94,40 +95,39 @@ for i, seg_text in enumerate(segments):
audio_dur = len(audio_list[0]) / sr audio_dur = len(audio_list[0]) / sr
total_audio_duration += audio_dur total_audio_duration += audio_dur
if len(captured_inputs) < 2: if not call_shapes:
print(f" WARNING: Only {len(captured_inputs)} steps, skipping") print(f" WARNING: no calls captured, skipping")
continue continue
nPrefill = 10 # 9 base + first gen input # First call = multi-token prefill (xvector + text encoding + role tokens).
nDecode = len(captured_inputs) - 1 # Subsequent calls = single-token decode (one per generated codec frame).
nTotal = nPrefill + nDecode nPrefill = call_shapes[0]
nTotal = len(captured_embeds)
if nTotal <= nPrefill:
print(f" WARNING: no decode steps captured, skipping")
continue
seg_data = { seg_data = {
'nPrefill': nPrefill, 'nPrefill': nPrefill,
'nTotal': nTotal, 'nTotal': nTotal,
'prefill': prefill_base.copy(), 'embeds': captured_embeds.copy(),
'first_gen': captured_inputs[0],
'decode': captured_inputs[1:],
'audio_dur': audio_dur, 'audio_dur': audio_dur,
'n_tokens': len(captured_inputs),
} }
all_segments_data.append(seg_data) all_segments_data.append(seg_data)
print(f" {nTotal} embeds ({len(captured_inputs)} tokens), {audio_dur:.2f}s audio") print(f" {nTotal} embeds ({nPrefill} prefill + {nTotal - nPrefill} decode), {audio_dur:.2f}s audio")
# Write multi-segment file # Write multi-segment file. Per-segment layout matches single-segment format so
# the tablet can read either shape with the same parser.
with open(OUTPUT, "wb") as f: with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", len(all_segments_data))) f.write(struct.pack("<i", len(all_segments_data)))
for seg in all_segments_data: for seg in all_segments_data:
f.write(struct.pack("<i", seg['nPrefill'])) f.write(struct.pack("<i", seg['nPrefill']))
f.write(struct.pack("<i", seg['nTotal'])) f.write(struct.pack("<i", seg['nTotal']))
for emb in seg['prefill']: for emb in seg['embeds']:
f.write(emb.tobytes())
f.write(seg['first_gen'].tobytes())
for emb in seg['decode']:
f.write(emb.tobytes()) f.write(emb.tobytes())
sz = os.path.getsize(OUTPUT) sz = os.path.getsize(OUTPUT)
total_tokens = sum(s['n_tokens'] for s in all_segments_data) total_tokens = sum(s['nTotal'] - s['nPrefill'] for s in all_segments_data)
print(f"\n=== RESULT ===") print(f"\n=== RESULT ===")
print(f"Segments: {len(all_segments_data)}") print(f"Segments: {len(all_segments_data)}")
print(f"Total tokens: {total_tokens}") print(f"Total tokens: {total_tokens}")