208 lines
8.9 KiB
C++
208 lines
8.9 KiB
C++
/**
|
|
* TTS Code Predictor Runner — ExecuTorch .pte on NPU HTP.
|
|
* Based on executor_runner.cpp but with socket IPC for the app.
|
|
* Same protocol as the GGUF CP runner.
|
|
*/
|
|
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cstdlib>
|
|
#include <chrono>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
#include <executorch/extension/data_loader/file_data_loader.h>
|
|
#include <executorch/extension/runner_util/inputs.h>
|
|
#include <executorch/runtime/executor/method.h>
|
|
#include <executorch/runtime/executor/program.h>
|
|
#include <executorch/runtime/platform/runtime.h>
|
|
|
|
#include <sys/stat.h>
|
|
#include <sys/socket.h>
|
|
#include <sys/un.h>
|
|
#include <unistd.h>
|
|
|
|
DEFINE_string(model_path, "", "Path to .pte file");
|
|
DEFINE_string(sock_path, "/data/local/tmp/kazeia/cp_et.sock", "Socket path");
|
|
DEFINE_string(heads_path, "/data/local/tmp/kazeia/models/cp_heads.bin", "Heads file");
|
|
DEFINE_string(embs_path, "/data/local/tmp/kazeia/models/cp_codec_embs.bin", "Codec embs file");
|
|
DEFINE_string(cos_path, "/data/local/tmp/kazeia/models/qwen3-tts-npu/cp_kv_v2/cp_rotary_cos.npy", "Cos file");
|
|
DEFINE_string(sin_path, "/data/local/tmp/kazeia/models/qwen3-tts-npu/cp_kv_v2/cp_rotary_sin.npy", "Sin file");
|
|
|
|
using executorch::runtime::Error;
|
|
using executorch::runtime::EValue;
|
|
using executorch::runtime::HierarchicalAllocator;
|
|
using executorch::runtime::MemoryAllocator;
|
|
using executorch::runtime::MemoryManager;
|
|
using executorch::runtime::Method;
|
|
using executorch::runtime::Program;
|
|
using executorch::runtime::Result;
|
|
using executorch::runtime::Span;
|
|
|
|
static const int N_EMBD=1024, N_VOCAB=2048, N_CB=15, N_KV=8, HD=128, KV_LEN=16, N_L=5;
|
|
|
|
static bool read_exact(int fd,void*buf,size_t n){
|
|
size_t d=0;while(d<n){ssize_t r=read(fd,(char*)buf+d,n-d);if(r<=0)return false;d+=r;}return true;
|
|
}
|
|
static bool write_exact(int fd,const void*buf,size_t n){
|
|
size_t d=0;while(d<n){ssize_t r=write(fd,(const char*)buf+d,n-d);if(r<=0)return false;d+=r;}return true;
|
|
}
|
|
|
|
static float* load_npy(const char*p,int n){
|
|
FILE*f=fopen(p,"rb");if(!f)return nullptr;
|
|
unsigned char h[10];fread(h,1,10,f);
|
|
int hl=h[8]|(h[9]<<8);fseek(f,10+hl,SEEK_SET);
|
|
float*d=(float*)malloc(n*4);fread(d,4,n,f);fclose(f);return d;
|
|
}
|
|
|
|
static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4MB
|
|
static uint8_t temp_allocator_pool[1024U * 1024U]; // 1MB
|
|
|
|
int main(int argc, char** argv) {
|
|
executorch::runtime::runtime_init();
|
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
|
|
|
if (FLAGS_model_path.empty()) {
|
|
fprintf(stderr, "Usage: cp_et_runner --model_path=model.pte\n");
|
|
return 1;
|
|
}
|
|
|
|
// Load program
|
|
auto loader = executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str());
|
|
ET_CHECK_MSG(loader.ok(), "Failed to load %s", FLAGS_model_path.c_str());
|
|
|
|
auto program = Program::load(&loader.get());
|
|
ET_CHECK_MSG(program.ok(), "Failed to parse program");
|
|
|
|
// Setup memory — allocate planned buffers from program metadata
|
|
MemoryAllocator method_allocator(sizeof(method_allocator_pool), method_allocator_pool);
|
|
auto temp_allocator = MemoryAllocator(sizeof(temp_allocator_pool), temp_allocator_pool);
|
|
|
|
auto method_meta = program->method_meta("forward");
|
|
ET_CHECK_MSG(method_meta.ok(), "Failed to get method meta");
|
|
|
|
std::vector<std::unique_ptr<uint8_t[]>> planned_bufs;
|
|
std::vector<Span<uint8_t>> planned_spans;
|
|
size_t n_planned = method_meta->num_memory_planned_buffers();
|
|
for (size_t id = 0; id < n_planned; id++) {
|
|
size_t sz = (size_t)method_meta->memory_planned_buffer_size(id).get();
|
|
planned_bufs.push_back(std::make_unique<uint8_t[]>(sz));
|
|
planned_spans.push_back({planned_bufs.back().get(), sz});
|
|
}
|
|
HierarchicalAllocator planned_memory({planned_spans.data(), planned_spans.size()});
|
|
MemoryManager memory_manager(&method_allocator, &planned_memory, &temp_allocator);
|
|
|
|
// Load method
|
|
auto method = program->load_method("forward", &memory_manager);
|
|
ET_CHECK_MSG(method.ok(), "Failed to load method: 0x%x", (int)method.error());
|
|
|
|
auto meta = method->method_meta();
|
|
fprintf(stderr, "CP_ET: %zu inputs, %zu outputs\n", meta.num_inputs(), meta.num_outputs());
|
|
|
|
// Load heads, embeddings, rotary
|
|
float* heads = (float*)malloc(N_CB * N_VOCAB * N_EMBD * 4);
|
|
float* embs_data = (float*)malloc(N_CB * N_VOCAB * N_EMBD * 4);
|
|
FILE* fh = fopen(FLAGS_heads_path.c_str(), "rb");
|
|
if (fh) { fread(heads, 4, N_CB*N_VOCAB*N_EMBD, fh); fclose(fh); }
|
|
FILE* fe = fopen(FLAGS_embs_path.c_str(), "rb");
|
|
if (fe) { fread(embs_data, 4, N_CB*N_VOCAB*N_EMBD, fe); fclose(fe); }
|
|
float* rcos = load_npy(FLAGS_cos_path.c_str(), 17*HD);
|
|
float* rsin = load_npy(FLAGS_sin_path.c_str(), 17*HD);
|
|
|
|
// Socket setup
|
|
unlink(FLAGS_sock_path.c_str());
|
|
int srv = socket(AF_UNIX, SOCK_STREAM, 0);
|
|
struct sockaddr_un addr = {}; addr.sun_family = AF_UNIX;
|
|
strncpy(addr.sun_path, FLAGS_sock_path.c_str(), sizeof(addr.sun_path)-1);
|
|
bind(srv, (struct sockaddr*)&addr, sizeof(addr));
|
|
chmod(FLAGS_sock_path.c_str(), 0666);
|
|
listen(srv, 1);
|
|
fprintf(stderr, "CP_ET READY on %s\n", FLAGS_sock_path.c_str());
|
|
|
|
while (true) {
|
|
int cli = accept(srv, nullptr, nullptr);
|
|
if (cli < 0) break;
|
|
|
|
float input[2 * N_EMBD];
|
|
while (read_exact(cli, input, sizeof(input))) {
|
|
auto t0 = std::chrono::high_resolution_clock::now();
|
|
float* hidden_in = input;
|
|
float* cb0_emb = input + N_EMBD;
|
|
int kv_elem = N_KV * KV_LEN * HD;
|
|
std::vector<float> kv(N_L * 2 * kv_elem, 0.0f);
|
|
int codes[N_CB] = {};
|
|
float* emb = hidden_in;
|
|
|
|
for (int step = 0; step < 17; step++) {
|
|
if (step == 1) emb = cb0_emb;
|
|
else if (step >= 2) emb = embs_data + ((step-2)*N_VOCAB + codes[step-2]) * N_EMBD;
|
|
|
|
// Prepare input tensors (allocates buffers matching the method's expectations)
|
|
auto prep = executorch::extension::prepare_input_tensors(method.get());
|
|
if (!prep.ok()) { fprintf(stderr, "prep fail %d\n", step); break; }
|
|
|
|
// Copy our data into the prepared tensors
|
|
// Input 0: emb [1,1,1024]
|
|
memcpy(method->mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, N_EMBD*4);
|
|
// Input 1: mask [1,1,1,16]
|
|
float* mp = method->mutable_input(1).toTensor().mutable_data_ptr<float>();
|
|
for (int p = 0; p < KV_LEN; p++) mp[p] = (p >= KV_LEN-1-step) ? 0.0f : -1e9f;
|
|
// Input 2: cos [1,1,128]
|
|
memcpy(method->mutable_input(2).toTensor().mutable_data_ptr<float>(), rcos+step*HD, HD*4);
|
|
// Input 3: sin [1,1,128]
|
|
memcpy(method->mutable_input(3).toTensor().mutable_data_ptr<float>(), rsin+step*HD, HD*4);
|
|
// Inputs 4-13: KV caches [1,8,16,128]
|
|
for (int l = 0; l < N_L; l++) {
|
|
memcpy(method->mutable_input(4+l*2).toTensor().mutable_data_ptr<float>(),
|
|
kv.data()+(l*2)*kv_elem, kv_elem*4);
|
|
memcpy(method->mutable_input(5+l*2).toTensor().mutable_data_ptr<float>(),
|
|
kv.data()+(l*2+1)*kv_elem, kv_elem*4);
|
|
}
|
|
|
|
auto status = method->execute();
|
|
if (status != Error::Ok) {
|
|
fprintf(stderr, "exec fail step %d: %d\n", step, (int)status);
|
|
break;
|
|
}
|
|
|
|
// Get hidden output
|
|
const float* h = method->get_output(0).toTensor().const_data_ptr<float>();
|
|
|
|
// Head argmax on CPU
|
|
if (step >= 1 && step-1 < N_CB) {
|
|
int cb = step - 1;
|
|
const float* W = heads + cb * N_VOCAB * N_EMBD;
|
|
int best = 0; float bv = -1e30f;
|
|
for (int j = 0; j < N_VOCAB; j++) {
|
|
float dot = 0;
|
|
for (int k = 0; k < N_EMBD; k++) dot += h[k] * W[j*N_EMBD+k];
|
|
if (dot > bv) { bv = dot; best = j; }
|
|
}
|
|
codes[cb] = best;
|
|
}
|
|
|
|
// Update KV caches from outputs
|
|
for (int l = 0; l < N_L; l++) {
|
|
const float* ko = method->get_output(1+l*2).toTensor().const_data_ptr<float>();
|
|
const float* vo = method->get_output(2+l*2).toTensor().const_data_ptr<float>();
|
|
memcpy(kv.data()+(l*2)*kv_elem, ko, kv_elem*4);
|
|
memcpy(kv.data()+(l*2+1)*kv_elem, vo, kv_elem*4);
|
|
}
|
|
}
|
|
|
|
auto t1 = std::chrono::high_resolution_clock::now();
|
|
float ms = std::chrono::duration<float, std::milli>(t1-t0).count();
|
|
write_exact(cli, codes, sizeof(codes));
|
|
write_exact(cli, &ms, sizeof(ms));
|
|
}
|
|
close(cli);
|
|
}
|
|
|
|
free(heads); free(embs_data); free(rcos); free(rsin);
|
|
close(srv); unlink(FLAGS_sock_path.c_str());
|
|
return 0;
|
|
}
|