TTS Stage 2: on-device voice-cloning TTS for arbitrary text

Removes the PC-side prepare_tts_segments.py dependency for day-to-day
generation. The tablet now tokenizes, embeds, and voice-clones any
French (or Qwen3-supported) text with no network, no ADB push per
phrase, and quality that matches Python's reference on "Bonjour, je
suis Kazeia, je suis là pour vous écouter." — user validation:
"impeccable".

Three pieces that compose the path:

  1. Qwen3BpeTokenizer.kt — byte-level BPE matching Qwen2/Qwen3's
     Python implementation bit-for-bit. UTF-8 + GPT-2 byte encoder,
     Qwen regex with \p{IsAlphabetic}/\p{IsDigit} (Android's regex
     lacks UNICODE_CHARACTER_CLASS — caught in testing). Produces
     identical token IDs to HF's Qwen2TokenizerFast on the test phrase:
     [81581, 11, 4759, 35631, 730, 9832, 685, 11, 4759, 35631, 37915,
      4914, 9012, 90229, 2676, 13].

  2. export_tts_text_embeddings.py — one-time PC export of:
     * Full projected text embeddings for the entire 151936-token vocab
       as fp16 (297 MB). Sanity check: live vs stored max abs diff
       1.15e-4 on token 1043. Mmap'd on-device so it stays off the
       Java heap and leaves room for the 125 MB cp_embeddings alloc.
     * Damien voice PREFIX (9 × 1024 fp32) — positions 0..8 of a
       Python voice-clone capture, text-invariant across segments.
     * Damien voice SUFFIX (2 × 1024 fp32) — positions nP-2..nP-1
       of the same capture. Also text-invariant (diff = 0.0 across
       3 different-text segments). Without it the talker never sees
       "text ended" and decode falls into page/beg repetition.
     * Qwen3 tokenizer vocab.json + merges.txt.

  3. Qwen3TtsEngine.kt:
     * mmap loader for the embeddings table + buffered fp16→fp32
       lookup (halfToFloat covers subnormals/inf/NaN so pathological
       tokens don't become 0).
     * Stage 2 assets detected at init; missing file transparently
       falls back to legacy 1050-token reduced-vocab path.
     * synthesizeTextStreaming(text, onSegmentReady) — new public API:
       sentence-split → BPE → build prefill as
         [voice prefix] + [text_proj(id) + codec_pad] × N + [voice suffix]
       (exact structure Python emits; verified bit-for-bit by matching
       captured Baer prefill positions against text_projection(tok)+
       codec_embedding(CODEC_PAD)) → runHexGenWithPrefill → decode
       each segment through the existing BigVGAN pipeline → callback.
     * runHexGenWithPrefill — Hexagon prefill + interleaved CP decode
       loop. Feeds tts_eos once, tts_pad thereafter (same schedule as
       Python's voice_clone). Degeneracy guard stops when 9 identical
       cb0 in a row appear — catches the rare "page beg beg beg" tail
       when EOS never fires. maxGen = ids.size*4 + 10 matches the
       typical 3.3 codec-frames-per-text-token that Python produces.
     * Prefill build uses the speaker's captured prefix/suffix rather
       than the legacy in-code buildPrefillEmbeddings that puts only
       one text token in prefill — the structure mismatch produced
       garbled audio in the first attempt of this commit.

  4. KazeiaService.kt: new stream_text intent extra wires text input
     to synthesizeTextStreaming with an AudioTrack MODE_STREAM consumer.
     First-audio latency on the "Bonjour..." test: ~23 s on Snapdragon
     8 Elite (prefill + 74-token decode), vs a 3-phrase sentence batch
     that was 65 s pre-streaming — streaming + on-device text together
     unblock the MVP chat loop.

Known caveats:
  * 297 MB on-device footprint for the embedding table. Acceptable on
    OnePlus Pad 3; can be quantized further (int8 per-row) if storage
    becomes tight.
  * First init adds ~3 s for BPE vocab + merges load (151k × 2 hash-
    maps). Happens once per process.
  * maxGen cap means extremely long sentences may truncate. The
    sentence splitter already keeps segments ≤120 chars so this
    hasn't been observed in practice.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-13 10:12:09 +02:00
parent 5e416713ce
commit 7f1a44c23d
4 changed files with 759 additions and 1 deletions

View File

@ -123,6 +123,47 @@ class KazeiaService : Service() {
// Audio is played by the TTS engine internally // Audio is played by the TTS engine internally
} }
} }
intent?.getStringExtra("stream_text")?.let { text ->
// Stage 2 streaming from arbitrary text: BPE tokenize on-device,
// look up embeds in the full Qwen3 vocab, run the existing
// interleaved Hexagon generation loop, and play each segment
// as soon as it's decoded. No PC-side prep required.
log("Stream text: '${text.take(60)}${if (text.length>60) "..." else ""}'")
serviceScope.launch {
try {
val qwenTts = tts as? com.kazeia.tts.Qwen3TtsEngine ?: return@launch
val sr = 24000
val track = android.media.AudioTrack.Builder()
.setAudioAttributes(android.media.AudioAttributes.Builder()
.setUsage(android.media.AudioAttributes.USAGE_MEDIA)
.setContentType(android.media.AudioAttributes.CONTENT_TYPE_SPEECH)
.build())
.setAudioFormat(android.media.AudioFormat.Builder()
.setEncoding(android.media.AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(sr)
.setChannelMask(android.media.AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(sr * 4)
.setTransferMode(android.media.AudioTrack.MODE_STREAM)
.build()
track.play()
val tStart = System.currentTimeMillis()
var firstLogged = false
qwenTts.synthesizeTextStreaming(text) { segIdx, audio ->
if (!firstLogged) {
log("First audio out at ${System.currentTimeMillis() - tStart}ms (seg ${segIdx+1})")
firstLogged = true
}
track.write(audio, 0, audio.size)
}
track.stop(); track.release()
log("Stream text done at ${System.currentTimeMillis() - tStart}ms")
} catch (e: Exception) {
log("Stream text error: ${e.message}")
e.printStackTrace()
}
}
}
intent?.getStringExtra("stream_pipeline")?.let { embedsPath -> intent?.getStringExtra("stream_pipeline")?.let { embedsPath ->
// Stage 1 streaming pipeline: generate segment-by-segment and play each // Stage 1 streaming pipeline: generate segment-by-segment and play each
// segment the moment it's ready via an AudioTrack MODE_STREAM. First audio // segment the moment it's ready via an AudioTrack MODE_STREAM. First audio

View File

@ -0,0 +1,219 @@
package com.kazeia.tts
import android.util.Log
import org.json.JSONObject
import java.io.File
import java.util.regex.Pattern
/**
* Byte-level BPE tokenizer compatible with Qwen2/Qwen3 tokenizer.
*
* Why this file exists:
* The app needs to tokenize arbitrary LLM-generated French text on-device
* for the TTS pipeline (the reduced 1050-token table shipped previously
* couldn't handle free-form text). The reference is the HuggingFace
* Qwen2TokenizerFast a GPT-2-style byte-level BPE. We reimplement it
* in Kotlin rather than linking libtokenizers.so so there's no new
* native dependency and the numbers are easy to audit.
*
* Algorithm (identical to GPT-2 / Qwen2):
* 1. Pre-tokenize the input with a regex that groups contractions,
* words, numbers, and whitespace into "word chunks" that BPE never
* merges across.
* 2. UTF-8 encode each chunk, then map each byte (0..255) to one of
* 256 printable Unicode code points via a fixed "byte encoder" table
* this is the trick that lets a byte-level vocab fit inside JSON
* without invalid or control characters.
* 3. Apply BPE: repeatedly find the pair with the lowest rank in the
* merges list and merge it, until no more merges apply. Look up each
* resulting super-token in vocab.json to get the final token IDs.
*
* Bit-perfect compatibility notes:
* - The pre-tokenize regex below MUST match the one in tokenizer_config.json.
* Qwen2 uses the GPT-2 pattern with a couple of Unicode property
* extensions; Java regex supports these directly.
* - Byte encoder is canonical (see bytesToUnicode()).
* - Merges are rank-ordered: lower line number = higher priority, matching
* HuggingFace's `merges.txt` file ordering.
* - We do NOT add BOS/EOS or chat-template specials the TTS prefill
* prepends its own 9-embed voice prefix that already handles role tokens.
*/
class Qwen3BpeTokenizer private constructor(
private val vocab: HashMap<String, Int>,
private val merges: HashMap<Pair<String, String>, Int>,
private val byteEncoder: IntArray,
) {
companion object {
private const val TAG = "Qwen3BPE"
// Qwen2/Qwen3 pre-tokenization regex. This is the exact pattern used
// by the HuggingFace Qwen2Tokenizer (adapted from GPT-2). Matches
// contractions ('s|'d|'ll|'ve|'re|'t), word chunks with a leading
// optional non-word char + letters, runs of digits, runs of
// punctuation, and whitespace runs with boundary semantics.
//
// Note on the character classes: Android's Pattern does NOT support
// UNICODE_CHARACTER_CLASS — so plain \\p{L} / \\p{N} would collapse
// to ASCII-only and break French accents ("é" → wrong token). Use
// \\p{IsAlphabetic} and \\p{IsDigit} instead; those ARE Unicode-aware
// out of the box (they map to ICU's IsAlphabetic / IsDigit properties
// in both JDK and Android runtimes). Output matches Python's Qwen2
// tokenizer on French text.
private val PRE_TOKENIZE_PATTERN: Pattern = Pattern.compile(
"'s|'d|'ll|'ve|'re|'t" +
"|[^\\r\\n\\p{IsAlphabetic}\\p{IsDigit}]?\\p{IsAlphabetic}+" +
"|\\p{IsDigit}{1,3}" +
"| ?[^\\s\\p{IsAlphabetic}\\p{IsDigit}]+[\\r\\n]*" +
"|\\s*[\\r\\n]+" +
"|\\s+(?!\\S)" +
"|\\s+"
)
/**
* GPT-2 byte encoder: maps 0..255 a printable Unicode codepoint.
* Ensures every possible byte has a visible, JSON-safe representation
* so a byte-level vocab can be stored as strings in vocab.json.
*/
private fun bytesToUnicode(): IntArray {
val bs = mutableListOf<Int>()
// Printable ASCII and common Latin blocks.
bs.addAll(('!'.code..'~'.code).toList())
bs.addAll(('¡'.code..'¬'.code).toList())
bs.addAll(('®'.code..'ÿ'.code).toList())
val cs = bs.toMutableList()
// Every byte not in bs gets mapped to a code point past 255 so no
// existing character collides with it.
var n = 0
val map = IntArray(256)
for (b in 0..255) {
if (b in bs) {
map[b] = b
} else {
map[b] = 256 + n
bs.add(b) // placeholder only, we've already recorded cs from the frozen snapshot
cs.add(256 + n)
n += 1
}
}
return map
}
fun load(modelDir: String): Qwen3BpeTokenizer {
val t0 = System.currentTimeMillis()
val vocabFile = File("$modelDir/vocab.json")
val mergesFile = File("$modelDir/merges.txt")
require(vocabFile.exists()) { "vocab.json missing at $modelDir" }
require(mergesFile.exists()) { "merges.txt missing at $modelDir" }
val vocabJson = JSONObject(vocabFile.readText())
val vocab = HashMap<String, Int>(vocabJson.length())
val keys = vocabJson.keys()
while (keys.hasNext()) {
val k = keys.next()
vocab[k] = vocabJson.getInt(k)
}
val merges = HashMap<Pair<String, String>, Int>()
var rank = 0
mergesFile.useLines { lines ->
for (line in lines) {
// Skip header / blanks. Qwen's merges.txt starts with
// "#version" which we simply filter out.
if (line.isBlank() || line.startsWith("#")) continue
val sp = line.indexOf(' ')
if (sp < 0) continue
merges[Pair(line.substring(0, sp), line.substring(sp + 1))] = rank
rank++
}
}
Log.i(TAG, "Loaded vocab=${vocab.size} merges=${merges.size} in ${System.currentTimeMillis()-t0}ms")
return Qwen3BpeTokenizer(vocab, merges, bytesToUnicode())
}
}
/**
* Convert a single pre-tokenized word (UTF-8 bytes encoded via the byte
* encoder into a string) into token IDs via BPE merges. Caches results so
* repeated common words (spaces, punctuation) only BPE once.
*/
private val bpeCache = HashMap<String, IntArray>()
private fun bpeEncode(byteEncodedWord: String): IntArray {
bpeCache[byteEncodedWord]?.let { return it }
// Start with one "sub-token" per Unicode code point (code points,
// not chars — surrogate pairs are handled automatically since the
// byte encoder only produces BMP codepoints by construction).
val parts = ArrayList<String>(byteEncodedWord.length)
var i = 0
while (i < byteEncodedWord.length) {
val cp = byteEncodedWord.codePointAt(i)
parts.add(String(Character.toChars(cp)))
i += Character.charCount(cp)
}
if (parts.size < 2) {
val id = vocab[parts.getOrElse(0) { "" }] ?: vocab["<unk>"] ?: 0
val out = intArrayOf(id)
bpeCache[byteEncodedWord] = out
return out
}
// Greedy lowest-rank merge, classic BPE. We scan for the pair with
// the smallest rank, merge ALL its occurrences, then repeat. This
// matches HF's reference implementation.
while (parts.size > 1) {
var bestRank = Int.MAX_VALUE
var bestIdx = -1
for (k in 0 until parts.size - 1) {
val r = merges[Pair(parts[k], parts[k + 1])] ?: continue
if (r < bestRank) { bestRank = r; bestIdx = k }
}
if (bestIdx < 0) break
// Merge all non-overlapping occurrences of that exact pair.
val a = parts[bestIdx]; val b = parts[bestIdx + 1]
val merged = a + b
var k = 0
val out = ArrayList<String>(parts.size - 1)
while (k < parts.size) {
if (k < parts.size - 1 && parts[k] == a && parts[k + 1] == b) {
out.add(merged); k += 2
} else {
out.add(parts[k]); k += 1
}
}
parts.clear(); parts.addAll(out)
}
val ids = IntArray(parts.size)
for (k in parts.indices) {
ids[k] = vocab[parts[k]] ?: vocab["<unk>"] ?: 0
}
bpeCache[byteEncodedWord] = ids
return ids
}
/**
* Encode text token IDs using the full Qwen3 vocabulary. Does NOT
* prepend BOS/EOS callers can add specials themselves. Unicode
* characters outside ASCII (e.g. French accents) are UTF-8 encoded and
* go through the byte encoder, so "é" and "ï" tokenize the same way as
* they do in Python.
*/
fun encode(text: String): IntArray {
val all = ArrayList<Int>(text.length / 2 + 4)
val matcher = PRE_TOKENIZE_PATTERN.matcher(text)
while (matcher.find()) {
val chunk = matcher.group()
// UTF-8 encode the chunk, then map each raw byte to its Unicode
// "byte-encoded" character. This produces the exact string that
// BPE merges operate on.
val bytes = chunk.toByteArray(Charsets.UTF_8)
val sb = StringBuilder(bytes.size)
for (b in bytes) sb.appendCodePoint(byteEncoder[b.toInt() and 0xff])
val ids = bpeEncode(sb.toString())
for (id in ids) all.add(id)
}
return all.toIntArray()
}
}

View File

@ -100,8 +100,21 @@ class Qwen3TtsEngine(
private var decoderOnGpu: Boolean = false private var decoderOnGpu: Boolean = false
// Dual embedding tables for talker input // Dual embedding tables for talker input
private var textEmbeds: FloatArray? = null // [1050, 1024] - pre-projected text embeddings private var textEmbeds: FloatArray? = null // [1050, 1024] - reduced vocab, legacy fallback
private var codecEmbedding: FloatArray? = null // [3072, 1024] - codec/control token embeddings private var codecEmbedding: FloatArray? = null // [3072, 1024] - codec/control token embeddings
// Stage 2 — on-device full-vocab text embeddings + BPE tokenizer.
// textEmbedsFull is 151936 × 1024 fp16 memory-mapped (~296 MB); using
// mmap keeps the bytes off the Java heap so the app doesn't crash when
// the ~125 MB cp_embeddings allocation comes next. damienVoicePrefix is
// the fixed 9-embed voice-cloning header that is prepended to the
// tokenized text to form a full prefill.
private var textEmbedsFullBuf: java.nio.ByteBuffer? = null
private var textEmbedsFullChan: java.nio.channels.FileChannel? = null
private val textEmbedsFullLen = 151936
private var damienVoicePrefix: Array<FloatArray>? = null
private var damienVoiceSuffix: Array<FloatArray>? = null
private var bpeTokenizer: Qwen3BpeTokenizer? = null
private var ttsBosEmbed: FloatArray? = null // [1024] - tts_bos text-side embedding private var ttsBosEmbed: FloatArray? = null // [1024] - tts_bos text-side embedding
private var ttsEosEmbed: FloatArray? = null // [1024] - tts_eos text-side embedding private var ttsEosEmbed: FloatArray? = null // [1024] - tts_eos text-side embedding
private var ttsPadEmbed: FloatArray? = null // [1024] - tts_pad text-side embedding private var ttsPadEmbed: FloatArray? = null // [1024] - tts_pad text-side embedding
@ -457,6 +470,72 @@ class Qwen3TtsEngine(
// Load dual embedding tables for talker // Load dual embedding tables for talker
textEmbeds = loadNpy("$path/text_embeds_projected.npy") textEmbeds = loadNpy("$path/text_embeds_projected.npy")
nlog("Text embeddings: ${textEmbeds!!.size / TALKER_DIM} × $TALKER_DIM") nlog("Text embeddings: ${textEmbeds!!.size / TALKER_DIM} × $TALKER_DIM")
// Stage 2 assets: full-vocab text embeddings + voice prefix + BPE.
// All three are optional — if any is missing we fall back to the
// legacy 1050-token path. This keeps the app bootable during
// asset rollout and avoids turning a missing file into a crash.
try {
val fullEmbFile = File("$path/text_embeds_full_fp16.bin")
val prefixFile = File("$path/damien_voice_prefix.bin")
val tokDir = File("$path/qwen3_tokenizer")
if (fullEmbFile.exists() && prefixFile.exists() && tokDir.isDirectory) {
val tE0 = System.currentTimeMillis()
// Memory-map the fp16 embeddings table instead of heap-
// loading it. Without mmap, the 296 MB ByteArray plus
// the 125 MB cp_embeddings FloatArray loaded a few
// lines below overrun the ~536 MB large-heap limit and
// the app OOMs during init. mmap pages the file via
// the kernel and keeps zero bytes on the Java heap.
val expectedBytes = textEmbedsFullLen.toLong() * TALKER_DIM * 2
if (fullEmbFile.length() != expectedBytes) {
nlog("text_embeds_full_fp16 size mismatch (got ${fullEmbFile.length()}, expected $expectedBytes) — disabling on-device text")
} else {
textEmbedsFullChan = java.io.RandomAccessFile(fullEmbFile, "r").channel
textEmbedsFullBuf = textEmbedsFullChan!!.map(
java.nio.channels.FileChannel.MapMode.READ_ONLY, 0L, expectedBytes
).order(ByteOrder.LITTLE_ENDIAN)
nlog("Full-vocab text embeddings mmap: ${textEmbedsFullLen} × $TALKER_DIM fp16 (${expectedBytes/1024/1024}MB, off-heap) in ${System.currentTimeMillis()-tE0}ms")
}
val pb = ByteBuffer.wrap(prefixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
val nPref = pb.int; val dimPref = pb.int
if (nPref == 9 && dimPref == TALKER_DIM) {
damienVoicePrefix = Array(9) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = pb.float } }
nlog("Damien voice prefix: $nPref × $dimPref")
} else {
nlog("damien_voice_prefix.bin header mismatch ($nPref × $dimPref) — disabling on-device text")
}
// Voice SUFFIX — the 2 fixed positions that close out
// the prefill after text tokens. Empirically invariant
// across segments of the same speaker (diff = 0.0).
val suffixFile = File("$path/damien_voice_suffix.bin")
if (suffixFile.exists()) {
val sb = ByteBuffer.wrap(suffixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
val nSuf = sb.int; val dimSuf = sb.int
if (nSuf == 2 && dimSuf == TALKER_DIM) {
damienVoiceSuffix = Array(2) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = sb.float } }
nlog("Damien voice suffix: $nSuf × $dimSuf")
}
} else {
nlog("damien_voice_suffix.bin missing — on-device text will lack closure markers")
}
if (textEmbedsFullBuf != null && damienVoicePrefix != null && damienVoiceSuffix != null) {
bpeTokenizer = Qwen3BpeTokenizer.load(tokDir.absolutePath)
nlog("Stage 2 on-device text path ready (BPE + full embeds + voice prefix)")
}
} else {
nlog("Stage 2 assets not fully present (full=$fullEmbFile, prefix=$prefixFile, tok=$tokDir) — legacy path only")
}
} catch (e: Exception) {
nlog("Stage 2 asset load failed: ${e.message} — legacy path only")
textEmbedsFullBuf = null; damienVoicePrefix = null; bpeTokenizer = null
try { textEmbedsFullChan?.close() } catch (_: Exception) {}
textEmbedsFullChan = null
}
codecEmbedding = loadNpy("$path/codec_embedding.npy") codecEmbedding = loadNpy("$path/codec_embedding.npy")
nlog("Codec embedding: ${codecEmbedding!!.size / TALKER_DIM} × $TALKER_DIM") nlog("Codec embedding: ${codecEmbedding!!.size / TALKER_DIM} × $TALKER_DIM")
val ttsSpecial = loadNpy("$path/tts_special_embeds.npy") // [3, 1024] = bos, eos, pad val ttsSpecial = loadNpy("$path/tts_special_embeds.npy") // [3, 1024] = bos, eos, pad
@ -3183,6 +3262,269 @@ class Qwen3TtsEngine(
return concat return concat
} }
/**
* Look up a single token in the fp16 full-vocab text-embedding table and
* return it as fp32. Uses direct ByteBuffer arithmetic so we don't
* allocate a new buffer per token for a typical 50-token sentence the
* inner loop runs 50 × 1024 fp16fp32 conversions.
*/
private fun textEmbFromFull(tokenId: Int): FloatArray {
val buf = textEmbedsFullBuf ?: error("Stage 2 full embeddings not loaded")
val clamped = tokenId.coerceIn(0, textEmbedsFullLen - 1)
val base = clamped * TALKER_DIM * 2
val out = FloatArray(TALKER_DIM)
synchronized(buf) {
// MappedByteBuffer has mutable position; guard in case two
// coroutines ever race on tokenizer output concurrently.
buf.position(base)
for (j in 0 until TALKER_DIM) {
val bits = buf.short
out[j] = halfToFloat(bits)
}
}
return out
}
/** IEEE 754 fp16 -> fp32 conversion. Handles subnormals, inf and NaN
* exactly the way Python's `torch.float16.to(float32)` does. */
private fun halfToFloat(h: Short): Float {
val bits = h.toInt() and 0xffff
val sign = (bits ushr 15) and 0x1
val exp = (bits ushr 10) and 0x1f
val mant = bits and 0x3ff
val f32: Int = when {
exp == 0 && mant == 0 -> sign shl 31
exp == 0 -> {
// Subnormal: normalize by shifting until leading 1 appears.
var e = -14; var m = mant
while ((m and 0x400) == 0) { m = m shl 1; e -= 1 }
val e32 = e + 127
(sign shl 31) or (e32 shl 23) or ((m and 0x3ff) shl 13)
}
exp == 0x1f -> (sign shl 31) or (0xff shl 23) or (mant shl 13) // inf / NaN
else -> (sign shl 31) or ((exp - 15 + 127) shl 23) or (mant shl 13)
}
return Float.fromBits(f32)
}
/**
* Run the Hexagon talker + CP generation loop with a fully pre-built
* prefill (voice prefix + all text tokens). Same decode recipe as
* runInterleavedHexagon's inner loop: at each step the next talker
* input is codecSum (codec embedding of previous codes) + tts_pad
* (since all text has already been consumed in the prefill). On EOS,
* terminate early. Returns the generated [step, codebook] codes.
*/
private fun runHexGenWithPrefill(prefill: List<FloatArray>, maxGen: Int): Array<IntArray> {
val padE = ttsPadEmbed ?: return emptyArray()
val eosE = ttsEosEmbed ?: return emptyArray()
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var totalTalkerMs = 0L; var totalCpMs = 0L
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(prefill)
nlog("VC prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
if (prefillResults.isEmpty()) return emptyArray()
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
nlog("VC prefill done: first cb0=$currentCb0")
// After the text has been fully consumed in prefill, Python's voice-
// clone loop feeds tts_eos once, then tts_pad for every subsequent
// decode step. We follow the same schedule so the model's attention
// sees the same "text exhausted" signal it was trained with.
var eosFedOnce = false
for (genStep in 0 until maxGen) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
val nextEmbed = if (!eosFedOnce) { eosFedOnce = true; sumEmb(codecSum, eosE) } else sumEmb(codecSum, padE)
val tT = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
totalTalkerMs += System.currentTimeMillis() - tT
if (results.isEmpty()) { nlog("VC: hex empty at step ${genStep+1}"); break }
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
currentCb0 = sampleTopK(logits, 0.9f, 50)
if (currentCb0 == CODEC_EOS) { nlog("VC EOS at step ${genStep+1}"); break }
// Degeneracy guard: when the talker fails to emit EOS it falls
// into a stuck loop where cb0 repeats (the "page beg beg beg"
// artifact audible at the tail of generated phrases). Nine
// consecutive identical cb0s is the threshold the native .pte
// pipeline uses too. The short history is just the last 9
// entries of generatedCb0 — cheap to scan.
val nHist = generatedCb0.size
if (nHist >= 9) {
val last = generatedCb0[nHist - 1]
var allSame = true
for (i in nHist - 9 until nHist) if (generatedCb0[i] != last) { allSame = false; break }
if (allSame && currentCb0 == last) {
nlog("VC degen: cb0=$last repeated ≥9× at step ${genStep+1}, stopping")
break
}
}
}
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
return allCodes.toTypedArray()
}
/**
* Split text into short segments for the streaming pipeline. Reproduces
* the behaviour of scripts/prepare_tts_segments.py but on-device so the
* tablet doesn't depend on a PC-side preprocessor. The 120-character
* target matches the length where the talker still terminates reliably
* on EOS; longer segments risk the auto-repressor's repetition penalty
* cutting decode short of the full phrase.
*/
private fun splitSentences(text: String, maxChars: Int = 120): List<String> {
val first = text.trim().split(Regex("(?<=[.!?;:])\\s+"))
val out = mutableListOf<String>()
for (part in first) {
if (part.length <= maxChars) {
if (part.isNotBlank()) out.add(part.trim())
continue
}
// Break overlong sentences at commas, greedily packing sub-parts
// back together up to maxChars so we don't over-split.
val subs = part.split(Regex("(?<=,)\\s+"))
var current = ""
for (s in subs) {
if (current.isNotEmpty() && current.length + s.length > maxChars) {
out.add(current.trim()); current = s
} else {
current = if (current.isEmpty()) s else "$current $s"
}
}
if (current.isNotBlank()) out.add(current.trim())
}
return if (out.isEmpty()) listOf(text) else out
}
/**
* Tokenize `text` on-device and stream the synthesized audio segment by
* segment via `onSegmentReady`. The PC-side prep script becomes optional
* this is the first path in the app where TTS runs fully offline on
* arbitrary LLM output.
*
* For each segment:
* 1. BPE tokenize the segment text (same algorithm as Python's
* Qwen2Tokenizer).
* 2. Look up each token ID in the fp16 full-vocab table, converted to
* fp32 on the fly. One embedding per token.
* 3. Call the existing runInterleavedHexagon loop which already
* synthesizes its own decode inputs via codec_sum + trailing text
* so we reuse the same prefill construction and generation path that
* runs today for the pre-computed-embeds test harness.
* 4. Decode codes audio via decodeChunked, emit the audio through
* the callback immediately, save WAV per segment plus the concat.
*
* Emits at most one callback per segment. First-audio latency prefill +
* one segment's decode (typically ~17-22s for a 5-6s-duration segment on
* Snapdragon 8 Elite). The ordering and gaps between segments are the
* same as generateFromEmbedsHexagonStreaming.
*/
fun synthesizeTextStreaming(
text: String,
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
): ShortArray {
if (!loaded || !useHexagonTalker) {
nlog("synthesizeTextStreaming: Hexagon talker not ready"); return ShortArray(0)
}
if (bpeTokenizer == null || textEmbedsFullBuf == null) {
nlog("synthesizeTextStreaming: Stage 2 assets missing"); return ShortArray(0)
}
val segments = splitSentences(text)
nlog("synthesizeTextStreaming: ${segments.size} segment(s) for ${text.length} chars")
hexReset()
val segmentAudios = mutableListOf<ShortArray>()
val gapSamples = SR * 120 / 1000
val gap = ShortArray(gapSamples)
val t0 = System.currentTimeMillis()
val prefix = damienVoicePrefix!!
val suffix = damienVoiceSuffix!!
// CODEC_PAD embedding is the per-token companion that Python voice-
// cloning sums into every text-encoded prefill position. Computed
// once here so the per-token loop stays a simple vector add.
val codecPadEmb = codecEmb(CODEC_PAD)
for ((segIdx, segText) in segments.withIndex()) {
if (segIdx > 0) hexReset()
val tSeg = System.currentTimeMillis()
val ids = bpeTokenizer!!.encode(segText)
nlog("Seg ${segIdx+1}/${segments.size}: '${segText.take(60)}' → ${ids.size} tokens: ${ids.toList()}")
// Voice-cloning prefill, fully reconstructed on-device — exact
// structure Python emits via generate_voice_clone (verified
// bit-for-bit by comparing captured Baer segments):
// [0..8] damienVoicePrefix (9 fixed positions, xvector@7)
// [9..N-3] text_projection(BPE_id) + codec_embedding(CODEC_PAD)
// [N-2, N-1] damienVoiceSuffix (2 fixed positions, end-of-text marker)
// The earlier attempt skipped the suffix and used raw text
// projections, which produced garbled audio — the talker needs
// BOTH the per-token codec_pad sum AND the closure markers to
// know that text input has ended and decoding can begin.
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
for (e in prefix) prefill.add(e)
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e)
// Empirical budget: Python's voice_clone typically emits ~3.3
// codec frames per text token for French. Keep a small cushion
// so ~80% of runs terminate via EOS/degeneracy before exhausting
// the budget; trimming is done by the degeneracy guard inside
// runHexGenWithPrefill. Too-generous maxGen guarantees the tail
// artifacts the user hears as "page beg beg beg".
val maxGen = minOf(ids.size * 4 + 10, MAX_CONTEXT - 15)
val codes = runHexGenWithPrefill(prefill, maxGen)
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
val n = codes.size
val padLen = maxOf(n, SEQ_LEN)
val codebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t ->
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
}
}
val audio = decodeChunked(codebooks, n)
val segMs = System.currentTimeMillis() - tSeg
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms")
segmentAudios.add(audio)
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
onSegmentReady?.invoke(segIdx, audio)
}
if (segmentAudios.isEmpty()) return ShortArray(0)
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
val concat = ShortArray(total)
var off = 0
for ((i, s) in segmentAudios.withIndex()) {
System.arraycopy(s, 0, concat, off, s.size); off += s.size
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
}
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
nlog("synthesizeTextStreaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s")
return concat
}
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to /** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
* save one file per segment plus the concatenated result for inspection. */ * save one file per segment plus the concatenated result for inspection. */
private fun saveWav(path: String, audio: ShortArray) { private fun saveWav(path: String, audio: ShortArray) {

View File

@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Export everything the tablet needs to build TTS prefill embeds for arbitrary
LLM text, offline, without talking to a PC.
Outputs (pushed to /data/local/tmp/kazeia/models/qwen3-tts-npu/):
- text_embeds_full_fp16.bin : 151936 × 1024 fp16 = 311 MB
Pre-projected text embeddings for the full Qwen3 vocab. Per-token
lookup on-device replaces a lookup + FC1 + SiLU + FC2 + bias. Same
numbers PyTorch produces for text_projection(text_embedding(id)).
- damien_voice_prefix.bin : 9 × 1024 fp32 = 36 KB
The fixed voice-cloning prefix (positions 0..8) for speaker Damien,
captured from a real voice-clone run. Positions 0..6 = role/control
tokens, position 7 = xvector (L2 norm ~10), position 8 = trailing
voice-marker. Same for every phrase uttered by this speaker, so we
capture once here and reuse indefinitely on-device.
- damien_voice_suffix.bin : 2 × 1024 fp32 = 8 KB
The fixed voice-cloning SUFFIX (last 2 positions of the prefill)
that Python emits AFTER the text tokens. Verified bit-identical
across segments of different texts invariant closure marker
for the voice-clone conditioning. Without it the talker misreads
the end of the text and produces garbled output.
- qwen3_tokenizer/ : tokenizer files copied from HF snapshot
tokenizer.json, vocab.json, merges.txt, special_tokens_map.json.
Kotlin BPE implementation reads vocab + merges at init.
The combination lets the tablet build, for any text, the exact same
prefill tensor PyTorch would build, bit-for-bit at fp16 which is
what our Hexagon talker consumes anyway.
Usage:
python3 export_tts_text_embeddings.py [output_dir]
"""
import sys, os, struct, shutil, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
OUTPUT_DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/kazeia_tts_export"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/qwen3_tokenizer", exist_ok=True)
import torch, numpy as np
from qwen_tts import Qwen3TTSModel
print("Loading Qwen3-TTS model (~30s, CPU)...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
# ---- 1. Full projected text embeddings ----
# Evaluate text_projection(text_embedding.weight) for EVERY vocab entry.
# Batching keeps peak memory bounded; fp32 matmul then fp16 store preserves
# precision up to the final quantization step.
print("\n[1/3] Precomputing projected embeddings for full vocab...")
vocab_size = talker.model.text_embedding.weight.shape[0]
print(f" Vocab size: {vocab_size}")
BATCH = 4096
out_path = f"{OUTPUT_DIR}/text_embeds_full_fp16.bin"
with torch.no_grad():
W_emb = talker.model.text_embedding.weight # [vocab, 2048]
fc1_w = talker.text_projection.linear_fc1.weight # [2048, 2048]
fc1_b = talker.text_projection.linear_fc1.bias # [2048]
fc2_w = talker.text_projection.linear_fc2.weight # [1024, 2048]
fc2_b = talker.text_projection.linear_fc2.bias # [1024]
with open(out_path, "wb") as f:
for start in range(0, vocab_size, BATCH):
end = min(start + BATCH, vocab_size)
x = W_emb[start:end].float() # [b, 2048]
h = torch.nn.functional.linear(x, fc1_w, fc1_b) # [b, 2048]
h = torch.nn.functional.silu(h) # [b, 2048]
y = torch.nn.functional.linear(h, fc2_w, fc2_b) # [b, 1024]
f.write(y.to(torch.float16).numpy().tobytes())
if start % (BATCH * 4) == 0:
print(f" {end}/{vocab_size} ({end*100//vocab_size}%)", flush=True)
sz_mb = os.path.getsize(out_path) / (1024*1024)
print(f" -> {out_path} ({sz_mb:.1f} MB)")
# Sanity check: re-read a couple of tokens, project live, compare.
print("\n Sanity check (token 1043 = 'Bonjour'):")
with torch.no_grad():
live = talker.text_projection(talker.model.text_embedding(torch.tensor([1043])))[0].float().numpy()
with open(out_path, "rb") as f:
f.seek(1043 * 1024 * 2)
stored = np.frombuffer(f.read(1024 * 2), dtype=np.float16).astype(np.float32)
diff = float(np.abs(live - stored).max())
print(f" max abs diff live vs stored fp16: {diff:.2e} (expect < 1e-3)")
# ---- 2. Damien voice prefix (positions 0..8) ----
# Run a voice-clone and capture the multi-token prefill call, then keep the
# first 9 rows. Those are fixed per speaker — same for every phrase — so
# one capture suffices for the app's lifetime.
print(f"\n[2/3] Capturing Damien voice prefix from {VOICE}...")
captured = []
call_shapes = []
original_forward = talker.model.forward
def patched(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.dim() == 3:
call_shapes.append(inputs_embeds.shape[1])
for i in range(inputs_embeds.shape[1]):
captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32))
return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
talker.model.forward = patched
# Any short sentence works — we only keep positions 0..8 which are text-
# invariant.
_ = tts.generate_voice_clone(
text="Bonjour, je suis Kazeia.", ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True,
)
talker.model.forward = original_forward
nP = call_shapes[0]
print(f" Prefill size: {nP} tokens")
prefix_9 = np.stack(captured[:9]) # [9, 1024]
suffix_2 = np.stack(captured[nP-2:nP]) # [2, 1024]
prefix_path = f"{OUTPUT_DIR}/damien_voice_prefix.bin"
with open(prefix_path, "wb") as f:
f.write(struct.pack("<i", 9))
f.write(struct.pack("<i", 1024))
f.write(prefix_9.astype(np.float32).tobytes())
print(f" prefix -> {prefix_path} ({os.path.getsize(prefix_path)} bytes)")
suffix_path = f"{OUTPUT_DIR}/damien_voice_suffix.bin"
with open(suffix_path, "wb") as f:
f.write(struct.pack("<i", 2))
f.write(struct.pack("<i", 1024))
f.write(suffix_2.astype(np.float32).tobytes())
print(f" suffix -> {suffix_path} ({os.path.getsize(suffix_path)} bytes)")
norms_pref = [float(np.linalg.norm(prefix_9[i])) for i in range(9)]
norms_suff = [float(np.linalg.norm(suffix_2[i])) for i in range(2)]
print(f" Prefix norms: {[f'{n:.2f}' for n in norms_pref]} (pos 7 = xvector ~10, others ~1.6-1.8)")
print(f" Suffix norms: {[f'{n:.2f}' for n in norms_suff]}")
# ---- 3. Tokenizer files ----
# Copy the HF tokenizer artefacts so a Kotlin BPE can reproduce Python
# encode() bit-for-bit.
print(f"\n[3/3] Copying tokenizer to {OUTPUT_DIR}/qwen3_tokenizer/...")
for name in ("tokenizer.json", "vocab.json", "merges.txt", "tokenizer_config.json", "special_tokens_map.json"):
src = os.path.join(MODEL, name)
if os.path.exists(src):
shutil.copy(src, f"{OUTPUT_DIR}/qwen3_tokenizer/{name}")
print(f" {name} ({os.path.getsize(src)} bytes)")
else:
print(f" (skipped, not present: {name})")
print(f"\n=== DONE ===")
print(f"Files ready in {OUTPUT_DIR}/")
print(f"\nPush to tablet:")
print(f" adb push {OUTPUT_DIR}/text_embeds_full_fp16.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/")
print(f" adb push {OUTPUT_DIR}/damien_voice_prefix.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/")
print(f" adb push {OUTPUT_DIR}/qwen3_tokenizer /data/local/tmp/kazeia/models/qwen3-tts-npu/")