diff --git a/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt b/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt index 4411c77..4ac2ab4 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/service/KazeiaService.kt @@ -123,6 +123,47 @@ class KazeiaService : Service() { // Audio is played by the TTS engine internally } } + intent?.getStringExtra("stream_text")?.let { text -> + // Stage 2 streaming from arbitrary text: BPE tokenize on-device, + // look up embeds in the full Qwen3 vocab, run the existing + // interleaved Hexagon generation loop, and play each segment + // as soon as it's decoded. No PC-side prep required. + log("Stream text: '${text.take(60)}${if (text.length>60) "..." else ""}'") + serviceScope.launch { + try { + val qwenTts = tts as? com.kazeia.tts.Qwen3TtsEngine ?: return@launch + val sr = 24000 + val track = android.media.AudioTrack.Builder() + .setAudioAttributes(android.media.AudioAttributes.Builder() + .setUsage(android.media.AudioAttributes.USAGE_MEDIA) + .setContentType(android.media.AudioAttributes.CONTENT_TYPE_SPEECH) + .build()) + .setAudioFormat(android.media.AudioFormat.Builder() + .setEncoding(android.media.AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(sr) + .setChannelMask(android.media.AudioFormat.CHANNEL_OUT_MONO) + .build()) + .setBufferSizeInBytes(sr * 4) + .setTransferMode(android.media.AudioTrack.MODE_STREAM) + .build() + track.play() + val tStart = System.currentTimeMillis() + var firstLogged = false + qwenTts.synthesizeTextStreaming(text) { segIdx, audio -> + if (!firstLogged) { + log("First audio out at ${System.currentTimeMillis() - tStart}ms (seg ${segIdx+1})") + firstLogged = true + } + track.write(audio, 0, audio.size) + } + track.stop(); track.release() + log("Stream text done at ${System.currentTimeMillis() - tStart}ms") + } catch (e: Exception) { + log("Stream text error: ${e.message}") + e.printStackTrace() + } + } + } intent?.getStringExtra("stream_pipeline")?.let { embedsPath -> // Stage 1 streaming pipeline: generate segment-by-segment and play each // segment the moment it's ready via an AudioTrack MODE_STREAM. First audio diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3BpeTokenizer.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3BpeTokenizer.kt new file mode 100644 index 0000000..adc334a --- /dev/null +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3BpeTokenizer.kt @@ -0,0 +1,219 @@ +package com.kazeia.tts + +import android.util.Log +import org.json.JSONObject +import java.io.File +import java.util.regex.Pattern + +/** + * Byte-level BPE tokenizer compatible with Qwen2/Qwen3 tokenizer. + * + * Why this file exists: + * The app needs to tokenize arbitrary LLM-generated French text on-device + * for the TTS pipeline (the reduced 1050-token table shipped previously + * couldn't handle free-form text). The reference is the HuggingFace + * Qwen2TokenizerFast — a GPT-2-style byte-level BPE. We reimplement it + * in Kotlin rather than linking libtokenizers.so so there's no new + * native dependency and the numbers are easy to audit. + * + * Algorithm (identical to GPT-2 / Qwen2): + * 1. Pre-tokenize the input with a regex that groups contractions, + * words, numbers, and whitespace into "word chunks" that BPE never + * merges across. + * 2. UTF-8 encode each chunk, then map each byte (0..255) to one of + * 256 printable Unicode code points via a fixed "byte encoder" table + * — this is the trick that lets a byte-level vocab fit inside JSON + * without invalid or control characters. + * 3. Apply BPE: repeatedly find the pair with the lowest rank in the + * merges list and merge it, until no more merges apply. Look up each + * resulting super-token in vocab.json to get the final token IDs. + * + * Bit-perfect compatibility notes: + * - The pre-tokenize regex below MUST match the one in tokenizer_config.json. + * Qwen2 uses the GPT-2 pattern with a couple of Unicode property + * extensions; Java regex supports these directly. + * - Byte encoder is canonical (see bytesToUnicode()). + * - Merges are rank-ordered: lower line number = higher priority, matching + * HuggingFace's `merges.txt` file ordering. + * - We do NOT add BOS/EOS or chat-template specials — the TTS prefill + * prepends its own 9-embed voice prefix that already handles role tokens. + */ +class Qwen3BpeTokenizer private constructor( + private val vocab: HashMap, + private val merges: HashMap, Int>, + private val byteEncoder: IntArray, +) { + companion object { + private const val TAG = "Qwen3BPE" + + // Qwen2/Qwen3 pre-tokenization regex. This is the exact pattern used + // by the HuggingFace Qwen2Tokenizer (adapted from GPT-2). Matches + // contractions ('s|'d|'ll|'ve|'re|'t), word chunks with a leading + // optional non-word char + letters, runs of digits, runs of + // punctuation, and whitespace runs with boundary semantics. + // + // Note on the character classes: Android's Pattern does NOT support + // UNICODE_CHARACTER_CLASS — so plain \\p{L} / \\p{N} would collapse + // to ASCII-only and break French accents ("é" → wrong token). Use + // \\p{IsAlphabetic} and \\p{IsDigit} instead; those ARE Unicode-aware + // out of the box (they map to ICU's IsAlphabetic / IsDigit properties + // in both JDK and Android runtimes). Output matches Python's Qwen2 + // tokenizer on French text. + private val PRE_TOKENIZE_PATTERN: Pattern = Pattern.compile( + "'s|'d|'ll|'ve|'re|'t" + + "|[^\\r\\n\\p{IsAlphabetic}\\p{IsDigit}]?\\p{IsAlphabetic}+" + + "|\\p{IsDigit}{1,3}" + + "| ?[^\\s\\p{IsAlphabetic}\\p{IsDigit}]+[\\r\\n]*" + + "|\\s*[\\r\\n]+" + + "|\\s+(?!\\S)" + + "|\\s+" + ) + + /** + * GPT-2 byte encoder: maps 0..255 → a printable Unicode codepoint. + * Ensures every possible byte has a visible, JSON-safe representation + * so a byte-level vocab can be stored as strings in vocab.json. + */ + private fun bytesToUnicode(): IntArray { + val bs = mutableListOf() + // Printable ASCII and common Latin blocks. + bs.addAll(('!'.code..'~'.code).toList()) + bs.addAll(('¡'.code..'¬'.code).toList()) + bs.addAll(('®'.code..'ÿ'.code).toList()) + val cs = bs.toMutableList() + // Every byte not in bs gets mapped to a code point past 255 so no + // existing character collides with it. + var n = 0 + val map = IntArray(256) + for (b in 0..255) { + if (b in bs) { + map[b] = b + } else { + map[b] = 256 + n + bs.add(b) // placeholder only, we've already recorded cs from the frozen snapshot + cs.add(256 + n) + n += 1 + } + } + return map + } + + fun load(modelDir: String): Qwen3BpeTokenizer { + val t0 = System.currentTimeMillis() + val vocabFile = File("$modelDir/vocab.json") + val mergesFile = File("$modelDir/merges.txt") + require(vocabFile.exists()) { "vocab.json missing at $modelDir" } + require(mergesFile.exists()) { "merges.txt missing at $modelDir" } + + val vocabJson = JSONObject(vocabFile.readText()) + val vocab = HashMap(vocabJson.length()) + val keys = vocabJson.keys() + while (keys.hasNext()) { + val k = keys.next() + vocab[k] = vocabJson.getInt(k) + } + + val merges = HashMap, Int>() + var rank = 0 + mergesFile.useLines { lines -> + for (line in lines) { + // Skip header / blanks. Qwen's merges.txt starts with + // "#version" which we simply filter out. + if (line.isBlank() || line.startsWith("#")) continue + val sp = line.indexOf(' ') + if (sp < 0) continue + merges[Pair(line.substring(0, sp), line.substring(sp + 1))] = rank + rank++ + } + } + + Log.i(TAG, "Loaded vocab=${vocab.size} merges=${merges.size} in ${System.currentTimeMillis()-t0}ms") + return Qwen3BpeTokenizer(vocab, merges, bytesToUnicode()) + } + } + + /** + * Convert a single pre-tokenized word (UTF-8 bytes encoded via the byte + * encoder into a string) into token IDs via BPE merges. Caches results so + * repeated common words (spaces, punctuation) only BPE once. + */ + private val bpeCache = HashMap() + + private fun bpeEncode(byteEncodedWord: String): IntArray { + bpeCache[byteEncodedWord]?.let { return it } + + // Start with one "sub-token" per Unicode code point (code points, + // not chars — surrogate pairs are handled automatically since the + // byte encoder only produces BMP codepoints by construction). + val parts = ArrayList(byteEncodedWord.length) + var i = 0 + while (i < byteEncodedWord.length) { + val cp = byteEncodedWord.codePointAt(i) + parts.add(String(Character.toChars(cp))) + i += Character.charCount(cp) + } + if (parts.size < 2) { + val id = vocab[parts.getOrElse(0) { "" }] ?: vocab[""] ?: 0 + val out = intArrayOf(id) + bpeCache[byteEncodedWord] = out + return out + } + + // Greedy lowest-rank merge, classic BPE. We scan for the pair with + // the smallest rank, merge ALL its occurrences, then repeat. This + // matches HF's reference implementation. + while (parts.size > 1) { + var bestRank = Int.MAX_VALUE + var bestIdx = -1 + for (k in 0 until parts.size - 1) { + val r = merges[Pair(parts[k], parts[k + 1])] ?: continue + if (r < bestRank) { bestRank = r; bestIdx = k } + } + if (bestIdx < 0) break + // Merge all non-overlapping occurrences of that exact pair. + val a = parts[bestIdx]; val b = parts[bestIdx + 1] + val merged = a + b + var k = 0 + val out = ArrayList(parts.size - 1) + while (k < parts.size) { + if (k < parts.size - 1 && parts[k] == a && parts[k + 1] == b) { + out.add(merged); k += 2 + } else { + out.add(parts[k]); k += 1 + } + } + parts.clear(); parts.addAll(out) + } + + val ids = IntArray(parts.size) + for (k in parts.indices) { + ids[k] = vocab[parts[k]] ?: vocab[""] ?: 0 + } + bpeCache[byteEncodedWord] = ids + return ids + } + + /** + * Encode text → token IDs using the full Qwen3 vocabulary. Does NOT + * prepend BOS/EOS — callers can add specials themselves. Unicode + * characters outside ASCII (e.g. French accents) are UTF-8 encoded and + * go through the byte encoder, so "é" and "ï" tokenize the same way as + * they do in Python. + */ + fun encode(text: String): IntArray { + val all = ArrayList(text.length / 2 + 4) + val matcher = PRE_TOKENIZE_PATTERN.matcher(text) + while (matcher.find()) { + val chunk = matcher.group() + // UTF-8 encode the chunk, then map each raw byte to its Unicode + // "byte-encoded" character. This produces the exact string that + // BPE merges operate on. + val bytes = chunk.toByteArray(Charsets.UTF_8) + val sb = StringBuilder(bytes.size) + for (b in bytes) sb.appendCodePoint(byteEncoder[b.toInt() and 0xff]) + val ids = bpeEncode(sb.toString()) + for (id in ids) all.add(id) + } + return all.toIntArray() + } +} diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 4088624..b75d42d 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -100,8 +100,21 @@ class Qwen3TtsEngine( private var decoderOnGpu: Boolean = false // Dual embedding tables for talker input - private var textEmbeds: FloatArray? = null // [1050, 1024] - pre-projected text embeddings + private var textEmbeds: FloatArray? = null // [1050, 1024] - reduced vocab, legacy fallback private var codecEmbedding: FloatArray? = null // [3072, 1024] - codec/control token embeddings + + // Stage 2 — on-device full-vocab text embeddings + BPE tokenizer. + // textEmbedsFull is 151936 × 1024 fp16 memory-mapped (~296 MB); using + // mmap keeps the bytes off the Java heap so the app doesn't crash when + // the ~125 MB cp_embeddings allocation comes next. damienVoicePrefix is + // the fixed 9-embed voice-cloning header that is prepended to the + // tokenized text to form a full prefill. + private var textEmbedsFullBuf: java.nio.ByteBuffer? = null + private var textEmbedsFullChan: java.nio.channels.FileChannel? = null + private val textEmbedsFullLen = 151936 + private var damienVoicePrefix: Array? = null + private var damienVoiceSuffix: Array? = null + private var bpeTokenizer: Qwen3BpeTokenizer? = null private var ttsBosEmbed: FloatArray? = null // [1024] - tts_bos text-side embedding private var ttsEosEmbed: FloatArray? = null // [1024] - tts_eos text-side embedding private var ttsPadEmbed: FloatArray? = null // [1024] - tts_pad text-side embedding @@ -457,6 +470,72 @@ class Qwen3TtsEngine( // Load dual embedding tables for talker textEmbeds = loadNpy("$path/text_embeds_projected.npy") nlog("Text embeddings: ${textEmbeds!!.size / TALKER_DIM} × $TALKER_DIM") + + // Stage 2 assets: full-vocab text embeddings + voice prefix + BPE. + // All three are optional — if any is missing we fall back to the + // legacy 1050-token path. This keeps the app bootable during + // asset rollout and avoids turning a missing file into a crash. + try { + val fullEmbFile = File("$path/text_embeds_full_fp16.bin") + val prefixFile = File("$path/damien_voice_prefix.bin") + val tokDir = File("$path/qwen3_tokenizer") + if (fullEmbFile.exists() && prefixFile.exists() && tokDir.isDirectory) { + val tE0 = System.currentTimeMillis() + // Memory-map the fp16 embeddings table instead of heap- + // loading it. Without mmap, the 296 MB ByteArray plus + // the 125 MB cp_embeddings FloatArray loaded a few + // lines below overrun the ~536 MB large-heap limit and + // the app OOMs during init. mmap pages the file via + // the kernel and keeps zero bytes on the Java heap. + val expectedBytes = textEmbedsFullLen.toLong() * TALKER_DIM * 2 + if (fullEmbFile.length() != expectedBytes) { + nlog("text_embeds_full_fp16 size mismatch (got ${fullEmbFile.length()}, expected $expectedBytes) — disabling on-device text") + } else { + textEmbedsFullChan = java.io.RandomAccessFile(fullEmbFile, "r").channel + textEmbedsFullBuf = textEmbedsFullChan!!.map( + java.nio.channels.FileChannel.MapMode.READ_ONLY, 0L, expectedBytes + ).order(ByteOrder.LITTLE_ENDIAN) + nlog("Full-vocab text embeddings mmap: ${textEmbedsFullLen} × $TALKER_DIM fp16 (${expectedBytes/1024/1024}MB, off-heap) in ${System.currentTimeMillis()-tE0}ms") + } + + val pb = ByteBuffer.wrap(prefixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN) + val nPref = pb.int; val dimPref = pb.int + if (nPref == 9 && dimPref == TALKER_DIM) { + damienVoicePrefix = Array(9) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = pb.float } } + nlog("Damien voice prefix: $nPref × $dimPref") + } else { + nlog("damien_voice_prefix.bin header mismatch ($nPref × $dimPref) — disabling on-device text") + } + + // Voice SUFFIX — the 2 fixed positions that close out + // the prefill after text tokens. Empirically invariant + // across segments of the same speaker (diff = 0.0). + val suffixFile = File("$path/damien_voice_suffix.bin") + if (suffixFile.exists()) { + val sb = ByteBuffer.wrap(suffixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN) + val nSuf = sb.int; val dimSuf = sb.int + if (nSuf == 2 && dimSuf == TALKER_DIM) { + damienVoiceSuffix = Array(2) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = sb.float } } + nlog("Damien voice suffix: $nSuf × $dimSuf") + } + } else { + nlog("damien_voice_suffix.bin missing — on-device text will lack closure markers") + } + + if (textEmbedsFullBuf != null && damienVoicePrefix != null && damienVoiceSuffix != null) { + bpeTokenizer = Qwen3BpeTokenizer.load(tokDir.absolutePath) + nlog("Stage 2 on-device text path ready (BPE + full embeds + voice prefix)") + } + } else { + nlog("Stage 2 assets not fully present (full=$fullEmbFile, prefix=$prefixFile, tok=$tokDir) — legacy path only") + } + } catch (e: Exception) { + nlog("Stage 2 asset load failed: ${e.message} — legacy path only") + textEmbedsFullBuf = null; damienVoicePrefix = null; bpeTokenizer = null + try { textEmbedsFullChan?.close() } catch (_: Exception) {} + textEmbedsFullChan = null + } + codecEmbedding = loadNpy("$path/codec_embedding.npy") nlog("Codec embedding: ${codecEmbedding!!.size / TALKER_DIM} × $TALKER_DIM") val ttsSpecial = loadNpy("$path/tts_special_embeds.npy") // [3, 1024] = bos, eos, pad @@ -3183,6 +3262,269 @@ class Qwen3TtsEngine( return concat } + /** + * Look up a single token in the fp16 full-vocab text-embedding table and + * return it as fp32. Uses direct ByteBuffer arithmetic so we don't + * allocate a new buffer per token — for a typical 50-token sentence the + * inner loop runs 50 × 1024 fp16→fp32 conversions. + */ + private fun textEmbFromFull(tokenId: Int): FloatArray { + val buf = textEmbedsFullBuf ?: error("Stage 2 full embeddings not loaded") + val clamped = tokenId.coerceIn(0, textEmbedsFullLen - 1) + val base = clamped * TALKER_DIM * 2 + val out = FloatArray(TALKER_DIM) + synchronized(buf) { + // MappedByteBuffer has mutable position; guard in case two + // coroutines ever race on tokenizer output concurrently. + buf.position(base) + for (j in 0 until TALKER_DIM) { + val bits = buf.short + out[j] = halfToFloat(bits) + } + } + return out + } + + /** IEEE 754 fp16 -> fp32 conversion. Handles subnormals, inf and NaN + * exactly the way Python's `torch.float16.to(float32)` does. */ + private fun halfToFloat(h: Short): Float { + val bits = h.toInt() and 0xffff + val sign = (bits ushr 15) and 0x1 + val exp = (bits ushr 10) and 0x1f + val mant = bits and 0x3ff + val f32: Int = when { + exp == 0 && mant == 0 -> sign shl 31 + exp == 0 -> { + // Subnormal: normalize by shifting until leading 1 appears. + var e = -14; var m = mant + while ((m and 0x400) == 0) { m = m shl 1; e -= 1 } + val e32 = e + 127 + (sign shl 31) or (e32 shl 23) or ((m and 0x3ff) shl 13) + } + exp == 0x1f -> (sign shl 31) or (0xff shl 23) or (mant shl 13) // inf / NaN + else -> (sign shl 31) or ((exp - 15 + 127) shl 23) or (mant shl 13) + } + return Float.fromBits(f32) + } + + /** + * Run the Hexagon talker + CP generation loop with a fully pre-built + * prefill (voice prefix + all text tokens). Same decode recipe as + * runInterleavedHexagon's inner loop: at each step the next talker + * input is codecSum (codec embedding of previous codes) + tts_pad + * (since all text has already been consumed in the prefill). On EOS, + * terminate early. Returns the generated [step, codebook] codes. + */ + private fun runHexGenWithPrefill(prefill: List, maxGen: Int): Array { + val padE = ttsPadEmbed ?: return emptyArray() + val eosE = ttsEosEmbed ?: return emptyArray() + val allCodes = mutableListOf() + val generatedCb0 = mutableListOf() + var totalTalkerMs = 0L; var totalCpMs = 0L + + val tPrefill = System.currentTimeMillis() + val prefillResults = hexForward(prefill) + nlog("VC prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps") + if (prefillResults.isEmpty()) return emptyArray() + + var pastHidden = prefillResults.last().first + val prefillLogits = prefillResults.last().second + for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY } + var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50) + nlog("VC prefill done: first cb0=$currentCb0") + + // After the text has been fully consumed in prefill, Python's voice- + // clone loop feeds tts_eos once, then tts_pad for every subsequent + // decode step. We follow the same schedule so the model's attention + // sees the same "text exhausted" signal it was trained with. + var eosFedOnce = false + for (genStep in 0 until maxGen) { + val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 + val tCp = System.currentTimeMillis() + val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0) + for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1] + allCodes.add(codes); generatedCb0.add(currentCb0) + totalCpMs += System.currentTimeMillis() - tCp + + val codecSum = FloatArray(TALKER_DIM) + addEmb(codecSum, codecEmb(codes[0])) + for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb])) + val nextEmbed = if (!eosFedOnce) { eosFedOnce = true; sumEmb(codecSum, eosE) } else sumEmb(codecSum, padE) + + val tT = System.currentTimeMillis() + val results = hexForward(listOf(nextEmbed)) + totalTalkerMs += System.currentTimeMillis() - tT + if (results.isEmpty()) { nlog("VC: hex empty at step ${genStep+1}"); break } + pastHidden = results[0].first + val logits = results[0].second + for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } + val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) + for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f } + currentCb0 = sampleTopK(logits, 0.9f, 50) + if (currentCb0 == CODEC_EOS) { nlog("VC EOS at step ${genStep+1}"); break } + + // Degeneracy guard: when the talker fails to emit EOS it falls + // into a stuck loop where cb0 repeats (the "page beg beg beg" + // artifact audible at the tail of generated phrases). Nine + // consecutive identical cb0s is the threshold the native .pte + // pipeline uses too. The short history is just the last 9 + // entries of generatedCb0 — cheap to scan. + val nHist = generatedCb0.size + if (nHist >= 9) { + val last = generatedCb0[nHist - 1] + var allSame = true + for (i in nHist - 9 until nHist) if (generatedCb0[i] != last) { allSame = false; break } + if (allSame && currentCb0 == last) { + nlog("VC degen: cb0=$last repeated ≥9× at step ${genStep+1}, stopping") + break + } + } + } + nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms") + return allCodes.toTypedArray() + } + + /** + * Split text into short segments for the streaming pipeline. Reproduces + * the behaviour of scripts/prepare_tts_segments.py but on-device so the + * tablet doesn't depend on a PC-side preprocessor. The 120-character + * target matches the length where the talker still terminates reliably + * on EOS; longer segments risk the auto-repressor's repetition penalty + * cutting decode short of the full phrase. + */ + private fun splitSentences(text: String, maxChars: Int = 120): List { + val first = text.trim().split(Regex("(?<=[.!?;:])\\s+")) + val out = mutableListOf() + for (part in first) { + if (part.length <= maxChars) { + if (part.isNotBlank()) out.add(part.trim()) + continue + } + // Break overlong sentences at commas, greedily packing sub-parts + // back together up to maxChars so we don't over-split. + val subs = part.split(Regex("(?<=,)\\s+")) + var current = "" + for (s in subs) { + if (current.isNotEmpty() && current.length + s.length > maxChars) { + out.add(current.trim()); current = s + } else { + current = if (current.isEmpty()) s else "$current $s" + } + } + if (current.isNotBlank()) out.add(current.trim()) + } + return if (out.isEmpty()) listOf(text) else out + } + + /** + * Tokenize `text` on-device and stream the synthesized audio segment by + * segment via `onSegmentReady`. The PC-side prep script becomes optional — + * this is the first path in the app where TTS runs fully offline on + * arbitrary LLM output. + * + * For each segment: + * 1. BPE tokenize the segment text (same algorithm as Python's + * Qwen2Tokenizer). + * 2. Look up each token ID in the fp16 full-vocab table, converted to + * fp32 on the fly. One embedding per token. + * 3. Call the existing runInterleavedHexagon loop — which already + * synthesizes its own decode inputs via codec_sum + trailing text — + * so we reuse the same prefill construction and generation path that + * runs today for the pre-computed-embeds test harness. + * 4. Decode codes → audio via decodeChunked, emit the audio through + * the callback immediately, save WAV per segment plus the concat. + * + * Emits at most one callback per segment. First-audio latency ≈ prefill + + * one segment's decode (typically ~17-22s for a 5-6s-duration segment on + * Snapdragon 8 Elite). The ordering and gaps between segments are the + * same as generateFromEmbedsHexagonStreaming. + */ + fun synthesizeTextStreaming( + text: String, + onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null + ): ShortArray { + if (!loaded || !useHexagonTalker) { + nlog("synthesizeTextStreaming: Hexagon talker not ready"); return ShortArray(0) + } + if (bpeTokenizer == null || textEmbedsFullBuf == null) { + nlog("synthesizeTextStreaming: Stage 2 assets missing"); return ShortArray(0) + } + val segments = splitSentences(text) + nlog("synthesizeTextStreaming: ${segments.size} segment(s) for ${text.length} chars") + + hexReset() + val segmentAudios = mutableListOf() + val gapSamples = SR * 120 / 1000 + val gap = ShortArray(gapSamples) + val t0 = System.currentTimeMillis() + + val prefix = damienVoicePrefix!! + val suffix = damienVoiceSuffix!! + // CODEC_PAD embedding is the per-token companion that Python voice- + // cloning sums into every text-encoded prefill position. Computed + // once here so the per-token loop stays a simple vector add. + val codecPadEmb = codecEmb(CODEC_PAD) + + for ((segIdx, segText) in segments.withIndex()) { + if (segIdx > 0) hexReset() + val tSeg = System.currentTimeMillis() + val ids = bpeTokenizer!!.encode(segText) + nlog("Seg ${segIdx+1}/${segments.size}: '${segText.take(60)}' → ${ids.size} tokens: ${ids.toList()}") + + // Voice-cloning prefill, fully reconstructed on-device — exact + // structure Python emits via generate_voice_clone (verified + // bit-for-bit by comparing captured Baer segments): + // [0..8] damienVoicePrefix (9 fixed positions, xvector@7) + // [9..N-3] text_projection(BPE_id) + codec_embedding(CODEC_PAD) + // [N-2, N-1] damienVoiceSuffix (2 fixed positions, end-of-text marker) + // The earlier attempt skipped the suffix and used raw text + // projections, which produced garbled audio — the talker needs + // BOTH the per-token codec_pad sum AND the closure markers to + // know that text input has ended and decoding can begin. + val prefill = ArrayList(prefix.size + ids.size + suffix.size) + for (e in prefix) prefill.add(e) + for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb)) + for (e in suffix) prefill.add(e) + + // Empirical budget: Python's voice_clone typically emits ~3.3 + // codec frames per text token for French. Keep a small cushion + // so ~80% of runs terminate via EOS/degeneracy before exhausting + // the budget; trimming is done by the degeneracy guard inside + // runHexGenWithPrefill. Too-generous maxGen guarantees the tail + // artifacts the user hears as "page beg beg beg". + val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15) + val codes = runHexGenWithPrefill(prefill, maxGen) + if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue } + + val n = codes.size + val padLen = maxOf(n, SEQ_LEN) + val codebooks = Array(NUM_CODEBOOKS) { cb -> + IntArray(padLen) { t -> + if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0 + } + } + val audio = decodeChunked(codebooks, n) + val segMs = System.currentTimeMillis() - tSeg + nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms") + + segmentAudios.add(audio) + saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio) + onSegmentReady?.invoke(segIdx, audio) + } + + if (segmentAudios.isEmpty()) return ShortArray(0) + val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples + val concat = ShortArray(total) + var off = 0 + for ((i, s) in segmentAudios.withIndex()) { + System.arraycopy(s, 0, concat, off, s.size); off += s.size + if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples } + } + saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat) + nlog("synthesizeTextStreaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s") + return concat + } + /** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to * save one file per segment plus the concatenated result for inspection. */ private fun saveWav(path: String, audio: ShortArray) { diff --git a/scripts/export_tts_text_embeddings.py b/scripts/export_tts_text_embeddings.py new file mode 100644 index 0000000..e52c7fa --- /dev/null +++ b/scripts/export_tts_text_embeddings.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Export everything the tablet needs to build TTS prefill embeds for arbitrary +LLM text, offline, without talking to a PC. + +Outputs (pushed to /data/local/tmp/kazeia/models/qwen3-tts-npu/): + - text_embeds_full_fp16.bin : 151936 × 1024 fp16 = 311 MB + Pre-projected text embeddings for the full Qwen3 vocab. Per-token + lookup on-device replaces a lookup + FC1 + SiLU + FC2 + bias. Same + numbers PyTorch produces for text_projection(text_embedding(id)). + + - damien_voice_prefix.bin : 9 × 1024 fp32 = 36 KB + The fixed voice-cloning prefix (positions 0..8) for speaker Damien, + captured from a real voice-clone run. Positions 0..6 = role/control + tokens, position 7 = xvector (L2 norm ~10), position 8 = trailing + voice-marker. Same for every phrase uttered by this speaker, so we + capture once here and reuse indefinitely on-device. + + - damien_voice_suffix.bin : 2 × 1024 fp32 = 8 KB + The fixed voice-cloning SUFFIX (last 2 positions of the prefill) + that Python emits AFTER the text tokens. Verified bit-identical + across segments of different texts → invariant closure marker + for the voice-clone conditioning. Without it the talker misreads + the end of the text and produces garbled output. + + - qwen3_tokenizer/ : tokenizer files copied from HF snapshot + tokenizer.json, vocab.json, merges.txt, special_tokens_map.json. + Kotlin BPE implementation reads vocab + merges at init. + +The combination lets the tablet build, for any text, the exact same +prefill tensor PyTorch would build, bit-for-bit at fp16 — which is +what our Hexagon talker consumes anyway. + +Usage: + python3 export_tts_text_embeddings.py [output_dir] +""" +import sys, os, struct, shutil, warnings +os.chdir("/tmp") +warnings.filterwarnings("ignore") + +OUTPUT_DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/kazeia_tts_export" +MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" +VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav" + +os.makedirs(OUTPUT_DIR, exist_ok=True) +os.makedirs(f"{OUTPUT_DIR}/qwen3_tokenizer", exist_ok=True) + +import torch, numpy as np +from qwen_tts import Qwen3TTSModel + +print("Loading Qwen3-TTS model (~30s, CPU)...") +tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") +talker = tts.model.talker + +# ---- 1. Full projected text embeddings ---- +# Evaluate text_projection(text_embedding.weight) for EVERY vocab entry. +# Batching keeps peak memory bounded; fp32 matmul then fp16 store preserves +# precision up to the final quantization step. +print("\n[1/3] Precomputing projected embeddings for full vocab...") +vocab_size = talker.model.text_embedding.weight.shape[0] +print(f" Vocab size: {vocab_size}") +BATCH = 4096 +out_path = f"{OUTPUT_DIR}/text_embeds_full_fp16.bin" +with torch.no_grad(): + W_emb = talker.model.text_embedding.weight # [vocab, 2048] + fc1_w = talker.text_projection.linear_fc1.weight # [2048, 2048] + fc1_b = talker.text_projection.linear_fc1.bias # [2048] + fc2_w = talker.text_projection.linear_fc2.weight # [1024, 2048] + fc2_b = talker.text_projection.linear_fc2.bias # [1024] + with open(out_path, "wb") as f: + for start in range(0, vocab_size, BATCH): + end = min(start + BATCH, vocab_size) + x = W_emb[start:end].float() # [b, 2048] + h = torch.nn.functional.linear(x, fc1_w, fc1_b) # [b, 2048] + h = torch.nn.functional.silu(h) # [b, 2048] + y = torch.nn.functional.linear(h, fc2_w, fc2_b) # [b, 1024] + f.write(y.to(torch.float16).numpy().tobytes()) + if start % (BATCH * 4) == 0: + print(f" {end}/{vocab_size} ({end*100//vocab_size}%)", flush=True) +sz_mb = os.path.getsize(out_path) / (1024*1024) +print(f" -> {out_path} ({sz_mb:.1f} MB)") + +# Sanity check: re-read a couple of tokens, project live, compare. +print("\n Sanity check (token 1043 = 'Bonjour'):") +with torch.no_grad(): + live = talker.text_projection(talker.model.text_embedding(torch.tensor([1043])))[0].float().numpy() +with open(out_path, "rb") as f: + f.seek(1043 * 1024 * 2) + stored = np.frombuffer(f.read(1024 * 2), dtype=np.float16).astype(np.float32) +diff = float(np.abs(live - stored).max()) +print(f" max abs diff live vs stored fp16: {diff:.2e} (expect < 1e-3)") + +# ---- 2. Damien voice prefix (positions 0..8) ---- +# Run a voice-clone and capture the multi-token prefill call, then keep the +# first 9 rows. Those are fixed per speaker — same for every phrase — so +# one capture suffices for the app's lifetime. +print(f"\n[2/3] Capturing Damien voice prefix from {VOICE}...") +captured = [] +call_shapes = [] +original_forward = talker.model.forward +def patched(input_ids=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None and inputs_embeds.dim() == 3: + call_shapes.append(inputs_embeds.shape[1]) + for i in range(inputs_embeds.shape[1]): + captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32)) + return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) +talker.model.forward = patched +# Any short sentence works — we only keep positions 0..8 which are text- +# invariant. +_ = tts.generate_voice_clone( + text="Bonjour, je suis Kazeia.", ref_audio=VOICE, language="french", + x_vector_only_mode=True, non_streaming_mode=True, +) +talker.model.forward = original_forward +nP = call_shapes[0] +print(f" Prefill size: {nP} tokens") +prefix_9 = np.stack(captured[:9]) # [9, 1024] +suffix_2 = np.stack(captured[nP-2:nP]) # [2, 1024] + +prefix_path = f"{OUTPUT_DIR}/damien_voice_prefix.bin" +with open(prefix_path, "wb") as f: + f.write(struct.pack(" {prefix_path} ({os.path.getsize(prefix_path)} bytes)") + +suffix_path = f"{OUTPUT_DIR}/damien_voice_suffix.bin" +with open(suffix_path, "wb") as f: + f.write(struct.pack(" {suffix_path} ({os.path.getsize(suffix_path)} bytes)") + +norms_pref = [float(np.linalg.norm(prefix_9[i])) for i in range(9)] +norms_suff = [float(np.linalg.norm(suffix_2[i])) for i in range(2)] +print(f" Prefix norms: {[f'{n:.2f}' for n in norms_pref]} (pos 7 = xvector ~10, others ~1.6-1.8)") +print(f" Suffix norms: {[f'{n:.2f}' for n in norms_suff]}") + +# ---- 3. Tokenizer files ---- +# Copy the HF tokenizer artefacts so a Kotlin BPE can reproduce Python +# encode() bit-for-bit. +print(f"\n[3/3] Copying tokenizer to {OUTPUT_DIR}/qwen3_tokenizer/...") +for name in ("tokenizer.json", "vocab.json", "merges.txt", "tokenizer_config.json", "special_tokens_map.json"): + src = os.path.join(MODEL, name) + if os.path.exists(src): + shutil.copy(src, f"{OUTPUT_DIR}/qwen3_tokenizer/{name}") + print(f" {name} ({os.path.getsize(src)} bytes)") + else: + print(f" (skipped, not present: {name})") + +print(f"\n=== DONE ===") +print(f"Files ready in {OUTPUT_DIR}/") +print(f"\nPush to tablet:") +print(f" adb push {OUTPUT_DIR}/text_embeds_full_fp16.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/") +print(f" adb push {OUTPUT_DIR}/damien_voice_prefix.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/") +print(f" adb push {OUTPUT_DIR}/qwen3_tokenizer /data/local/tmp/kazeia/models/qwen3-tts-npu/")