diff --git a/executorch-custom/Module.java b/executorch-custom/Module.java
new file mode 100644
index 0000000..d248eaa
--- /dev/null
+++ b/executorch-custom/Module.java
@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.executorch;
+
+import android.util.Log;
+import com.facebook.jni.HybridData;
+import com.facebook.jni.annotations.DoNotStrip;
+import com.facebook.soloader.nativeloader.NativeLoader;
+import com.facebook.soloader.nativeloader.SystemDelegate;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import org.pytorch.executorch.annotations.Experimental;
+
+/**
+ * Java wrapper for ExecuTorch Module.
+ *
+ *
Warning: These APIs are experimental and subject to change without notice
+ */
+@Experimental
+public class Module {
+
+ static {
+ if (!NativeLoader.isInitialized()) {
+ NativeLoader.init(new SystemDelegate());
+ }
+ // Loads libexecutorch.so from jniLibs
+ NativeLoader.loadLibrary("executorch");
+ }
+
+ /** Load mode for the module. Load the whole file as a buffer. */
+ public static final int LOAD_MODE_FILE = 0;
+
+ /** Load mode for the module. Use mmap to load pages into memory. */
+ public static final int LOAD_MODE_MMAP = 1;
+
+ /** Load mode for the module. Use memory locking and handle errors. */
+ public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
+
+ /** Load mode for the module. Use memory locking and ignore errors. */
+ public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
+
+ private final HybridData mHybridData;
+
+ private final Map mMethodMetadata;
+
+ @DoNotStrip
+ private static native HybridData initHybrid(
+ String moduleAbsolutePath, int loadMode, int numThreads);
+
+ private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
+ ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
+
+ mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
+
+ mMethodMetadata = populateMethodMeta();
+ }
+
+ private Map populateMethodMeta() {
+ String[] methods = getMethods();
+ Map metadata = new HashMap();
+ for (String name : methods) {
+ metadata.put(name, new MethodMetadata(name, getUsedBackends(name)));
+ }
+ return metadata;
+ }
+
+ /** Lock protecting the non-thread safe methods in mHybridData. */
+ private Lock mLock = new ReentrantLock();
+
+ /**
+ * Loads a serialized ExecuTorch module from the specified path on the disk.
+ *
+ * @param modelPath path to file that contains the serialized ExecuTorch module.
+ * @param loadMode load mode for the module. See constants in {@link Module}.
+ * @return new {@link org.pytorch.executorch.Module} object which owns the model module.
+ */
+ public static Module load(final String modelPath, int loadMode) {
+ return load(modelPath, loadMode, 0);
+ }
+
+ /**
+ * Loads a serialized ExecuTorch module from the specified path on the disk.
+ *
+ * @param modelPath path to file that contains the serialized ExecuTorch module.
+ * @param loadMode load mode for the module. See constants in {@link Module}.
+ * @param numThreads the number of threads to use for inference. A value of 0 defaults to a
+ * hardware-specific default.
+ * @return new {@link org.pytorch.executorch.Module} object which owns the model module.
+ */
+ public static Module load(final String modelPath, int loadMode, int numThreads) {
+ ExecuTorchRuntime.validateFilePath(modelPath, "model path");
+ return new Module(modelPath, loadMode, numThreads);
+ }
+
+ /**
+ * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU.
+ *
+ * @param modelPath path to file that contains the serialized ExecuTorch module.
+ * @return new {@link org.pytorch.executorch.Module} object which owns the model module.
+ */
+ public static Module load(final String modelPath) {
+ return load(modelPath, LOAD_MODE_FILE);
+ }
+
+ /**
+ * Runs the 'forward' method of this module with the specified arguments.
+ *
+ * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
+ * requires inputs but no inputs are given, the function will not error out, but run 'forward'
+ * with sample inputs.
+ * @return return value from the 'forward' method.
+ */
+ public EValue[] forward(EValue... inputs) {
+ return execute("forward", inputs);
+ }
+
+ /**
+ * Runs the specified method of this module with the specified arguments.
+ *
+ * @param methodName name of the ExecuTorch method to run.
+ * @param inputs arguments that will be passed to ExecuTorch method.
+ * @return return value from the method.
+ */
+ public EValue[] execute(String methodName, EValue... inputs) {
+ try {
+ mLock.lock();
+ if (!mHybridData.isValid()) {
+ Log.e("ExecuTorch", "Attempt to use a destroyed module");
+ return new EValue[0];
+ }
+ return executeNative(methodName, inputs);
+ } finally {
+ mLock.unlock();
+ }
+ }
+
+ @DoNotStrip
+ private native EValue[] executeNative(String methodName, EValue... inputs);
+
+ /**
+ * Load a method on this module. This might help with the first time inference performance,
+ * because otherwise the method is loaded lazily when it's execute. Note: this function is
+ * synchronous, and will block until the method is loaded. Therefore, it is recommended to call
+ * this on a background thread. However, users need to make sure that they don't execute before
+ * this function returns.
+ *
+ * @return the Error code if there was an error loading the method
+ */
+ public int loadMethod(String methodName) {
+ try {
+ mLock.lock();
+ if (!mHybridData.isValid()) {
+ Log.e("ExecuTorch", "Attempt to use a destroyed module");
+ return 0x2; // InvalidState
+ }
+ return loadMethodNative(methodName);
+ } finally {
+ mLock.unlock();
+ }
+ }
+
+ @DoNotStrip
+ private native int loadMethodNative(String methodName);
+
+ /**
+ * Returns the names of the backends in a certain method.
+ *
+ * @param methodName method name to query
+ * @return an array of backend name
+ */
+ @DoNotStrip
+ private native String[] getUsedBackends(String methodName);
+
+ /**
+ * Returns the names of methods.
+ *
+ * @return name of methods in this Module
+ */
+ @DoNotStrip
+ public native String[] getMethods();
+
+ /**
+ * Get the corresponding @MethodMetadata for a method
+ *
+ * @param name method name
+ * @return @MethodMetadata for this method
+ */
+ public MethodMetadata getMethodMetadata(String name) {
+ MethodMetadata methodMetadata = mMethodMetadata.get(name);
+ if (methodMetadata == null) {
+ throw new IllegalArgumentException("method " + name + " does not exist for this module");
+ }
+ return methodMetadata;
+ }
+
+ @DoNotStrip
+ private static native String[] readLogBufferStaticNative();
+
+ public static String[] readLogBufferStatic() {
+ return readLogBufferStaticNative();
+ }
+
+ /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
+ public String[] readLogBuffer() {
+ return readLogBufferNative();
+ }
+
+ @DoNotStrip
+ private native String[] readLogBufferNative();
+
+ /**
+ * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
+ *
+ * Currently for internal (minibench) use only.
+ *
+ * @return true if the etdump was successfully written, false otherwise.
+ */
+ @Experimental
+ @DoNotStrip
+ public native boolean etdump();
+
+ /** TTS Pipeline: set this module as the CP (code predictor) for the pipeline. */
+ @DoNotStrip
+ public native void nativeSetCpModule();
+
+ /** TTS Pipeline: run full talker+CP loop in C++. 'this' is the talker module. */
+ @DoNotStrip
+ public native int[] nativeRunTtsPipeline(
+ float[] prefill, int nPrefill, float[] trailing, int nTrailing,
+ float[] codecEmb, float[] cpEmbs, float[] cpHeads,
+ float[] talkerCos, float[] talkerSin, float[] cpCos, float[] cpSin,
+ float[] eosEmbed, float[] padEmbed, int maxTokens);
+
+ /**
+ * Explicitly destroys the native Module object. Calling this method is not required, as the
+ * native object will be destroyed when this object is garbage-collected. However, the timing of
+ * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
+ * more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
+ */
+ public void destroy() {
+ if (mLock.tryLock()) {
+ try {
+ mHybridData.resetNative();
+ } finally {
+ mLock.unlock();
+ }
+ } else {
+ Log.w(
+ "ExecuTorch",
+ "Destroy was called while the module was in use. Resources will not be immediately"
+ + " released.");
+ }
+ }
+}
diff --git a/executorch-custom/jni_layer_tts.cpp b/executorch-custom/jni_layer_tts.cpp
new file mode 100644
index 0000000..16a091e
--- /dev/null
+++ b/executorch-custom/jni_layer_tts.cpp
@@ -0,0 +1,888 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ET_USE_THREADPOOL
+#include
+#include
+#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
+#include
+#endif
+#endif
+
+#ifdef EXECUTORCH_ANDROID_PROFILING
+#include
+#include
+#include
+#endif
+
+#include
+#include
+
+using namespace executorch::extension;
+using namespace torch::executor;
+
+namespace executorch::extension {
+class TensorHybrid : public facebook::jni::HybridClass {
+ public:
+ constexpr static const char* kJavaDescriptor =
+ "Lorg/pytorch/executorch/Tensor;";
+
+ explicit TensorHybrid(executorch::aten::Tensor tensor) {}
+
+ static facebook::jni::local_ref
+ newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
+ // Java wrapper currently only supports contiguous tensors.
+
+ const auto scalarType = tensor.scalar_type();
+ if (scalar_type_to_java_dtype.count(scalarType) == 0) {
+ std::stringstream ss;
+ ss << "executorch::aten::Tensor scalar type "
+ << static_cast(scalarType) << " is not supported on java side";
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+ }
+ int jdtype = scalar_type_to_java_dtype.at(scalarType);
+
+ const auto& tensor_shape = tensor.sizes();
+ std::vector tensor_shape_vec;
+ for (const auto& s : tensor_shape) {
+ tensor_shape_vec.push_back(s);
+ }
+ facebook::jni::local_ref jTensorShape =
+ facebook::jni::make_long_array(tensor_shape_vec.size());
+ jTensorShape->setRegion(
+ 0, tensor_shape_vec.size(), tensor_shape_vec.data());
+
+ static auto cls = TensorHybrid::javaClassStatic();
+ // Note: this is safe as long as the data stored in tensor is valid; the
+ // data won't go out of scope as long as the Method for the inference is
+ // valid and there is no other inference call. Java layer picks up this
+ // value immediately so the data is valid.
+ facebook::jni::local_ref jTensorBuffer =
+ facebook::jni::JByteBuffer::wrapBytes(
+ (uint8_t*)tensor.data_ptr(), tensor.nbytes());
+ jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
+
+ static const auto jMethodNewTensor =
+ cls->getStaticMethod(
+ facebook::jni::alias_ref,
+ facebook::jni::alias_ref,
+ jint,
+ facebook::jni::alias_ref)>("nativeNewTensor");
+ return jMethodNewTensor(
+ cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
+ }
+
+ static TensorPtr newTensorFromJTensor(
+ facebook::jni::alias_ref jtensor) {
+ static auto cls = TensorHybrid::javaClassStatic();
+ static const auto dtypeMethod = cls->getMethod("dtypeJniCode");
+ jint jdtype = dtypeMethod(jtensor);
+
+ static const auto shapeField = cls->getField("shape");
+ auto jshape = jtensor->getFieldValue(shapeField);
+
+ static auto dataBufferMethod = cls->getMethod<
+ facebook::jni::local_ref()>(
+ "getRawDataBuffer");
+ facebook::jni::local_ref jbuffer =
+ dataBufferMethod(jtensor);
+
+ const auto rank = jshape->size();
+
+ const auto shapeArr = jshape->getRegion(0, rank);
+ std::vector shape_vec;
+ shape_vec.reserve(rank);
+
+ int64_t numel = 1;
+ for (int i = 0; i < rank; i++) {
+ shape_vec.push_back(shapeArr[i]);
+ }
+ for (int i = rank - 1; i >= 0; --i) {
+ numel *= shapeArr[i];
+ }
+ JNIEnv* jni = facebook::jni::Environment::current();
+ if (java_dtype_to_scalar_type.count(jdtype) == 0) {
+ std::stringstream ss;
+ ss << "Unknown Tensor jdtype: [" << jdtype << "]";
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+ }
+ ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
+ const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
+ if (dataCapacity < 0) {
+ std::stringstream ss;
+ ss << "Tensor buffer is not direct or has invalid capacity";
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+ }
+ const size_t elementSize = executorch::runtime::elementSize(scalar_type);
+ const jlong expectedElements = static_cast(numel);
+ const jlong expectedBytes =
+ expectedElements * static_cast(elementSize);
+ const bool matchesElements = dataCapacity == expectedElements;
+ const bool matchesBytes = dataCapacity == expectedBytes;
+ if (!matchesElements && !matchesBytes) {
+ std::stringstream ss;
+ ss << "Tensor dimensions(elements number: " << numel
+ << ") inconsistent with buffer capacity " << dataCapacity
+ << " (element size bytes: " << elementSize << ")";
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+ }
+ return from_blob(
+ jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
+ }
+
+ private:
+ friend HybridBase;
+};
+
+class JEValue : public facebook::jni::JavaClass {
+ public:
+ constexpr static const char* kJavaDescriptor =
+ "Lorg/pytorch/executorch/EValue;";
+
+ constexpr static int kTypeCodeTensor = 1;
+ constexpr static int kTypeCodeString = 2;
+ constexpr static int kTypeCodeDouble = 3;
+ constexpr static int kTypeCodeInt = 4;
+ constexpr static int kTypeCodeBool = 5;
+
+ static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) {
+ if (evalue.isTensor()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(
+ facebook::jni::local_ref)>("from");
+ return jMethodTensor(
+ JEValue::javaClassStatic(),
+ TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
+ } else if (evalue.isInt()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jlong)>(
+ "from");
+ return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
+ } else if (evalue.isDouble()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jdouble)>(
+ "from");
+ return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
+ } else if (evalue.isBool()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jboolean)>(
+ "from");
+ return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
+ } else if (evalue.isString()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(
+ facebook::jni::local_ref)>("from");
+ std::string str =
+ std::string(evalue.toString().begin(), evalue.toString().end());
+ return jMethodTensor(
+ JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
+ }
+ std::stringstream ss;
+ ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]";
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return {};
+ }
+
+ static TensorPtr JEValueToTensorImpl(
+ facebook::jni::alias_ref JEValue) {
+ static const auto typeCodeField =
+ JEValue::javaClassStatic()->getField("mTypeCode");
+ const auto typeCode = JEValue->getFieldValue(typeCodeField);
+ if (JEValue::kTypeCodeTensor == typeCode) {
+ static const auto jMethodGetTensor =
+ JEValue::javaClassStatic()
+ ->getMethod()>(
+ "toTensor");
+ auto jtensor = jMethodGetTensor(JEValue);
+ return TensorHybrid::newTensorFromJTensor(jtensor);
+ }
+ std::stringstream ss;
+ ss << "Unknown EValue typeCode: " << typeCode;
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str().c_str());
+ return {};
+ }
+};
+
+class ExecuTorchJni : public facebook::jni::HybridClass {
+ private:
+ friend HybridBase;
+ std::unique_ptr module_;
+#if defined(ET_USE_THREADPOOL) && \
+ defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
+ int num_threads_{0};
+#endif
+
+ public:
+ constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
+
+ static facebook::jni::local_ref initHybrid(
+ facebook::jni::alias_ref,
+ facebook::jni::alias_ref modelPath,
+ jint loadMode,
+ jint numThreads) {
+ return makeCxxInstance(modelPath, loadMode, numThreads);
+ }
+
+ ExecuTorchJni(
+ facebook::jni::alias_ref modelPath,
+ jint loadMode,
+ jint numThreads) {
+ Module::LoadMode load_mode = Module::LoadMode::Mmap;
+ if (loadMode == 0) {
+ load_mode = Module::LoadMode::File;
+ } else if (loadMode == 1) {
+ load_mode = Module::LoadMode::Mmap;
+ } else if (loadMode == 2) {
+ load_mode = Module::LoadMode::MmapUseMlock;
+ } else if (loadMode == 3) {
+ load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
+ }
+#ifdef EXECUTORCH_ANDROID_PROFILING
+ auto etdump_gen = std::make_unique();
+#else
+ auto etdump_gen = nullptr;
+#endif
+ module_ = std::make_unique(
+ modelPath->toStdString(), load_mode, std::move(etdump_gen));
+
+#ifdef ET_USE_THREADPOOL
+ // Default to using cores/2 threadpool threads. The long-term plan is to
+ // improve performant core detection in CPUInfo, but for now we can use
+ // cores/2 as a sane default.
+ //
+ // Based on testing, this is almost universally faster than using all
+ // cores, as efficiency cores can be quite slow. In extreme cases, using
+ // all cores can be 10x slower than using cores/2.
+ int thread_count =
+ numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
+#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
+ num_threads_ = thread_count;
+#else
+ auto threadpool = executorch::extension::threadpool::get_threadpool();
+ if (threadpool) {
+ if (thread_count > 0) {
+ threadpool->_unsafe_reset_threadpool(thread_count);
+ }
+ }
+#endif
+#endif
+ }
+
+ facebook::jni::local_ref> execute(
+ facebook::jni::alias_ref methodName,
+ facebook::jni::alias_ref<
+ facebook::jni::JArrayClass::javaobject>
+ jinputs) {
+ return execute_method(methodName->toStdString(), jinputs);
+ }
+
+ jint load_method(facebook::jni::alias_ref methodName) {
+ return static_cast(module_->load_method(methodName->toStdString()));
+ }
+
+ facebook::jni::local_ref> execute_method(
+ std::string method,
+ facebook::jni::alias_ref<
+ facebook::jni::JArrayClass::javaobject>
+ jinputs) {
+ // If no inputs is given, it will run with sample inputs (ones)
+ if (jinputs->size() == 0) {
+ auto result = module_->load_method(method);
+ if (result != Error::Ok) {
+ // Format hex string
+ std::stringstream ss;
+ ss << "Cannot get method names [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(result) << "]";
+
+ jni_helper::throwExecutorchException(
+ static_cast(result), ss.str());
+ return {};
+ }
+ auto&& underlying_method = module_->methods_[method].method;
+ auto&& buf = prepare_input_tensors(*underlying_method);
+ result = underlying_method->execute();
+ if (result != Error::Ok) {
+ jni_helper::throwExecutorchException(
+ static_cast(result),
+ "Execution failed for method: " + method);
+ return {};
+ }
+ facebook::jni::local_ref> jresult =
+ facebook::jni::JArrayClass::newArray(
+ underlying_method->outputs_size());
+
+ for (int i = 0; i < underlying_method->outputs_size(); i++) {
+ auto jevalue =
+ JEValue::newJEValueFromEValue(underlying_method->get_output(i));
+ jresult->setElement(i, *jevalue);
+ }
+ return jresult;
+ }
+
+ std::vector evalues;
+ std::vector tensors;
+
+ static const auto typeCodeField =
+ JEValue::javaClassStatic()->getField("mTypeCode");
+
+ for (int i = 0; i < jinputs->size(); i++) {
+ auto jevalue = jinputs->getElement(i);
+ const auto typeCode = jevalue->getFieldValue(typeCodeField);
+ if (typeCode == JEValue::kTypeCodeTensor) {
+ tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
+ evalues.emplace_back(tensors.back());
+ } else if (typeCode == JEValue::kTypeCodeInt) {
+ static const auto toIntMethod =
+ JEValue::javaClassStatic()->getMethod("toInt");
+ evalues.emplace_back(static_cast(toIntMethod(jevalue)));
+ } else if (typeCode == JEValue::kTypeCodeDouble) {
+ static const auto toDoubleMethod =
+ JEValue::javaClassStatic()->getMethod("toDouble");
+ evalues.emplace_back(static_cast(toDoubleMethod(jevalue)));
+ } else if (typeCode == JEValue::kTypeCodeBool) {
+ static const auto toBoolMethod =
+ JEValue::javaClassStatic()->getMethod("toBool");
+ evalues.emplace_back(static_cast(toBoolMethod(jevalue)));
+ }
+ }
+
+#if defined(ET_USE_THREADPOOL) && \
+ defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
+ ::executorch::extension::threadpool::UseNThreadsThreadPoolGuard
+ thread_pool_guard(num_threads_);
+#endif
+
+#ifdef EXECUTORCH_ANDROID_PROFILING
+ auto start = std::chrono::high_resolution_clock::now();
+ auto result = module_->execute(method, evalues);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration =
+ std::chrono::duration_cast(end - start)
+ .count();
+ ET_LOG(Debug, "Execution time: %lld ms.", duration);
+
+#else
+ auto result = module_->execute(method, evalues);
+
+#endif
+
+ if (!result.ok()) {
+ jni_helper::throwExecutorchException(
+ static_cast(result.error()),
+ "Execution failed for method: " + method);
+ return {};
+ }
+
+ facebook::jni::local_ref> jresult =
+ facebook::jni::JArrayClass::newArray(result.get().size());
+
+ for (int i = 0; i < result.get().size(); i++) {
+ auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
+ jresult->setElement(i, *jevalue);
+ }
+ return jresult;
+ }
+
+ facebook::jni::local_ref>
+ readLogBuffer() {
+ return readLogBufferUtil();
+ }
+
+ static facebook::jni::local_ref>
+ readLogBufferStatic(facebook::jni::alias_ref) {
+ return readLogBufferUtil();
+ }
+
+ static facebook::jni::local_ref>
+ readLogBufferUtil() {
+#ifdef __ANDROID__
+
+ facebook::jni::local_ref> ret;
+
+ access_log_buffer([&](std::vector& buffer) {
+ const auto size = buffer.size();
+ ret = facebook::jni::JArrayClass::newArray(size);
+ for (auto i = 0u; i < size; i++) {
+ const auto& entry = buffer[i];
+ // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
+ // MESSAGE".
+ std::stringstream ss;
+ ss << "[" << entry.timestamp << " " << entry.function << " "
+ << entry.filename << ":" << entry.line << "] "
+ << static_cast(entry.level) << " " << entry.message;
+
+ facebook::jni::local_ref jstr_message =
+ facebook::jni::make_jstring(ss.str().c_str());
+ (*ret)[i] = jstr_message;
+ }
+ });
+
+ return ret;
+#else
+ return facebook::jni::JArrayClass::newArray(0);
+#endif
+ }
+
+ jboolean etdump() {
+#ifdef EXECUTORCH_ANDROID_PROFILING
+ executorch::etdump::ETDumpGen* etdumpgen =
+ (executorch::etdump::ETDumpGen*)module_->event_tracer();
+ auto etdump_data = etdumpgen->get_etdump_data();
+
+ if (etdump_data.buf != nullptr && etdump_data.size > 0) {
+ int etdump_file =
+ open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
+ if (etdump_file == -1) {
+ ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
+ return false;
+ }
+ ssize_t bytes_written =
+ write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
+ if (bytes_written == -1) {
+ ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
+ return false;
+ } else {
+ ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
+ }
+ close(etdump_file);
+ free(etdump_data.buf);
+ return true;
+ } else {
+ ET_LOG(Error, "No ETDump data available!");
+ }
+#endif
+ return false;
+ }
+
+ facebook::jni::local_ref> getMethods() {
+ const auto& names_result = module_->method_names();
+ if (!names_result.ok()) {
+ // Format hex string
+ std::stringstream ss;
+ ss << "Cannot get load module [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(names_result.error())
+ << "]";
+
+ jni_helper::throwExecutorchException(
+ static_cast(Error::InvalidArgument), ss.str());
+ return {};
+ }
+ const auto& methods = names_result.get();
+ facebook::jni::local_ref> ret =
+ facebook::jni::JArrayClass::newArray(methods.size());
+ int i = 0;
+ for (auto s : methods) {
+ facebook::jni::local_ref method_name =
+ facebook::jni::make_jstring(s.c_str());
+ (*ret)[i] = method_name;
+ i++;
+ }
+ return ret;
+ }
+
+ facebook::jni::local_ref> getUsedBackends(
+ facebook::jni::alias_ref methodName) {
+ auto method_name = methodName->toStdString();
+ auto methodMetaResult = module_->method_meta(method_name);
+ if (!methodMetaResult.ok()) {
+ std::stringstream ss;
+ ss << "Cannot get method meta for '" << method_name
+ << "' [Native Error: 0x" << std::hex << std::uppercase
+ << static_cast(methodMetaResult.error()) << "]";
+ jni_helper::throwExecutorchException(
+ static_cast(methodMetaResult.error()), ss.str());
+ return {};
+ }
+ auto methodMeta = methodMetaResult.get();
+ std::unordered_set backends;
+ for (auto i = 0; i < methodMeta.num_backends(); i++) {
+ auto backend_name_result = methodMeta.get_backend_name(i);
+ if (backend_name_result.ok()) {
+ backends.insert(backend_name_result.get());
+ }
+ }
+
+ facebook::jni::local_ref> ret =
+ facebook::jni::JArrayClass::newArray(backends.size());
+ int i = 0;
+ for (auto s : backends) {
+ facebook::jni::local_ref backend_name =
+ facebook::jni::make_jstring(s.c_str());
+ (*ret)[i] = backend_name;
+ i++;
+ }
+ return ret;
+ }
+
+ static void registerNatives() {
+ registerHybrid({
+ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
+ makeNativeMethod("executeNative", ExecuTorchJni::execute),
+ makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
+ makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
+ makeNativeMethod(
+ "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
+ makeNativeMethod("etdump", ExecuTorchJni::etdump),
+ makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
+ makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
+ makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule),
+ makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline),
+ });
+ }
+
+ // ======= TTS Pipeline =======
+ static Module* cpModule_;
+
+ void setCpModule() {
+ cpModule_ = module_.get();
+ ET_LOG(Info, "TTS CP module set");
+ }
+
+ // Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP.
+ facebook::jni::local_ref
+ runTtsPipeline(
+ facebook::jni::alias_ref jPrefill, jint nPrefill,
+ facebook::jni::alias_ref jTrailing, jint nTrailing,
+ facebook::jni::alias_ref jCodecEmb,
+ facebook::jni::alias_ref jCpEmbs,
+ facebook::jni::alias_ref jCpHeads,
+ facebook::jni::alias_ref jTCos,
+ facebook::jni::alias_ref jTSin,
+ facebook::jni::alias_ref jCCos,
+ facebook::jni::alias_ref jCSin,
+ facebook::jni::alias_ref jEos,
+ facebook::jni::alias_ref jPad,
+ jint maxTokens);
+
+ static facebook::jni::local_ref
+ runTtsPipelineImpl(
+ Module* talker, Module* cp,
+ facebook::jni::alias_ref jPrefill, jint nPrefill,
+ facebook::jni::alias_ref jTrailing, jint nTrailing,
+ facebook::jni::alias_ref jCodecEmb,
+ facebook::jni::alias_ref jCpEmbs,
+ facebook::jni::alias_ref jCpHeads,
+ facebook::jni::alias_ref jTCos,
+ facebook::jni::alias_ref jTSin,
+ facebook::jni::alias_ref jCCos,
+ facebook::jni::alias_ref jCSin,
+ facebook::jni::alias_ref jEos,
+ facebook::jni::alias_ref jPad,
+ jint maxTokens);
+};
+} // namespace executorch::extension
+
+// ======= TTS Pipeline statics and wrapper =======
+namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; }
+
+namespace executorch::extension {
+facebook::jni::local_ref
+ExecuTorchJni::runTtsPipeline(
+ facebook::jni::alias_ref jPrefill, jint nPrefill,
+ facebook::jni::alias_ref jTrailing, jint nTrailing,
+ facebook::jni::alias_ref jCodecEmb,
+ facebook::jni::alias_ref jCpEmbs,
+ facebook::jni::alias_ref jCpHeads,
+ facebook::jni::alias_ref jTCos,
+ facebook::jni::alias_ref jTSin,
+ facebook::jni::alias_ref jCCos,
+ facebook::jni::alias_ref jCSin,
+ facebook::jni::alias_ref jEos,
+ facebook::jni::alias_ref jPad,
+ jint maxTokens) {
+ if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); }
+ return runTtsPipelineImpl(module_.get(), cpModule_,
+ jPrefill, nPrefill, jTrailing, nTrailing,
+ jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens);
+}
+} // namespace
+
+#include
+#include
+#include
+
+static inline float tts_dot_neon(const float* a, const float* b, int n) {
+ float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
+ int i=0;
+ for(;i+15 topk(k,{0,-FLT_MAX});
+ for(int i=0;itopk[k-1].v){topk[k-1]={i,logits[i]};
+ for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}}
+ }
+ float maxv=topk[0].v,sum=0;
+ for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;}
+ tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL;
+ float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum;
+ float acc=0;
+ for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;}
+ return topk[0].i;
+}
+
+namespace executorch::extension {
+facebook::jni::local_ref
+ExecuTorchJni::runTtsPipelineImpl(
+ Module* talker, Module* cp,
+ facebook::jni::alias_ref jPrefill, jint nPrefill,
+ facebook::jni::alias_ref jTrailing, jint nTrailing,
+ facebook::jni::alias_ref jCodecEmb,
+ facebook::jni::alias_ref jCpEmbs,
+ facebook::jni::alias_ref jCpHeads,
+ facebook::jni::alias_ref jTCos,
+ facebook::jni::alias_ref jTSin,
+ facebook::jni::alias_ref jCCos,
+ facebook::jni::alias_ref jCSin,
+ facebook::jni::alias_ref jEos,
+ facebook::jni::alias_ref jPad,
+ jint maxTokens)
+{
+ static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
+ static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
+ static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
+ static const int CODEC_EOS=2150;
+
+ // Copy JNI arrays
+ auto copyArr = [](facebook::jni::alias_ref j) {
+ std::vector v(j->size());
+ j->getRegion(0, j->size(), v.data());
+ return v;
+ };
+ auto prefill=copyArr(jPrefill);
+ auto trailing=nTrailing>0?copyArr(jTrailing):std::vector();
+ auto codecEmb=copyArr(jCodecEmb);
+ auto cpEmbs=copyArr(jCpEmbs);
+ auto cpHeads=copyArr(jCpHeads);
+ auto tCos=copyArr(jTCos),tSin=copyArr(jTSin);
+ auto cCos=copyArr(jCCos),cSin=copyArr(jCSin);
+ auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad);
+
+ int tkvElem=T_KV*T_KV_LEN*T_HD;
+ std::vector tK(T_L*tkvElem,0),tV(T_L*tkvElem,0);
+ float mask[T_KV_LEN]; for(int i=0;i allCodes,cb0Hist;
+ int pos=0,curCb0=-1,trIdx=0;
+
+ // Get raw Method pointers for direct execution
+ auto tMethodRes = talker->method("forward");
+ auto cMethodRes = cp->method("forward");
+ if (!tMethodRes.ok() || !cMethodRes.ok()) {
+ ET_LOG(Error, "TTS cannot get Method pointers");
+ return facebook::jni::JArrayInt::newArray(0);
+ }
+ Method* tMethod = tMethodRes.get();
+ Method* cMethod = cMethodRes.get();
+
+ // Talker step: prepare_input_tensors + memcpy + execute (like cp_et_runner)
+ auto talkerStep = [&](const float* emb) {
+ int pi=std::min(pos,249);
+ int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
+ if(mi>=0) mask[mi]=0.0f;
+
+ auto prep = executorch::extension::prepare_input_tensors(*tMethod);
+ if(!prep.ok()){ET_LOG(Error,"Talker prep fail");return;}
+ memcpy(tMethod->mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4);
+ memcpy(tMethod->mutable_input(1).toTensor().mutable_data_ptr(), mask, T_KV_LEN*4);
+ memcpy(tMethod->mutable_input(2).toTensor().mutable_data_ptr(), tCos.data()+pi*T_HD, T_HD*4);
+ memcpy(tMethod->mutable_input(3).toTensor().mutable_data_ptr(), tSin.data()+pi*T_HD, T_HD*4);
+ for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(), tK.data()+i*tkvElem, tkvElem*4);
+ memcpy(tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(), tV.data()+i*tkvElem, tkvElem*4);
+ }
+ auto status = tMethod->execute();
+ if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;}
+ memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr(), DIM*4);
+ memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr(), VOCAB*4);
+ for(int i=0;iget_output(2+i*2).toTensor().const_data_ptr(), tkvElem*4);
+ memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr(), tkvElem*4);
+ }
+ pos++;
+ };
+
+ // CP step: 17 autoregressive steps using Method directly
+ auto cpStep = [&](const float* h, int cb0, int* codes) {
+ int ckvElem=C_KV*C_KV_LEN*C_HD;
+ std::vector ckv(C_L*2*ckvElem,0);
+ for(int step=0;step<17;step++){
+ const float*emb;
+ if(step==0) emb=h;
+ else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM;
+ else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
+
+ auto prep=executorch::extension::prepare_input_tensors(*cMethod);
+ if(!prep.ok()) break;
+ memcpy(cMethod->mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4);
+ float*mp=cMethod->mutable_input(1).toTensor().mutable_data_ptr();
+ for(int p=0;p=C_KV_LEN-1-step)?0.0f:-1e9f;
+ memcpy(cMethod->mutable_input(2).toTensor().mutable_data_ptr(), cCos.data()+step*C_HD, C_HD*4);
+ memcpy(cMethod->mutable_input(3).toTensor().mutable_data_ptr(), cSin.data()+step*C_HD, C_HD*4);
+ for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(), ckv.data()+(i*2)*ckvElem, ckvElem*4);
+ memcpy(cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(), ckv.data()+(i*2+1)*ckvElem, ckvElem*4);
+ }
+ auto status=cMethod->execute();
+ if(status!=Error::Ok) break;
+ const float*ho=cMethod->get_output(0).toTensor().const_data_ptr();
+ if(step>=1&&step-1<15){
+ const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM;
+ int best=0;float bv=-FLT_MAX;
+ for(int j=0;jbv){bv=d;best=j;}}
+ codes[step-1]=best;
+ }
+ for(int i=0;iget_output(1+i*2).toTensor().const_data_ptr(),ckvElem*4);
+ memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr(),ckvElem*4);
+ }
+ }
+ };
+
+ auto t0=std::chrono::high_resolution_clock::now();
+ // Prefill
+ for(int s=0;s(tP-t0).count(),curCb0);
+
+ if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0);
+
+ // Generation
+ float totalTalker=0,totalCp=0;
+ for(int g=0;g(tc1-tc0).count();
+ for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
+ for(int i=0;i(tt1-tt0).count();
+
+ for(int j=CB_SIZE;j seen(cb0Hist.begin(),cb0Hist.end());
+ for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
+ int next=tts_sample_topk(logits,VOCAB,0.9f,50);
+ if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;}
+ if((int)cb0Hist.size()>=9){
+ bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;}
+ if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;}
+ }
+ curCb0=next;
+ }
+ int nTok=(int)allCodes.size()/NUM_CB;
+ auto t1=std::chrono::high_resolution_clock::now();
+ ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms",
+ nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1),
+ std::chrono::duration(t1-t0).count());
+
+ auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size());
+ result->setRegion(0,(int)allCodes.size(),allCodes.data());
+ return result;
+}
+} // namespace executorch::extension
+
+#ifdef EXECUTORCH_BUILD_LLAMA_JNI
+extern void register_natives_for_llm();
+#else
+// No op if we don't build LLM
+void register_natives_for_llm() {}
+#endif
+extern void register_natives_for_runtime();
+
+#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
+extern void register_natives_for_training();
+#else
+// No op if we don't build training JNI
+void register_natives_for_training() {}
+#endif
+
+JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
+ return facebook::jni::initialize(vm, [] {
+ executorch::extension::ExecuTorchJni::registerNatives();
+ register_natives_for_llm();
+ register_natives_for_runtime();
+ register_natives_for_training();
+ });
+}
diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt
index 0d6bc7e..315bc82 100644
--- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt
+++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt
@@ -229,29 +229,11 @@ class Qwen3TtsEngine(
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
- // Try native C++ pipeline first (single QNN instance, no Java overhead)
+ // Load .pte modules via Java JNI (native pipeline uses same instances)
run {
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
- val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte")
- if (etModel.exists() && talkerPte.exists()) {
+ if (etModel.exists() && cpPteModule == null) {
try {
- val tn = System.currentTimeMillis()
- nativePipelineReady = TtsPipeline.nativeInit(
- talkerPte.absolutePath, etModel.absolutePath
- )
- nlog("Native C++ pipeline: ${if (nativePipelineReady) "OK" else "FAILED"} (${System.currentTimeMillis() - tn}ms)")
- } catch (e: Exception) {
- nlog("Native pipeline init failed: ${e.message}")
- }
- }
- }
-
- // Fallback: Load Java .pte modules (only if native pipeline failed)
- run {
- val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
- if (!nativePipelineReady && etModel.exists() && cpPteModule == null) {
- try {
- nlog("Loading Java .pte modules (native unavailable)...")
val t0 = System.currentTimeMillis()
cpPteModule = org.pytorch.executorch.Module.load(
etModel.absolutePath,
@@ -265,6 +247,10 @@ class Qwen3TtsEngine(
if (lmResult != 0) {
nlog("CP .pte loadMethod failed ($lmResult), disabling JNI")
cpPteModule = null
+ } else {
+ // Register CP module for native pipeline
+ cpPteModule!!.nativeSetCpModule()
+ nlog("CP module registered for native pipeline")
}
} catch (e: Exception) {
nlog("CP .pte JNI failed: ${e.message}")
@@ -2298,11 +2284,9 @@ class Qwen3TtsEngine(
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
val allCodes: Array
- // Native C++ disabled: QNN HTP compilation not deterministic between loads
- // Two instances of same .pte give slightly different hidden states → trembling
- // Keep Java pipeline (same QNN instance, validated quality)
- if (false && nativePipelineReady) {
- // Native C++ pipeline — zero Java overhead
+ // Native C++ pipeline using SAME Java Module instances (no quality loss)
+ if (talkerPteModule != null && cpPteModule != null) {
+ // C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
val nTrailing = nTotal - nPrefill
@@ -2336,8 +2320,8 @@ class Qwen3TtsEngine(
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
}
- nlog("Running native C++ pipeline...")
- val flat = TtsPipeline.nativeRun(
+ nlog("Running native C++ pipeline (shared Module)...")
+ val flat = talkerPteModule!!.nativeRunTtsPipeline(
prefillFlat, nPrefill,
trailingFlat, nTrailing,
codecEmbedding ?: FloatArray(0),