From 5e416713ceb7d6b8d268ed22d9041c965060e24a Mon Sep 17 00:00:00 2001 From: Kazeia Team Date: Mon, 13 Apr 2026 08:43:30 +0200 Subject: [PATCH] TTS Stage 1 streaming: play each segment the moment it's decoded MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a streaming multi-segment pipeline on top of the Hexagon talker + ONNX CP backend. First audio arrives at ~20s (vs ~65s for the full phrase non-streamed) on the Baer 16.56s reference (3-segment split). Voice cloning is preserved per segment because each segment now ships its own full prefill. Changes: * Qwen3TtsEngine.generateFromEmbedsHexagonStreaming(path, onSegmentReady) reads single- or multi-segment embeds, runs prefill + generation + VQ decode + BigVGAN per segment, and fires the callback with each segment's ShortArray the moment it's ready. Saves per-segment WAVs (kazeia_stream_seg{N}.wav) plus the concatenated kazeia_stream_full.wav for offline inspection. Extracted the common generation loop into runHexSegmentFromEmbeds(prefill, trailing, idx) so single-segment and streaming paths share exactly the same code (no quality drift between modes). Added hexReset() between segments so segment 2's prefill logits don't contain segment 1's KV state. * vqDecode buffer overrun fix: when the talker samples CODEC_EOS as cb0 it stores a vocab id > CODEBOOK_SIZE, which vqDecode then used as a codebook row index — reading past the 2048-row buffer. The short Baer probe never hit this; longer phrases do. Clamp any out-of-vocab code to 0 at allCodebooks build time. * KazeiaService: new stream_pipeline intent extra wires the callback to an AudioTrack MODE_STREAM instance, writing each segment's audio as soon as it comes back. Logs time-to-first-audio. * prepare_tts_segments.py: the previous version only captured 1-token decode calls and substituted a generic 9-embed "prefill_base" pulled from an unrelated single-segment file — dropping the per-segment xvector conditioning AND the text-encoded embeddings, so Hexagon produced garbled mixed speech for segments 2..N. Now captures the multi-token prefill call too (like prepare_tts_voiceclone.py) so each segment is self-contained. Limitation (documented, not fixed in this commit): RTF ~4.4 > 1 on the Snapdragon 8 Elite with current config means each segment takes longer to generate than it takes to play, so audible gaps between segments remain. Removing the gaps requires either (a) producer/consumer parallelism across two coroutines (doesn't help if RTF stays > 1), or (b) faster CP (the ~180ms/step ONNX MLAS CP is the bottleneck; Hexagon HMX has a known NaN bug and the .pte path contends with Hexagon talker on the DSP). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../java/com/kazeia/service/KazeiaService.kt | 45 +++++ .../java/com/kazeia/tts/Qwen3TtsEngine.kt | 176 ++++++++++++++++++ scripts/prepare_tts_segments.py | 62 +++--- 3 files changed, 252 insertions(+), 31 deletions(-) diff --git a/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt b/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt index 6730c42..4411c77 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt @@ -123,6 +123,51 @@ class KazeiaService : Service() { // Audio is played by the TTS engine internally } } + intent?.getStringExtra("stream_pipeline")?.let { embedsPath -> + // Stage 1 streaming pipeline: generate segment-by-segment and play each + // segment the moment it's ready via an AudioTrack MODE_STREAM. First audio + // arrives ~5s after request (time to generate 1 segment) instead of ~15s + // for the whole phrase. Per-segment WAVs + concatenated full WAV are + // written to /data/local/tmp/kazeia/kazeia_stream_seg*.wav and _full.wav + // by the engine itself — the service only handles playback. + log("Stream pipeline from pre-computed embeds: $embedsPath") + serviceScope.launch { + try { + val qwenTts = tts as? com.kazeia.tts.Qwen3TtsEngine ?: return@launch + val sr = 24000 + val track = android.media.AudioTrack.Builder() + .setAudioAttributes(android.media.AudioAttributes.Builder() + .setUsage(android.media.AudioAttributes.USAGE_MEDIA) + .setContentType(android.media.AudioAttributes.CONTENT_TYPE_SPEECH) + .build()) + .setAudioFormat(android.media.AudioFormat.Builder() + .setEncoding(android.media.AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(sr) + .setChannelMask(android.media.AudioFormat.CHANNEL_OUT_MONO) + .build()) + .setBufferSizeInBytes(sr * 4) // 2s mono pcm16 buffer, plenty for seg handoff + .setTransferMode(android.media.AudioTrack.MODE_STREAM) + .build() + track.play() + val tStart = System.currentTimeMillis() + var firstAudioLogged = false + qwenTts.generateFromEmbedsHexagonStreaming(embedsPath) { segIdx, audio -> + if (!firstAudioLogged) { + log("First audio out at ${System.currentTimeMillis() - tStart}ms (seg ${segIdx+1})") + firstAudioLogged = true + } + track.write(audio, 0, audio.size) + } + // Let AudioTrack drain the written samples before releasing. + track.stop() + track.release() + log("Stream pipeline done at ${System.currentTimeMillis() - tStart}ms") + } catch (e: Exception) { + log("Stream pipeline error: ${e.message}") + e.printStackTrace() + } + } + } intent?.getStringExtra("full_pipeline")?.let { embedsPath -> val savePath = intent.getStringExtra("save_wav") ?: "/data/local/tmp/kazeia/tts_output.wav" log("Full pipeline from pre-computed embeds: $embedsPath") diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 7339951..4088624 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -3026,6 +3026,182 @@ class Qwen3TtsEngine( return audio } + /** + * Run the Hexagon talker + CP generation loop on a single segment's embeds + * and return the decoded audio. Extracted from generateFromEmbedsHexagon so + * both single-segment playback and the streaming multi-segment path share + * exactly the same generation path (no quality drift between modes). + * + * Caller is responsible for hexReset() before first call of a request. + * Subsequent calls (segments 2..N in multi-segment mode) must hexReset() + * between segments so the talker KV-cache doesn't carry stale context. + */ + private fun runHexSegmentFromEmbeds( + prefillEmbeds: List, + trailingEmbeds: List, + segIdx: Int = 0 + ): ShortArray { + val allCodes = mutableListOf() + val generatedCb0 = mutableListOf() + var totalTalkerMs = 0L; var totalCpMs = 0L + + // Prefill + val tPrefill = System.currentTimeMillis() + val prefillResults = hexForward(prefillEmbeds) + nlog("Seg ${segIdx+1} prefill: ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps") + if (prefillResults.isEmpty()) return ShortArray(0) + + var pastHidden = prefillResults.last().first + val prefillLogits = prefillResults.last().second + for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY } + var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50) + + val nTrailing = trailingEmbeds.size + for (genStep in 0 until nTrailing) { + val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 + val tCp = System.currentTimeMillis() + val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0) + totalCpMs += System.currentTimeMillis() - tCp + for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1] + allCodes.add(codes); generatedCb0.add(currentCb0) + + val nextEmbed = trailingEmbeds[genStep] + val tT = System.currentTimeMillis() + val results = hexForward(listOf(nextEmbed)) + totalTalkerMs += System.currentTimeMillis() - tT + if (results.isEmpty()) { nlog("Seg ${segIdx+1}: hex empty at step ${genStep+1}"); break } + pastHidden = results[0].first + val logits = results[0].second + for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } + val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) + for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } + currentCb0 = sampleTopK(logits, 0.9f, 50) + } + + val n = allCodes.size + nlog("Seg ${segIdx+1} generated $n tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms") + if (n == 0) return ShortArray(0) + + // cb0 can hit CODEC_EOS (> CODEBOOK_SIZE) on longer phrases — the original + // single-segment Hexagon path never exercises this because the short Baer + // probe stayed well inside the decode budget. Clamp any out-of-vocab code + // to 0 (silence) so vqDecode can't read past the codebook buffer. + val padLen = maxOf(n, SEQ_LEN) + val allCodebooks = Array(NUM_CODEBOOKS) { cb -> + IntArray(padLen) { t -> + if (t < n) { val v = allCodes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0 + } + } + return decodeChunked(allCodebooks, n) + } + + /** + * Streaming multi-segment variant of generateFromEmbeds. Reads a multi-segment + * embeds file, generates each segment via Hexagon talker + CP sequentially, + * and invokes `onSegmentReady(idx, audio)` the moment each segment's audio is + * decoded. The callback writes to an AudioTrack in the calling coroutine so + * playback begins as soon as segment 1 finishes (~5s for a 5s segment, + * instead of ~15s for the full phrase). + * + * Each segment's raw audio is saved to /data/local/tmp/kazeia/kazeia_stream_segN.wav + * and the final concatenated audio to /data/local/tmp/kazeia/kazeia_stream_full.wav + * so the caller can inspect individual segments for quality regressions. + * + * Single-segment files are supported as a degenerate case (nSegments=1) so + * the caller doesn't need to branch on format. + */ + fun generateFromEmbedsHexagonStreaming( + embedsPath: String, + onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null + ): ShortArray { + if (!loaded || !useHexagonTalker) { + nlog("Streaming: Hexagon talker not ready") + return ShortArray(0) + } + val t0 = System.currentTimeMillis() + val bytes = File(embedsPath).readBytes() + val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + + // Format detection (same logic as generateFromEmbedsPte): + // single: + // multi : [ ] × nSegments + val firstInt = bb.int + val secondInt = bb.int + val fileLen = bytes.size.toLong() + val singleSize = 8L + secondInt.toLong() * TALKER_DIM * 4 + val isSingle = secondInt > 0 && secondInt < 100000 && fileLen == singleSize + val nSegments = if (isSingle) 1 else firstInt + bb.position(if (isSingle) 0 else 4) + nlog("Streaming: $nSegments segment(s), ${bytes.size} bytes") + + // Ensure a fresh runner connection for this request. Between requests + // the KV-cache carries stale state from the previous generation and + // prefill logits come out as garbage on segment 1. + hexReset() + + val segmentAudios = mutableListOf() + val gapSamples = SR * 120 / 1000 + val gap = ShortArray(gapSamples) + + for (seg in 0 until nSegments) { + // Between segments the talker KV-cache must be reset so segment 2's + // prefill logits don't contain segment 1's state. Skipping this + // produces garbled speech from segment 2 onwards. + if (seg > 0) hexReset() + + val nPrefill = bb.int + val nTotal = bb.int + val prefill = List(nPrefill) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } } + val nTrailing = nTotal - nPrefill + val trailing = List(nTrailing) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } } + nlog("Streaming seg ${seg+1}/$nSegments: $nPrefill prefill + $nTrailing decode") + + val tSeg = System.currentTimeMillis() + val audio = runHexSegmentFromEmbeds(prefill, trailing, seg) + val segMs = System.currentTimeMillis() - tSeg + nlog("Streaming seg ${seg+1}/$nSegments: ${audio.size/SR.toFloat()}s audio in ${segMs}ms") + + // Emit to caller immediately so playback starts now. WAV dump is + // synchronous here; for true zero-lag streaming the caller can + // write to AudioTrack on its own dispatcher from the callback. + segmentAudios.add(audio) + saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${seg+1}.wav", audio) + onSegmentReady?.invoke(seg, audio) + } + + // Concatenate with short gaps between segments for the full-file WAV. + // Playback path already inserted perceptual spacing via the callback order. + val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples + val concat = ShortArray(total) + var off = 0 + for ((i, s) in segmentAudios.withIndex()) { + System.arraycopy(s, 0, concat, off, s.size); off += s.size + if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples } + } + saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat) + nlog("Streaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s ($nSegments seg)") + return concat + } + + /** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to + * save one file per segment plus the concatenated result for inspection. */ + private fun saveWav(path: String, audio: ShortArray) { + try { + val fos = java.io.FileOutputStream(path) + val dataLen = audio.size * 2 + val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN) + header.put("RIFF".toByteArray()); header.putInt(36 + dataLen) + header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray()) + header.putInt(16); header.putShort(1); header.putShort(1) + header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16) + header.put("data".toByteArray()); header.putInt(dataLen) + fos.write(header.array()) + val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN) + for (s in audio) buf.putShort(s) + fos.write(buf.array()); fos.close() + } catch (e: Exception) { nlog("WAV save failed ($path): ${e.message}") } + } + /** Test with pre-computed codec tokens from PC (for validation) */ fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray { if (!loaded) return ShortArray(0) diff --git a/scripts/prepare_tts_segments.py b/scripts/prepare_tts_segments.py index 6125d22..43a0017 100644 --- a/scripts/prepare_tts_segments.py +++ b/scripts/prepare_tts_segments.py @@ -59,25 +59,25 @@ print("\nLoading model...") tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") talker = tts.model.talker -# Capture generation inputs via monkey-patch on inner model -captured_inputs = [] +# Capture generation inputs. The previous version of this script only captured +# 1-token decode calls and reassembled them with a generic 9-embed "prefill_base" +# pulled from another file — dropping the per-segment prefill that contains the +# xvector conditioning AND the text-encoded embeddings. With that generic prefix +# the talker had no idea which sentence to produce → Hexagon output was garbled. +# Fix: capture the MULTI-token prefill call too (first call has shape[1] > 1), +# exactly like prepare_tts_voiceclone.py does. Each segment becomes self-contained. +captured_embeds = [] # 1024-dim vectors in order +call_shapes = [] # length of each talker.model.forward call original_model_forward = talker.model.forward def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs): - if inputs_embeds is not None and inputs_embeds.shape[1] == 1: - captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32)) + if inputs_embeds is not None and inputs_embeds.dim() == 3: + t = inputs_embeds.shape[1] + call_shapes.append(t) + for j in range(t): + captured_embeds.append(inputs_embeds[0, j, :].detach().cpu().numpy().astype(np.float32)) return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) talker.model.forward = patched_model_forward -# Load prefill structure -EXISTING = "/tmp/existing_embeds.bin" -if not os.path.exists(EXISTING): - os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}") -with open(EXISTING, "rb") as f: - nP = struct.unpack("