3899 lines
194 KiB
Kotlin
3899 lines
194 KiB
Kotlin
package com.kazeia.tts
|
||
|
||
import ai.onnxruntime.OnnxTensor
|
||
import ai.onnxruntime.OnnxJavaType
|
||
import ai.onnxruntime.OrtEnvironment
|
||
import ai.onnxruntime.OrtSession
|
||
import android.media.AudioAttributes
|
||
import android.media.AudioFormat
|
||
import android.media.AudioTrack
|
||
import android.util.Log
|
||
import com.kazeia.BuildConfig
|
||
import com.kazeia.core.TtsEngine
|
||
import com.kazeia.core.TtsResult
|
||
import kotlinx.coroutines.Dispatchers
|
||
import kotlinx.coroutines.launch
|
||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||
import kotlinx.coroutines.withContext
|
||
import java.io.File
|
||
import java.io.RandomAccessFile
|
||
import java.nio.ByteBuffer
|
||
import java.nio.ByteOrder
|
||
import java.nio.FloatBuffer
|
||
import java.nio.LongBuffer
|
||
import java.nio.IntBuffer
|
||
import kotlin.coroutines.resume
|
||
|
||
/**
|
||
* Qwen3-TTS Engine — Full NPU pipeline for voice cloning TTS.
|
||
*
|
||
* Pipeline:
|
||
* 1. Talker NPU (ORT KV-cache decoder) → codec tokens (codebook 0)
|
||
* 2. Code Predictor NPU (QNN ONNX) → 16 codebooks
|
||
* 3. VQ decode (Kotlin) → quantized [1, 512, 60]
|
||
* 4. pre_conv NPU → preprocessor NPU → ConvNet decoder NPU → audio WAV
|
||
*
|
||
* All heavy computation on NPU. VQ decode is trivial CPU (codebook lookup).
|
||
*/
|
||
class Qwen3TtsEngine(
|
||
private val nativeLibDir: String,
|
||
private val onLog: ((String) -> Unit)? = null
|
||
) : TtsEngine {
|
||
|
||
companion object {
|
||
private const val TAG = "Qwen3TTS"
|
||
private const val SR = 24000
|
||
private const val SEQ_LEN = 60 // Fixed NPU chunk size for decoder
|
||
private const val CHUNK_OVERLAP = 10
|
||
private const val EFFECTIVE_CHUNK = SEQ_LEN - CHUNK_OVERLAP
|
||
private const val SAMPLES_PER_TOKEN = 1920
|
||
private const val CODEC_OFFSET = 1050
|
||
private const val NUM_CODEBOOKS = 16
|
||
private const val CODEBOOK_SIZE = 2048
|
||
private const val CODEBOOK_DIM = 256
|
||
private const val HIDDEN_DIM = 512
|
||
|
||
// Talker KV-cache model constants
|
||
private const val TALKER_DIM = 1024
|
||
private const val TALKER_VOCAB = 3072 // codec vocabulary size
|
||
private const val TALKER_LAYERS = 28
|
||
private const val TALKER_HEADS = 8
|
||
private const val TALKER_HEAD_DIM = 128
|
||
private const val KV_LEN = 199
|
||
private const val MAX_CONTEXT = 200 // KV_LEN + 1
|
||
private const val MASK_NEG = -10000f
|
||
|
||
// Code Predictor KV-cache constants
|
||
private const val CP_LAYERS = 5
|
||
private const val CP_KV_HEADS = 8
|
||
private const val CP_HEAD_DIM = 128
|
||
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
|
||
|
||
// Talker .pte constants
|
||
private const val TALKER_PTE_KV_LEN = 100 // must match .pte export (KV=64 caused quality loss)
|
||
|
||
// Codec special token IDs (in talker's 3072 vocab space)
|
||
private const val CODEC_EOS = 2150
|
||
private const val CODEC_BOS = 2149
|
||
private const val CODEC_PAD = 2148
|
||
private const val CODEC_THINK = 2154
|
||
private const val CODEC_NOTHINK = 2155
|
||
private const val CODEC_THINK_BOS = 2156
|
||
private const val CODEC_THINK_EOS = 2157
|
||
private const val CODEC_LANG_FR = 2061
|
||
|
||
// Chat template token IDs (reduced vocab)
|
||
private const val IM_START = 1048
|
||
private const val IM_END = 1049
|
||
private const val TOKEN_USER = 872
|
||
private const val TOKEN_ASSISTANT = 1042
|
||
private const val TOKEN_NEWLINE = 198
|
||
}
|
||
|
||
private var ortEnv: OrtEnvironment? = null
|
||
private var talkerKv: OrtSession? = null // Talker ONNX CPU (fallback)
|
||
private var useHexagonTalker: Boolean = false // Use ggml-hexagon runner
|
||
private var cpKv: OrtSession? = null // CP KV-cache ONNX CPU
|
||
private var preConv: OrtSession? = null
|
||
private var preprocessor: OrtSession? = null
|
||
private var convDecoder: OrtSession? = null
|
||
private var decoderOnCpu: Boolean = false
|
||
private var decoderOnGpu: Boolean = false
|
||
|
||
// Dual embedding tables for talker input
|
||
private var textEmbeds: FloatArray? = null // [1050, 1024] - reduced vocab, legacy fallback
|
||
private var codecEmbedding: FloatArray? = null // [3072, 1024] - codec/control token embeddings
|
||
|
||
// Stage 2 — on-device full-vocab text embeddings + BPE tokenizer.
|
||
// textEmbedsFull is 151936 × 1024 fp16 memory-mapped (~296 MB); using
|
||
// mmap keeps the bytes off the Java heap so the app doesn't crash when
|
||
// the ~125 MB cp_embeddings allocation comes next. damienVoicePrefix is
|
||
// the fixed 9-embed voice-cloning header that is prepended to the
|
||
// tokenized text to form a full prefill.
|
||
private var textEmbedsFullBuf: java.nio.ByteBuffer? = null
|
||
private var textEmbedsFullChan: java.nio.channels.FileChannel? = null
|
||
private val textEmbedsFullLen = 151936
|
||
private var damienVoicePrefix: Array<FloatArray>? = null
|
||
private var damienVoiceSuffix: Array<FloatArray>? = null
|
||
private var bpeTokenizer: Qwen3BpeTokenizer? = null
|
||
private var ttsBosEmbed: FloatArray? = null // [1024] - tts_bos text-side embedding
|
||
private var ttsEosEmbed: FloatArray? = null // [1024] - tts_eos text-side embedding
|
||
private var ttsPadEmbed: FloatArray? = null // [1024] - tts_pad text-side embedding
|
||
private var speakerEmbed: FloatArray? = null // [1024] - x-vector speaker embedding
|
||
|
||
// VQ codebooks (loaded from numpy)
|
||
private var firstCodebook: FloatArray? = null // [2048, 256]
|
||
private var restCodebooks: Array<FloatArray>? = null // 15 × [2048, 256]
|
||
private var firstOutputProj: FloatArray? = null // [512, 256]
|
||
private var restOutputProj: FloatArray? = null // [512, 256]
|
||
|
||
// Code predictor embeddings [15, 2048, 1024]
|
||
private var cpEmbeddings: FloatArray? = null
|
||
|
||
private var loaded = false
|
||
private var modelPath: String? = null
|
||
private var audioTrack: AudioTrack? = null
|
||
private var talkerUsesInt64Pos = false
|
||
private var talkerUsesCosSin = false
|
||
private var rotaryCos: FloatArray? = null
|
||
private var rotarySin: FloatArray? = null
|
||
private var cpUsesCosSin = false
|
||
private var cpRotaryCos: FloatArray? = null
|
||
private var cpRotarySin: FloatArray? = null
|
||
private var cpHeadsPath: String? = null // path dir for head_0..14.npy
|
||
private var cpHeadsCache: Array<FloatArray?>? = null // lazy-loaded heads cache (8MB each)
|
||
private var cpAllHeads: FloatArray? = null // all 15 heads concatenated [15*2048*1024] for batch NEON
|
||
private var cpPteModule: org.pytorch.executorch.Module? = null // ExecuTorch CP on NPU (JNI)
|
||
private var talkerPteModule: org.pytorch.executorch.Module? = null // ExecuTorch talker on NPU (JNI)
|
||
private var nativePipelineReady: Boolean = false // C++ native pipeline available
|
||
private var talkerPteRotaryCos: FloatArray? = null
|
||
private var talkerPteRotarySin: FloatArray? = null
|
||
private var useEtCp: Boolean = false // CP via ExecuTorch runner process (root)
|
||
private var cpEtSocket: java.net.Socket? = null
|
||
private var cpFixedKv: Boolean = false // GPU uses fixed-size KV (shift + mask)
|
||
|
||
private var debugLogFile: java.io.File? = null
|
||
|
||
/**
|
||
* Check a diagnostic flag file. Returns true only in DEBUG builds; release
|
||
* builds always return false so the file system check and any diagnostic
|
||
* code path is dead code (JIT-eliminated). Used for investigations that
|
||
* were valuable during development but must not leak into production:
|
||
* force_inject_pycb0, force_greedy_cb0, cb0_temp, force_python_codes,
|
||
* force_cpu_talker, force_cpu_talker_gguf.
|
||
*
|
||
* The cross-architecture tremor investigation (thesis-relevant) showed
|
||
* x86 Python ↔ ARM64 tablet is a numerical floor (~28% cb0 divergence),
|
||
* not fixable by runtime swap. These flags stay in tree for re-running
|
||
* the experiment on demand (demos, thesis figures), not for production.
|
||
*/
|
||
private fun diagFlag(name: String): Boolean =
|
||
BuildConfig.DEBUG && File("/data/local/tmp/kazeia/$name").exists()
|
||
|
||
private fun diagFile(name: String): File? =
|
||
if (BuildConfig.DEBUG) File("/data/local/tmp/kazeia/$name").takeIf { it.exists() } else null
|
||
|
||
private fun nlog(msg: String) {
|
||
Log.i(TAG, msg)
|
||
Log.w(TAG, msg)
|
||
onLog?.invoke("[TTS] $msg")
|
||
try {
|
||
val f = debugLogFile ?: java.io.File("/data/local/tmp/kazeia/tts_debug.log")
|
||
f.appendText("${System.currentTimeMillis()} $msg\n")
|
||
} catch (_: Exception) {
|
||
// If /data/local/tmp fails, try app-internal dir
|
||
try {
|
||
val f2 = java.io.File("/data/local/tmp/kazeia/models/qwen3-tts-npu/tts_debug.log")
|
||
f2.appendText("${System.currentTimeMillis()} $msg\n")
|
||
debugLogFile = f2
|
||
} catch (_: Exception) {}
|
||
}
|
||
}
|
||
|
||
override suspend fun load(modelPath: String?, voiceId: String?) {
|
||
withContext(Dispatchers.IO) {
|
||
val path = modelPath ?: return@withContext
|
||
this@Qwen3TtsEngine.modelPath = path
|
||
try {
|
||
val t0 = System.currentTimeMillis()
|
||
ortEnv = OrtEnvironment.getEnvironment()
|
||
val htpPath = "$nativeLibDir/libQnnHtp.so"
|
||
|
||
val qnnCacheDir = "$path/qnn_cache"
|
||
File(qnnCacheDir).mkdirs()
|
||
|
||
fun loadQnn(name: String, fp16: Boolean = false): OrtSession {
|
||
val t = System.currentTimeMillis()
|
||
val opts = OrtSession.SessionOptions()
|
||
val qnnOpts = mutableMapOf(
|
||
"backend_path" to htpPath,
|
||
"qnn_context_cache_enable" to "1",
|
||
"qnn_context_cache_path" to "$qnnCacheDir/${name}.bin"
|
||
)
|
||
if (fp16) {
|
||
qnnOpts["htp_graph_finalization_optimization_mode"] = "3"
|
||
qnnOpts["enable_htp_fp16_precision"] = "1"
|
||
}
|
||
opts.addQnn(qnnOpts)
|
||
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
|
||
val mode = if (fp16) "QNN fp16" else "QNN"
|
||
nlog("$name ($mode): ${System.currentTimeMillis() - t}ms")
|
||
return session
|
||
}
|
||
|
||
fun loadCpu(name: String, threads: Int = 8): OrtSession {
|
||
val t = System.currentTimeMillis()
|
||
val opts = OrtSession.SessionOptions()
|
||
opts.setIntraOpNumThreads(threads)
|
||
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
|
||
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
|
||
nlog("$name (CPU ${threads}T): ${System.currentTimeMillis() - t}ms")
|
||
return session
|
||
}
|
||
|
||
val gpuLib = "$nativeLibDir/libQnnGpu.so"
|
||
val hasGpu = File(gpuLib).exists()
|
||
|
||
fun loadGpu(onnxPath: String, label: String): OrtSession {
|
||
val t = System.currentTimeMillis()
|
||
val opts = OrtSession.SessionOptions()
|
||
opts.addQnn(mapOf("backend_path" to gpuLib))
|
||
val session = ortEnv!!.createSession(onnxPath, opts)
|
||
nlog("$label (GPU): ${System.currentTimeMillis() - t}ms")
|
||
return session
|
||
}
|
||
|
||
// Speech decoder V2 on CPU. Two paths tried, both worse than CPU:
|
||
// - HTP: BigVGAN convolutions too slow to compile (timeout)
|
||
// - GPU Adreno via QNN GPU EP: model loads but per-phrase
|
||
// inference is ~3.5 s vs ~2 s on CPU (GPU/CPU memory transfer
|
||
// overhead dominates for this conv-heavy model)
|
||
// CPU 8-thread stays the practical optimum.
|
||
val v2Path = "$path/v2_pre_conv"
|
||
if (File("$v2Path/model.onnx").exists()) {
|
||
nlog("Loading V2 speech decoder (CPU ONNX)...")
|
||
preConv = loadCpu("v2_pre_conv", 4)
|
||
preprocessor = loadCpu("v2_pre_transformer", 4)
|
||
convDecoder = loadCpu("v2_decoder_conv", 8) // BigVGAN benefits from more threads
|
||
decoderOnCpu = true
|
||
nlog("Speech decoder V2 on CPU")
|
||
} else {
|
||
nlog("Loading speech decoder (QNN HTP)...")
|
||
preConv = loadQnn("pre_conv")
|
||
preprocessor = loadQnn("preprocessor")
|
||
convDecoder = loadQnn("conv_decoder")
|
||
nlog("Speech decoder on HTP")
|
||
}
|
||
|
||
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
|
||
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
|
||
|
||
// Test overrides: skip backends to expose lower-priority talker paths.
|
||
// force_hexagon → skip .pte, use ggml-hexagon HMX
|
||
// force_cpu_talker → skip .pte AND Hexagon, use the CPU ONNX talker
|
||
// (used as an "EOS oracle" path because fp32 still
|
||
// produces natural codec_eos where fp16 NPU drifts.)
|
||
// force_hexagon is the production routing flag (not diagnostic): keep
|
||
// unconditional. force_cpu_talker is diagnostic (used to verify CPU-fp32
|
||
// equivalence during the tremor investigation) → DEBUG only.
|
||
val forceHexagon = File("/data/local/tmp/kazeia/force_hexagon").exists()
|
||
val forceCpuTalker = diagFlag("force_cpu_talker")
|
||
if (forceHexagon) nlog("force_hexagon flag detected: skipping .pte init")
|
||
if (forceCpuTalker) nlog("force_cpu_talker flag detected: skipping .pte AND hexagon init")
|
||
|
||
// Load .pte modules via Java JNI (native pipeline uses same instances).
|
||
// CP .pte can load regardless of forceHexagon/forceCpuTalker — it only touches
|
||
// the NPU for its own graph. The talker .pte is separately gated below.
|
||
// Running CP on the NPU alongside a Hexagon talker MAY trigger DSP contention
|
||
// (memory note: "ggml-hexagon and ExecuTorch QNN cannot share CDSP"). This
|
||
// is exactly what we're testing.
|
||
run {
|
||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
||
if (!forceCpuTalker && etModel.exists() && cpPteModule == null) {
|
||
try {
|
||
val t0 = System.currentTimeMillis()
|
||
cpPteModule = org.pytorch.executorch.Module.load(
|
||
etModel.absolutePath,
|
||
org.pytorch.executorch.Module.LOAD_MODE_FILE,
|
||
1
|
||
)
|
||
nlog("CP .pte JNI loaded: ${System.currentTimeMillis() - t0}ms")
|
||
val t1 = System.currentTimeMillis()
|
||
val lmResult = cpPteModule!!.loadMethod("forward")
|
||
nlog("CP .pte loadMethod: ${System.currentTimeMillis() - t1}ms, result=$lmResult")
|
||
if (lmResult != 0) {
|
||
nlog("CP .pte loadMethod failed ($lmResult), disabling JNI")
|
||
cpPteModule = null
|
||
} else {
|
||
// Register CP module for native pipeline
|
||
cpPteModule!!.nativeSetCpModule()
|
||
nlog("CP module registered for native pipeline")
|
||
}
|
||
} catch (e: Exception) {
|
||
nlog("CP .pte JNI failed: ${e.message}")
|
||
cpPteModule = null
|
||
}
|
||
}
|
||
// Load talker .pte JNI (same QNN context, no DSP contention with CP .pte)
|
||
val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte")
|
||
if (!forceHexagon && !forceCpuTalker && talkerPte.exists() && cpPteModule != null && talkerPteModule == null) {
|
||
try {
|
||
val t0 = System.currentTimeMillis()
|
||
talkerPteModule = org.pytorch.executorch.Module.load(
|
||
talkerPte.absolutePath,
|
||
org.pytorch.executorch.Module.LOAD_MODE_FILE,
|
||
1
|
||
)
|
||
val lm = talkerPteModule!!.loadMethod("forward")
|
||
nlog("Talker .pte JNI loaded+compiled: ${System.currentTimeMillis() - t0}ms, result=$lm")
|
||
if (lm != 0) { nlog("Talker .pte loadMethod failed"); talkerPteModule = null }
|
||
else {
|
||
val path = "/data/local/tmp/kazeia/models"
|
||
talkerPteRotaryCos = loadNpy("$path/talker_pte_rotary_cos.npy")
|
||
talkerPteRotarySin = loadNpy("$path/talker_pte_rotary_sin.npy")
|
||
nlog("Talker .pte rotary: ${talkerPteRotaryCos?.size} floats")
|
||
|
||
// Warmup both models: first forward() triggers QNN DSP compilation (~7s)
|
||
// Better to pay this cost at init than at first pipeline run
|
||
val tw = System.currentTimeMillis()
|
||
try {
|
||
val dE = FloatArray(TALKER_DIM)
|
||
val dM = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }; dM[TALKER_PTE_KV_LEN - 1] = 0f
|
||
val dC = FloatArray(TALKER_HEAD_DIM) { 1f }
|
||
val dS = FloatArray(TALKER_HEAD_DIM)
|
||
val tkvSz = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
|
||
val ins = mutableListOf(
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dE, longArrayOf(1,1,TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dM, longArrayOf(1,1,1,TALKER_PTE_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dC, longArrayOf(1,1,TALKER_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(dS, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
|
||
)
|
||
for (i in 0 until TALKER_LAYERS * 2) ins.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(FloatArray(tkvSz), longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong()))))
|
||
talkerPteModule!!.forward(*ins.toTypedArray())
|
||
nlog("Talker warmup: ${System.currentTimeMillis() - tw}ms")
|
||
} catch (e: Exception) { nlog("Talker warmup failed: ${e.message}") }
|
||
|
||
// CP warmup
|
||
val cw = System.currentTimeMillis()
|
||
try {
|
||
val ckvSz = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
|
||
val cIns = mutableListOf(
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(TALKER_DIM), longArrayOf(1,1,TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_KV_LEN){-1e9f}.also{it[CP_KV_LEN-1]=0f}, longArrayOf(1,1,1,CP_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_HEAD_DIM){1f}, longArrayOf(1,1,CP_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(FloatArray(CP_HEAD_DIM), longArrayOf(1,1,CP_HEAD_DIM.toLong())))
|
||
)
|
||
for (i in 0 until CP_LAYERS * 2) cIns.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(FloatArray(ckvSz), longArrayOf(1,CP_KV_HEADS.toLong(),CP_KV_LEN.toLong(),CP_HEAD_DIM.toLong()))))
|
||
cpPteModule!!.forward(*cIns.toTypedArray())
|
||
nlog("CP warmup: ${System.currentTimeMillis() - cw}ms")
|
||
} catch (e: Exception) { nlog("CP warmup failed: ${e.message}") }
|
||
|
||
// Native pipeline already initialized above
|
||
}
|
||
} catch (e: Exception) {
|
||
nlog("Talker .pte JNI failed: ${e.message}")
|
||
talkerPteModule = null
|
||
}
|
||
}
|
||
}
|
||
|
||
// Talker: skip Hexagon if .pte talker+CP are both loaded (avoids DSP contention)
|
||
val talkerT = System.currentTimeMillis()
|
||
val hexRunner = File("/data/local/tmp/kazeia/llama-hex/llama-tts-talker")
|
||
val hexModel = File("/data/local/tmp/kazeia/models/talker_f16.gguf")
|
||
val cpuOnnx = File("$path/talker_kv_cpu/model.onnx")
|
||
if (talkerPteModule != null && cpPteModule != null) {
|
||
nlog("Talker+CP using .pte JNI NPU — skipping Hexagon runners")
|
||
} else if (forceCpuTalker) {
|
||
nlog("force_cpu_talker: skipping Hexagon, will use CPU ONNX")
|
||
} else {
|
||
if (hexRunner.exists() && hexModel.exists()) {
|
||
if (hexStartRunner()) {
|
||
useHexagonTalker = true
|
||
nlog("Talker using Hexagon NPU (HMX FP16)")
|
||
} else {
|
||
nlog("Hexagon talker runner failed, using CPU")
|
||
}
|
||
}
|
||
val hexCpRunner = File("/data/local/tmp/kazeia/llama-hex/llama-tts-cp")
|
||
val hexCpModel = File("/data/local/tmp/kazeia/models/cp_f16.gguf")
|
||
if (hexCpRunner.exists() && hexCpModel.exists()) {
|
||
if (hexStartCpRunner()) {
|
||
useHexagonCp = true
|
||
nlog("CP using CPU GGUF (avoids Hexagon HMX NaN bug)")
|
||
} else {
|
||
nlog("CP CPU runner failed, falling back to ONNX")
|
||
}
|
||
}
|
||
}
|
||
// Fallback: CPU ONNX for talker if hexagon failed
|
||
if (!useHexagonTalker) {
|
||
// Try new M-RoPE ONNX: GPU Adreno first (fast fp16), fallback CPU
|
||
val mropeOnnx = File("$path/talker_kv_cpu/model.onnx")
|
||
if (mropeOnnx.exists() && mropeOnnx.length() > 1_000_000) {
|
||
val talkerOpts = OrtSession.SessionOptions()
|
||
val gpuLib = "$nativeLibDir/libQnnGpu.so"
|
||
if (File(gpuLib).exists()) {
|
||
talkerOpts.addQnn(mapOf("backend_path" to gpuLib))
|
||
nlog("Talker ONNX: loading on GPU Adreno...")
|
||
} else {
|
||
talkerOpts.setIntraOpNumThreads(6)
|
||
talkerOpts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
|
||
nlog("Talker ONNX: loading on CPU 6T...")
|
||
}
|
||
talkerKv = ortEnv!!.createSession(mropeOnnx.absolutePath, talkerOpts)
|
||
talkerUsesCosSin = true
|
||
// Load rotary tables
|
||
val cosFile = File("$path/talker_kv_cpu/talker_rotary_cos.npy")
|
||
val sinFile = File("$path/talker_kv_cpu/talker_rotary_sin.npy")
|
||
if (cosFile.exists()) rotaryCos = loadNpy(cosFile.absolutePath)
|
||
if (sinFile.exists()) rotarySin = loadNpy(sinFile.absolutePath)
|
||
nlog("talker_kv M-RoPE (CPU 6T): ${System.currentTimeMillis() - talkerT}ms, cos/sin=${rotaryCos?.size}")
|
||
} else if (cpuOnnx.exists()) {
|
||
val talkerOpts = OrtSession.SessionOptions()
|
||
talkerOpts.setIntraOpNumThreads(6)
|
||
talkerOpts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
|
||
talkerKv = ortEnv!!.createSession(cpuOnnx.absolutePath, talkerOpts)
|
||
talkerUsesInt64Pos = true
|
||
nlog("talker_kv legacy (CPU fp32 6T): ${System.currentTimeMillis() - talkerT}ms")
|
||
} else {
|
||
nlog("WARNING: No talker model available (no hexagon, no CPU ONNX)")
|
||
}
|
||
}
|
||
// Always try to load CP V2 ONNX Runtime as an optional routing target.
|
||
// Previously gated on !useHexagonCp (fallback only); now loaded unconditionally
|
||
// so force_cp_v2 flag can divert to ORT's MLAS kernels at runtime.
|
||
run {
|
||
val cpV2 = File("$path/cp_kv_v2/model.onnx")
|
||
if (cpV2.exists() && cpKv == null) {
|
||
try {
|
||
if (!useHexagonCp && (cpPteModule == null || (useHexagonTalker && talkerPteModule == null))) {
|
||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
||
val etRunner = File("/data/local/tmp/kazeia/cp_et_runner")
|
||
if (etRunner.exists() && etModel.exists()) {
|
||
if (hexStartCpEtRunner()) {
|
||
useEtCp = true
|
||
nlog("CP using ExecuTorch NPU runner (TCP, root)")
|
||
} else {
|
||
nlog("CP ET runner failed to start")
|
||
}
|
||
}
|
||
}
|
||
cpKv = loadCpu("cp_kv_v2")
|
||
cpUsesCosSin = true
|
||
cpRotaryCos = loadNpy("$path/cp_kv_v2/cp_rotary_cos.npy")
|
||
cpRotarySin = loadNpy("$path/cp_kv_v2/cp_rotary_sin.npy")
|
||
cpHeadsPath = "$path/cp_kv_v2/cp_heads.npy"
|
||
nlog("CP V2 ONNX loaded, cos/sin=${cpRotaryCos?.size}")
|
||
} catch (e: Exception) {
|
||
nlog("CP V2 ONNX failed: ${e.message}")
|
||
}
|
||
}
|
||
}
|
||
|
||
// Load dual embedding tables for talker
|
||
textEmbeds = loadNpy("$path/text_embeds_projected.npy")
|
||
nlog("Text embeddings: ${textEmbeds!!.size / TALKER_DIM} × $TALKER_DIM")
|
||
|
||
// Stage 2 assets: full-vocab text embeddings + voice prefix + BPE.
|
||
// All three are optional — if any is missing we fall back to the
|
||
// legacy 1050-token path. This keeps the app bootable during
|
||
// asset rollout and avoids turning a missing file into a crash.
|
||
try {
|
||
val fullEmbFile = File("$path/text_embeds_full_fp16.bin")
|
||
val prefixFile = File("$path/damien_voice_prefix.bin")
|
||
val tokDir = File("$path/qwen3_tokenizer")
|
||
if (fullEmbFile.exists() && prefixFile.exists() && tokDir.isDirectory) {
|
||
val tE0 = System.currentTimeMillis()
|
||
// Memory-map the fp16 embeddings table instead of heap-
|
||
// loading it. Without mmap, the 296 MB ByteArray plus
|
||
// the 125 MB cp_embeddings FloatArray loaded a few
|
||
// lines below overrun the ~536 MB large-heap limit and
|
||
// the app OOMs during init. mmap pages the file via
|
||
// the kernel and keeps zero bytes on the Java heap.
|
||
val expectedBytes = textEmbedsFullLen.toLong() * TALKER_DIM * 2
|
||
if (fullEmbFile.length() != expectedBytes) {
|
||
nlog("text_embeds_full_fp16 size mismatch (got ${fullEmbFile.length()}, expected $expectedBytes) — disabling on-device text")
|
||
} else {
|
||
textEmbedsFullChan = java.io.RandomAccessFile(fullEmbFile, "r").channel
|
||
textEmbedsFullBuf = textEmbedsFullChan!!.map(
|
||
java.nio.channels.FileChannel.MapMode.READ_ONLY, 0L, expectedBytes
|
||
).order(ByteOrder.LITTLE_ENDIAN)
|
||
nlog("Full-vocab text embeddings mmap: ${textEmbedsFullLen} × $TALKER_DIM fp16 (${expectedBytes/1024/1024}MB, off-heap) in ${System.currentTimeMillis()-tE0}ms")
|
||
}
|
||
|
||
val pb = ByteBuffer.wrap(prefixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||
val nPref = pb.int; val dimPref = pb.int
|
||
if (nPref == 9 && dimPref == TALKER_DIM) {
|
||
damienVoicePrefix = Array(9) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = pb.float } }
|
||
nlog("Damien voice prefix: $nPref × $dimPref")
|
||
} else {
|
||
nlog("damien_voice_prefix.bin header mismatch ($nPref × $dimPref) — disabling on-device text")
|
||
}
|
||
|
||
// Voice SUFFIX — the 2 fixed positions that close out
|
||
// the prefill after text tokens. Empirically invariant
|
||
// across segments of the same speaker (diff = 0.0).
|
||
val suffixFile = File("$path/damien_voice_suffix.bin")
|
||
if (suffixFile.exists()) {
|
||
val sb = ByteBuffer.wrap(suffixFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||
val nSuf = sb.int; val dimSuf = sb.int
|
||
if (nSuf == 2 && dimSuf == TALKER_DIM) {
|
||
damienVoiceSuffix = Array(2) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = sb.float } }
|
||
nlog("Damien voice suffix: $nSuf × $dimSuf")
|
||
}
|
||
} else {
|
||
nlog("damien_voice_suffix.bin missing — on-device text will lack closure markers")
|
||
}
|
||
|
||
if (textEmbedsFullBuf != null && damienVoicePrefix != null && damienVoiceSuffix != null) {
|
||
bpeTokenizer = Qwen3BpeTokenizer.load(tokDir.absolutePath)
|
||
nlog("Stage 2 on-device text path ready (BPE + full embeds + voice prefix)")
|
||
}
|
||
} else {
|
||
nlog("Stage 2 assets not fully present (full=$fullEmbFile, prefix=$prefixFile, tok=$tokDir) — legacy path only")
|
||
}
|
||
} catch (e: Exception) {
|
||
nlog("Stage 2 asset load failed: ${e.message} — legacy path only")
|
||
textEmbedsFullBuf = null; damienVoicePrefix = null; bpeTokenizer = null
|
||
try { textEmbedsFullChan?.close() } catch (_: Exception) {}
|
||
textEmbedsFullChan = null
|
||
}
|
||
|
||
codecEmbedding = loadNpy("$path/codec_embedding.npy")
|
||
nlog("Codec embedding: ${codecEmbedding!!.size / TALKER_DIM} × $TALKER_DIM")
|
||
val ttsSpecial = loadNpy("$path/tts_special_embeds.npy") // [3, 1024] = bos, eos, pad
|
||
ttsBosEmbed = ttsSpecial.sliceArray(0 until TALKER_DIM)
|
||
ttsEosEmbed = ttsSpecial.sliceArray(TALKER_DIM until 2 * TALKER_DIM)
|
||
ttsPadEmbed = ttsSpecial.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
|
||
nlog("TTS special embeddings loaded")
|
||
|
||
// Load speaker embedding (x-vector for voice cloning)
|
||
val spkFile = File("$path/speaker_embedding.npy")
|
||
if (spkFile.exists()) {
|
||
speakerEmbed = loadNpy(spkFile.absolutePath)
|
||
nlog("Speaker embedding: ${speakerEmbed!!.size} floats, norm=${Math.sqrt(speakerEmbed!!.sumOf { (it * it).toDouble() })}")
|
||
}
|
||
|
||
// Load VQ codebooks
|
||
loadVqCodebooks(path)
|
||
|
||
// Load code predictor embeddings
|
||
cpEmbeddings = loadNpy("$path/code_predictor_embeddings.npy")
|
||
nlog("CP embeddings: ${cpEmbeddings!!.size} floats")
|
||
|
||
loaded = true
|
||
nlog("Qwen3-TTS loaded in ${System.currentTimeMillis() - t0}ms")
|
||
} catch (e: Exception) {
|
||
nlog("ERROR: ${e.message}")
|
||
e.printStackTrace()
|
||
}
|
||
}
|
||
}
|
||
|
||
override fun isLoaded(): Boolean = loaded
|
||
|
||
fun setVoice(voicePath: String) {
|
||
nlog("Voice: $voicePath")
|
||
}
|
||
|
||
override suspend fun synthesize(text: String, language: String): TtsResult {
|
||
return withContext(Dispatchers.IO) {
|
||
val audio = generateSpeech(text, language)
|
||
TtsResult(audioData = audio, sampleRate = SR, durationMs = audio.size * 1000L / SR)
|
||
}
|
||
}
|
||
|
||
override suspend fun synthesizeAndPlay(
|
||
text: String, language: String,
|
||
onStart: (() -> Unit)?, onComplete: (() -> Unit)?
|
||
) {
|
||
withContext(Dispatchers.IO) {
|
||
val codebooks = generateCodebooks(text, language)
|
||
if (codebooks == null) { onComplete?.invoke(); return@withContext }
|
||
|
||
val (allCodebooks, numRealTokens) = codebooks
|
||
|
||
// Release DSP for QNN decoder ONLY if decoder is on HTP (not GPU)
|
||
if (!decoderOnGpu && (useHexagonTalker || useHexagonCp)) {
|
||
hexStopRunner()
|
||
}
|
||
|
||
onStart?.invoke()
|
||
|
||
// Decode and play in streaming chunks
|
||
val track = AudioTrack.Builder()
|
||
.setAudioAttributes(AudioAttributes.Builder()
|
||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||
.build())
|
||
.setAudioFormat(AudioFormat.Builder()
|
||
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||
.setSampleRate(SR)
|
||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||
.build())
|
||
.setBufferSizeInBytes(SAMPLES_PER_TOKEN * 20 * 2) // ~1.6s buffer
|
||
.setTransferMode(AudioTrack.MODE_STREAM)
|
||
.build()
|
||
track.play()
|
||
audioTrack = track
|
||
|
||
val t3 = System.currentTimeMillis()
|
||
var pos = 0
|
||
while (pos < numRealTokens) {
|
||
val chunkEnd = minOf(pos + EFFECTIVE_CHUNK, numRealTokens)
|
||
val chunkTokens = chunkEnd - pos
|
||
|
||
// Build chunk codebooks padded to SEQ_LEN
|
||
val chunkCodes = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(SEQ_LEN) { t ->
|
||
val srcIdx = pos + t
|
||
if (srcIdx < numRealTokens) allCodebooks[cb][srcIdx] else 0
|
||
}
|
||
}
|
||
|
||
val quantized = vqDecode(chunkCodes)
|
||
val chunkAudio = runSpeechDecoder(quantized)
|
||
|
||
// Trim and crossfade
|
||
if (pos == 0) {
|
||
val keepSamples = minOf(chunkTokens * SAMPLES_PER_TOKEN, chunkAudio.size)
|
||
track.write(chunkAudio, 0, keepSamples)
|
||
} else {
|
||
val skipSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
|
||
val keepSamples = minOf(chunkTokens * SAMPLES_PER_TOKEN, chunkAudio.size - skipSamples)
|
||
if (keepSamples > 0 && skipSamples < chunkAudio.size) {
|
||
track.write(chunkAudio, skipSamples, keepSamples)
|
||
}
|
||
}
|
||
|
||
pos += EFFECTIVE_CHUNK
|
||
}
|
||
nlog("Streaming decode: ${System.currentTimeMillis() - t3}ms")
|
||
|
||
track.stop()
|
||
track.release()
|
||
audioTrack = null
|
||
onComplete?.invoke()
|
||
}
|
||
}
|
||
|
||
override fun stop() {
|
||
audioTrack?.apply {
|
||
try { stop() } catch (_: Exception) {}
|
||
release()
|
||
}
|
||
audioTrack = null
|
||
}
|
||
|
||
/** Generate codebooks only (no decode). Returns (allCodebooks[16][padLen], numRealTokens). */
|
||
private fun generateCodebooks(text: String, language: String): Pair<Array<IntArray>, Int>? {
|
||
if (!loaded) return null
|
||
val t0 = System.currentTimeMillis()
|
||
nlog("Generating codebooks: '$text' [$language]")
|
||
try {
|
||
val textEmbedsList: List<FloatArray>
|
||
val phraseFile = java.io.File("$modelPath/phrase_embeds.bin")
|
||
if (phraseFile.exists()) {
|
||
val bytes = phraseFile.readBytes()
|
||
val buf = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
val count = buf.int
|
||
textEmbedsList = (0 until count).map {
|
||
FloatArray(TALKER_DIM).also { arr -> buf.asFloatBuffer().get(arr); buf.position(buf.position() + TALKER_DIM * 4) }
|
||
}
|
||
nlog("Loaded $count pre-computed text embeddings")
|
||
} else {
|
||
val tokenIds = tokenizeText(text)
|
||
textEmbedsList = tokenIds.map { textEmb(it) }
|
||
}
|
||
val maxGen = MAX_CONTEXT - 15
|
||
val allCodesArray = runInterleavedGeneration(textEmbedsList, maxGen)
|
||
val genMs = System.currentTimeMillis() - t0
|
||
nlog("Interleaved gen: ${genMs}ms, ${allCodesArray.size} tokens")
|
||
if (allCodesArray.isEmpty()) return null
|
||
|
||
val numRealTokens = allCodesArray.size
|
||
val padLen = maxOf(numRealTokens, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t -> if (t < numRealTokens) allCodesArray[t][cb] else 0 }
|
||
}
|
||
return Pair(allCodebooks, numRealTokens)
|
||
} catch (e: Exception) {
|
||
nlog("ERROR: ${e.message}")
|
||
return null
|
||
}
|
||
}
|
||
|
||
private fun generateSpeech(text: String, language: String): ShortArray {
|
||
if (!loaded) return ShortArray(0)
|
||
val t0 = System.currentTimeMillis()
|
||
nlog("Generating: '$text' [$language]")
|
||
|
||
try {
|
||
// Step 1: Load pre-computed text embeddings, or use tokenizer fallback
|
||
val phraseFile = File("$modelPath/phrase_embeds.bin")
|
||
val textEmbedsList: List<FloatArray>
|
||
if (phraseFile.exists()) {
|
||
// Pre-computed embeddings from PC
|
||
val bytes = phraseFile.readBytes()
|
||
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
val count = buf.int
|
||
textEmbedsList = (0 until count).map {
|
||
FloatArray(TALKER_DIM).also { arr -> buf.asFloatBuffer().get(arr); buf.position(buf.position() + TALKER_DIM * 4) }
|
||
}
|
||
nlog("Loaded $count pre-computed text embeddings")
|
||
} else {
|
||
// Fallback: use reduced vocab tokenizer
|
||
val tokenIds = tokenizeText(text)
|
||
textEmbedsList = tokenIds.map { textEmb(it) }
|
||
nlog("Text tokens: ${tokenIds.size}: ${tokenIds.toList()}")
|
||
}
|
||
|
||
// Step 2: Interleaved talker + code predictor → all 16 codebooks
|
||
val t1 = System.currentTimeMillis()
|
||
// Hard safety limit; actual stop is controlled by post-text step counter below.
|
||
val maxGen = MAX_CONTEXT - 15
|
||
val allCodesArray = runInterleavedGeneration(textEmbedsList, maxGen)
|
||
nlog("Interleaved gen: ${System.currentTimeMillis() - t1}ms, ${allCodesArray.size} tokens")
|
||
if (allCodesArray.isEmpty()) return ShortArray(0)
|
||
|
||
val numRealTokens = allCodesArray.size
|
||
|
||
// Reshape to [16][totalTokens] for decoder (pad to SEQ_LEN if needed)
|
||
val padLen = maxOf(numRealTokens, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t -> if (t < numRealTokens) allCodesArray[t][cb] else 0 }
|
||
}
|
||
|
||
// Release DSP if decoder needs HTP (not GPU or CPU)
|
||
if (!decoderOnCpu && !decoderOnGpu && (useHexagonTalker || useHexagonCp)) {
|
||
hexStopRunner()
|
||
}
|
||
|
||
// Step 3: VQ → decoder → audio (chunked)
|
||
val t3 = System.currentTimeMillis()
|
||
val audio = decodeChunked(allCodebooks, numRealTokens)
|
||
nlog("Decode (chunked): ${System.currentTimeMillis() - t3}ms")
|
||
|
||
val totalMs = System.currentTimeMillis() - t0
|
||
val audioDur = audio.size.toFloat() / SR
|
||
nlog("Total: ${totalMs}ms for ${audioDur}s audio (RTF ${totalMs / 1000f / audioDur})")
|
||
|
||
return audio
|
||
} catch (e: Exception) {
|
||
nlog("ERROR: ${e.message}")
|
||
e.printStackTrace()
|
||
return ShortArray(0)
|
||
}
|
||
}
|
||
|
||
// ==================== Tokenization ====================
|
||
|
||
/**
|
||
* Simple text tokenization. Maps known words to reduced vocab IDs.
|
||
* For production, this should be replaced with proper BPE tokenization.
|
||
*/
|
||
private fun tokenizeText(text: String): IntArray {
|
||
// Known word → reduced vocab ID mappings (from Qwen3 tokenizer + token_mapping)
|
||
val knownTokens = mapOf(
|
||
"Bonjour" to 1043,
|
||
"bonjour" to 1028,
|
||
" bonjour" to intArrayOf(220, 1028), // space + bonjour
|
||
"Oui" to 1006,
|
||
"Non" to 966,
|
||
"Merci" to 1035,
|
||
"salut" to 1023,
|
||
"." to 13,
|
||
"," to 11,
|
||
"!" to 0,
|
||
"?" to 30,
|
||
" " to 220,
|
||
)
|
||
|
||
// Try exact match first
|
||
val exactId = knownTokens[text]
|
||
if (exactId is Int) return intArrayOf(exactId)
|
||
|
||
// Try word-by-word tokenization
|
||
val tokens = mutableListOf<Int>()
|
||
var remaining = text
|
||
while (remaining.isNotEmpty()) {
|
||
var matched = false
|
||
// Try longest match
|
||
for (len in minOf(remaining.length, 20) downTo 1) {
|
||
val candidate = remaining.substring(0, len)
|
||
val id = knownTokens[candidate]
|
||
if (id != null) {
|
||
when (id) {
|
||
is Int -> tokens.add(id)
|
||
is IntArray -> tokens.addAll(id.toList())
|
||
}
|
||
remaining = remaining.substring(len)
|
||
matched = true
|
||
break
|
||
}
|
||
}
|
||
if (!matched) {
|
||
// Skip unknown character
|
||
nlog("WARN: Unknown token for '${remaining.first()}'")
|
||
remaining = remaining.substring(1)
|
||
}
|
||
}
|
||
|
||
if (tokens.isEmpty()) {
|
||
nlog("WARN: No tokens from text '$text', using Bonjour fallback")
|
||
return intArrayOf(1043) // "Bonjour"
|
||
}
|
||
return tokens.toIntArray()
|
||
}
|
||
|
||
// ==================== Talker KV-Cache Generation ====================
|
||
|
||
// ==================== Embedding Helpers ====================
|
||
|
||
private fun textEmb(reducedId: Int): FloatArray {
|
||
val e = FloatArray(TALKER_DIM)
|
||
System.arraycopy(textEmbeds!!, reducedId.coerceIn(0, 1049) * TALKER_DIM, e, 0, TALKER_DIM)
|
||
return e
|
||
}
|
||
|
||
private fun codecEmb(codecIdx: Int): FloatArray {
|
||
val e = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, codecIdx.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, e, 0, TALKER_DIM)
|
||
return e
|
||
}
|
||
|
||
private fun cpEmb(codebookIdx: Int, tokenIdx: Int): FloatArray {
|
||
// cpEmbeddings is [15, 2048, 1024] flattened
|
||
val e = FloatArray(TALKER_DIM)
|
||
val offset = (codebookIdx * CODEBOOK_SIZE + tokenIdx.coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
|
||
System.arraycopy(cpEmbeddings!!, offset, e, 0, TALKER_DIM)
|
||
return e
|
||
}
|
||
|
||
private fun sumEmb(a: FloatArray, b: FloatArray): FloatArray {
|
||
val r = FloatArray(TALKER_DIM)
|
||
for (i in 0 until TALKER_DIM) r[i] = a[i] + b[i]
|
||
return r
|
||
}
|
||
|
||
private fun addEmb(dst: FloatArray, src: FloatArray) {
|
||
for (i in 0 until TALKER_DIM) dst[i] += src[i]
|
||
}
|
||
|
||
/**
|
||
* Build prefill embeddings with speaker embedding (voice cloning).
|
||
* textEmbedsList: pre-computed text embeddings (one per text token)
|
||
*/
|
||
private fun buildPrefillEmbeddings(textEmbedsList: List<FloatArray>): List<FloatArray> {
|
||
val padE = ttsPadEmbed ?: return emptyList()
|
||
val bosE = ttsBosEmbed ?: return emptyList()
|
||
val spkE = speakerEmbed
|
||
|
||
val embeddings = mutableListOf<FloatArray>()
|
||
|
||
// 1. Role: <|im_start|>assistant\n
|
||
embeddings.add(textEmb(IM_START))
|
||
embeddings.add(textEmb(TOKEN_ASSISTANT))
|
||
embeddings.add(textEmb(TOKEN_NEWLINE))
|
||
|
||
// 2. Codec control: [think, think_bos, lang_fr, think_eos] + tts_pad
|
||
for (cc in intArrayOf(CODEC_THINK, CODEC_THINK_BOS, CODEC_LANG_FR, CODEC_THINK_EOS)) {
|
||
embeddings.add(sumEmb(padE, codecEmb(cc)))
|
||
}
|
||
|
||
// 3. Speaker embedding (x-vector)
|
||
if (spkE != null) {
|
||
embeddings.add(sumEmb(padE, spkE))
|
||
}
|
||
|
||
// 4. codec_pad + tts_bos
|
||
embeddings.add(sumEmb(bosE, codecEmb(CODEC_PAD)))
|
||
|
||
// 5. First text token + codec_bos
|
||
if (textEmbedsList.isNotEmpty()) {
|
||
embeddings.add(sumEmb(textEmbedsList[0], codecEmb(CODEC_BOS)))
|
||
}
|
||
|
||
nlog("Prefill: ${embeddings.size} tokens (3 role + 4 ctrl + ${if (spkE != null) "1 spk + " else ""}1 bos + 1 text)")
|
||
return embeddings
|
||
}
|
||
|
||
// ==================== Hexagon NPU Talker ====================
|
||
|
||
private val HEX_DIR = "/data/local/tmp/kazeia/llama-hex"
|
||
private val HEX_INPUT = "/data/local/tmp/kazeia/tts_input.bin"
|
||
private val HEX_OUTPUT = "/data/local/tmp/kazeia/tts_logits.bin"
|
||
private val HEX_CONTROL = "/data/local/tmp/kazeia/tts_control.txt"
|
||
private val TALKER_SOCK = "/data/local/tmp/kazeia/talker.sock"
|
||
private val CP_SOCK = "/data/local/tmp/kazeia/cp.sock"
|
||
private val CP_ET_SOCK = "/data/local/tmp/kazeia/cp_et.sock"
|
||
private var talkerSocket: android.net.LocalSocket? = null
|
||
private var cpSocket: android.net.LocalSocket? = null
|
||
private var useHexagonCp = false
|
||
|
||
/** Run a command with su. Returns false if su is not available. */
|
||
private fun suExec(cmd: String): Boolean {
|
||
return try {
|
||
Runtime.getRuntime().exec(arrayOf("su", "-c", cmd)).waitFor()
|
||
true
|
||
} catch (e: Exception) {
|
||
false
|
||
}
|
||
}
|
||
|
||
/** Read a root-owned file via su. */
|
||
private fun suReadFile(path: String): String {
|
||
val p = Runtime.getRuntime().exec(arrayOf("su", "-c", "cat $path"))
|
||
val text = p.inputStream.bufferedReader().readText().trim()
|
||
p.waitFor()
|
||
return text
|
||
}
|
||
|
||
/** Start the hexagon talker runner and connect via socket.
|
||
* If flag force_cpu_talker_gguf exists, runs with -ngl 0 (CPU fp32) to isolate
|
||
* whether the voice-cloning tremor comes from NPU fp16 drift. Diagnostic only;
|
||
* expected RTF 5-7x on CPU. */
|
||
private fun hexStartRunner(): Boolean {
|
||
val forceCpuGguf = diagFlag("force_cpu_talker_gguf")
|
||
val nglArg = if (forceCpuGguf) "-ngl 0" else "-ngl 99 -mg 1"
|
||
nlog("Starting ${if (forceCpuGguf) "CPU GGUF" else "Hexagon"} talker runner ($nglArg)...")
|
||
val t0 = System.currentTimeMillis()
|
||
if (!suExec("pkill -f llama-tts-talker")) { nlog("su not available"); return false }
|
||
Thread.sleep(200)
|
||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-talker -m /data/local/tmp/kazeia/models/talker_f16.gguf $nglArg -s $TALKER_SOCK > /data/local/tmp/kazeia/tts_runner.log 2>&1 &")) return false
|
||
// Wait for socket to be connectable (30s max — model loading takes ~15s)
|
||
for (w in 0 until 300) {
|
||
Thread.sleep(100)
|
||
try {
|
||
val sock = android.net.LocalSocket()
|
||
sock.connect(android.net.LocalSocketAddress(TALKER_SOCK, android.net.LocalSocketAddress.Namespace.FILESYSTEM))
|
||
talkerSocket = sock
|
||
nlog("Hexagon talker connected: ${System.currentTimeMillis() - t0}ms")
|
||
return true
|
||
} catch (_: Exception) {}
|
||
}
|
||
nlog("Hexagon talker timeout (30s)")
|
||
return false
|
||
}
|
||
|
||
/** Talker forward via socket: send embedding, get hidden+logits. */
|
||
private fun hexForward(embeddings: List<FloatArray>): List<Pair<FloatArray, FloatArray>> {
|
||
val sock = talkerSocket ?: return emptyList()
|
||
val os = sock.outputStream
|
||
val ins = sock.inputStream
|
||
val results = mutableListOf<Pair<FloatArray, FloatArray>>()
|
||
|
||
for (emb in embeddings) {
|
||
// Send "FWRD" + embedding
|
||
os.write("FWRD".toByteArray())
|
||
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (v in emb) buf.putFloat(v)
|
||
os.write(buf.array())
|
||
os.flush()
|
||
|
||
// Read hidden(1024*4) + logits(3072*4) = 16384 bytes
|
||
val respSize = (TALKER_DIM + TALKER_VOCAB) * 4
|
||
val resp = ByteArray(respSize)
|
||
var read = 0
|
||
while (read < respSize) {
|
||
val n = ins.read(resp, read, respSize - read)
|
||
if (n <= 0) break
|
||
read += n
|
||
}
|
||
val fb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN).asFloatBuffer()
|
||
val h = FloatArray(TALKER_DIM); fb.get(h)
|
||
val l = FloatArray(TALKER_VOCAB); fb.get(l)
|
||
results.add(Pair(h, l))
|
||
}
|
||
return results
|
||
}
|
||
|
||
/** Reset KV cache for a new phrase via socket. */
|
||
private fun hexReset() {
|
||
val sock = talkerSocket ?: return
|
||
sock.outputStream.write("REST".toByteArray())
|
||
sock.outputStream.flush()
|
||
val resp = ByteArray(4)
|
||
sock.inputStream.read(resp) // blocking, waits for "OK\0\0"
|
||
}
|
||
|
||
/** Start CP ExecuTorch runner (NPU HTP via root process). */
|
||
private fun hexStartCpEtRunner(): Boolean {
|
||
nlog("Connecting CP ExecuTorch runner (TCP:8790)...")
|
||
val t0 = System.currentTimeMillis()
|
||
// Try connecting to existing runner first
|
||
try {
|
||
val sock = java.net.Socket()
|
||
sock.connect(java.net.InetSocketAddress("127.0.0.1", 8790), 2000)
|
||
sock.tcpNoDelay = true
|
||
cpEtSocket = sock
|
||
nlog("CP ET connected (TCP, existing): ${System.currentTimeMillis() - t0}ms")
|
||
return true
|
||
} catch (e: Exception) {
|
||
nlog("CP ET existing connection failed: ${e.message}")
|
||
}
|
||
// Start new runner
|
||
nlog("No running CP ET, starting new...")
|
||
suExec("pkill -f cp_et_runner")
|
||
Thread.sleep(200)
|
||
if (!suExec("cd /data/local/tmp/kazeia && LD_LIBRARY_PATH=qnn_libs:. ADSP_LIBRARY_PATH=qnn_libs " +
|
||
"nohup ./cp_et_runner --model_path=models/cp_transformer_fp16.pte --tcp_port=8790 " +
|
||
"> /data/local/tmp/kazeia/cp_et.log 2>&1 &")) return false
|
||
for (w in 0 until 300) {
|
||
Thread.sleep(100)
|
||
try {
|
||
val sock = java.net.Socket("127.0.0.1", 8790)
|
||
sock.tcpNoDelay = true
|
||
cpEtSocket = sock
|
||
nlog("CP ET connected (TCP): ${System.currentTimeMillis() - t0}ms")
|
||
return true
|
||
} catch (_: Exception) {}
|
||
}
|
||
nlog("CP ET timeout (30s)")
|
||
return false
|
||
}
|
||
|
||
/** CP via ExecuTorch NPU TCP socket: send hidden+cb0_emb, recv 15 codes + timing. */
|
||
private fun etCpForward(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
val sock = cpEtSocket
|
||
if (sock == null || sock.isClosed) { nlog("CP ET socket null"); return runCpV2(pastHidden, cb0) }
|
||
try {
|
||
val os = sock.getOutputStream(); val ins = sock.getInputStream()
|
||
val buf = java.nio.ByteBuffer.allocate(2 * TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (v in pastHidden) buf.putFloat(v)
|
||
val cb0Emb = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, cb0Emb, 0, TALKER_DIM)
|
||
for (v in cb0Emb) buf.putFloat(v)
|
||
os.write(buf.array()); os.flush()
|
||
val resp = ByteArray(64); var read = 0
|
||
while (read < 64) { val n = ins.read(resp, read, 64 - read); if (n <= 0) break; read += n }
|
||
if (read < 64) { nlog("CP ET read incomplete: $read/64"); return runCpV2(pastHidden, cb0) }
|
||
val codes = IntArray(15); val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (i in 0 until 15) codes[i] = rb.int
|
||
return codes
|
||
} catch (e: Exception) {
|
||
nlog("CP ET error: ${e.message}"); cpEtSocket = null; useEtCp = false
|
||
return runCpV2(pastHidden, cb0)
|
||
}
|
||
}
|
||
|
||
/** Start CP hexagon runner and connect via socket.
|
||
* NOTE: uses -ngl 0 (CPU only). Running the CP on Hexagon HTP produces deterministic
|
||
* NaN at position 2 of every request — bug in libggml-htp-v79.so for this specific
|
||
* 5-layer qwen3 graph (confirmed empirically). CPU path is ~60 ms/step and correct.
|
||
* The talker still benefits from Hexagon HMX since it runs in a separate process. */
|
||
private fun hexStartCpRunner(): Boolean {
|
||
nlog("Starting CP runner (CPU, -ngl 0 to avoid HMX NaN bug)...")
|
||
val t0 = System.currentTimeMillis()
|
||
if (!suExec("pkill -f llama-tts-cp")) { nlog("su not available for CP"); return false }
|
||
Thread.sleep(200)
|
||
if (!suExec("cd $HEX_DIR && LD_LIBRARY_PATH=. nohup ./llama-tts-cp -m /data/local/tmp/kazeia/models/cp_f16.gguf -ngl 0 -s $CP_SOCK > /data/local/tmp/kazeia/cp_runner.log 2>&1 &")) return false
|
||
for (w in 0 until 300) {
|
||
Thread.sleep(100)
|
||
try {
|
||
val sock = android.net.LocalSocket()
|
||
sock.connect(android.net.LocalSocketAddress(CP_SOCK, android.net.LocalSocketAddress.Namespace.FILESYSTEM))
|
||
cpSocket = sock
|
||
nlog("CP Hexagon connected: ${System.currentTimeMillis() - t0}ms")
|
||
return true
|
||
} catch (_: Exception) {}
|
||
}
|
||
nlog("CP Hexagon timeout (30s)")
|
||
return false
|
||
}
|
||
|
||
/** CP via Hexagon NPU socket: send hidden+cb0_emb, get 15 codes. */
|
||
private fun hexCpForward(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
val sock = cpSocket
|
||
if (sock == null) {
|
||
nlog("CP socket NULL, falling back to CPU")
|
||
return runCpCpu(pastHidden, cb0)
|
||
}
|
||
try {
|
||
val os = sock.outputStream
|
||
val ins = sock.inputStream
|
||
|
||
val t0 = System.currentTimeMillis()
|
||
// Send hidden(1024) + cb0_emb(1024) = 8192 bytes
|
||
val buf = java.nio.ByteBuffer.allocate(2 * TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (v in pastHidden) buf.putFloat(v)
|
||
val cb0Emb = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, cb0Emb, 0, TALKER_DIM)
|
||
for (v in cb0Emb) buf.putFloat(v)
|
||
os.write(buf.array()); os.flush()
|
||
val writeMs = System.currentTimeMillis() - t0
|
||
|
||
// Read 15 codes (60 bytes) + timing (4 bytes) = 64 bytes
|
||
val resp = ByteArray(64)
|
||
var read = 0
|
||
while (read < 64) { val n = ins.read(resp, read, 64 - read); if (n <= 0) break; read += n }
|
||
val totalMs = System.currentTimeMillis() - t0
|
||
|
||
if (cpCallCount <= 3) nlog("CP socket: write=${writeMs}ms, total=${totalMs}ms, read=$read/64")
|
||
|
||
if (read < 64) {
|
||
nlog("CP socket read incomplete: $read/64 bytes")
|
||
return runCpCpu(pastHidden, cb0)
|
||
}
|
||
|
||
val codes = IntArray(15)
|
||
val rb = java.nio.ByteBuffer.wrap(resp).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (i in 0 until 15) codes[i] = rb.int
|
||
// Defensive clamp: CP on first call may occasionally return uninitialised
|
||
// memory before warm-up stabilises. Out-of-range codes would crash BigVGAN
|
||
// decoder (ArrayIndexOutOfBoundsException on codebook lookup). Map anything
|
||
// invalid to 0 — a single-frame artefact is much better than an app crash.
|
||
var n_bad = 0
|
||
for (i in 0 until 15) {
|
||
if (codes[i] < 0 || codes[i] >= CODEBOOK_SIZE) { codes[i] = 0; n_bad++ }
|
||
}
|
||
if (n_bad > 0) nlog("CP socket: clamped $n_bad out-of-range codes to 0")
|
||
return codes
|
||
} catch (e: Exception) {
|
||
nlog("CP socket error: ${e.message}, falling back to CPU")
|
||
cpSocket = null
|
||
return runCpCpu(pastHidden, cb0)
|
||
}
|
||
}
|
||
|
||
/** Ensure hexagon runners are alive. Only restarts if they're not connected. */
|
||
private fun ensureHexagonRunners() {
|
||
if (!useHexagonTalker) {
|
||
val hexRunner = java.io.File("/data/local/tmp/kazeia/llama-hex/llama-tts-talker")
|
||
val hexModel = java.io.File("/data/local/tmp/kazeia/models/talker_f16.gguf")
|
||
if (hexRunner.exists() && hexModel.exists() && hexStartRunner()) {
|
||
useHexagonTalker = true
|
||
}
|
||
}
|
||
if (!useHexagonCp) {
|
||
val hexCpRunner = java.io.File("/data/local/tmp/kazeia/llama-hex/llama-tts-cp")
|
||
val hexCpModel = java.io.File("/data/local/tmp/kazeia/models/cp_f16.gguf")
|
||
if (hexCpRunner.exists() && hexCpModel.exists() && hexStartCpRunner()) {
|
||
useHexagonCp = true
|
||
}
|
||
}
|
||
}
|
||
|
||
/** Stop all runners and release DSP for QNN decode. Runners restart on next generate(). */
|
||
private fun hexStopRunner() {
|
||
try {
|
||
talkerSocket?.outputStream?.write("QUIT".toByteArray())
|
||
talkerSocket?.close()
|
||
cpSocket?.close()
|
||
cpEtSocket?.close()
|
||
Thread.sleep(300)
|
||
suExec("pkill -f llama-tts")
|
||
suExec("pkill -f cp_et_runner")
|
||
Thread.sleep(200)
|
||
} catch (_: Exception) {}
|
||
talkerSocket = null; cpSocket = null; cpEtSocket = null
|
||
useHexagonTalker = false; useHexagonCp = false; useEtCp = false
|
||
nlog("Hexagon runners stopped, DSP released")
|
||
}
|
||
|
||
/** Full interleaved generation using Hexagon NPU talker + CPU CP. */
|
||
/** All-NPU pipeline: talker .pte + CP .pte via JNI, no root. */
|
||
private fun runInterleavedPte(textEmbedsList: List<FloatArray>, maxGenTokens: Int): Array<IntArray> {
|
||
val talkerMod = talkerPteModule!!
|
||
val cpMod = cpPteModule!!
|
||
val tCos = talkerPteRotaryCos!!
|
||
val tSin = talkerPteRotarySin!!
|
||
val cCos = cpRotaryCos ?: return emptyArray()
|
||
val cSin = cpRotarySin ?: return emptyArray()
|
||
val eosE = ttsEosEmbed ?: return emptyArray()
|
||
val padE = ttsPadEmbed ?: return emptyArray()
|
||
|
||
val prefill = buildPrefillEmbeddings(textEmbedsList)
|
||
if (prefill.isEmpty()) return emptyArray()
|
||
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
|
||
var trailingIdx = 0
|
||
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
|
||
// Talker KV caches [TALKER_LAYERS × (k,v)] each [1, 8, TALKER_PTE_KV_LEN, 128]
|
||
val tkvSize = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
|
||
var tK = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
|
||
var tV = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
|
||
val maskData = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }
|
||
|
||
var pos = 0; var currentCb0 = -1; var pastHidden: FloatArray? = null
|
||
|
||
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
||
|
||
// Capture mode: save all talker inputs for reuse with C++ pipeline
|
||
val capturedEmbeds = mutableListOf<FloatArray>()
|
||
|
||
// ===== PREFILL =====
|
||
val tPrefill = System.currentTimeMillis()
|
||
for (step in prefill.indices) {
|
||
capturedEmbeds.add(prefill[step].clone()) // capture prefill input
|
||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||
|
||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||
System.arraycopy(tCos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||
System.arraycopy(tSin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||
|
||
val inputs = mutableListOf(
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(prefill[step], longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||
)
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
inputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
inputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
}
|
||
|
||
val out = talkerMod.forward(*inputs.toTypedArray())
|
||
pastHidden = out[0].toTensor().dataAsFloatArray
|
||
val logits = out[1].toTensor().dataAsFloatArray
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||
}
|
||
pos++
|
||
|
||
if (step == prefill.size - 1) {
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||
}
|
||
}
|
||
nlog("Prefill (PTE): ${System.currentTimeMillis() - tPrefill}ms, ${prefill.size} steps")
|
||
|
||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
||
|
||
// ===== INTERLEAVED GENERATION =====
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
for (genStep in 0 until maxGenTokens) {
|
||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||
|
||
// 1. CP: predict CB1-15
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCpPte(pastHidden!!, currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||
|
||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||
|
||
// 2. Build next talker input
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
|
||
val nextEmbed: FloatArray = when {
|
||
trailingIdx < trailingEmbeds.size -> { trailingIdx++; sumEmb(codecSum, trailingEmbeds[trailingIdx - 1]) }
|
||
trailingIdx == trailingEmbeds.size -> { trailingIdx++; sumEmb(codecSum, eosE) }
|
||
else -> sumEmb(codecSum, padE)
|
||
}
|
||
|
||
capturedEmbeds.add(nextEmbed.clone()) // capture decode input
|
||
|
||
// 3. Talker step
|
||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||
|
||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||
System.arraycopy(tCos, minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||
System.arraycopy(tSin, minOf(pos, tSin.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||
|
||
val inputs = mutableListOf(
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(nextEmbed, longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||
)
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
inputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
inputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
}
|
||
|
||
val tTalker = System.currentTimeMillis()
|
||
val out = talkerMod.forward(*inputs.toTypedArray())
|
||
totalTalkerMs += System.currentTimeMillis() - tTalker
|
||
|
||
pastHidden = out[0].toTensor().dataAsFloatArray
|
||
val logits = out[1].toTensor().dataAsFloatArray
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||
}
|
||
pos++
|
||
|
||
// 4. Next CB0
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||
|
||
if (nextCb0 == CODEC_EOS) { nlog("EOS at gen step ${genStep + 2}"); break }
|
||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||
nlog("Degeneration at step ${genStep + 2}"); break
|
||
}
|
||
currentCb0 = nextCb0
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
|
||
// Save captured embeds for C++ pipeline reuse
|
||
if (capturedEmbeds.isNotEmpty()) {
|
||
try {
|
||
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
|
||
val nPrefill = prefill.size
|
||
val fos = java.io.FileOutputStream(capPath)
|
||
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
|
||
fos.write(hdr.array())
|
||
for (emb in capturedEmbeds) {
|
||
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (v in emb) buf.putFloat(v)
|
||
fos.write(buf.array())
|
||
}
|
||
fos.close()
|
||
nlog("Captured ${capturedEmbeds.size} embeds → $capPath")
|
||
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
|
||
}
|
||
|
||
return allCodes.toTypedArray()
|
||
}
|
||
|
||
private fun runInterleavedHexagon(textEmbedsList: List<FloatArray>, maxGenTokens: Int): Array<IntArray> {
|
||
val eosE = ttsEosEmbed ?: return emptyArray()
|
||
val padE = ttsPadEmbed ?: return emptyArray()
|
||
|
||
val prefill = buildPrefillEmbeddings(textEmbedsList)
|
||
if (prefill.isEmpty()) return emptyArray()
|
||
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
|
||
var trailingIdx = 0
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
|
||
// Reset KV cache for new phrase (runner already started at load time)
|
||
hexReset()
|
||
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
|
||
try {
|
||
// Prefill
|
||
val tPrefill = System.currentTimeMillis()
|
||
val prefillResults = hexForward(prefill)
|
||
val prefillMs = System.currentTimeMillis() - tPrefill
|
||
nlog("Prefill (Hexagon): ${prefillMs}ms, ${prefillResults.size} steps")
|
||
|
||
var pastHidden = prefillResults.last().first
|
||
val prefillLogits = prefillResults.last().second
|
||
// Suppress non-codec tokens
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
|
||
nlog("Prefill done: first cb0=$currentCb0")
|
||
|
||
// Generation loop
|
||
for (genStep in 0 until maxGenTokens) {
|
||
// CP on CPU
|
||
val tCp = System.currentTimeMillis()
|
||
val codes = IntArray(NUM_CODEBOOKS)
|
||
codes[0] = currentCb0
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes)
|
||
generatedCb0.add(currentCb0)
|
||
val cpMs = System.currentTimeMillis() - tCp
|
||
totalCpMs += cpMs
|
||
|
||
if (genStep < 3) nlog("Gen step ${genStep + 1}: cb0=$currentCb0 cb1=${codes[1]} [CP=${cpMs}ms]")
|
||
|
||
// Build next embedding
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
|
||
val nextEmbed: FloatArray = if (trailingIdx < trailingEmbeds.size) {
|
||
sumEmb(codecSum, trailingEmbeds[trailingIdx++])
|
||
} else if (trailingIdx == trailingEmbeds.size) {
|
||
trailingIdx++; sumEmb(codecSum, eosE)
|
||
} else {
|
||
sumEmb(codecSum, padE)
|
||
}
|
||
|
||
// Talker forward on Hexagon NPU
|
||
val tTalker = System.currentTimeMillis()
|
||
val results = hexForward(listOf(nextEmbed))
|
||
val talkerMs = System.currentTimeMillis() - tTalker
|
||
totalTalkerMs += talkerMs
|
||
if (genStep < 3) nlog(" Talker(HEX)=${talkerMs}ms")
|
||
|
||
pastHidden = results[0].first
|
||
val logits = results[0].second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||
|
||
if (nextCb0 == CODEC_EOS) { nlog("EOS at gen step ${genStep + 2}"); break }
|
||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||
nlog("Degeneration at step ${genStep + 2}"); break
|
||
}
|
||
currentCb0 = nextCb0
|
||
}
|
||
} finally {
|
||
// Runners stay alive here — hexStopRunner() is called before QNN decode
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker(HEX): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
return allCodes.toTypedArray()
|
||
}
|
||
|
||
/**
|
||
* Run interleaved talker + code predictor pipeline.
|
||
*
|
||
* At each step:
|
||
* 1. Talker forward → logits + hidden_state
|
||
* 2. Code predictor (hidden, cb0_emb) → CB1-15 autoregressively
|
||
* 3. Sum all 16 codebook embeddings + trailing text → next talker input
|
||
*
|
||
* Returns all 16 codebooks as [numTokens][16].
|
||
*/
|
||
private fun runInterleavedGeneration(textEmbedsList: List<FloatArray>, maxGenTokens: Int = 50): Array<IntArray> {
|
||
// Priority 1: All .pte JNI on NPU (no root needed)
|
||
if (talkerPteModule != null && cpPteModule != null) {
|
||
return runInterleavedPte(textEmbedsList, maxGenTokens)
|
||
}
|
||
// Priority 2: Hexagon talker + socket/ONNX CP
|
||
ensureHexagonRunners()
|
||
if (useHexagonTalker) {
|
||
return runInterleavedHexagon(textEmbedsList, maxGenTokens)
|
||
}
|
||
val env = ortEnv ?: return emptyArray()
|
||
if (talkerKv == null) return emptyArray()
|
||
val session = talkerKv // may be null if using NPU
|
||
val eosE = ttsEosEmbed ?: return emptyArray()
|
||
val padE = ttsPadEmbed ?: return emptyArray()
|
||
|
||
val prefill = buildPrefillEmbeddings(textEmbedsList)
|
||
if (prefill.isEmpty()) return emptyArray()
|
||
|
||
// Trailing text embeddings (all except first, which is in prefill)
|
||
val trailingEmbeds = if (textEmbedsList.size > 1) textEmbedsList.subList(1, textEmbedsList.size) else emptyList()
|
||
var trailingIdx = 0
|
||
|
||
val allCodes = mutableListOf<IntArray>() // each entry is [16] codebooks for one time step
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
|
||
val kCacheSize = TALKER_HEADS * TALKER_HEAD_DIM * KV_LEN
|
||
val vCacheSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
|
||
var kCaches = Array(TALKER_LAYERS) { FloatArray(kCacheSize) }
|
||
var vCaches = Array(TALKER_LAYERS) { FloatArray(vCacheSize) }
|
||
val maskData = FloatArray(MAX_CONTEXT) { MASK_NEG }
|
||
|
||
var pos = 0
|
||
var currentCb0 = -1
|
||
var pastHidden: FloatArray? = null
|
||
|
||
// ===== PREFILL =====
|
||
for (step in prefill.indices) {
|
||
maskData[MAX_CONTEXT - 1 - step] = 0f
|
||
val res = runTalkerStep(env, session!!, prefill[step], maskData, pos, kCaches, vCaches)
|
||
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
|
||
pos++
|
||
|
||
if (step == prefill.size - 1) {
|
||
// Apply suppression + sampling to get first codec_0
|
||
val logits = res.logits
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||
nlog("Prefill done: first cb0=$currentCb0")
|
||
}
|
||
}
|
||
|
||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
||
|
||
// ===== INTERLEAVED GENERATION =====
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
for (genStep in 0 until maxGenTokens) {
|
||
// 1. Run code predictor: (pastHidden, cb0_emb) → CB1-15
|
||
val codes = IntArray(NUM_CODEBOOKS)
|
||
codes[0] = currentCb0
|
||
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden!!, currentCb0)
|
||
val cpMs = System.currentTimeMillis() - tCp
|
||
totalCpMs += cpMs
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes)
|
||
generatedCb0.add(currentCb0)
|
||
|
||
if (genStep < 3) {
|
||
nlog("Gen step ${genStep + 1}: cb0=$currentCb0 cb1=${codes[1]} [CP=${cpMs}ms]")
|
||
}
|
||
|
||
// 2. Build next talker input: sum of ALL 16 codebook embeddings + trailing text
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) {
|
||
addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
}
|
||
|
||
// Text side: trailing text tokens, then tts_eos, then tts_pad (NOT nothing!)
|
||
// Python model expects tts_pad after text exhaustion - crucial for EOS convergence
|
||
val nextEmbed: FloatArray
|
||
if (trailingIdx < trailingEmbeds.size) {
|
||
nextEmbed = sumEmb(codecSum, trailingEmbeds[trailingIdx])
|
||
trailingIdx++
|
||
} else if (trailingIdx == trailingEmbeds.size) {
|
||
nextEmbed = sumEmb(codecSum, eosE)
|
||
trailingIdx++
|
||
} else {
|
||
nextEmbed = sumEmb(codecSum, padE)
|
||
}
|
||
|
||
// 3. Run talker step
|
||
maskData[MAX_CONTEXT - 1 - pos] = 0f
|
||
val tTalker = System.currentTimeMillis()
|
||
val res = runTalkerStep(env, session!!, nextEmbed, maskData, pos, kCaches, vCaches)
|
||
val talkerMs = System.currentTimeMillis() - tTalker
|
||
totalTalkerMs += talkerMs
|
||
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
|
||
pos++
|
||
|
||
if (genStep < 3) {
|
||
nlog(" Talker=${talkerMs}ms")
|
||
}
|
||
|
||
// 4. Get next cb0 from logits
|
||
val logits = res.logits
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
// HuggingFace repetition penalty: 1.05x once per unique token in history
|
||
val seen = HashSet<Int>()
|
||
for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) {
|
||
logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f
|
||
}
|
||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||
|
||
if (nextCb0 == CODEC_EOS) {
|
||
nlog("EOS at gen step ${genStep + 2}")
|
||
break
|
||
}
|
||
// Safety: stop on extreme repetition (10+ identical = model degeneration)
|
||
if (generatedCb0.size >= 9) {
|
||
val last9 = generatedCb0.takeLast(9)
|
||
if (last9.all { it == nextCb0 }) {
|
||
nlog("Degeneration at step ${genStep + 2} (token $nextCb0 × 10), stopping")
|
||
break
|
||
}
|
||
}
|
||
currentCb0 = nextCb0
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
return allCodes.toTypedArray()
|
||
}
|
||
|
||
/** Single talker step: returns logits, hidden, and updated KV caches */
|
||
private data class TalkerStepResult(
|
||
val logits: FloatArray, val hidden: FloatArray,
|
||
val newK: Array<FloatArray>, val newV: Array<FloatArray>
|
||
)
|
||
|
||
private fun runTalkerStep(
|
||
env: OrtEnvironment, session: OrtSession,
|
||
inputEmbed: FloatArray, maskData: FloatArray, pos: Int,
|
||
kCaches: Array<FloatArray>, vCaches: Array<FloatArray>
|
||
): TalkerStepResult {
|
||
if (talkerUsesCosSin) {
|
||
return runTalkerStepMRoPE(env, session, inputEmbed, maskData, pos, kCaches, vCaches)
|
||
}
|
||
val embedTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||
val maskTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
|
||
val posTensor = if (talkerUsesInt64Pos) {
|
||
OnnxTensor.createTensor(env, java.nio.LongBuffer.wrap(longArrayOf(pos.toLong())), longArrayOf(1))
|
||
} else {
|
||
OnnxTensor.createTensor(env, IntBuffer.wrap(intArrayOf(pos)), longArrayOf(1))
|
||
}
|
||
|
||
val inputs = LinkedHashMap<String, OnnxTensor>()
|
||
inputs["inputs_embeds"] = embedTensor
|
||
inputs["attention_mask"] = maskTensor
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
|
||
longArrayOf(1, TALKER_HEADS.toLong(), TALKER_HEAD_DIM.toLong(), KV_LEN.toLong()))
|
||
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
|
||
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
|
||
}
|
||
inputs["position_ids"] = posTensor
|
||
|
||
val result = session.run(inputs)
|
||
|
||
// logits [1, 3072, 1, 1]
|
||
val logits = FloatArray(TALKER_VOCAB)
|
||
(result.get(0) as OnnxTensor).floatBuffer.get(logits)
|
||
|
||
// hidden [1, 1, 1024] at output index 57
|
||
val hidden = FloatArray(TALKER_DIM)
|
||
(result.get(57) as OnnxTensor).floatBuffer.get(hidden)
|
||
|
||
// KV caches
|
||
val kCacheSize = TALKER_HEADS * TALKER_HEAD_DIM * KV_LEN
|
||
val vCacheSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
|
||
val newK = Array(TALKER_LAYERS) { FloatArray(kCacheSize) }
|
||
val newV = Array(TALKER_LAYERS) { FloatArray(vCacheSize) }
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
(result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(newK[i])
|
||
(result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(newV[i])
|
||
}
|
||
|
||
for ((_, v) in inputs) v.close()
|
||
result.close()
|
||
|
||
return TalkerStepResult(logits, hidden, newK, newV)
|
||
}
|
||
|
||
/** New talker step with M-RoPE: cos/sin inputs, KV shape [1,8,199,128]. */
|
||
private fun runTalkerStepMRoPE(
|
||
env: OrtEnvironment, session: OrtSession,
|
||
inputEmbed: FloatArray, maskData: FloatArray, pos: Int,
|
||
kCaches: Array<FloatArray>, vCaches: Array<FloatArray>
|
||
): TalkerStepResult {
|
||
val cos = rotaryCos ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches)
|
||
val sin = rotarySin ?: return TalkerStepResult(FloatArray(TALKER_VOCAB), FloatArray(TALKER_DIM), kCaches, vCaches)
|
||
|
||
val inputs = LinkedHashMap<String, OnnxTensor>()
|
||
inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputEmbed), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||
inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(maskData.clone()), longArrayOf(1, 1, 1, MAX_CONTEXT.toLong()))
|
||
|
||
// cos/sin for this position: [1, 1, 128]; clamp index so we never read past
|
||
// the rotary table even if generation runs longer than the table.
|
||
val rotMax = (cos.size / TALKER_HEAD_DIM) - 1
|
||
val rotPos = if (pos > rotMax) rotMax else pos
|
||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||
System.arraycopy(cos, rotPos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||
System.arraycopy(sin, rotPos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||
inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
|
||
inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))
|
||
|
||
// KV caches [1, 8, 199, 128] (NOT transposed like legacy)
|
||
val kvSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM // 8 * 199 * 128
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
|
||
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
|
||
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
|
||
longArrayOf(1, TALKER_HEADS.toLong(), KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))
|
||
}
|
||
|
||
val result = session.run(inputs)
|
||
|
||
// New format: hidden at index 0 [1,1,1024], logits at index 1 [1,1,3072]
|
||
val hidden = FloatArray(TALKER_DIM)
|
||
(result.get(0) as OnnxTensor).floatBuffer.get(hidden)
|
||
val logits = FloatArray(TALKER_VOCAB)
|
||
(result.get(1) as OnnxTensor).floatBuffer.get(logits)
|
||
|
||
// KV caches out: [1, 8, 200, 128] — trim to [1, 8, 199, 128] (drop oldest)
|
||
val newK = Array(TALKER_LAYERS) { i ->
|
||
val full = FloatArray(TALKER_HEADS * MAX_CONTEXT * TALKER_HEAD_DIM)
|
||
(result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(full)
|
||
// Drop first position: copy [1:200] → [0:199]
|
||
val trimmed = FloatArray(kvSize)
|
||
System.arraycopy(full, TALKER_HEADS * TALKER_HEAD_DIM, trimmed, 0, kvSize)
|
||
// Wait — full is [1,8,200,128] flattened. Drop pos 0 from dim 2:
|
||
// For each head h: copy full[h*200*128 + 1*128 .. h*200*128 + 200*128] → trimmed[h*199*128..]
|
||
val t = FloatArray(kvSize)
|
||
for (h in 0 until TALKER_HEADS) {
|
||
System.arraycopy(full, h * MAX_CONTEXT * TALKER_HEAD_DIM + TALKER_HEAD_DIM,
|
||
t, h * KV_LEN * TALKER_HEAD_DIM, KV_LEN * TALKER_HEAD_DIM)
|
||
}
|
||
t
|
||
}
|
||
val newV = Array(TALKER_LAYERS) { i ->
|
||
val full = FloatArray(TALKER_HEADS * MAX_CONTEXT * TALKER_HEAD_DIM)
|
||
(result.get(3 + i * 2) as OnnxTensor).floatBuffer.get(full)
|
||
val t = FloatArray(kvSize)
|
||
for (h in 0 until TALKER_HEADS) {
|
||
System.arraycopy(full, h * MAX_CONTEXT * TALKER_HEAD_DIM + TALKER_HEAD_DIM,
|
||
t, h * KV_LEN * TALKER_HEAD_DIM, KV_LEN * TALKER_HEAD_DIM)
|
||
}
|
||
t
|
||
}
|
||
|
||
for ((_, v) in inputs) v.close()
|
||
result.close()
|
||
|
||
return TalkerStepResult(logits, hidden, newK, newV)
|
||
}
|
||
|
||
|
||
/** Run code predictor — JNI .pte > TCP runner > Hexagon > CPU ONNX. */
|
||
private var cpCallCount = 0
|
||
private fun runCodePredictorInterleaved(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
if (cpCallCount == 0) nlog("CP: pte=${cpPteModule != null}, talkerPte=${talkerPteModule != null}, et=$useEtCp, hex=$useHexagonCp, v2=$cpUsesCosSin")
|
||
cpCallCount++
|
||
// Diagnostic override: force the ONNX Runtime CP path (runCpV2). ORT's MLAS
|
||
// kernels may have a different fp32 reduction order than ggml, potentially closer
|
||
// to PyTorch. Test hypothesis: does that reduce code divergence from Python?
|
||
if (File("/data/local/tmp/kazeia/force_cp_v2").exists() && cpUsesCosSin && cpKv != null) {
|
||
if (cpCallCount == 1) nlog("force_cp_v2: using ONNX Runtime CP (cpKv session)")
|
||
return runCpV2(pastHidden, cb0)
|
||
}
|
||
// Routing priority:
|
||
// 1. If Hexagon talker is running: MUST use hexCpForward (separate CPU process
|
||
// via socket). Running CP .pte on QNN HTP alongside Hexagon triggers err 6031
|
||
// (documented DSP contention).
|
||
// 2. Otherwise (Talker on .pte or CPU ONNX): CP .pte JNI in the same QNN context.
|
||
// 3. Legacy fallbacks.
|
||
if (useHexagonTalker && useHexagonCp) return hexCpForward(pastHidden, cb0)
|
||
if (cpPteModule != null && (talkerPteModule != null || talkerKv != null)) return runCpPte(pastHidden, cb0)
|
||
if (useEtCp) return etCpForward(pastHidden, cb0)
|
||
if (useHexagonCp) return hexCpForward(pastHidden, cb0)
|
||
if (cpUsesCosSin && cpKv != null) return runCpV2(pastHidden, cb0)
|
||
return runCpCpu(pastHidden, cb0)
|
||
}
|
||
|
||
/** CP via ExecuTorch .pte on NPU (QNN fp16).
|
||
* .pte inputs: emb[1,1,1024], mask[1,1,1,17], cos[1,1,128], sin[1,1,128], 10×kv[1,8,16,128]
|
||
* .pte outputs: hidden[1,1,1024], head_logits[1,15,2048], 10×kv[1,8,17,128]
|
||
*/
|
||
private fun runCpPte(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
val module = cpPteModule ?: return runCpV2(pastHidden, cb0)
|
||
val cos = cpRotaryCos ?: return runCpV2(pastHidden, cb0)
|
||
val sin = cpRotarySin ?: return runCpV2(pastHidden, cb0)
|
||
|
||
val codes = IntArray(15)
|
||
// Fixed KV caches [1, 8, 16, 128]
|
||
val kvSize = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
|
||
var kCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
|
||
var vCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
|
||
|
||
var emb = pastHidden
|
||
try {
|
||
for (step in 0 until 17) {
|
||
if (step == 1) {
|
||
emb = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
|
||
} else if (step >= 2) {
|
||
emb = FloatArray(TALKER_DIM)
|
||
val cembs = cpEmbeddings ?: return codes
|
||
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
|
||
System.arraycopy(cembs, off, emb, 0, TALKER_DIM)
|
||
}
|
||
|
||
// Build mask: last (step+1) positions active
|
||
val mask = FloatArray(CP_KV_LEN) { -1e9f }
|
||
for (p in 0..minOf(step, CP_KV_LEN - 1)) mask[CP_KV_LEN - 1 - p] = 0f
|
||
|
||
val cosSlice = FloatArray(CP_HEAD_DIM)
|
||
System.arraycopy(cos, step * CP_HEAD_DIM, cosSlice, 0, CP_HEAD_DIM)
|
||
val sinSlice = FloatArray(CP_HEAD_DIM)
|
||
System.arraycopy(sin, step * CP_HEAD_DIM, sinSlice, 0, CP_HEAD_DIM)
|
||
|
||
// Build EValue inputs
|
||
val inputs = arrayOf(
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(emb, longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(mask, longArrayOf(1, 1, 1, CP_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, CP_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, CP_HEAD_DIM.toLong()))),
|
||
)
|
||
// Add KV caches
|
||
val allInputs = inputs.toMutableList()
|
||
for (i in 0 until CP_LAYERS) {
|
||
allInputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(kCaches[i], longArrayOf(1, CP_KV_HEADS.toLong(), CP_KV_LEN.toLong(), CP_HEAD_DIM.toLong()))))
|
||
allInputs.add(org.pytorch.executorch.EValue.from(
|
||
org.pytorch.executorch.Tensor.fromBlob(vCaches[i], longArrayOf(1, CP_KV_HEADS.toLong(), CP_KV_LEN.toLong(), CP_HEAD_DIM.toLong()))))
|
||
}
|
||
|
||
val outputs = module.forward(*allInputs.toTypedArray())
|
||
|
||
// .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ...
|
||
val hiddenOut = outputs[0].toTensor().dataAsFloatArray
|
||
|
||
// Head argmax using NEON SIMD (individual 8MB heads, pre-loaded)
|
||
if (step >= 1 && step - 1 < 15) {
|
||
if (cpHeadsCache == null) {
|
||
cpHeadsCache = arrayOfNulls(15)
|
||
val hp = cpHeadsPath
|
||
if (hp != null) {
|
||
for (h in 0 until 15) {
|
||
cpHeadsCache!![h] = loadNpy(hp.replace("cp_heads.npy", "head_${h}.npy"))
|
||
}
|
||
nlog("CP heads pre-loaded: 15 × ${cpHeadsCache!![0]?.size} floats")
|
||
}
|
||
}
|
||
val cbIdx = step - 1
|
||
val headData = cpHeadsCache?.get(cbIdx)
|
||
if (headData != null) {
|
||
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, headData, CODEBOOK_SIZE, TALKER_DIM)
|
||
}
|
||
}
|
||
|
||
// Update KV caches (output is [1,8,16,128] — fixed size, already shifted)
|
||
for (i in 0 until CP_LAYERS) {
|
||
kCaches[i] = outputs[1 + i * 2].toTensor().dataAsFloatArray
|
||
vCaches[i] = outputs[2 + i * 2].toTensor().dataAsFloatArray
|
||
}
|
||
}
|
||
} catch (e: Exception) {
|
||
nlog("CP .pte error: ${e.message}, falling back to ONNX CPU")
|
||
cpPteModule = null
|
||
return runCpV2(pastHidden, cb0)
|
||
}
|
||
|
||
return codes
|
||
}
|
||
|
||
/** CP V2: ONNX single-step with KV-cache, 15 lm_heads, autoregressive. */
|
||
private fun runCpV2(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
val env = ortEnv ?: return IntArray(15)
|
||
val session = cpKv ?: return IntArray(15)
|
||
val cos = cpRotaryCos ?: return IntArray(15)
|
||
val sin = cpRotarySin ?: return IntArray(15)
|
||
val headPath = cpHeadsPath ?: return IntArray(15)
|
||
val cembs = cpEmbeddings ?: return IntArray(15)
|
||
|
||
val codes = IntArray(15)
|
||
// Fixed KV: size=CP_KV_LEN always, with mask for active positions
|
||
// Dynamic KV: starts empty, grows
|
||
val kvSize = if (cpFixedKv) CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM else 0
|
||
var kCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
|
||
var vCaches = Array(CP_LAYERS) { FloatArray(kvSize) }
|
||
val cpMask = if (cpFixedKv) FloatArray(CP_KV_LEN) { -1e9f } else null
|
||
|
||
// Step 0: hidden state
|
||
var emb = pastHidden
|
||
for (step in 0 until 17) {
|
||
if (step == 1) {
|
||
// cb0 embedding from talker codec_embedding
|
||
emb = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
|
||
} else if (step >= 2) {
|
||
// codec_embedding from CP's own tables
|
||
emb = FloatArray(TALKER_DIM)
|
||
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
|
||
System.arraycopy(cembs, off, emb, 0, TALKER_DIM)
|
||
}
|
||
|
||
val kvLen = if (cpFixedKv) CP_KV_LEN else kCaches[0].size / (CP_KV_HEADS * CP_HEAD_DIM)
|
||
val cosSlice = FloatArray(CP_HEAD_DIM)
|
||
System.arraycopy(cos, step * CP_HEAD_DIM, cosSlice, 0, CP_HEAD_DIM)
|
||
val sinSlice = FloatArray(CP_HEAD_DIM)
|
||
System.arraycopy(sin, step * CP_HEAD_DIM, sinSlice, 0, CP_HEAD_DIM)
|
||
|
||
val inputs = LinkedHashMap<String, OnnxTensor>()
|
||
inputs["emb"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(emb), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||
if (cpFixedKv) {
|
||
// Fixed KV mask: after step N, the last (N+1) positions are valid
|
||
// The model shifts left internally, so valid positions are right-aligned
|
||
val mask = FloatArray(CP_KV_LEN) { -1e9f }
|
||
for (p in 0..step) mask[CP_KV_LEN - 1 - p] = 0f
|
||
inputs["mask"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(mask), longArrayOf(1, 1, 1, CP_KV_LEN.toLong()))
|
||
}
|
||
inputs["cos"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(cosSlice), longArrayOf(1, CP_HEAD_DIM.toLong()))
|
||
inputs["sin"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(sinSlice), longArrayOf(1, CP_HEAD_DIM.toLong()))
|
||
for (i in 0 until CP_LAYERS) {
|
||
inputs["k_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(kCaches[i]),
|
||
longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
|
||
inputs["v_${i}_in"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(vCaches[i]),
|
||
longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
|
||
}
|
||
|
||
val result = session.run(inputs)
|
||
val hidden = FloatArray(TALKER_DIM)
|
||
(result.get(0) as OnnxTensor).floatBuffer.get(hidden)
|
||
|
||
// Extract code from lm_head: hidden @ head[cb]^T → argmax
|
||
if (step >= 1 && step - 1 < 15) {
|
||
val cbIdx = step - 1
|
||
// Lazy-load head into cache (8MB each, loaded once)
|
||
if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15)
|
||
val cache = cpHeadsCache!!
|
||
if (cache[cbIdx] == null) {
|
||
cache[cbIdx] = loadNpy(headPath.replace("cp_heads.npy", "head_${cbIdx}.npy"))
|
||
}
|
||
val headData = cache[cbIdx]!!
|
||
var best = 0; var bestVal = Float.NEGATIVE_INFINITY
|
||
for (j in 0 until CODEBOOK_SIZE) {
|
||
var dot = 0f
|
||
val off = j * TALKER_DIM
|
||
for (k in 0 until TALKER_DIM) dot += hidden[k] * headData[off + k]
|
||
if (dot > bestVal) { bestVal = dot; best = j }
|
||
}
|
||
codes[cbIdx] = best
|
||
}
|
||
|
||
// Update KV caches
|
||
if (cpFixedKv) {
|
||
// Fixed: output is same size as input (CP_KV)
|
||
val fixedSize = CP_KV_HEADS * CP_KV_LEN * CP_HEAD_DIM
|
||
kCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(fixedSize).also { (result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
vCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(fixedSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
} else {
|
||
// Dynamic: grows by 1
|
||
val newKvLen = kvLen + 1
|
||
val newKvSize = CP_KV_HEADS * newKvLen * CP_HEAD_DIM
|
||
kCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(newKvSize).also { (result.get(1 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
vCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(newKvSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
}
|
||
|
||
for ((_, v) in inputs) v.close()
|
||
result.close()
|
||
}
|
||
|
||
return codes
|
||
}
|
||
|
||
|
||
/** CP via ONNX CPU KV-cache: 17 single-token steps with dynamic KV. ~180ms. */
|
||
private fun runCpCpu(pastHidden: FloatArray, cb0: Int): IntArray {
|
||
val env = ortEnv ?: return IntArray(15)
|
||
val cpModel = cpKv ?: return IntArray(15)
|
||
val cpEmbs = cpEmbeddings ?: return IntArray(15)
|
||
|
||
val codes = IntArray(15)
|
||
// Dynamic KV caches: grow from [1,8,0,128] to [1,8,16,128]
|
||
var kCaches = Array(CP_LAYERS) { FloatArray(0) }
|
||
var vCaches = Array(CP_LAYERS) { FloatArray(0) }
|
||
|
||
// Step 0: hidden state
|
||
var emb = pastHidden
|
||
for (step in 0 until 17) {
|
||
if (step == 1) {
|
||
emb = FloatArray(TALKER_DIM)
|
||
System.arraycopy(codecEmbedding!!, cb0.coerceIn(0, TALKER_VOCAB - 1) * TALKER_DIM, emb, 0, TALKER_DIM)
|
||
} else if (step >= 2) {
|
||
emb = FloatArray(TALKER_DIM)
|
||
val off = ((step - 2) * CODEBOOK_SIZE + codes[step - 2].coerceIn(0, CODEBOOK_SIZE - 1)) * TALKER_DIM
|
||
System.arraycopy(cpEmbs, off, emb, 0, TALKER_DIM)
|
||
}
|
||
|
||
val totalLen = step + 1
|
||
val inputs = LinkedHashMap<String, OnnxTensor>()
|
||
inputs["input_embeds"] = OnnxTensor.createTensor(env,
|
||
FloatBuffer.wrap(emb), longArrayOf(1, 1, TALKER_DIM.toLong()))
|
||
inputs["attention_mask"] = OnnxTensor.createTensor(env,
|
||
FloatBuffer.wrap(FloatArray(totalLen)), longArrayOf(1, 1, 1, totalLen.toLong()))
|
||
inputs["position_ids"] = OnnxTensor.createTensor(env,
|
||
LongBuffer.wrap(longArrayOf(step.toLong())), longArrayOf(1, 1))
|
||
for (i in 0 until CP_LAYERS) {
|
||
val kvLen = kCaches[i].size / (CP_KV_HEADS * CP_HEAD_DIM)
|
||
inputs["k_${i}_in"] = OnnxTensor.createTensor(env,
|
||
FloatBuffer.wrap(kCaches[i]), longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
|
||
inputs["v_${i}_in"] = OnnxTensor.createTensor(env,
|
||
FloatBuffer.wrap(vCaches[i]), longArrayOf(1, CP_KV_HEADS.toLong(), kvLen.toLong(), CP_HEAD_DIM.toLong()))
|
||
}
|
||
|
||
val result = cpModel.run(inputs)
|
||
|
||
// Extract head logits [1, 15, 2048] and code for this step
|
||
if (step >= 1 && step - 1 < 15) {
|
||
val headLogits = FloatArray(15 * CODEBOOK_SIZE)
|
||
(result.get(1) as OnnxTensor).floatBuffer.get(headLogits)
|
||
val cbIdx = step - 1
|
||
val headOff = cbIdx * CODEBOOK_SIZE
|
||
var maxIdx = 0; var maxVal = Float.NEGATIVE_INFINITY
|
||
for (j in 0 until CODEBOOK_SIZE) {
|
||
if (headLogits[headOff + j] > maxVal) { maxVal = headLogits[headOff + j]; maxIdx = j }
|
||
}
|
||
codes[cbIdx] = maxIdx
|
||
}
|
||
|
||
// Update KV caches (dynamic size grows by 1 each step)
|
||
val newKvSize = CP_KV_HEADS * totalLen * CP_HEAD_DIM
|
||
kCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(newKvSize).also { (result.get(2 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
vCaches = Array(CP_LAYERS) { i ->
|
||
FloatArray(newKvSize).also { (result.get(3 + i * 2) as OnnxTensor).floatBuffer.get(it) }
|
||
}
|
||
|
||
for ((_, v) in inputs) v.close()
|
||
result.close()
|
||
}
|
||
|
||
return codes
|
||
}
|
||
|
||
/**
|
||
* Trim trailing silence/noise from audio.
|
||
* Duplicate removed — see trimTrailingSilence below.
|
||
*/
|
||
|
||
/** Sample from logits with temperature scaling and top-K filtering */
|
||
private fun sampleTopK(logits: FloatArray, temperature: Float = 0.9f, topK: Int = 50): Int {
|
||
// Find top-K indices
|
||
val indices = logits.indices.sortedByDescending { logits[it] }.take(topK)
|
||
|
||
// Temperature-scaled softmax over top-K
|
||
val maxLogit = logits[indices[0]]
|
||
val expValues = FloatArray(indices.size)
|
||
var sumExp = 0f
|
||
for (i in indices.indices) {
|
||
val scaled = (logits[indices[i]] - maxLogit) / temperature
|
||
expValues[i] = kotlin.math.exp(scaled)
|
||
sumExp += expValues[i]
|
||
}
|
||
|
||
// Categorical sampling
|
||
val r = Math.random().toFloat() * sumExp
|
||
var cumSum = 0f
|
||
for (i in indices.indices) {
|
||
cumSum += expValues[i]
|
||
if (cumSum >= r) return indices[i]
|
||
}
|
||
return indices.last()
|
||
}
|
||
|
||
// ==================== Code Predictor ====================
|
||
|
||
// ==================== Speech Decoder ====================
|
||
|
||
/** Decode 16 codebooks to audio in chunks, with overlap for seamless concatenation */
|
||
private fun decodeChunked(codebooks: Array<IntArray>, numRealTokens: Int): ShortArray {
|
||
val totalTokens = codebooks[0].size
|
||
|
||
if (totalTokens <= SEQ_LEN) {
|
||
val quantized = vqDecode(codebooks)
|
||
val fullAudio = runSpeechDecoder(quantized)
|
||
val trimSamples = minOf(numRealTokens * SAMPLES_PER_TOKEN, fullAudio.size)
|
||
return fullAudio.copyOf(trimSamples)
|
||
}
|
||
|
||
val overlapSamples = CHUNK_OVERLAP * SAMPLES_PER_TOKEN
|
||
var result = ShortArray(0)
|
||
var pos = 0
|
||
|
||
while (pos < numRealTokens) {
|
||
val chunkCodes = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(SEQ_LEN) { t ->
|
||
val srcIdx = pos + t
|
||
if (srcIdx < totalTokens) codebooks[cb][srcIdx] else 0
|
||
}
|
||
}
|
||
|
||
val quantized = vqDecode(chunkCodes)
|
||
val chunkAudio = runSpeechDecoder(quantized)
|
||
|
||
// Trim chunk to real tokens
|
||
val realInChunk = minOf(SEQ_LEN, numRealTokens - pos)
|
||
val trimmed = chunkAudio.copyOf(minOf(realInChunk * SAMPLES_PER_TOKEN, chunkAudio.size))
|
||
|
||
if (pos == 0) {
|
||
result = trimmed
|
||
} else {
|
||
// Crossfade over the overlap region
|
||
val fadeLen = minOf(overlapSamples, result.size, trimmed.size)
|
||
for (i in 0 until fadeLen) {
|
||
val alpha = i.toFloat() / fadeLen
|
||
val mixed = ((1f - alpha) * result[result.size - fadeLen + i] + alpha * trimmed[i]).toInt()
|
||
.coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort()
|
||
result[result.size - fadeLen + i] = mixed
|
||
}
|
||
// Append the non-overlapping part
|
||
if (fadeLen < trimmed.size) {
|
||
val newPart = trimmed.copyOfRange(fadeLen, trimmed.size)
|
||
val combined = ShortArray(result.size + newPart.size)
|
||
System.arraycopy(result, 0, combined, 0, result.size)
|
||
System.arraycopy(newPart, 0, combined, result.size, newPart.size)
|
||
result = combined
|
||
}
|
||
}
|
||
|
||
pos += EFFECTIVE_CHUNK
|
||
}
|
||
return result
|
||
}
|
||
|
||
/** VQ decode: 16 codebooks → quantized [1, 512, SEQ_LEN] */
|
||
private fun vqDecode(codebooks: Array<IntArray>): FloatArray {
|
||
val firstCb = firstCodebook ?: return FloatArray(0)
|
||
val restCbs = restCodebooks ?: return FloatArray(0)
|
||
val firstProj = firstOutputProj ?: return FloatArray(0)
|
||
val restProj = restOutputProj ?: return FloatArray(0)
|
||
|
||
val qFirst = FloatArray(CODEBOOK_DIM * SEQ_LEN)
|
||
for (t in 0 until SEQ_LEN) {
|
||
val idx = codebooks[0][t]
|
||
System.arraycopy(firstCb, idx * CODEBOOK_DIM, qFirst, t * CODEBOOK_DIM, CODEBOOK_DIM)
|
||
}
|
||
val quantized = FloatArray(HIDDEN_DIM * SEQ_LEN)
|
||
for (i in 0 until HIDDEN_DIM) {
|
||
for (t in 0 until SEQ_LEN) {
|
||
var sum = 0f
|
||
for (d in 0 until CODEBOOK_DIM) {
|
||
sum += firstProj[i * CODEBOOK_DIM + d] * qFirst[t * CODEBOOK_DIM + d]
|
||
}
|
||
quantized[i * SEQ_LEN + t] = sum
|
||
}
|
||
}
|
||
|
||
val restSum = FloatArray(CODEBOOK_DIM * SEQ_LEN)
|
||
for (cb in 0 until 15) {
|
||
val cbData = restCbs[cb]
|
||
for (t in 0 until SEQ_LEN) {
|
||
val idx = codebooks[cb + 1][t]
|
||
for (d in 0 until CODEBOOK_DIM) {
|
||
restSum[t * CODEBOOK_DIM + d] += cbData[idx * CODEBOOK_DIM + d]
|
||
}
|
||
}
|
||
}
|
||
for (i in 0 until HIDDEN_DIM) {
|
||
for (t in 0 until SEQ_LEN) {
|
||
var sum = 0f
|
||
for (d in 0 until CODEBOOK_DIM) {
|
||
sum += restProj[i * CODEBOOK_DIM + d] * restSum[t * CODEBOOK_DIM + d]
|
||
}
|
||
quantized[i * SEQ_LEN + t] += sum
|
||
}
|
||
}
|
||
|
||
return quantized
|
||
}
|
||
|
||
/** Run speech decoder: quantized → pre_conv → preprocessor → ConvNet → audio */
|
||
private fun runSpeechDecoder(quantized: FloatArray): ShortArray {
|
||
val env = ortEnv!!
|
||
if (decoderOnCpu || decoderOnGpu) {
|
||
return runSpeechDecoderV2(quantized)
|
||
}
|
||
val qTensor = OnnxTensor.createTensor(
|
||
env, FloatBuffer.wrap(quantized), longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong())
|
||
)
|
||
val pcResult = preConv!!.run(mapOf("x" to qTensor))
|
||
val pcOut = pcResult[0] as OnnxTensor
|
||
nlog("pre_conv: ${qTensor.info.shape.contentToString()} → ${pcOut.info.shape.contentToString()}")
|
||
qTensor.close()
|
||
|
||
val ppResult = preprocessor!!.run(mapOf("pre_conv_out" to pcOut))
|
||
val ppOut = ppResult[0] as OnnxTensor
|
||
nlog("preprocessor: → ${ppOut.info.shape.contentToString()}")
|
||
pcResult.close()
|
||
|
||
val cdResult = convDecoder!!.run(mapOf("hidden" to ppOut))
|
||
val cdOut = cdResult[0] as OnnxTensor
|
||
nlog("conv_decoder: → ${cdOut.info.shape.contentToString()}")
|
||
@Suppress("UNCHECKED_CAST")
|
||
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
|
||
ppResult.close()
|
||
cdResult.close()
|
||
|
||
return ShortArray(audioFloat.size) {
|
||
(audioFloat[it].coerceIn(-1f, 1f) * 32767).toInt().toShort()
|
||
}
|
||
}
|
||
|
||
/** V2 decoder: quantized[1,512,60] → pre_conv → transpose → pre_transformer → decoder → audio */
|
||
private fun runSpeechDecoderV2(quantized: FloatArray): ShortArray {
|
||
val env = ortEnv!!
|
||
val td0 = System.currentTimeMillis()
|
||
// pre_conv: [1,512,60] → [1,1024,60]
|
||
val qTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(quantized),
|
||
longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong()))
|
||
val pcResult = preConv!!.run(mapOf(preConv!!.inputNames.first() to qTensor))
|
||
val pcOutRaw = (pcResult[0] as OnnxTensor).floatBuffer
|
||
val pcData = FloatArray(1024 * SEQ_LEN)
|
||
pcOutRaw.get(pcData)
|
||
val td1 = System.currentTimeMillis()
|
||
nlog("V2 pre_conv: ${td1-td0}ms")
|
||
qTensor.close()
|
||
|
||
// Transpose [1,1024,60] → [1,60,1024]
|
||
val transposed = FloatArray(SEQ_LEN * 1024)
|
||
for (c in 0 until 1024) {
|
||
for (t in 0 until SEQ_LEN) {
|
||
transposed[t * 1024 + c] = pcData[c * SEQ_LEN + t]
|
||
}
|
||
}
|
||
val ptInput = OnnxTensor.createTensor(env, FloatBuffer.wrap(transposed),
|
||
longArrayOf(1, SEQ_LEN.toLong(), 1024))
|
||
pcResult.close()
|
||
|
||
// pre_transformer: [1,60,1024] → [1,60,1024]
|
||
val ptResult = preprocessor!!.run(mapOf(preprocessor!!.inputNames.first() to ptInput))
|
||
val ptOut = ptResult[0] as OnnxTensor
|
||
val td2 = System.currentTimeMillis()
|
||
nlog("V2 pre_transformer: ${td2-td1}ms")
|
||
ptInput.close()
|
||
|
||
// decoder: [1,60,1024] → [1,1,samples]
|
||
val cdResult = convDecoder!!.run(mapOf(convDecoder!!.inputNames.first() to ptOut))
|
||
val cdOut = cdResult[0] as OnnxTensor
|
||
nlog("V2 BigVGAN: ${System.currentTimeMillis()-td2}ms")
|
||
@Suppress("UNCHECKED_CAST")
|
||
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
|
||
ptResult.close()
|
||
cdResult.close()
|
||
|
||
return ShortArray(audioFloat.size) {
|
||
(audioFloat[it].coerceIn(-1f, 1f) * 32767).toInt().toShort()
|
||
}
|
||
}
|
||
|
||
// ==================== Loading ====================
|
||
|
||
private fun loadVqCodebooks(path: String) {
|
||
firstCodebook = loadNpy("$path/vq_rvq_first_vq_layers_0_codebook.npy")
|
||
firstOutputProj = loadNpy("$path/vq_rvq_first_output_proj_w.npy")
|
||
restOutputProj = loadNpy("$path/vq_rvq_rest_output_proj_w.npy")
|
||
restCodebooks = Array(15) { i ->
|
||
loadNpy("$path/vq_rvq_rest_vq_layers_${i}_codebook.npy")
|
||
}
|
||
nlog("VQ codebooks loaded (16 × [${CODEBOOK_SIZE}, ${CODEBOOK_DIM}])")
|
||
}
|
||
|
||
/** Load a numpy .npy file as float array */
|
||
private fun loadNpy(path: String): FloatArray {
|
||
val file = File(path)
|
||
if (!file.exists()) {
|
||
nlog("WARN: $path not found")
|
||
return FloatArray(0)
|
||
}
|
||
val bytes = file.readBytes()
|
||
val headerLen = (bytes[8].toInt() and 0xFF) or ((bytes[9].toInt() and 0xFF) shl 8)
|
||
val dataOffset = 10 + headerLen
|
||
val numFloats = (bytes.size - dataOffset) / 4
|
||
val result = FloatArray(numFloats)
|
||
val bb = ByteBuffer.wrap(bytes, dataOffset, bytes.size - dataOffset)
|
||
.order(ByteOrder.LITTLE_ENDIAN)
|
||
bb.asFloatBuffer().get(result)
|
||
return result
|
||
}
|
||
|
||
// ==================== Audio Playback ====================
|
||
|
||
private fun playAudio(audioData: ShortArray, onComplete: () -> Unit) {
|
||
stop()
|
||
val bufferSize = audioData.size * 2
|
||
audioTrack = AudioTrack.Builder()
|
||
.setAudioAttributes(AudioAttributes.Builder()
|
||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||
.build())
|
||
.setAudioFormat(AudioFormat.Builder()
|
||
.setSampleRate(SR)
|
||
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||
.build())
|
||
.setBufferSizeInBytes(bufferSize)
|
||
.setTransferMode(AudioTrack.MODE_STATIC)
|
||
.build()
|
||
|
||
audioTrack?.apply {
|
||
write(audioData, 0, audioData.size)
|
||
setNotificationMarkerPosition(audioData.size)
|
||
setPlaybackPositionUpdateListener(object : AudioTrack.OnPlaybackPositionUpdateListener {
|
||
override fun onMarkerReached(track: AudioTrack?) { onComplete() }
|
||
override fun onPeriodicNotification(track: AudioTrack?) {}
|
||
})
|
||
play()
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Diagnostic path: decode Python's captured codes directly, bypassing talker+CP.
|
||
* Used to isolate whether audio degradation comes from our talker/CP predictions
|
||
* (different codes than Python) vs from BigVGAN implementation differences.
|
||
*
|
||
* File format python_codes.bin: int32 n_steps, int32 n_codebooks(=16),
|
||
* then n_steps × 16 × int32 codes.
|
||
*/
|
||
fun generateFromPythonCodes(codesPath: String): ShortArray {
|
||
nlog("Direct Python-codes path: decoding $codesPath via BigVGAN only")
|
||
val t0 = System.currentTimeMillis()
|
||
val bytes = java.io.File(codesPath).readBytes()
|
||
val bb = java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
val nSteps = bb.int
|
||
val nCodebooks = bb.int
|
||
nlog("Python codes: $nSteps steps × $nCodebooks codebooks")
|
||
if (nCodebooks != NUM_CODEBOOKS) {
|
||
nlog("Unexpected codebook count: $nCodebooks != $NUM_CODEBOOKS"); return ShortArray(0)
|
||
}
|
||
val padLen = maxOf(nSteps, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { IntArray(padLen) }
|
||
for (t in 0 until nSteps) {
|
||
for (cb in 0 until NUM_CODEBOOKS) {
|
||
val v = bb.int
|
||
allCodebooks[cb][t] = v.coerceIn(0, CODEBOOK_SIZE - 1)
|
||
}
|
||
}
|
||
val audio = decodeChunked(allCodebooks, nSteps)
|
||
nlog("Python-codes decode: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
|
||
try {
|
||
val wavPath = "/data/local/tmp/kazeia/kazeia_PYCODES.wav"
|
||
val fos = java.io.FileOutputStream(wavPath)
|
||
val dataLen = audio.size * 2
|
||
val header = java.nio.ByteBuffer.allocate(44).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = java.nio.ByteBuffer.allocate(dataLen).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (s in audio) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("WAV saved (PyCodes): $wavPath")
|
||
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
|
||
return audio
|
||
}
|
||
|
||
/**
|
||
* Full pipeline from pre-computed embeddings (from Python generate() capture).
|
||
* File format: int32 n_prefill, int32 n_total, then n_total × 1024 floats.
|
||
* Runs talker ONNX → CP → VQ decode → speech decoder.
|
||
*/
|
||
fun generateFromEmbeds(embedsPath: String): ShortArray {
|
||
// Diagnostic override: if python_codes.bin exists and flag is set, decode
|
||
// Python's captured codes directly to isolate BigVGAN from talker/CP predictions.
|
||
val pyCodesPath = embedsPath.replace("full_pipeline_embeds.bin", "python_codes.bin")
|
||
if (diagFlag("force_python_codes") && java.io.File(pyCodesPath).exists()) {
|
||
nlog("force_python_codes flag + file present → diagnostic codes-direct path")
|
||
return generateFromPythonCodes(pyCodesPath)
|
||
}
|
||
if (!loaded || (!nativePipelineReady && talkerPteModule == null && !useHexagonTalker && (talkerKv == null || !talkerUsesCosSin))) {
|
||
nlog("generateFromEmbeds: no talker (native=$nativePipelineReady, pte=${talkerPteModule != null}, hex=$useHexagonTalker)")
|
||
return ShortArray(0)
|
||
}
|
||
// Priority: native C++ > Java .pte > Hexagon > ONNX CPU
|
||
if (nativePipelineReady) {
|
||
return generateFromEmbedsPte(embedsPath)
|
||
}
|
||
if (talkerPteModule != null && cpPteModule != null) {
|
||
return generateFromEmbedsPte(embedsPath)
|
||
}
|
||
if (useHexagonTalker) {
|
||
return generateFromEmbedsHexagon(embedsPath)
|
||
}
|
||
nlog("Full pipeline from: $embedsPath")
|
||
val t0 = System.currentTimeMillis()
|
||
|
||
val bytes = File(embedsPath).readBytes()
|
||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
val nPrefill = bb.int
|
||
val nTotal = bb.int
|
||
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||
|
||
val env = ortEnv!!
|
||
val session = talkerKv!!
|
||
val cos = rotaryCos!!
|
||
val sin = rotarySin!!
|
||
|
||
// KV caches [1, 8, KV_LEN, 128]
|
||
val kvSize = TALKER_HEADS * KV_LEN * TALKER_HEAD_DIM
|
||
var kCaches = Array(TALKER_LAYERS) { FloatArray(kvSize) }
|
||
var vCaches = Array(TALKER_LAYERS) { FloatArray(kvSize) }
|
||
val maskData = FloatArray(MAX_CONTEXT) { -1e9f }
|
||
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
var currentCb0 = -1
|
||
var pastHidden: FloatArray? = null
|
||
|
||
// Prefill
|
||
val tPrefill = System.currentTimeMillis()
|
||
for (step in 0 until nPrefill) {
|
||
maskData[MAX_CONTEXT - 1 - step] = 0f
|
||
val res = runTalkerStepMRoPE(env, session, embeds[step], maskData, step, kCaches, vCaches)
|
||
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
|
||
|
||
if (step == nPrefill - 1) {
|
||
val logits = res.logits
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||
nlog("Prefill: ${System.currentTimeMillis() - tPrefill}ms, cb0=$currentCb0")
|
||
}
|
||
}
|
||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return ShortArray(0)
|
||
|
||
// Voice-cloned replay: detect by large nTrailing (≥ 30 → Python-captured full
|
||
// trajectory). In that mode, feed captured decode embeds verbatim (they already
|
||
// include Python's codec_sum + text_hidden) and stop exactly at nTrailing.
|
||
val nTrailingOnnx = nTotal - nPrefill
|
||
val isVoiceClonedOnnx = nTrailingOnnx >= 30
|
||
val maxGenOnnx = if (isVoiceClonedOnnx) nTrailingOnnx else (nTrailingOnnx * 5 + 20)
|
||
var trailingIdxOnnx = 0
|
||
val eosE = ttsEosEmbed ?: FloatArray(TALKER_DIM)
|
||
val padE = ttsPadEmbed ?: FloatArray(TALKER_DIM)
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
val forceGreedyCb0Onnx = diagFlag("force_greedy_cb0")
|
||
nlog("ONNX voice-cloned=$isVoiceClonedOnnx, maxGen=$maxGenOnnx, greedy=$forceGreedyCb0Onnx")
|
||
for (genStep in 0 until maxGenOnnx) {
|
||
val codes = IntArray(NUM_CODEBOOKS)
|
||
codes[0] = currentCb0
|
||
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden!!, currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes)
|
||
generatedCb0.add(currentCb0)
|
||
|
||
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||
|
||
// Voice-cloned: feed captured embed directly. Text-only: reconstruct codec_sum + trailing.
|
||
val nextEmbed = if (isVoiceClonedOnnx) {
|
||
embeds[nPrefill + genStep]
|
||
} else {
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
val textE: FloatArray = when {
|
||
trailingIdxOnnx < nTrailingOnnx -> embeds[nPrefill + trailingIdxOnnx].also { trailingIdxOnnx++ }
|
||
trailingIdxOnnx == nTrailingOnnx -> { trailingIdxOnnx++; eosE }
|
||
else -> padE
|
||
}
|
||
sumEmb(codecSum, textE)
|
||
}
|
||
val absPos = nPrefill + genStep
|
||
if (absPos < MAX_CONTEXT) maskData[MAX_CONTEXT - 1 - absPos] = 0f
|
||
|
||
val tTalker = System.currentTimeMillis()
|
||
val res = runTalkerStepMRoPE(env, session, nextEmbed, maskData, nPrefill + genStep, kCaches, vCaches)
|
||
totalTalkerMs += System.currentTimeMillis() - tTalker
|
||
kCaches = res.newK; vCaches = res.newV; pastHidden = res.hidden
|
||
|
||
val logits = res.logits
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
val nextCb0 = if (forceGreedyCb0Onnx) {
|
||
var b = 0; var bv = logits[0]
|
||
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
|
||
b
|
||
} else sampleTopK(logits, 0.9f, 50)
|
||
|
||
if (!isVoiceClonedOnnx) {
|
||
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
|
||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||
nlog("Degeneration at step ${genStep + 2}"); break
|
||
}
|
||
}
|
||
currentCb0 = nextCb0
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker: ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
|
||
// Diagnostic: dump our generated codes for comparison with Python capture
|
||
try {
|
||
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
|
||
val fos = java.io.FileOutputStream(codesPath)
|
||
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
|
||
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("Tablet codes dumped (ONNX path): $codesPath")
|
||
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
|
||
|
||
if (n == 0) return ShortArray(0)
|
||
|
||
// Decode
|
||
val padLen = maxOf(n, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 }
|
||
}
|
||
val audio = decodeChunked(allCodebooks, n)
|
||
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s audio")
|
||
|
||
try {
|
||
val wavPath = "/data/local/tmp/kazeia/kazeia_ONNX.wav"
|
||
val fos = java.io.FileOutputStream(wavPath)
|
||
val dataLen = audio.size * 2
|
||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||
for (s in audio) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("WAV saved (ONNX): $wavPath (${audio.size} samples)")
|
||
} catch (e: Exception) { nlog("WAV save (ONNX) failed: ${e.message}") }
|
||
|
||
return audio
|
||
}
|
||
|
||
/** Full pipeline using .pte JNI talker + CP on NPU from pre-computed embeddings. */
|
||
private fun generateFromEmbedsPte(embedsPath: String): ShortArray {
|
||
nlog("Full pipeline (PTE) from: $embedsPath")
|
||
val t0 = System.currentTimeMillis()
|
||
val bytes = File(embedsPath).readBytes()
|
||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
|
||
// Detect format by comparing the file size to what each layout would predict.
|
||
// single-segment layout: 8 + nTotal*1024*4 bytes (int nPrefill, int nTotal, embeds)
|
||
// multi-segment layout: n_segments * (8 + nTotal_i*1024*4) + 4 bytes header
|
||
// For a single-segment file, `firstInt == nPrefill` and
|
||
// `8 + secondInt*1024*4 == file.length`. If that equation holds, it's single-segment;
|
||
// otherwise we treat firstInt as n_segments.
|
||
val firstInt = bb.int
|
||
val secondInt = bb.int
|
||
val fileLen = bytes.size.toLong()
|
||
val singleSegmentSize = 8L + secondInt.toLong() * TALKER_DIM * 4
|
||
val isSingleSegment = secondInt > 0 && secondInt < 100000 && fileLen == singleSegmentSize
|
||
val isMultiSegment = !isSingleSegment && firstInt in 1..100
|
||
bb.position(0)
|
||
bb.int // re-consume firstInt so downstream reads stay aligned
|
||
// (we'll re-read secondInt below when building the single-segment path)
|
||
|
||
if (isMultiSegment && talkerPteModule != null && cpPteModule != null) {
|
||
return generateMultiSegment(bb, firstInt, t0)
|
||
}
|
||
|
||
// Legacy single-segment format
|
||
val nPrefill = if (isMultiSegment) { bb.position(0); bb.int; bb.int } else firstInt
|
||
val nTotal = bb.int
|
||
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||
|
||
val allCodes: Array<IntArray>
|
||
// Check if capture mode requested (force Java path to capture embeds)
|
||
val captureMode = File("/data/local/tmp/kazeia/capture_mode").exists()
|
||
// Native C++ pipeline using SAME Java Module instances (no quality loss)
|
||
if (!captureMode && talkerPteModule != null && cpPteModule != null) {
|
||
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
|
||
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
||
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
||
val nTrailing = nTotal - nPrefill
|
||
val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr ->
|
||
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill + i], 0, arr, i * TALKER_DIM, TALKER_DIM)
|
||
} else null
|
||
|
||
// Load CP heads if not already
|
||
if (cpAllHeads == null) {
|
||
val headsFile = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin")
|
||
if (headsFile.exists()) {
|
||
val hb = headsFile.readBytes()
|
||
cpAllHeads = FloatArray(hb.size / 4)
|
||
ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!)
|
||
}
|
||
}
|
||
|
||
// Ensure all data arrays are loaded
|
||
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
|
||
if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy")
|
||
if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy")
|
||
if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy")
|
||
if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy")
|
||
if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("/data/local/tmp/kazeia/models/talker_pte_rotary_cos.npy")
|
||
if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("/data/local/tmp/kazeia/models/talker_pte_rotary_sin.npy")
|
||
if (ttsEosEmbed == null || ttsPadEmbed == null) {
|
||
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
|
||
val sp = loadNpy("$mpath/tts_special_embeds.npy")
|
||
ttsBosEmbed = sp.sliceArray(0 until TALKER_DIM)
|
||
ttsEosEmbed = sp.sliceArray(TALKER_DIM until 2 * TALKER_DIM)
|
||
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
|
||
}
|
||
|
||
// Voice-cloned embeds have one entry per audio frame (already captured via
|
||
// Python's natural EOS). Heuristic: if nTrailing is large enough to be a
|
||
// per-frame capture (≥ 30), trust it as the exact length; otherwise it's a
|
||
// text-only prepare_tts_native.py file that needs the runway/boost budget.
|
||
val isVoiceCloned = nTrailing >= 30
|
||
val maxGen = if (isVoiceCloned) nTrailing
|
||
else minOf(nTrailing * 4 + 20, 90 - nPrefill)
|
||
nlog("Running native C++ pipeline (shared Module), maxGen=$maxGen, voiceCloned=$isVoiceCloned...")
|
||
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||
prefillFlat, nPrefill,
|
||
trailingFlat, nTrailing,
|
||
codecEmbedding ?: FloatArray(0),
|
||
cpEmbeddings ?: FloatArray(0),
|
||
cpAllHeads ?: FloatArray(0),
|
||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||
maxGen
|
||
)
|
||
if (flat == null || flat.isEmpty()) return ShortArray(0)
|
||
val nTokens = flat.size / NUM_CODEBOOKS
|
||
allCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } }
|
||
nlog("Native pipeline: $nTokens tokens")
|
||
} else {
|
||
// Fallback: Java pipeline
|
||
val prefillEmbeds = embeds.sliceArray(0 until nPrefill).toList()
|
||
val trailingEmbeds = if (nPrefill < nTotal) embeds.sliceArray(nPrefill until nTotal).toList() else emptyList()
|
||
allCodes = runInterleavedPteFromEmbeds(prefillEmbeds, trailingEmbeds, nTotal - nPrefill)
|
||
}
|
||
|
||
if (allCodes.isEmpty()) return ShortArray(0)
|
||
val numRealTokens = allCodes.size
|
||
val padLen = maxOf(numRealTokens, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t -> if (t < numRealTokens) allCodes[t][cb] else 0 }
|
||
}
|
||
|
||
val t3 = System.currentTimeMillis()
|
||
val rawAudio = decodeChunked(allCodebooks, numRealTokens)
|
||
nlog("Decode: ${System.currentTimeMillis() - t3}ms")
|
||
|
||
// Trim trailing noise/silence: scan from end, find last loud frame
|
||
val audio = trimTrailingSilence(rawAudio)
|
||
nlog("Trimmed: ${rawAudio.size} → ${audio.size} samples (${(rawAudio.size-audio.size)/SR.toFloat()}s removed)")
|
||
|
||
val totalMs = System.currentTimeMillis() - t0
|
||
val audioDur = audio.size.toFloat() / SR
|
||
nlog("Total: ${totalMs}ms for ${audioDur}s")
|
||
|
||
// Save WAV file for validation
|
||
try {
|
||
val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav"
|
||
val fos = java.io.FileOutputStream(wavPath)
|
||
val dataLen = audio.size * 2
|
||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||
for (s in audio) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("WAV saved: $wavPath (${audio.size} samples)")
|
||
} catch (e: Exception) {
|
||
nlog("WAV save failed: ${e.message}")
|
||
}
|
||
|
||
return audio
|
||
}
|
||
|
||
/** PTE pipeline from pre-computed embeddings (prefill + trailing). */
|
||
private fun runInterleavedPteFromEmbeds(
|
||
prefillEmbeds: List<FloatArray>, trailingEmbeds: List<FloatArray>, maxGenTokens: Int
|
||
): Array<IntArray> {
|
||
val talkerMod = talkerPteModule ?: return emptyArray()
|
||
val cpMod = cpPteModule ?: return emptyArray()
|
||
val tCos = talkerPteRotaryCos ?: return emptyArray()
|
||
val tSin = talkerPteRotarySin ?: return emptyArray()
|
||
val eosE = ttsEosEmbed ?: return emptyArray()
|
||
val padE = ttsPadEmbed ?: return emptyArray()
|
||
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
|
||
val tkvSize = TALKER_HEADS * TALKER_PTE_KV_LEN * TALKER_HEAD_DIM
|
||
var tK = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
|
||
var tV = Array(TALKER_LAYERS) { FloatArray(tkvSize) }
|
||
val maskData = FloatArray(TALKER_PTE_KV_LEN) { -1e9f }
|
||
|
||
var pos = 0; var currentCb0 = -1; var pastHidden: FloatArray? = null
|
||
var trailingIdx = 0
|
||
|
||
// Helper to run one talker step
|
||
fun talkerStep(emb: FloatArray): Pair<FloatArray, FloatArray> {
|
||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||
|
||
val posIdx = minOf(pos, tCos.size / TALKER_HEAD_DIM - 1)
|
||
val cosSlice = FloatArray(TALKER_HEAD_DIM); System.arraycopy(tCos, posIdx * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||
val sinSlice = FloatArray(TALKER_HEAD_DIM); System.arraycopy(tSin, posIdx * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||
|
||
val inputs = mutableListOf(
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(emb, longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||
org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||
)
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
inputs.add(org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
inputs.add(org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||
}
|
||
|
||
val out = talkerMod.forward(*inputs.toTypedArray())
|
||
val hidden = out[0].toTensor().dataAsFloatArray
|
||
val logits = out[1].toTensor().dataAsFloatArray
|
||
for (i in 0 until TALKER_LAYERS) {
|
||
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||
}
|
||
pos++
|
||
return Pair(hidden, logits)
|
||
}
|
||
|
||
// Capture embeds for C++ reuse
|
||
val capturedEmbeds = mutableListOf<FloatArray>()
|
||
|
||
// ===== PREFILL =====
|
||
val tPrefill = System.currentTimeMillis()
|
||
for (step in prefillEmbeds.indices) {
|
||
capturedEmbeds.add(prefillEmbeds[step].clone())
|
||
val (h, logits) = talkerStep(prefillEmbeds[step])
|
||
pastHidden = h
|
||
if (step == prefillEmbeds.size - 1) {
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||
}
|
||
}
|
||
nlog("Prefill (PTE): ${System.currentTimeMillis() - tPrefill}ms, ${prefillEmbeds.size} steps, cb0=$currentCb0")
|
||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
||
|
||
// ===== GENERATION =====
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
for (genStep in 0 until maxGenTokens) {
|
||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||
|
||
val tCp0 = System.currentTimeMillis()
|
||
val cpCodes = runCpPte(pastHidden!!, currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp0
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||
|
||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||
|
||
// Next talker input: use pre-computed decode embed if available
|
||
val nextEmbed: FloatArray
|
||
if (trailingIdx < trailingEmbeds.size) {
|
||
nextEmbed = trailingEmbeds[trailingIdx]; trailingIdx++
|
||
} else {
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
nextEmbed = sumEmb(codecSum, padE)
|
||
}
|
||
capturedEmbeds.add(nextEmbed.clone())
|
||
|
||
val tTalker0 = System.currentTimeMillis()
|
||
val (h, logits) = talkerStep(nextEmbed)
|
||
totalTalkerMs += System.currentTimeMillis() - tTalker0
|
||
pastHidden = h
|
||
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
val nextCb0 = sampleTopK(logits, 0.9f, 50)
|
||
|
||
if (nextCb0 == CODEC_EOS) { nlog("EOS at step ${genStep + 2}"); break }
|
||
if (generatedCb0.size >= 9 && generatedCb0.takeLast(9).all { it == nextCb0 }) {
|
||
nlog("Degeneration at step ${genStep + 2}"); break
|
||
}
|
||
currentCb0 = nextCb0
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker(PTE): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP(PTE): ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
|
||
// Save captured embeds
|
||
if (capturedEmbeds.isNotEmpty()) {
|
||
try {
|
||
val capPath = "/data/local/tmp/kazeia/captured_embeds.bin"
|
||
val nPrefill = prefillEmbeds.size
|
||
val fos = java.io.FileOutputStream(capPath)
|
||
val hdr = java.nio.ByteBuffer.allocate(8).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
hdr.putInt(nPrefill); hdr.putInt(capturedEmbeds.size)
|
||
fos.write(hdr.array())
|
||
for (emb in capturedEmbeds) {
|
||
val buf = java.nio.ByteBuffer.allocate(TALKER_DIM * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN)
|
||
for (v in emb) buf.putFloat(v)
|
||
fos.write(buf.array())
|
||
}
|
||
fos.close()
|
||
nlog("Captured ${capturedEmbeds.size} embeds → $capPath (${nPrefill} prefill + ${capturedEmbeds.size - nPrefill} decode)")
|
||
} catch (e: Exception) { nlog("Capture save failed: ${e.message}") }
|
||
}
|
||
|
||
return allCodes.toTypedArray()
|
||
}
|
||
|
||
/** Multi-segment pipeline: process each segment independently, concatenate audio. */
|
||
private fun generateMultiSegment(bb: ByteBuffer, nSegments: Int, t0: Long): ShortArray {
|
||
nlog("Multi-segment: $nSegments segments")
|
||
val allAudio = mutableListOf<ShortArray>()
|
||
|
||
// Ensure data arrays loaded
|
||
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
|
||
if (cpAllHeads == null) {
|
||
val hf = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin")
|
||
if (hf.exists()) { val hb = hf.readBytes(); cpAllHeads = FloatArray(hb.size/4); ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!) }
|
||
}
|
||
if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy")
|
||
if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy")
|
||
if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy")
|
||
if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy")
|
||
if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("$mpath/talker_pte_rotary_cos.npy")
|
||
if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("$mpath/talker_pte_rotary_sin.npy")
|
||
if (ttsEosEmbed == null) { val sp = loadNpy("$mpath/tts_special_embeds.npy"); ttsBosEmbed=sp.sliceArray(0 until TALKER_DIM); ttsEosEmbed=sp.sliceArray(TALKER_DIM until 2*TALKER_DIM); ttsPadEmbed=sp.sliceArray(2*TALKER_DIM until 3*TALKER_DIM) }
|
||
|
||
for (seg in 0 until nSegments) {
|
||
val nPrefill = bb.int; val nTotal = bb.int
|
||
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
nlog("Segment ${seg+1}/$nSegments: $nPrefill prefill + ${nTotal-nPrefill} decode")
|
||
|
||
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
||
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
||
val nTrailing = nTotal - nPrefill
|
||
val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr ->
|
||
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM)
|
||
} else null
|
||
|
||
val maxGen = minOf(nTrailing * 4 + 20, 90 - nPrefill)
|
||
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||
prefillFlat, nPrefill, trailingFlat, nTrailing,
|
||
codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0),
|
||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||
maxGen
|
||
)
|
||
if (flat == null || flat.isEmpty()) continue
|
||
val nTokens = flat.size / NUM_CODEBOOKS
|
||
val segCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } }
|
||
nlog(" → $nTokens tokens generated")
|
||
|
||
val padLen = maxOf(nTokens, SEQ_LEN)
|
||
val codebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < nTokens) segCodes[t][cb] else 0 } }
|
||
val segAudio = decodeChunked(codebooks, nTokens)
|
||
allAudio.add(segAudio)
|
||
nlog(" → ${segAudio.size/SR.toFloat()}s audio decoded")
|
||
}
|
||
|
||
// The fp16 NPU drift means each segment ends with a "filler" tail (page page beg).
|
||
// We detect the boundary on the audio itself: scan windows from the end, find the
|
||
// last sustained "speech-energy" window vs the trailing low-energy filler. Use
|
||
// the running peak energy as a per-segment reference so quiet vs loud voices
|
||
// both work.
|
||
val gapMs = 120
|
||
val gapSamples = SR * gapMs / 1000
|
||
val gap = ShortArray(gapSamples)
|
||
|
||
fun trimSegmentTail(audio: ShortArray): ShortArray {
|
||
if (audio.size < SR / 2) return audio
|
||
val winMs = 40
|
||
val winSamples = SR * winMs / 1000 // 960 samples
|
||
val nWin = audio.size / winSamples
|
||
if (nWin < 6) return audio
|
||
val rms = FloatArray(nWin)
|
||
for (w in 0 until nWin) {
|
||
var s = 0.0
|
||
val o = w * winSamples
|
||
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
|
||
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
|
||
}
|
||
// Reference: peak in the first 70% of the segment (definitely speech)
|
||
var peak = 0f
|
||
val refEnd = (nWin * 7) / 10
|
||
for (w in 0 until refEnd) if (rms[w] > peak) peak = rms[w]
|
||
// Walk backward from the end, find the last window where energy >= 35% of peak
|
||
// sustained over 3 consecutive windows (filler tail has lower & more variable energy).
|
||
val thr = peak * 0.35f
|
||
var lastSpeech = nWin - 1
|
||
for (w in nWin - 1 downTo 2) {
|
||
if (rms[w] >= thr && rms[w-1] >= thr && rms[w-2] >= thr) { lastSpeech = w; break }
|
||
}
|
||
// Add a tiny fade-out window after to keep the last syllable's natural decay
|
||
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
|
||
val keepSamples = (keepWin + 1) * winSamples
|
||
return audio.copyOf(keepSamples)
|
||
}
|
||
|
||
val trimmedAudio = allAudio.map { trimSegmentTail(it) }
|
||
for ((i, seg) in trimmedAudio.withIndex()) {
|
||
nlog(" Trim seg ${i+1}: ${allAudio[i].size/SR.toFloat()}s -> ${seg.size/SR.toFloat()}s")
|
||
}
|
||
|
||
val totalSamples = trimmedAudio.sumOf { it.size } + (trimmedAudio.size - 1) * gapSamples
|
||
val result = ShortArray(totalSamples)
|
||
var offset = 0
|
||
for ((i, seg) in trimmedAudio.withIndex()) {
|
||
System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size
|
||
if (i < trimmedAudio.size - 1) { System.arraycopy(gap, 0, result, offset, gapSamples); offset += gapSamples }
|
||
}
|
||
|
||
val totalMs = System.currentTimeMillis() - t0
|
||
nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)")
|
||
|
||
// Save WAV
|
||
try {
|
||
val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav"
|
||
val fos = java.io.FileOutputStream(wavPath)
|
||
val dataLen = result.size * 2
|
||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||
for (s in result) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("WAV saved: $wavPath ($totalSamples samples)")
|
||
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
|
||
|
||
return result
|
||
}
|
||
|
||
/** No-op trim — garbage post-text has high energy, can't distinguish from speech.
|
||
* Length is controlled by maxTokens = trailing count instead. */
|
||
private fun trimTrailingSilence(audio: ShortArray): ShortArray = audio
|
||
|
||
/** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */
|
||
private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray {
|
||
nlog("Full pipeline (Hexagon) from: $embedsPath")
|
||
val t0 = System.currentTimeMillis()
|
||
|
||
val bytes = File(embedsPath).readBytes()
|
||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
val nPrefill = bb.int; val nTotal = bb.int
|
||
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||
|
||
hexReset()
|
||
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
|
||
// Prefill
|
||
val tPrefill = System.currentTimeMillis()
|
||
val prefillResults = hexForward(embeds.take(nPrefill))
|
||
nlog("Prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
|
||
|
||
if (prefillResults.isEmpty()) return ShortArray(0)
|
||
var pastHidden = prefillResults.last().first
|
||
val prefillLogits = prefillResults.last().second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||
|
||
// Diagnostic-only cb0 sampling controls (DEBUG builds only).
|
||
// Empirical finding (thesis-relevant): temperature and greedy vs stochastic
|
||
// give similar perceptual audio quality because the cross-architecture logits
|
||
// drift (x86 AVX ↔ ARM64 HMX fp16) is the dominant error source, not sampling.
|
||
// Production always uses temp=0.9 stochastic matching Python's default.
|
||
val cb0Temp: Float = diagFile("cb0_temp")?.let {
|
||
try { it.readText().trim().toFloat() } catch (e: Exception) { 0.9f }
|
||
} ?: 0.9f
|
||
if (cb0Temp != 0.9f) nlog("cb0_temp override: $cb0Temp")
|
||
|
||
val forceGreedyCb0 = diagFlag("force_greedy_cb0")
|
||
if (forceGreedyCb0) nlog("force_greedy_cb0: using argmax for cb0 sampling (incl. prefill)")
|
||
|
||
// Apply greedy/temperature to PREFILL sample too — previously this was always
|
||
// sampleTopK(0.9) which made step 0 (the "B" of "Bonjour") stochastic even when
|
||
// greedy was requested. This is the most audible source of tremor on the very
|
||
// first syllable.
|
||
var currentCb0 = if (forceGreedyCb0) {
|
||
var b = 0; var bv = prefillLogits[0]
|
||
for (j in 1 until prefillLogits.size) if (prefillLogits[j] > bv) { bv = prefillLogits[j]; b = j }
|
||
b
|
||
} else sampleTopK(prefillLogits, cb0Temp, 50)
|
||
nlog("Prefill cb0=$currentCb0 (greedy=$forceGreedyCb0)")
|
||
|
||
// Voice-cloned single-file replay: nTotal-nPrefill captured decode embeds that
|
||
// already contain codec_sum_python + text_hidden at each step. Feed them directly
|
||
// to the Hexagon talker (no re-sum from NPU codes). Stops at maxGen = nTrailing
|
||
// because Python already decided where to stop. Hexagon's dynamic KV means no
|
||
// eviction of the prefill context — this is the whole point of the test.
|
||
val nTrailingHex = nTotal - nPrefill
|
||
val isVoiceCloned = nTrailingHex >= 30
|
||
nlog("Hex voice-cloned mode: $isVoiceCloned ($nTrailingHex decode steps)")
|
||
|
||
// Diagnostic (DEBUG only): inject Python-captured cb0 at each step. This is
|
||
// how we proved tablet CP is 99.76% bit-identical to PyTorch when cb0 matches
|
||
// — the tremor is entirely caused by divergent cb0 trajectories, not CP drift.
|
||
val injectPyCb0 = diagFlag("force_inject_pycb0")
|
||
val pyCodesFile = File("/data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
|
||
var pyCb0: IntArray? = null
|
||
if (injectPyCb0 && pyCodesFile.exists()) {
|
||
val pb = ByteBuffer.wrap(pyCodesFile.readBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||
val n = pb.int; val cb = pb.int
|
||
pyCb0 = IntArray(n) {
|
||
val v = pb.int // cb0
|
||
for (j in 1 until cb) pb.int // skip cb1..cb(cb-1)
|
||
v
|
||
}
|
||
nlog("injectPyCb0: loaded ${pyCb0!!.size} cb0 values (cb=$cb) from python_codes.bin")
|
||
} else if (injectPyCb0) {
|
||
nlog("injectPyCb0: flag set but python_codes.bin not found at ${pyCodesFile.absolutePath}")
|
||
}
|
||
|
||
// Seed the very first cb0 with Python's value too if injecting — otherwise step 0's
|
||
// cb0 comes from tablet prefill sampling and doesn't match.
|
||
if (injectPyCb0 && pyCb0 != null && pyCb0!!.isNotEmpty()) {
|
||
currentCb0 = pyCb0!![0]
|
||
nlog("injectPyCb0: overriding step 0 cb0 -> ${pyCb0!![0]}")
|
||
}
|
||
|
||
for (genStep in 0 until nTrailingHex) {
|
||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||
|
||
if (genStep < 3 || genStep % 20 == 0) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||
|
||
// Feed captured decode embed verbatim (already includes Python's codec_sum + text)
|
||
val nextEmbed = embeds[nPrefill + genStep]
|
||
|
||
val tT = System.currentTimeMillis()
|
||
val results = hexForward(listOf(nextEmbed))
|
||
totalTalkerMs += System.currentTimeMillis() - tT
|
||
if (results.isEmpty()) { nlog("Hex forward returned empty at step ${genStep+1}"); break }
|
||
pastHidden = results[0].first
|
||
val logits = results[0].second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
|
||
// Determine next cb0. If injecting, snap to Python's captured cb0 for step+1.
|
||
// This isolates pure CP numerical divergence: tablet gets identical embeds AND
|
||
// identical cb0 at every step, so any cb1..cb15 divergence must come from CP
|
||
// numerics (ONNX Runtime vs PyTorch) alone.
|
||
val nextIdx = genStep + 1
|
||
currentCb0 = if (injectPyCb0 && pyCb0 != null && nextIdx < pyCb0!!.size) {
|
||
pyCb0!![nextIdx]
|
||
} else if (forceGreedyCb0) {
|
||
var b = 0; var bv = logits[0]
|
||
for (j in 1 until logits.size) if (logits[j] > bv) { bv = logits[j]; b = j }
|
||
b
|
||
} else sampleTopK(logits, cb0Temp, 50)
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Generated $n tokens | Talker(HEX): ${totalTalkerMs}ms (${totalTalkerMs/maxOf(n,1)}ms/step) | CP: ${totalCpMs}ms (${totalCpMs/maxOf(n,1)}ms/step)")
|
||
|
||
// Stop hexagon runners before decode ONLY if decoder uses HTP (DSP conflict)
|
||
if (!decoderOnCpu && !decoderOnGpu) {
|
||
hexStopRunner()
|
||
}
|
||
|
||
if (n == 0) return ShortArray(0)
|
||
val padLen = maxOf(n, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < n) allCodes[t][cb] else 0 } }
|
||
|
||
// Diagnostic: dump our generated codes for comparison with Python capture
|
||
try {
|
||
val codesPath = "/data/local/tmp/kazeia/tablet_codes.bin"
|
||
val fos = java.io.FileOutputStream(codesPath)
|
||
val buf = ByteBuffer.allocate(8 + n * NUM_CODEBOOKS * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||
buf.putInt(n); buf.putInt(NUM_CODEBOOKS)
|
||
for (t in 0 until n) for (cb in 0 until NUM_CODEBOOKS) buf.putInt(allCodes[t][cb])
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("Tablet codes dumped: $codesPath ($n steps × $NUM_CODEBOOKS)")
|
||
} catch (e: Exception) { nlog("Codes dump failed: ${e.message}") }
|
||
|
||
val audio = decodeChunked(allCodebooks, n)
|
||
nlog("Total: ${System.currentTimeMillis() - t0}ms for ${audio.size.toFloat()/SR}s")
|
||
|
||
// Save WAV for inspection
|
||
try {
|
||
val wavPath = "/data/local/tmp/kazeia/kazeia_HEX.wav"
|
||
val fos = java.io.FileOutputStream(wavPath)
|
||
val dataLen = audio.size * 2
|
||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||
for (s in audio) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
nlog("WAV saved (Hex): $wavPath (${audio.size} samples)")
|
||
} catch (e: Exception) { nlog("WAV save (Hex) failed: ${e.message}") }
|
||
|
||
return audio
|
||
}
|
||
|
||
/**
|
||
* Run the Hexagon talker + CP generation loop on a single segment's embeds
|
||
* and return the decoded audio. Extracted from generateFromEmbedsHexagon so
|
||
* both single-segment playback and the streaming multi-segment path share
|
||
* exactly the same generation path (no quality drift between modes).
|
||
*
|
||
* Caller is responsible for hexReset() before first call of a request.
|
||
* Subsequent calls (segments 2..N in multi-segment mode) must hexReset()
|
||
* between segments so the talker KV-cache doesn't carry stale context.
|
||
*/
|
||
private fun runHexSegmentFromEmbeds(
|
||
prefillEmbeds: List<FloatArray>,
|
||
trailingEmbeds: List<FloatArray>,
|
||
segIdx: Int = 0
|
||
): ShortArray {
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
|
||
// Prefill
|
||
val tPrefill = System.currentTimeMillis()
|
||
val prefillResults = hexForward(prefillEmbeds)
|
||
nlog("Seg ${segIdx+1} prefill: ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
|
||
if (prefillResults.isEmpty()) return ShortArray(0)
|
||
|
||
var pastHidden = prefillResults.last().first
|
||
val prefillLogits = prefillResults.last().second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
|
||
|
||
val nTrailing = trailingEmbeds.size
|
||
for (genStep in 0 until nTrailing) {
|
||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||
|
||
val nextEmbed = trailingEmbeds[genStep]
|
||
val tT = System.currentTimeMillis()
|
||
val results = hexForward(listOf(nextEmbed))
|
||
totalTalkerMs += System.currentTimeMillis() - tT
|
||
if (results.isEmpty()) { nlog("Seg ${segIdx+1}: hex empty at step ${genStep+1}"); break }
|
||
pastHidden = results[0].first
|
||
val logits = results[0].second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||
}
|
||
|
||
val n = allCodes.size
|
||
nlog("Seg ${segIdx+1} generated $n tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
|
||
if (n == 0) return ShortArray(0)
|
||
|
||
// cb0 can hit CODEC_EOS (> CODEBOOK_SIZE) on longer phrases — the original
|
||
// single-segment Hexagon path never exercises this because the short Baer
|
||
// probe stayed well inside the decode budget. Clamp any out-of-vocab code
|
||
// to 0 (silence) so vqDecode can't read past the codebook buffer.
|
||
val padLen = maxOf(n, SEQ_LEN)
|
||
val allCodebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t ->
|
||
if (t < n) { val v = allCodes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
||
}
|
||
}
|
||
return decodeChunked(allCodebooks, n)
|
||
}
|
||
|
||
/**
|
||
* Streaming multi-segment variant of generateFromEmbeds. Reads a multi-segment
|
||
* embeds file, generates each segment via Hexagon talker + CP sequentially,
|
||
* and invokes `onSegmentReady(idx, audio)` the moment each segment's audio is
|
||
* decoded. The callback writes to an AudioTrack in the calling coroutine so
|
||
* playback begins as soon as segment 1 finishes (~5s for a 5s segment,
|
||
* instead of ~15s for the full phrase).
|
||
*
|
||
* Each segment's raw audio is saved to /data/local/tmp/kazeia/kazeia_stream_segN.wav
|
||
* and the final concatenated audio to /data/local/tmp/kazeia/kazeia_stream_full.wav
|
||
* so the caller can inspect individual segments for quality regressions.
|
||
*
|
||
* Single-segment files are supported as a degenerate case (nSegments=1) so
|
||
* the caller doesn't need to branch on format.
|
||
*/
|
||
fun generateFromEmbedsHexagonStreaming(
|
||
embedsPath: String,
|
||
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
|
||
): ShortArray {
|
||
if (!loaded || !useHexagonTalker) {
|
||
nlog("Streaming: Hexagon talker not ready")
|
||
return ShortArray(0)
|
||
}
|
||
val t0 = System.currentTimeMillis()
|
||
val bytes = File(embedsPath).readBytes()
|
||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
|
||
// Format detection (same logic as generateFromEmbedsPte):
|
||
// single: <i32 nPrefill> <i32 nTotal> <f32 × nTotal × 1024>
|
||
// multi : <i32 nSegments> [<i32 nPrefill_i> <i32 nTotal_i> <f32 ...>] × nSegments
|
||
val firstInt = bb.int
|
||
val secondInt = bb.int
|
||
val fileLen = bytes.size.toLong()
|
||
val singleSize = 8L + secondInt.toLong() * TALKER_DIM * 4
|
||
val isSingle = secondInt > 0 && secondInt < 100000 && fileLen == singleSize
|
||
val nSegments = if (isSingle) 1 else firstInt
|
||
bb.position(if (isSingle) 0 else 4)
|
||
nlog("Streaming: $nSegments segment(s), ${bytes.size} bytes")
|
||
|
||
// Ensure a fresh runner connection for this request. Between requests
|
||
// the KV-cache carries stale state from the previous generation and
|
||
// prefill logits come out as garbage on segment 1.
|
||
hexReset()
|
||
|
||
val segmentAudios = mutableListOf<ShortArray>()
|
||
val gapSamples = SR * 120 / 1000
|
||
val gap = ShortArray(gapSamples)
|
||
|
||
for (seg in 0 until nSegments) {
|
||
// Between segments the talker KV-cache must be reset so segment 2's
|
||
// prefill logits don't contain segment 1's state. Skipping this
|
||
// produces garbled speech from segment 2 onwards.
|
||
if (seg > 0) hexReset()
|
||
|
||
val nPrefill = bb.int
|
||
val nTotal = bb.int
|
||
val prefill = List(nPrefill) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
val nTrailing = nTotal - nPrefill
|
||
val trailing = List(nTrailing) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||
nlog("Streaming seg ${seg+1}/$nSegments: $nPrefill prefill + $nTrailing decode")
|
||
|
||
val tSeg = System.currentTimeMillis()
|
||
val audio = runHexSegmentFromEmbeds(prefill, trailing, seg)
|
||
val segMs = System.currentTimeMillis() - tSeg
|
||
nlog("Streaming seg ${seg+1}/$nSegments: ${audio.size/SR.toFloat()}s audio in ${segMs}ms")
|
||
|
||
// Emit to caller immediately so playback starts now. WAV dump is
|
||
// synchronous here; for true zero-lag streaming the caller can
|
||
// write to AudioTrack on its own dispatcher from the callback.
|
||
segmentAudios.add(audio)
|
||
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${seg+1}.wav", audio)
|
||
onSegmentReady?.invoke(seg, audio)
|
||
}
|
||
|
||
// Concatenate with short gaps between segments for the full-file WAV.
|
||
// Playback path already inserted perceptual spacing via the callback order.
|
||
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
|
||
val concat = ShortArray(total)
|
||
var off = 0
|
||
for ((i, s) in segmentAudios.withIndex()) {
|
||
System.arraycopy(s, 0, concat, off, s.size); off += s.size
|
||
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
|
||
}
|
||
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
|
||
nlog("Streaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s ($nSegments seg)")
|
||
return concat
|
||
}
|
||
|
||
/**
|
||
* Look up a single token in the fp16 full-vocab text-embedding table and
|
||
* return it as fp32. Uses direct ByteBuffer arithmetic so we don't
|
||
* allocate a new buffer per token — for a typical 50-token sentence the
|
||
* inner loop runs 50 × 1024 fp16→fp32 conversions.
|
||
*/
|
||
private fun textEmbFromFull(tokenId: Int): FloatArray {
|
||
val buf = textEmbedsFullBuf ?: error("Stage 2 full embeddings not loaded")
|
||
val clamped = tokenId.coerceIn(0, textEmbedsFullLen - 1)
|
||
val base = clamped * TALKER_DIM * 2
|
||
val out = FloatArray(TALKER_DIM)
|
||
synchronized(buf) {
|
||
// MappedByteBuffer has mutable position; guard in case two
|
||
// coroutines ever race on tokenizer output concurrently.
|
||
buf.position(base)
|
||
for (j in 0 until TALKER_DIM) {
|
||
val bits = buf.short
|
||
out[j] = halfToFloat(bits)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
/** IEEE 754 fp16 -> fp32 conversion. Handles subnormals, inf and NaN
|
||
* exactly the way Python's `torch.float16.to(float32)` does. */
|
||
private fun halfToFloat(h: Short): Float {
|
||
val bits = h.toInt() and 0xffff
|
||
val sign = (bits ushr 15) and 0x1
|
||
val exp = (bits ushr 10) and 0x1f
|
||
val mant = bits and 0x3ff
|
||
val f32: Int = when {
|
||
exp == 0 && mant == 0 -> sign shl 31
|
||
exp == 0 -> {
|
||
// Subnormal: normalize by shifting until leading 1 appears.
|
||
var e = -14; var m = mant
|
||
while ((m and 0x400) == 0) { m = m shl 1; e -= 1 }
|
||
val e32 = e + 127
|
||
(sign shl 31) or (e32 shl 23) or ((m and 0x3ff) shl 13)
|
||
}
|
||
exp == 0x1f -> (sign shl 31) or (0xff shl 23) or (mant shl 13) // inf / NaN
|
||
else -> (sign shl 31) or ((exp - 15 + 127) shl 23) or (mant shl 13)
|
||
}
|
||
return Float.fromBits(f32)
|
||
}
|
||
|
||
// Stage 3 streaming session state. A session opens a single AudioTrack
|
||
// and a background worker that pulls sentences from a Channel and
|
||
// generates+plays them sequentially. Audio for sentence N plays while
|
||
// sentence N+1's codes are being generated on Hexagon+CP, giving a
|
||
// smoother LLM-to-voice UX than "generate all, play all".
|
||
private var sessionTrack: AudioTrack? = null
|
||
private var sessionChannel: kotlinx.coroutines.channels.Channel<String>? = null
|
||
private var sessionJob: kotlinx.coroutines.Job? = null
|
||
|
||
/**
|
||
* Open a streaming TTS session backed by a persistent AudioTrack. After
|
||
* this returns, callers feed sentences one by one via enqueueSentence();
|
||
* each sentence is voice-cloned and its audio is written to the shared
|
||
* track as soon as it's decoded. Call endStreamingSession() to flush
|
||
* the queue and release the track.
|
||
*/
|
||
fun startStreamingSession() {
|
||
if (sessionTrack != null) return // already open
|
||
val track = AudioTrack.Builder()
|
||
.setAudioAttributes(AudioAttributes.Builder()
|
||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||
.build())
|
||
.setAudioFormat(AudioFormat.Builder()
|
||
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||
.setSampleRate(SR)
|
||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||
.build())
|
||
.setBufferSizeInBytes(SR * 30 * 2) // 30 s buffer; AudioTrack
|
||
// paces writes when full.
|
||
.setTransferMode(AudioTrack.MODE_STREAM)
|
||
.build()
|
||
track.play()
|
||
val chan = kotlinx.coroutines.channels.Channel<String>(
|
||
capacity = kotlinx.coroutines.channels.Channel.UNLIMITED
|
||
)
|
||
val job = kotlinx.coroutines.CoroutineScope(
|
||
kotlinx.coroutines.Dispatchers.IO
|
||
).launch {
|
||
var segIdx = 0
|
||
for (sentence in chan) {
|
||
try {
|
||
val audio = generateSegmentAudioVC(sentence, segIdx)
|
||
if (audio.isNotEmpty()) track.write(audio, 0, audio.size)
|
||
segIdx++
|
||
} catch (e: Exception) {
|
||
nlog("session seg $segIdx error: ${e.message}")
|
||
}
|
||
}
|
||
}
|
||
sessionTrack = track; sessionChannel = chan; sessionJob = job
|
||
nlog("streaming session opened")
|
||
}
|
||
|
||
/**
|
||
* Enqueue a sentence for background generation and playback in the
|
||
* session opened by startStreamingSession(). Non-blocking: returns
|
||
* immediately. Sentences play in the order they were enqueued.
|
||
*/
|
||
fun enqueueSentence(sentence: String) {
|
||
val chan = sessionChannel ?: run { nlog("enqueueSentence: no session open"); return }
|
||
val r = chan.trySend(sentence)
|
||
if (r.isFailure) nlog("enqueueSentence: channel full / closed")
|
||
}
|
||
|
||
/**
|
||
* Close the sentence queue, wait for all pending audio to finish
|
||
* generating (playback may continue briefly after as the AudioTrack
|
||
* drains), then release the shared track. Safe to call more than once.
|
||
*/
|
||
suspend fun endStreamingSession() {
|
||
val chan = sessionChannel ?: return
|
||
chan.close()
|
||
try { sessionJob?.join() } catch (_: Exception) {}
|
||
try {
|
||
sessionTrack?.let {
|
||
// Block until written samples have been consumed by the
|
||
// hardware so users aren't cut off mid-syllable.
|
||
it.stop(); it.release()
|
||
}
|
||
} catch (_: Exception) {}
|
||
sessionTrack = null; sessionChannel = null; sessionJob = null
|
||
nlog("streaming session closed")
|
||
}
|
||
|
||
/**
|
||
* Voice-clone a single sentence and return its decoded audio. Mirrors
|
||
* the per-segment body of synthesizeTextStreaming but without WAV save
|
||
* side effects, since the streaming session only cares about PCM going
|
||
* to the AudioTrack.
|
||
*/
|
||
private fun generateSegmentAudioVC(segText: String, segIdx: Int): ShortArray {
|
||
if (bpeTokenizer == null || textEmbedsFullBuf == null || damienVoicePrefix == null || damienVoiceSuffix == null) {
|
||
nlog("generateSegmentAudioVC: Stage 2 assets missing"); return ShortArray(0)
|
||
}
|
||
val prefix = damienVoicePrefix!!
|
||
val suffix = damienVoiceSuffix!!
|
||
val codecPadEmb = codecEmb(CODEC_PAD)
|
||
val ids = bpeTokenizer!!.encode(segText)
|
||
nlog("session seg $segIdx '${segText.take(60)}' → ${ids.size} tokens")
|
||
|
||
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
|
||
for (e in prefix) prefill.add(e)
|
||
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
||
for (e in suffix) prefill.add(e)
|
||
|
||
// See synthesizeTextStreaming for the rationale: generous absolute
|
||
// cap, dynamic EOS-rank-driven boost, safety floor at 50 % of
|
||
// expected to suppress early termination on short utterances.
|
||
val expectedSteps = (ids.size * 24) / 10
|
||
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
|
||
val eosBoostMinStep = expectedSteps / 2
|
||
|
||
// Backend dispatch: with the DSP-contention fix (force_hexagon removed)
|
||
// the Hexagon talker socket isn't opened. Fall back to the .pte path,
|
||
// which creates fresh KV arrays per call so no manual reset is needed.
|
||
val codes: Array<IntArray> = if (talkerSocket != null) {
|
||
hexReset()
|
||
runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
|
||
} else if (talkerPteModule != null && cpPteModule != null) {
|
||
runInterleavedPteFromEmbeds(prefill, emptyList(), maxGen)
|
||
} else {
|
||
nlog("generateSegmentAudioVC: no talker backend available"); return ShortArray(0)
|
||
}
|
||
if (codes.isEmpty()) return ShortArray(0)
|
||
|
||
val n = codes.size
|
||
val padLen = maxOf(n, SEQ_LEN)
|
||
val codebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t ->
|
||
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
||
}
|
||
}
|
||
// 40 ms fade-out so the EOS-clipped tail decays naturally before
|
||
// the AudioTrack write.
|
||
return fadeOut(decodeChunked(codebooks, n), 40)
|
||
}
|
||
|
||
/**
|
||
* Run the Hexagon talker + CP generation loop with a fully pre-built
|
||
* prefill (voice prefix + all text tokens). Same decode recipe as
|
||
* runInterleavedHexagon's inner loop: at each step the next talker
|
||
* input is codecSum (codec embedding of previous codes) + tts_pad
|
||
* (since all text has already been consumed in the prefill). On EOS,
|
||
* terminate early. Returns the generated [step, codebook] codes.
|
||
*/
|
||
private fun runHexGenWithPrefill(
|
||
prefill: List<FloatArray>,
|
||
maxGen: Int,
|
||
// Floor on how many steps to generate before the dynamic boost is
|
||
// allowed to fire. Prevents short-text segments from terminating on
|
||
// a stray "EOS-leaning" hidden state during the first few frames.
|
||
// Callers pass ~50 % of expected speech length as a safety floor.
|
||
eosBoostMinStep: Int = -1,
|
||
// Once EOS rank falls below this threshold (model itself is
|
||
// "thinking about stopping"), start adding eosBoostScale per step
|
||
// until argmax flips to EOS. Empirically EOS rank plateaus ~150-700
|
||
// mid-speech and dips to ~50-60 right at the natural end, so this
|
||
// catches the model's intent without a fixed length budget.
|
||
eosRankTrigger: Int = 60,
|
||
eosBoostScale: Float = 4.0f
|
||
): Array<IntArray> {
|
||
val padE = ttsPadEmbed ?: return emptyArray()
|
||
val eosE = ttsEosEmbed ?: return emptyArray()
|
||
val allCodes = mutableListOf<IntArray>()
|
||
val generatedCb0 = mutableListOf<Int>()
|
||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||
|
||
val tPrefill = System.currentTimeMillis()
|
||
val prefillResults = hexForward(prefill)
|
||
nlog("VC prefill (Hex): ${System.currentTimeMillis() - tPrefill}ms, ${prefillResults.size} steps")
|
||
if (prefillResults.isEmpty()) return emptyArray()
|
||
|
||
var pastHidden = prefillResults.last().first
|
||
val prefillLogits = prefillResults.last().second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) prefillLogits[j] = Float.NEGATIVE_INFINITY }
|
||
var currentCb0 = sampleTopK(prefillLogits, 0.9f, 50)
|
||
nlog("VC prefill done: first cb0=$currentCb0")
|
||
|
||
// After the text has been fully consumed in prefill, Python's voice-
|
||
// clone loop feeds tts_eos once, then tts_pad for every subsequent
|
||
// decode step. We follow the same schedule so the model's attention
|
||
// sees the same "text exhausted" signal it was trained with.
|
||
var eosFedOnce = false
|
||
// Counter for the dynamic EOS boost — once armed, increments every
|
||
// step monotonically. Arming requires 3 CONSECUTIVE steps below
|
||
// the rank trigger (transient dips during normal speech aren't
|
||
// enough); this keeps short sentences from terminating mid-word
|
||
// on a fluke low rank.
|
||
var boostStepsActive = 0
|
||
var consecLowRank = 0
|
||
for (genStep in 0 until maxGen) {
|
||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||
val tCp = System.currentTimeMillis()
|
||
val cpCodes = runCodePredictorInterleaved(pastHidden, currentCb0)
|
||
for (cb in 1 until NUM_CODEBOOKS) codes[cb] = cpCodes[cb - 1]
|
||
allCodes.add(codes); generatedCb0.add(currentCb0)
|
||
totalCpMs += System.currentTimeMillis() - tCp
|
||
|
||
val codecSum = FloatArray(TALKER_DIM)
|
||
addEmb(codecSum, codecEmb(codes[0]))
|
||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||
val nextEmbed = if (!eosFedOnce) { eosFedOnce = true; sumEmb(codecSum, eosE) } else sumEmb(codecSum, padE)
|
||
|
||
val tT = System.currentTimeMillis()
|
||
val results = hexForward(listOf(nextEmbed))
|
||
totalTalkerMs += System.currentTimeMillis() - tT
|
||
if (results.isEmpty()) { nlog("VC: hex empty at step ${genStep+1}"); break }
|
||
pastHidden = results[0].first
|
||
val logits = results[0].second
|
||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||
|
||
// Dynamic EOS detection: track current EOS rank, require 3
|
||
// consecutive low-rank steps before arming. Single-step rank
|
||
// dips happen mid-utterance during low-energy phonemes and
|
||
// would otherwise trigger premature termination on short
|
||
// sentences. Once armed, the boost accumulates monotonically.
|
||
val eosLogit0 = logits[CODEC_EOS]
|
||
var eosRank = 0
|
||
for (j in logits.indices) if (logits[j] > eosLogit0) eosRank++
|
||
if (boostStepsActive == 0 && genStep >= eosBoostMinStep) {
|
||
if (eosRank < eosRankTrigger) {
|
||
consecLowRank++
|
||
if (consecLowRank >= 3) {
|
||
nlog("VC boost armed at step ${genStep+1} (EOS rank $eosRank, 3 consecutive low)")
|
||
boostStepsActive = 1
|
||
}
|
||
} else {
|
||
consecLowRank = 0
|
||
}
|
||
} else if (boostStepsActive > 0) {
|
||
boostStepsActive++
|
||
}
|
||
if (boostStepsActive > 0) {
|
||
logits[CODEC_EOS] += boostStepsActive * eosBoostScale
|
||
}
|
||
|
||
var argmax = 0; var argmaxV = logits[0]
|
||
for (j in 1 until logits.size) if (logits[j] > argmaxV) { argmaxV = logits[j]; argmax = j }
|
||
if (argmax == CODEC_EOS) { nlog("VC EOS (boosted argmax) at step ${genStep+1}"); break }
|
||
|
||
currentCb0 = sampleTopK(logits, 0.9f, 200)
|
||
if (currentCb0 == CODEC_EOS) { nlog("VC EOS (sampled) at step ${genStep+1}"); break }
|
||
|
||
// Degeneracy guard #1: 9 consecutive identical cb0 → stop.
|
||
// Catches the simplest stuck loop.
|
||
val nHist = generatedCb0.size
|
||
if (nHist >= 9) {
|
||
val last = generatedCb0[nHist - 1]
|
||
var allSame = true
|
||
for (i in nHist - 9 until nHist) if (generatedCb0[i] != last) { allSame = false; break }
|
||
if (allSame && currentCb0 == last) {
|
||
nlog("VC degen: cb0=$last repeated ≥9× at step ${genStep+1}, stopping")
|
||
break
|
||
}
|
||
}
|
||
|
||
// Degeneracy guard #2: low diversity in the recent window.
|
||
// The "page beg beg" filler doesn't repeat a single token — it
|
||
// cycles through 2-3 tokens. If the last 12 cb0 contain fewer
|
||
// than 4 unique values, the talker is in the cycle and we stop
|
||
// before the audio degrades further.
|
||
if (nHist >= 12) {
|
||
val recent = HashSet<Int>()
|
||
for (i in nHist - 12 until nHist) recent.add(generatedCb0[i])
|
||
if (recent.size < 4) {
|
||
nlog("VC degen: only ${recent.size} unique cb0 in last 12 at step ${genStep+1}, stopping")
|
||
break
|
||
}
|
||
}
|
||
}
|
||
nlog("VC gen: ${allCodes.size} tokens | Talker(HEX): ${totalTalkerMs}ms | CP: ${totalCpMs}ms")
|
||
// Diagnostic: log full cb0 sequence so we can correlate against the
|
||
// page-beg region in the produced audio.
|
||
val cb0Str = generatedCb0.joinToString(",")
|
||
nlog("VC cb0 sequence: $cb0Str")
|
||
return allCodes.toTypedArray()
|
||
}
|
||
|
||
/**
|
||
* Split text into short segments for the streaming pipeline. Reproduces
|
||
* the behaviour of scripts/prepare_tts_segments.py but on-device so the
|
||
* tablet doesn't depend on a PC-side preprocessor. The 120-character
|
||
* target matches the length where the talker still terminates reliably
|
||
* on EOS; longer segments risk the auto-repressor's repetition penalty
|
||
* cutting decode short of the full phrase.
|
||
*/
|
||
private fun splitSentences(text: String, maxChars: Int = 120): List<String> {
|
||
val first = text.trim().split(Regex("(?<=[.!?;:])\\s+"))
|
||
val out = mutableListOf<String>()
|
||
for (part in first) {
|
||
if (part.length <= maxChars) {
|
||
if (part.isNotBlank()) out.add(part.trim())
|
||
continue
|
||
}
|
||
// Break overlong sentences at commas, greedily packing sub-parts
|
||
// back together up to maxChars so we don't over-split.
|
||
val subs = part.split(Regex("(?<=,)\\s+"))
|
||
var current = ""
|
||
for (s in subs) {
|
||
if (current.isNotEmpty() && current.length + s.length > maxChars) {
|
||
out.add(current.trim()); current = s
|
||
} else {
|
||
current = if (current.isEmpty()) s else "$current $s"
|
||
}
|
||
}
|
||
if (current.isNotBlank()) out.add(current.trim())
|
||
}
|
||
return if (out.isEmpty()) listOf(text) else out
|
||
}
|
||
|
||
/**
|
||
* Tokenize `text` on-device and stream the synthesized audio segment by
|
||
* segment via `onSegmentReady`. The PC-side prep script becomes optional —
|
||
* this is the first path in the app where TTS runs fully offline on
|
||
* arbitrary LLM output.
|
||
*
|
||
* For each segment:
|
||
* 1. BPE tokenize the segment text (same algorithm as Python's
|
||
* Qwen2Tokenizer).
|
||
* 2. Look up each token ID in the fp16 full-vocab table, converted to
|
||
* fp32 on the fly. One embedding per token.
|
||
* 3. Call the existing runInterleavedHexagon loop — which already
|
||
* synthesizes its own decode inputs via codec_sum + trailing text —
|
||
* so we reuse the same prefill construction and generation path that
|
||
* runs today for the pre-computed-embeds test harness.
|
||
* 4. Decode codes → audio via decodeChunked, emit the audio through
|
||
* the callback immediately, save WAV per segment plus the concat.
|
||
*
|
||
* Emits at most one callback per segment. First-audio latency ≈ prefill +
|
||
* one segment's decode (typically ~17-22s for a 5-6s-duration segment on
|
||
* Snapdragon 8 Elite). The ordering and gaps between segments are the
|
||
* same as generateFromEmbedsHexagonStreaming.
|
||
*/
|
||
fun synthesizeTextStreaming(
|
||
text: String,
|
||
onSegmentReady: ((segIdx: Int, audio: ShortArray) -> Unit)? = null
|
||
): ShortArray {
|
||
if (!loaded || !useHexagonTalker) {
|
||
nlog("synthesizeTextStreaming: Hexagon talker not ready"); return ShortArray(0)
|
||
}
|
||
if (bpeTokenizer == null || textEmbedsFullBuf == null) {
|
||
nlog("synthesizeTextStreaming: Stage 2 assets missing"); return ShortArray(0)
|
||
}
|
||
val segments = splitSentences(text)
|
||
nlog("synthesizeTextStreaming: ${segments.size} segment(s) for ${text.length} chars")
|
||
|
||
hexReset()
|
||
val segmentAudios = mutableListOf<ShortArray>()
|
||
// Wider gap (250 ms) between segments — gives the listener a small
|
||
// breath/pause that mirrors how real speakers separate sentences,
|
||
// and masks the EOS-boost cut by surrounding it with silence rather
|
||
// than another sentence's onset.
|
||
val gapSamples = SR * 250 / 1000
|
||
val gap = ShortArray(gapSamples)
|
||
val t0 = System.currentTimeMillis()
|
||
|
||
val prefix = damienVoicePrefix!!
|
||
val suffix = damienVoiceSuffix!!
|
||
// CODEC_PAD embedding is the per-token companion that Python voice-
|
||
// cloning sums into every text-encoded prefill position. Computed
|
||
// once here so the per-token loop stays a simple vector add.
|
||
val codecPadEmb = codecEmb(CODEC_PAD)
|
||
|
||
for ((segIdx, segText) in segments.withIndex()) {
|
||
if (segIdx > 0) hexReset()
|
||
val tSeg = System.currentTimeMillis()
|
||
val ids = bpeTokenizer!!.encode(segText)
|
||
nlog("Seg ${segIdx+1}/${segments.size}: '${segText.take(60)}' → ${ids.size} tokens: ${ids.toList()}")
|
||
|
||
// Voice-cloning prefill, fully reconstructed on-device — exact
|
||
// structure Python emits via generate_voice_clone (verified
|
||
// bit-for-bit by comparing captured Baer segments):
|
||
// [0..8] damienVoicePrefix (9 fixed positions, xvector@7)
|
||
// [9..N-3] text_projection(BPE_id) + codec_embedding(CODEC_PAD)
|
||
// [N-2, N-1] damienVoiceSuffix (2 fixed positions, end-of-text marker)
|
||
// The earlier attempt skipped the suffix and used raw text
|
||
// projections, which produced garbled audio — the talker needs
|
||
// BOTH the per-token codec_pad sum AND the closure markers to
|
||
// know that text input has ended and decoding can begin.
|
||
val prefill = ArrayList<FloatArray>(prefix.size + ids.size + suffix.size)
|
||
for (e in prefix) prefill.add(e)
|
||
for (id in ids) prefill.add(sumEmb(textEmbFromFull(id), codecPadEmb))
|
||
for (e in suffix) prefill.add(e)
|
||
|
||
// Generous absolute cap; the dynamic EOS boost (triggered when
|
||
// the model's own EOS rank dips below threshold) is what
|
||
// actually terminates generation. The minStep floor protects
|
||
// against early-termination spikes for short sentences.
|
||
val expectedSteps = (ids.size * 24) / 10 // ids * 2.4 (int math)
|
||
val maxGen = minOf(expectedSteps * 3 / 2 + 10, MAX_CONTEXT - 15)
|
||
val eosBoostMinStep = expectedSteps / 2
|
||
val codes = runHexGenWithPrefill(prefill, maxGen, eosBoostMinStep)
|
||
if (codes.isEmpty()) { nlog("Seg ${segIdx+1}: empty codes"); continue }
|
||
|
||
val n = codes.size
|
||
val padLen = maxOf(n, SEQ_LEN)
|
||
val codebooks = Array(NUM_CODEBOOKS) { cb ->
|
||
IntArray(padLen) { t ->
|
||
if (t < n) { val v = codes[t][cb]; if (v in 0 until CODEBOOK_SIZE) v else 0 } else 0
|
||
}
|
||
}
|
||
val rawAudio = decodeChunked(codebooks, n)
|
||
// Cosine fade-out on the last 40 ms — softens the cut imposed by
|
||
// the EOS boost so the segment ends on a recognisable phoneme
|
||
// tail instead of an abrupt sample-clip.
|
||
val audio = fadeOut(rawAudio, 40)
|
||
val segMs = System.currentTimeMillis() - tSeg
|
||
val budgetHit = if (n >= maxGen) " [maxGen cap]" else ""
|
||
nlog("Seg ${segIdx+1}/${segments.size}: $n tokens, ${audio.size/SR.toFloat()}s audio in ${segMs}ms$budgetHit")
|
||
|
||
segmentAudios.add(audio)
|
||
saveWav("/data/local/tmp/kazeia/kazeia_stream_seg${segIdx+1}.wav", audio)
|
||
onSegmentReady?.invoke(segIdx, audio)
|
||
}
|
||
|
||
if (segmentAudios.isEmpty()) return ShortArray(0)
|
||
val total = segmentAudios.sumOf { it.size } + maxOf(0, segmentAudios.size - 1) * gapSamples
|
||
val concat = ShortArray(total)
|
||
var off = 0
|
||
for ((i, s) in segmentAudios.withIndex()) {
|
||
System.arraycopy(s, 0, concat, off, s.size); off += s.size
|
||
if (i < segmentAudios.size - 1) { System.arraycopy(gap, 0, concat, off, gapSamples); off += gapSamples }
|
||
}
|
||
saveWav("/data/local/tmp/kazeia/kazeia_stream_full.wav", concat)
|
||
nlog("synthesizeTextStreaming total: ${System.currentTimeMillis() - t0}ms for ${concat.size/SR.toFloat()}s")
|
||
return concat
|
||
}
|
||
|
||
/**
|
||
* Trim low-energy tail from a voice-cloned segment. When the talker
|
||
* fails to emit EOS within maxGen, the remaining decoded codes tend
|
||
* to be "page beg beg" fillers — audible as a mumbled tail. This
|
||
* RMS-based trim finds the last sustained high-energy window and cuts
|
||
* after it, with a small fade-out so the last real syllable keeps its
|
||
* natural decay. Extracted from generateMultiSegment so the
|
||
* synthesizeTextStreaming path can reuse it verbatim.
|
||
*/
|
||
private fun trimTailLowEnergy(audio: ShortArray): ShortArray {
|
||
if (audio.size < SR / 2) return audio
|
||
val winSamples = SR * 40 / 1000 // 40 ms windows = 960 samples
|
||
val nWin = audio.size / winSamples
|
||
if (nWin < 10) return audio
|
||
val rms = FloatArray(nWin)
|
||
for (w in 0 until nWin) {
|
||
var s = 0.0
|
||
val o = w * winSamples
|
||
for (i in 0 until winSamples) { val x = audio[o + i].toFloat() / 32768f; s += x * x }
|
||
rms[w] = kotlin.math.sqrt(s / winSamples).toFloat()
|
||
}
|
||
var peak = 0f
|
||
for (w in 0 until nWin) if (rms[w] > peak) peak = rms[w]
|
||
|
||
// Heuristic 1: classic 35 % sustained threshold from the back.
|
||
val thrSust = peak * 0.35f
|
||
var lastSpeech = nWin - 1
|
||
for (w in nWin - 1 downTo 2) {
|
||
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
|
||
lastSpeech = w; break
|
||
}
|
||
}
|
||
|
||
// Heuristic 2: "low-energy tail" — the "page beg beg" filler emits
|
||
// audio that's louder than silence but quieter than real speech.
|
||
// If the last 30 % of the segment has a max RMS under 40 % of the
|
||
// global peak, the tail is degenerate; cut at the last sustained-
|
||
// speech window before the tail starts.
|
||
val tailStart = (nWin * 7) / 10
|
||
var tailMax = 0f
|
||
for (w in tailStart until nWin) if (rms[w] > tailMax) tailMax = rms[w]
|
||
if (tailMax < peak * 0.40f) {
|
||
for (w in tailStart - 1 downTo 2) {
|
||
if (rms[w] >= thrSust && rms[w-1] >= thrSust && rms[w-2] >= thrSust) {
|
||
if (w < lastSpeech) lastSpeech = w
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
val keepWin = (lastSpeech + 2).coerceAtMost(nWin - 1)
|
||
var keepSamples = (keepWin + 1) * winSamples
|
||
// Over-trim guard: never cut below 40 % of the raw length.
|
||
val minKeep = (audio.size * 4) / 10
|
||
if (keepSamples < minKeep) keepSamples = minKeep
|
||
return audio.copyOf(keepSamples)
|
||
}
|
||
|
||
/**
|
||
* Apply a short cosine fade-out to the tail of an audio buffer. The
|
||
* EOS boost ends generation right at the model's chosen stop point,
|
||
* which usually clips the natural decay of the last syllable. A
|
||
* 40-ms fade smooths that into a recognisable phoneme tail without
|
||
* shortening the perceived word.
|
||
*/
|
||
private fun fadeOut(audio: ShortArray, fadeMs: Int = 40): ShortArray {
|
||
val fadeSamples = SR * fadeMs / 1000
|
||
if (audio.size <= fadeSamples) return audio
|
||
val out = audio.copyOf()
|
||
val start = out.size - fadeSamples
|
||
for (i in 0 until fadeSamples) {
|
||
// Cosine roll-off: 1 → 0 over fadeSamples
|
||
val t = i.toFloat() / fadeSamples
|
||
val gain = 0.5f * (1f + kotlin.math.cos(Math.PI.toFloat() * t))
|
||
out[start + i] = (out[start + i] * gain).toInt().toShort()
|
||
}
|
||
return out
|
||
}
|
||
|
||
/** Write PCM16 mono audio to a WAV file. Used by the streaming pipeline to
|
||
* save one file per segment plus the concatenated result for inspection. */
|
||
private fun saveWav(path: String, audio: ShortArray) {
|
||
try {
|
||
val fos = java.io.FileOutputStream(path)
|
||
val dataLen = audio.size * 2
|
||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||
header.putInt(16); header.putShort(1); header.putShort(1)
|
||
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||
header.put("data".toByteArray()); header.putInt(dataLen)
|
||
fos.write(header.array())
|
||
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||
for (s in audio) buf.putShort(s)
|
||
fos.write(buf.array()); fos.close()
|
||
} catch (e: Exception) { nlog("WAV save failed ($path): ${e.message}") }
|
||
}
|
||
|
||
/** Test with pre-computed codec tokens from PC (for validation) */
|
||
fun testWithPrecomputedCodes(codesPath: String, realTokens: Int = 16): ShortArray {
|
||
if (!loaded) return ShortArray(0)
|
||
nlog("Testing with pre-computed codes: $codesPath (realTokens=$realTokens)")
|
||
try {
|
||
val bytes = File(codesPath).readBytes()
|
||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||
val allCodes = Array(NUM_CODEBOOKS) { IntArray(SEQ_LEN) }
|
||
for (cb in 0 until NUM_CODEBOOKS) {
|
||
for (t in 0 until SEQ_LEN) {
|
||
allCodes[cb][t] = bb.int.coerceIn(0, CODEBOOK_SIZE - 1)
|
||
}
|
||
}
|
||
nlog("Loaded ${NUM_CODEBOOKS} × ${SEQ_LEN} codes")
|
||
nlog("codes[0][:5] = ${allCodes[0].take(5).toList()}")
|
||
nlog("codes[1][:5] = ${allCodes[1].take(5).toList()}")
|
||
val t0 = System.currentTimeMillis()
|
||
val quantized = vqDecode(allCodes)
|
||
// Dump quantized stats for debug
|
||
var qMin = Float.MAX_VALUE; var qMax = Float.MIN_VALUE; var qSum = 0.0
|
||
for (v in quantized) { qMin = minOf(qMin, v); qMax = maxOf(qMax, v); qSum += v * v }
|
||
nlog("quantized: size=${quantized.size}, range=[${qMin}, ${qMax}], rms=${Math.sqrt(qSum / quantized.size)}")
|
||
// Save quantized to file for comparison
|
||
try {
|
||
val qf = java.io.File("/data/local/tmp/kazeia/quantized_dump.bin")
|
||
qf.writeBytes(java.nio.ByteBuffer.allocate(quantized.size * 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).also {
|
||
for (v in quantized) it.putFloat(v)
|
||
}.array())
|
||
nlog("Saved quantized to ${qf.absolutePath}")
|
||
} catch (_: Exception) {}
|
||
val fullAudio = runSpeechDecoder(quantized)
|
||
val trimSamples = minOf(realTokens * SAMPLES_PER_TOKEN, fullAudio.size)
|
||
val audio = fullAudio.copyOf(trimSamples)
|
||
nlog("Decode: ${System.currentTimeMillis() - t0}ms, audio=${audio.size.toFloat()/SR}s (trimmed from ${fullAudio.size.toFloat()/SR}s)")
|
||
return audio
|
||
} catch (e: Exception) {
|
||
nlog("Test error: ${e.message}")
|
||
return ShortArray(0)
|
||
}
|
||
}
|
||
|
||
override fun release() {
|
||
stop()
|
||
// Stop hexagon runners cleanly
|
||
if (useHexagonTalker || useHexagonCp) {
|
||
hexStopRunner()
|
||
}
|
||
talkerKv?.close()
|
||
cpKv?.close()
|
||
preConv?.close(); preprocessor?.close(); convDecoder?.close()
|
||
ortEnv?.close()
|
||
loaded = false
|
||
}
|
||
}
|