Optimize decoder: BigVGAN 8T, small models 4T → RTF 1.26

BigVGAN benefits from 8 intra-op threads (all perf cores).
Pre_conv and pre_transformer kept at 4T (small, less contention).

BigVGAN: 2757ms → 1872ms (-885ms), decode total: 2830ms → 2035ms
Pipeline: 6438ms → 5834ms → RTF 1.26

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 13:00:05 +02:00
parent a688edc9ec
commit 42bbb96fd8
1 changed files with 12 additions and 9 deletions

View File

@ -187,13 +187,13 @@ class Qwen3TtsEngine(
return session
}
fun loadCpu(name: String): OrtSession {
fun loadCpu(name: String, threads: Int = 8): OrtSession {
val t = System.currentTimeMillis()
val opts = OrtSession.SessionOptions()
opts.setIntraOpNumThreads(6)
opts.setIntraOpNumThreads(threads)
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
nlog("$name (CPU 6T): ${System.currentTimeMillis() - t}ms")
nlog("$name (CPU ${threads}T): ${System.currentTimeMillis() - t}ms")
return session
}
@ -213,9 +213,9 @@ class Qwen3TtsEngine(
val v2Path = "$path/v2_pre_conv"
if (File("$v2Path/model.onnx").exists()) {
nlog("Loading V2 speech decoder (CPU ONNX)...")
preConv = loadCpu("v2_pre_conv")
preprocessor = loadCpu("v2_pre_transformer")
convDecoder = loadCpu("v2_decoder_conv")
preConv = loadCpu("v2_pre_conv", 4)
preprocessor = loadCpu("v2_pre_transformer", 4)
convDecoder = loadCpu("v2_decoder_conv", 8) // BigVGAN benefits from more threads
decoderOnCpu = true
nlog("Speech decoder V2 on CPU")
} else {
@ -2058,6 +2058,7 @@ class Qwen3TtsEngine(
/** V2 decoder: quantized[1,512,60] → pre_conv → transpose → pre_transformer → decoder → audio */
private fun runSpeechDecoderV2(quantized: FloatArray): ShortArray {
val env = ortEnv!!
val td0 = System.currentTimeMillis()
// pre_conv: [1,512,60] → [1,1024,60]
val qTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(quantized),
longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong()))
@ -2065,7 +2066,8 @@ class Qwen3TtsEngine(
val pcOutRaw = (pcResult[0] as OnnxTensor).floatBuffer
val pcData = FloatArray(1024 * SEQ_LEN)
pcOutRaw.get(pcData)
nlog("V2 pre_conv done")
val td1 = System.currentTimeMillis()
nlog("V2 pre_conv: ${td1-td0}ms")
qTensor.close()
// Transpose [1,1024,60] → [1,60,1024]
@ -2082,13 +2084,14 @@ class Qwen3TtsEngine(
// pre_transformer: [1,60,1024] → [1,60,1024]
val ptResult = preprocessor!!.run(mapOf(preprocessor!!.inputNames.first() to ptInput))
val ptOut = ptResult[0] as OnnxTensor
nlog("V2 pre_transformer done")
val td2 = System.currentTimeMillis()
nlog("V2 pre_transformer: ${td2-td1}ms")
ptInput.close()
// decoder: [1,60,1024] → [1,1,samples]
val cdResult = convDecoder!!.run(mapOf(convDecoder!!.inputNames.first() to ptOut))
val cdOut = cdResult[0] as OnnxTensor
nlog("V2 decoder done: ${cdOut.info.shape.contentToString()}")
nlog("V2 BigVGAN: ${System.currentTimeMillis()-td2}ms")
@Suppress("UNCHECKED_CAST")
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
ptResult.close()