Revert "Pre-allocate Tensor/EValue in Java pipeline: 16s → 8.9s (RTF 1.9)"
This reverts commit 0f027c5fde.
This commit is contained in:
parent
0f027c5fde
commit
439629c9bf
|
|
@ -1074,51 +1074,44 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}")
|
||||||
|
|
||||||
// Pre-allocate reusable buffers + Tensor/EValue (avoids GC overhead per step)
|
|
||||||
val tEmbBuf = FloatArray(TALKER_DIM)
|
|
||||||
val tMaskBuf = FloatArray(TALKER_PTE_KV_LEN)
|
|
||||||
val tCosBuf = FloatArray(TALKER_HEAD_DIM)
|
|
||||||
val tSinBuf = FloatArray(TALKER_HEAD_DIM)
|
|
||||||
val tInputs = arrayOfNulls<org.pytorch.executorch.EValue>(4 + TALKER_LAYERS * 2)
|
|
||||||
tInputs[0] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tEmbBuf, longArrayOf(1,1,TALKER_DIM.toLong())))
|
|
||||||
tInputs[1] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tMaskBuf, longArrayOf(1,1,1,TALKER_PTE_KV_LEN.toLong())))
|
|
||||||
tInputs[2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tCosBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
|
|
||||||
tInputs[3] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tSinBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong())))
|
|
||||||
for (i in 0 until TALKER_LAYERS) {
|
|
||||||
tInputs[4+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong())))
|
|
||||||
tInputs[5+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong())))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Talker step helper: update buffers → forward → extract outputs
|
|
||||||
fun talkerStep(emb: FloatArray): Pair<FloatArray, FloatArray> {
|
|
||||||
System.arraycopy(emb, 0, tEmbBuf, 0, TALKER_DIM)
|
|
||||||
System.arraycopy(maskData, 0, tMaskBuf, 0, TALKER_PTE_KV_LEN)
|
|
||||||
val pi = minOf(pos, tCos.size / TALKER_HEAD_DIM - 1)
|
|
||||||
System.arraycopy(tCos, pi * TALKER_HEAD_DIM, tCosBuf, 0, TALKER_HEAD_DIM)
|
|
||||||
System.arraycopy(tSin, pi * TALKER_HEAD_DIM, tSinBuf, 0, TALKER_HEAD_DIM)
|
|
||||||
// KV buffers already point to tK[i]/tV[i] — data updated in-place below
|
|
||||||
|
|
||||||
val out = talkerMod.forward(*tInputs.requireNoNulls())
|
|
||||||
val hidden = out[0].toTensor().dataAsFloatArray
|
|
||||||
val logits = out[1].toTensor().dataAsFloatArray
|
|
||||||
for (i in 0 until TALKER_LAYERS) {
|
|
||||||
val newK = out[2+i*2].toTensor().dataAsFloatArray
|
|
||||||
val newV = out[3+i*2].toTensor().dataAsFloatArray
|
|
||||||
System.arraycopy(newK, 0, tK[i], 0, tkvSize)
|
|
||||||
System.arraycopy(newV, 0, tV[i], 0, tkvSize)
|
|
||||||
}
|
|
||||||
pos++
|
|
||||||
return Pair(hidden, logits)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ===== PREFILL =====
|
// ===== PREFILL =====
|
||||||
val tPrefill = System.currentTimeMillis()
|
val tPrefill = System.currentTimeMillis()
|
||||||
for (step in prefill.indices) {
|
for (step in prefill.indices) {
|
||||||
|
// Unmask position
|
||||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||||
|
|
||||||
val (h, logits) = talkerStep(prefill[step])
|
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||||||
pastHidden = h
|
System.arraycopy(tCos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||||
|
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||||||
|
System.arraycopy(tSin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||||
|
|
||||||
|
val inputs = mutableListOf(
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(prefill[step], longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||||||
|
)
|
||||||
|
for (i in 0 until TALKER_LAYERS) {
|
||||||
|
inputs.add(org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||||
|
inputs.add(org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||||
|
}
|
||||||
|
|
||||||
|
val out = talkerMod.forward(*inputs.toTypedArray())
|
||||||
|
pastHidden = out[0].toTensor().dataAsFloatArray
|
||||||
|
val logits = out[1].toTensor().dataAsFloatArray
|
||||||
|
for (i in 0 until TALKER_LAYERS) {
|
||||||
|
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||||||
|
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
|
||||||
if (step == prefill.size - 1) {
|
if (step == prefill.size - 1) {
|
||||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||||
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
currentCb0 = sampleTopK(logits, 0.9f, 50)
|
||||||
|
|
@ -1128,11 +1121,12 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray()
|
||||||
|
|
||||||
// ===== INTERLEAVED GENERATION (reusing pre-allocated tensors) =====
|
// ===== INTERLEAVED GENERATION =====
|
||||||
var totalTalkerMs = 0L; var totalCpMs = 0L
|
var totalTalkerMs = 0L; var totalCpMs = 0L
|
||||||
for (genStep in 0 until maxGenTokens) {
|
for (genStep in 0 until maxGenTokens) {
|
||||||
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0
|
||||||
|
|
||||||
|
// 1. CP: predict CB1-15
|
||||||
val tCp = System.currentTimeMillis()
|
val tCp = System.currentTimeMillis()
|
||||||
val cpCodes = runCpPte(pastHidden!!, currentCb0)
|
val cpCodes = runCpPte(pastHidden!!, currentCb0)
|
||||||
totalCpMs += System.currentTimeMillis() - tCp
|
totalCpMs += System.currentTimeMillis() - tCp
|
||||||
|
|
@ -1141,7 +1135,7 @@ class Qwen3TtsEngine(
|
||||||
|
|
||||||
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}")
|
||||||
|
|
||||||
// Build next talker input
|
// 2. Build next talker input
|
||||||
val codecSum = FloatArray(TALKER_DIM)
|
val codecSum = FloatArray(TALKER_DIM)
|
||||||
addEmb(codecSum, codecEmb(codes[0]))
|
addEmb(codecSum, codecEmb(codes[0]))
|
||||||
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb]))
|
||||||
|
|
@ -1152,14 +1146,45 @@ class Qwen3TtsEngine(
|
||||||
else -> sumEmb(codecSum, padE)
|
else -> sumEmb(codecSum, padE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3. Talker step
|
||||||
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1)
|
||||||
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
if (maskIdx >= 0) maskData[maskIdx] = 0f
|
||||||
|
|
||||||
val tTalker = System.currentTimeMillis()
|
val cosSlice = FloatArray(TALKER_HEAD_DIM)
|
||||||
val (h, logits) = talkerStep(nextEmbed)
|
System.arraycopy(tCos, minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM)
|
||||||
totalTalkerMs += System.currentTimeMillis() - tTalker
|
val sinSlice = FloatArray(TALKER_HEAD_DIM)
|
||||||
pastHidden = h
|
System.arraycopy(tSin, minOf(pos, tSin.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM)
|
||||||
|
|
||||||
|
val inputs = mutableListOf(
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(nextEmbed, longArrayOf(1, 1, TALKER_DIM.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))),
|
||||||
|
org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong())))
|
||||||
|
)
|
||||||
|
for (i in 0 until TALKER_LAYERS) {
|
||||||
|
inputs.add(org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||||
|
inputs.add(org.pytorch.executorch.EValue.from(
|
||||||
|
org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong()))))
|
||||||
|
}
|
||||||
|
|
||||||
|
val tTalker = System.currentTimeMillis()
|
||||||
|
val out = talkerMod.forward(*inputs.toTypedArray())
|
||||||
|
totalTalkerMs += System.currentTimeMillis() - tTalker
|
||||||
|
|
||||||
|
pastHidden = out[0].toTensor().dataAsFloatArray
|
||||||
|
val logits = out[1].toTensor().dataAsFloatArray
|
||||||
|
for (i in 0 until TALKER_LAYERS) {
|
||||||
|
tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray
|
||||||
|
tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
|
||||||
|
// 4. Next CB0
|
||||||
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY }
|
||||||
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
val seen = HashSet<Int>(); for (prev in generatedCb0) seen.add(prev)
|
||||||
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue