kazeia/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt

3899 lines
194 KiB
Kotlin
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.kazeia.tts
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OnnxJavaType
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioTrack
import android.util.Log
import com.kazeia.BuildConfig
import com.kazeia.core.TtsEngine
import com.kazeia.core.TtsResult
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withContext
import java.io.File
import java.io.RandomAccessFile
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.FloatBuffer
import java.nio.LongBuffer
import java.nio.IntBuffer
import kotlin.coroutines.resume
/**
* Qwen3-TTS Engine — Full NPU pipeline for voice cloning TTS.
*
* Pipeline:
* 1. Talker NPU (ORT KV-cache decoder) → codec tokens (codebook 0)
* 2. Code Predictor NPU (QNN ONNX) → 16 codebooks
* 3. VQ decode (Kotlin) → quantized [1, 512, 60]
* 4. pre_conv NPU → preprocessor NPU → ConvNet decoder NPU → audio WAV
*
* All heavy computation on NPU. VQ decode is trivial CPU (codebook lookup).
*/
class Qwen3TtsEngine(
private val nativeLibDir: String,
private val onLog: ((String) -> Unit)? = null
) : TtsEngine {
companion object {
private const val TAG = "Qwen3TTS"
private const val SR = 24000
private const val SEQ_LEN = 60 // Fixed NPU chunk size for decoder
private const val CHUNK_OVERLAP = 10
private const val EFFECTIVE_CHUNK = SEQ_LEN - CHUNK_OVERLAP
private const val SAMPLES_PER_TOKEN = 1920
private const val CODEC_OFFSET = 1050
private const val NUM_CODEBOOKS = 16
private const val CODEBOOK_SIZE = 2048
private const val CODEBOOK_DIM = 256
private const val HIDDEN_DIM = 512
// Talker KV-cache model constants
private const val TALKER_DIM = 1024
private const val TALKER_VOCAB = 3072 // codec vocabulary size
private const val TALKER_LAYERS = 28
private const val TALKER_HEADS = 8
private const val TALKER_HEAD_DIM = 128
private const val KV_LEN = 199
private const val MAX_CONTEXT = 200 // KV_LEN + 1
private const val MASK_NEG = -10000f
// Code Predictor KV-cache constants
private const val CP_LAYERS = 5
private const val CP_KV_HEADS = 8
private const val CP_HEAD_DIM = 128
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
// Talker .pte constants
private const val TALKER_PTE_KV_LEN = 100 // must match .pte export (KV=64 caused quality loss)
// Codec special token IDs (in talker's 3072 vocab space)
private const val CODEC_EOS = 2150
private const val CODEC_BOS = 2149
private const val CODEC_PAD = 2148
private const val CODEC_THINK = 2154
private const val CODEC_NOTHINK = 2155
private const val CODEC_THINK_BOS = 2156
private const val CODEC_THINK_EOS = 2157
private const val CODEC_LANG_FR = 2061
// Chat template token IDs (reduced vocab)
private const val IM_START = 1048
private const val IM_END = 1049
private const val TOKEN_USER = 872
private const val TOKEN_ASSISTANT = 1042
private const val TOKEN_NEWLINE = 198
}
private var ortEnv: OrtEnvironment? = null
private var talkerKv: OrtSession? = null // Talker ONNX CPU (fallback)
private var useHexagonTalker: Boolean = false // Use ggml-hexagon runner
private var cpKv: OrtSession? = null // CP KV-cache ONNX CPU
private var preConv: OrtSession? = null
private var preprocessor: OrtSession? = null
private var convDecoder: OrtSession? = null
private var decoderOnCpu: Boolean = false
private var decoderOnGpu: Boolean = false
// Dual embedding tables for talker input
private var textEmbeds: FloatArray? = null // [1050, 1024] - reduced vocab, legacy fallback
private var codecEmbedding: FloatArray? = null // [3072, 1024] - codec/control token embeddings
// Stage 2 — on-device full-vocab text embeddings + BPE tokenizer.
// textEmbedsFull is 151936 × 1024 fp16 memory-mapped (~296 MB); using
// mmap keeps the bytes off the Java heap so the app doesn't crash when
// the ~125 MB cp_embeddings allocation comes next. damienVoicePrefix is
// the fixed 9-embed voice-cloning header that is prepended to the
// tokenized text to form a full prefill.
private var textEmbedsFullBuf: java.nio.ByteBuffer? = null
private var textEmbedsFullChan: java.nio.channels.FileChannel? = null
private val textEmbedsFullLen = 151936
private var damienVoicePrefix: Array<FloatArray>? = null
private var damienVoiceSuffix: Array<FloatArray>? = null
private var bpeTokenizer: Qwen3BpeTokenizer? = null
private var ttsBosEmbed: FloatArray? = null // [1024] - tts_bos text-side embedding
private var ttsEosEmbed: FloatArray? = null // [1024] - tts_eos text-side embedding
private var ttsPadEmbed: FloatArray? = null // [1024] - tts_pad text-side embedding
private var speakerEmbed: FloatArray? = null // [1024] - x-vector speaker embedding
// VQ codebooks (loaded from numpy)
private var firstCodebook: FloatArray? = null // [2048, 256]
private var restCodebooks: Array<FloatArray>? = null // 15 × [2048, 256]
private var firstOutputProj: FloatArray? = null // [512, 256]
private var restOutputProj: FloatArray? = null // [512, 256]
// Code predictor embeddings [15, 2048, 1024]
private var cpEmbeddings: FloatArray? = null
private var loaded = false
private var modelPath: String? = null
private var audioTrack: AudioTrack? = null
private var talkerUsesInt64Pos = false
private var talkerUsesCosSin = false
private var rotaryCos: FloatArray? = null
private var rotarySin: FloatArray? = null
private var cpUsesCosSin = false
private var cpRotaryCos: FloatArray? = null
private var cpRotarySin: FloatArray? = null
private var cpHeadsPath: String? = null // path dir for head_0..14.npy
private var cpHeadsCache: Array<FloatArray?>? = null // lazy-loaded heads cache (8MB each)
private var cpAllHeads: FloatArray? = null // all 15 heads concatenated [15*2048*1024] for batch NEON
private var cpPteModule: org.pytorch.executorch.Module? = null // ExecuTorch CP on NPU (JNI)
private var talkerPteModule: org.pytorch.executorch.Module? = null // ExecuTorch talker on NPU (JNI)
private var nativePipelineReady: Boolean = false // C++ native pipeline available
private var talkerPteRotaryCos: FloatArray? = null
private var talkerPteRotarySin: FloatArray? = null
private var useEtCp: Boolean = false // CP via ExecuTorch runner process (root)
private var cpEtSocket: java.net.Socket? = null
private var cpFixedKv: Boolean = false // GPU uses fixed-size KV (shift + mask)
private var debugLogFile: java.io.File? = null
/**
* Check a diagnostic flag file. Returns true only in DEBUG builds; release
* builds always return false so the file system check and any diagnostic
* code path is dead code (JIT-eliminated). Used for investigations that
* were valuable during development but must not leak into production:
* force_inject_pycb0, force_greedy_cb0, cb0_temp, force_python_codes,
* force_cpu_talker, force_cpu_talker_gguf.
*
* The cross-architecture tremor investigation (thesis-relevant) showed
* x86 Python ↔ ARM64 tablet is a numerical floor (~28% cb0 divergence),
* not fixable by runtime swap. These flags stay in tree for re-running
* the experiment on demand (demos, thesis figures), not for production.
*/
private fun diagFlag(name: String): Boolean =
BuildConfig.DEBUG && File("/data/local/tmp/kazeia/$name").exists()
private fun diagFile(name: String): File? =
if (BuildConfig.DEBUG) File("/data/local/tmp/kazeia/$name").takeIf { it.exists() } else null
private fun nlog(msg: String) {
Log.i(TAG, msg)
Log.w(TAG, msg)
onLog?.invoke("[TTS] $msg")
try {
val f = debugLogFile ?: java.io.File("/data/local/tmp/kazeia/tts_debug.log")
f.appendText("${System.currentTimeMillis()} $msg\n")
} catch (_: Exception) {
// If /data/local/tmp fails, try app-internal dir
try {
val f2 = java.io.File("/data/local/tmp/kazeia/models/qwen3-tts-npu/tts_debug.log")
f2.appendText("${System.currentTimeMillis()} $msg\n")
debugLogFile = f2
} catch (_: Exception) {}
}
}
override suspend fun load(modelPath: String?, voiceId: String?) {
withContext(Dispatchers.IO) {
val path = modelPath ?: return@withContext
this@Qwen3TtsEngine.modelPath = path
try {
val t0 = System.currentTimeMillis()
ortEnv = OrtEnvironment.getEnvironment()
val htpPath = "$nativeLibDir/libQnnHtp.so"
val qnnCacheDir = "$path/qnn_cache"
File(qnnCacheDir).mkdirs()
fun loadQnn(name: String, fp16: Boolean = false): OrtSession {
val t = System.currentTimeMillis()
val opts = OrtSession.SessionOptions()
val qnnOpts = mutableMapOf(
"backend_path" to htpPath,
"qnn_context_cache_enable" to "1",
"qnn_context_cache_path" to "$qnnCacheDir/${name}.bin"
)
if (fp16) {
qnnOpts["htp_graph_finalization_optimization_mode"] = "3"
qnnOpts["enable_htp_fp16_precision"] = "1"
}
opts.addQnn(qnnOpts)
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
val mode = if (fp16) "QNN fp16" else "QNN"
nlog("$name ($mode): ${System.currentTimeMillis() - t}ms")
return session
}
fun loadCpu(name: String, threads: Int = 8): OrtSession {
val t = System.currentTimeMillis()
val opts = OrtSession.SessionOptions()
opts.setIntraOpNumThreads(threads)
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
nlog("$name (CPU ${threads}T): ${System.currentTimeMillis() - t}ms")
return session
}
val gpuLib = "$nativeLibDir/libQnnGpu.so"
val hasGpu = File(gpuLib).exists()
fun loadGpu(onnxPath: String, label: String): OrtSession {
val t = System.currentTimeMillis()
val opts = OrtSession.SessionOptions()
opts.addQnn(mapOf("backend_path" to gpuLib))
val session = ortEnv!!.createSession(onnxPath, opts)
nlog("$label (GPU): ${System.currentTimeMillis() - t}ms")
return session
}
// Speech decoder V2 on CPU. Two paths tried, both worse than CPU:
// - HTP: BigVGAN convolutions too slow to compile (timeout)
// - GPU Adreno via QNN GPU EP: model loads but per-phrase
// inference is ~3.5 s vs ~2 s on CPU (GPU/CPU memory transfer
// overhead dominates for this conv-heavy model)
// CPU 8-thread stays the practical optimum.
val v2Path = "$path/v2_pre_conv"
if (File("$v2Path/model.onnx").exists()) {
nlog("Loading V2 speech decoder (CPU ONNX)...")
preConv = loadCpu("v2_pre_conv", 4)
preprocessor = loadCpu("v2_pre_transformer", 4)
convDecoder = loadCpu("v2_decoder_conv", 8) // BigVGAN benefits from more threads
decoderOnCpu = true
nlog("Speech decoder V2 on CPU")
} else {
nlog("Loading speech decoder (QNN HTP)...")
preConv = loadQnn("pre_conv")
preprocessor = loadQnn("preprocessor")
convDecoder = loadQnn("conv_decoder")
nlog("Speech decoder on HTP")
}
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
// Test overrides: skip backends to expose lower-priority talker paths.
// force_hexagon → skip .pte, use ggml-hexagon HMX
// force_cpu_talker → skip .pte AND Hexagon, use the CPU ONNX talker
// (used as an "EOS oracle" path because fp32 still
// produces natural codec_eos where fp16 NPU drifts.)
// force_hexagon is the production routing flag (not diagnostic): keep
// unconditional. force_cpu_talker is diagnostic (used to verify CPU-fp32
// equivalence during the tremor investigation) → DEBUG only.
val forceHexagon = File("/data/local/tmp/kazeia/force_hexagon").exists()
val forceCpuTalker = diagFlag("force_cpu_talker")
if (forceHexagon) nlog("force_hexagon flag detected: skipping .pte init")
if (forceCpuTalker) nlog("force_cpu_talker flag detected: skipping .pte AND hexagon init")
// Load .pte modules via Java JNI (native pipeline uses same instances).
// CP .pte can load regardless of forceHexagon/forceCpuTalker — it only touches
// the NPU for its own graph. The talker .pte is separately gated below.
// Running CP on the NPU alongside a Hexagon talker MAY trigger DSP contention
// (memory note: "ggml-hexagon and ExecuTorch QNN cannot share CDSP"). This
// is exactly what we're testing.
run {
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
if (!forceCpuTalker && etModel.exists() && cpPteModule == null) {
try {
val t0 = System.currentTimeMillis()
cpPteModule = org.pytorch.executorch.Module.load(
etModel.absolutePath,
org.pytorch.executorch.Module.LOAD_MODE_FILE,
1
)
nlog("CP .pte JNI loaded: ${System.currentTimeMillis() - t0}ms")
val t1 = System.currentTimeMillis()
val lmResult = cpPteModule!!.loadMethod("forward")
nlog("CP .pte loadMethod: ${System.currentTimeMillis() - t1}ms, result=$lmResult")
if (lmResult != 0) {
nlog("CP .pte loadMethod failed ($lmResult), disabling JNI")
cpPteModule = null
} else {
// Register CP module for native pipeline
cpPteModule!!.nativeSetCpModule()
nlog("CP module registered for native pipeline")
}
} catch (e: Exception) {
nlog("CP .pte JNI failed: ${e.message}")
cpPteModule = null
}
}
// Load talker .pte JNI (same QNN context, no DSP contention with CP .pte)
val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte")
if (!forceHexagon && !forceCpuTalker && talkerPte.exists() && cpPteModule != null && talkerPteModule == null) {
try {
val t0 = System.currentTimeMillis()
talkerPteModule = org.pytorch.executorch.Module.load(
talkerPte.absolutePath,
org.pytorch.executorch.Module.LOAD_MODE_FILE,
1
)
val lm = talkerPteModule!!.loadMethod("forward")
nlog("Talker .pte JNI loaded+compiled: ${System.currentTimeMillis() - t0}ms, result=$lm")
if (lm != 0) { nlog("Talker .pte loadMethod failed"); talkerPteModule = null }
else {
val path = "/data/local/tmp/kazeia/models"
talkerPteRotaryCos = loadNpy("$path/talker_pte_rotary_cos.npy")
talkerPteRotarySin = loadNpy("$path/talker_pte_rotary_sin.npy")
nlog("Talker .pte rotary: ${talkerPteRotaryCos?.size} floats")
// Warmup both models: first forward() triggers QNN DSP compilation (~7s)
// Better to pay this cost at init than at first pipeline run
val tw = System.currentTimeMillis()
try {
val dE = FloatArray(TALKER_DIM)
val dM = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }; dM[TALKER_PTE_KV_LEN - 1] = 0f
val dC = FloatArray(TALKER_HEAD_DIM) { 1f }
val dS = FloatArray(TALKER_HEAD_DIM)
val tkvSz = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
val ins = mutableListOf(
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dE, longArrayOf(1,1,TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dM, longArrayOf(1,1,1,TALKER_PTE_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dC, longArrayOf(1,1,TALKER_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dS, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
)
for (i in 0 until TALKER_LAYERS * 2) ins.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(FloatArray(tkvSz), longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong()))))
talkerPteModule!!.forward(*ins.toTypedArray())
nlog("Talker warmup: ${System.currentTimeMillis() - tw}ms")
} catch (e: Exception) { nlog("Talker warmup failed: ${e.message}") }
// CP warmup
val cw = System.currentTimeMillis()
try {
val ckvSz = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
val cIns = mutableListOf(
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(TALKER_DIM), longArrayOf(1,1,TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_KV_LEN){-1e9f}.also{it[CP_KV_LEN-1]=0f}, longArrayOf(1,1,1,CP_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_HEAD_DIM){1f}, longArrayOf(1,1,CP_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_HEAD_DIM), longArrayOf(1,1,CP_HEAD_DIM.toLong())))
)
for (i in 0 until CP_LAYERS * 2) cIns.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(FloatArray(ckvSz), longArrayOf(1,CP_KV_HEADS.toLong(),CP_KV_LEN.toLong(),CP_HEAD_DIM.toLong()))))
cpPteModule!!.forward(*cIns.toTypedArray())
nlog("CP warmup: ${System.currentTimeMillis() - cw}ms")
} catch (e: Exception) { nlog("CP warmup failed: ${e.message}") }
// Native pipeline already initialized above
}
} catch (e: Exception) {
nlog("Talker .pte JNI failed: ${e.message}")
talkerPteModule = null
}
}
}
// Talker: skip Hexagon if .pte talker+CP are both loaded (avoids DSP contention)
val talkerT = System.currentTimeMillis()
val hexRunner = File("/data/local/tmp/kazeia/llama-hex/llama-tts-talker")
val hexModel = File("/data/local/tmp/kazeia/models/talker_f16.gguf")
val cpuOnnx = File("$path/talker_kv_cpu/model.onnx")
if (talkerPteModule != null && cpPteModule != null) {
nlog("Talker+CP using .pte JNI NPU — skipping Hexagon runners")
} else if (forceCpuTalker) {
nlog("force_cpu_talker: skipping Hexagon, will use CPU ONNX")
} else {
if (hexRunner.exists() && hexModel.exists()) {
if (hexStartRunner()) {
useHexagonTalker = true
nlog("Talker using Hexagon NPU (HMX FP16)")
} else {
nlog("Hexagon talker runner failed, using CPU")
}
}
val hexCpRunner = File("/data/local/tmp/kazeia/llama-hex/llama-tts-cp")
val hexCpModel = File("/data/local/tmp/kazeia/models/cp_f16.gguf")
if (hexCpRunner.exists() && hexCpModel.exists()) {
if (hexStartCpRunner()) {
useHexagonCp = true
nlog("CP using CPU GGUF (avoids Hexagon HMX NaN bug)")
} else {
nlog("CP CPU runner failed, falling back to ONNX")
}
}
}
// Fallback: CPU ONNX for talker if hexagon failed
if (!useHexagonTalker) {
// Try new M-RoPE ONNX: GPU Adreno first (fast fp16), fallback CPU
val mropeOnnx = File("$path/talker_kv_cpu/model.onnx")
if (mropeOnnx.exists() && mropeOnnx.length() > 1_000_000) {
val talkerOpts = OrtSession.SessionOptions()
val gpuLib = "$nativeLibDir/libQnnGpu.so"
if (File(gpuLib).exists()) {
talkerOpts.addQnn(mapOf("backend_path" to gpuLib))
nlog("Talker ONNX: loading on GPU Adreno...")
} else {
talkerOpts.setIntraOpNumThreads(6)
talkerOpts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
nlog("Talker ONNX: loading on CPU 6T...")
}
talkerKv = ortEnv!!.createSession(mropeOnnx.absolutePath, talkerOpts)
talkerUsesCosSin = true
// Load rotary tables
val cosFile = File("$path/talker_kv_cpu/talker_rotary_cos.npy")
val sinFile = File("$path/talker_kv_cpu/talker_rotary_sin.npy")
if (cosFile.exists()) rotaryCos = loadNpy(cosFile.absolutePath)
if (sinFile.exists()) rotarySin = loadNpy(sinFile.absolutePath)
nlog("talker_kv M-RoPE (CPU 6T): ${System.currentTimeMillis() - talkerT}ms, cos/sin=${rotaryCos?.size}")
} else if (cpuOnnx.exists()) {
val talkerOpts = OrtSession.SessionOptions()
talkerOpts.setIntraOpNumThreads(6)
talkerOpts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
talkerKv = ortEnv!!.createSession(cpuOnnx.absolutePath, talkerOpts)
talkerUsesInt64Pos = true
nlog("talker_kv legacy (CPU fp32 6T): ${System.currentTimeMillis() - talkerT}ms")
} else {
nlog("WARNING: No talker model available (no hexagon, no CPU ONNX)")
}
}
// Always try to load CP V2 ONNX Runtime as an optional routing target.
// Previously gated on !useHexagonCp (fallback only); now loaded unconditionally
// so force_cp_v2 flag can divert to ORT's MLAS kernels at runtime.
run {
val cpV2 = File("$path/cp_kv_v2/model.onnx")
if (cpV2.exists() && cpKv == null) {
try {
if (!useHexagonCp && (cpPteModule == null || (useHexagonTalker && talkerPteModule == null))) {
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
val etRunner = File("/data/local/tmp/kazeia/cp_et_runner")
if (etRunner.exists() && etModel.exists()) {
if (hexStartCpEtRunner()) {
useEtCp = true
nlog("CP using ExecuTorch NPU runner (TCP, root)")
} else {
nlog("CP ET runner failed to start")
}
}
}
cpKv = loadCpu("cp_kv_v2")
cpUsesCosSin = true
cpRotaryCos = loadNpy("$path/cp_kv_v2/cp_rotary_cos.npy")
cpRotarySin = loadNpy("$path/cp_kv_v2/cp_rotary_sin.npy")
cpHeadsPath = "$path/cp_kv_v2/cp_heads.npy"
nlog("CP V2 ONNX loaded, cos/sin=${cpRotaryCos?.size}")
} catch (e: Exception) {
nlog("CP V2 ONNX failed: ${e.message}")
}
}
}
// Load dual embedding tables for talker
textEmbeds = loadNpy("$path/text_embeds_projected.npy")
nlog("Text embeddings: ${textEmbeds!!.size / TALKER_DIM} × $TALKER_DIM")
// Stage 2 assets: full-vocab text embeddings + voice prefix + BPE.
// All three are optional — if any is missing we fall back to the
// legacy 1050-token path. This keeps the app bootable during
// asset rollout and avoids turning a missing file into a crash.
try {
val fullEmbFile = File("$path/text_embeds_full_fp16.bin")
val prefixFile = File("$path/damien_voice_prefix.bin")
val tokDir = File("$path/qwen3_tokenizer")
if (fullEmbFile.exists() && prefixFile.exists() && tokDir.isDirectory) {
val tE0 = System.currentTimeMillis()
// Memory-map the fp16 embeddings table instead of heap-
// loading it. Without mmap, the 296 MB ByteArray plus
// the 125 MB cp_embeddings FloatArray loaded a few
// lines below overrun the ~536 MB large-heap limit and
// the app OOMs during init. mmap pages the file via
// the kernel and keeps zero bytes on the Java heap.
val expectedBytes = textEmbedsFullLen.toLong() * TALKER_DIM * 2
if (fullEmbFile.length() != expectedBytes) {
nlog("text_embeds_full_fp16 size mismatch (got ${fullEmbFile.length()}, expected $expectedBytes) — disabling on-device text")
} else {
textEmbedsFullChan = java.io.RandomAccessFile(fullEmbFile, "r").channel
textEmbedsFullBuf = textEmbedsFullChan!!.map(
java.nio.channels.FileChannel.MapMode.READ_ONLY, 0L, expectedBytes
).order(ByteOrder.LITTLE_ENDIAN)
nlog("Full-vocab text embeddings mmap: ${textEmbedsFullLen} × $TALKER_DIM fp16 (${expectedBytes/1024/1024}MB, off-heap) in ${System.currentTimeMillis()-tE0}ms")
}
val pb = ByteBuffer.wrap(prefixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
val nPref = pb.int; val dimPref = pb.int
if (nPref == 9 && dimPref == TALKER_DIM) {
damienVoicePrefix = Array(9) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = pb.float } }
nlog("Damien voice prefix: $nPref × $dimPref")
} else {
nlog("damien_voice_prefix.bin header mismatch ($nPref × $dimPref) — disabling on-device text")
}
// Voice SUFFIX — the 2 fixed positions that close out
// the prefill after text tokens. Empirically invariant
// across segments of the same speaker (diff = 0.0).
val suffixFile = File("$path/damien_voice_suffix.bin")
if (suffixFile.exists()) {
val sb = ByteBuffer.wrap(suffixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
val nSuf = sb.int; val dimSuf = sb.int
if (nSuf == 2 && dimSuf == TALKER_DIM) {
damienVoiceSuffix = Array(2) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = sb.float } }
nlog("Damien voice suffix: $nSuf × $dimSuf")
}
} else {
nlog("damien_voice_suffix.bin missing — on-device text will lack closure markers")
}
if (textEmbedsFullBuf != null && damienVoicePrefix != null && damienVoiceSuffix != null) {
bpeTokenizer = Qwen3BpeTokenizer.load(tokDir.absolutePath)
nlog("Stage 2 on-device text path ready (BPE + full embeds + voice prefix)")
}
} else {
nlog("Stage 2 assets not fully present (full=$fullEmbFile, prefix=$prefixFile, tok=$tokDir) — legacy path only")
}
} catch (e: Exception) {
nlog("Stage 2 asset load failed: ${e.message} — legacy path only")
textEmbedsFullBuf = null; damienVoicePrefix = null; bpeTokenizer = null
try { textEmbedsFullChan?.close() } catch (_: Exception) {}
textEmbedsFullChan = null
}
codecEmbedding = loadNpy("$path/codec_embedding.npy")
nlog("Codec embedding: ${codecEmbedding!!.size / TALKER_DIM} × $TALKER_DIM")
val ttsSpecial = loadNpy("$path/tts_special_embeds.npy") // [3, 1024] = bos, eos, pad
ttsBosEmbed = ttsSpecial.sliceArray(0 until TALKER_DIM)
ttsEosEmbed = ttsSpecial.sliceArray(TALKER_DIM until 2 * TALKER_DIM)
ttsPadEmbed = ttsSpecial.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
nlog("TTS special embeddings loaded")
// Load speaker embedding (x-vector for voice cloning)
val spkFile = File("$path/speaker_embedding.npy")
if (spkFile.exists()) {
speakerEmbed = loadNpy(spkFile.absolutePath)
nlog("Speaker embedding: ${speakerEmbed!!.size} floats, norm=${Math.sqrt(speakerEmbed!!.sumOf { (it * it).toDouble() })}")
}
// Load VQ codebooks
loadVqCodebooks(path)
// Load code predictor embeddings
cpEmbeddings = loadNpy("$path/code_predictor_embeddings.npy")
nlog("CP embeddings: ${cpEmbeddings!!.size} floats")
loaded = true
nlog("Qwen3-TTS loaded in ${System.currentTimeMillis() - t0}ms")
} catch (e: Exception) {
nlog("ERROR: ${e.message}")
e.printStackTrace()
}
}
}
override fun isLoaded(): Boolean = loaded
fun setVoice(voicePath: String) {
nlog("Voice: $voicePath")
}
override suspend fun synthesize(text: String, language: String): TtsResult {
return withContext(Dispatchers.IO) {
val audio = generateSpeech(text, language)
TtsResult(audioData = audio, sampleRate = SR, durationMs = audio.size * 1000L / SR)
}
}
override suspend fun synthesizeAndPlay(
text: String, language: String,
onStart: (() -> Unit)?, onComplete: (() -> Unit)?
) {
withContext(Dispatchers.IO) {
val codebooks = generateCodebooks(text, language)
if (codebooks == null) { onComplete?.invoke(); return@withContext }
val (allCodebooks, numRealTokens) = codebooks
// Release DSP for QNN decoder ONLY if decoder is on HTP (not GPU)
if (!decoderOnGpu && (useHexagonTalker || useHexagonCp)) {
hexStopRunner()
}
onStart?.invoke()
// Decode and play in streaming chunks
val track = AudioTrack.Builder()
.setAudioAttributes(AudioAttributes.Builder()
.setUsage(AudioAttributes.USAGE_MEDIA)
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.build())
.setAudioFormat(AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(SR)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(SAMPLES_PER_TOKEN * 20 * 2) // ~1.6s buffer
.setTransferMode(AudioTrack.MODE_STREAM)
.build()
track.play()
audioTrack = track
val t3 = System.currentTimeMillis()
var pos = 0
while (pos < numRealTokens) {
val chunkEnd = minOf(pos + EFFECTIVE_CHUNK, numRealTokens)
val chunkTokens = chunkEnd - pos
// Build chunk codebooks padded to SEQ_LEN
val chunkCodes = Array(NUM_CODEBOOKS) { cb ->
IntArray(SEQ_LEN) { t ->
val srcIdx = pos + t
if (srcIdx < numRealTokens) allCodebooks[cb][srcIdx] else 0
}
}
val quantized = vqDecode(chunkCodes)
val chunkAudio = runSpeechDecoder(quantized)
// Trim and crossfade
if (pos == 0) {
val keepSamples = minOf(chunkTokens * SAMPLES_PER_TOKEN, chunkAudio.size)
track.write(chunkAudio, 0, keepSamples)
} else {
val skipSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
val keepSamples = minOf(chunkTokens * SAMPLES_PER_TOKEN, chunkAudio.size - skipSamples)
if (keepSamples > 0 && skipSamples < chunkAudio.size) {
track.write(chunkAudio, skipSamples, keepSamples)
}
}
pos += EFFECTIVE_CHUNK
}
nlog("Streaming decode: ${System.currentTimeMillis() - t3}ms")
track.stop()
track.release()
audioTrack = null
onComplete?.invoke()
}
}
override fun stop() {
audioTrack?.apply {
try { stop() } catch (_: Exception) {}
release()
}
audioTrack = null
}
/** Generate codebooks only (no decode). Returns (allCodebooks[16][padLen], numRealTokens). */
private fun generateCodebooks(text: String, language: String): Pair<Array<IntArray>, Int>? {
if (!loaded) return null
val t0 = System.currentTimeMillis()
nlog("Generating codebooks: '$text' [$language]")
try {
val textEmbedsList: List<FloatArray>
val phraseFile = java.io.File("$modelPath/phrase_embeds.bin")
if (phraseFile.exists()) {
val bytes = phraseFile.readBytes()
val buf = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN)
val count = buf.int
textEmbedsList = (0 until count).map {
FloatArray(TALKER_DIM).also { arr -> buf.asFloatBuffer().get(arr); buf.position(buf.position() + TALKER_DIM * 4) }
}
nlog("Loaded $count pre-computed text embeddings")
} else {
val tokenIds = tokenizeText(text)
textEmbedsList = tokenIds.map { textEmb(it) }
}
val maxGen = MAX_CONTEXT - 15
val allCodesArray = runInterleavedGeneration(textEmbedsList, maxGen)
val genMs = System.currentTimeMillis() - t0
nlog("Interleaved gen: ${genMs}ms, ${allCodesArray.size} tokens")
if (allCodesArray.isEmpty()) return null
val numRealTokens = allCodesArray.size
val padLen = maxOf(numRealTokens, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t -> if (t < numRealTokens) allCodesArray[t][cb] else 0 }
}
return Pair(allCodebooks, numRealTokens)
} catch (e: Exception) {
nlog("ERROR: ${e.message}")
return null
}
}
private fun generateSpeech(text: String, language: String): ShortArray {
if (!loaded) return ShortArray(0)
val t0 = System.currentTimeMillis()
nlog("Generating: '$text' [$language]")
try {
// Step 1: Load pre-computed text embeddings, or use tokenizer fallback
val phraseFile = File("$modelPath/phrase_embeds.bin")
val textEmbedsList: List<FloatArray>
if (phraseFile.exists()) {
// Pre-computed embeddings from PC
val bytes = phraseFile.readBytes()
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val count = buf.int
textEmbedsList = (0 until count).map {
FloatArray(TALKER_DIM).also { arr -> buf.asFloatBuffer().get(arr); buf.position(buf.position() + TALKER_DIM * 4) }
}
nlog("Loaded $count pre-computed text embeddings")
} else {
// Fallback: use reduced vocab tokenizer
val tokenIds = tokenizeText(text)
textEmbedsList = tokenIds.map { textEmb(it) }
nlog("Text tokens: ${tokenIds.size}: ${tokenIds.toList()}")
}
// Step 2: Interleaved talker + code predictor → all 16 codebooks
val t1 = System.currentTimeMillis()
// Hard safety limit; actual stop is controlled by post-text step counter below.
val maxGen = MAX_CONTEXT - 15
val allCodesArray = runInterleavedGeneration(textEmbedsList, maxGen)
nlog("Interleaved gen: ${System.currentTimeMillis() - t1}ms, ${allCodesArray.size} tokens")
if (allCodesArray.isEmpty()) return ShortArray(0)
val numRealTokens = allCodesArray.size
// Reshape to [16][totalTokens] for decoder (pad to SEQ_LEN if needed)
val padLen = maxOf(numRealTokens, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t -> if (t < numRealTokens) allCodesArray[t][cb] else 0 }
}
// Release DSP if decoder needs HTP (not GPU or CPU)
if (!decoderOnCpu && !decoderOnGpu && (useHexagonTalker || useHexagonCp)) {
hexStopRunner()
}
// Step 3: VQ → decoder → audio (chunked)
val t3 = System.currentTimeMillis()
val audio = decodeChunked(allCodebooks, numRealTokens)
nlog("Decode (chunked): ${System.currentTimeMillis() - t3}ms")
val totalMs = System.currentTimeMillis() - t0
val audioDur = audio.size.toFloat() / SR
nlog("Total: ${totalMs}ms for ${audioDur}s audio (RTF ${totalMs / 1000f / audioDur})")
return audio
} catch (e: Exception) {
nlog("ERROR: ${e.message}")
e.printStackTrace()
return ShortArray(0)
}
}
// ==================== Tokenization ====================
/**
* Simple text tokenization. Maps known words to reduced vocab IDs.
* For production, this should be replaced with proper BPE tokenization.
*/
private fun tokenizeText(text: String): IntArray {
// Known word → reduced vocab ID mappings (from Qwen3 tokenizer + token_mapping)
val knownTokens = mapOf(
"Bonjour" to 1043,
"bonjour" to 1028,
" bonjour" to intArrayOf(220, 1028), // space + bonjour
"Oui" to 1006,
"Non" to 966,
"Merci" to 1035,
"salut" to 1023,
"." to 13,
"," to 11,
"!" to 0,
"?" to 30,
" " to 220,
)
// Try exact match first
val exactId = knownTokens[text]
if (exactId is Int) return intArrayOf(exactId)
// Try word-by-word tokenization
val tokens = mutableListOf<Int>()
var remaining = text
while (remaining.isNotEmpty()) {
var matched = false
// Try longest match
for (len in minOf(remaining.length, 20) downTo 1) {
val candidate = remaining.substring(0, len)
val id = knownTokens[candidate]
if (id != null) {
when (id) {
is Int -> tokens.add(id)
is IntArray -> tokens.addAll(id.toList())
}
remaining = remaining.substring(len)
matched = true
break
}
}
if (!matched) {
// Skip unknown character
nlog("WARN: Unknown token for '${remaining.first()}'")
remaining = remaining.substring(1)
}
}
if (tokens.isEmpty()) {
nlog("WARN: No tokens from text '$text', using Bonjour fallback")
return intArrayOf(1043) // "Bonjour"
}
return tokens.toIntArray()
}
// ==================== Talker KV-Cache Generation ====================
// ==================== Embedding Helpers ====================
private fun textEmb(reducedId: Int): FloatArray {
val e = FloatArray(TALKER_DIM)
System.arraycopy(textEmbeds!!, reducedId.coerceIn(0, 1049) * TALKER_DIM, e, 0, TALKER_DIM)
return e
}
private fun codecEmb(codecIdx: Int): FloatArray {
val e = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, codecIdx.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, e, 0, TALKER_DIM)
return e
}
private fun cpEmb(codebookIdx: Int, tokenIdx: Int): FloatArray {
// cpEmbeddings is [15, 2048, 1024] flattened
val e = FloatArray(TALKER_DIM)
val offset = (codebookIdx * CODEBOOK_SIZE + tokenIdx.coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
System.arraycopy(cpEmbeddings!!, offset, e, 0, TALKER_DIM)
return e
}
private fun sumEmb(a: FloatArray, b: FloatArray): FloatArray {
val r = FloatArray(TALKER_DIM)
for (i in 0 until TALKER_DIM) r[i] = a[i] + b[i]
return r
}
private fun addEmb(dst: FloatArray, src: FloatArray) {
for (i in 0 until TALKER_DIM) dst[i] += src[i]
}
/**
* Build prefill embeddings with speaker embedding (voice cloning).
* textEmbedsList: pre-computed text embeddings (one per text token)
*/
private fun buildPrefillEmbeddings(textEmbedsList: List<FloatArray>): List<FloatArray> {
val padE = ttsPadEmbed ?: return emptyList()
val bosE = ttsBosEmbed ?: return emptyList()
val spkE = speakerEmbed
val embeddings = mutableListOf<FloatArray>()
// 1. Role: <|im_start|>assistant\n
embeddings.add(textEmb(IM_START))
embeddings.add(textEmb(TOKEN_ASSISTANT))
embeddings.add(textEmb(TOKEN_NEWLINE))
// 2. Codec control: [think, think_bos, lang_fr, think_eos] + tts_pad
for (cc in intArrayOf(CODEC_THINK, CODEC_THINK_BOS, CODEC_LANG_FR, CODEC_THINK_EOS)) {
embeddings.add(sumEmb(padE, codecEmb(cc)))
}
// 3. Speaker embedding (x-vector)
if (spkE != null) {
embeddings.add(sumEmb(padE, spkE))
}
// 4. codec_pad + tts_bos
embeddings.add(sumEmb(bosE, codecEmb(CODEC_PAD)))
// 5. First text token + codec_bos
if (textEmbedsList.isNotEmpty()) {
embeddings.add(sumEmb(textEmbedsList[0], codecEmb(CODEC_BOS)))
}
nlog("Prefill: ${embeddings.size} tokens (3 role + 4 ctrl + ${if (spkE != null) "1 spk + " else ""}1 bos + 1 text)")
return embeddings
}
// ==================== Hexagon NPU Talker ====================
private val HEX_DIR = "/data/local/tmp/kazeia/llama-hex"
private val HEX_INPUT = "/data/local/tmp/kazeia/tts_input.bin"
private val HEX_OUTPUT = "/data/local/tmp/kazeia/tts_logits.bin"
private val HEX_CONTROL = "/data/local/tmp/kazeia/tts_control.txt"
private val TALKER_SOCK = "/data/local/tmp/kazeia/talker.sock"
private val CP_SOCK = "/data/local/tmp/kazeia/cp.sock"
private val CP_ET_SOCK = "/data/local/tmp/kazeia/cp_et.sock"
private var talkerSocket: android.net.LocalSocket? = null
private var cpSocket: android.net.LocalSocket? = null
private var useHexagonCp = false
/** Run a command with su. Returns false if su is not available. */
private fun suExec(cmd: String): Boolean {
return try {
Runtime.getRuntime().exec(arrayOf("su", "-c", cmd)).waitFor()
true
} catch (e: Exception) {
false
}
}
/** Read a root-owned file via su. */
private fun suReadFile(path: String): String {
val p = Runtime.getRuntime().exec(arrayOf("su", "-c", "cat $path"))
val text = p.inputStream.bufferedReader().readText().trim()
p.waitFor()
return text
}
/** Start the hexagon talker runner and connect via socket.
* If flag force_cpu_talker_gguf exists, runs with -ngl 0 (CPU fp32) to isolate
* whether the voice-cloning tremor comes from NPU fp16 drift. Diagnostic only;
* expected RTF 5-7x on CPU. */
private fun hexStartRunner(): Boolean {
val forceCpuGguf = diagFlag("force_cpu_talker_gguf")
val nglArg = if (forceCpuGguf) "-ngl 0" else "-ngl 99 -mg 1"
nlog("Starting ${if (forceCpuGguf) "CPU GGUF" else "Hexagon"} talker runner ($nglArg)...")
val t0 = System.currentTimeMillis()
if (!suExec("pkill -f llama-tts-talker")) { nlog("su not available"); return false }
Thread.sleep(200)
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf $nglArg -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false
// Wait for socket to be connectable (30s max — model loading takes ~15s)
for (w in 0 until 300) {
Thread.sleep(100)
try {
val sock = android.net.LocalSocket()
sock.connect(android.net.LocalSocketAddress(TALKER_SOCK, android.net.LocalSocketAddress.Namespace.FILESYSTEM))
talkerSocket = sock
nlog("Hexagon talker connected: ${System.currentTimeMillis() - t0}ms")
return true
} catch (_: Exception) {}
}
nlog("Hexagon talker timeout (30s)")
return false
}
/** Talker forward via socket: send embedding, get hidden+logits. */
private fun hexForward(embeddings: List<FloatArray>): List<Pair<FloatArray, FloatArray>> {
val sock = talkerSocket ?: return emptyList()
val os = sock.outputStream
val ins = sock.inputStream
val results = mutableListOf<Pair<FloatArray, FloatArray>>()
for (emb in embeddings) {
// Send "FWRD" + embedding
os.write("FWRD".toByteArray())
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in emb) buf.putFloat(v)
os.write(buf.array())
os.flush()
// Read hidden(1024*4) + logits(3072*4) = 16384 bytes
val respSize = (TALKER_DIM + TALKER_VOCAB) * 4
val resp = ByteArray(respSize)
var read = 0
while (read < respSize) {
val n = ins.read(resp, read, respSize - read)
if (n <= 0) break
read += n
}
val fb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN).asFloatBuffer()
val h = FloatArray(TALKER_DIM); fb.get(h)
val l = FloatArray(TALKER_VOCAB); fb.get(l)
results.add(Pair(h, l))
}
return results
}
/** Reset KV cache for a new phrase via socket. */
private fun hexReset() {
val sock = talkerSocket ?: return
sock.outputStream.write("REST".toByteArray())
sock.outputStream.flush()
val resp = ByteArray(4)
sock.inputStream.read(resp) // blocking, waits for "OK\0\0"
}
/** Start CP ExecuTorch runner (NPU HTP via root process). */
private fun hexStartCpEtRunner(): Boolean {
nlog("Connecting CP ExecuTorch runner (TCP:8790)...")
val t0 = System.currentTimeMillis()
// Try connecting to existing runner first
try {
val sock = java.net.Socket()
sock.connect(java.net.InetSocketAddress("127.0.0.1", 8790), 2000)
sock.tcpNoDelay = true
cpEtSocket = sock
nlog("CP ET connected (TCP, existing): ${System.currentTimeMillis() - t0}ms")
return true
} catch (e: Exception) {
nlog("CP ET existing connection failed: ${e.message}")
}
// Start new runner
nlog("No running CP ET, starting new...")
suExec("pkill -f cp_et_runner")
Thread.sleep(200)
if (!suExec("cd /data/local/tmp/kazeia && LD_LIBRARY_PATH=qnn_libs:. ADSP_LIBRARY_PATH=qnn_libs " +
"nohup ./cp_et_runner --model_path=models/cp_transformer_fp16.pte --tcp_port=8790 " +
"> /data/local/tmp/kazeia/cp_et.log 2>&1 &")) return false
for (w in 0 until 300) {
Thread.sleep(100)
try {
val sock = java.net.Socket("127.0.0.1", 8790)
sock.tcpNoDelay = true
cpEtSocket = sock
nlog("CP ET connected (TCP): ${System.currentTimeMillis() - t0}ms")
return true
} catch (_: Exception) {}
}
nlog("CP ET timeout (30s)")
return false
}
/** CP via ExecuTorch NPU TCP socket: send hidden+cb0_emb, recv 15 codes + timing. */
private fun etCpForward(pastHidden: FloatArray, cb0: Int): IntArray {
val sock = cpEtSocket
if (sock == null || sock.isClosed) { nlog("CP ET socket null"); return runCpV2(pastHidden, cb0) }
try {
val os = sock.getOutputStream(); val ins = sock.getInputStream()
val buf = java.nio.ByteBuffer.allocate(2 * TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in pastHidden) buf.putFloat(v)
val cb0Emb = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, cb0Emb, 0, TALKER_DIM)
for (v in cb0Emb) buf.putFloat(v)
os.write(buf.array()); os.flush()
val resp = ByteArray(64); var read = 0
while (read < 64) { val n = ins.read(resp, read, 64 - read); if (n <= 0) break; read += n }
if (read < 64) { nlog("CP ET read incomplete: $read/64"); return runCpV2(pastHidden, cb0) }
val codes = IntArray(15); val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (i in 0 until 15) codes[i] = rb.int
return codes
} catch (e: Exception) {
nlog("CP ET error: ${e.message}"); cpEtSocket = null; useEtCp = false
return runCpV2(pastHidden, cb0)
}
}
/** Start CP hexagon runner and connect via socket.
* NOTE: uses -ngl 0 (CPU only). Running the CP on Hexagon HTP produces deterministic
* NaN at position 2 of every request — bug in libggml-htp-v79.so for this specific
* 5-layer qwen3 graph (confirmed empirically). CPU path is ~60 ms/step and correct.
* The talker still benefits from Hexagon HMX since it runs in a separate process. */
private fun hexStartCpRunner(): Boolean {
nlog("Starting CP runner (CPU, -ngl 0 to avoid HMX NaN bug)...")
val t0 = System.currentTimeMillis()
if (!suExec("pkill -f llama-tts-cp")) { nlog("su not available for CP"); return false }
Thread.sleep(200)
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 0 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false
for (w in 0 until 300) {
Thread.sleep(100)
try {
val sock = android.net.LocalSocket()
sock.connect(android.net.LocalSocketAddress(CP_SOCK, android.net.LocalSocketAddress.Namespace.FILESYSTEM))
cpSocket = sock
nlog("CP Hexagon connected: ${System.currentTimeMillis() - t0}ms")
return true
} catch (_: Exception) {}
}
nlog("CP Hexagon timeout (30s)")
return false
}
/** CP via Hexagon NPU socket: send hidden+cb0_emb, get 15 codes. */
private fun hexCpForward(pastHidden: FloatArray, cb0: Int): IntArray {
val sock = cpSocket
if (sock == null) {
nlog("CP socket NULL, falling back to CPU")
return runCpCpu(pastHidden, cb0)
}
try {
val os = sock.outputStream
val ins = sock.inputStream
val t0 = System.currentTimeMillis()
// Send hidden(1024) + cb0_emb(1024) = 8192 bytes
val buf = java.nio.ByteBuffer.allocate(2 * TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in pastHidden) buf.putFloat(v)
val cb0Emb = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, cb0Emb, 0, TALKER_DIM)
for (v in cb0Emb) buf.putFloat(v)
os.write(buf.array()); os.flush()
val writeMs = System.currentTimeMillis() - t0
// Read 15 codes (60 bytes) + timing (4 bytes) = 64 bytes
val resp = ByteArray(64)
var read = 0
while (read < 64) { val n = ins.read(resp, read, 64 - read); if (n <= 0) break; read += n }
val totalMs = System.currentTimeMillis() - t0
if (cpCallCount <= 3) nlog("CP socket: write=${writeMs}ms, total=${totalMs}ms, read=$read/64")
if (read < 64) {
nlog("CP socket read incomplete: $read/64 bytes")
return runCpCpu(pastHidden, cb0)
}
val codes = IntArray(15)
val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (i in 0 until 15) codes[i] = rb.int
// Defensive clamp: CP on first call may occasionally return uninitialised
// memory before warm-up stabilises. Out-of-range codes would crash BigVGAN
// decoder (ArrayIndexOutOfBoundsException on codebook lookup). Map anything
// invalid to 0 — a single-frame artefact is much better than an app crash.
var n_bad = 0
for (i in 0 until 15) {
if (codes[i] < 0 || codes[i] >= CODEBOOK_SIZE) { codes[i] = 0; n_bad++ }
}
if (n_bad > 0) nlog("CP socket: clamped $n_bad out-of-range codes to 0")
return codes
} catch (e: Exception) {
nlog("CP socket error: ${e.message}, falling back to CPU")
cpSocket = null
return runCpCpu(pastHidden, cb0)
}
}
/** Ensure hexagon runners are alive. Only restarts if they're not connected. */
private fun ensureHexagonRunners() {
if (!useHexagonTalker) {
val hexRunner = java.io.File("/data/local/tmp/kazeia/llama-hex/llama-tts-talker")
val hexModel = java.io.File("/data/local/tmp/kazeia/models/talker_f16.gguf")
if (hexRunner.exists() && hexModel.exists() && hexStartRunner()) {
useHexagonTalker = true
}
}
if (!useHexagonCp) {
val hexCpRunner = java.io.File("/data/local/tmp/kazeia/llama-hex/llama-tts-cp")
val hexCpModel = java.io.File("/data/local/tmp/kazeia/models/cp_f16.gguf")
if (hexCpRunner.exists() && hexCpModel.exists() && hexStartCpRunner()) {
useHexagonCp = true
}
}
}
/** Stop all runners and release DSP for QNN decode. Runners restart on next generate(). */
private fun hexStopRunner() {
try {
talkerSocket?.outputStream?.write("QUIT".toByteArray())
talkerSocket?.close()
cpSocket?.close()
cpEtSocket?.close()
Thread.sleep(300)
suExec("pkill -f llama-tts")
suExec("pkill -f cp_et_runner")
Thread.sleep(200)
} catch (_: Exception) {}
talkerSocket = null; cpSocket = null; cpEtSocket = null
useHexagonTalker = false; useHexagonCp = false; useEtCp = false
nlog("Hexagon runners stopped, DSP released")
}
/** Full interleaved generation using Hexagon NPU talker + CPU CP. */
/** All-NPU pipeline: talker .pte + CP .pte via JNI, no root. */
private fun runInterleavedPte(textEmbedsList: List<FloatArray>, maxGenTokens: Int): Array<IntArray> {
val talkerMod = talkerPteModule!!
val cpMod = cpPteModule!!
val tCos = talkerPteRotaryCos!!
val tSin = talkerPteRotarySin!!
val cCos = cpRotaryCos ?: return emptyArray()
val cSin = cpRotarySin ?: return emptyArray()
val eosE = ttsEosEmbed ?: return emptyArray()
val padE = ttsPadEmbed ?: return emptyArray()
val prefill = buildPrefillEmbeddings(textEmbedsList)
if (prefill.isEmpty()) return emptyArray()
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
var trailingIdx = 0
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
// Talker KV caches [TALKER_LAYERS × (k,v)] each [1, 8, TALKER_PTE_KV_LEN, 128]
val tkvSize = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
var tK = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
var tV = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
val maskData = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }
var pos = 0; var currentCb0 = -1; var pastHidden: FloatArray? = null
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
// Capture mode: save all talker inputs for reuse with C++ pipeline
val capturedEmbeds = mutableListOf<FloatArray>()
// ===== PREFILL =====
val tPrefill = System.currentTimeMillis()
for (step in prefill.indices) {
capturedEmbeds.add(prefill[step].clone()) // capture prefill input
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
if (maskIdx >= 0) maskData[maskIdx] = 0f
val cosSlice = FloatArray(TALKER_HEAD_DIM)
System.arraycopy(tCos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
val sinSlice = FloatArray(TALKER_HEAD_DIM)
System.arraycopy(tSin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
val inputs = mutableListOf(
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(prefill[step], longArrayOf(1, 1, TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
)
for (i in 0 until TALKER_LAYERS) {
inputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
inputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
}
val out = talkerMod.forward(*inputs.toTypedArray())
pastHidden = out[0].toTensor().dataAsFloatArray
val logits = out[1].toTensor().dataAsFloatArray
for (i in 0 until TALKER_LAYERS) {
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
}
pos++
if (step == prefill.size - 1) {
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
currentCb0 = sampleTopK(logits, 0.9f, 50)
}
}
nlog("Prefill (PTE): ${System.currentTimeMillis() - tPrefill}ms, ${prefill.size} steps")
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
// ===== INTERLEAVED GENERATION =====
var totalTalkerMs = 0L; var totalCpMs = 0L
for (genStep in 0 until maxGenTokens) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
// 1. CP: predict CB1-15
val tCp = System.currentTimeMillis()
val cpCodes = runCpPte(pastHidden!!, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
// 2. Build next talker input
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
val nextEmbed: FloatArray = when {
trailingIdx < trailingEmbeds.size -> { trailingIdx++; sumEmb(codecSum, trailingEmbeds[trailingIdx - 1]) }
trailingIdx == trailingEmbeds.size -> { trailingIdx++; sumEmb(codecSum, eosE) }
else -> sumEmb(codecSum, padE)
}
capturedEmbeds.add(nextEmbed.clone()) // capture decode input
// 3. Talker step
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
if (maskIdx >= 0) maskData[maskIdx] = 0f
val cosSlice = FloatArray(TALKER_HEAD_DIM)
System.arraycopy(tCos, minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
val sinSlice = FloatArray(TALKER_HEAD_DIM)
System.arraycopy(tSin, minOf(pos, tSin.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
val inputs = mutableListOf(
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(nextEmbed, longArrayOf(1, 1, TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
)
for (i in 0 until TALKER_LAYERS) {
inputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
inputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
}
val tTalker = System.currentTimeMillis()
val out = talkerMod.forward(*inputs.toTypedArray())
totalTalkerMs += System.currentTimeMillis() - tTalker
pastHidden = out[0].toTensor().dataAsFloatArray
val logits = out[1].toTensor().dataAsFloatArray
for (i in 0 until TALKER_LAYERS) {
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
}
pos++
// 4. Next CB0
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
val nextCb0 = sampleTopK(logits, 0.9f, 50)
if (nextCb0 == CODEC_EOS) { nlog("EOS at gen step ${genStep + 2}"); break }
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
nlog("Degeneration at step ${genStep + 2}"); break
}
currentCb0 = nextCb0
}
val n = allCodes.size
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
// Save captured embeds for C++ pipeline reuse
if (capturedEmbeds.isNotEmpty()) {
try {
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
val nPrefill = prefill.size
val fos = java.io.FileOutputStream(capPath)
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
fos.write(hdr.array())
for (emb in capturedEmbeds) {
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in emb) buf.putFloat(v)
fos.write(buf.array())
}
fos.close()
nlog("Captured ${capturedEmbeds.size} embeds → $capPath")
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
}
return allCodes.toTypedArray()
}
private fun runInterleavedHexagon(textEmbedsList: List<FloatArray>, maxGenTokens: Int): Array<IntArray> {
val eosE = ttsEosEmbed ?: return emptyArray()
val padE = ttsPadEmbed ?: return emptyArray()
val prefill = buildPrefillEmbeddings(textEmbedsList)
if (prefill.isEmpty()) return emptyArray()
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
var trailingIdx = 0
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
// Reset KV cache for new phrase (runner already started at load time)
hexReset()
var totalTalkerMs = 0L; var totalCpMs = 0L
try {
// Prefill
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(prefill)
val prefillMs = System.currentTimeMillis() - tPrefill
nlog("Prefill (Hexagon): ${prefillMs}ms, ${prefillResults.size} steps")
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
// Suppress non-codec tokens
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
nlog("Prefill done: first cb0=$currentCb0")
// Generation loop
for (genStep in 0 until maxGenTokens) {
// CP on CPU
val tCp = System.currentTimeMillis()
val codes = IntArray(NUM_CODEBOOKS)
codes[0] = currentCb0
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes)
generatedCb0.add(currentCb0)
val cpMs = System.currentTimeMillis() - tCp
totalCpMs += cpMs
if (genStep < 3) nlog("Gen step ${genStep + 1}: cb0=$currentCb0 cb1=${codes[1]} [CP=${cpMs}ms]")
// Build next embedding
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
val nextEmbed: FloatArray = if (trailingIdx < trailingEmbeds.size) {
sumEmb(codecSum, trailingEmbeds[trailingIdx++])
} else if (trailingIdx == trailingEmbeds.size) {
trailingIdx++; sumEmb(codecSum, eosE)
} else {
sumEmb(codecSum, padE)
}
// Talker forward on Hexagon NPU
val tTalker = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
val talkerMs = System.currentTimeMillis() - tTalker
totalTalkerMs += talkerMs
if (genStep < 3) nlog(" Talker(HEX)=${talkerMs}ms")
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
val nextCb0 = sampleTopK(logits, 0.9f, 50)
if (nextCb0 == CODEC_EOS) { nlog("EOS at gen step ${genStep + 2}"); break }
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
nlog("Degeneration at step ${genStep + 2}"); break
}
currentCb0 = nextCb0
}
} finally {
// Runners stay alive here — hexStopRunner() is called before QNN decode
}
val n = allCodes.size
nlog("Generated $n tokens | Talker(HEX): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
return allCodes.toTypedArray()
}
/**
* Run interleaved talker + code predictor pipeline.
*
* At each step:
* 1. Talker forward → logits + hidden_state
* 2. Code predictor (hidden, cb0_emb) → CB1-15 autoregressively
* 3. Sum all 16 codebook embeddings + trailing text → next talker input
*
* Returns all 16 codebooks as [numTokens][16].
*/
private fun runInterleavedGeneration(textEmbedsList: List<FloatArray>, maxGenTokens: Int = 50): Array<IntArray> {
// Priority 1: All .pte JNI on NPU (no root needed)
if (talkerPteModule != null && cpPteModule != null) {
return runInterleavedPte(textEmbedsList, maxGenTokens)
}
// Priority 2: Hexagon talker + socket/ONNX CP
ensureHexagonRunners()
if (useHexagonTalker) {
return runInterleavedHexagon(textEmbedsList, maxGenTokens)
}
val env = ortEnv ?: return emptyArray()
if (talkerKv == null) return emptyArray()
val session = talkerKv // may be null if using NPU
val eosE = ttsEosEmbed ?: return emptyArray()
val padE = ttsPadEmbed ?: return emptyArray()
val prefill = buildPrefillEmbeddings(textEmbedsList)
if (prefill.isEmpty()) return emptyArray()
// Trailing text embeddings (all except first, which is in prefill)
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
var trailingIdx = 0
val allCodes = mutableListOf<IntArray>() // each entry is [16] codebooks for one time step
val generatedCb0 = mutableListOf<Int>()
val kCacheSize = TALKER_HEADS * TALKER_HEAD_DIM * KV_LEN
val vCacheSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
var kCaches = Array(TALKER_LAYERS) { FloatArray(kCacheSize) }
var vCaches = Array(TALKER_LAYERS) { FloatArray(vCacheSize) }
val maskData = FloatArray(MAX_CONTEXT) { MASK_NEG }
var pos = 0
var currentCb0 = -1
var pastHidden: FloatArray? = null
// ===== PREFILL =====
for (step in prefill.indices) {
maskData[MAX_CONTEXT - 1 - step] = 0f
val res = runTalkerStep(env, session!!, prefill[step], maskData, pos, kCaches, vCaches)
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
pos++
if (step == prefill.size - 1) {
// Apply suppression + sampling to get first codec_0
val logits = res.logits
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
currentCb0 = sampleTopK(logits, 0.9f, 50)
nlog("Prefill done: first cb0=$currentCb0")
}
}
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
// ===== INTERLEAVED GENERATION =====
var totalTalkerMs = 0L; var totalCpMs = 0L
for (genStep in 0 until maxGenTokens) {
// 1. Run code predictor: (pastHidden, cb0_emb) → CB1-15
val codes = IntArray(NUM_CODEBOOKS)
codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden!!, currentCb0)
val cpMs = System.currentTimeMillis() - tCp
totalCpMs += cpMs
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes)
generatedCb0.add(currentCb0)
if (genStep < 3) {
nlog("Gen step ${genStep + 1}: cb0=$currentCb0 cb1=${codes[1]} [CP=${cpMs}ms]")
}
// 2. Build next talker input: sum of ALL 16 codebook embeddings + trailing text
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) {
addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
}
// Text side: trailing text tokens, then tts_eos, then tts_pad (NOT nothing!)
// Python model expects tts_pad after text exhaustion - crucial for EOS convergence
val nextEmbed: FloatArray
if (trailingIdx < trailingEmbeds.size) {
nextEmbed = sumEmb(codecSum, trailingEmbeds[trailingIdx])
trailingIdx++
} else if (trailingIdx == trailingEmbeds.size) {
nextEmbed = sumEmb(codecSum, eosE)
trailingIdx++
} else {
nextEmbed = sumEmb(codecSum, padE)
}
// 3. Run talker step
maskData[MAX_CONTEXT - 1 - pos] = 0f
val tTalker = System.currentTimeMillis()
val res = runTalkerStep(env, session!!, nextEmbed, maskData, pos, kCaches, vCaches)
val talkerMs = System.currentTimeMillis() - tTalker
totalTalkerMs += talkerMs
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
pos++
if (genStep < 3) {
nlog(" Talker=${talkerMs}ms")
}
// 4. Get next cb0 from logits
val logits = res.logits
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
// HuggingFace repetition penalty: 1.05x once per unique token in history
val seen = HashSet<Int>()
for (prev in generatedCb0) seen.add(prev)
for (tok in seen) {
logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f
}
val nextCb0 = sampleTopK(logits, 0.9f, 50)
if (nextCb0 == CODEC_EOS) {
nlog("EOS at gen step ${genStep + 2}")
break
}
// Safety: stop on extreme repetition (10+ identical = model degeneration)
if (generatedCb0.size >= 9) {
val last9 = generatedCb0.takeLast(9)
if (last9.all { it == nextCb0 }) {
nlog("Degeneration at step ${genStep + 2} (token $nextCb0 × 10), stopping")
break
}
}
currentCb0 = nextCb0
}
val n = allCodes.size
nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
return allCodes.toTypedArray()
}
/** Single talker step: returns logits, hidden, and updated KV caches */
private data class TalkerStepResult(
val logits: FloatArray, val hidden: FloatArray,
val newK: Array<FloatArray>, val newV: Array<FloatArray>
)
private fun runTalkerStep(
env: OrtEnvironment, session: OrtSession,
inputEmbed: FloatArray, maskData: FloatArray, pos: Int,
kCaches: Array<FloatArray>, vCaches: Array<FloatArray>
): TalkerStepResult {
if (talkerUsesCosSin) {
return runTalkerStepMRoPE(env, session, inputEmbed, maskData, pos, kCaches, vCaches)
}
val embedTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
val maskTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
val posTensor = if (talkerUsesInt64Pos) {
OnnxTensor.createTensor(env, java.nio.LongBuffer.wrap(longArrayOf(pos.toLong())), longArrayOf(1))
} else {
OnnxTensor.createTensor(env, IntBuffer.wrap(intArrayOf(pos)), longArrayOf(1))
}
val inputs = LinkedHashMap<String, OnnxTensor>()
inputs["inputs_embeds"] = embedTensor
inputs["attention_mask"] = maskTensor
for (i in 0 until TALKER_LAYERS) {
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
longArrayOf(1, TALKER_HEADS.toLong(), TALKER_HEAD_DIM.toLong(), KV_LEN.toLong()))
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
}
inputs["position_ids"] = posTensor
val result = session.run(inputs)
// logits [1, 3072, 1, 1]
val logits = FloatArray(TALKER_VOCAB)
(result.get(0) as OnnxTensor).floatBuffer.get(logits)
// hidden [1, 1, 1024] at output index 57
val hidden = FloatArray(TALKER_DIM)
(result.get(57) as OnnxTensor).floatBuffer.get(hidden)
// KV caches
val kCacheSize = TALKER_HEADS * TALKER_HEAD_DIM * KV_LEN
val vCacheSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
val newK = Array(TALKER_LAYERS) { FloatArray(kCacheSize) }
val newV = Array(TALKER_LAYERS) { FloatArray(vCacheSize) }
for (i in 0 until TALKER_LAYERS) {
(result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(newK[i])
(result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(newV[i])
}
for ((_, v) in inputs) v.close()
result.close()
return TalkerStepResult(logits, hidden, newK, newV)
}
/** New talker step with M-RoPE: cos/sin inputs, KV shape [1,8,199,128]. */
private fun runTalkerStepMRoPE(
env: OrtEnvironment, session: OrtSession,
inputEmbed: FloatArray, maskData: FloatArray, pos: Int,
kCaches: Array<FloatArray>, vCaches: Array<FloatArray>
): TalkerStepResult {
val cos = rotaryCos ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches)
val sin = rotarySin ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches)
val inputs = LinkedHashMap<String, OnnxTensor>()
inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
// cos/sin for this position: [1, 1, 128]; clamp index so we never read past
// the rotary table even if generation runs longer than the table.
val rotMax = (cos.size / TALKER_HEAD_DIM) - 1
val rotPos = if (pos > rotMax) rotMax else pos
val cosSlice = FloatArray(TALKER_HEAD_DIM)
val sinSlice = FloatArray(TALKER_HEAD_DIM)
System.arraycopy(cos, rotPos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
System.arraycopy(sin, rotPos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
// KV caches [1, 8, 199, 128] (NOT transposed like legacy)
val kvSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM // 8 * 199 * 128
for (i in 0 until TALKER_LAYERS) {
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
}
val result = session.run(inputs)
// New format: hidden at index 0 [1,1,1024], logits at index 1 [1,1,3072]
val hidden = FloatArray(TALKER_DIM)
(result.get(0) as OnnxTensor).floatBuffer.get(hidden)
val logits = FloatArray(TALKER_VOCAB)
(result.get(1) as OnnxTensor).floatBuffer.get(logits)
// KV caches out: [1, 8, 200, 128] — trim to [1, 8, 199, 128] (drop oldest)
val newK = Array(TALKER_LAYERS) { i ->
val full = FloatArray(TALKER_HEADS * MAX_CONTEXT * TALKER_HEAD_DIM)
(result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(full)
// Drop first position: copy [1:200] → [0:199]
val trimmed = FloatArray(kvSize)
System.arraycopy(full, TALKER_HEADS * TALKER_HEAD_DIM, trimmed, 0, kvSize)
// Wait — full is [1,8,200,128] flattened. Drop pos 0 from dim 2:
// For each head h: copy full[h*200*128 + 1*128 .. h*200*128 + 200*128] → trimmed[h*199*128..]
val t = FloatArray(kvSize)
for (h in 0 until TALKER_HEADS) {
System.arraycopy(full, h * MAX_CONTEXT * TALKER_HEAD_DIM + TALKER_HEAD_DIM,
t, h * KV_LEN * TALKER_HEAD_DIM, KV_LEN * TALKER_HEAD_DIM)
}
t
}
val newV = Array(TALKER_LAYERS) { i ->
val full = FloatArray(TALKER_HEADS * MAX_CONTEXT * TALKER_HEAD_DIM)
(result.get(3 + i * 2) as OnnxTensor).floatBuffer.get(full)
val t = FloatArray(kvSize)
for (h in 0 until TALKER_HEADS) {
System.arraycopy(full, h * MAX_CONTEXT * TALKER_HEAD_DIM + TALKER_HEAD_DIM,
t, h * KV_LEN * TALKER_HEAD_DIM, KV_LEN * TALKER_HEAD_DIM)
}
t
}
for ((_, v) in inputs) v.close()
result.close()
return TalkerStepResult(logits, hidden, newK, newV)
}
/** Run code predictor — JNI .pte > TCP runner > Hexagon > CPU ONNX. */
private var cpCallCount = 0
private fun runCodePredictorInterleaved(pastHidden: FloatArray, cb0: Int): IntArray {
if (cpCallCount == 0) nlog("CP: pte=${cpPteModule != null}, talkerPte=${talkerPteModule != null}, et=$useEtCp, hex=$useHexagonCp, v2=$cpUsesCosSin")
cpCallCount++
// Diagnostic override: force the ONNX Runtime CP path (runCpV2). ORT's MLAS
// kernels may have a different fp32 reduction order than ggml, potentially closer
// to PyTorch. Test hypothesis: does that reduce code divergence from Python?
if (File("/data/local/tmp/kazeia/force_cp_v2").exists() && cpUsesCosSin && cpKv != null) {
if (cpCallCount == 1) nlog("force_cp_v2: using ONNX Runtime CP (cpKv session)")
return runCpV2(pastHidden, cb0)
}
// Routing priority:
// 1. If Hexagon talker is running: MUST use hexCpForward (separate CPU process
// via socket). Running CP .pte on QNN HTP alongside Hexagon triggers err 6031
// (documented DSP contention).
// 2. Otherwise (Talker on .pte or CPU ONNX): CP .pte JNI in the same QNN context.
// 3. Legacy fallbacks.
if (useHexagonTalker && useHexagonCp) return hexCpForward(pastHidden, cb0)
if (cpPteModule != null && (talkerPteModule != null || talkerKv != null)) return runCpPte(pastHidden, cb0)
if (useEtCp) return etCpForward(pastHidden, cb0)
if (useHexagonCp) return hexCpForward(pastHidden, cb0)
if (cpUsesCosSin && cpKv != null) return runCpV2(pastHidden, cb0)
return runCpCpu(pastHidden, cb0)
}
/** CP via ExecuTorch .pte on NPU (QNN fp16).
* .pte inputs: emb[1,1,1024], mask[1,1,1,17], cos[1,1,128], sin[1,1,128], 10×kv[1,8,16,128]
* .pte outputs: hidden[1,1,1024], head_logits[1,15,2048], 10×kv[1,8,17,128]
*/
private fun runCpPte(pastHidden: FloatArray, cb0: Int): IntArray {
val module = cpPteModule ?: return runCpV2(pastHidden, cb0)
val cos = cpRotaryCos ?: return runCpV2(pastHidden, cb0)
val sin = cpRotarySin ?: return runCpV2(pastHidden, cb0)
val codes = IntArray(15)
// Fixed KV caches [1, 8, 16, 128]
val kvSize = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
var kCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
var vCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
var emb = pastHidden
try {
for (step in 0 until 17) {
if (step == 1) {
emb = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
} else if (step >= 2) {
emb = FloatArray(TALKER_DIM)
val cembs = cpEmbeddings ?: return codes
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
System.arraycopy(cembs, off, emb, 0, TALKER_DIM)
}
// Build mask: last (step+1) positions active
val mask = FloatArray(CP_KV_LEN) { -1e9f }
for (p in 0..minOf(step, CP_KV_LEN - 1)) mask[CP_KV_LEN - 1 - p] = 0f
val cosSlice = FloatArray(CP_HEAD_DIM)
System.arraycopy(cos, step * CP_HEAD_DIM, cosSlice, 0, CP_HEAD_DIM)
val sinSlice = FloatArray(CP_HEAD_DIM)
System.arraycopy(sin, step * CP_HEAD_DIM, sinSlice, 0, CP_HEAD_DIM)
// Build EValue inputs
val inputs = arrayOf(
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(emb, longArrayOf(1, 1, TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(mask, longArrayOf(1, 1, 1, CP_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, CP_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, CP_HEAD_DIM.toLong()))),
)
// Add KV caches
val allInputs = inputs.toMutableList()
for (i in 0 until CP_LAYERS) {
allInputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(kCaches[i], longArrayOf(1, CP_KV_HEADS.toLong(), CP_KV_LEN.toLong(), CP_HEAD_DIM.toLong()))))
allInputs.add(org.pytorch.executorch.EValue.from(
org.pytorch.executorch.Tensor.fromBlob(vCaches[i], longArrayOf(1, CP_KV_HEADS.toLong(), CP_KV_LEN.toLong(), CP_HEAD_DIM.toLong()))))
}
val outputs = module.forward(*allInputs.toTypedArray())
// .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ...
val hiddenOut = outputs[0].toTensor().dataAsFloatArray
// Head argmax using NEON SIMD (individual 8MB heads, pre-loaded)
if (step >= 1 && step - 1 < 15) {
if (cpHeadsCache == null) {
cpHeadsCache = arrayOfNulls(15)
val hp = cpHeadsPath
if (hp != null) {
for (h in 0 until 15) {
cpHeadsCache!![h] = loadNpy(hp.replace("cp_heads.npy", "head_${h}.npy"))
}
nlog("CP heads pre-loaded: 15 × ${cpHeadsCache!![0]?.size} floats")
}
}
val cbIdx = step - 1
val headData = cpHeadsCache?.get(cbIdx)
if (headData != null) {
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, headData, CODEBOOK_SIZE, TALKER_DIM)
}
}
// Update KV caches (output is [1,8,16,128] — fixed size, already shifted)
for (i in 0 until CP_LAYERS) {
kCaches[i] = outputs[1 + i * 2].toTensor().dataAsFloatArray
vCaches[i] = outputs[2 + i * 2].toTensor().dataAsFloatArray
}
}
} catch (e: Exception) {
nlog("CP .pte error: ${e.message}, falling back to ONNX CPU")
cpPteModule = null
return runCpV2(pastHidden, cb0)
}
return codes
}
/** CP V2: ONNX single-step with KV-cache, 15 lm_heads, autoregressive. */
private fun runCpV2(pastHidden: FloatArray, cb0: Int): IntArray {
val env = ortEnv ?: return IntArray(15)
val session = cpKv ?: return IntArray(15)
val cos = cpRotaryCos ?: return IntArray(15)
val sin = cpRotarySin ?: return IntArray(15)
val headPath = cpHeadsPath ?: return IntArray(15)
val cembs = cpEmbeddings ?: return IntArray(15)
val codes = IntArray(15)
// Fixed KV: size=CP_KV_LEN always, with mask for active positions
// Dynamic KV: starts empty, grows
val kvSize = if (cpFixedKv) CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM else 0
var kCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
var vCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
val cpMask = if (cpFixedKv) FloatArray(CP_KV_LEN) { -1e9f } else null
// Step 0: hidden state
var emb = pastHidden
for (step in 0 until 17) {
if (step == 1) {
// cb0 embedding from talker codec_embedding
emb = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
} else if (step >= 2) {
// codec_embedding from CP's own tables
emb = FloatArray(TALKER_DIM)
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
System.arraycopy(cembs, off, emb, 0, TALKER_DIM)
}
val kvLen = if (cpFixedKv) CP_KV_LEN else kCaches[0].size / (CP_KV_HEADS * CP_HEAD_DIM)
val cosSlice = FloatArray(CP_HEAD_DIM)
System.arraycopy(cos, step * CP_HEAD_DIM, cosSlice, 0, CP_HEAD_DIM)
val sinSlice = FloatArray(CP_HEAD_DIM)
System.arraycopy(sin, step * CP_HEAD_DIM, sinSlice, 0, CP_HEAD_DIM)
val inputs = LinkedHashMap<String, OnnxTensor>()
inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(emb), longArrayOf(1, 1, TALKER_DIM.toLong()))
if (cpFixedKv) {
// Fixed KV mask: after step N, the last (N+1) positions are valid
// The model shifts left internally, so valid positions are right-aligned
val mask = FloatArray(CP_KV_LEN) { -1e9f }
for (p in 0..step) mask[CP_KV_LEN - 1 - p] = 0f
inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(mask), longArrayOf(1, 1, 1, CP_KV_LEN.toLong()))
}
inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, CP_HEAD_DIM.toLong()))
inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, CP_HEAD_DIM.toLong()))
for (i in 0 until CP_LAYERS) {
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
}
val result = session.run(inputs)
val hidden = FloatArray(TALKER_DIM)
(result.get(0) as OnnxTensor).floatBuffer.get(hidden)
// Extract code from lm_head: hidden @ head[cb]^T → argmax
if (step >= 1 && step - 1 < 15) {
val cbIdx = step - 1
// Lazy-load head into cache (8MB each, loaded once)
if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15)
val cache = cpHeadsCache!!
if (cache[cbIdx] == null) {
cache[cbIdx] = loadNpy(headPath.replace("cp_heads.npy", "head_${cbIdx}.npy"))
}
val headData = cache[cbIdx]!!
var best = 0; var bestVal = Float.NEGATIVE_INFINITY
for (j in 0 until CODEBOOK_SIZE) {
var dot = 0f
val off = j * TALKER_DIM
for (k in 0 until TALKER_DIM) dot += hidden[k] * headData[off + k]
if (dot > bestVal) { bestVal = dot; best = j }
}
codes[cbIdx] = best
}
// Update KV caches
if (cpFixedKv) {
// Fixed: output is same size as input (CP_KV)
val fixedSize = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
kCaches = Array(CP_LAYERS) { i ->
FloatArray(fixedSize).also { (result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
vCaches = Array(CP_LAYERS) { i ->
FloatArray(fixedSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
} else {
// Dynamic: grows by 1
val newKvLen = kvLen + 1
val newKvSize = CP_KV_HEADS * newKvLen * CP_HEAD_DIM
kCaches = Array(CP_LAYERS) { i ->
FloatArray(newKvSize).also { (result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
vCaches = Array(CP_LAYERS) { i ->
FloatArray(newKvSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
}
for ((_, v) in inputs) v.close()
result.close()
}
return codes
}
/** CP via ONNX CPU KV-cache: 17 single-token steps with dynamic KV. ~180ms. */
private fun runCpCpu(pastHidden: FloatArray, cb0: Int): IntArray {
val env = ortEnv ?: return IntArray(15)
val cpModel = cpKv ?: return IntArray(15)
val cpEmbs = cpEmbeddings ?: return IntArray(15)
val codes = IntArray(15)
// Dynamic KV caches: grow from [1,8,0,128] to [1,8,16,128]
var kCaches = Array(CP_LAYERS) { FloatArray(0) }
var vCaches = Array(CP_LAYERS) { FloatArray(0) }
// Step 0: hidden state
var emb = pastHidden
for (step in 0 until 17) {
if (step == 1) {
emb = FloatArray(TALKER_DIM)
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
} else if (step >= 2) {
emb = FloatArray(TALKER_DIM)
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
System.arraycopy(cpEmbs, off, emb, 0, TALKER_DIM)
}
val totalLen = step + 1
val inputs = LinkedHashMap<String, OnnxTensor>()
inputs["input_embeds"] = OnnxTensor.createTensor(env,
FloatBuffer.wrap(emb), longArrayOf(1, 1, TALKER_DIM.toLong()))
inputs["attention_mask"] = OnnxTensor.createTensor(env,
FloatBuffer.wrap(FloatArray(totalLen)), longArrayOf(1, 1, 1, totalLen.toLong()))
inputs["position_ids"] = OnnxTensor.createTensor(env,
LongBuffer.wrap(longArrayOf(step.toLong())), longArrayOf(1, 1))
for (i in 0 until CP_LAYERS) {
val kvLen = kCaches[i].size / (CP_KV_HEADS * CP_HEAD_DIM)
inputs["k_${i}_in"] = OnnxTensor.createTensor(env,
FloatBuffer.wrap(kCaches[i]), longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
inputs["v_${i}_in"] = OnnxTensor.createTensor(env,
FloatBuffer.wrap(vCaches[i]), longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
}
val result = cpModel.run(inputs)
// Extract head logits [1, 15, 2048] and code for this step
if (step >= 1 && step - 1 < 15) {
val headLogits = FloatArray(15 * CODEBOOK_SIZE)
(result.get(1) as OnnxTensor).floatBuffer.get(headLogits)
val cbIdx = step - 1
val headOff = cbIdx * CODEBOOK_SIZE
var maxIdx = 0; var maxVal = Float.NEGATIVE_INFINITY
for (j in 0 until CODEBOOK_SIZE) {
if (headLogits[headOff + j] > maxVal) { maxVal = headLogits[headOff + j]; maxIdx = j }
}
codes[cbIdx] = maxIdx
}
// Update KV caches (dynamic size grows by 1 each step)
val newKvSize = CP_KV_HEADS * totalLen * CP_HEAD_DIM
kCaches = Array(CP_LAYERS) { i ->
FloatArray(newKvSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
vCaches = Array(CP_LAYERS) { i ->
FloatArray(newKvSize).also { (result.get(3 + i * 2) as OnnxTensor).floatBuffer.get(it) }
}
for ((_, v) in inputs) v.close()
result.close()
}
return codes
}
/**
* Trim trailing silence/noise from audio.
* Duplicate removed — see trimTrailingSilence below.
*/
/** Sample from logits with temperature scaling and top-K filtering */
private fun sampleTopK(logits: FloatArray, temperature: Float = 0.9f, topK: Int = 50): Int {
// Find top-K indices
val indices = logits.indices.sortedByDescending { logits[it] }.take(topK)
// Temperature-scaled softmax over top-K
val maxLogit = logits[indices[0]]
val expValues = FloatArray(indices.size)
var sumExp = 0f
for (i in indices.indices) {
val scaled = (logits[indices[i]] - maxLogit) / temperature
expValues[i] = kotlin.math.exp(scaled)
sumExp += expValues[i]
}
// Categorical sampling
val r = Math.random().toFloat() * sumExp
var cumSum = 0f
for (i in indices.indices) {
cumSum += expValues[i]
if (cumSum >= r) return indices[i]
}
return indices.last()
}
// ==================== Code Predictor ====================
// ==================== Speech Decoder ====================
/** Decode 16 codebooks to audio in chunks, with overlap for seamless concatenation */
private fun decodeChunked(codebooks: Array<IntArray>, numRealTokens: Int): ShortArray {
val totalTokens = codebooks[0].size
if (totalTokens <= SEQ_LEN) {
val quantized = vqDecode(codebooks)
val fullAudio = runSpeechDecoder(quantized)
val trimSamples = minOf(numRealTokens * SAMPLES_PER_TOKEN, fullAudio.size)
return fullAudio.copyOf(trimSamples)
}
val overlapSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
var result = ShortArray(0)
var pos = 0
while (pos < numRealTokens) {
val chunkCodes = Array(NUM_CODEBOOKS) { cb ->
IntArray(SEQ_LEN) { t ->
val srcIdx = pos + t
if (srcIdx < totalTokens) codebooks[cb][srcIdx] else 0
}
}
val quantized = vqDecode(chunkCodes)
val chunkAudio = runSpeechDecoder(quantized)
// Trim chunk to real tokens
val realInChunk = minOf(SEQ_LEN, numRealTokens - pos)
val trimmed = chunkAudio.copyOf(minOf(realInChunk * SAMPLES_PER_TOKEN, chunkAudio.size))
if (pos == 0) {
result = trimmed
} else {
// Crossfade over the overlap region
val fadeLen = minOf(overlapSamples, result.size, trimmed.size)
for (i in 0 until fadeLen) {
val alpha = i.toFloat() / fadeLen
val mixed = ((1f - alpha) * result[result.size - fadeLen + i] + alpha * trimmed[i]).toInt()
.coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort()
result[result.size - fadeLen + i] = mixed
}
// Append the non-overlapping part
if (fadeLen < trimmed.size) {
val newPart = trimmed.copyOfRange(fadeLen, trimmed.size)
val combined = ShortArray(result.size + newPart.size)
System.arraycopy(result, 0, combined, 0, result.size)
System.arraycopy(newPart, 0, combined, result.size, newPart.size)
result = combined
}
}
pos += EFFECTIVE_CHUNK
}
return result
}
/** VQ decode: 16 codebooks → quantized [1, 512, SEQ_LEN] */
private fun vqDecode(codebooks: Array<IntArray>): FloatArray {
val firstCb = firstCodebook ?: return FloatArray(0)
val restCbs = restCodebooks ?: return FloatArray(0)
val firstProj = firstOutputProj ?: return FloatArray(0)
val restProj = restOutputProj ?: return FloatArray(0)
val qFirst = FloatArray(CODEBOOK_DIM * SEQ_LEN)
for (t in 0 until SEQ_LEN) {
val idx = codebooks[0][t]
System.arraycopy(firstCb, idx * CODEBOOK_DIM, qFirst, t * CODEBOOK_DIM, CODEBOOK_DIM)
}
val quantized = FloatArray(HIDDEN_DIM * SEQ_LEN)
for (i in 0 until HIDDEN_DIM) {
for (t in 0 until SEQ_LEN) {
var sum = 0f
for (d in 0 until CODEBOOK_DIM) {
sum += firstProj[i * CODEBOOK_DIM + d] * qFirst[t * CODEBOOK_DIM + d]
}
quantized[i * SEQ_LEN + t] = sum
}
}
val restSum = FloatArray(CODEBOOK_DIM * SEQ_LEN)
for (cb in 0 until 15) {
val cbData = restCbs[cb]
for (t in 0 until SEQ_LEN) {
val idx = codebooks[cb + 1][t]
for (d in 0 until CODEBOOK_DIM) {
restSum[t * CODEBOOK_DIM + d] += cbData[idx * CODEBOOK_DIM + d]
}
}
}
for (i in 0 until HIDDEN_DIM) {
for (t in 0 until SEQ_LEN) {
var sum = 0f
for (d in 0 until CODEBOOK_DIM) {
sum += restProj[i * CODEBOOK_DIM + d] * restSum[t * CODEBOOK_DIM + d]
}
quantized[i * SEQ_LEN + t] += sum
}
}
return quantized
}
/** Run speech decoder: quantized → pre_conv → preprocessor → ConvNet → audio */
private fun runSpeechDecoder(quantized: FloatArray): ShortArray {
val env = ortEnv!!
if (decoderOnCpu || decoderOnGpu) {
return runSpeechDecoderV2(quantized)
}
val qTensor = OnnxTensor.createTensor(
env, FloatBuffer.wrap(quantized), longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong())
)
val pcResult = preConv!!.run(mapOf("x" to qTensor))
val pcOut = pcResult[0] as OnnxTensor
nlog("pre_conv: ${qTensor.info.shape.contentToString()}${pcOut.info.shape.contentToString()}")
qTensor.close()
val ppResult = preprocessor!!.run(mapOf("pre_conv_out" to pcOut))
val ppOut = ppResult[0] as OnnxTensor
nlog("preprocessor: → ${ppOut.info.shape.contentToString()}")
pcResult.close()
val cdResult = convDecoder!!.run(mapOf("hidden" to ppOut))
val cdOut = cdResult[0] as OnnxTensor
nlog("conv_decoder: → ${cdOut.info.shape.contentToString()}")
@Suppress("UNCHECKED_CAST")
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
ppResult.close()
cdResult.close()
return ShortArray(audioFloat.size) {
(audioFloat[it].coerceIn(-1f, 1f) * 32767).toInt().toShort()
}
}
/** V2 decoder: quantized[1,512,60] → pre_conv → transpose → pre_transformer → decoder → audio */
private fun runSpeechDecoderV2(quantized: FloatArray): ShortArray {
val env = ortEnv!!
val td0 = System.currentTimeMillis()
// pre_conv: [1,512,60] → [1,1024,60]
val qTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(quantized),
longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong()))
val pcResult = preConv!!.run(mapOf(preConv!!.inputNames.first() to qTensor))
val pcOutRaw = (pcResult[0] as OnnxTensor).floatBuffer
val pcData = FloatArray(1024 * SEQ_LEN)
pcOutRaw.get(pcData)
val td1 = System.currentTimeMillis()
nlog("V2 pre_conv: ${td1-td0}ms")
qTensor.close()
// Transpose [1,1024,60] → [1,60,1024]
val transposed = FloatArray(SEQ_LEN * 1024)
for (c in 0 until 1024) {
for (t in 0 until SEQ_LEN) {
transposed[t * 1024 + c] = pcData[c * SEQ_LEN + t]
}
}
val ptInput = OnnxTensor.createTensor(env, FloatBuffer.wrap(transposed),
longArrayOf(1, SEQ_LEN.toLong(), 1024))
pcResult.close()
// pre_transformer: [1,60,1024] → [1,60,1024]
val ptResult = preprocessor!!.run(mapOf(preprocessor!!.inputNames.first() to ptInput))
val ptOut = ptResult[0] as OnnxTensor
val td2 = System.currentTimeMillis()
nlog("V2 pre_transformer: ${td2-td1}ms")
ptInput.close()
// decoder: [1,60,1024] → [1,1,samples]
val cdResult = convDecoder!!.run(mapOf(convDecoder!!.inputNames.first() to ptOut))
val cdOut = cdResult[0] as OnnxTensor
nlog("V2 BigVGAN: ${System.currentTimeMillis()-td2}ms")
@Suppress("UNCHECKED_CAST")
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
ptResult.close()
cdResult.close()
return ShortArray(audioFloat.size) {
(audioFloat[it].coerceIn(-1f, 1f) * 32767).toInt().toShort()
}
}
// ==================== Loading ====================
private fun loadVqCodebooks(path: String) {
firstCodebook = loadNpy("$path/vq_rvq_first_vq_layers_0_codebook.npy")
firstOutputProj = loadNpy("$path/vq_rvq_first_output_proj_w.npy")
restOutputProj = loadNpy("$path/vq_rvq_rest_output_proj_w.npy")
restCodebooks = Array(15) { i ->
loadNpy("$path/vq_rvq_rest_vq_layers_${i}_codebook.npy")
}
nlog("VQ codebooks loaded (16 × [${CODEBOOK_SIZE}, ${CODEBOOK_DIM}])")
}
/** Load a numpy .npy file as float array */
private fun loadNpy(path: String): FloatArray {
val file = File(path)
if (!file.exists()) {
nlog("WARN: $path not found")
return FloatArray(0)
}
val bytes = file.readBytes()
val headerLen = (bytes[8].toInt() and 0xFF) or ((bytes[9].toInt() and 0xFF) shl 8)
val dataOffset = 10 + headerLen
val numFloats = (bytes.size - dataOffset) / 4
val result = FloatArray(numFloats)
val bb = ByteBuffer.wrap(bytes, dataOffset, bytes.size - dataOffset)
.order(ByteOrder.LITTLE_ENDIAN)
bb.asFloatBuffer().get(result)
return result
}
// ==================== Audio Playback ====================
private fun playAudio(audioData: ShortArray, onComplete: () -> Unit) {
stop()
val bufferSize = audioData.size * 2
audioTrack = AudioTrack.Builder()
.setAudioAttributes(AudioAttributes.Builder()
.setUsage(AudioAttributes.USAGE_MEDIA)
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.build())
.setAudioFormat(AudioFormat.Builder()
.setSampleRate(SR)
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(bufferSize)
.setTransferMode(AudioTrack.MODE_STATIC)
.build()
audioTrack?.apply {
write(audioData, 0, audioData.size)
setNotificationMarkerPosition(audioData.size)
setPlaybackPositionUpdateListener(object : AudioTrack.OnPlaybackPositionUpdateListener {
override fun onMarkerReached(track: AudioTrack?) { onComplete() }
override fun onPeriodicNotification(track: AudioTrack?) {}
})
play()
}
}
/**
* Diagnostic path: decode Python's captured codes directly, bypassing talker+CP.
* Used to isolate whether audio degradation comes from our talker/CP predictions
* (different codes than Python) vs from BigVGAN implementation differences.
*
* File format python_codes.bin: int32 n_steps, int32 n_codebooks(=16),
* then n_steps × 16 × int32 codes.
*/
fun generateFromPythonCodes(codesPath: String): ShortArray {
nlog("Direct Python-codes path: decoding $codesPath via BigVGAN only")
val t0 = System.currentTimeMillis()
val bytes = java.io.File(codesPath).readBytes()
val bb = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN)
val nSteps = bb.int
val nCodebooks = bb.int
nlog("Python codes: $nSteps steps × $nCodebooks codebooks")
if (nCodebooks != NUM_CODEBOOKS) {
nlog("Unexpected codebook count: $nCodebooks != $NUM_CODEBOOKS"); return ShortArray(0)
}
val padLen = maxOf(nSteps, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { IntArray(padLen) }
for (t in 0 until nSteps) {
for (cb in 0 until NUM_CODEBOOKS) {
val v = bb.int
allCodebooks[cb][t] = v.coerceIn(0, CODEBOOK_SIZE - 1)
}
}
val audio = decodeChunked(allCodebooks, nSteps)
nlog("Python-codes decode: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
try {
val wavPath = "/data/local/tmp/kazeia/kazeia_PYCODES.wav"
val fos = java.io.FileOutputStream(wavPath)
val dataLen = audio.size * 2
val header = java.nio.ByteBuffer.allocate(44).order(java.nio.ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = java.nio.ByteBuffer.allocate(dataLen).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
nlog("WAV saved (PyCodes): $wavPath")
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
return audio
}
/**
* Full pipeline from pre-computed embeddings (from Python generate() capture).
* File format: int32 n_prefill, int32 n_total, then n_total × 1024 floats.
* Runs talker ONNX → CP → VQ decode → speech decoder.
*/
fun generateFromEmbeds(embedsPath: String): ShortArray {
// Diagnostic override: if python_codes.bin exists and flag is set, decode
// Python's captured codes directly to isolate BigVGAN from talker/CP predictions.
val pyCodesPath = embedsPath.replace("full_pipeline_embeds.bin", "python_codes.bin")
if (diagFlag("force_python_codes") && java.io.File(pyCodesPath).exists()) {
nlog("force_python_codes flag + file present → diagnostic codes-direct path")
return generateFromPythonCodes(pyCodesPath)
}
if (!loaded || (!nativePipelineReady && talkerPteModule == null && !useHexagonTalker && (talkerKv == null || !talkerUsesCosSin))) {
nlog("generateFromEmbeds: no talker (native=$nativePipelineReady, pte=${talkerPteModule != null}, hex=$useHexagonTalker)")
return ShortArray(0)
}
// Priority: native C++ > Java .pte > Hexagon > ONNX CPU
if (nativePipelineReady) {
return generateFromEmbedsPte(embedsPath)
}
if (talkerPteModule != null && cpPteModule != null) {
return generateFromEmbedsPte(embedsPath)
}
if (useHexagonTalker) {
return generateFromEmbedsHexagon(embedsPath)
}
nlog("Full pipeline from: $embedsPath")
val t0 = System.currentTimeMillis()
val bytes = File(embedsPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val nPrefill = bb.int
val nTotal = bb.int
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
val env = ortEnv!!
val session = talkerKv!!
val cos = rotaryCos!!
val sin = rotarySin!!
// KV caches [1, 8, KV_LEN, 128]
val kvSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
var kCaches = Array(TALKER_LAYERS) { FloatArray(kvSize) }
var vCaches = Array(TALKER_LAYERS) { FloatArray(kvSize) }
val maskData = FloatArray(MAX_CONTEXT) { -1e9f }
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var currentCb0 = -1
var pastHidden: FloatArray? = null
// Prefill
val tPrefill = System.currentTimeMillis()
for (step in 0 until nPrefill) {
maskData[MAX_CONTEXT - 1 - step] = 0f
val res = runTalkerStepMRoPE(env, session, embeds[step], maskData, step, kCaches, vCaches)
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
if (step == nPrefill - 1) {
val logits = res.logits
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
currentCb0 = sampleTopK(logits, 0.9f, 50)
nlog("Prefill: ${System.currentTimeMillis() - tPrefill}ms, cb0=$currentCb0")
}
}
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return ShortArray(0)
// Voice-cloned replay: detect by large nTrailing (≥ 30 → Python-captured full
// trajectory). In that mode, feed captured decode embeds verbatim (they already
// include Python's codec_sum + text_hidden) and stop exactly at nTrailing.
val nTrailingOnnx = nTotal - nPrefill
val isVoiceClonedOnnx = nTrailingOnnx >= 30
val maxGenOnnx = if (isVoiceClonedOnnx) nTrailingOnnx else (nTrailingOnnx * 5 + 20)
var trailingIdxOnnx = 0
val eosE = ttsEosEmbed ?: FloatArray(TALKER_DIM)
val padE = ttsPadEmbed ?: FloatArray(TALKER_DIM)
var totalTalkerMs = 0L; var totalCpMs = 0L
val forceGreedyCb0Onnx = diagFlag("force_greedy_cb0")
nlog("ONNX voice-cloned=$isVoiceClonedOnnx, maxGen=$maxGenOnnx, greedy=$forceGreedyCb0Onnx")
for (genStep in 0 until maxGenOnnx) {
val codes = IntArray(NUM_CODEBOOKS)
codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden!!, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes)
generatedCb0.add(currentCb0)
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
// Voice-cloned: feed captured embed directly. Text-only: reconstruct codec_sum + trailing.
val nextEmbed = if (isVoiceClonedOnnx) {
embeds[nPrefill + genStep]
} else {
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
val textE: FloatArray = when {
trailingIdxOnnx < nTrailingOnnx -> embeds[nPrefill + trailingIdxOnnx].also { trailingIdxOnnx++ }
trailingIdxOnnx == nTrailingOnnx -> { trailingIdxOnnx++; eosE }
else -> padE
}
sumEmb(codecSum, textE)
}
val absPos = nPrefill + genStep
if (absPos < MAX_CONTEXT) maskData[MAX_CONTEXT - 1 - absPos] = 0f
val tTalker = System.currentTimeMillis()
val res = runTalkerStepMRoPE(env, session, nextEmbed, maskData, nPrefill + genStep, kCaches, vCaches)
totalTalkerMs += System.currentTimeMillis() - tTalker
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
val logits = res.logits
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
val nextCb0 = if (forceGreedyCb0Onnx) {
var b = 0; var bv = logits[0]
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
b
} else sampleTopK(logits, 0.9f, 50)
if (!isVoiceClonedOnnx) {
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
nlog("Degeneration at step ${genStep + 2}"); break
}
}
currentCb0 = nextCb0
}
val n = allCodes.size
nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
// Diagnostic: dump our generated codes for comparison with Python capture
try {
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
val fos = java.io.FileOutputStream(codesPath)
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
fos.write(buf.array()); fos.close()
nlog("Tablet codes dumped (ONNX path): $codesPath")
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
if (n == 0) return ShortArray(0)
// Decode
val padLen = maxOf(n, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 }
}
val audio = decodeChunked(allCodebooks, n)
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
try {
val wavPath = "/data/local/tmp/kazeia/kazeia_ONNX.wav"
val fos = java.io.FileOutputStream(wavPath)
val dataLen = audio.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
nlog("WAV saved (ONNX): $wavPath (${audio.size} samples)")
} catch (e: Exception) { nlog("WAV save (ONNX) failed: ${e.message}") }
return audio
}
/** Full pipeline using .pte JNI talker + CP on NPU from pre-computed embeddings. */
private fun generateFromEmbedsPte(embedsPath: String): ShortArray {
nlog("Full pipeline (PTE) from: $embedsPath")
val t0 = System.currentTimeMillis()
val bytes = File(embedsPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
// Detect format by comparing the file size to what each layout would predict.
// single-segment layout: 8 + nTotal*1024*4 bytes (int nPrefill, int nTotal, embeds)
// multi-segment layout: n_segments * (8 + nTotal_i*1024*4) + 4 bytes header
// For a single-segment file, `firstInt == nPrefill` and
// `8 + secondInt*1024*4 == file.length`. If that equation holds, it's single-segment;
// otherwise we treat firstInt as n_segments.
val firstInt = bb.int
val secondInt = bb.int
val fileLen = bytes.size.toLong()
val singleSegmentSize = 8L + secondInt.toLong() * TALKER_DIM * 4
val isSingleSegment = secondInt > 0 && secondInt < 100000 && fileLen == singleSegmentSize
val isMultiSegment = !isSingleSegment && firstInt in 1..100
bb.position(0)
bb.int // re-consume firstInt so downstream reads stay aligned
// (we'll re-read secondInt below when building the single-segment path)
if (isMultiSegment && talkerPteModule != null && cpPteModule != null) {
return generateMultiSegment(bb, firstInt, t0)
}
// Legacy single-segment format
val nPrefill = if (isMultiSegment) { bb.position(0); bb.int; bb.int } else firstInt
val nTotal = bb.int
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
val allCodes: Array<IntArray>
// Check if capture mode requested (force Java path to capture embeds)
val captureMode = File("/data/local/tmp/kazeia/capture_mode").exists()
// Native C++ pipeline using SAME Java Module instances (no quality loss)
if (!captureMode && talkerPteModule != null && cpPteModule != null) {
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
val nTrailing = nTotal - nPrefill
val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr ->
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill + i], 0, arr, i * TALKER_DIM, TALKER_DIM)
} else null
// Load CP heads if not already
if (cpAllHeads == null) {
val headsFile = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin")
if (headsFile.exists()) {
val hb = headsFile.readBytes()
cpAllHeads = FloatArray(hb.size / 4)
ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!)
}
}
// Ensure all data arrays are loaded
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy")
if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy")
if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy")
if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy")
if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("/data/local/tmp/kazeia/models/talker_pte_rotary_cos.npy")
if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("/data/local/tmp/kazeia/models/talker_pte_rotary_sin.npy")
if (ttsEosEmbed == null || ttsPadEmbed == null) {
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
val sp = loadNpy("$mpath/tts_special_embeds.npy")
ttsBosEmbed = sp.sliceArray(0 until TALKER_DIM)
ttsEosEmbed = sp.sliceArray(TALKER_DIM until 2 * TALKER_DIM)
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
}
// Voice-cloned embeds have one entry per audio frame (already captured via
// Python's natural EOS). Heuristic: if nTrailing is large enough to be a
// per-frame capture (≥ 30), trust it as the exact length; otherwise it's a
// text-only prepare_tts_native.py file that needs the runway/boost budget.
val isVoiceCloned = nTrailing >= 30
val maxGen = if (isVoiceCloned) nTrailing
else minOf(nTrailing * 4 + 20, 90 - nPrefill)
nlog("Running native C++ pipeline (shared Module), maxGen=$maxGen, voiceCloned=$isVoiceCloned...")
val flat = talkerPteModule!!.nativeRunTtsPipeline(
prefillFlat, nPrefill,
trailingFlat, nTrailing,
codecEmbedding ?: FloatArray(0),
cpEmbeddings ?: FloatArray(0),
cpAllHeads ?: FloatArray(0),
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
maxGen
)
if (flat == null || flat.isEmpty()) return ShortArray(0)
val nTokens = flat.size / NUM_CODEBOOKS
allCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } }
nlog("Native pipeline: $nTokens tokens")
} else {
// Fallback: Java pipeline
val prefillEmbeds = embeds.sliceArray(0 until nPrefill).toList()
val trailingEmbeds = if (nPrefill < nTotal) embeds.sliceArray(nPrefill until nTotal).toList() else emptyList()
allCodes = runInterleavedPteFromEmbeds(prefillEmbeds, trailingEmbeds, nTotal - nPrefill)
}
if (allCodes.isEmpty()) return ShortArray(0)
val numRealTokens = allCodes.size
val padLen = maxOf(numRealTokens, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t -> if (t < numRealTokens) allCodes[t][cb] else 0 }
}
val t3 = System.currentTimeMillis()
val rawAudio = decodeChunked(allCodebooks, numRealTokens)
nlog("Decode: ${System.currentTimeMillis() - t3}ms")
// Trim trailing noise/silence: scan from end, find last loud frame
val audio = trimTrailingSilence(rawAudio)
nlog("Trimmed: ${rawAudio.size}${audio.size} samples (${(rawAudio.size-audio.size)/SR.toFloat()}s removed)")
val totalMs = System.currentTimeMillis() - t0
val audioDur = audio.size.toFloat() / SR
nlog("Total: ${totalMs}ms for ${audioDur}s")
// Save WAV file for validation
try {
val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav"
val fos = java.io.FileOutputStream(wavPath)
val dataLen = audio.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
nlog("WAV saved: $wavPath (${audio.size} samples)")
} catch (e: Exception) {
nlog("WAV save failed: ${e.message}")
}
return audio
}
/** PTE pipeline from pre-computed embeddings (prefill + trailing). */
private fun runInterleavedPteFromEmbeds(
prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int
): Array<IntArray> {
val talkerMod = talkerPteModule ?: return emptyArray()
val cpMod = cpPteModule ?: return emptyArray()
val tCos = talkerPteRotaryCos ?: return emptyArray()
val tSin = talkerPteRotarySin ?: return emptyArray()
val eosE = ttsEosEmbed ?: return emptyArray()
val padE = ttsPadEmbed ?: return emptyArray()
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
val tkvSize = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
var tK = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
var tV = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
val maskData = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }
var pos = 0; var currentCb0 = -1; var pastHidden: FloatArray? = null
var trailingIdx = 0
// Helper to run one talker step
fun talkerStep(emb: FloatArray): Pair<FloatArray, FloatArray> {
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
if (maskIdx >= 0) maskData[maskIdx] = 0f
val posIdx = minOf(pos, tCos.size / TALKER_HEAD_DIM - 1)
val cosSlice = FloatArray(TALKER_HEAD_DIM); System.arraycopy(tCos, posIdx * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
val sinSlice = FloatArray(TALKER_HEAD_DIM); System.arraycopy(tSin, posIdx * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
val inputs = mutableListOf(
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(emb, longArrayOf(1, 1, TALKER_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
)
for (i in 0 until TALKER_LAYERS) {
inputs.add(org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
inputs.add(org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
}
val out = talkerMod.forward(*inputs.toTypedArray())
val hidden = out[0].toTensor().dataAsFloatArray
val logits = out[1].toTensor().dataAsFloatArray
for (i in 0 until TALKER_LAYERS) {
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
}
pos++
return Pair(hidden, logits)
}
// Capture embeds for C++ reuse
val capturedEmbeds = mutableListOf<FloatArray>()
// ===== PREFILL =====
val tPrefill = System.currentTimeMillis()
for (step in prefillEmbeds.indices) {
capturedEmbeds.add(prefillEmbeds[step].clone())
val (h, logits) = talkerStep(prefillEmbeds[step])
pastHidden = h
if (step == prefillEmbeds.size - 1) {
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
currentCb0 = sampleTopK(logits, 0.9f, 50)
}
}
nlog("Prefill (PTE): ${System.currentTimeMillis() - tPrefill}ms, ${prefillEmbeds.size} steps, cb0=$currentCb0")
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
// ===== GENERATION =====
var totalTalkerMs = 0L; var totalCpMs = 0L
for (genStep in 0 until maxGenTokens) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp0 = System.currentTimeMillis()
val cpCodes = runCpPte(pastHidden!!, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp0
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
// Next talker input: use pre-computed decode embed if available
val nextEmbed: FloatArray
if (trailingIdx < trailingEmbeds.size) {
nextEmbed = trailingEmbeds[trailingIdx]; trailingIdx++
} else {
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
nextEmbed = sumEmb(codecSum, padE)
}
capturedEmbeds.add(nextEmbed.clone())
val tTalker0 = System.currentTimeMillis()
val (h, logits) = talkerStep(nextEmbed)
totalTalkerMs += System.currentTimeMillis() - tTalker0
pastHidden = h
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
val nextCb0 = sampleTopK(logits, 0.9f, 50)
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
nlog("Degeneration at step ${genStep + 2}"); break
}
currentCb0 = nextCb0
}
val n = allCodes.size
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
// Save captured embeds
if (capturedEmbeds.isNotEmpty()) {
try {
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
val nPrefill = prefillEmbeds.size
val fos = java.io.FileOutputStream(capPath)
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
fos.write(hdr.array())
for (emb in capturedEmbeds) {
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
for (v in emb) buf.putFloat(v)
fos.write(buf.array())
}
fos.close()
nlog("Captured ${capturedEmbeds.size} embeds → $capPath (${nPrefill} prefill + ${capturedEmbeds.size - nPrefill} decode)")
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
}
return allCodes.toTypedArray()
}
/** Multi-segment pipeline: process each segment independently, concatenate audio. */
private fun generateMultiSegment(bb: ByteBuffer, nSegments: Int, t0: Long): ShortArray {
nlog("Multi-segment: $nSegments segments")
val allAudio = mutableListOf<ShortArray>()
// Ensure data arrays loaded
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
if (cpAllHeads == null) {
val hf = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin")
if (hf.exists()) { val hb = hf.readBytes(); cpAllHeads = FloatArray(hb.size/4); ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!) }
}
if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy")
if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy")
if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy")
if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy")
if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("$mpath/talker_pte_rotary_cos.npy")
if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("$mpath/talker_pte_rotary_sin.npy")
if (ttsEosEmbed == null) { val sp = loadNpy("$mpath/tts_special_embeds.npy"); ttsBosEmbed=sp.sliceArray(0 until TALKER_DIM); ttsEosEmbed=sp.sliceArray(TALKER_DIM until 2*TALKER_DIM); ttsPadEmbed=sp.sliceArray(2*TALKER_DIM until 3*TALKER_DIM) }
for (seg in 0 until nSegments) {
val nPrefill = bb.int; val nTotal = bb.int
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Segment ${seg+1}/$nSegments: $nPrefill prefill + ${nTotal-nPrefill} decode")
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
val nTrailing = nTotal - nPrefill
val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr ->
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM)
} else null
val maxGen = minOf(nTrailing * 4 + 20, 90 - nPrefill)
val flat = talkerPteModule!!.nativeRunTtsPipeline(
prefillFlat, nPrefill, trailingFlat, nTrailing,
codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0),
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
maxGen
)
if (flat == null || flat.isEmpty()) continue
val nTokens = flat.size / NUM_CODEBOOKS
val segCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } }
nlog("$nTokens tokens generated")
val padLen = maxOf(nTokens, SEQ_LEN)
val codebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < nTokens) segCodes[t][cb] else 0 } }
val segAudio = decodeChunked(codebooks, nTokens)
allAudio.add(segAudio)
nlog("${segAudio.size/SR.toFloat()}s audio decoded")
}
// The fp16 NPU drift means each segment ends with a "filler" tail (page page beg).
// We detect the boundary on the audio itself: scan windows from the end, find the
// last sustained "speech-energy" window vs the trailing low-energy filler. Use
// the running peak energy as a per-segment reference so quiet vs loud voices
// both work.
val gapMs = 120
val gapSamples = SR * gapMs / 1000
val gap = ShortArray(gapSamples)
fun trimSegmentTail(audio: ShortArray): ShortArray {
if (audio.size < SR / 2) return audio
val winMs = 40
val winSamples = SR * winMs / 1000 // 960 samples
val nWin = audio.size / winSamples
if (nWin < 6) return audio
val rms = FloatArray(nWin)
for (w in 0 until nWin) {
var s = 0.0
val o = w * winSamples
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
}
// Reference: peak in the first 70% of the segment (definitely speech)
var peak = 0f
val refEnd = (nWin * 7) / 10
for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w]
// Walk backward from the end, find the last window where energy >= 35% of peak
// sustained over 3 consecutive windows (filler tail has lower & more variable energy).
val thr = peak * 0.35f
var lastSpeech = nWin - 1
for (w in nWin - 1 downTo 2) {
if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break }
}
// Add a tiny fade-out window after to keep the last syllable's natural decay
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
val keepSamples = (keepWin + 1) * winSamples
return audio.copyOf(keepSamples)
}
val trimmedAudio = allAudio.map { trimSegmentTail(it) }
for ((i, seg) in trimmedAudio.withIndex()) {
nlog(" Trim seg ${i+1}: ${allAudio[i].size/SR.toFloat()}s -> ${seg.size/SR.toFloat()}s")
}
val totalSamples = trimmedAudio.sumOf { it.size } + (trimmedAudio.size - 1) * gapSamples
val result = ShortArray(totalSamples)
var offset = 0
for ((i, seg) in trimmedAudio.withIndex()) {
System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size
if (i < trimmedAudio.size - 1) { System.arraycopy(gap, 0, result, offset, gapSamples); offset += gapSamples }
}
val totalMs = System.currentTimeMillis() - t0
nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)")
// Save WAV
try {
val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav"
val fos = java.io.FileOutputStream(wavPath)
val dataLen = result.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in result) buf.putShort(s)
fos.write(buf.array()); fos.close()
nlog("WAV saved: $wavPath ($totalSamples samples)")
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
return result
}
/** No-op trim — garbage post-text has high energy, can't distinguish from speech.
* Length is controlled by maxTokens = trailing count instead. */
private fun trimTrailingSilence(audio: ShortArray): ShortArray = audio
/** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */
private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray {
nlog("Full pipeline (Hexagon) from: $embedsPath")
val t0 = System.currentTimeMillis()
val bytes = File(embedsPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val nPrefill = bb.int; val nTotal = bb.int
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
hexReset()
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var totalTalkerMs = 0L; var totalCpMs = 0L
// Prefill
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(embeds.take(nPrefill))
nlog("Prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
if (prefillResults.isEmpty()) return ShortArray(0)
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
// Diagnostic-only cb0 sampling controls (DEBUG builds only).
// Empirical finding (thesis-relevant): temperature and greedy vs stochastic
// give similar perceptual audio quality because the cross-architecture logits
// drift (x86 AVX ↔ ARM64 HMX fp16) is the dominant error source, not sampling.
// Production always uses temp=0.9 stochastic matching Python's default.
val cb0Temp: Float = diagFile("cb0_temp")?.let {
try { it.readText().trim().toFloat() } catch (e: Exception) { 0.9f }
} ?: 0.9f
if (cb0Temp != 0.9f) nlog("cb0_temp override: $cb0Temp")
val forceGreedyCb0 = diagFlag("force_greedy_cb0")
if (forceGreedyCb0) nlog("force_greedy_cb0: using argmax for cb0 sampling (incl. prefill)")
// Apply greedy/temperature to PREFILL sample too — previously this was always
// sampleTopK(0.9) which made step 0 (the "B" of "Bonjour") stochastic even when
// greedy was requested. This is the most audible source of tremor on the very
// first syllable.
var currentCb0 = if (forceGreedyCb0) {
var b = 0; var bv = prefillLogits[0]
for (j in 1 until prefillLogits.size) if (prefillLogits[j] > bv) { bv = prefillLogits[j]; b = j }
b
} else sampleTopK(prefillLogits, cb0Temp, 50)
nlog("Prefill cb0=$currentCb0 (greedy=$forceGreedyCb0)")
// Voice-cloned single-file replay: nTotal-nPrefill captured decode embeds that
// already contain codec_sum_python + text_hidden at each step. Feed them directly
// to the Hexagon talker (no re-sum from NPU codes). Stops at maxGen = nTrailing
// because Python already decided where to stop. Hexagon's dynamic KV means no
// eviction of the prefill context — this is the whole point of the test.
val nTrailingHex = nTotal - nPrefill
val isVoiceCloned = nTrailingHex >= 30
nlog("Hex voice-cloned mode: $isVoiceCloned ($nTrailingHex decode steps)")
// Diagnostic (DEBUG only): inject Python-captured cb0 at each step. This is
// how we proved tablet CP is 99.76% bit-identical to PyTorch when cb0 matches
// — the tremor is entirely caused by divergent cb0 trajectories, not CP drift.
val injectPyCb0 = diagFlag("force_inject_pycb0")
val pyCodesFile = File("/data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
var pyCb0: IntArray? = null
if (injectPyCb0 && pyCodesFile.exists()) {
val pb = ByteBuffer.wrap(pyCodesFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
val n = pb.int; val cb = pb.int
pyCb0 = IntArray(n) {
val v = pb.int // cb0
for (j in 1 until cb) pb.int // skip cb1..cb(cb-1)
v
}
nlog("injectPyCb0: loaded ${pyCb0!!.size} cb0 values (cb=$cb) from python_codes.bin")
} else if (injectPyCb0) {
nlog("injectPyCb0: flag set but python_codes.bin not found at ${pyCodesFile.absolutePath}")
}
// Seed the very first cb0 with Python's value too if injecting — otherwise step 0's
// cb0 comes from tablet prefill sampling and doesn't match.
if (injectPyCb0 && pyCb0 != null && pyCb0!!.isNotEmpty()) {
currentCb0 = pyCb0!![0]
nlog("injectPyCb0: overriding step 0 cb0 -> ${pyCb0!![0]}")
}
for (genStep in 0 until nTrailingHex) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
// Feed captured decode embed verbatim (already includes Python's codec_sum + text)
val nextEmbed = embeds[nPrefill + genStep]
val tT = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
totalTalkerMs += System.currentTimeMillis() - tT
if (results.isEmpty()) { nlog("Hex forward returned empty at step ${genStep+1}"); break }
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
// Determine next cb0. If injecting, snap to Python's captured cb0 for step+1.
// This isolates pure CP numerical divergence: tablet gets identical embeds AND
// identical cb0 at every step, so any cb1..cb15 divergence must come from CP
// numerics (ONNX Runtime vs PyTorch) alone.
val nextIdx = genStep + 1
currentCb0 = if (injectPyCb0 && pyCb0 != null && nextIdx < pyCb0!!.size) {
pyCb0!![nextIdx]
} else if (forceGreedyCb0) {
var b = 0; var bv = logits[0]
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
b
} else sampleTopK(logits, cb0Temp, 50)
}
val n = allCodes.size
nlog("Generated $n tokens | Talker(HEX): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
// Stop hexagon runners before decode ONLY if decoder uses HTP (DSP conflict)
if (!decoderOnCpu && !decoderOnGpu) {
hexStopRunner()
}
if (n == 0) return ShortArray(0)
val padLen = maxOf(n, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 } }
// Diagnostic: dump our generated codes for comparison with Python capture
try {
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
val fos = java.io.FileOutputStream(codesPath)
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
fos.write(buf.array()); fos.close()
nlog("Tablet codes dumped: $codesPath ($n steps × $NUM_CODEBOOKS)")
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
val audio = decodeChunked(allCodebooks, n)
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s")
// Save WAV for inspection
try {
val wavPath = "/data/local/tmp/kazeia/kazeia_HEX.wav"
val fos = java.io.FileOutputStream(wavPath)
val dataLen = audio.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
nlog("WAV saved (Hex): $wavPath (${audio.size} samples)")
} catch (e: Exception) { nlog("WAV save (Hex) failed: ${e.message}") }
return audio
}
/**
* Run the Hexagon talker + CP generation loop on a single segment's embeds
* and return the decoded audio. Extracted from generateFromEmbedsHexagon so
* both single-segment playback and the streaming multi-segment path share
* exactly the same generation path (no quality drift between modes).
*
* Caller is responsible for hexReset() before first call of a request.
* Subsequent calls (segments 2..N in multi-segment mode) must hexReset()
* between segments so the talker KV-cache doesn't carry stale context.
*/
private fun runHexSegmentFromEmbeds(
prefillEmbeds: List<FloatArray>,
trailingEmbeds: List<FloatArray>,
segIdx: Int = 0
): ShortArray {
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var totalTalkerMs = 0L; var totalCpMs = 0L
// Prefill
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(prefillEmbeds)
nlog("Seg ${segIdx+1} prefill: ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
if (prefillResults.isEmpty()) return ShortArray(0)
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
val nTrailing = trailingEmbeds.size
for (genStep in 0 until nTrailing) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
val nextEmbed = trailingEmbeds[genStep]
val tT = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
totalTalkerMs += System.currentTimeMillis() - tT
if (results.isEmpty()) { nlog("Seg ${segIdx+1}: hex empty at step ${genStep+1}"); break }
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
currentCb0 = sampleTopK(logits, 0.9f, 50)
}
val n = allCodes.size
nlog("Seg ${segIdx+1} generated $n tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
if (n == 0) return ShortArray(0)
// cb0 can hit CODEC_EOS (> CODEBOOK_SIZE) on longer phrases — the original
// single-segment Hexagon path never exercises this because the short Baer
// probe stayed well inside the decode budget. Clamp any out-of-vocab code
// to 0 (silence) so vqDecode can't read past the codebook buffer.
val padLen = maxOf(n, SEQ_LEN)
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t ->
if (t < n) { val v = allCodes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
}
}
return decodeChunked(allCodebooks, n)
}
/**
* Streaming multi-segment variant of generateFromEmbeds. Reads a multi-segment
* embeds file, generates each segment via Hexagon talker + CP sequentially,
* and invokes `onSegmentReady(idx, audio)` the moment each segment's audio is
* decoded. The callback writes to an AudioTrack in the calling coroutine so
* playback begins as soon as segment 1 finishes (~5s for a 5s segment,
* instead of ~15s for the full phrase).
*
* Each segment's raw audio is saved to /data/local/tmp/kazeia/kazeia_stream_segN.wav
* and the final concatenated audio to /data/local/tmp/kazeia/kazeia_stream_full.wav
* so the caller can inspect individual segments for quality regressions.
*
* Single-segment files are supported as a degenerate case (nSegments=1) so
* the caller doesn't need to branch on format.
*/
fun generateFromEmbedsHexagonStreaming(
embedsPath: String,
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
): ShortArray {
if (!loaded || !useHexagonTalker) {
nlog("Streaming: Hexagon talker not ready")
return ShortArray(0)
}
val t0 = System.currentTimeMillis()
val bytes = File(embedsPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
// Format detection (same logic as generateFromEmbedsPte):
// single: <i32 nPrefill> <i32 nTotal> <f32 × nTotal × 1024>
// multi : <i32 nSegments> [<i32 nPrefill_i> <i32 nTotal_i> <f32 ...>] × nSegments
val firstInt = bb.int
val secondInt = bb.int
val fileLen = bytes.size.toLong()
val singleSize = 8L + secondInt.toLong() * TALKER_DIM * 4
val isSingle = secondInt > 0 && secondInt < 100000 && fileLen == singleSize
val nSegments = if (isSingle) 1 else firstInt
bb.position(if (isSingle) 0 else 4)
nlog("Streaming: $nSegments segment(s), ${bytes.size} bytes")
// Ensure a fresh runner connection for this request. Between requests
// the KV-cache carries stale state from the previous generation and
// prefill logits come out as garbage on segment 1.
hexReset()
val segmentAudios = mutableListOf<ShortArray>()
val gapSamples = SR * 120 / 1000
val gap = ShortArray(gapSamples)
for (seg in 0 until nSegments) {
// Between segments the talker KV-cache must be reset so segment 2's
// prefill logits don't contain segment 1's state. Skipping this
// produces garbled speech from segment 2 onwards.
if (seg > 0) hexReset()
val nPrefill = bb.int
val nTotal = bb.int
val prefill = List(nPrefill) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
val nTrailing = nTotal - nPrefill
val trailing = List(nTrailing) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
nlog("Streaming seg ${seg+1}/$nSegments: $nPrefill prefill + $nTrailing decode")
val tSeg = System.currentTimeMillis()
val audio = runHexSegmentFromEmbeds(prefill, trailing, seg)
val segMs = System.currentTimeMillis() - tSeg
nlog("Streaming seg ${seg+1}/$nSegments: ${audio.size/SR.toFloat()}s audio in ${segMs}ms")
// Emit to caller immediately so playback starts now. WAV dump is
// synchronous here; for true zero-lag streaming the caller can
// write to AudioTrack on its own dispatcher from the callback.
segmentAudios.add(audio)
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${seg+1}.wav", audio)
onSegmentReady?.invoke(seg, audio)
}
// Concatenate with short gaps between segments for the full-file WAV.
// Playback path already inserted perceptual spacing via the callback order.
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
val concat = ShortArray(total)
var off = 0
for ((i, s) in segmentAudios.withIndex()) {
System.arraycopy(s, 0, concat, off, s.size); off += s.size
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
}
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
nlog("Streaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s ($nSegments seg)")
return concat
}
/**
* Look up a single token in the fp16 full-vocab text-embedding table and
* return it as fp32. Uses direct ByteBuffer arithmetic so we don't
* allocate a new buffer per token — for a typical 50-token sentence the
* inner loop runs 50 × 1024 fp16→fp32 conversions.
*/
private fun textEmbFromFull(tokenId: Int): FloatArray {
val buf = textEmbedsFullBuf ?: error("Stage 2 full embeddings not loaded")
val clamped = tokenId.coerceIn(0, textEmbedsFullLen - 1)
val base = clamped * TALKER_DIM * 2
val out = FloatArray(TALKER_DIM)
synchronized(buf) {
// MappedByteBuffer has mutable position; guard in case two
// coroutines ever race on tokenizer output concurrently.
buf.position(base)
for (j in 0 until TALKER_DIM) {
val bits = buf.short
out[j] = halfToFloat(bits)
}
}
return out
}
/** IEEE 754 fp16 -> fp32 conversion. Handles subnormals, inf and NaN
* exactly the way Python's `torch.float16.to(float32)` does. */
private fun halfToFloat(h: Short): Float {
val bits = h.toInt() and 0xffff
val sign = (bits ushr 15) and 0x1
val exp = (bits ushr 10) and 0x1f
val mant = bits and 0x3ff
val f32: Int = when {
exp == 0 && mant == 0 -> sign shl 31
exp == 0 -> {
// Subnormal: normalize by shifting until leading 1 appears.
var e = -14; var m = mant
while ((m and 0x400) == 0) { m = m shl 1; e -= 1 }
val e32 = e + 127
(sign shl 31) or (e32 shl 23) or ((m and 0x3ff) shl 13)
}
exp == 0x1f -> (sign shl 31) or (0xff shl 23) or (mant shl 13) // inf / NaN
else -> (sign shl 31) or ((exp - 15 + 127) shl 23) or (mant shl 13)
}
return Float.fromBits(f32)
}
// Stage 3 streaming session state. A session opens a single AudioTrack
// and a background worker that pulls sentences from a Channel and
// generates+plays them sequentially. Audio for sentence N plays while
// sentence N+1's codes are being generated on Hexagon+CP, giving a
// smoother LLM-to-voice UX than "generate all, play all".
private var sessionTrack: AudioTrack? = null
private var sessionChannel: kotlinx.coroutines.channels.Channel<String>? = null
private var sessionJob: kotlinx.coroutines.Job? = null
/**
* Open a streaming TTS session backed by a persistent AudioTrack. After
* this returns, callers feed sentences one by one via enqueueSentence();
* each sentence is voice-cloned and its audio is written to the shared
* track as soon as it's decoded. Call endStreamingSession() to flush
* the queue and release the track.
*/
fun startStreamingSession() {
if (sessionTrack != null) return // already open
val track = AudioTrack.Builder()
.setAudioAttributes(AudioAttributes.Builder()
.setUsage(AudioAttributes.USAGE_MEDIA)
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.build())
.setAudioFormat(AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(SR)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(SR * 30 * 2) // 30 s buffer; AudioTrack
// paces writes when full.
.setTransferMode(AudioTrack.MODE_STREAM)
.build()
track.play()
val chan = kotlinx.coroutines.channels.Channel<String>(
capacity = kotlinx.coroutines.channels.Channel.UNLIMITED
)
val job = kotlinx.coroutines.CoroutineScope(
kotlinx.coroutines.Dispatchers.IO
).launch {
var segIdx = 0
for (sentence in chan) {
try {
val audio = generateSegmentAudioVC(sentence, segIdx)
if (audio.isNotEmpty()) track.write(audio, 0, audio.size)
segIdx++
} catch (e: Exception) {
nlog("session seg $segIdx error: ${e.message}")
}
}
}
sessionTrack = track; sessionChannel = chan; sessionJob = job
nlog("streaming session opened")
}
/**
* Enqueue a sentence for background generation and playback in the
* session opened by startStreamingSession(). Non-blocking: returns
* immediately. Sentences play in the order they were enqueued.
*/
fun enqueueSentence(sentence: String) {
val chan = sessionChannel ?: run { nlog("enqueueSentence: no session open"); return }
val r = chan.trySend(sentence)
if (r.isFailure) nlog("enqueueSentence: channel full / closed")
}
/**
* Close the sentence queue, wait for all pending audio to finish
* generating (playback may continue briefly after as the AudioTrack
* drains), then release the shared track. Safe to call more than once.
*/
suspend fun endStreamingSession() {
val chan = sessionChannel ?: return
chan.close()
try { sessionJob?.join() } catch (_: Exception) {}
try {
sessionTrack?.let {
// Block until written samples have been consumed by the
// hardware so users aren't cut off mid-syllable.
it.stop(); it.release()
}
} catch (_: Exception) {}
sessionTrack = null; sessionChannel = null; sessionJob = null
nlog("streaming session closed")
}
/**
* Voice-clone a single sentence and return its decoded audio. Mirrors
* the per-segment body of synthesizeTextStreaming but without WAV save
* side effects, since the streaming session only cares about PCM going
* to the AudioTrack.
*/
private fun generateSegmentAudioVC(segText: String, segIdx: Int): ShortArray {
if (bpeTokenizer == null || textEmbedsFullBuf == null || damienVoicePrefix == null || damienVoiceSuffix == null) {
nlog("generateSegmentAudioVC: Stage 2 assets missing"); return ShortArray(0)
}
val prefix = damienVoicePrefix!!
val suffix = damienVoiceSuffix!!
val codecPadEmb = codecEmb(CODEC_PAD)
val ids = bpeTokenizer!!.encode(segText)
nlog("session seg $segIdx '${segText.take(60)}' → ${ids.size} tokens")
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
for (e in prefix) prefill.add(e)
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e)
// See synthesizeTextStreaming for the rationale: generous absolute
// cap, dynamic EOS-rank-driven boost, safety floor at 50 % of
// expected to suppress early termination on short utterances.
val expectedSteps = (ids.size * 24) / 10
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
val eosBoostMinStep = expectedSteps / 2
// Backend dispatch: with the DSP-contention fix (force_hexagon removed)
// the Hexagon talker socket isn't opened. Fall back to the .pte path,
// which creates fresh KV arrays per call so no manual reset is needed.
val codes: Array<IntArray> = if (talkerSocket != null) {
hexReset()
runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
} else if (talkerPteModule != null && cpPteModule != null) {
runInterleavedPteFromEmbeds(prefill, emptyList(), maxGen)
} else {
nlog("generateSegmentAudioVC: no talker backend available"); return ShortArray(0)
}
if (codes.isEmpty()) return ShortArray(0)
val n = codes.size
val padLen = maxOf(n, SEQ_LEN)
val codebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t ->
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
}
}
// 40 ms fade-out so the EOS-clipped tail decays naturally before
// the AudioTrack write.
return fadeOut(decodeChunked(codebooks, n), 40)
}
/**
* Run the Hexagon talker + CP generation loop with a fully pre-built
* prefill (voice prefix + all text tokens). Same decode recipe as
* runInterleavedHexagon's inner loop: at each step the next talker
* input is codecSum (codec embedding of previous codes) + tts_pad
* (since all text has already been consumed in the prefill). On EOS,
* terminate early. Returns the generated [step, codebook] codes.
*/
private fun runHexGenWithPrefill(
prefill: List<FloatArray>,
maxGen: Int,
// Floor on how many steps to generate before the dynamic boost is
// allowed to fire. Prevents short-text segments from terminating on
// a stray "EOS-leaning" hidden state during the first few frames.
// Callers pass ~50 % of expected speech length as a safety floor.
eosBoostMinStep: Int = -1,
// Once EOS rank falls below this threshold (model itself is
// "thinking about stopping"), start adding eosBoostScale per step
// until argmax flips to EOS. Empirically EOS rank plateaus ~150-700
// mid-speech and dips to ~50-60 right at the natural end, so this
// catches the model's intent without a fixed length budget.
eosRankTrigger: Int = 60,
eosBoostScale: Float = 4.0f
): Array<IntArray> {
val padE = ttsPadEmbed ?: return emptyArray()
val eosE = ttsEosEmbed ?: return emptyArray()
val allCodes = mutableListOf<IntArray>()
val generatedCb0 = mutableListOf<Int>()
var totalTalkerMs = 0L; var totalCpMs = 0L
val tPrefill = System.currentTimeMillis()
val prefillResults = hexForward(prefill)
nlog("VC prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
if (prefillResults.isEmpty()) return emptyArray()
var pastHidden = prefillResults.last().first
val prefillLogits = prefillResults.last().second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
nlog("VC prefill done: first cb0=$currentCb0")
// After the text has been fully consumed in prefill, Python's voice-
// clone loop feeds tts_eos once, then tts_pad for every subsequent
// decode step. We follow the same schedule so the model's attention
// sees the same "text exhausted" signal it was trained with.
var eosFedOnce = false
// Counter for the dynamic EOS boost — once armed, increments every
// step monotonically. Arming requires 3 CONSECUTIVE steps below
// the rank trigger (transient dips during normal speech aren't
// enough); this keeps short sentences from terminating mid-word
// on a fluke low rank.
var boostStepsActive = 0
var consecLowRank = 0
for (genStep in 0 until maxGen) {
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
val tCp = System.currentTimeMillis()
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
allCodes.add(codes); generatedCb0.add(currentCb0)
totalCpMs += System.currentTimeMillis() - tCp
val codecSum = FloatArray(TALKER_DIM)
addEmb(codecSum, codecEmb(codes[0]))
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
val nextEmbed = if (!eosFedOnce) { eosFedOnce = true; sumEmb(codecSum, eosE) } else sumEmb(codecSum, padE)
val tT = System.currentTimeMillis()
val results = hexForward(listOf(nextEmbed))
totalTalkerMs += System.currentTimeMillis() - tT
if (results.isEmpty()) { nlog("VC: hex empty at step ${genStep+1}"); break }
pastHidden = results[0].first
val logits = results[0].second
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
// Dynamic EOS detection: track current EOS rank, require 3
// consecutive low-rank steps before arming. Single-step rank
// dips happen mid-utterance during low-energy phonemes and
// would otherwise trigger premature termination on short
// sentences. Once armed, the boost accumulates monotonically.
val eosLogit0 = logits[CODEC_EOS]
var eosRank = 0
for (j in logits.indices) if (logits[j] > eosLogit0) eosRank++
if (boostStepsActive == 0 && genStep >= eosBoostMinStep) {
if (eosRank < eosRankTrigger) {
consecLowRank++
if (consecLowRank >= 3) {
nlog("VC boost armed at step ${genStep+1} (EOS rank $eosRank, 3 consecutive low)")
boostStepsActive = 1
}
} else {
consecLowRank = 0
}
} else if (boostStepsActive > 0) {
boostStepsActive++
}
if (boostStepsActive > 0) {
logits[CODEC_EOS] += boostStepsActive * eosBoostScale
}
var argmax = 0; var argmaxV = logits[0]
for (j in 1 until logits.size) if (logits[j] > argmaxV) { argmaxV = logits[j]; argmax = j }
if (argmax == CODEC_EOS) { nlog("VC EOS (boosted argmax) at step ${genStep+1}"); break }
currentCb0 = sampleTopK(logits, 0.9f, 200)
if (currentCb0 == CODEC_EOS) { nlog("VC EOS (sampled) at step ${genStep+1}"); break }
// Degeneracy guard #1: 9 consecutive identical cb0 → stop.
// Catches the simplest stuck loop.
val nHist = generatedCb0.size
if (nHist >= 9) {
val last = generatedCb0[nHist - 1]
var allSame = true
for (i in nHist - 9 until nHist) if (generatedCb0[i] != last) { allSame = false; break }
if (allSame && currentCb0 == last) {
nlog("VC degen: cb0=$last repeated ≥9× at step ${genStep+1}, stopping")
break
}
}
// Degeneracy guard #2: low diversity in the recent window.
// The "page beg beg" filler doesn't repeat a single token — it
// cycles through 2-3 tokens. If the last 12 cb0 contain fewer
// than 4 unique values, the talker is in the cycle and we stop
// before the audio degrades further.
if (nHist >= 12) {
val recent = HashSet<Int>()
for (i in nHist - 12 until nHist) recent.add(generatedCb0[i])
if (recent.size < 4) {
nlog("VC degen: only ${recent.size} unique cb0 in last 12 at step ${genStep+1}, stopping")
break
}
}
}
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
// Diagnostic: log full cb0 sequence so we can correlate against the
// page-beg region in the produced audio.
val cb0Str = generatedCb0.joinToString(",")
nlog("VC cb0 sequence: $cb0Str")
return allCodes.toTypedArray()
}
/**
* Split text into short segments for the streaming pipeline. Reproduces
* the behaviour of scripts/prepare_tts_segments.py but on-device so the
* tablet doesn't depend on a PC-side preprocessor. The 120-character
* target matches the length where the talker still terminates reliably
* on EOS; longer segments risk the auto-repressor's repetition penalty
* cutting decode short of the full phrase.
*/
private fun splitSentences(text: String, maxChars: Int = 120): List<String> {
val first = text.trim().split(Regex("(?<=[.!?;:])\\s+"))
val out = mutableListOf<String>()
for (part in first) {
if (part.length <= maxChars) {
if (part.isNotBlank()) out.add(part.trim())
continue
}
// Break overlong sentences at commas, greedily packing sub-parts
// back together up to maxChars so we don't over-split.
val subs = part.split(Regex("(?<=,)\\s+"))
var current = ""
for (s in subs) {
if (current.isNotEmpty() && current.length + s.length > maxChars) {
out.add(current.trim()); current = s
} else {
current = if (current.isEmpty()) s else "$current $s"
}
}
if (current.isNotBlank()) out.add(current.trim())
}
return if (out.isEmpty()) listOf(text) else out
}
/**
* Tokenize `text` on-device and stream the synthesized audio segment by
* segment via `onSegmentReady`. The PC-side prep script becomes optional —
* this is the first path in the app where TTS runs fully offline on
* arbitrary LLM output.
*
* For each segment:
* 1. BPE tokenize the segment text (same algorithm as Python's
* Qwen2Tokenizer).
* 2. Look up each token ID in the fp16 full-vocab table, converted to
* fp32 on the fly. One embedding per token.
* 3. Call the existing runInterleavedHexagon loop — which already
* synthesizes its own decode inputs via codec_sum + trailing text —
* so we reuse the same prefill construction and generation path that
* runs today for the pre-computed-embeds test harness.
* 4. Decode codes → audio via decodeChunked, emit the audio through
* the callback immediately, save WAV per segment plus the concat.
*
* Emits at most one callback per segment. First-audio latency ≈ prefill +
* one segment's decode (typically ~17-22s for a 5-6s-duration segment on
* Snapdragon 8 Elite). The ordering and gaps between segments are the
* same as generateFromEmbedsHexagonStreaming.
*/
fun synthesizeTextStreaming(
text: String,
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
): ShortArray {
if (!loaded || !useHexagonTalker) {
nlog("synthesizeTextStreaming: Hexagon talker not ready"); return ShortArray(0)
}
if (bpeTokenizer == null || textEmbedsFullBuf == null) {
nlog("synthesizeTextStreaming: Stage 2 assets missing"); return ShortArray(0)
}
val segments = splitSentences(text)
nlog("synthesizeTextStreaming: ${segments.size} segment(s) for ${text.length} chars")
hexReset()
val segmentAudios = mutableListOf<ShortArray>()
// Wider gap (250 ms) between segments — gives the listener a small
// breath/pause that mirrors how real speakers separate sentences,
// and masks the EOS-boost cut by surrounding it with silence rather
// than another sentence's onset.
val gapSamples = SR * 250 / 1000
val gap = ShortArray(gapSamples)
val t0 = System.currentTimeMillis()
val prefix = damienVoicePrefix!!
val suffix = damienVoiceSuffix!!
// CODEC_PAD embedding is the per-token companion that Python voice-
// cloning sums into every text-encoded prefill position. Computed
// once here so the per-token loop stays a simple vector add.
val codecPadEmb = codecEmb(CODEC_PAD)
for ((segIdx, segText) in segments.withIndex()) {
if (segIdx > 0) hexReset()
val tSeg = System.currentTimeMillis()
val ids = bpeTokenizer!!.encode(segText)
nlog("Seg ${segIdx+1}/${segments.size}: '${segText.take(60)}' → ${ids.size} tokens: ${ids.toList()}")
// Voice-cloning prefill, fully reconstructed on-device — exact
// structure Python emits via generate_voice_clone (verified
// bit-for-bit by comparing captured Baer segments):
// [0..8] damienVoicePrefix (9 fixed positions, xvector@7)
// [9..N-3] text_projection(BPE_id) + codec_embedding(CODEC_PAD)
// [N-2, N-1] damienVoiceSuffix (2 fixed positions, end-of-text marker)
// The earlier attempt skipped the suffix and used raw text
// projections, which produced garbled audio — the talker needs
// BOTH the per-token codec_pad sum AND the closure markers to
// know that text input has ended and decoding can begin.
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
for (e in prefix) prefill.add(e)
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
for (e in suffix) prefill.add(e)
// Generous absolute cap; the dynamic EOS boost (triggered when
// the model's own EOS rank dips below threshold) is what
// actually terminates generation. The minStep floor protects
// against early-termination spikes for short sentences.
val expectedSteps = (ids.size * 24) / 10 // ids * 2.4 (int math)
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
val eosBoostMinStep = expectedSteps / 2
val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
val n = codes.size
val padLen = maxOf(n, SEQ_LEN)
val codebooks = Array(NUM_CODEBOOKS) { cb ->
IntArray(padLen) { t ->
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
}
}
val rawAudio = decodeChunked(codebooks, n)
// Cosine fade-out on the last 40 ms — softens the cut imposed by
// the EOS boost so the segment ends on a recognisable phoneme
// tail instead of an abrupt sample-clip.
val audio = fadeOut(rawAudio, 40)
val segMs = System.currentTimeMillis() - tSeg
val budgetHit = if (n >= maxGen) " [maxGen cap]" else ""
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms$budgetHit")
segmentAudios.add(audio)
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
onSegmentReady?.invoke(segIdx, audio)
}
if (segmentAudios.isEmpty()) return ShortArray(0)
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
val concat = ShortArray(total)
var off = 0
for ((i, s) in segmentAudios.withIndex()) {
System.arraycopy(s, 0, concat, off, s.size); off += s.size
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
}
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
nlog("synthesizeTextStreaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s")
return concat
}
/**
* Trim low-energy tail from a voice-cloned segment. When the talker
* fails to emit EOS within maxGen, the remaining decoded codes tend
* to be "page beg beg" fillers — audible as a mumbled tail. This
* RMS-based trim finds the last sustained high-energy window and cuts
* after it, with a small fade-out so the last real syllable keeps its
* natural decay. Extracted from generateMultiSegment so the
* synthesizeTextStreaming path can reuse it verbatim.
*/
private fun trimTailLowEnergy(audio: ShortArray): ShortArray {
if (audio.size < SR / 2) return audio
val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples
val nWin = audio.size / winSamples
if (nWin < 10) return audio
val rms = FloatArray(nWin)
for (w in 0 until nWin) {
var s = 0.0
val o = w * winSamples
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
}
var peak = 0f
for (w in 0 until nWin) if (rms[w] > peak) peak = rms[w]
// Heuristic 1: classic 35 % sustained threshold from the back.
val thrSust = peak * 0.35f
var lastSpeech = nWin - 1
for (w in nWin - 1 downTo 2) {
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
lastSpeech = w; break
}
}
// Heuristic 2: "low-energy tail" — the "page beg beg" filler emits
// audio that's louder than silence but quieter than real speech.
// If the last 30 % of the segment has a max RMS under 40 % of the
// global peak, the tail is degenerate; cut at the last sustained-
// speech window before the tail starts.
val tailStart = (nWin * 7) / 10
var tailMax = 0f
for (w in tailStart until nWin) if (rms[w] > tailMax) tailMax = rms[w]
if (tailMax < peak * 0.40f) {
for (w in tailStart - 1 downTo 2) {
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
if (w < lastSpeech) lastSpeech = w
break
}
}
}
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
var keepSamples = (keepWin + 1) * winSamples
// Over-trim guard: never cut below 40 % of the raw length.
val minKeep = (audio.size * 4) / 10
if (keepSamples < minKeep) keepSamples = minKeep
return audio.copyOf(keepSamples)
}
/**
* Apply a short cosine fade-out to the tail of an audio buffer. The
* EOS boost ends generation right at the model's chosen stop point,
* which usually clips the natural decay of the last syllable. A
* 40-ms fade smooths that into a recognisable phoneme tail without
* shortening the perceived word.
*/
private fun fadeOut(audio: ShortArray, fadeMs: Int = 40): ShortArray {
val fadeSamples = SR * fadeMs / 1000
if (audio.size <= fadeSamples) return audio
val out = audio.copyOf()
val start = out.size - fadeSamples
for (i in 0 until fadeSamples) {
// Cosine roll-off: 1 → 0 over fadeSamples
val t = i.toFloat() / fadeSamples
val gain = 0.5f * (1f + kotlin.math.cos(Math.PI.toFloat() * t))
out[start + i] = (out[start + i] * gain).toInt().toShort()
}
return out
}
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
* save one file per segment plus the concatenated result for inspection. */
private fun saveWav(path: String, audio: ShortArray) {
try {
val fos = java.io.FileOutputStream(path)
val dataLen = audio.size * 2
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
header.putInt(16); header.putShort(1); header.putShort(1)
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
header.put("data".toByteArray()); header.putInt(dataLen)
fos.write(header.array())
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
for (s in audio) buf.putShort(s)
fos.write(buf.array()); fos.close()
} catch (e: Exception) { nlog("WAV save failed ($path): ${e.message}") }
}
/** Test with pre-computed codec tokens from PC (for validation) */
fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray {
if (!loaded) return ShortArray(0)
nlog("Testing with pre-computed codes: $codesPath (realTokens=$realTokens)")
try {
val bytes = File(codesPath).readBytes()
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val allCodes = Array(NUM_CODEBOOKS) { IntArray(SEQ_LEN) }
for (cb in 0 until NUM_CODEBOOKS) {
for (t in 0 until SEQ_LEN) {
allCodes[cb][t] = bb.int.coerceIn(0, CODEBOOK_SIZE - 1)
}
}
nlog("Loaded ${NUM_CODEBOOKS} × ${SEQ_LEN} codes")
nlog("codes[0][:5] = ${allCodes[0].take(5).toList()}")
nlog("codes[1][:5] = ${allCodes[1].take(5).toList()}")
val t0 = System.currentTimeMillis()
val quantized = vqDecode(allCodes)
// Dump quantized stats for debug
var qMin = Float.MAX_VALUE; var qMax = Float.MIN_VALUE; var qSum = 0.0
for (v in quantized) { qMin = minOf(qMin, v); qMax = maxOf(qMax, v); qSum += v * v }
nlog("quantized: size=${quantized.size}, range=[${qMin}, ${qMax}], rms=${Math.sqrt(qSum / quantized.size)}")
// Save quantized to file for comparison
try {
val qf = java.io.File("/data/local/tmp/kazeia/quantized_dump.bin")
qf.writeBytes(java.nio.ByteBuffer.allocate(quantized.size * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).also {
for (v in quantized) it.putFloat(v)
}.array())
nlog("Saved quantized to ${qf.absolutePath}")
} catch (_: Exception) {}
val fullAudio = runSpeechDecoder(quantized)
val trimSamples = minOf(realTokens * SAMPLES_PER_TOKEN, fullAudio.size)
val audio = fullAudio.copyOf(trimSamples)
nlog("Decode: ${System.currentTimeMillis() - t0}ms, audio=${audio.size.toFloat()/SR}s (trimmed from ${fullAudio.size.toFloat()/SR}s)")
return audio
} catch (e: Exception) {
nlog("Test error: ${e.message}")
return ShortArray(0)
}
}
override fun release() {
stop()
// Stop hexagon runners cleanly
if (useHexagonTalker || useHexagonCp) {
hexStopRunner()
}
talkerKv?.close()
cpKv?.close()
preConv?.close(); preprocessor?.close(); convDecoder?.close()
ortEnv?.close()
loaded = false
}
}