/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef ET_USE_THREADPOOL #include #include #ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD #include #endif #endif #ifdef EXECUTORCH_ANDROID_PROFILING #include #include #include #endif #include #include using namespace executorch::extension; using namespace torch::executor; namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/Tensor;"; explicit TensorHybrid(executorch::aten::Tensor tensor) {} static facebook::jni::local_ref newJTensorFromTensor(const executorch::aten::Tensor& tensor) { // Java wrapper currently only supports contiguous tensors. const auto scalarType = tensor.scalar_type(); if (scalar_type_to_java_dtype.count(scalarType) == 0) { std::stringstream ss; ss << "executorch::aten::Tensor scalar type " << static_cast(scalarType) << " is not supported on java side"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } int jdtype = scalar_type_to_java_dtype.at(scalarType); const auto& tensor_shape = tensor.sizes(); std::vector tensor_shape_vec; for (const auto& s : tensor_shape) { tensor_shape_vec.push_back(s); } facebook::jni::local_ref jTensorShape = facebook::jni::make_long_array(tensor_shape_vec.size()); jTensorShape->setRegion( 0, tensor_shape_vec.size(), tensor_shape_vec.data()); static auto cls = TensorHybrid::javaClassStatic(); // Note: this is safe as long as the data stored in tensor is valid; the // data won't go out of scope as long as the Method for the inference is // valid and there is no other inference call. Java layer picks up this // value immediately so the data is valid. facebook::jni::local_ref jTensorBuffer = facebook::jni::JByteBuffer::wrapBytes( (uint8_t*)tensor.data_ptr(), tensor.nbytes()); jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); static const auto jMethodNewTensor = cls->getStaticMethod( facebook::jni::alias_ref, facebook::jni::alias_ref, jint, facebook::jni::alias_ref)>("nativeNewTensor"); return jMethodNewTensor( cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); } static TensorPtr newTensorFromJTensor( facebook::jni::alias_ref jtensor) { static auto cls = TensorHybrid::javaClassStatic(); static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); jint jdtype = dtypeMethod(jtensor); static const auto shapeField = cls->getField("shape"); auto jshape = jtensor->getFieldValue(shapeField); static auto dataBufferMethod = cls->getMethod< facebook::jni::local_ref()>( "getRawDataBuffer"); facebook::jni::local_ref jbuffer = dataBufferMethod(jtensor); const auto rank = jshape->size(); const auto shapeArr = jshape->getRegion(0, rank); std::vector shape_vec; shape_vec.reserve(rank); int64_t numel = 1; for (int i = 0; i < rank; i++) { shape_vec.push_back(shapeArr[i]); } for (int i = rank - 1; i >= 0; --i) { numel *= shapeArr[i]; } JNIEnv* jni = facebook::jni::Environment::current(); if (java_dtype_to_scalar_type.count(jdtype) == 0) { std::stringstream ss; ss << "Unknown Tensor jdtype: [" << jdtype << "]"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); if (dataCapacity < 0) { std::stringstream ss; ss << "Tensor buffer is not direct or has invalid capacity"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } const size_t elementSize = executorch::runtime::elementSize(scalar_type); const jlong expectedElements = static_cast(numel); const jlong expectedBytes = expectedElements * static_cast(elementSize); const bool matchesElements = dataCapacity == expectedElements; const bool matchesBytes = dataCapacity == expectedBytes; if (!matchesElements && !matchesBytes) { std::stringstream ss; ss << "Tensor dimensions(elements number: " << numel << ") inconsistent with buffer capacity " << dataCapacity << " (element size bytes: " << elementSize << ")"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } return from_blob( jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); } private: friend HybridBase; }; class JEValue : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/EValue;"; constexpr static int kTypeCodeTensor = 1; constexpr static int kTypeCodeString = 2; constexpr static int kTypeCodeDouble = 3; constexpr static int kTypeCodeInt = 4; constexpr static int kTypeCodeBool = 5; static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) { if (evalue.isTensor()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod( facebook::jni::local_ref)>("from"); return jMethodTensor( JEValue::javaClassStatic(), TensorHybrid::newJTensorFromTensor(evalue.toTensor())); } else if (evalue.isInt()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jlong)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt()); } else if (evalue.isDouble()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jdouble)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble()); } else if (evalue.isBool()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jboolean)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool()); } else if (evalue.isString()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod( facebook::jni::local_ref)>("from"); std::string str = std::string(evalue.toString().begin(), evalue.toString().end()); return jMethodTensor( JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); } std::stringstream ss; ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return {}; } static TensorPtr JEValueToTensorImpl( facebook::jni::alias_ref JEValue) { static const auto typeCodeField = JEValue::javaClassStatic()->getField("mTypeCode"); const auto typeCode = JEValue->getFieldValue(typeCodeField); if (JEValue::kTypeCodeTensor == typeCode) { static const auto jMethodGetTensor = JEValue::javaClassStatic() ->getMethod()>( "toTensor"); auto jtensor = jMethodGetTensor(JEValue); return TensorHybrid::newTensorFromJTensor(jtensor); } std::stringstream ss; ss << "Unknown EValue typeCode: " << typeCode; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str().c_str()); return {}; } }; class ExecuTorchJni : public facebook::jni::HybridClass { private: friend HybridBase; std::unique_ptr module_; #if defined(ET_USE_THREADPOOL) && \ defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD) int num_threads_{0}; #endif public: constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, facebook::jni::alias_ref modelPath, jint loadMode, jint numThreads) { return makeCxxInstance(modelPath, loadMode, numThreads); } ExecuTorchJni( facebook::jni::alias_ref modelPath, jint loadMode, jint numThreads) { Module::LoadMode load_mode = Module::LoadMode::Mmap; if (loadMode == 0) { load_mode = Module::LoadMode::File; } else if (loadMode == 1) { load_mode = Module::LoadMode::Mmap; } else if (loadMode == 2) { load_mode = Module::LoadMode::MmapUseMlock; } else if (loadMode == 3) { load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors; } #ifdef EXECUTORCH_ANDROID_PROFILING auto etdump_gen = std::make_unique(); #else auto etdump_gen = nullptr; #endif module_ = std::make_unique( modelPath->toStdString(), load_mode, std::move(etdump_gen)); #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to // improve performant core detection in CPUInfo, but for now we can use // cores/2 as a sane default. // // Based on testing, this is almost universally faster than using all // cores, as efficiency cores can be quite slow. In extreme cases, using // all cores can be 10x slower than using cores/2. int thread_count = numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2; #ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD num_threads_ = thread_count; #else auto threadpool = executorch::extension::threadpool::get_threadpool(); if (threadpool) { if (thread_count > 0) { threadpool->_unsafe_reset_threadpool(thread_count); } } #endif #endif } facebook::jni::local_ref> execute( facebook::jni::alias_ref methodName, facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { return execute_method(methodName->toStdString(), jinputs); } jint load_method(facebook::jni::alias_ref methodName) { return static_cast(module_->load_method(methodName->toStdString())); } facebook::jni::local_ref> execute_method( std::string method, facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { auto result = module_->load_method(method); if (result != Error::Ok) { // Format hex string std::stringstream ss; ss << "Cannot get method names [Native Error: 0x" << std::hex << std::uppercase << static_cast(result) << "]"; jni_helper::throwExecutorchException( static_cast(result), ss.str()); return {}; } auto&& underlying_method = module_->methods_[method].method; auto&& buf = prepare_input_tensors(*underlying_method); result = underlying_method->execute(); if (result != Error::Ok) { jni_helper::throwExecutorchException( static_cast(result), "Execution failed for method: " + method); return {}; } facebook::jni::local_ref> jresult = facebook::jni::JArrayClass::newArray( underlying_method->outputs_size()); for (int i = 0; i < underlying_method->outputs_size(); i++) { auto jevalue = JEValue::newJEValueFromEValue(underlying_method->get_output(i)); jresult->setElement(i, *jevalue); } return jresult; } std::vector evalues; std::vector tensors; static const auto typeCodeField = JEValue::javaClassStatic()->getField("mTypeCode"); for (int i = 0; i < jinputs->size(); i++) { auto jevalue = jinputs->getElement(i); const auto typeCode = jevalue->getFieldValue(typeCodeField); if (typeCode == JEValue::kTypeCodeTensor) { tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); evalues.emplace_back(tensors.back()); } else if (typeCode == JEValue::kTypeCodeInt) { static const auto toIntMethod = JEValue::javaClassStatic()->getMethod("toInt"); evalues.emplace_back(static_cast(toIntMethod(jevalue))); } else if (typeCode == JEValue::kTypeCodeDouble) { static const auto toDoubleMethod = JEValue::javaClassStatic()->getMethod("toDouble"); evalues.emplace_back(static_cast(toDoubleMethod(jevalue))); } else if (typeCode == JEValue::kTypeCodeBool) { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); } } #if defined(ET_USE_THREADPOOL) && \ defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD) ::executorch::extension::threadpool::UseNThreadsThreadPoolGuard thread_pool_guard(num_threads_); #endif #ifdef EXECUTORCH_ANDROID_PROFILING auto start = std::chrono::high_resolution_clock::now(); auto result = module_->execute(method, evalues); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start) .count(); ET_LOG(Debug, "Execution time: %lld ms.", duration); #else auto result = module_->execute(method, evalues); #endif if (!result.ok()) { jni_helper::throwExecutorchException( static_cast(result.error()), "Execution failed for method: " + method); return {}; } facebook::jni::local_ref> jresult = facebook::jni::JArrayClass::newArray(result.get().size()); for (int i = 0; i < result.get().size(); i++) { auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); jresult->setElement(i, *jevalue); } return jresult; } facebook::jni::local_ref> readLogBuffer() { return readLogBufferUtil(); } static facebook::jni::local_ref> readLogBufferStatic(facebook::jni::alias_ref) { return readLogBufferUtil(); } static facebook::jni::local_ref> readLogBufferUtil() { #ifdef __ANDROID__ facebook::jni::local_ref> ret; access_log_buffer([&](std::vector& buffer) { const auto size = buffer.size(); ret = facebook::jni::JArrayClass::newArray(size); for (auto i = 0u; i < size; i++) { const auto& entry = buffer[i]; // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL // MESSAGE". std::stringstream ss; ss << "[" << entry.timestamp << " " << entry.function << " " << entry.filename << ":" << entry.line << "] " << static_cast(entry.level) << " " << entry.message; facebook::jni::local_ref jstr_message = facebook::jni::make_jstring(ss.str().c_str()); (*ret)[i] = jstr_message; } }); return ret; #else return facebook::jni::JArrayClass::newArray(0); #endif } jboolean etdump() { #ifdef EXECUTORCH_ANDROID_PROFILING executorch::etdump::ETDumpGen* etdumpgen = (executorch::etdump::ETDumpGen*)module_->event_tracer(); auto etdump_data = etdumpgen->get_etdump_data(); if (etdump_data.buf != nullptr && etdump_data.size > 0) { int etdump_file = open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); if (etdump_file == -1) { ET_LOG(Error, "Cannot create result.etdump error: %d", errno); return false; } ssize_t bytes_written = write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); if (bytes_written == -1) { ET_LOG(Error, "Cannot write result.etdump error: %d", errno); return false; } else { ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); } close(etdump_file); free(etdump_data.buf); return true; } else { ET_LOG(Error, "No ETDump data available!"); } #endif return false; } facebook::jni::local_ref> getMethods() { const auto& names_result = module_->method_names(); if (!names_result.ok()) { // Format hex string std::stringstream ss; ss << "Cannot get load module [Native Error: 0x" << std::hex << std::uppercase << static_cast(names_result.error()) << "]"; jni_helper::throwExecutorchException( static_cast(Error::InvalidArgument), ss.str()); return {}; } const auto& methods = names_result.get(); facebook::jni::local_ref> ret = facebook::jni::JArrayClass::newArray(methods.size()); int i = 0; for (auto s : methods) { facebook::jni::local_ref method_name = facebook::jni::make_jstring(s.c_str()); (*ret)[i] = method_name; i++; } return ret; } facebook::jni::local_ref> getUsedBackends( facebook::jni::alias_ref methodName) { auto method_name = methodName->toStdString(); auto methodMetaResult = module_->method_meta(method_name); if (!methodMetaResult.ok()) { std::stringstream ss; ss << "Cannot get method meta for '" << method_name << "' [Native Error: 0x" << std::hex << std::uppercase << static_cast(methodMetaResult.error()) << "]"; jni_helper::throwExecutorchException( static_cast(methodMetaResult.error()), ss.str()); return {}; } auto methodMeta = methodMetaResult.get(); std::unordered_set backends; for (auto i = 0; i < methodMeta.num_backends(); i++) { auto backend_name_result = methodMeta.get_backend_name(i); if (backend_name_result.ok()) { backends.insert(backend_name_result.get()); } } facebook::jni::local_ref> ret = facebook::jni::JArrayClass::newArray(backends.size()); int i = 0; for (auto s : backends) { facebook::jni::local_ref backend_name = facebook::jni::make_jstring(s.c_str()); (*ret)[i] = backend_name; i++; } return ret; } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("executeNative", ExecuTorchJni::execute), makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method), makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod( "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), makeNativeMethod("etdump", ExecuTorchJni::etdump), makeNativeMethod("getMethods", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule), makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline), }); } // ======= TTS Pipeline ======= static Module* cpModule_; void setCpModule() { cpModule_ = module_.get(); ET_LOG(Info, "TTS CP module set"); } // Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP. facebook::jni::local_ref runTtsPipeline( facebook::jni::alias_ref jPrefill, jint nPrefill, facebook::jni::alias_ref jTrailing, jint nTrailing, facebook::jni::alias_ref jCodecEmb, facebook::jni::alias_ref jCpEmbs, facebook::jni::alias_ref jCpHeads, facebook::jni::alias_ref jTCos, facebook::jni::alias_ref jTSin, facebook::jni::alias_ref jCCos, facebook::jni::alias_ref jCSin, facebook::jni::alias_ref jEos, facebook::jni::alias_ref jPad, jint maxTokens); static facebook::jni::local_ref runTtsPipelineImpl( Module* talker, Module* cp, facebook::jni::alias_ref jPrefill, jint nPrefill, facebook::jni::alias_ref jTrailing, jint nTrailing, facebook::jni::alias_ref jCodecEmb, facebook::jni::alias_ref jCpEmbs, facebook::jni::alias_ref jCpHeads, facebook::jni::alias_ref jTCos, facebook::jni::alias_ref jTSin, facebook::jni::alias_ref jCCos, facebook::jni::alias_ref jCSin, facebook::jni::alias_ref jEos, facebook::jni::alias_ref jPad, jint maxTokens); }; } // namespace executorch::extension // ======= TTS Pipeline statics and wrapper ======= namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; } namespace executorch::extension { facebook::jni::local_ref ExecuTorchJni::runTtsPipeline( facebook::jni::alias_ref jPrefill, jint nPrefill, facebook::jni::alias_ref jTrailing, jint nTrailing, facebook::jni::alias_ref jCodecEmb, facebook::jni::alias_ref jCpEmbs, facebook::jni::alias_ref jCpHeads, facebook::jni::alias_ref jTCos, facebook::jni::alias_ref jTSin, facebook::jni::alias_ref jCCos, facebook::jni::alias_ref jCSin, facebook::jni::alias_ref jEos, facebook::jni::alias_ref jPad, jint maxTokens) { if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); } return runTtsPipelineImpl(module_.get(), cpModule_, jPrefill, nPrefill, jTrailing, nTrailing, jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens); } } // namespace #include #include #include static inline float tts_dot_neon(const float* a, const float* b, int n) { float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0); int i=0; for(;i+15 topk(k,{0,-FLT_MAX}); for(int i=0;itopk[k-1].v){topk[k-1]={i,logits[i]}; for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}} } float maxv=topk[0].v,sum=0; for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;} tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL; float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum; float acc=0; for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;} return topk[0].i; } namespace executorch::extension { facebook::jni::local_ref ExecuTorchJni::runTtsPipelineImpl( Module* talker, Module* cp, facebook::jni::alias_ref jPrefill, jint nPrefill, facebook::jni::alias_ref jTrailing, jint nTrailing, facebook::jni::alias_ref jCodecEmb, facebook::jni::alias_ref jCpEmbs, facebook::jni::alias_ref jCpHeads, facebook::jni::alias_ref jTCos, facebook::jni::alias_ref jTSin, facebook::jni::alias_ref jCCos, facebook::jni::alias_ref jCSin, facebook::jni::alias_ref jEos, facebook::jni::alias_ref jPad, jint maxTokens) { static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16; static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100; static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16; static const int CODEC_EOS=2150; // Copy JNI arrays auto copyArr = [](facebook::jni::alias_ref j) { std::vector v(j->size()); j->getRegion(0, j->size(), v.data()); return v; }; auto prefill=copyArr(jPrefill); auto trailing=nTrailing>0?copyArr(jTrailing):std::vector(); auto codecEmb=copyArr(jCodecEmb); auto cpEmbs=copyArr(jCpEmbs); auto cpHeads=copyArr(jCpHeads); auto tCos=copyArr(jTCos),tSin=copyArr(jTSin); auto cCos=copyArr(jCCos),cSin=copyArr(jCSin); auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad); int tkvElem=T_KV*T_KV_LEN*T_HD; // tK/tV not needed — KV copied directly from output to input each step float mask[T_KV_LEN]; for(int i=0;i allCodes,cb0Hist; int pos=0,curCb0=-1,trIdx=0; // Get raw Method pointers for direct execution auto tMethodRes = talker->method("forward"); auto cMethodRes = cp->method("forward"); if (!tMethodRes.ok() || !cMethodRes.ok()) { ET_LOG(Error, "TTS cannot get Method pointers"); return facebook::jni::JArrayInt::newArray(0); } Method* tMethod = tMethodRes.get(); Method* cMethod = cMethodRes.get(); // Talker: prepare once, cache pointers, reuse for all 58+ steps {auto prep=executorch::extension::prepare_input_tensors(*tMethod);} float* tInEmb = tMethod->mutable_input(0).toTensor().mutable_data_ptr(); float* tInMask = tMethod->mutable_input(1).toTensor().mutable_data_ptr(); float* tInCos = tMethod->mutable_input(2).toTensor().mutable_data_ptr(); float* tInSin = tMethod->mutable_input(3).toTensor().mutable_data_ptr(); float* tInKV[T_L*2]; for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(); tInKV[i*2+1] = tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(); } auto talkerStep = [&](const float* emb) { int pi=std::min(pos,249); int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1); if(mi>=0) mask[mi]=0.0f; memcpy(tInEmb, emb, DIM*4); memcpy(tInMask, mask, T_KV_LEN*4); memcpy(tInCos, tCos.data()+pi*T_HD, T_HD*4); memcpy(tInSin, tSin.data()+pi*T_HD, T_HD*4); // KV: copy directly from PREVIOUS output to input (skip intermediate buffer) if(pos > 0) { for(int i=0;iget_output(2+i*2).toTensor().const_data_ptr(), tkvElem*4); memcpy(tInKV[i*2+1], tMethod->get_output(3+i*2).toTensor().const_data_ptr(), tkvElem*4); } } // (pos==0: first call, input KV is already zeros from prepare_input_tensors) auto status = tMethod->execute(); if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;} memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr(), DIM*4); memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr(), VOCAB*4); // KV NOT copied to tK/tV — read from output directly next step pos++; }; // CP step: prepare once, direct output→input KV copy int ckvElem=C_KV*C_KV_LEN*C_HD; {auto prep=executorch::extension::prepare_input_tensors(*cMethod);} float* cpInEmb = cMethod->mutable_input(0).toTensor().mutable_data_ptr(); float* cpInMask = cMethod->mutable_input(1).toTensor().mutable_data_ptr(); float* cpInCos = cMethod->mutable_input(2).toTensor().mutable_data_ptr(); float* cpInSin = cMethod->mutable_input(3).toTensor().mutable_data_ptr(); float* cpInKV[C_L*2]; for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(); cpInKV[i*2+1] = cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(); } auto cpStep = [&](const float* h, int cb0, int* codes) { // Reset CP KV to zeros for step 0 for(int i=0;i=C_KV_LEN-1-step)?0.0f:-1e9f; memcpy(cpInCos, cCos.data()+step*C_HD, C_HD*4); memcpy(cpInSin, cSin.data()+step*C_HD, C_HD*4); // KV: copy from previous output directly to input (skip buffer) if(step>0){ for(int i=0;iget_output(1+i*2).toTensor().const_data_ptr(), ckvElem*4); memcpy(cpInKV[i*2+1], cMethod->get_output(2+i*2).toTensor().const_data_ptr(), ckvElem*4); } } auto status=cMethod->execute(); if(status!=Error::Ok) break; const float*ho=cMethod->get_output(0).toTensor().const_data_ptr(); if(step>=1&&step-1<15){ const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM; int best=0;float bv=-FLT_MAX; for(int j=0;jbv){bv=d;best=j;}} codes[step-1]=best; } } }; auto t0=std::chrono::high_resolution_clock::now(); // Prefill for(int s=0;s(tP-t0).count(),curCb0); if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0); // Generation float totalTalker=0,totalCp=0; for(int g=0;g(tc1-tc0).count(); for(int i=0;i<15;i++) codes[i+1]=cpCodes[i]; for(int i=0;i(tt1-tt0).count(); for(int j=CB_SIZE;j seen(cb0Hist.begin(),cb0Hist.end()); for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f; int next=tts_sample_topk(logits,VOCAB,0.9f,50); if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;} if((int)cb0Hist.size()>=9){ bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;} if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;} } curCb0=next; } int nTok=(int)allCodes.size()/NUM_CB; auto t1=std::chrono::high_resolution_clock::now(); ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms", nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1), std::chrono::duration(t1-t0).count()); auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size()); result->setRegion(0,(int)allCodes.size(),allCodes.data()); return result; } } // namespace executorch::extension #ifdef EXECUTORCH_BUILD_LLAMA_JNI extern void register_natives_for_llm(); #else // No op if we don't build LLM void register_natives_for_llm() {} #endif extern void register_natives_for_runtime(); #ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING extern void register_natives_for_training(); #else // No op if we don't build training JNI void register_natives_for_training() {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); register_natives_for_llm(); register_natives_for_runtime(); register_natives_for_training(); }); }