Pre-allocate Tensor/EValue in Java pipeline: 16s → 8.9s (RTF 1.9)
Reuse float arrays and Tensor/EValue objects across talker steps instead of creating new ones each iteration. Eliminates ~7s of GC overhead from thousands of JNI object allocations. Same validated audio quality as before, no C++ pipeline needed. Talker 35ms/step, CP 58ms/step, total 8.9s for 4.64s audio. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
8e536094df
commit
0f027c5fde
|
|
@ -1074,44 +1074,51 @@ class Qwen3TtsEngine(
|
|||
|
||||
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
||||
|
||||
// Pre-allocate reusable buffers + Tensor/EValue (avoids GC overhead per step)
|
||||
val tEmbBuf = FloatArray(TALKER_DIM)
|
||||
val tMaskBuf = FloatArray(TALKER_PTE_KV_LEN)
|
||||
val tCosBuf = FloatArray(TALKER_HEAD_DIM)
|
||||
val tSinBuf = FloatArray(TALKER_HEAD_DIM)
|
||||
val tInputs = arrayOfNulls<org.pytorch.executorch.EValue>(4 + TALKER_LAYERS * 2)
|
||||
tInputs[0] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tEmbBuf, longArrayOf(1,1,TALKER_DIM.toLong())))
|
||||
tInputs[1] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tMaskBuf, longArrayOf(1,1,1,TALKER_PTE_KV_LEN.toLong())))
|
||||
tInputs[2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tCosBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
|
||||
tInputs[3] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tSinBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
tInputs[4+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong())))
|
||||
tInputs[5+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong())))
|
||||
}
|
||||
|
||||
// Talker step helper: update buffers → forward → extract outputs
|
||||
fun talkerStep(emb: FloatArray): Pair<FloatArray, FloatArray> {
|
||||
System.arraycopy(emb, 0, tEmbBuf, 0, TALKER_DIM)
|
||||
System.arraycopy(maskData, 0, tMaskBuf, 0, TALKER_PTE_KV_LEN)
|
||||
val pi = minOf(pos, tCos.size / TALKER_HEAD_DIM - 1)
|
||||
System.arraycopy(tCos, pi * TALKER_HEAD_DIM, tCosBuf, 0, TALKER_HEAD_DIM)
|
||||
System.arraycopy(tSin, pi * TALKER_HEAD_DIM, tSinBuf, 0, TALKER_HEAD_DIM)
|
||||
// KV buffers already point to tK[i]/tV[i] — data updated in-place below
|
||||
|
||||
val out = talkerMod.forward(*tInputs.requireNoNulls())
|
||||
val hidden = out[0].toTensor().dataAsFloatArray
|
||||
val logits = out[1].toTensor().dataAsFloatArray
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
val newK = out[2+i*2].toTensor().dataAsFloatArray
|
||||
val newV = out[3+i*2].toTensor().dataAsFloatArray
|
||||
System.arraycopy(newK, 0, tK[i], 0, tkvSize)
|
||||
System.arraycopy(newV, 0, tV[i], 0, tkvSize)
|
||||
}
|
||||
pos++
|
||||
return Pair(hidden, logits)
|
||||
}
|
||||
|
||||
// ===== PREFILL =====
|
||||
val tPrefill = System.currentTimeMillis()
|
||||
for (step in prefill.indices) {
|
||||
// Unmask position
|
||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||
|
||||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
System.arraycopy(tCos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
System.arraycopy(tSin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||
|
||||
val inputs = mutableListOf(
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(prefill[step], longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||||
)
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
inputs.add(org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||
inputs.add(org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||
}
|
||||
|
||||
val out = talkerMod.forward(*inputs.toTypedArray())
|
||||
pastHidden = out[0].toTensor().dataAsFloatArray
|
||||
val logits = out[1].toTensor().dataAsFloatArray
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||||
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||||
}
|
||||
pos++
|
||||
|
||||
val (h, logits) = talkerStep(prefill[step])
|
||||
pastHidden = h
|
||||
if (step == prefill.size - 1) {
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||||
|
|
@ -1121,12 +1128,11 @@ class Qwen3TtsEngine(
|
|||
|
||||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
||||
|
||||
// ===== INTERLEAVED GENERATION =====
|
||||
// ===== INTERLEAVED GENERATION (reusing pre-allocated tensors) =====
|
||||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||||
for (genStep in 0 until maxGenTokens) {
|
||||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||||
|
||||
// 1. CP: predict CB1-15
|
||||
val tCp = System.currentTimeMillis()
|
||||
val cpCodes = runCpPte(pastHidden!!, currentCb0)
|
||||
totalCpMs += System.currentTimeMillis() - tCp
|
||||
|
|
@ -1135,7 +1141,7 @@ class Qwen3TtsEngine(
|
|||
|
||||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||
|
||||
// 2. Build next talker input
|
||||
// Build next talker input
|
||||
val codecSum = FloatArray(TALKER_DIM)
|
||||
addEmb(codecSum, codecEmb(codes[0]))
|
||||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||||
|
|
@ -1146,45 +1152,14 @@ class Qwen3TtsEngine(
|
|||
else -> sumEmb(codecSum, padE)
|
||||
}
|
||||
|
||||
// 3. Talker step
|
||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||
|
||||
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
System.arraycopy(tCos, minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||||
System.arraycopy(tSin, minOf(pos, tSin.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||
|
||||
val inputs = mutableListOf(
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(nextEmbed, longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||||
org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||||
)
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
inputs.add(org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||
inputs.add(org.pytorch.executorch.EValue.from(
|
||||
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||
}
|
||||
|
||||
val tTalker = System.currentTimeMillis()
|
||||
val out = talkerMod.forward(*inputs.toTypedArray())
|
||||
val (h, logits) = talkerStep(nextEmbed)
|
||||
totalTalkerMs += System.currentTimeMillis() - tTalker
|
||||
pastHidden = h
|
||||
|
||||
pastHidden = out[0].toTensor().dataAsFloatArray
|
||||
val logits = out[1].toTensor().dataAsFloatArray
|
||||
for (i in 0 until TALKER_LAYERS) {
|
||||
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||||
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||||
}
|
||||
pos++
|
||||
|
||||
// 4. Next CB0
|
||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||
|
|
|
|||
Loading…
Reference in New Issue