diff --git a/executorch-custom/tts_pipeline_jni.cpp b/executorch-custom/tts_pipeline_jni.cpp index 4c4f0c6..91d5970 100644 --- a/executorch-custom/tts_pipeline_jni.cpp +++ b/executorch-custom/tts_pipeline_jni.cpp @@ -222,6 +222,7 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun( jfloatArray jTrailing,jint nTrailing, jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads, jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin, + jfloatArray jEosEmbed, jfloatArray jPadEmbed, jint maxTokens) { if(!gState||!gState->loaded) return nullptr; @@ -257,6 +258,10 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun( env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data()); env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data()); + std::vector eosEmbed(DIM,0), padEmbed(DIM,0); + if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data()); + if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data()); + // Pipeline state int tkvElem=T_KV*T_KV_LEN*T_HD; std::vector tK(T_L*tkvElem,0), tV(T_L*tkvElem,0); @@ -317,13 +322,17 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun( const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM; for(int k=0;k