Direct output→input KV copy: RTF 1.51 → 1.31

Skip intermediate KV buffer: copy output tensors directly into
next step's input pointers. Saves ~1.5GB/run of memcpy for talker
(28L × 2 × 100×8×128 floats × 58 steps) and CP similarly.

Generation: 4007ms → 3713ms, total: 7180ms → 6078ms

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 12:23:45 +02:00
parent 14f7e5b05f
commit 985fd9cff9
1 changed files with 22 additions and 20 deletions

View File

@ -712,7 +712,7 @@ ExecuTorchJni::runTtsPipelineImpl(
auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad);
int tkvElem=T_KV*T_KV_LEN*T_HD;
std::vector<float> tK(T_L*tkvElem,0),tV(T_L*tkvElem,0);
// tK/tV not needed — KV copied directly from output to input each step
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
float hidden[DIM]={},logits[VOCAB]={};
std::vector<int> allCodes,cb0Hist;
@ -749,27 +749,26 @@ ExecuTorchJni::runTtsPipelineImpl(
memcpy(tInMask, mask, T_KV_LEN*4);
memcpy(tInCos, tCos.data()+pi*T_HD, T_HD*4);
memcpy(tInSin, tSin.data()+pi*T_HD, T_HD*4);
// KV: copy directly from PREVIOUS output to input (skip intermediate buffer)
if(pos > 0) {
for(int i=0;i<T_L;i++){
memcpy(tInKV[i*2], tK.data()+i*tkvElem, tkvElem*4);
memcpy(tInKV[i*2+1], tV.data()+i*tkvElem, tkvElem*4);
memcpy(tInKV[i*2], tMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
memcpy(tInKV[i*2+1], tMethod->get_output(3+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
}
}
// (pos==0: first call, input KV is already zeros from prepare_input_tensors)
auto status = tMethod->execute();
if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;}
memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
for(int i=0;i<T_L;i++){
memcpy(tK.data()+i*tkvElem, tMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
}
// KV NOT copied to tK/tV — read from output directly next step
pos++;
};
// CP step: 17 autoregressive steps with cached input pointers
// prepare_input_tensors called ONCE, then reuse pointers for all 17×58 steps
// CP step: prepare once, direct output→input KV copy
int ckvElem=C_KV*C_KV_LEN*C_HD;
std::vector<float> ckv(C_L*2*ckvElem,0);
{auto prep=executorch::extension::prepare_input_tensors(*cMethod);} // first alloc
// Cache input data pointers (stable after prepare)
{auto prep=executorch::extension::prepare_input_tensors(*cMethod);}
float* cpInEmb = cMethod->mutable_input(0).toTensor().mutable_data_ptr<float>();
float* cpInMask = cMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
float* cpInCos = cMethod->mutable_input(2).toTensor().mutable_data_ptr<float>();
@ -781,19 +780,26 @@ ExecuTorchJni::runTtsPipelineImpl(
}
auto cpStep = [&](const float* h, int cb0, int* codes) {
memset(ckv.data(), 0, ckv.size()*4); // reset KV caches
// Reset CP KV to zeros for step 0
for(int i=0;i<C_L*2;i++) memset(cpInKV[i], 0, ckvElem*4);
for(int step=0;step<17;step++){
const float*emb;
if(step==0) emb=h;
else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM;
else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
// Write directly to cached pointers (no prepare_input_tensors!)
memcpy(cpInEmb, emb, DIM*4);
for(int p=0;p<C_KV_LEN;p++) cpInMask[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
memcpy(cpInCos, cCos.data()+step*C_HD, C_HD*4);
memcpy(cpInSin, cSin.data()+step*C_HD, C_HD*4);
for(int i=0;i<C_L*2;i++) memcpy(cpInKV[i], ckv.data()+i*ckvElem, ckvElem*4);
// KV: copy from previous output directly to input (skip buffer)
if(step>0){
for(int i=0;i<C_L;i++){
memcpy(cpInKV[i*2], cMethod->get_output(1+i*2).toTensor().const_data_ptr<float>(), ckvElem*4);
memcpy(cpInKV[i*2+1], cMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), ckvElem*4);
}
}
auto status=cMethod->execute();
if(status!=Error::Ok) break;
@ -804,10 +810,6 @@ ExecuTorchJni::runTtsPipelineImpl(
for(int j=0;j<CB_SIZE;j++){float d=tts_dot_neon(ho,W+j*DIM,DIM);if(d>bv){bv=d;best=j;}}
codes[step-1]=best;
}
for(int i=0;i<C_L;i++){
memcpy(ckv.data()+(i*2)*ckvElem,cMethod->get_output(1+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
}
}
};