diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 0d4a163..5ed1656 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -1074,44 +1074,51 @@ class Qwen3TtsEngine( nlog("PTE pipeline: prefill=${prefill.size}, trailing=${trailingEmbeds.size}") + // Pre-allocate reusable buffers + Tensor/EValue (avoids GC overhead per step) + val tEmbBuf = FloatArray(TALKER_DIM) + val tMaskBuf = FloatArray(TALKER_PTE_KV_LEN) + val tCosBuf = FloatArray(TALKER_HEAD_DIM) + val tSinBuf = FloatArray(TALKER_HEAD_DIM) + val tInputs = arrayOfNulls(4 + TALKER_LAYERS * 2) + tInputs[0] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tEmbBuf, longArrayOf(1,1,TALKER_DIM.toLong()))) + tInputs[1] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tMaskBuf, longArrayOf(1,1,1,TALKER_PTE_KV_LEN.toLong()))) + tInputs[2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tCosBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong()))) + tInputs[3] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tSinBuf, longArrayOf(1,1,TALKER_HEAD_DIM.toLong()))) + for (i in 0 until TALKER_LAYERS) { + tInputs[4+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong()))) + tInputs[5+i*2] = org.pytorch.executorch.EValue.from(org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1,TALKER_HEADS.toLong(),TALKER_PTE_KV_LEN.toLong(),TALKER_HEAD_DIM.toLong()))) + } + + // Talker step helper: update buffers → forward → extract outputs + fun talkerStep(emb: FloatArray): Pair { + System.arraycopy(emb, 0, tEmbBuf, 0, TALKER_DIM) + System.arraycopy(maskData, 0, tMaskBuf, 0, TALKER_PTE_KV_LEN) + val pi = minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) + System.arraycopy(tCos, pi * TALKER_HEAD_DIM, tCosBuf, 0, TALKER_HEAD_DIM) + System.arraycopy(tSin, pi * TALKER_HEAD_DIM, tSinBuf, 0, TALKER_HEAD_DIM) + // KV buffers already point to tK[i]/tV[i] — data updated in-place below + + val out = talkerMod.forward(*tInputs.requireNoNulls()) + val hidden = out[0].toTensor().dataAsFloatArray + val logits = out[1].toTensor().dataAsFloatArray + for (i in 0 until TALKER_LAYERS) { + val newK = out[2+i*2].toTensor().dataAsFloatArray + val newV = out[3+i*2].toTensor().dataAsFloatArray + System.arraycopy(newK, 0, tK[i], 0, tkvSize) + System.arraycopy(newV, 0, tV[i], 0, tkvSize) + } + pos++ + return Pair(hidden, logits) + } + // ===== PREFILL ===== val tPrefill = System.currentTimeMillis() for (step in prefill.indices) { - // Unmask position val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1) if (maskIdx >= 0) maskData[maskIdx] = 0f - val cosSlice = FloatArray(TALKER_HEAD_DIM) - System.arraycopy(tCos, pos * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM) - val sinSlice = FloatArray(TALKER_HEAD_DIM) - System.arraycopy(tSin, pos * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM) - - val inputs = mutableListOf( - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(prefill[step], longArrayOf(1, 1, TALKER_DIM.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))) - ) - for (i in 0 until TALKER_LAYERS) { - inputs.add(org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong())))) - inputs.add(org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong())))) - } - - val out = talkerMod.forward(*inputs.toTypedArray()) - pastHidden = out[0].toTensor().dataAsFloatArray - val logits = out[1].toTensor().dataAsFloatArray - for (i in 0 until TALKER_LAYERS) { - tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray - tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray - } - pos++ - + val (h, logits) = talkerStep(prefill[step]) + pastHidden = h if (step == prefill.size - 1) { for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } currentCb0 = sampleTopK(logits, 0.9f, 50) @@ -1121,12 +1128,11 @@ class Qwen3TtsEngine( if (currentCb0 < 0 || currentCb0 == CODEC_EOS) return emptyArray() - // ===== INTERLEAVED GENERATION ===== + // ===== INTERLEAVED GENERATION (reusing pre-allocated tensors) ===== var totalTalkerMs = 0L; var totalCpMs = 0L for (genStep in 0 until maxGenTokens) { val codes = IntArray(NUM_CODEBOOKS); codes[0] = currentCb0 - // 1. CP: predict CB1-15 val tCp = System.currentTimeMillis() val cpCodes = runCpPte(pastHidden!!, currentCb0) totalCpMs += System.currentTimeMillis() - tCp @@ -1135,7 +1141,7 @@ class Qwen3TtsEngine( if (genStep < 3) nlog("Step ${genStep+1}: cb0=$currentCb0 cb1=${codes[1]}") - // 2. Build next talker input + // Build next talker input val codecSum = FloatArray(TALKER_DIM) addEmb(codecSum, codecEmb(codes[0])) for (cb in 1 until NUM_CODEBOOKS) addEmb(codecSum, cpEmb(cb - 1, codes[cb])) @@ -1146,45 +1152,14 @@ class Qwen3TtsEngine( else -> sumEmb(codecSum, padE) } - // 3. Talker step val maskIdx = TALKER_PTE_KV_LEN - 1 - minOf(pos, TALKER_PTE_KV_LEN - 1) if (maskIdx >= 0) maskData[maskIdx] = 0f - val cosSlice = FloatArray(TALKER_HEAD_DIM) - System.arraycopy(tCos, minOf(pos, tCos.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, cosSlice, 0, TALKER_HEAD_DIM) - val sinSlice = FloatArray(TALKER_HEAD_DIM) - System.arraycopy(tSin, minOf(pos, tSin.size / TALKER_HEAD_DIM - 1) * TALKER_HEAD_DIM, sinSlice, 0, TALKER_HEAD_DIM) - - val inputs = mutableListOf( - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(nextEmbed, longArrayOf(1, 1, TALKER_DIM.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(maskData.clone(), longArrayOf(1, 1, 1, TALKER_PTE_KV_LEN.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(cosSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))), - org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(sinSlice, longArrayOf(1, 1, TALKER_HEAD_DIM.toLong()))) - ) - for (i in 0 until TALKER_LAYERS) { - inputs.add(org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(tK[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong())))) - inputs.add(org.pytorch.executorch.EValue.from( - org.pytorch.executorch.Tensor.fromBlob(tV[i], longArrayOf(1, TALKER_HEADS.toLong(), TALKER_PTE_KV_LEN.toLong(), TALKER_HEAD_DIM.toLong())))) - } - val tTalker = System.currentTimeMillis() - val out = talkerMod.forward(*inputs.toTypedArray()) + val (h, logits) = talkerStep(nextEmbed) totalTalkerMs += System.currentTimeMillis() - tTalker + pastHidden = h - pastHidden = out[0].toTensor().dataAsFloatArray - val logits = out[1].toTensor().dataAsFloatArray - for (i in 0 until TALKER_LAYERS) { - tK[i] = out[2 + i * 2].toTensor().dataAsFloatArray - tV[i] = out[3 + i * 2].toTensor().dataAsFloatArray - } - pos++ - - // 4. Next CB0 for (j in CODEBOOK_SIZE until TALKER_VOCAB) { if (j != CODEC_EOS) logits[j] = Float.NEGATIVE_INFINITY } val seen = HashSet(); for (prev in generatedCb0) seen.add(prev) for (tok in seen) { logits[tok] = if (logits[tok] > 0) logits[tok] / 1.05f else logits[tok] * 1.05f }