/** * Native TTS pipeline: talker + CP autoregressive loop in C++. * One JNI call runs the entire generation → returns all codebook codes. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define TAG "TtsPipeline" #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::HierarchicalAllocator; using executorch::runtime::MemoryAllocator; using executorch::runtime::MemoryManager; using executorch::runtime::Method; using executorch::runtime::Program; using executorch::runtime::Span; static const int DIM=1024, VOCAB=3072, CB_SIZE=2048, NUM_CB=16; static const int T_L=28, T_KV=8, T_HD=128, T_KV_LEN=100; static const int C_L=5, C_KV=8, C_HD=128, C_KV_LEN=16; static const int CODEC_EOS=2150; static inline float dot_neon(const float* a, const float* b, int n) { float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0); int i=0; for(;i+15bv){bv=d;best=j;}} return best; } // Top-k sampling with temperature (Java-compatible PRNG) static uint64_t g_rng_state = 0x12345678ABCDEF01ULL; static float next_rand() { // Java-style LCG for reproducibility g_rng_state = g_rng_state * 6364136223846793005ULL + 1442695040888963407ULL; return (float)((g_rng_state >> 33) & 0x7FFFFFFF) / (float)0x7FFFFFFF; } static int sample_topk(const float* logits, int vocab, float temp, int k) { struct IV { int i; float v; }; std::vector topk(k, {0, -FLT_MAX}); for (int i = 0; i < vocab; i++) { if (logits[i] > topk[k-1].v) { topk[k-1] = {i, logits[i]}; for (int j = k-2; j >= 0; j--) { if (topk[j+1].v > topk[j].v) std::swap(topk[j], topk[j+1]); else break; } } } float maxv = topk[0].v; float sum = 0; for (auto& t : topk) { t.v = expf((t.v - maxv) / temp); sum += t.v; } float r = next_rand() * sum; float acc = 0; for (auto& t : topk) { acc += t.v; if (acc >= r) return t.i; } return topk[0].i; } struct PipelineState { std::unique_ptr tLoader, cLoader; std::unique_ptr tProg, cProg; std::unique_ptr tMM, cMM; Method* talker = nullptr; Method* cp = nullptr; std::vector> tBufs, cBufs; bool loaded = false; }; static PipelineState* gState = nullptr; static uint8_t tMethodPool[8*1024*1024], tTempPool[2*1024*1024]; static uint8_t cMethodPool[4*1024*1024], cTempPool[1*1024*1024]; static Method* loadModel(const char* path, std::unique_ptr& loader, std::unique_ptr& program, std::unique_ptr& mm, std::vector>& bufs, uint8_t* mp, size_t mps, uint8_t* tp, size_t tps) { auto ld = executorch::extension::FileDataLoader::from(path); if(!ld.ok()) return nullptr; loader=std::make_unique(std::move(ld.get())); auto prog=Program::load(&*loader); if(!prog.ok()) return nullptr; program=std::make_unique(std::move(prog.get())); auto meta=program->method_meta("forward"); if(!meta.ok()) return nullptr; std::vector> spans; for(size_t i=0;inum_memory_planned_buffers();i++){ size_t sz=(size_t)meta->memory_planned_buffer_size(i).get(); bufs.push_back(std::make_unique(sz)); spans.push_back({bufs.back().get(),sz}); } auto*ma=new MemoryAllocator(mps,mp); auto*ta=new MemoryAllocator(tps,tp); auto*ha=new HierarchicalAllocator({spans.data(),spans.size()}); mm=std::unique_ptr(new MemoryManager(ma,ha,ta)); auto method=program->load_method("forward",mm.get()); if(!method.ok()) return nullptr; return new Method(std::move(method.get())); } extern "C" { JNIEXPORT jboolean JNICALL Java_com_kazeia_tts_TtsPipeline_nativeInit(JNIEnv*env,jclass,jstring jTP,jstring jCP){ executorch::runtime::runtime_init(); if(gState&&gState->loaded) return JNI_TRUE; const char*tp=env->GetStringUTFChars(jTP,nullptr); const char*cp=env->GetStringUTFChars(jCP,nullptr); gState=new PipelineState(); LOGI("Loading talker+CP..."); auto t0=std::chrono::high_resolution_clock::now(); gState->talker=loadModel(tp,gState->tLoader,gState->tProg,gState->tMM,gState->tBufs,tMethodPool,sizeof(tMethodPool),tTempPool,sizeof(tTempPool)); gState->cp=loadModel(cp,gState->cLoader,gState->cProg,gState->cMM,gState->cBufs,cMethodPool,sizeof(cMethodPool),cTempPool,sizeof(cTempPool)); env->ReleaseStringUTFChars(jTP,tp); env->ReleaseStringUTFChars(jCP,cp); if(!gState->talker||!gState->cp){LOGI("Load failed");delete gState;gState=nullptr;return JNI_FALSE;} gState->loaded=true; auto t1=std::chrono::high_resolution_clock::now(); LOGI("Models loaded: %.0fms (no warmup — first forward will be slower)",std::chrono::duration(t1-t0).count()); return JNI_TRUE; } JNIEXPORT void JNICALL Java_com_kazeia_tts_TtsPipeline_nativeDestroy(JNIEnv*,jclass){ if(gState){delete gState->talker;delete gState->cp;delete gState;gState=nullptr;} } // Helper: run one talker step static void talkerStep(Method&m, const float*emb, float*mask, int pos, const float*tCos, const float*tSin, float*tK, float*tV, float*outHidden, float*outLogits) { int kvElem = T_KV * T_KV_LEN * T_HD; auto prep = executorch::extension::prepare_input_tensors(m); memcpy(m.mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4); memcpy(m.mutable_input(1).toTensor().mutable_data_ptr(), mask, T_KV_LEN*4); int pi = std::min(pos, 249); memcpy(m.mutable_input(2).toTensor().mutable_data_ptr(), tCos+pi*T_HD, T_HD*4); memcpy(m.mutable_input(3).toTensor().mutable_data_ptr(), tSin+pi*T_HD, T_HD*4); for(int i=0;i(), tK+i*kvElem, kvElem*4); memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr(), tV+i*kvElem, kvElem*4); } m.execute(); memcpy(outHidden, m.get_output(0).toTensor().const_data_ptr(), DIM*4); memcpy(outLogits, m.get_output(1).toTensor().const_data_ptr(), VOCAB*4); for(int i=0;i(), kvElem*4); memcpy(tV+i*kvElem, m.get_output(3+i*2).toTensor().const_data_ptr(), kvElem*4); } } // Helper: run full CP (17 steps) → 15 codes static void cpForward(Method&m, const float*hidden, int cb0, const float*codecEmb, const float*cpEmbs, const float*cpHeads, const float*cCos, const float*cSin, int*codes) { int kvElem = C_KV * C_KV_LEN * C_HD; std::vector kv(C_L*2*kvElem, 0.0f); for(int step=0;step<17;step++){ const float*emb; if(step==0) emb=hidden; else if(step==1) emb=codecEmb + std::min(std::max(cb0,0),VOCAB-1)*DIM; else emb=cpEmbs + ((step-2)*CB_SIZE + std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM; auto prep=executorch::extension::prepare_input_tensors(m); memcpy(m.mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4); // Mask float*mp=m.mutable_input(1).toTensor().mutable_data_ptr(); for(int p=0;p=C_KV_LEN-1-step)?0.0f:-1e9f; memcpy(m.mutable_input(2).toTensor().mutable_data_ptr(), cCos+step*C_HD, C_HD*4); memcpy(m.mutable_input(3).toTensor().mutable_data_ptr(), cSin+step*C_HD, C_HD*4); for(int i=0;i(), kv.data()+(i*2)*kvElem, kvElem*4); memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr(), kv.data()+(i*2+1)*kvElem, kvElem*4); } m.execute(); const float*h=m.get_output(0).toTensor().const_data_ptr(); if(step>=1&&step-1<15){ codes[step-1]=argmax_head(h, cpHeads+(step-1)*CB_SIZE*DIM, CB_SIZE, DIM); } for(int i=0;i(), kvElem*4); memcpy(kv.data()+(i*2+1)*kvElem, m.get_output(2+i*2).toTensor().const_data_ptr(), kvElem*4); } } } JNIEXPORT jintArray JNICALL Java_com_kazeia_tts_TtsPipeline_nativeRun( JNIEnv*env,jclass, jfloatArray jPrefill,jint nPrefill, jfloatArray jTrailing,jint nTrailing, jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads, jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin, jfloatArray jEosEmbed, jfloatArray jPadEmbed, jint maxTokens) { if(!gState||!gState->loaded) return nullptr; auto T0=std::chrono::high_resolution_clock::now(); // Copy all data from JNI (then release immediately) int prefillSize = env->GetArrayLength(jPrefill); std::vector prefill(prefillSize); env->GetFloatArrayRegion(jPrefill, 0, prefillSize, prefill.data()); std::vector trailingData; if(nTrailing>0){trailingData.resize(nTrailing*DIM);env->GetFloatArrayRegion(jTrailing,0,nTrailing*DIM,trailingData.data());} int codecSize=env->GetArrayLength(jCodecEmb); std::vector codecEmb(codecSize); env->GetFloatArrayRegion(jCodecEmb,0,codecSize,codecEmb.data()); int cpEmbsSize=env->GetArrayLength(jCpEmbs); std::vector cpEmbs(cpEmbsSize); env->GetFloatArrayRegion(jCpEmbs,0,cpEmbsSize,cpEmbs.data()); int headsSize=env->GetArrayLength(jCpHeads); std::vector cpHeads(headsSize); env->GetFloatArrayRegion(jCpHeads,0,headsSize,cpHeads.data()); int tcSize=env->GetArrayLength(jTCos); std::vector tCos(tcSize),tSin(tcSize); env->GetFloatArrayRegion(jTCos,0,tcSize,tCos.data()); env->GetFloatArrayRegion(jTSin,0,tcSize,tSin.data()); int ccSize=env->GetArrayLength(jCCos); std::vector cCos(ccSize),cSin(ccSize); env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data()); env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data()); std::vector eosEmbed(DIM,0), padEmbed(DIM,0); if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data()); if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data()); // Pipeline state int tkvElem=T_KV*T_KV_LEN*T_HD; std::vector tK(T_L*tkvElem,0), tV(T_L*tkvElem,0); float mask[T_KV_LEN]; for(int i=0;i allCodes; // flat: numTokens × 16 std::vector cb0History; int pos=0, currentCb0=-1; // ===== PREFILL ===== auto tP0=std::chrono::high_resolution_clock::now(); for(int step=0;step=0) mask[mi]=0.0f; talkerStep(*gState->talker, prefill.data()+step*DIM, mask, pos, tCos.data(),tSin.data(), tK.data(), tV.data(), hidden, logits); pos++; if(step==nPrefill-1){ for(int j=CB_SIZE;j(tP1-tP0).count(), nPrefill, currentCb0, hidden[0],hidden[1],hidden[2],hidden[3]); if(currentCb0<0||currentCb0==CODEC_EOS){return env->NewIntArray(0);} // ===== GENERATION ===== float totalTalkerMs=0, totalCpMs=0; int trailingIdx=0; // Pad embedding (zeros + pad token is not available here, use zeros) float padEmb[DIM]={}; // In practice should be tts_pad_embed, passed as param for(int gen=0;gencp, hidden, currentCb0, codecEmb.data(), cpEmbs.data(), cpHeads.data(), cCos.data(), cSin.data(), cpCodes); auto tc1=std::chrono::high_resolution_clock::now(); totalCpMs+=std::chrono::duration(tc1-tc0).count(); for(int i=0;i<15;i++) codes[i+1]=cpCodes[i]; for(int i=0;i=0) mask[mi]=0.0f; auto tt0=std::chrono::high_resolution_clock::now(); talkerStep(*gState->talker, nextEmb, mask, pos, tCos.data(),tSin.data(), tK.data(), tV.data(), hidden, logits); auto tt1=std::chrono::high_resolution_clock::now(); totalTalkerMs+=std::chrono::duration(tt1-tt0).count(); pos++; // Next cb0: suppress non-codec, repetition penalty, top-k sampling for(int j=CB_SIZE;j seen(cb0History.begin(),cb0History.end()); for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f; int nextCb0=sample_topk(logits,VOCAB,0.9f,50); if(nextCb0==CODEC_EOS){LOGI("EOS at step %d",gen+2);break;} // Degeneration check int histSz=(int)cb0History.size(); if(histSz>=9){ bool degen=true; for(int i=histSz-9;i(T1-T0).count()); jintArray result=env->NewIntArray((int)allCodes.size()); env->SetIntArrayRegion(result,0,(int)allCodes.size(),allCodes.data()); return result; } } // extern "C"