kazeia/executorch-custom/tts_pipeline_jni.cpp

377 lines
16 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Native TTS pipeline: talker + CP autoregressive loop in C++.
* One JNI call runs the entire generation → returns all codebook codes.
*/
#include <jni.h>
#include <arm_neon.h>
#include <android/log.h>
#include <cstring>
#include <cstdlib>
#include <cfloat>
#include <cmath>
#include <chrono>
#include <memory>
#include <vector>
#include <unordered_set>
#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/runner_util/inputs.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/runtime.h>
#define TAG "TtsPipeline"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::HierarchicalAllocator;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::MemoryManager;
using executorch::runtime::Method;
using executorch::runtime::Program;
using executorch::runtime::Span;
static const int DIM=1024, VOCAB=3072, CB_SIZE=2048, NUM_CB=16;
static const int T_L=28, T_KV=8, T_HD=128, T_KV_LEN=100;
static const int C_L=5, C_KV=8, C_HD=128, C_KV_LEN=16;
static const int CODEC_EOS=2150;
static inline float dot_neon(const float* a, const float* b, int n) {
float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
int i=0;
for(;i+15<n;i+=16){
s0=vfmaq_f32(s0,vld1q_f32(a+i),vld1q_f32(b+i));
s1=vfmaq_f32(s1,vld1q_f32(a+i+4),vld1q_f32(b+i+4));
s2=vfmaq_f32(s2,vld1q_f32(a+i+8),vld1q_f32(b+i+8));
s3=vfmaq_f32(s3,vld1q_f32(a+i+12),vld1q_f32(b+i+12));
}
float r=vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1),vaddq_f32(s2,s3)));
for(;i<n;i++) r+=a[i]*b[i];
return r;
}
static int argmax_head(const float*h,const float*W,int vocab,int dim){
int best=0;float bv=-FLT_MAX;
for(int j=0;j<vocab;j++){float d=dot_neon(h,W+j*dim,dim);if(d>bv){bv=d;best=j;}}
return best;
}
// Top-k sampling with temperature
static int sample_topk(const float* logits, int vocab, float temp, int k) {
// Find top-k
struct IV { int i; float v; };
std::vector<IV> topk(k, {0, -FLT_MAX});
for (int i = 0; i < vocab; i++) {
if (logits[i] > topk[k-1].v) {
topk[k-1] = {i, logits[i]};
// Bubble up
for (int j = k-2; j >= 0; j--) {
if (topk[j+1].v > topk[j].v) std::swap(topk[j], topk[j+1]);
else break;
}
}
}
// Softmax with temperature
float maxv = topk[0].v;
float sum = 0;
for (auto& t : topk) { t.v = expf((t.v - maxv) / temp); sum += t.v; }
// Sample
float r = (float)rand() / RAND_MAX * sum;
float acc = 0;
for (auto& t : topk) { acc += t.v; if (acc >= r) return t.i; }
return topk[0].i;
}
struct PipelineState {
std::unique_ptr<executorch::extension::FileDataLoader> tLoader, cLoader;
std::unique_ptr<Program> tProg, cProg;
std::unique_ptr<MemoryManager> tMM, cMM;
Method* talker = nullptr;
Method* cp = nullptr;
std::vector<std::unique_ptr<uint8_t[]>> tBufs, cBufs;
bool loaded = false;
};
static PipelineState* gState = nullptr;
static uint8_t tMethodPool[8*1024*1024], tTempPool[2*1024*1024];
static uint8_t cMethodPool[4*1024*1024], cTempPool[1*1024*1024];
static Method* loadModel(const char* path,
std::unique_ptr<executorch::extension::FileDataLoader>& loader,
std::unique_ptr<Program>& program,
std::unique_ptr<MemoryManager>& mm,
std::vector<std::unique_ptr<uint8_t[]>>& bufs,
uint8_t* mp, size_t mps, uint8_t* tp, size_t tps)
{
auto ld = executorch::extension::FileDataLoader::from(path);
if(!ld.ok()) return nullptr;
loader=std::make_unique<executorch::extension::FileDataLoader>(std::move(ld.get()));
auto prog=Program::load(&*loader); if(!prog.ok()) return nullptr;
program=std::make_unique<Program>(std::move(prog.get()));
auto meta=program->method_meta("forward"); if(!meta.ok()) return nullptr;
std::vector<Span<uint8_t>> spans;
for(size_t i=0;i<meta->num_memory_planned_buffers();i++){
size_t sz=(size_t)meta->memory_planned_buffer_size(i).get();
bufs.push_back(std::make_unique<uint8_t[]>(sz));
spans.push_back({bufs.back().get(),sz});
}
auto*ma=new MemoryAllocator(mps,mp);
auto*ta=new MemoryAllocator(tps,tp);
auto*ha=new HierarchicalAllocator({spans.data(),spans.size()});
mm=std::unique_ptr<MemoryManager>(new MemoryManager(ma,ha,ta));
auto method=program->load_method("forward",mm.get());
if(!method.ok()) return nullptr;
return new Method(std::move(method.get()));
}
extern "C" {
JNIEXPORT jboolean JNICALL
Java_com_kazeia_tts_TtsPipeline_nativeInit(JNIEnv*env,jclass,jstring jTP,jstring jCP){
executorch::runtime::runtime_init();
if(gState&&gState->loaded) return JNI_TRUE;
const char*tp=env->GetStringUTFChars(jTP,nullptr);
const char*cp=env->GetStringUTFChars(jCP,nullptr);
gState=new PipelineState();
LOGI("Loading talker+CP...");
auto t0=std::chrono::high_resolution_clock::now();
gState->talker=loadModel(tp,gState->tLoader,gState->tProg,gState->tMM,gState->tBufs,tMethodPool,sizeof(tMethodPool),tTempPool,sizeof(tTempPool));
gState->cp=loadModel(cp,gState->cLoader,gState->cProg,gState->cMM,gState->cBufs,cMethodPool,sizeof(cMethodPool),cTempPool,sizeof(cTempPool));
env->ReleaseStringUTFChars(jTP,tp); env->ReleaseStringUTFChars(jCP,cp);
if(!gState->talker||!gState->cp){LOGI("Load failed");delete gState;gState=nullptr;return JNI_FALSE;}
gState->loaded=true;
// Warmup both
{auto p=executorch::extension::prepare_input_tensors(*gState->talker);if(p.ok())gState->talker->execute();}
{auto p=executorch::extension::prepare_input_tensors(*gState->cp);if(p.ok())gState->cp->execute();}
auto t1=std::chrono::high_resolution_clock::now();
LOGI("Loaded+warmup: %.0fms",std::chrono::duration<float,std::milli>(t1-t0).count());
return JNI_TRUE;
}
JNIEXPORT void JNICALL
Java_com_kazeia_tts_TtsPipeline_nativeDestroy(JNIEnv*,jclass){
if(gState){delete gState->talker;delete gState->cp;delete gState;gState=nullptr;}
}
// Helper: run one talker step
static void talkerStep(Method&m, const float*emb, float*mask, int pos,
const float*tCos, const float*tSin, float*tK, float*tV,
float*outHidden, float*outLogits)
{
int kvElem = T_KV * T_KV_LEN * T_HD;
auto prep = executorch::extension::prepare_input_tensors(m);
memcpy(m.mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
memcpy(m.mutable_input(1).toTensor().mutable_data_ptr<float>(), mask, T_KV_LEN*4);
int pi = std::min(pos, 249);
memcpy(m.mutable_input(2).toTensor().mutable_data_ptr<float>(), tCos+pi*T_HD, T_HD*4);
memcpy(m.mutable_input(3).toTensor().mutable_data_ptr<float>(), tSin+pi*T_HD, T_HD*4);
for(int i=0;i<T_L;i++){
memcpy(m.mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), tK+i*kvElem, kvElem*4);
memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), tV+i*kvElem, kvElem*4);
}
m.execute();
memcpy(outHidden, m.get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
memcpy(outLogits, m.get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
for(int i=0;i<T_L;i++){
memcpy(tK+i*kvElem, m.get_output(2+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
memcpy(tV+i*kvElem, m.get_output(3+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
}
}
// Helper: run full CP (17 steps) → 15 codes
static void cpForward(Method&m, const float*hidden, int cb0,
const float*codecEmb, const float*cpEmbs, const float*cpHeads,
const float*cCos, const float*cSin, int*codes)
{
int kvElem = C_KV * C_KV_LEN * C_HD;
std::vector<float> kv(C_L*2*kvElem, 0.0f);
for(int step=0;step<17;step++){
const float*emb;
if(step==0) emb=hidden;
else if(step==1) emb=codecEmb + std::min(std::max(cb0,0),VOCAB-1)*DIM;
else emb=cpEmbs + ((step-2)*CB_SIZE + std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
auto prep=executorch::extension::prepare_input_tensors(m);
memcpy(m.mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
// Mask
float*mp=m.mutable_input(1).toTensor().mutable_data_ptr<float>();
for(int p=0;p<C_KV_LEN;p++) mp[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
memcpy(m.mutable_input(2).toTensor().mutable_data_ptr<float>(), cCos+step*C_HD, C_HD*4);
memcpy(m.mutable_input(3).toTensor().mutable_data_ptr<float>(), cSin+step*C_HD, C_HD*4);
for(int i=0;i<C_L;i++){
memcpy(m.mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), kv.data()+(i*2)*kvElem, kvElem*4);
memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), kv.data()+(i*2+1)*kvElem, kvElem*4);
}
m.execute();
const float*h=m.get_output(0).toTensor().const_data_ptr<float>();
if(step>=1&&step-1<15){
codes[step-1]=argmax_head(h, cpHeads+(step-1)*CB_SIZE*DIM, CB_SIZE, DIM);
}
for(int i=0;i<C_L;i++){
memcpy(kv.data()+(i*2)*kvElem, m.get_output(1+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
memcpy(kv.data()+(i*2+1)*kvElem, m.get_output(2+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
}
}
}
JNIEXPORT jintArray JNICALL
Java_com_kazeia_tts_TtsPipeline_nativeRun(
JNIEnv*env,jclass,
jfloatArray jPrefill,jint nPrefill,
jfloatArray jTrailing,jint nTrailing,
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
jfloatArray jEosEmbed, jfloatArray jPadEmbed,
jint maxTokens)
{
if(!gState||!gState->loaded) return nullptr;
auto T0=std::chrono::high_resolution_clock::now();
// Copy all data from JNI (then release immediately)
int prefillSize = env->GetArrayLength(jPrefill);
std::vector<float> prefill(prefillSize);
env->GetFloatArrayRegion(jPrefill, 0, prefillSize, prefill.data());
std::vector<float> trailingData;
if(nTrailing>0){trailingData.resize(nTrailing*DIM);env->GetFloatArrayRegion(jTrailing,0,nTrailing*DIM,trailingData.data());}
int codecSize=env->GetArrayLength(jCodecEmb);
std::vector<float> codecEmb(codecSize);
env->GetFloatArrayRegion(jCodecEmb,0,codecSize,codecEmb.data());
int cpEmbsSize=env->GetArrayLength(jCpEmbs);
std::vector<float> cpEmbs(cpEmbsSize);
env->GetFloatArrayRegion(jCpEmbs,0,cpEmbsSize,cpEmbs.data());
int headsSize=env->GetArrayLength(jCpHeads);
std::vector<float> cpHeads(headsSize);
env->GetFloatArrayRegion(jCpHeads,0,headsSize,cpHeads.data());
int tcSize=env->GetArrayLength(jTCos);
std::vector<float> tCos(tcSize),tSin(tcSize);
env->GetFloatArrayRegion(jTCos,0,tcSize,tCos.data());
env->GetFloatArrayRegion(jTSin,0,tcSize,tSin.data());
int ccSize=env->GetArrayLength(jCCos);
std::vector<float> cCos(ccSize),cSin(ccSize);
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
std::vector<float> eosEmbed(DIM,0), padEmbed(DIM,0);
if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data());
if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data());
// Pipeline state
int tkvElem=T_KV*T_KV_LEN*T_HD;
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
float hidden[DIM]={}, logits[VOCAB]={};
std::vector<int> allCodes; // flat: numTokens × 16
std::vector<int> cb0History;
int pos=0, currentCb0=-1;
// ===== PREFILL =====
auto tP0=std::chrono::high_resolution_clock::now();
for(int step=0;step<nPrefill;step++){
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
if(mi>=0) mask[mi]=0.0f;
talkerStep(*gState->talker, prefill.data()+step*DIM, mask, pos, tCos.data(),tSin.data(),
tK.data(), tV.data(), hidden, logits);
pos++;
if(step==nPrefill-1){
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
currentCb0=sample_topk(logits,VOCAB,0.9f,50);
}
}
auto tP1=std::chrono::high_resolution_clock::now();
LOGI("Prefill: %.0fms, %d steps, cb0=%d",
std::chrono::duration<float,std::milli>(tP1-tP0).count(), nPrefill, currentCb0);
if(currentCb0<0||currentCb0==CODEC_EOS){return env->NewIntArray(0);}
// ===== GENERATION =====
float totalTalkerMs=0, totalCpMs=0;
int trailingIdx=0;
// Pad embedding (zeros + pad token is not available here, use zeros)
float padEmb[DIM]={}; // In practice should be tts_pad_embed, passed as param
for(int gen=0;gen<maxTokens;gen++){
int codes[NUM_CB]={}; codes[0]=currentCb0;
// CP
auto tc0=std::chrono::high_resolution_clock::now();
int cpCodes[15]={};
cpForward(*gState->cp, hidden, currentCb0, codecEmb.data(), cpEmbs.data(), cpHeads.data(),
cCos.data(), cSin.data(), cpCodes);
auto tc1=std::chrono::high_resolution_clock::now();
totalCpMs+=std::chrono::duration<float,std::milli>(tc1-tc0).count();
for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
cb0History.push_back(currentCb0);
// Build next talker input: sum codec embeddings
float nextEmb[DIM]={};
// cb0 embedding
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
// cb1-15 embeddings
for(int cb=0;cb<15;cb++){
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
}
// Add trailing text, then eos, then pad (matches Python/Kotlin pipeline)
if(trailingIdx<nTrailing){
const float*te=trailingData.data()+trailingIdx*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
trailingIdx++;
} else if(trailingIdx==nTrailing){
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmbed[k];
trailingIdx++;
} else {
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmbed[k];
}
// Talker step
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
if(mi>=0) mask[mi]=0.0f;
auto tt0=std::chrono::high_resolution_clock::now();
talkerStep(*gState->talker, nextEmb, mask, pos, tCos.data(),tSin.data(),
tK.data(), tV.data(), hidden, logits);
auto tt1=std::chrono::high_resolution_clock::now();
totalTalkerMs+=std::chrono::duration<float,std::milli>(tt1-tt0).count();
pos++;
// Sample next cb0 (suppress non-codec, repetition penalty)
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
std::unordered_set<int> seen(cb0History.begin(),cb0History.end());
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
int nextCb0=sample_topk(logits,VOCAB,0.9f,50);
if(nextCb0==CODEC_EOS){LOGI("EOS at step %d",gen+2);break;}
// Degeneration check
int histSz=(int)cb0History.size();
if(histSz>=9){
bool degen=true;
for(int i=histSz-9;i<histSz;i++) if(cb0History[i]!=nextCb0){degen=false;break;}
if(degen){LOGI("Degeneration at step %d",gen+2);break;}
}
currentCb0=nextCb0;
}
int nTokens=(int)allCodes.size()/NUM_CB;
auto T1=std::chrono::high_resolution_clock::now();
LOGI("Generated %d tokens | Talker: %.0fms (%.0fms/step) | CP: %.0fms (%.0fms/step) | Total: %.0fms",
nTokens, totalTalkerMs, totalTalkerMs/std::max(nTokens,1),
totalCpMs, totalCpMs/std::max(nTokens,1),
std::chrono::duration<float,std::milli>(T1-T0).count());
jintArray result=env->NewIntArray((int)allCodes.size());
env->SetIntArrayRegion(result,0,(int)allCodes.size(),allCodes.data());
return result;
}
} // extern "C"