Optimize decoder: BigVGAN 8T, small models 4T → RTF 1.26
BigVGAN benefits from 8 intra-op threads (all perf cores). Pre_conv and pre_transformer kept at 4T (small, less contention). BigVGAN: 2757ms → 1872ms (-885ms), decode total: 2830ms → 2035ms Pipeline: 6438ms → 5834ms → RTF 1.26 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a688edc9ec
commit
42bbb96fd8
|
|
@ -187,13 +187,13 @@ class Qwen3TtsEngine(
|
|||
return session
|
||||
}
|
||||
|
||||
fun loadCpu(name: String): OrtSession {
|
||||
fun loadCpu(name: String, threads: Int = 8): OrtSession {
|
||||
val t = System.currentTimeMillis()
|
||||
val opts = OrtSession.SessionOptions()
|
||||
opts.setIntraOpNumThreads(6)
|
||||
opts.setIntraOpNumThreads(threads)
|
||||
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)
|
||||
val session = ortEnv!!.createSession("$path/$name/model.onnx", opts)
|
||||
nlog("$name (CPU 6T): ${System.currentTimeMillis() - t}ms")
|
||||
nlog("$name (CPU ${threads}T): ${System.currentTimeMillis() - t}ms")
|
||||
return session
|
||||
}
|
||||
|
||||
|
|
@ -213,9 +213,9 @@ class Qwen3TtsEngine(
|
|||
val v2Path = "$path/v2_pre_conv"
|
||||
if (File("$v2Path/model.onnx").exists()) {
|
||||
nlog("Loading V2 speech decoder (CPU ONNX)...")
|
||||
preConv = loadCpu("v2_pre_conv")
|
||||
preprocessor = loadCpu("v2_pre_transformer")
|
||||
convDecoder = loadCpu("v2_decoder_conv")
|
||||
preConv = loadCpu("v2_pre_conv", 4)
|
||||
preprocessor = loadCpu("v2_pre_transformer", 4)
|
||||
convDecoder = loadCpu("v2_decoder_conv", 8) // BigVGAN benefits from more threads
|
||||
decoderOnCpu = true
|
||||
nlog("Speech decoder V2 on CPU")
|
||||
} else {
|
||||
|
|
@ -2058,6 +2058,7 @@ class Qwen3TtsEngine(
|
|||
/** V2 decoder: quantized[1,512,60] → pre_conv → transpose → pre_transformer → decoder → audio */
|
||||
private fun runSpeechDecoderV2(quantized: FloatArray): ShortArray {
|
||||
val env = ortEnv!!
|
||||
val td0 = System.currentTimeMillis()
|
||||
// pre_conv: [1,512,60] → [1,1024,60]
|
||||
val qTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(quantized),
|
||||
longArrayOf(1, HIDDEN_DIM.toLong(), SEQ_LEN.toLong()))
|
||||
|
|
@ -2065,7 +2066,8 @@ class Qwen3TtsEngine(
|
|||
val pcOutRaw = (pcResult[0] as OnnxTensor).floatBuffer
|
||||
val pcData = FloatArray(1024 * SEQ_LEN)
|
||||
pcOutRaw.get(pcData)
|
||||
nlog("V2 pre_conv done")
|
||||
val td1 = System.currentTimeMillis()
|
||||
nlog("V2 pre_conv: ${td1-td0}ms")
|
||||
qTensor.close()
|
||||
|
||||
// Transpose [1,1024,60] → [1,60,1024]
|
||||
|
|
@ -2082,13 +2084,14 @@ class Qwen3TtsEngine(
|
|||
// pre_transformer: [1,60,1024] → [1,60,1024]
|
||||
val ptResult = preprocessor!!.run(mapOf(preprocessor!!.inputNames.first() to ptInput))
|
||||
val ptOut = ptResult[0] as OnnxTensor
|
||||
nlog("V2 pre_transformer done")
|
||||
val td2 = System.currentTimeMillis()
|
||||
nlog("V2 pre_transformer: ${td2-td1}ms")
|
||||
ptInput.close()
|
||||
|
||||
// decoder: [1,60,1024] → [1,1,samples]
|
||||
val cdResult = convDecoder!!.run(mapOf(convDecoder!!.inputNames.first() to ptOut))
|
||||
val cdOut = cdResult[0] as OnnxTensor
|
||||
nlog("V2 decoder done: ${cdOut.info.shape.contentToString()}")
|
||||
nlog("V2 BigVGAN: ${System.currentTimeMillis()-td2}ms")
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val audioFloat = (cdOut.value as Array<Array<FloatArray>>)[0][0]
|
||||
ptResult.close()
|
||||
|
|
|
|||
Loading…
Reference in New Issue