diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 753f245..b5dcde2 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -2282,7 +2282,19 @@ class Qwen3TtsEngine( val t0 = System.currentTimeMillis() val bytes = File(embedsPath).readBytes() val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) - val nPrefill = bb.int; val nTotal = bb.int + + // Detect format: multi-segment (first int = n_segments, small) or legacy (first int = nPrefill ~10) + val firstInt = bb.int + val isMultiSegment = firstInt > 0 && firstInt < 100 && firstInt != 10 && firstInt != 9 + // Heuristic: legacy format has nPrefill=9 or 10. Multi-segment has n_segments=2..50 + + if (isMultiSegment && talkerPteModule != null && cpPteModule != null) { + return generateMultiSegment(bb, firstInt, t0) + } + + // Legacy single-segment format + val nPrefill = if (isMultiSegment) { bb.position(0); bb.int; bb.int } else firstInt + val nTotal = bb.int val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } } nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)") @@ -2498,6 +2510,87 @@ class Qwen3TtsEngine( return allCodes.toTypedArray() } + /** Multi-segment pipeline: process each segment independently, concatenate audio. */ + private fun generateMultiSegment(bb: ByteBuffer, nSegments: Int, t0: Long): ShortArray { + nlog("Multi-segment: $nSegments segments") + val allAudio = mutableListOf() + + // Ensure data arrays loaded + val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu" + if (cpAllHeads == null) { + val hf = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin") + if (hf.exists()) { val hb = hf.readBytes(); cpAllHeads = FloatArray(hb.size/4); ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!) } + } + if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy") + if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy") + if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy") + if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy") + if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("$mpath/talker_pte_rotary_cos.npy") + if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("$mpath/talker_pte_rotary_sin.npy") + if (ttsEosEmbed == null) { val sp = loadNpy("$mpath/tts_special_embeds.npy"); ttsBosEmbed=sp.sliceArray(0 until TALKER_DIM); ttsEosEmbed=sp.sliceArray(TALKER_DIM until 2*TALKER_DIM); ttsPadEmbed=sp.sliceArray(2*TALKER_DIM until 3*TALKER_DIM) } + + for (seg in 0 until nSegments) { + val nPrefill = bb.int; val nTotal = bb.int + val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } } + nlog("Segment ${seg+1}/$nSegments: $nPrefill prefill + ${nTotal-nPrefill} decode") + + val prefillFlat = FloatArray(nPrefill * TALKER_DIM) + for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM) + val nTrailing = nTotal - nPrefill + val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr -> + for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM) + } else null + + val flat = talkerPteModule!!.nativeRunTtsPipeline( + prefillFlat, nPrefill, trailingFlat, nTrailing, + codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0), + talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0), + cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0), + ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM), + nTrailing + ) + if (flat == null || flat.isEmpty()) continue + val nTokens = flat.size / NUM_CODEBOOKS + val segCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } } + nlog(" → $nTokens tokens generated") + + val padLen = maxOf(nTokens, SEQ_LEN) + val codebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < nTokens) segCodes[t][cb] else 0 } } + val segAudio = decodeChunked(codebooks, nTokens) + allAudio.add(segAudio) + nlog(" → ${segAudio.size/SR.toFloat()}s audio decoded") + } + + // Concatenate all segments + val totalSamples = allAudio.sumOf { it.size } + val result = ShortArray(totalSamples) + var offset = 0 + for (seg in allAudio) { System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size } + + val totalMs = System.currentTimeMillis() - t0 + nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)") + + // Save WAV + try { + val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav" + val fos = java.io.FileOutputStream(wavPath) + val dataLen = result.size * 2 + val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN) + header.put("RIFF".toByteArray()); header.putInt(36 + dataLen) + header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray()) + header.putInt(16); header.putShort(1); header.putShort(1) + header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16) + header.put("data".toByteArray()); header.putInt(dataLen) + fos.write(header.array()) + val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN) + for (s in result) buf.putShort(s) + fos.write(buf.array()); fos.close() + nlog("WAV saved: $wavPath ($totalSamples samples)") + } catch (e: Exception) { nlog("WAV save failed: ${e.message}") } + + return result + } + /** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */ private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray { nlog("Full pipeline (Hexagon) from: $embedsPath") diff --git a/scripts/prepare_tts_segments.py b/scripts/prepare_tts_segments.py new file mode 100644 index 0000000..bf36f8a --- /dev/null +++ b/scripts/prepare_tts_segments.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +Generate TTS embeddings for long text, split into sentence segments. +Each segment is generated independently by Python for maximum quality. + +Usage: python3 prepare_tts_segments.py "Long text..." [output.bin] + adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin + +Output format: + int32 n_segments + for each segment: + int32 n_prefill + int32 n_total + float32[n_total * 1024] embeddings +""" +import sys, os, struct, re, types, warnings +os.chdir("/tmp") +warnings.filterwarnings("ignore") + +TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour. Je suis Kazeia." +OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_segments.bin" +VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav" +MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" + +import torch, numpy as np +from qwen_tts import Qwen3TTSModel + +def split_sentences(text, max_tokens=60): + """Split text at sentence boundaries, keeping segments short.""" + # Split at . ! ? ; and keep the punctuation + parts = re.split(r'(?<=[.!?;])\s+', text.strip()) + + segments = [] + current = "" + for part in parts: + if current and len(current) + len(part) > 200: # rough char limit + segments.append(current.strip()) + current = part + else: + current = (current + " " + part).strip() if current else part + if current.strip(): + segments.append(current.strip()) + + # If any segment is still too long, split at commas + final = [] + for seg in segments: + if len(seg) > 250: + parts = re.split(r'(?<=,)\s+', seg) + sub = "" + for p in parts: + if sub and len(sub) + len(p) > 200: + final.append(sub.strip()) + sub = p + else: + sub = (sub + " " + p).strip() if sub else p + if sub.strip(): + final.append(sub.strip()) + else: + final.append(seg) + + return final if final else [text] + +print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'") +segments = split_sentences(TEXT) +print(f"Split into {len(segments)} segments:") +for i, s in enumerate(segments): + print(f" [{i}] '{s[:60]}{'...' if len(s)>60 else ''}'") + +print("\nLoading model...") +tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") +talker = tts.model.talker + +# Capture generation inputs via monkey-patch on inner model +captured_inputs = [] +original_model_forward = talker.model.forward +def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None and inputs_embeds.shape[1] == 1: + captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32)) + return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) +talker.model.forward = patched_model_forward + +# Load prefill structure +EXISTING = "/tmp/existing_embeds.bin" +if not os.path.exists(EXISTING): + os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}") +with open(EXISTING, "rb") as f: + nP = struct.unpack("