kazeia/executorch-custom/jni_layer_tts.cpp

908 lines
36 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/android/jni/jni_helper.h>
#include <executorch/extension/android/jni/jni_layer_constants.h>
#include <executorch/extension/android/jni/log.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/inputs.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/portable_type/tensor_impl.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/runtime.h>
#include <cassert>
#include <chrono>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#ifdef ET_USE_THREADPOOL
#include <cpuinfo.h>
#include <executorch/extension/threadpool/threadpool.h>
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
#include <executorch/extension/threadpool/fb/threadpool_use_n_threads.h>
#endif
#endif
#ifdef EXECUTORCH_ANDROID_PROFILING
#include <executorch/devtools/etdump/etdump_flatcc.h>
#include <fcntl.h>
#include <unistd.h>
#endif
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
using namespace executorch::extension;
using namespace torch::executor;
namespace executorch::extension {
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/Tensor;";
explicit TensorHybrid(executorch::aten::Tensor tensor) {}
static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
// Java wrapper currently only supports contiguous tensors.
const auto scalarType = tensor.scalar_type();
if (scalar_type_to_java_dtype.count(scalarType) == 0) {
std::stringstream ss;
ss << "executorch::aten::Tensor scalar type "
<< static_cast<int>(scalarType) << " is not supported on java side";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return nullptr;
}
int jdtype = scalar_type_to_java_dtype.at(scalarType);
const auto& tensor_shape = tensor.sizes();
std::vector<jlong> tensor_shape_vec;
for (const auto& s : tensor_shape) {
tensor_shape_vec.push_back(s);
}
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensor_shape_vec.size());
jTensorShape->setRegion(
0, tensor_shape_vec.size(), tensor_shape_vec.data());
static auto cls = TensorHybrid::javaClassStatic();
// Note: this is safe as long as the data stored in tensor is valid; the
// data won't go out of scope as long as the Method for the inference is
// valid and there is no other inference call. Java layer picks up this
// value immediately so the data is valid.
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::wrapBytes(
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
static const auto jMethodNewTensor =
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint,
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
return jMethodNewTensor(
cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
}
static TensorPtr newTensorFromJTensor(
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto shapeField = cls->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);
static auto dataBufferMethod = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
const auto rank = jshape->size();
const auto shapeArr = jshape->getRegion(0, rank);
std::vector<executorch::aten::SizesType> shape_vec;
shape_vec.reserve(rank);
int64_t numel = 1;
for (int i = 0; i < rank; i++) {
shape_vec.push_back(shapeArr[i]);
}
for (int i = rank - 1; i >= 0; --i) {
numel *= shapeArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
if (java_dtype_to_scalar_type.count(jdtype) == 0) {
std::stringstream ss;
ss << "Unknown Tensor jdtype: [" << jdtype << "]";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return nullptr;
}
ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity < 0) {
std::stringstream ss;
ss << "Tensor buffer is not direct or has invalid capacity";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return nullptr;
}
const size_t elementSize = executorch::runtime::elementSize(scalar_type);
const jlong expectedElements = static_cast<jlong>(numel);
const jlong expectedBytes =
expectedElements * static_cast<jlong>(elementSize);
const bool matchesElements = dataCapacity == expectedElements;
const bool matchesBytes = dataCapacity == expectedBytes;
if (!matchesElements && !matchesBytes) {
std::stringstream ss;
ss << "Tensor dimensions(elements number: " << numel
<< ") inconsistent with buffer capacity " << dataCapacity
<< " (element size bytes: " << elementSize << ")";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return nullptr;
}
return from_blob(
jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
}
private:
friend HybridBase;
};
class JEValue : public facebook::jni::JavaClass<JEValue> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/EValue;";
constexpr static int kTypeCodeTensor = 1;
constexpr static int kTypeCodeString = 2;
constexpr static int kTypeCodeDouble = 3;
constexpr static int kTypeCodeInt = 4;
constexpr static int kTypeCodeBool = 5;
static facebook::jni::local_ref<JEValue> newJEValueFromEValue(EValue evalue) {
if (evalue.isTensor()) {
static auto jMethodTensor =
JEValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JEValue>(
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
return jMethodTensor(
JEValue::javaClassStatic(),
TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
} else if (evalue.isInt()) {
static auto jMethodTensor =
JEValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JEValue>(jlong)>(
"from");
return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
} else if (evalue.isDouble()) {
static auto jMethodTensor =
JEValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JEValue>(jdouble)>(
"from");
return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
} else if (evalue.isBool()) {
static auto jMethodTensor =
JEValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JEValue>(jboolean)>(
"from");
return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
} else if (evalue.isString()) {
static auto jMethodTensor =
JEValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JEValue>(
facebook::jni::local_ref<jstring>)>("from");
std::string str =
std::string(evalue.toString().begin(), evalue.toString().end());
return jMethodTensor(
JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
}
std::stringstream ss;
ss << "Unknown EValue type: [" << static_cast<int>(evalue.tag) << "]";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return {};
}
static TensorPtr JEValueToTensorImpl(
facebook::jni::alias_ref<JEValue> JEValue) {
static const auto typeCodeField =
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
const auto typeCode = JEValue->getFieldValue(typeCodeField);
if (JEValue::kTypeCodeTensor == typeCode) {
static const auto jMethodGetTensor =
JEValue::javaClassStatic()
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
"toTensor");
auto jtensor = jMethodGetTensor(JEValue);
return TensorHybrid::newTensorFromJTensor(jtensor);
}
std::stringstream ss;
ss << "Unknown EValue typeCode: " << typeCode;
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
return {};
}
};
class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
private:
friend HybridBase;
std::unique_ptr<Module> module_;
#if defined(ET_USE_THREADPOOL) && \
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
int num_threads_{0};
#endif
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
jint loadMode,
jint numThreads) {
return makeCxxInstance(modelPath, loadMode, numThreads);
}
ExecuTorchJni(
facebook::jni::alias_ref<jstring> modelPath,
jint loadMode,
jint numThreads) {
Module::LoadMode load_mode = Module::LoadMode::Mmap;
if (loadMode == 0) {
load_mode = Module::LoadMode::File;
} else if (loadMode == 1) {
load_mode = Module::LoadMode::Mmap;
} else if (loadMode == 2) {
load_mode = Module::LoadMode::MmapUseMlock;
} else if (loadMode == 3) {
load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
}
#ifdef EXECUTORCH_ANDROID_PROFILING
auto etdump_gen = std::make_unique<executorch::etdump::ETDumpGen>();
#else
auto etdump_gen = nullptr;
#endif
module_ = std::make_unique<Module>(
modelPath->toStdString(), load_mode, std::move(etdump_gen));
#ifdef ET_USE_THREADPOOL
// Default to using cores/2 threadpool threads. The long-term plan is to
// improve performant core detection in CPUInfo, but for now we can use
// cores/2 as a sane default.
//
// Based on testing, this is almost universally faster than using all
// cores, as efficiency cores can be quite slow. In extreme cases, using
// all cores can be 10x slower than using cores/2.
int thread_count =
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
num_threads_ = thread_count;
#else
auto threadpool = executorch::extension::threadpool::get_threadpool();
if (threadpool) {
if (thread_count > 0) {
threadpool->_unsafe_reset_threadpool(thread_count);
}
}
#endif
#endif
}
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute(
facebook::jni::alias_ref<jstring> methodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
jinputs) {
return execute_method(methodName->toStdString(), jinputs);
}
jint load_method(facebook::jni::alias_ref<jstring> methodName) {
return static_cast<jint>(module_->load_method(methodName->toStdString()));
}
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
std::string method,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
jinputs) {
// If no inputs is given, it will run with sample inputs (ones)
if (jinputs->size() == 0) {
auto result = module_->load_method(method);
if (result != Error::Ok) {
// Format hex string
std::stringstream ss;
ss << "Cannot get method names [Native Error: 0x" << std::hex
<< std::uppercase << static_cast<uint32_t>(result) << "]";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(result), ss.str());
return {};
}
auto&& underlying_method = module_->methods_[method].method;
auto&& buf = prepare_input_tensors(*underlying_method);
result = underlying_method->execute();
if (result != Error::Ok) {
jni_helper::throwExecutorchException(
static_cast<uint32_t>(result),
"Execution failed for method: " + method);
return {};
}
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
facebook::jni::JArrayClass<JEValue>::newArray(
underlying_method->outputs_size());
for (int i = 0; i < underlying_method->outputs_size(); i++) {
auto jevalue =
JEValue::newJEValueFromEValue(underlying_method->get_output(i));
jresult->setElement(i, *jevalue);
}
return jresult;
}
std::vector<EValue> evalues;
std::vector<TensorPtr> tensors;
static const auto typeCodeField =
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
for (int i = 0; i < jinputs->size(); i++) {
auto jevalue = jinputs->getElement(i);
const auto typeCode = jevalue->getFieldValue(typeCodeField);
if (typeCode == JEValue::kTypeCodeTensor) {
tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
evalues.emplace_back(tensors.back());
} else if (typeCode == JEValue::kTypeCodeInt) {
static const auto toIntMethod =
JEValue::javaClassStatic()->getMethod<jlong()>("toInt");
evalues.emplace_back(static_cast<int64_t>(toIntMethod(jevalue)));
} else if (typeCode == JEValue::kTypeCodeDouble) {
static const auto toDoubleMethod =
JEValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
evalues.emplace_back(static_cast<double>(toDoubleMethod(jevalue)));
} else if (typeCode == JEValue::kTypeCodeBool) {
static const auto toBoolMethod =
JEValue::javaClassStatic()->getMethod<jboolean()>("toBool");
evalues.emplace_back(static_cast<bool>(toBoolMethod(jevalue)));
}
}
#if defined(ET_USE_THREADPOOL) && \
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
::executorch::extension::threadpool::UseNThreadsThreadPoolGuard
thread_pool_guard(num_threads_);
#endif
#ifdef EXECUTORCH_ANDROID_PROFILING
auto start = std::chrono::high_resolution_clock::now();
auto result = module_->execute(method, evalues);
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count();
ET_LOG(Debug, "Execution time: %lld ms.", duration);
#else
auto result = module_->execute(method, evalues);
#endif
if (!result.ok()) {
jni_helper::throwExecutorchException(
static_cast<uint32_t>(result.error()),
"Execution failed for method: " + method);
return {};
}
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
facebook::jni::JArrayClass<JEValue>::newArray(result.get().size());
for (int i = 0; i < result.get().size(); i++) {
auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
jresult->setElement(i, *jevalue);
}
return jresult;
}
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
readLogBuffer() {
return readLogBufferUtil();
}
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
readLogBufferStatic(facebook::jni::alias_ref<jclass>) {
return readLogBufferUtil();
}
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
readLogBufferUtil() {
#ifdef __ANDROID__
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret;
access_log_buffer([&](std::vector<log_entry>& buffer) {
const auto size = buffer.size();
ret = facebook::jni::JArrayClass<jstring>::newArray(size);
for (auto i = 0u; i < size; i++) {
const auto& entry = buffer[i];
// Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
// MESSAGE".
std::stringstream ss;
ss << "[" << entry.timestamp << " " << entry.function << " "
<< entry.filename << ":" << entry.line << "] "
<< static_cast<char>(entry.level) << " " << entry.message;
facebook::jni::local_ref<facebook::jni::JString> jstr_message =
facebook::jni::make_jstring(ss.str().c_str());
(*ret)[i] = jstr_message;
}
});
return ret;
#else
return facebook::jni::JArrayClass<String>::newArray(0);
#endif
}
jboolean etdump() {
#ifdef EXECUTORCH_ANDROID_PROFILING
executorch::etdump::ETDumpGen* etdumpgen =
(executorch::etdump::ETDumpGen*)module_->event_tracer();
auto etdump_data = etdumpgen->get_etdump_data();
if (etdump_data.buf != nullptr && etdump_data.size > 0) {
int etdump_file =
open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
if (etdump_file == -1) {
ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
return false;
}
ssize_t bytes_written =
write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
if (bytes_written == -1) {
ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
return false;
} else {
ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
}
close(etdump_file);
free(etdump_data.buf);
return true;
} else {
ET_LOG(Error, "No ETDump data available!");
}
#endif
return false;
}
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
const auto& names_result = module_->method_names();
if (!names_result.ok()) {
// Format hex string
std::stringstream ss;
ss << "Cannot get load module [Native Error: 0x" << std::hex
<< std::uppercase << static_cast<uint32_t>(names_result.error())
<< "]";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::InvalidArgument), ss.str());
return {};
}
const auto& methods = names_result.get();
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
int i = 0;
for (auto s : methods) {
facebook::jni::local_ref<facebook::jni::JString> method_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = method_name;
i++;
}
return ret;
}
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
facebook::jni::alias_ref<jstring> methodName) {
auto method_name = methodName->toStdString();
auto methodMetaResult = module_->method_meta(method_name);
if (!methodMetaResult.ok()) {
std::stringstream ss;
ss << "Cannot get method meta for '" << method_name
<< "' [Native Error: 0x" << std::hex << std::uppercase
<< static_cast<uint32_t>(methodMetaResult.error()) << "]";
jni_helper::throwExecutorchException(
static_cast<uint32_t>(methodMetaResult.error()), ss.str());
return {};
}
auto methodMeta = methodMetaResult.get();
std::unordered_set<std::string> backends;
for (auto i = 0; i < methodMeta.num_backends(); i++) {
auto backend_name_result = methodMeta.get_backend_name(i);
if (backend_name_result.ok()) {
backends.insert(backend_name_result.get());
}
}
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
int i = 0;
for (auto s : backends) {
facebook::jni::local_ref<facebook::jni::JString> backend_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = backend_name;
i++;
}
return ret;
}
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
makeNativeMethod("executeNative", ExecuTorchJni::execute),
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
makeNativeMethod(
"readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
makeNativeMethod("etdump", ExecuTorchJni::etdump),
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule),
makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline),
});
}
// ======= TTS Pipeline =======
static Module* cpModule_;
void setCpModule() {
cpModule_ = module_.get();
ET_LOG(Info, "TTS CP module set");
}
// Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP.
facebook::jni::local_ref<facebook::jni::JArrayInt>
runTtsPipeline(
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
jint maxTokens);
static facebook::jni::local_ref<facebook::jni::JArrayInt>
runTtsPipelineImpl(
Module* talker, Module* cp,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
jint maxTokens);
};
} // namespace executorch::extension
// ======= TTS Pipeline statics and wrapper =======
namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; }
namespace executorch::extension {
facebook::jni::local_ref<facebook::jni::JArrayInt>
ExecuTorchJni::runTtsPipeline(
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
jint maxTokens) {
if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); }
return runTtsPipelineImpl(module_.get(), cpModule_,
jPrefill, nPrefill, jTrailing, nTrailing,
jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens);
}
} // namespace
#include <arm_neon.h>
#include <cfloat>
#include <cmath>
static inline float tts_dot_neon(const float* a, const float* b, int n) {
float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
int i=0;
for(;i+15<n;i+=16){
s0=vfmaq_f32(s0,vld1q_f32(a+i),vld1q_f32(b+i));
s1=vfmaq_f32(s1,vld1q_f32(a+i+4),vld1q_f32(b+i+4));
s2=vfmaq_f32(s2,vld1q_f32(a+i+8),vld1q_f32(b+i+8));
s3=vfmaq_f32(s3,vld1q_f32(a+i+12),vld1q_f32(b+i+12));
}
float r=vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1),vaddq_f32(s2,s3)));
for(;i<n;i++) r+=a[i]*b[i];
return r;
}
static uint64_t tts_rng = 0x12345678ABCDEF01ULL;
static int tts_sample_topk(const float* logits, int vocab, float temp, int k) {
struct IV{int i;float v;};
std::vector<IV> topk(k,{0,-FLT_MAX});
for(int i=0;i<vocab;i++){
if(logits[i]>topk[k-1].v){topk[k-1]={i,logits[i]};
for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}}
}
float maxv=topk[0].v,sum=0;
for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;}
tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL;
float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum;
float acc=0;
for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;}
return topk[0].i;
}
namespace executorch::extension {
facebook::jni::local_ref<facebook::jni::JArrayInt>
ExecuTorchJni::runTtsPipelineImpl(
Module* talker, Module* cp,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
jint maxTokens)
{
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
static const int CODEC_EOS=2150;
// Copy JNI arrays
auto copyArr = [](facebook::jni::alias_ref<facebook::jni::JArrayFloat> j) {
std::vector<float> v(j->size());
j->getRegion(0, j->size(), v.data());
return v;
};
auto prefill=copyArr(jPrefill);
auto trailing=nTrailing>0?copyArr(jTrailing):std::vector<float>();
auto codecEmb=copyArr(jCodecEmb);
auto cpEmbs=copyArr(jCpEmbs);
auto cpHeads=copyArr(jCpHeads);
auto tCos=copyArr(jTCos),tSin=copyArr(jTSin);
auto cCos=copyArr(jCCos),cSin=copyArr(jCSin);
auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad);
int tkvElem=T_KV*T_KV_LEN*T_HD;
std::vector<float> tK(T_L*tkvElem,0),tV(T_L*tkvElem,0);
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
float hidden[DIM]={},logits[VOCAB]={};
std::vector<int> allCodes,cb0Hist;
int pos=0,curCb0=-1,trIdx=0;
// Get raw Method pointers for direct execution
auto tMethodRes = talker->method("forward");
auto cMethodRes = cp->method("forward");
if (!tMethodRes.ok() || !cMethodRes.ok()) {
ET_LOG(Error, "TTS cannot get Method pointers");
return facebook::jni::JArrayInt::newArray(0);
}
Method* tMethod = tMethodRes.get();
Method* cMethod = cMethodRes.get();
// Talker: prepare once, cache pointers, reuse for all 58+ steps
{auto prep=executorch::extension::prepare_input_tensors(*tMethod);}
float* tInEmb = tMethod->mutable_input(0).toTensor().mutable_data_ptr<float>();
float* tInMask = tMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
float* tInCos = tMethod->mutable_input(2).toTensor().mutable_data_ptr<float>();
float* tInSin = tMethod->mutable_input(3).toTensor().mutable_data_ptr<float>();
float* tInKV[T_L*2];
for(int i=0;i<T_L;i++){
tInKV[i*2] = tMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>();
tInKV[i*2+1] = tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>();
}
auto talkerStep = [&](const float* emb) {
int pi=std::min(pos,249);
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
if(mi>=0) mask[mi]=0.0f;
memcpy(tInEmb, emb, DIM*4);
memcpy(tInMask, mask, T_KV_LEN*4);
memcpy(tInCos, tCos.data()+pi*T_HD, T_HD*4);
memcpy(tInSin, tSin.data()+pi*T_HD, T_HD*4);
for(int i=0;i<T_L;i++){
memcpy(tInKV[i*2], tK.data()+i*tkvElem, tkvElem*4);
memcpy(tInKV[i*2+1], tV.data()+i*tkvElem, tkvElem*4);
}
auto status = tMethod->execute();
if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;}
memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
for(int i=0;i<T_L;i++){
memcpy(tK.data()+i*tkvElem, tMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
}
pos++;
};
// CP step: 17 autoregressive steps with cached input pointers
// prepare_input_tensors called ONCE, then reuse pointers for all 17×58 steps
int ckvElem=C_KV*C_KV_LEN*C_HD;
std::vector<float> ckv(C_L*2*ckvElem,0);
{auto prep=executorch::extension::prepare_input_tensors(*cMethod);} // first alloc
// Cache input data pointers (stable after prepare)
float* cpInEmb = cMethod->mutable_input(0).toTensor().mutable_data_ptr<float>();
float* cpInMask = cMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
float* cpInCos = cMethod->mutable_input(2).toTensor().mutable_data_ptr<float>();
float* cpInSin = cMethod->mutable_input(3).toTensor().mutable_data_ptr<float>();
float* cpInKV[C_L*2];
for(int i=0;i<C_L;i++){
cpInKV[i*2] = cMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>();
cpInKV[i*2+1] = cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>();
}
auto cpStep = [&](const float* h, int cb0, int* codes) {
memset(ckv.data(), 0, ckv.size()*4); // reset KV caches
for(int step=0;step<17;step++){
const float*emb;
if(step==0) emb=h;
else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM;
else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
// Write directly to cached pointers (no prepare_input_tensors!)
memcpy(cpInEmb, emb, DIM*4);
for(int p=0;p<C_KV_LEN;p++) cpInMask[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
memcpy(cpInCos, cCos.data()+step*C_HD, C_HD*4);
memcpy(cpInSin, cSin.data()+step*C_HD, C_HD*4);
for(int i=0;i<C_L*2;i++) memcpy(cpInKV[i], ckv.data()+i*ckvElem, ckvElem*4);
auto status=cMethod->execute();
if(status!=Error::Ok) break;
const float*ho=cMethod->get_output(0).toTensor().const_data_ptr<float>();
if(step>=1&&step-1<15){
const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM;
int best=0;float bv=-FLT_MAX;
for(int j=0;j<CB_SIZE;j++){float d=tts_dot_neon(ho,W+j*DIM,DIM);if(d>bv){bv=d;best=j;}}
codes[step-1]=best;
}
for(int i=0;i<C_L;i++){
memcpy(ckv.data()+(i*2)*ckvElem,cMethod->get_output(1+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
}
}
};
auto t0=std::chrono::high_resolution_clock::now();
// Prefill
for(int s=0;s<nPrefill;s++){
talkerStep(prefill.data()+s*DIM);
if(s==nPrefill-1){
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
curCb0=tts_sample_topk(logits,VOCAB,0.9f,50);
}
}
auto tP=std::chrono::high_resolution_clock::now();
ET_LOG(Info,"TTS Prefill: %.0fms, cb0=%d",std::chrono::duration<float,std::milli>(tP-t0).count(),curCb0);
if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0);
// Generation
float totalTalker=0,totalCp=0;
for(int g=0;g<maxTokens;g++){
int codes[NUM_CB]={}; codes[0]=curCb0;
auto tc0=std::chrono::high_resolution_clock::now();
int cpCodes[15]={};
cpStep(hidden,curCb0,cpCodes);
auto tc1=std::chrono::high_resolution_clock::now();
totalCp+=std::chrono::duration<float,std::milli>(tc1-tc0).count();
for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
cb0Hist.push_back(curCb0);
// Next embed: pre-computed trailing OR codec_sum + eos/pad
float nextEmb[DIM]={};
if(trIdx<nTrailing){
memcpy(nextEmb,trailing.data()+trIdx*DIM,DIM*4);
trIdx++;
} else {
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
for(int cb=0;cb<15;cb++){
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
}
if(trIdx==nTrailing){for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmb[k];trIdx++;}
else {for(int k=0;k<DIM;k++) nextEmb[k]+=padEmb[k];}
}
auto tt0=std::chrono::high_resolution_clock::now();
talkerStep(nextEmb);
auto tt1=std::chrono::high_resolution_clock::now();
totalTalker+=std::chrono::duration<float,std::milli>(tt1-tt0).count();
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
std::unordered_set<int> seen(cb0Hist.begin(),cb0Hist.end());
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
int next=tts_sample_topk(logits,VOCAB,0.9f,50);
if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;}
if((int)cb0Hist.size()>=9){
bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;}
if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;}
}
curCb0=next;
}
int nTok=(int)allCodes.size()/NUM_CB;
auto t1=std::chrono::high_resolution_clock::now();
ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms",
nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1),
std::chrono::duration<float,std::milli>(t1-t0).count());
auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size());
result->setRegion(0,(int)allCodes.size(),allCodes.data());
return result;
}
} // namespace executorch::extension
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
extern void register_natives_for_llm();
#else
// No op if we don't build LLM
void register_natives_for_llm() {}
#endif
extern void register_natives_for_runtime();
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
extern void register_natives_for_training();
#else
// No op if we don't build training JNI
void register_natives_for_training() {}
#endif
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_llm();
register_natives_for_runtime();
register_natives_for_training();
});
}