TTS Stage 1 streaming: play each segment the moment it's decoded
Adds a streaming multi-segment pipeline on top of the Hexagon talker + ONNX
CP backend. First audio arrives at ~20s (vs ~65s for the full phrase
non-streamed) on the Baer 16.56s reference (3-segment split). Voice cloning
is preserved per segment because each segment now ships its own full prefill.
Changes:
* Qwen3TtsEngine.generateFromEmbedsHexagonStreaming(path, onSegmentReady)
reads single- or multi-segment embeds, runs prefill + generation + VQ
decode + BigVGAN per segment, and fires the callback with each
segment's ShortArray the moment it's ready. Saves per-segment WAVs
(kazeia_stream_seg{N}.wav) plus the concatenated kazeia_stream_full.wav
for offline inspection. Extracted the common generation loop into
runHexSegmentFromEmbeds(prefill, trailing, idx) so single-segment and
streaming paths share exactly the same code (no quality drift between
modes). Added hexReset() between segments so segment 2's prefill logits
don't contain segment 1's KV state.
* vqDecode buffer overrun fix: when the talker samples CODEC_EOS as cb0
it stores a vocab id > CODEBOOK_SIZE, which vqDecode then used as a
codebook row index — reading past the 2048-row buffer. The short Baer
probe never hit this; longer phrases do. Clamp any out-of-vocab code
to 0 at allCodebooks build time.
* KazeiaService: new stream_pipeline intent extra wires the callback
to an AudioTrack MODE_STREAM instance, writing each segment's audio as
soon as it comes back. Logs time-to-first-audio.
* prepare_tts_segments.py: the previous version only captured 1-token
decode calls and substituted a generic 9-embed "prefill_base" pulled
from an unrelated single-segment file — dropping the per-segment
xvector conditioning AND the text-encoded embeddings, so Hexagon
produced garbled mixed speech for segments 2..N. Now captures the
multi-token prefill call too (like prepare_tts_voiceclone.py) so each
segment is self-contained.
Limitation (documented, not fixed in this commit): RTF ~4.4 > 1 on the
Snapdragon 8 Elite with current config means each segment takes longer to
generate than it takes to play, so audible gaps between segments remain.
Removing the gaps requires either (a) producer/consumer parallelism across
two coroutines (doesn't help if RTF stays > 1), or (b) faster CP (the
~180ms/step ONNX MLAS CP is the bottleneck; Hexagon HMX has a known NaN bug
and the .pte path contends with Hexagon talker on the DSP).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
de878ddf5c
commit
5e416713ce
|
|
@ -123,6 +123,51 @@ class KazeiaService : Service() {
|
|||
// Audio is played by the TTS engine internally
|
||||
}
|
||||
}
|
||||
intent?.getStringExtra("stream_pipeline")?.let { embedsPath ->
|
||||
// Stage 1 streaming pipeline: generate segment-by-segment and play each
|
||||
// segment the moment it's ready via an AudioTrack MODE_STREAM. First audio
|
||||
// arrives ~5s after request (time to generate 1 segment) instead of ~15s
|
||||
// for the whole phrase. Per-segment WAVs + concatenated full WAV are
|
||||
// written to /data/local/tmp/kazeia/kazeia_stream_seg*.wav and _full.wav
|
||||
// by the engine itself — the service only handles playback.
|
||||
log("Stream pipeline from pre-computed embeds: $embedsPath")
|
||||
serviceScope.launch {
|
||||
try {
|
||||
val qwenTts = tts as? com.kazeia.tts.Qwen3TtsEngine ?: return@launch
|
||||
val sr = 24000
|
||||
val track = android.media.AudioTrack.Builder()
|
||||
.setAudioAttributes(android.media.AudioAttributes.Builder()
|
||||
.setUsage(android.media.AudioAttributes.USAGE_MEDIA)
|
||||
.setContentType(android.media.AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.build())
|
||||
.setAudioFormat(android.media.AudioFormat.Builder()
|
||||
.setEncoding(android.media.AudioFormat.ENCODING_PCM_16BIT)
|
||||
.setSampleRate(sr)
|
||||
.setChannelMask(android.media.AudioFormat.CHANNEL_OUT_MONO)
|
||||
.build())
|
||||
.setBufferSizeInBytes(sr * 4) // 2s mono pcm16 buffer, plenty for seg handoff
|
||||
.setTransferMode(android.media.AudioTrack.MODE_STREAM)
|
||||
.build()
|
||||
track.play()
|
||||
val tStart = System.currentTimeMillis()
|
||||
var firstAudioLogged = false
|
||||
qwenTts.generateFromEmbedsHexagonStreaming(embedsPath) { segIdx, audio ->
|
||||
if (!firstAudioLogged) {
|
||||
log("First audio out at ${System.currentTimeMillis() - tStart}ms (seg ${segIdx+1})")
|
||||
firstAudioLogged = true
|
||||
}
|
||||
track.write(audio, 0, audio.size)
|
||||
}
|
||||
// Let AudioTrack drain the written samples before releasing.
|
||||
track.stop()
|
||||
track.release()
|
||||
log("Stream pipeline done at ${System.currentTimeMillis() - tStart}ms")
|
||||
} catch (e: Exception) {
|
||||
log("Stream pipeline error: ${e.message}")
|
||||
e.printStackTrace()
|
||||
}
|
||||
}
|
||||
}
|
||||
intent?.getStringExtra("full_pipeline")?.let { embedsPath ->
|
||||
val savePath = intent.getStringExtra("save_wav") ?: "/data/local/tmp/kazeia/tts_output.wav"
|
||||
log("Full pipeline from pre-computed embeds: $embedsPath")
|
||||
|
|
|
|||
|
|
@ -3026,6 +3026,182 @@ class Qwen3TtsEngine(
|
|||
return audio
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the Hexagon talker + CP generation loop on a single segment's embeds
|
||||
* and return the decoded audio. Extracted from generateFromEmbedsHexagon so
|
||||
* both single-segment playback and the streaming multi-segment path share
|
||||
* exactly the same generation path (no quality drift between modes).
|
||||
*
|
||||
* Caller is responsible for hexReset() before first call of a request.
|
||||
* Subsequent calls (segments 2..N in multi-segment mode) must hexReset()
|
||||
* between segments so the talker KV-cache doesn't carry stale context.
|
||||
*/
|
||||
private fun runHexSegmentFromEmbeds(
|
||||
prefillEmbeds: List<FloatArray>,
|
||||
trailingEmbeds: List<FloatArray>,
|
||||
segIdx: Int = 0
|
||||
): ShortArray {
|
||||
val allCodes = mutableListOf<IntArray>()
|
||||
val generatedCb0 = mutableListOf<Int>()
|
||||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||||
|
||||
// Prefill
|
||||
val tPrefill = System.currentTimeMillis()
|
||||
val prefillResults = hexForward(prefillEmbeds)
|
||||
nlog("Seg ${segIdx+1} prefill: ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
|
||||
if (prefillResults.isEmpty()) return ShortArray(0)
|
||||
|
||||
var pastHidden = prefillResults.last().first
|
||||
val prefillLogits = prefillResults.last().second
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||||
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
|
||||
|
||||
val nTrailing = trailingEmbeds.size
|
||||
for (genStep in 0 until nTrailing) {
|
||||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||||
val tCp = System.currentTimeMillis()
|
||||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||||
totalCpMs += System.currentTimeMillis() - tCp
|
||||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||||
|
||||
val nextEmbed = trailingEmbeds[genStep]
|
||||
val tT = System.currentTimeMillis()
|
||||
val results = hexForward(listOf(nextEmbed))
|
||||
totalTalkerMs += System.currentTimeMillis() - tT
|
||||
if (results.isEmpty()) { nlog("Seg ${segIdx+1}: hex empty at step ${genStep+1}"); break }
|
||||
pastHidden = results[0].first
|
||||
val logits = results[0].second
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||||
}
|
||||
|
||||
val n = allCodes.size
|
||||
nlog("Seg ${segIdx+1} generated $n tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
|
||||
if (n == 0) return ShortArray(0)
|
||||
|
||||
// cb0 can hit CODEC_EOS (> CODEBOOK_SIZE) on longer phrases — the original
|
||||
// single-segment Hexagon path never exercises this because the short Baer
|
||||
// probe stayed well inside the decode budget. Clamp any out-of-vocab code
|
||||
// to 0 (silence) so vqDecode can't read past the codebook buffer.
|
||||
val padLen = maxOf(n, SEQ_LEN)
|
||||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||||
IntArray(padLen) { t ->
|
||||
if (t < n) { val v = allCodes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
||||
}
|
||||
}
|
||||
return decodeChunked(allCodebooks, n)
|
||||
}
|
||||
|
||||
/**
|
||||
* Streaming multi-segment variant of generateFromEmbeds. Reads a multi-segment
|
||||
* embeds file, generates each segment via Hexagon talker + CP sequentially,
|
||||
* and invokes `onSegmentReady(idx, audio)` the moment each segment's audio is
|
||||
* decoded. The callback writes to an AudioTrack in the calling coroutine so
|
||||
* playback begins as soon as segment 1 finishes (~5s for a 5s segment,
|
||||
* instead of ~15s for the full phrase).
|
||||
*
|
||||
* Each segment's raw audio is saved to /data/local/tmp/kazeia/kazeia_stream_segN.wav
|
||||
* and the final concatenated audio to /data/local/tmp/kazeia/kazeia_stream_full.wav
|
||||
* so the caller can inspect individual segments for quality regressions.
|
||||
*
|
||||
* Single-segment files are supported as a degenerate case (nSegments=1) so
|
||||
* the caller doesn't need to branch on format.
|
||||
*/
|
||||
fun generateFromEmbedsHexagonStreaming(
|
||||
embedsPath: String,
|
||||
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
|
||||
): ShortArray {
|
||||
if (!loaded || !useHexagonTalker) {
|
||||
nlog("Streaming: Hexagon talker not ready")
|
||||
return ShortArray(0)
|
||||
}
|
||||
val t0 = System.currentTimeMillis()
|
||||
val bytes = File(embedsPath).readBytes()
|
||||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
// Format detection (same logic as generateFromEmbedsPte):
|
||||
// single: <i32 nPrefill> <i32 nTotal> <f32 × nTotal × 1024>
|
||||
// multi : <i32 nSegments> [<i32 nPrefill_i> <i32 nTotal_i> <f32 ...>] × nSegments
|
||||
val firstInt = bb.int
|
||||
val secondInt = bb.int
|
||||
val fileLen = bytes.size.toLong()
|
||||
val singleSize = 8L + secondInt.toLong() * TALKER_DIM * 4
|
||||
val isSingle = secondInt > 0 && secondInt < 100000 && fileLen == singleSize
|
||||
val nSegments = if (isSingle) 1 else firstInt
|
||||
bb.position(if (isSingle) 0 else 4)
|
||||
nlog("Streaming: $nSegments segment(s), ${bytes.size} bytes")
|
||||
|
||||
// Ensure a fresh runner connection for this request. Between requests
|
||||
// the KV-cache carries stale state from the previous generation and
|
||||
// prefill logits come out as garbage on segment 1.
|
||||
hexReset()
|
||||
|
||||
val segmentAudios = mutableListOf<ShortArray>()
|
||||
val gapSamples = SR * 120 / 1000
|
||||
val gap = ShortArray(gapSamples)
|
||||
|
||||
for (seg in 0 until nSegments) {
|
||||
// Between segments the talker KV-cache must be reset so segment 2's
|
||||
// prefill logits don't contain segment 1's state. Skipping this
|
||||
// produces garbled speech from segment 2 onwards.
|
||||
if (seg > 0) hexReset()
|
||||
|
||||
val nPrefill = bb.int
|
||||
val nTotal = bb.int
|
||||
val prefill = List(nPrefill) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||||
val nTrailing = nTotal - nPrefill
|
||||
val trailing = List(nTrailing) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||||
nlog("Streaming seg ${seg+1}/$nSegments: $nPrefill prefill + $nTrailing decode")
|
||||
|
||||
val tSeg = System.currentTimeMillis()
|
||||
val audio = runHexSegmentFromEmbeds(prefill, trailing, seg)
|
||||
val segMs = System.currentTimeMillis() - tSeg
|
||||
nlog("Streaming seg ${seg+1}/$nSegments: ${audio.size/SR.toFloat()}s audio in ${segMs}ms")
|
||||
|
||||
// Emit to caller immediately so playback starts now. WAV dump is
|
||||
// synchronous here; for true zero-lag streaming the caller can
|
||||
// write to AudioTrack on its own dispatcher from the callback.
|
||||
segmentAudios.add(audio)
|
||||
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${seg+1}.wav", audio)
|
||||
onSegmentReady?.invoke(seg, audio)
|
||||
}
|
||||
|
||||
// Concatenate with short gaps between segments for the full-file WAV.
|
||||
// Playback path already inserted perceptual spacing via the callback order.
|
||||
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
|
||||
val concat = ShortArray(total)
|
||||
var off = 0
|
||||
for ((i, s) in segmentAudios.withIndex()) {
|
||||
System.arraycopy(s, 0, concat, off, s.size); off += s.size
|
||||
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
|
||||
}
|
||||
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
|
||||
nlog("Streaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s ($nSegments seg)")
|
||||
return concat
|
||||
}
|
||||
|
||||
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
|
||||
* save one file per segment plus the concatenated result for inspection. */
|
||||
private fun saveWav(path: String, audio: ShortArray) {
|
||||
try {
|
||||
val fos = java.io.FileOutputStream(path)
|
||||
val dataLen = audio.size * 2
|
||||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||||
fos.write(header.array())
|
||||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (s in audio) buf.putShort(s)
|
||||
fos.write(buf.array()); fos.close()
|
||||
} catch (e: Exception) { nlog("WAV save failed ($path): ${e.message}") }
|
||||
}
|
||||
|
||||
/** Test with pre-computed codec tokens from PC (for validation) */
|
||||
fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray {
|
||||
if (!loaded) return ShortArray(0)
|
||||
|
|
|
|||
|
|
@ -59,25 +59,25 @@ print("\nLoading model...")
|
|||
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||||
talker = tts.model.talker
|
||||
|
||||
# Capture generation inputs via monkey-patch on inner model
|
||||
captured_inputs = []
|
||||
# Capture generation inputs. The previous version of this script only captured
|
||||
# 1-token decode calls and reassembled them with a generic 9-embed "prefill_base"
|
||||
# pulled from another file — dropping the per-segment prefill that contains the
|
||||
# xvector conditioning AND the text-encoded embeddings. With that generic prefix
|
||||
# the talker had no idea which sentence to produce → Hexagon output was garbled.
|
||||
# Fix: capture the MULTI-token prefill call too (first call has shape[1] > 1),
|
||||
# exactly like prepare_tts_voiceclone.py does. Each segment becomes self-contained.
|
||||
captured_embeds = [] # 1024-dim vectors in order
|
||||
call_shapes = [] # length of each talker.model.forward call
|
||||
original_model_forward = talker.model.forward
|
||||
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
||||
if inputs_embeds is not None and inputs_embeds.shape[1] == 1:
|
||||
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32))
|
||||
if inputs_embeds is not None and inputs_embeds.dim() == 3:
|
||||
t = inputs_embeds.shape[1]
|
||||
call_shapes.append(t)
|
||||
for j in range(t):
|
||||
captured_embeds.append(inputs_embeds[0, j, :].detach().cpu().numpy().astype(np.float32))
|
||||
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
||||
talker.model.forward = patched_model_forward
|
||||
|
||||
# Load prefill structure
|
||||
EXISTING = "/tmp/existing_embeds.bin"
|
||||
if not os.path.exists(EXISTING):
|
||||
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
|
||||
with open(EXISTING, "rb") as f:
|
||||
nP = struct.unpack("<i", f.read(4))[0]
|
||||
nT = struct.unpack("<i", f.read(4))[0]
|
||||
old_embeds = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
|
||||
prefill_base = old_embeds[:9] # role+ctrl+spk+bos
|
||||
|
||||
# Generate embeds for each segment
|
||||
all_segments_data = []
|
||||
total_audio_duration = 0
|
||||
|
|
@ -86,7 +86,8 @@ for i, seg_text in enumerate(segments):
|
|||
print(f"\n--- Segment {i+1}/{len(segments)} ---")
|
||||
print(f" '{seg_text[:60]}...'")
|
||||
|
||||
captured_inputs.clear()
|
||||
captured_embeds.clear()
|
||||
call_shapes.clear()
|
||||
audio_list, sr = tts.generate_voice_clone(
|
||||
text=seg_text, ref_audio=VOICE, language="french",
|
||||
x_vector_only_mode=True, non_streaming_mode=True,
|
||||
|
|
@ -94,40 +95,39 @@ for i, seg_text in enumerate(segments):
|
|||
audio_dur = len(audio_list[0]) / sr
|
||||
total_audio_duration += audio_dur
|
||||
|
||||
if len(captured_inputs) < 2:
|
||||
print(f" WARNING: Only {len(captured_inputs)} steps, skipping")
|
||||
if not call_shapes:
|
||||
print(f" WARNING: no calls captured, skipping")
|
||||
continue
|
||||
|
||||
nPrefill = 10 # 9 base + first gen input
|
||||
nDecode = len(captured_inputs) - 1
|
||||
nTotal = nPrefill + nDecode
|
||||
# First call = multi-token prefill (xvector + text encoding + role tokens).
|
||||
# Subsequent calls = single-token decode (one per generated codec frame).
|
||||
nPrefill = call_shapes[0]
|
||||
nTotal = len(captured_embeds)
|
||||
if nTotal <= nPrefill:
|
||||
print(f" WARNING: no decode steps captured, skipping")
|
||||
continue
|
||||
|
||||
seg_data = {
|
||||
'nPrefill': nPrefill,
|
||||
'nTotal': nTotal,
|
||||
'prefill': prefill_base.copy(),
|
||||
'first_gen': captured_inputs[0],
|
||||
'decode': captured_inputs[1:],
|
||||
'embeds': captured_embeds.copy(),
|
||||
'audio_dur': audio_dur,
|
||||
'n_tokens': len(captured_inputs),
|
||||
}
|
||||
all_segments_data.append(seg_data)
|
||||
print(f" {nTotal} embeds ({len(captured_inputs)} tokens), {audio_dur:.2f}s audio")
|
||||
print(f" {nTotal} embeds ({nPrefill} prefill + {nTotal - nPrefill} decode), {audio_dur:.2f}s audio")
|
||||
|
||||
# Write multi-segment file
|
||||
# Write multi-segment file. Per-segment layout matches single-segment format so
|
||||
# the tablet can read either shape with the same parser.
|
||||
with open(OUTPUT, "wb") as f:
|
||||
f.write(struct.pack("<i", len(all_segments_data)))
|
||||
for seg in all_segments_data:
|
||||
f.write(struct.pack("<i", seg['nPrefill']))
|
||||
f.write(struct.pack("<i", seg['nTotal']))
|
||||
for emb in seg['prefill']:
|
||||
f.write(emb.tobytes())
|
||||
f.write(seg['first_gen'].tobytes())
|
||||
for emb in seg['decode']:
|
||||
for emb in seg['embeds']:
|
||||
f.write(emb.tobytes())
|
||||
|
||||
sz = os.path.getsize(OUTPUT)
|
||||
total_tokens = sum(s['n_tokens'] for s in all_segments_data)
|
||||
total_tokens = sum(s['nTotal'] - s['nPrefill'] for s in all_segments_data)
|
||||
print(f"\n=== RESULT ===")
|
||||
print(f"Segments: {len(all_segments_data)}")
|
||||
print(f"Total tokens: {total_tokens}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue