382 lines
16 KiB
C++
382 lines
16 KiB
C++
/**
|
||
* Native TTS pipeline: talker + CP autoregressive loop in C++.
|
||
* One JNI call runs the entire generation → returns all codebook codes.
|
||
*/
|
||
#include <jni.h>
|
||
#include <arm_neon.h>
|
||
#include <android/log.h>
|
||
#include <cstring>
|
||
#include <cstdlib>
|
||
#include <cfloat>
|
||
#include <cmath>
|
||
#include <chrono>
|
||
#include <memory>
|
||
#include <vector>
|
||
#include <unordered_set>
|
||
|
||
#include <executorch/extension/data_loader/file_data_loader.h>
|
||
#include <executorch/extension/runner_util/inputs.h>
|
||
#include <executorch/runtime/executor/method.h>
|
||
#include <executorch/runtime/executor/program.h>
|
||
#include <executorch/runtime/platform/runtime.h>
|
||
|
||
#define TAG "TtsPipeline"
|
||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||
|
||
using executorch::runtime::Error;
|
||
using executorch::runtime::EValue;
|
||
using executorch::runtime::HierarchicalAllocator;
|
||
using executorch::runtime::MemoryAllocator;
|
||
using executorch::runtime::MemoryManager;
|
||
using executorch::runtime::Method;
|
||
using executorch::runtime::Program;
|
||
using executorch::runtime::Span;
|
||
|
||
static const int DIM=1024, VOCAB=3072, CB_SIZE=2048, NUM_CB=16;
|
||
static const int T_L=28, T_KV=8, T_HD=128, T_KV_LEN=100;
|
||
static const int C_L=5, C_KV=8, C_HD=128, C_KV_LEN=16;
|
||
static const int CODEC_EOS=2150;
|
||
|
||
static inline float dot_neon(const float* a, const float* b, int n) {
|
||
float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
|
||
int i=0;
|
||
for(;i+15<n;i+=16){
|
||
s0=vfmaq_f32(s0,vld1q_f32(a+i),vld1q_f32(b+i));
|
||
s1=vfmaq_f32(s1,vld1q_f32(a+i+4),vld1q_f32(b+i+4));
|
||
s2=vfmaq_f32(s2,vld1q_f32(a+i+8),vld1q_f32(b+i+8));
|
||
s3=vfmaq_f32(s3,vld1q_f32(a+i+12),vld1q_f32(b+i+12));
|
||
}
|
||
float r=vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1),vaddq_f32(s2,s3)));
|
||
for(;i<n;i++) r+=a[i]*b[i];
|
||
return r;
|
||
}
|
||
|
||
static int argmax_head(const float*h,const float*W,int vocab,int dim){
|
||
int best=0;float bv=-FLT_MAX;
|
||
for(int j=0;j<vocab;j++){float d=dot_neon(h,W+j*dim,dim);if(d>bv){bv=d;best=j;}}
|
||
return best;
|
||
}
|
||
|
||
// Top-k sampling with temperature (Java-compatible PRNG)
|
||
static uint64_t g_rng_state = 0x12345678ABCDEF01ULL;
|
||
static float next_rand() {
|
||
// Java-style LCG for reproducibility
|
||
g_rng_state = g_rng_state * 6364136223846793005ULL + 1442695040888963407ULL;
|
||
return (float)((g_rng_state >> 33) & 0x7FFFFFFF) / (float)0x7FFFFFFF;
|
||
}
|
||
|
||
static int sample_topk(const float* logits, int vocab, float temp, int k) {
|
||
struct IV { int i; float v; };
|
||
std::vector<IV> topk(k, {0, -FLT_MAX});
|
||
for (int i = 0; i < vocab; i++) {
|
||
if (logits[i] > topk[k-1].v) {
|
||
topk[k-1] = {i, logits[i]};
|
||
for (int j = k-2; j >= 0; j--) {
|
||
if (topk[j+1].v > topk[j].v) std::swap(topk[j], topk[j+1]);
|
||
else break;
|
||
}
|
||
}
|
||
}
|
||
float maxv = topk[0].v;
|
||
float sum = 0;
|
||
for (auto& t : topk) { t.v = expf((t.v - maxv) / temp); sum += t.v; }
|
||
float r = next_rand() * sum;
|
||
float acc = 0;
|
||
for (auto& t : topk) { acc += t.v; if (acc >= r) return t.i; }
|
||
return topk[0].i;
|
||
}
|
||
|
||
struct PipelineState {
|
||
std::unique_ptr<executorch::extension::FileDataLoader> tLoader, cLoader;
|
||
std::unique_ptr<Program> tProg, cProg;
|
||
std::unique_ptr<MemoryManager> tMM, cMM;
|
||
Method* talker = nullptr;
|
||
Method* cp = nullptr;
|
||
std::vector<std::unique_ptr<uint8_t[]>> tBufs, cBufs;
|
||
bool loaded = false;
|
||
};
|
||
static PipelineState* gState = nullptr;
|
||
static uint8_t tMethodPool[8*1024*1024], tTempPool[2*1024*1024];
|
||
static uint8_t cMethodPool[4*1024*1024], cTempPool[1*1024*1024];
|
||
|
||
static Method* loadModel(const char* path,
|
||
std::unique_ptr<executorch::extension::FileDataLoader>& loader,
|
||
std::unique_ptr<Program>& program,
|
||
std::unique_ptr<MemoryManager>& mm,
|
||
std::vector<std::unique_ptr<uint8_t[]>>& bufs,
|
||
uint8_t* mp, size_t mps, uint8_t* tp, size_t tps)
|
||
{
|
||
auto ld = executorch::extension::FileDataLoader::from(path);
|
||
if(!ld.ok()) return nullptr;
|
||
loader=std::make_unique<executorch::extension::FileDataLoader>(std::move(ld.get()));
|
||
auto prog=Program::load(&*loader); if(!prog.ok()) return nullptr;
|
||
program=std::make_unique<Program>(std::move(prog.get()));
|
||
auto meta=program->method_meta("forward"); if(!meta.ok()) return nullptr;
|
||
std::vector<Span<uint8_t>> spans;
|
||
for(size_t i=0;i<meta->num_memory_planned_buffers();i++){
|
||
size_t sz=(size_t)meta->memory_planned_buffer_size(i).get();
|
||
bufs.push_back(std::make_unique<uint8_t[]>(sz));
|
||
spans.push_back({bufs.back().get(),sz});
|
||
}
|
||
auto*ma=new MemoryAllocator(mps,mp);
|
||
auto*ta=new MemoryAllocator(tps,tp);
|
||
auto*ha=new HierarchicalAllocator({spans.data(),spans.size()});
|
||
mm=std::unique_ptr<MemoryManager>(new MemoryManager(ma,ha,ta));
|
||
auto method=program->load_method("forward",mm.get());
|
||
if(!method.ok()) return nullptr;
|
||
return new Method(std::move(method.get()));
|
||
}
|
||
|
||
extern "C" {
|
||
|
||
JNIEXPORT jboolean JNICALL
|
||
Java_com_kazeia_tts_TtsPipeline_nativeInit(JNIEnv*env,jclass,jstring jTP,jstring jCP){
|
||
executorch::runtime::runtime_init();
|
||
if(gState&&gState->loaded) return JNI_TRUE;
|
||
const char*tp=env->GetStringUTFChars(jTP,nullptr);
|
||
const char*cp=env->GetStringUTFChars(jCP,nullptr);
|
||
gState=new PipelineState();
|
||
LOGI("Loading talker+CP...");
|
||
auto t0=std::chrono::high_resolution_clock::now();
|
||
gState->talker=loadModel(tp,gState->tLoader,gState->tProg,gState->tMM,gState->tBufs,tMethodPool,sizeof(tMethodPool),tTempPool,sizeof(tTempPool));
|
||
gState->cp=loadModel(cp,gState->cLoader,gState->cProg,gState->cMM,gState->cBufs,cMethodPool,sizeof(cMethodPool),cTempPool,sizeof(cTempPool));
|
||
env->ReleaseStringUTFChars(jTP,tp); env->ReleaseStringUTFChars(jCP,cp);
|
||
if(!gState->talker||!gState->cp){LOGI("Load failed");delete gState;gState=nullptr;return JNI_FALSE;}
|
||
gState->loaded=true;
|
||
auto t1=std::chrono::high_resolution_clock::now();
|
||
LOGI("Models loaded: %.0fms (no warmup — first forward will be slower)",std::chrono::duration<float,std::milli>(t1-t0).count());
|
||
return JNI_TRUE;
|
||
}
|
||
|
||
JNIEXPORT void JNICALL
|
||
Java_com_kazeia_tts_TtsPipeline_nativeDestroy(JNIEnv*,jclass){
|
||
if(gState){delete gState->talker;delete gState->cp;delete gState;gState=nullptr;}
|
||
}
|
||
|
||
// Helper: run one talker step
|
||
static void talkerStep(Method&m, const float*emb, float*mask, int pos,
|
||
const float*tCos, const float*tSin, float*tK, float*tV,
|
||
float*outHidden, float*outLogits)
|
||
{
|
||
int kvElem = T_KV * T_KV_LEN * T_HD;
|
||
auto prep = executorch::extension::prepare_input_tensors(m);
|
||
memcpy(m.mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
|
||
memcpy(m.mutable_input(1).toTensor().mutable_data_ptr<float>(), mask, T_KV_LEN*4);
|
||
int pi = std::min(pos, 249);
|
||
memcpy(m.mutable_input(2).toTensor().mutable_data_ptr<float>(), tCos+pi*T_HD, T_HD*4);
|
||
memcpy(m.mutable_input(3).toTensor().mutable_data_ptr<float>(), tSin+pi*T_HD, T_HD*4);
|
||
for(int i=0;i<T_L;i++){
|
||
memcpy(m.mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), tK+i*kvElem, kvElem*4);
|
||
memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), tV+i*kvElem, kvElem*4);
|
||
}
|
||
m.execute();
|
||
memcpy(outHidden, m.get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
|
||
memcpy(outLogits, m.get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
|
||
for(int i=0;i<T_L;i++){
|
||
memcpy(tK+i*kvElem, m.get_output(2+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
|
||
memcpy(tV+i*kvElem, m.get_output(3+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
|
||
}
|
||
}
|
||
|
||
// Helper: run full CP (17 steps) → 15 codes
|
||
static void cpForward(Method&m, const float*hidden, int cb0,
|
||
const float*codecEmb, const float*cpEmbs, const float*cpHeads,
|
||
const float*cCos, const float*cSin, int*codes)
|
||
{
|
||
int kvElem = C_KV * C_KV_LEN * C_HD;
|
||
std::vector<float> kv(C_L*2*kvElem, 0.0f);
|
||
|
||
for(int step=0;step<17;step++){
|
||
const float*emb;
|
||
if(step==0) emb=hidden;
|
||
else if(step==1) emb=codecEmb + std::min(std::max(cb0,0),VOCAB-1)*DIM;
|
||
else emb=cpEmbs + ((step-2)*CB_SIZE + std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
|
||
|
||
auto prep=executorch::extension::prepare_input_tensors(m);
|
||
memcpy(m.mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
|
||
// Mask
|
||
float*mp=m.mutable_input(1).toTensor().mutable_data_ptr<float>();
|
||
for(int p=0;p<C_KV_LEN;p++) mp[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
|
||
memcpy(m.mutable_input(2).toTensor().mutable_data_ptr<float>(), cCos+step*C_HD, C_HD*4);
|
||
memcpy(m.mutable_input(3).toTensor().mutable_data_ptr<float>(), cSin+step*C_HD, C_HD*4);
|
||
for(int i=0;i<C_L;i++){
|
||
memcpy(m.mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), kv.data()+(i*2)*kvElem, kvElem*4);
|
||
memcpy(m.mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), kv.data()+(i*2+1)*kvElem, kvElem*4);
|
||
}
|
||
m.execute();
|
||
const float*h=m.get_output(0).toTensor().const_data_ptr<float>();
|
||
if(step>=1&&step-1<15){
|
||
codes[step-1]=argmax_head(h, cpHeads+(step-1)*CB_SIZE*DIM, CB_SIZE, DIM);
|
||
}
|
||
for(int i=0;i<C_L;i++){
|
||
memcpy(kv.data()+(i*2)*kvElem, m.get_output(1+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
|
||
memcpy(kv.data()+(i*2+1)*kvElem, m.get_output(2+i*2).toTensor().const_data_ptr<float>(), kvElem*4);
|
||
}
|
||
}
|
||
}
|
||
|
||
JNIEXPORT jintArray JNICALL
|
||
Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
||
JNIEnv*env,jclass,
|
||
jfloatArray jPrefill,jint nPrefill,
|
||
jfloatArray jTrailing,jint nTrailing,
|
||
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
|
||
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
|
||
jfloatArray jEosEmbed, jfloatArray jPadEmbed,
|
||
jint maxTokens)
|
||
{
|
||
if(!gState||!gState->loaded) return nullptr;
|
||
auto T0=std::chrono::high_resolution_clock::now();
|
||
|
||
// Copy all data from JNI (then release immediately)
|
||
int prefillSize = env->GetArrayLength(jPrefill);
|
||
std::vector<float> prefill(prefillSize);
|
||
env->GetFloatArrayRegion(jPrefill, 0, prefillSize, prefill.data());
|
||
|
||
std::vector<float> trailingData;
|
||
if(nTrailing>0){trailingData.resize(nTrailing*DIM);env->GetFloatArrayRegion(jTrailing,0,nTrailing*DIM,trailingData.data());}
|
||
|
||
int codecSize=env->GetArrayLength(jCodecEmb);
|
||
std::vector<float> codecEmb(codecSize);
|
||
env->GetFloatArrayRegion(jCodecEmb,0,codecSize,codecEmb.data());
|
||
|
||
int cpEmbsSize=env->GetArrayLength(jCpEmbs);
|
||
std::vector<float> cpEmbs(cpEmbsSize);
|
||
env->GetFloatArrayRegion(jCpEmbs,0,cpEmbsSize,cpEmbs.data());
|
||
|
||
int headsSize=env->GetArrayLength(jCpHeads);
|
||
std::vector<float> cpHeads(headsSize);
|
||
env->GetFloatArrayRegion(jCpHeads,0,headsSize,cpHeads.data());
|
||
|
||
int tcSize=env->GetArrayLength(jTCos);
|
||
std::vector<float> tCos(tcSize),tSin(tcSize);
|
||
env->GetFloatArrayRegion(jTCos,0,tcSize,tCos.data());
|
||
env->GetFloatArrayRegion(jTSin,0,tcSize,tSin.data());
|
||
|
||
int ccSize=env->GetArrayLength(jCCos);
|
||
std::vector<float> cCos(ccSize),cSin(ccSize);
|
||
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
|
||
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
|
||
|
||
std::vector<float> eosEmbed(DIM,0), padEmbed(DIM,0);
|
||
if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data());
|
||
if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data());
|
||
|
||
// Pipeline state
|
||
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
||
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
|
||
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
|
||
float hidden[DIM]={}, logits[VOCAB]={};
|
||
|
||
std::vector<int> allCodes; // flat: numTokens × 16
|
||
std::vector<int> cb0History;
|
||
int pos=0, currentCb0=-1;
|
||
|
||
// ===== PREFILL =====
|
||
auto tP0=std::chrono::high_resolution_clock::now();
|
||
for(int step=0;step<nPrefill;step++){
|
||
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
||
if(mi>=0) mask[mi]=0.0f;
|
||
talkerStep(*gState->talker, prefill.data()+step*DIM, mask, pos, tCos.data(),tSin.data(),
|
||
tK.data(), tV.data(), hidden, logits);
|
||
pos++;
|
||
if(step==nPrefill-1){
|
||
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
||
currentCb0=sample_topk(logits,VOCAB,0.9f,50);
|
||
}
|
||
}
|
||
auto tP1=std::chrono::high_resolution_clock::now();
|
||
LOGI("Prefill: %.0fms, %d steps, cb0=%d, hidden[0:4]=[%.6f,%.6f,%.6f,%.6f]",
|
||
std::chrono::duration<float,std::milli>(tP1-tP0).count(), nPrefill, currentCb0,
|
||
hidden[0],hidden[1],hidden[2],hidden[3]);
|
||
|
||
if(currentCb0<0||currentCb0==CODEC_EOS){return env->NewIntArray(0);}
|
||
|
||
// ===== GENERATION =====
|
||
float totalTalkerMs=0, totalCpMs=0;
|
||
int trailingIdx=0;
|
||
// Pad embedding (zeros + pad token is not available here, use zeros)
|
||
float padEmb[DIM]={}; // In practice should be tts_pad_embed, passed as param
|
||
|
||
for(int gen=0;gen<maxTokens;gen++){
|
||
int codes[NUM_CB]={}; codes[0]=currentCb0;
|
||
|
||
// CP
|
||
auto tc0=std::chrono::high_resolution_clock::now();
|
||
int cpCodes[15]={};
|
||
cpForward(*gState->cp, hidden, currentCb0, codecEmb.data(), cpEmbs.data(), cpHeads.data(),
|
||
cCos.data(), cSin.data(), cpCodes);
|
||
auto tc1=std::chrono::high_resolution_clock::now();
|
||
totalCpMs+=std::chrono::duration<float,std::milli>(tc1-tc0).count();
|
||
for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
|
||
|
||
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
|
||
cb0History.push_back(currentCb0);
|
||
|
||
// Build next talker input: use pre-computed trailing embeds as-is
|
||
float nextEmb[DIM]={};
|
||
if(trailingIdx<nTrailing){
|
||
memcpy(nextEmb, trailingData.data()+trailingIdx*DIM, DIM*4);
|
||
trailingIdx++;
|
||
} else {
|
||
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
|
||
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
|
||
for(int cb=0;cb<15;cb++){
|
||
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
||
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
||
}
|
||
if(trailingIdx==nTrailing){
|
||
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmbed[k];
|
||
trailingIdx++;
|
||
} else {
|
||
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmbed[k];
|
||
}
|
||
}
|
||
|
||
// Talker step
|
||
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
||
if(mi>=0) mask[mi]=0.0f;
|
||
auto tt0=std::chrono::high_resolution_clock::now();
|
||
talkerStep(*gState->talker, nextEmb, mask, pos, tCos.data(),tSin.data(),
|
||
tK.data(), tV.data(), hidden, logits);
|
||
auto tt1=std::chrono::high_resolution_clock::now();
|
||
totalTalkerMs+=std::chrono::duration<float,std::milli>(tt1-tt0).count();
|
||
pos++;
|
||
|
||
// Next cb0: suppress non-codec, repetition penalty, top-k sampling
|
||
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
||
std::unordered_set<int> seen(cb0History.begin(),cb0History.end());
|
||
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
|
||
int nextCb0=sample_topk(logits,VOCAB,0.9f,50);
|
||
|
||
if(nextCb0==CODEC_EOS){LOGI("EOS at step %d",gen+2);break;}
|
||
// Degeneration check
|
||
int histSz=(int)cb0History.size();
|
||
if(histSz>=9){
|
||
bool degen=true;
|
||
for(int i=histSz-9;i<histSz;i++) if(cb0History[i]!=nextCb0){degen=false;break;}
|
||
if(degen){LOGI("Degeneration at step %d",gen+2);break;}
|
||
}
|
||
currentCb0=nextCb0;
|
||
}
|
||
|
||
int nTokens=(int)allCodes.size()/NUM_CB;
|
||
// Log all CB0 codes for quality comparison
|
||
{
|
||
char buf[2048]={};int off=0;
|
||
for(int i=0;i<(int)cb0History.size()&&off<2000;i++) off+=snprintf(buf+off,2048-off,"%d,",cb0History[i]);
|
||
LOGI("CB0 sequence: [%s]",buf);
|
||
}
|
||
auto T1=std::chrono::high_resolution_clock::now();
|
||
LOGI("Generated %d tokens | Talker: %.0fms (%.0fms/step) | CP: %.0fms (%.0fms/step) | Total: %.0fms",
|
||
nTokens, totalTalkerMs, totalTalkerMs/std::max(nTokens,1),
|
||
totalCpMs, totalCpMs/std::max(nTokens,1),
|
||
std::chrono::duration<float,std::milli>(T1-T0).count());
|
||
|
||
jintArray result=env->NewIntArray((int)allCodes.size());
|
||
env->SetIntArrayRegion(result,0,(int)allCodes.size(),allCodes.data());
|
||
return result;
|
||
}
|
||
|
||
} // extern "C"
|