diff --git a/executorch-custom/jni_layer_tts.cpp b/executorch-custom/jni_layer_tts.cpp index b5fa121..78e6604 100644 --- a/executorch-custom/jni_layer_tts.cpp +++ b/executorch-custom/jni_layer_tts.cpp @@ -712,7 +712,7 @@ ExecuTorchJni::runTtsPipelineImpl( auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad); int tkvElem=T_KV*T_KV_LEN*T_HD; - std::vector tK(T_L*tkvElem,0),tV(T_L*tkvElem,0); + // tK/tV not needed — KV copied directly from output to input each step float mask[T_KV_LEN]; for(int i=0;i allCodes,cb0Hist; @@ -749,27 +749,26 @@ ExecuTorchJni::runTtsPipelineImpl( memcpy(tInMask, mask, T_KV_LEN*4); memcpy(tInCos, tCos.data()+pi*T_HD, T_HD*4); memcpy(tInSin, tSin.data()+pi*T_HD, T_HD*4); - for(int i=0;i 0) { + for(int i=0;iget_output(2+i*2).toTensor().const_data_ptr(), tkvElem*4); + memcpy(tInKV[i*2+1], tMethod->get_output(3+i*2).toTensor().const_data_ptr(), tkvElem*4); + } } + // (pos==0: first call, input KV is already zeros from prepare_input_tensors) + auto status = tMethod->execute(); if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;} memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr(), DIM*4); memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr(), VOCAB*4); - for(int i=0;iget_output(2+i*2).toTensor().const_data_ptr(), tkvElem*4); - memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr(), tkvElem*4); - } + // KV NOT copied to tK/tV — read from output directly next step pos++; }; - // CP step: 17 autoregressive steps with cached input pointers - // prepare_input_tensors called ONCE, then reuse pointers for all 17×58 steps + // CP step: prepare once, direct output→input KV copy int ckvElem=C_KV*C_KV_LEN*C_HD; - std::vector ckv(C_L*2*ckvElem,0); - {auto prep=executorch::extension::prepare_input_tensors(*cMethod);} // first alloc - // Cache input data pointers (stable after prepare) + {auto prep=executorch::extension::prepare_input_tensors(*cMethod);} float* cpInEmb = cMethod->mutable_input(0).toTensor().mutable_data_ptr(); float* cpInMask = cMethod->mutable_input(1).toTensor().mutable_data_ptr(); float* cpInCos = cMethod->mutable_input(2).toTensor().mutable_data_ptr(); @@ -781,19 +780,26 @@ ExecuTorchJni::runTtsPipelineImpl( } auto cpStep = [&](const float* h, int cb0, int* codes) { - memset(ckv.data(), 0, ckv.size()*4); // reset KV caches + // Reset CP KV to zeros for step 0 + for(int i=0;i=C_KV_LEN-1-step)?0.0f:-1e9f; memcpy(cpInCos, cCos.data()+step*C_HD, C_HD*4); memcpy(cpInSin, cSin.data()+step*C_HD, C_HD*4); - for(int i=0;i0){ + for(int i=0;iget_output(1+i*2).toTensor().const_data_ptr(), ckvElem*4); + memcpy(cpInKV[i*2+1], cMethod->get_output(2+i*2).toTensor().const_data_ptr(), ckvElem*4); + } + } auto status=cMethod->execute(); if(status!=Error::Ok) break; @@ -804,10 +810,6 @@ ExecuTorchJni::runTtsPipelineImpl( for(int j=0;jbv){bv=d;best=j;}} codes[step-1]=best; } - for(int i=0;iget_output(1+i*2).toTensor().const_data_ptr(),ckvElem*4); - memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr(),ckvElem*4); - } } };