925 lines
37 KiB
C++
925 lines
37 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include <executorch/extension/android/jni/jni_helper.h>
|
|
#include <executorch/extension/android/jni/jni_layer_constants.h>
|
|
|
|
#include <executorch/extension/android/jni/log.h>
|
|
#include <executorch/extension/module/module.h>
|
|
#include <executorch/extension/runner_util/inputs.h>
|
|
#include <executorch/extension/tensor/tensor.h>
|
|
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
|
|
#include <executorch/runtime/core/portable_type/tensor_impl.h>
|
|
#include <executorch/runtime/platform/log.h>
|
|
#include <executorch/runtime/platform/platform.h>
|
|
#include <executorch/runtime/platform/runtime.h>
|
|
#include <cassert>
|
|
#include <chrono>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#ifdef ET_USE_THREADPOOL
|
|
#include <cpuinfo.h>
|
|
#include <executorch/extension/threadpool/threadpool.h>
|
|
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
|
|
#include <executorch/extension/threadpool/fb/threadpool_use_n_threads.h>
|
|
#endif
|
|
#endif
|
|
|
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
|
#include <executorch/devtools/etdump/etdump_flatcc.h>
|
|
#include <fcntl.h>
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
#include <fbjni/ByteBuffer.h>
|
|
#include <fbjni/fbjni.h>
|
|
|
|
using namespace executorch::extension;
|
|
using namespace torch::executor;
|
|
|
|
namespace executorch::extension {
|
|
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
|
public:
|
|
constexpr static const char* kJavaDescriptor =
|
|
"Lorg/pytorch/executorch/Tensor;";
|
|
|
|
explicit TensorHybrid(executorch::aten::Tensor tensor) {}
|
|
|
|
static facebook::jni::local_ref<TensorHybrid::javaobject>
|
|
newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
|
|
// Java wrapper currently only supports contiguous tensors.
|
|
|
|
const auto scalarType = tensor.scalar_type();
|
|
if (scalar_type_to_java_dtype.count(scalarType) == 0) {
|
|
std::stringstream ss;
|
|
ss << "executorch::aten::Tensor scalar type "
|
|
<< static_cast<int>(scalarType) << " is not supported on java side";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return nullptr;
|
|
}
|
|
int jdtype = scalar_type_to_java_dtype.at(scalarType);
|
|
|
|
const auto& tensor_shape = tensor.sizes();
|
|
std::vector<jlong> tensor_shape_vec;
|
|
for (const auto& s : tensor_shape) {
|
|
tensor_shape_vec.push_back(s);
|
|
}
|
|
facebook::jni::local_ref<jlongArray> jTensorShape =
|
|
facebook::jni::make_long_array(tensor_shape_vec.size());
|
|
jTensorShape->setRegion(
|
|
0, tensor_shape_vec.size(), tensor_shape_vec.data());
|
|
|
|
static auto cls = TensorHybrid::javaClassStatic();
|
|
// Note: this is safe as long as the data stored in tensor is valid; the
|
|
// data won't go out of scope as long as the Method for the inference is
|
|
// valid and there is no other inference call. Java layer picks up this
|
|
// value immediately so the data is valid.
|
|
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
|
|
facebook::jni::JByteBuffer::wrapBytes(
|
|
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
|
|
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
|
|
|
|
static const auto jMethodNewTensor =
|
|
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
|
|
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
|
|
facebook::jni::alias_ref<jlongArray>,
|
|
jint,
|
|
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
|
|
return jMethodNewTensor(
|
|
cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
|
|
}
|
|
|
|
static TensorPtr newTensorFromJTensor(
|
|
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
|
|
static auto cls = TensorHybrid::javaClassStatic();
|
|
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
|
|
jint jdtype = dtypeMethod(jtensor);
|
|
|
|
static const auto shapeField = cls->getField<jlongArray>("shape");
|
|
auto jshape = jtensor->getFieldValue(shapeField);
|
|
|
|
static auto dataBufferMethod = cls->getMethod<
|
|
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
|
"getRawDataBuffer");
|
|
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
|
|
dataBufferMethod(jtensor);
|
|
|
|
const auto rank = jshape->size();
|
|
|
|
const auto shapeArr = jshape->getRegion(0, rank);
|
|
std::vector<executorch::aten::SizesType> shape_vec;
|
|
shape_vec.reserve(rank);
|
|
|
|
int64_t numel = 1;
|
|
for (int i = 0; i < rank; i++) {
|
|
shape_vec.push_back(shapeArr[i]);
|
|
}
|
|
for (int i = rank - 1; i >= 0; --i) {
|
|
numel *= shapeArr[i];
|
|
}
|
|
JNIEnv* jni = facebook::jni::Environment::current();
|
|
if (java_dtype_to_scalar_type.count(jdtype) == 0) {
|
|
std::stringstream ss;
|
|
ss << "Unknown Tensor jdtype: [" << jdtype << "]";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return nullptr;
|
|
}
|
|
ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
|
|
const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
|
|
if (dataCapacity < 0) {
|
|
std::stringstream ss;
|
|
ss << "Tensor buffer is not direct or has invalid capacity";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return nullptr;
|
|
}
|
|
const size_t elementSize = executorch::runtime::elementSize(scalar_type);
|
|
const jlong expectedElements = static_cast<jlong>(numel);
|
|
const jlong expectedBytes =
|
|
expectedElements * static_cast<jlong>(elementSize);
|
|
const bool matchesElements = dataCapacity == expectedElements;
|
|
const bool matchesBytes = dataCapacity == expectedBytes;
|
|
if (!matchesElements && !matchesBytes) {
|
|
std::stringstream ss;
|
|
ss << "Tensor dimensions(elements number: " << numel
|
|
<< ") inconsistent with buffer capacity " << dataCapacity
|
|
<< " (element size bytes: " << elementSize << ")";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return nullptr;
|
|
}
|
|
return from_blob(
|
|
jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
|
|
}
|
|
|
|
private:
|
|
friend HybridBase;
|
|
};
|
|
|
|
class JEValue : public facebook::jni::JavaClass<JEValue> {
|
|
public:
|
|
constexpr static const char* kJavaDescriptor =
|
|
"Lorg/pytorch/executorch/EValue;";
|
|
|
|
constexpr static int kTypeCodeTensor = 1;
|
|
constexpr static int kTypeCodeString = 2;
|
|
constexpr static int kTypeCodeDouble = 3;
|
|
constexpr static int kTypeCodeInt = 4;
|
|
constexpr static int kTypeCodeBool = 5;
|
|
|
|
static facebook::jni::local_ref<JEValue> newJEValueFromEValue(EValue evalue) {
|
|
if (evalue.isTensor()) {
|
|
static auto jMethodTensor =
|
|
JEValue::javaClassStatic()
|
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(
|
|
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
|
|
return jMethodTensor(
|
|
JEValue::javaClassStatic(),
|
|
TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
|
|
} else if (evalue.isInt()) {
|
|
static auto jMethodTensor =
|
|
JEValue::javaClassStatic()
|
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jlong)>(
|
|
"from");
|
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
|
|
} else if (evalue.isDouble()) {
|
|
static auto jMethodTensor =
|
|
JEValue::javaClassStatic()
|
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jdouble)>(
|
|
"from");
|
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
|
|
} else if (evalue.isBool()) {
|
|
static auto jMethodTensor =
|
|
JEValue::javaClassStatic()
|
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jboolean)>(
|
|
"from");
|
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
|
|
} else if (evalue.isString()) {
|
|
static auto jMethodTensor =
|
|
JEValue::javaClassStatic()
|
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(
|
|
facebook::jni::local_ref<jstring>)>("from");
|
|
std::string str =
|
|
std::string(evalue.toString().begin(), evalue.toString().end());
|
|
return jMethodTensor(
|
|
JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
|
|
}
|
|
std::stringstream ss;
|
|
ss << "Unknown EValue type: [" << static_cast<int>(evalue.tag) << "]";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return {};
|
|
}
|
|
|
|
static TensorPtr JEValueToTensorImpl(
|
|
facebook::jni::alias_ref<JEValue> JEValue) {
|
|
static const auto typeCodeField =
|
|
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
|
|
const auto typeCode = JEValue->getFieldValue(typeCodeField);
|
|
if (JEValue::kTypeCodeTensor == typeCode) {
|
|
static const auto jMethodGetTensor =
|
|
JEValue::javaClassStatic()
|
|
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
|
|
"toTensor");
|
|
auto jtensor = jMethodGetTensor(JEValue);
|
|
return TensorHybrid::newTensorFromJTensor(jtensor);
|
|
}
|
|
std::stringstream ss;
|
|
ss << "Unknown EValue typeCode: " << typeCode;
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
|
return {};
|
|
}
|
|
};
|
|
|
|
class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
|
|
private:
|
|
friend HybridBase;
|
|
std::unique_ptr<Module> module_;
|
|
#if defined(ET_USE_THREADPOOL) && \
|
|
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
|
|
int num_threads_{0};
|
|
#endif
|
|
|
|
public:
|
|
constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
|
|
|
|
static facebook::jni::local_ref<jhybriddata> initHybrid(
|
|
facebook::jni::alias_ref<jclass>,
|
|
facebook::jni::alias_ref<jstring> modelPath,
|
|
jint loadMode,
|
|
jint numThreads) {
|
|
return makeCxxInstance(modelPath, loadMode, numThreads);
|
|
}
|
|
|
|
ExecuTorchJni(
|
|
facebook::jni::alias_ref<jstring> modelPath,
|
|
jint loadMode,
|
|
jint numThreads) {
|
|
Module::LoadMode load_mode = Module::LoadMode::Mmap;
|
|
if (loadMode == 0) {
|
|
load_mode = Module::LoadMode::File;
|
|
} else if (loadMode == 1) {
|
|
load_mode = Module::LoadMode::Mmap;
|
|
} else if (loadMode == 2) {
|
|
load_mode = Module::LoadMode::MmapUseMlock;
|
|
} else if (loadMode == 3) {
|
|
load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
|
|
}
|
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
|
auto etdump_gen = std::make_unique<executorch::etdump::ETDumpGen>();
|
|
#else
|
|
auto etdump_gen = nullptr;
|
|
#endif
|
|
module_ = std::make_unique<Module>(
|
|
modelPath->toStdString(), load_mode, std::move(etdump_gen));
|
|
|
|
#ifdef ET_USE_THREADPOOL
|
|
// Default to using cores/2 threadpool threads. The long-term plan is to
|
|
// improve performant core detection in CPUInfo, but for now we can use
|
|
// cores/2 as a sane default.
|
|
//
|
|
// Based on testing, this is almost universally faster than using all
|
|
// cores, as efficiency cores can be quite slow. In extreme cases, using
|
|
// all cores can be 10x slower than using cores/2.
|
|
int thread_count =
|
|
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
|
|
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
|
|
num_threads_ = thread_count;
|
|
#else
|
|
auto threadpool = executorch::extension::threadpool::get_threadpool();
|
|
if (threadpool) {
|
|
if (thread_count > 0) {
|
|
threadpool->_unsafe_reset_threadpool(thread_count);
|
|
}
|
|
}
|
|
#endif
|
|
#endif
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute(
|
|
facebook::jni::alias_ref<jstring> methodName,
|
|
facebook::jni::alias_ref<
|
|
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
|
|
jinputs) {
|
|
return execute_method(methodName->toStdString(), jinputs);
|
|
}
|
|
|
|
jint load_method(facebook::jni::alias_ref<jstring> methodName) {
|
|
return static_cast<jint>(module_->load_method(methodName->toStdString()));
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
|
|
std::string method,
|
|
facebook::jni::alias_ref<
|
|
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
|
|
jinputs) {
|
|
// If no inputs is given, it will run with sample inputs (ones)
|
|
if (jinputs->size() == 0) {
|
|
auto result = module_->load_method(method);
|
|
if (result != Error::Ok) {
|
|
// Format hex string
|
|
std::stringstream ss;
|
|
ss << "Cannot get method names [Native Error: 0x" << std::hex
|
|
<< std::uppercase << static_cast<uint32_t>(result) << "]";
|
|
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(result), ss.str());
|
|
return {};
|
|
}
|
|
auto&& underlying_method = module_->methods_[method].method;
|
|
auto&& buf = prepare_input_tensors(*underlying_method);
|
|
result = underlying_method->execute();
|
|
if (result != Error::Ok) {
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(result),
|
|
"Execution failed for method: " + method);
|
|
return {};
|
|
}
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
|
|
facebook::jni::JArrayClass<JEValue>::newArray(
|
|
underlying_method->outputs_size());
|
|
|
|
for (int i = 0; i < underlying_method->outputs_size(); i++) {
|
|
auto jevalue =
|
|
JEValue::newJEValueFromEValue(underlying_method->get_output(i));
|
|
jresult->setElement(i, *jevalue);
|
|
}
|
|
return jresult;
|
|
}
|
|
|
|
std::vector<EValue> evalues;
|
|
std::vector<TensorPtr> tensors;
|
|
|
|
static const auto typeCodeField =
|
|
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
|
|
|
|
for (int i = 0; i < jinputs->size(); i++) {
|
|
auto jevalue = jinputs->getElement(i);
|
|
const auto typeCode = jevalue->getFieldValue(typeCodeField);
|
|
if (typeCode == JEValue::kTypeCodeTensor) {
|
|
tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
|
|
evalues.emplace_back(tensors.back());
|
|
} else if (typeCode == JEValue::kTypeCodeInt) {
|
|
static const auto toIntMethod =
|
|
JEValue::javaClassStatic()->getMethod<jlong()>("toInt");
|
|
evalues.emplace_back(static_cast<int64_t>(toIntMethod(jevalue)));
|
|
} else if (typeCode == JEValue::kTypeCodeDouble) {
|
|
static const auto toDoubleMethod =
|
|
JEValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
|
evalues.emplace_back(static_cast<double>(toDoubleMethod(jevalue)));
|
|
} else if (typeCode == JEValue::kTypeCodeBool) {
|
|
static const auto toBoolMethod =
|
|
JEValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
|
evalues.emplace_back(static_cast<bool>(toBoolMethod(jevalue)));
|
|
}
|
|
}
|
|
|
|
#if defined(ET_USE_THREADPOOL) && \
|
|
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
|
|
::executorch::extension::threadpool::UseNThreadsThreadPoolGuard
|
|
thread_pool_guard(num_threads_);
|
|
#endif
|
|
|
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
auto result = module_->execute(method, evalues);
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
|
|
.count();
|
|
ET_LOG(Debug, "Execution time: %lld ms.", duration);
|
|
|
|
#else
|
|
auto result = module_->execute(method, evalues);
|
|
|
|
#endif
|
|
|
|
if (!result.ok()) {
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(result.error()),
|
|
"Execution failed for method: " + method);
|
|
return {};
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
|
|
facebook::jni::JArrayClass<JEValue>::newArray(result.get().size());
|
|
|
|
for (int i = 0; i < result.get().size(); i++) {
|
|
auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
|
|
jresult->setElement(i, *jevalue);
|
|
}
|
|
return jresult;
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
|
readLogBuffer() {
|
|
return readLogBufferUtil();
|
|
}
|
|
|
|
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
|
readLogBufferStatic(facebook::jni::alias_ref<jclass>) {
|
|
return readLogBufferUtil();
|
|
}
|
|
|
|
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
|
readLogBufferUtil() {
|
|
#ifdef __ANDROID__
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret;
|
|
|
|
access_log_buffer([&](std::vector<log_entry>& buffer) {
|
|
const auto size = buffer.size();
|
|
ret = facebook::jni::JArrayClass<jstring>::newArray(size);
|
|
for (auto i = 0u; i < size; i++) {
|
|
const auto& entry = buffer[i];
|
|
// Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
|
|
// MESSAGE".
|
|
std::stringstream ss;
|
|
ss << "[" << entry.timestamp << " " << entry.function << " "
|
|
<< entry.filename << ":" << entry.line << "] "
|
|
<< static_cast<char>(entry.level) << " " << entry.message;
|
|
|
|
facebook::jni::local_ref<facebook::jni::JString> jstr_message =
|
|
facebook::jni::make_jstring(ss.str().c_str());
|
|
(*ret)[i] = jstr_message;
|
|
}
|
|
});
|
|
|
|
return ret;
|
|
#else
|
|
return facebook::jni::JArrayClass<String>::newArray(0);
|
|
#endif
|
|
}
|
|
|
|
jboolean etdump() {
|
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
|
executorch::etdump::ETDumpGen* etdumpgen =
|
|
(executorch::etdump::ETDumpGen*)module_->event_tracer();
|
|
auto etdump_data = etdumpgen->get_etdump_data();
|
|
|
|
if (etdump_data.buf != nullptr && etdump_data.size > 0) {
|
|
int etdump_file =
|
|
open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
|
|
if (etdump_file == -1) {
|
|
ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
|
|
return false;
|
|
}
|
|
ssize_t bytes_written =
|
|
write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
|
|
if (bytes_written == -1) {
|
|
ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
|
|
return false;
|
|
} else {
|
|
ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
|
|
}
|
|
close(etdump_file);
|
|
free(etdump_data.buf);
|
|
return true;
|
|
} else {
|
|
ET_LOG(Error, "No ETDump data available!");
|
|
}
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
|
|
const auto& names_result = module_->method_names();
|
|
if (!names_result.ok()) {
|
|
// Format hex string
|
|
std::stringstream ss;
|
|
ss << "Cannot get load module [Native Error: 0x" << std::hex
|
|
<< std::uppercase << static_cast<uint32_t>(names_result.error())
|
|
<< "]";
|
|
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str());
|
|
return {};
|
|
}
|
|
const auto& methods = names_result.get();
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
|
|
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
|
|
int i = 0;
|
|
for (auto s : methods) {
|
|
facebook::jni::local_ref<facebook::jni::JString> method_name =
|
|
facebook::jni::make_jstring(s.c_str());
|
|
(*ret)[i] = method_name;
|
|
i++;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
|
|
facebook::jni::alias_ref<jstring> methodName) {
|
|
auto method_name = methodName->toStdString();
|
|
auto methodMetaResult = module_->method_meta(method_name);
|
|
if (!methodMetaResult.ok()) {
|
|
std::stringstream ss;
|
|
ss << "Cannot get method meta for '" << method_name
|
|
<< "' [Native Error: 0x" << std::hex << std::uppercase
|
|
<< static_cast<uint32_t>(methodMetaResult.error()) << "]";
|
|
jni_helper::throwExecutorchException(
|
|
static_cast<uint32_t>(methodMetaResult.error()), ss.str());
|
|
return {};
|
|
}
|
|
auto methodMeta = methodMetaResult.get();
|
|
std::unordered_set<std::string> backends;
|
|
for (auto i = 0; i < methodMeta.num_backends(); i++) {
|
|
auto backend_name_result = methodMeta.get_backend_name(i);
|
|
if (backend_name_result.ok()) {
|
|
backends.insert(backend_name_result.get());
|
|
}
|
|
}
|
|
|
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
|
|
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
|
|
int i = 0;
|
|
for (auto s : backends) {
|
|
facebook::jni::local_ref<facebook::jni::JString> backend_name =
|
|
facebook::jni::make_jstring(s.c_str());
|
|
(*ret)[i] = backend_name;
|
|
i++;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
static void registerNatives() {
|
|
registerHybrid({
|
|
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
|
|
makeNativeMethod("executeNative", ExecuTorchJni::execute),
|
|
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
|
|
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
|
|
makeNativeMethod(
|
|
"readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
|
|
makeNativeMethod("etdump", ExecuTorchJni::etdump),
|
|
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
|
|
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
|
|
makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule),
|
|
makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline),
|
|
});
|
|
}
|
|
|
|
// ======= TTS Pipeline =======
|
|
static Module* cpModule_;
|
|
|
|
void setCpModule() {
|
|
cpModule_ = module_.get();
|
|
ET_LOG(Info, "TTS CP module set");
|
|
}
|
|
|
|
// Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP.
|
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
|
runTtsPipeline(
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
|
jint maxTokens);
|
|
|
|
static facebook::jni::local_ref<facebook::jni::JArrayInt>
|
|
runTtsPipelineImpl(
|
|
Module* talker, Module* cp,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
|
jint maxTokens);
|
|
};
|
|
} // namespace executorch::extension
|
|
|
|
// ======= TTS Pipeline statics and wrapper =======
|
|
namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; }
|
|
|
|
namespace executorch::extension {
|
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
|
ExecuTorchJni::runTtsPipeline(
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
|
jint maxTokens) {
|
|
if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); }
|
|
return runTtsPipelineImpl(module_.get(), cpModule_,
|
|
jPrefill, nPrefill, jTrailing, nTrailing,
|
|
jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens);
|
|
}
|
|
} // namespace
|
|
|
|
#include <arm_neon.h>
|
|
#include <cfloat>
|
|
#include <cmath>
|
|
|
|
static inline float tts_dot_neon(const float* a, const float* b, int n) {
|
|
float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
|
|
int i=0;
|
|
for(;i+15<n;i+=16){
|
|
s0=vfmaq_f32(s0,vld1q_f32(a+i),vld1q_f32(b+i));
|
|
s1=vfmaq_f32(s1,vld1q_f32(a+i+4),vld1q_f32(b+i+4));
|
|
s2=vfmaq_f32(s2,vld1q_f32(a+i+8),vld1q_f32(b+i+8));
|
|
s3=vfmaq_f32(s3,vld1q_f32(a+i+12),vld1q_f32(b+i+12));
|
|
}
|
|
float r=vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1),vaddq_f32(s2,s3)));
|
|
for(;i<n;i++) r+=a[i]*b[i];
|
|
return r;
|
|
}
|
|
|
|
static uint64_t tts_rng = 0x12345678ABCDEF01ULL;
|
|
static int tts_sample_topk(const float* logits, int vocab, float temp, int k) {
|
|
struct IV{int i;float v;};
|
|
std::vector<IV> topk(k,{0,-FLT_MAX});
|
|
for(int i=0;i<vocab;i++){
|
|
if(logits[i]>topk[k-1].v){topk[k-1]={i,logits[i]};
|
|
for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}}
|
|
}
|
|
float maxv=topk[0].v,sum=0;
|
|
for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;}
|
|
tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL;
|
|
float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum;
|
|
float acc=0;
|
|
for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;}
|
|
return topk[0].i;
|
|
}
|
|
|
|
namespace executorch::extension {
|
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
|
ExecuTorchJni::runTtsPipelineImpl(
|
|
Module* talker, Module* cp,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
|
jint maxTokens)
|
|
{
|
|
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
|
|
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
|
|
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
|
|
static const int CODEC_EOS=2150;
|
|
|
|
// Copy JNI arrays
|
|
auto copyArr = [](facebook::jni::alias_ref<facebook::jni::JArrayFloat> j) {
|
|
std::vector<float> v(j->size());
|
|
j->getRegion(0, j->size(), v.data());
|
|
return v;
|
|
};
|
|
auto prefill=copyArr(jPrefill);
|
|
auto trailing=nTrailing>0?copyArr(jTrailing):std::vector<float>();
|
|
auto codecEmb=copyArr(jCodecEmb);
|
|
auto cpEmbs=copyArr(jCpEmbs);
|
|
auto cpHeads=copyArr(jCpHeads);
|
|
auto tCos=copyArr(jTCos),tSin=copyArr(jTSin);
|
|
auto cCos=copyArr(jCCos),cSin=copyArr(jCSin);
|
|
auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad);
|
|
|
|
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
|
std::vector<float> tKV(T_L*2*tkvElem,0); // intermediate KV buffer (avoids output overwrite race)
|
|
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
|
|
float hidden[DIM]={},logits[VOCAB]={};
|
|
std::vector<int> allCodes,cb0Hist;
|
|
int pos=0,curCb0=-1,trIdx=0;
|
|
|
|
// Get raw Method pointers for direct execution
|
|
auto tMethodRes = talker->method("forward");
|
|
auto cMethodRes = cp->method("forward");
|
|
if (!tMethodRes.ok() || !cMethodRes.ok()) {
|
|
ET_LOG(Error, "TTS cannot get Method pointers");
|
|
return facebook::jni::JArrayInt::newArray(0);
|
|
}
|
|
Method* tMethod = tMethodRes.get();
|
|
Method* cMethod = cMethodRes.get();
|
|
|
|
// Talker: prepare once, cache pointers, reuse for all 58+ steps
|
|
{auto prep=executorch::extension::prepare_input_tensors(*tMethod);}
|
|
float* tInEmb = tMethod->mutable_input(0).toTensor().mutable_data_ptr<float>();
|
|
float* tInMask = tMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
|
|
float* tInCos = tMethod->mutable_input(2).toTensor().mutable_data_ptr<float>();
|
|
float* tInSin = tMethod->mutable_input(3).toTensor().mutable_data_ptr<float>();
|
|
float* tInKV[T_L*2];
|
|
for(int i=0;i<T_L;i++){
|
|
tInKV[i*2] = tMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>();
|
|
tInKV[i*2+1] = tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>();
|
|
}
|
|
|
|
auto talkerStep = [&](const float* emb) {
|
|
int pi=std::min(pos,249);
|
|
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
|
if(mi>=0) mask[mi]=0.0f;
|
|
|
|
memcpy(tInEmb, emb, DIM*4);
|
|
memcpy(tInMask, mask, T_KV_LEN*4);
|
|
memcpy(tInCos, tCos.data()+pi*T_HD, T_HD*4);
|
|
memcpy(tInSin, tSin.data()+pi*T_HD, T_HD*4);
|
|
for(int i=0;i<T_L;i++){
|
|
memcpy(tInKV[i*2], tKV.data()+(i*2)*tkvElem, tkvElem*4);
|
|
memcpy(tInKV[i*2+1], tKV.data()+(i*2+1)*tkvElem, tkvElem*4);
|
|
}
|
|
|
|
auto status = tMethod->execute();
|
|
if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;}
|
|
memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
|
|
memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
|
|
for(int i=0;i<T_L;i++){
|
|
memcpy(tKV.data()+(i*2)*tkvElem, tMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
|
|
memcpy(tKV.data()+(i*2+1)*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
|
|
}
|
|
pos++;
|
|
};
|
|
|
|
// CP step: prepare once, direct output→input KV copy
|
|
int ckvElem=C_KV*C_KV_LEN*C_HD;
|
|
{auto prep=executorch::extension::prepare_input_tensors(*cMethod);}
|
|
float* cpInEmb = cMethod->mutable_input(0).toTensor().mutable_data_ptr<float>();
|
|
float* cpInMask = cMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
|
|
float* cpInCos = cMethod->mutable_input(2).toTensor().mutable_data_ptr<float>();
|
|
float* cpInSin = cMethod->mutable_input(3).toTensor().mutable_data_ptr<float>();
|
|
float* cpInKV[C_L*2];
|
|
for(int i=0;i<C_L;i++){
|
|
cpInKV[i*2] = cMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>();
|
|
cpInKV[i*2+1] = cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>();
|
|
}
|
|
|
|
auto cpStep = [&](const float* h, int cb0, int* codes) {
|
|
// Reset CP KV to zeros for step 0
|
|
for(int i=0;i<C_L*2;i++) memset(cpInKV[i], 0, ckvElem*4);
|
|
|
|
for(int step=0;step<17;step++){
|
|
const float*emb;
|
|
if(step==0) emb=h;
|
|
else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM;
|
|
else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
|
|
|
|
memcpy(cpInEmb, emb, DIM*4);
|
|
for(int p=0;p<C_KV_LEN;p++) cpInMask[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
|
|
memcpy(cpInCos, cCos.data()+step*C_HD, C_HD*4);
|
|
memcpy(cpInSin, cSin.data()+step*C_HD, C_HD*4);
|
|
// KV: copy from previous output directly to input (skip buffer)
|
|
if(step>0){
|
|
for(int i=0;i<C_L;i++){
|
|
memcpy(cpInKV[i*2], cMethod->get_output(1+i*2).toTensor().const_data_ptr<float>(), ckvElem*4);
|
|
memcpy(cpInKV[i*2+1], cMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), ckvElem*4);
|
|
}
|
|
}
|
|
|
|
auto status=cMethod->execute();
|
|
if(status!=Error::Ok) break;
|
|
const float*ho=cMethod->get_output(0).toTensor().const_data_ptr<float>();
|
|
if(step>=1&&step-1<15){
|
|
const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM;
|
|
int best=0;float bv=-FLT_MAX;
|
|
for(int j=0;j<CB_SIZE;j++){float d=tts_dot_neon(ho,W+j*DIM,DIM);if(d>bv){bv=d;best=j;}}
|
|
codes[step-1]=best;
|
|
}
|
|
}
|
|
};
|
|
|
|
auto t0=std::chrono::high_resolution_clock::now();
|
|
// Prefill
|
|
for(int s=0;s<nPrefill;s++){
|
|
talkerStep(prefill.data()+s*DIM);
|
|
if(s==nPrefill-1){
|
|
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
|
curCb0=tts_sample_topk(logits,VOCAB,0.9f,50);
|
|
}
|
|
}
|
|
auto tP=std::chrono::high_resolution_clock::now();
|
|
ET_LOG(Info,"TTS Prefill: %.0fms, cb0=%d",std::chrono::duration<float,std::milli>(tP-t0).count(),curCb0);
|
|
|
|
if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0);
|
|
|
|
// Generation
|
|
float totalTalker=0,totalCp=0;
|
|
for(int g=0;g<maxTokens;g++){
|
|
int codes[NUM_CB]={}; codes[0]=curCb0;
|
|
auto tc0=std::chrono::high_resolution_clock::now();
|
|
int cpCodes[15]={};
|
|
cpStep(hidden,curCb0,cpCodes);
|
|
auto tc1=std::chrono::high_resolution_clock::now();
|
|
totalCp+=std::chrono::duration<float,std::milli>(tc1-tc0).count();
|
|
for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
|
|
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
|
|
cb0Hist.push_back(curCb0);
|
|
|
|
// Next embed: OUR codec_sum + trailing text/eos/pad
|
|
// With shared Module, codec_sum is self-consistent (same QNN graph)
|
|
float nextEmb[DIM]={};
|
|
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
|
|
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
|
|
for(int cb=0;cb<15;cb++){
|
|
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
|
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
|
}
|
|
if(trIdx<nTrailing){
|
|
const float*te=trailing.data()+trIdx*DIM;
|
|
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
|
|
trIdx++;
|
|
} else if(trIdx==nTrailing){
|
|
// eos once after text
|
|
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmb[k];
|
|
trIdx++;
|
|
} else {
|
|
// pad after eos
|
|
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmb[k];
|
|
}
|
|
|
|
auto tt0=std::chrono::high_resolution_clock::now();
|
|
talkerStep(nextEmb);
|
|
auto tt1=std::chrono::high_resolution_clock::now();
|
|
totalTalker+=std::chrono::duration<float,std::milli>(tt1-tt0).count();
|
|
|
|
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
|
std::unordered_set<int> seen(cb0Hist.begin(),cb0Hist.end());
|
|
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
|
|
int next;
|
|
if(trIdx>nTrailing){
|
|
// In pad zone: greedy argmax to give EOS its honest chance.
|
|
// top-k sampling at temp 0.9 keeps producing audio even when EOS is the
|
|
// model's preferred choice; Python's seeded sampler hits EOS, ours doesn't.
|
|
int best=0;float bv=logits[0];
|
|
for(int j=1;j<VOCAB;j++) if(logits[j]>bv){bv=logits[j];best=j;}
|
|
next=best;
|
|
} else {
|
|
next=tts_sample_topk(logits,VOCAB,0.9f,50);
|
|
}
|
|
if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d (trIdx=%d nTr=%d)",g+2,trIdx,nTrailing);break;}
|
|
if((int)cb0Hist.size()>=9){
|
|
bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;}
|
|
if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;}
|
|
}
|
|
curCb0=next;
|
|
}
|
|
int nTok=(int)allCodes.size()/NUM_CB;
|
|
auto t1=std::chrono::high_resolution_clock::now();
|
|
ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms",
|
|
nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1),
|
|
std::chrono::duration<float,std::milli>(t1-t0).count());
|
|
|
|
auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size());
|
|
result->setRegion(0,(int)allCodes.size(),allCodes.data());
|
|
return result;
|
|
}
|
|
} // namespace executorch::extension
|
|
|
|
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
|
|
extern void register_natives_for_llm();
|
|
#else
|
|
// No op if we don't build LLM
|
|
void register_natives_for_llm() {}
|
|
#endif
|
|
extern void register_natives_for_runtime();
|
|
|
|
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
|
|
extern void register_natives_for_training();
|
|
#else
|
|
// No op if we don't build training JNI
|
|
void register_natives_for_training() {}
|
|
#endif
|
|
|
|
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
|
return facebook::jni::initialize(vm, [] {
|
|
executorch::extension::ExecuTorchJni::registerNatives();
|
|
register_natives_for_llm();
|
|
register_natives_for_runtime();
|
|
register_natives_for_training();
|
|
});
|
|
}
|