From e64791132908e6969b3b4643f7f7f842d492b125 Mon Sep 17 00:00:00 2001 From: Kazeia Team Date: Thu, 9 Apr 2026 12:05:58 +0200 Subject: [PATCH] Shared Module C++ pipeline: RTF 1.6 with perfect quality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Key breakthrough: C++ pipeline loop using the SAME Method* instances that Java loaded (via Module::method("forward")). This gives: - Same QNN compiled graph → identical numerical results → no trembling - C++ loop → no Java Tensor/EValue allocation overhead - prepare_input_tensors + memcpy + Method::execute (like cp_et_runner) Pipeline: talker ~20ms/step + CP ~44ms/step + decoder 2.8s = 7.3s for 4.64s Added to executorch JNI: - Module.nativeSetCpModule() — registers CP module for pipeline - Module.nativeRunTtsPipeline(...) — runs full talker+CP loop in C++ - Updated executorch.jar with new native method declarations From RTF 4.9 (start of session) to RTF 1.6 with impeccable audio quality. Co-Authored-By: Claude Opus 4.6 (1M context) --- executorch-custom/Module.java | 262 ++++++ executorch-custom/jni_layer_tts.cpp | 888 ++++++++++++++++++ .../java/com/kazeia/tts/Qwen3TtsEngine.kt | 38 +- 3 files changed, 1161 insertions(+), 27 deletions(-) create mode 100644 executorch-custom/Module.java create mode 100644 executorch-custom/jni_layer_tts.cpp diff --git a/executorch-custom/Module.java b/executorch-custom/Module.java new file mode 100644 index 0000000..d248eaa --- /dev/null +++ b/executorch-custom/Module.java @@ -0,0 +1,262 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import android.util.Log; +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Java wrapper for ExecuTorch Module. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class Module { + + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + // Loads libexecutorch.so from jniLibs + NativeLoader.loadLibrary("executorch"); + } + + /** Load mode for the module. Load the whole file as a buffer. */ + public static final int LOAD_MODE_FILE = 0; + + /** Load mode for the module. Use mmap to load pages into memory. */ + public static final int LOAD_MODE_MMAP = 1; + + /** Load mode for the module. Use memory locking and handle errors. */ + public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; + + /** Load mode for the module. Use memory locking and ignore errors. */ + public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; + + private final HybridData mHybridData; + + private final Map mMethodMetadata; + + @DoNotStrip + private static native HybridData initHybrid( + String moduleAbsolutePath, int loadMode, int numThreads); + + private Module(String moduleAbsolutePath, int loadMode, int numThreads) { + ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); + + mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); + + mMethodMetadata = populateMethodMeta(); + } + + private Map populateMethodMeta() { + String[] methods = getMethods(); + Map metadata = new HashMap(); + for (String name : methods) { + metadata.put(name, new MethodMetadata(name, getUsedBackends(name))); + } + return metadata; + } + + /** Lock protecting the non-thread safe methods in mHybridData. */ + private Lock mLock = new ReentrantLock(); + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param loadMode load mode for the module. See constants in {@link Module}. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath, int loadMode) { + return load(modelPath, loadMode, 0); + } + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param loadMode load mode for the module. See constants in {@link Module}. + * @param numThreads the number of threads to use for inference. A value of 0 defaults to a + * hardware-specific default. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath, int loadMode, int numThreads) { + ExecuTorchRuntime.validateFilePath(modelPath, "model path"); + return new Module(modelPath, loadMode, numThreads); + } + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath) { + return load(modelPath, LOAD_MODE_FILE); + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' + * requires inputs but no inputs are given, the function will not error out, but run 'forward' + * with sample inputs. + * @return return value from the 'forward' method. + */ + public EValue[] forward(EValue... inputs) { + return execute("forward", inputs); + } + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value from the method. + */ + public EValue[] execute(String methodName, EValue... inputs) { + try { + mLock.lock(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return executeNative(methodName, inputs); + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native EValue[] executeNative(String methodName, EValue... inputs); + + /** + * Load a method on this module. This might help with the first time inference performance, + * because otherwise the method is loaded lazily when it's execute. Note: this function is + * synchronous, and will block until the method is loaded. Therefore, it is recommended to call + * this on a background thread. However, users need to make sure that they don't execute before + * this function returns. + * + * @return the Error code if there was an error loading the method + */ + public int loadMethod(String methodName) { + try { + mLock.lock(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return 0x2; // InvalidState + } + return loadMethodNative(methodName); + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native int loadMethodNative(String methodName); + + /** + * Returns the names of the backends in a certain method. + * + * @param methodName method name to query + * @return an array of backend name + */ + @DoNotStrip + private native String[] getUsedBackends(String methodName); + + /** + * Returns the names of methods. + * + * @return name of methods in this Module + */ + @DoNotStrip + public native String[] getMethods(); + + /** + * Get the corresponding @MethodMetadata for a method + * + * @param name method name + * @return @MethodMetadata for this method + */ + public MethodMetadata getMethodMetadata(String name) { + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata == null) { + throw new IllegalArgumentException("method " + name + " does not exist for this module"); + } + return methodMetadata; + } + + @DoNotStrip + private static native String[] readLogBufferStaticNative(); + + public static String[] readLogBufferStatic() { + return readLogBufferStaticNative(); + } + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + public String[] readLogBuffer() { + return readLogBufferNative(); + } + + @DoNotStrip + private native String[] readLogBufferNative(); + + /** + * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. + * + *

Currently for internal (minibench) use only. + * + * @return true if the etdump was successfully written, false otherwise. + */ + @Experimental + @DoNotStrip + public native boolean etdump(); + + /** TTS Pipeline: set this module as the CP (code predictor) for the pipeline. */ + @DoNotStrip + public native void nativeSetCpModule(); + + /** TTS Pipeline: run full talker+CP loop in C++. 'this' is the talker module. */ + @DoNotStrip + public native int[] nativeRunTtsPipeline( + float[] prefill, int nPrefill, float[] trailing, int nTrailing, + float[] codecEmb, float[] cpEmbs, float[] cpHeads, + float[] talkerCos, float[] talkerSin, float[] cpCos, float[] cpSin, + float[] eosEmbed, float[] padEmbed, int maxTokens); + + /** + * Explicitly destroys the native Module object. Calling this method is not required, as the + * native object will be destroyed when this object is garbage-collected. However, the timing of + * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + */ + public void destroy() { + if (mLock.tryLock()) { + try { + mHybridData.resetNative(); + } finally { + mLock.unlock(); + } + } else { + Log.w( + "ExecuTorch", + "Destroy was called while the module was in use. Resources will not be immediately" + + " released."); + } + } +} diff --git a/executorch-custom/jni_layer_tts.cpp b/executorch-custom/jni_layer_tts.cpp new file mode 100644 index 0000000..16a091e --- /dev/null +++ b/executorch-custom/jni_layer_tts.cpp @@ -0,0 +1,888 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef ET_USE_THREADPOOL +#include +#include +#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD +#include +#endif +#endif + +#ifdef EXECUTORCH_ANDROID_PROFILING +#include +#include +#include +#endif + +#include +#include + +using namespace executorch::extension; +using namespace torch::executor; + +namespace executorch::extension { +class TensorHybrid : public facebook::jni::HybridClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/Tensor;"; + + explicit TensorHybrid(executorch::aten::Tensor tensor) {} + + static facebook::jni::local_ref + newJTensorFromTensor(const executorch::aten::Tensor& tensor) { + // Java wrapper currently only supports contiguous tensors. + + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + std::stringstream ss; + ss << "executorch::aten::Tensor scalar type " + << static_cast(scalarType) << " is not supported on java side"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; + } + int jdtype = scalar_type_to_java_dtype.at(scalarType); + + const auto& tensor_shape = tensor.sizes(); + std::vector tensor_shape_vec; + for (const auto& s : tensor_shape) { + tensor_shape_vec.push_back(s); + } + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensor_shape_vec.size()); + jTensorShape->setRegion( + 0, tensor_shape_vec.size(), tensor_shape_vec.data()); + + static auto cls = TensorHybrid::javaClassStatic(); + // Note: this is safe as long as the data stored in tensor is valid; the + // data won't go out of scope as long as the Method for the inference is + // valid and there is no other inference call. Java layer picks up this + // value immediately so the data is valid. + facebook::jni::local_ref jTensorBuffer = + facebook::jni::JByteBuffer::wrapBytes( + (uint8_t*)tensor.data_ptr(), tensor.nbytes()); + jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); + + static const auto jMethodNewTensor = + cls->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref, + jint, + facebook::jni::alias_ref)>("nativeNewTensor"); + return jMethodNewTensor( + cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); + } + + static TensorPtr newTensorFromJTensor( + facebook::jni::alias_ref jtensor) { + static auto cls = TensorHybrid::javaClassStatic(); + static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); + jint jdtype = dtypeMethod(jtensor); + + static const auto shapeField = cls->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); + + static auto dataBufferMethod = cls->getMethod< + facebook::jni::local_ref()>( + "getRawDataBuffer"); + facebook::jni::local_ref jbuffer = + dataBufferMethod(jtensor); + + const auto rank = jshape->size(); + + const auto shapeArr = jshape->getRegion(0, rank); + std::vector shape_vec; + shape_vec.reserve(rank); + + int64_t numel = 1; + for (int i = 0; i < rank; i++) { + shape_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; + } + JNIEnv* jni = facebook::jni::Environment::current(); + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + std::stringstream ss; + ss << "Unknown Tensor jdtype: [" << jdtype << "]"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; + } + ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); + const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); + if (dataCapacity < 0) { + std::stringstream ss; + ss << "Tensor buffer is not direct or has invalid capacity"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; + } + const size_t elementSize = executorch::runtime::elementSize(scalar_type); + const jlong expectedElements = static_cast(numel); + const jlong expectedBytes = + expectedElements * static_cast(elementSize); + const bool matchesElements = dataCapacity == expectedElements; + const bool matchesBytes = dataCapacity == expectedBytes; + if (!matchesElements && !matchesBytes) { + std::stringstream ss; + ss << "Tensor dimensions(elements number: " << numel + << ") inconsistent with buffer capacity " << dataCapacity + << " (element size bytes: " << elementSize << ")"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; + } + return from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); + } + + private: + friend HybridBase; +}; + +class JEValue : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/EValue;"; + + constexpr static int kTypeCodeTensor = 1; + constexpr static int kTypeCodeString = 2; + constexpr static int kTypeCodeDouble = 3; + constexpr static int kTypeCodeInt = 4; + constexpr static int kTypeCodeBool = 5; + + static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) { + if (evalue.isTensor()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + return jMethodTensor( + JEValue::javaClassStatic(), + TensorHybrid::newJTensorFromTensor(evalue.toTensor())); + } else if (evalue.isInt()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod(jlong)>( + "from"); + return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt()); + } else if (evalue.isDouble()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod(jdouble)>( + "from"); + return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble()); + } else if (evalue.isBool()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod(jboolean)>( + "from"); + return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool()); + } else if (evalue.isString()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + return jMethodTensor( + JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); + } + std::stringstream ss; + ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return {}; + } + + static TensorPtr JEValueToTensorImpl( + facebook::jni::alias_ref JEValue) { + static const auto typeCodeField = + JEValue::javaClassStatic()->getField("mTypeCode"); + const auto typeCode = JEValue->getFieldValue(typeCodeField); + if (JEValue::kTypeCodeTensor == typeCode) { + static const auto jMethodGetTensor = + JEValue::javaClassStatic() + ->getMethod()>( + "toTensor"); + auto jtensor = jMethodGetTensor(JEValue); + return TensorHybrid::newTensorFromJTensor(jtensor); + } + std::stringstream ss; + ss << "Unknown EValue typeCode: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return {}; + } +}; + +class ExecuTorchJni : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr module_; +#if defined(ET_USE_THREADPOOL) && \ + defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD) + int num_threads_{0}; +#endif + + public: + constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref modelPath, + jint loadMode, + jint numThreads) { + return makeCxxInstance(modelPath, loadMode, numThreads); + } + + ExecuTorchJni( + facebook::jni::alias_ref modelPath, + jint loadMode, + jint numThreads) { + Module::LoadMode load_mode = Module::LoadMode::Mmap; + if (loadMode == 0) { + load_mode = Module::LoadMode::File; + } else if (loadMode == 1) { + load_mode = Module::LoadMode::Mmap; + } else if (loadMode == 2) { + load_mode = Module::LoadMode::MmapUseMlock; + } else if (loadMode == 3) { + load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors; + } +#ifdef EXECUTORCH_ANDROID_PROFILING + auto etdump_gen = std::make_unique(); +#else + auto etdump_gen = nullptr; +#endif + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); + +#ifdef ET_USE_THREADPOOL + // Default to using cores/2 threadpool threads. The long-term plan is to + // improve performant core detection in CPUInfo, but for now we can use + // cores/2 as a sane default. + // + // Based on testing, this is almost universally faster than using all + // cores, as efficiency cores can be quite slow. In extreme cases, using + // all cores can be 10x slower than using cores/2. + int thread_count = + numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2; +#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD + num_threads_ = thread_count; +#else + auto threadpool = executorch::extension::threadpool::get_threadpool(); + if (threadpool) { + if (thread_count > 0) { + threadpool->_unsafe_reset_threadpool(thread_count); + } + } +#endif +#endif + } + + facebook::jni::local_ref> execute( + facebook::jni::alias_ref methodName, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + return execute_method(methodName->toStdString(), jinputs); + } + + jint load_method(facebook::jni::alias_ref methodName) { + return static_cast(module_->load_method(methodName->toStdString())); + } + + facebook::jni::local_ref> execute_method( + std::string method, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + // If no inputs is given, it will run with sample inputs (ones) + if (jinputs->size() == 0) { + auto result = module_->load_method(method); + if (result != Error::Ok) { + // Format hex string + std::stringstream ss; + ss << "Cannot get method names [Native Error: 0x" << std::hex + << std::uppercase << static_cast(result) << "]"; + + jni_helper::throwExecutorchException( + static_cast(result), ss.str()); + return {}; + } + auto&& underlying_method = module_->methods_[method].method; + auto&& buf = prepare_input_tensors(*underlying_method); + result = underlying_method->execute(); + if (result != Error::Ok) { + jni_helper::throwExecutorchException( + static_cast(result), + "Execution failed for method: " + method); + return {}; + } + facebook::jni::local_ref> jresult = + facebook::jni::JArrayClass::newArray( + underlying_method->outputs_size()); + + for (int i = 0; i < underlying_method->outputs_size(); i++) { + auto jevalue = + JEValue::newJEValueFromEValue(underlying_method->get_output(i)); + jresult->setElement(i, *jevalue); + } + return jresult; + } + + std::vector evalues; + std::vector tensors; + + static const auto typeCodeField = + JEValue::javaClassStatic()->getField("mTypeCode"); + + for (int i = 0; i < jinputs->size(); i++) { + auto jevalue = jinputs->getElement(i); + const auto typeCode = jevalue->getFieldValue(typeCodeField); + if (typeCode == JEValue::kTypeCodeTensor) { + tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); + evalues.emplace_back(tensors.back()); + } else if (typeCode == JEValue::kTypeCodeInt) { + static const auto toIntMethod = + JEValue::javaClassStatic()->getMethod("toInt"); + evalues.emplace_back(static_cast(toIntMethod(jevalue))); + } else if (typeCode == JEValue::kTypeCodeDouble) { + static const auto toDoubleMethod = + JEValue::javaClassStatic()->getMethod("toDouble"); + evalues.emplace_back(static_cast(toDoubleMethod(jevalue))); + } else if (typeCode == JEValue::kTypeCodeBool) { + static const auto toBoolMethod = + JEValue::javaClassStatic()->getMethod("toBool"); + evalues.emplace_back(static_cast(toBoolMethod(jevalue))); + } + } + +#if defined(ET_USE_THREADPOOL) && \ + defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD) + ::executorch::extension::threadpool::UseNThreadsThreadPoolGuard + thread_pool_guard(num_threads_); +#endif + +#ifdef EXECUTORCH_ANDROID_PROFILING + auto start = std::chrono::high_resolution_clock::now(); + auto result = module_->execute(method, evalues); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start) + .count(); + ET_LOG(Debug, "Execution time: %lld ms.", duration); + +#else + auto result = module_->execute(method, evalues); + +#endif + + if (!result.ok()) { + jni_helper::throwExecutorchException( + static_cast(result.error()), + "Execution failed for method: " + method); + return {}; + } + + facebook::jni::local_ref> jresult = + facebook::jni::JArrayClass::newArray(result.get().size()); + + for (int i = 0; i < result.get().size(); i++) { + auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); + jresult->setElement(i, *jevalue); + } + return jresult; + } + + facebook::jni::local_ref> + readLogBuffer() { + return readLogBufferUtil(); + } + + static facebook::jni::local_ref> + readLogBufferStatic(facebook::jni::alias_ref) { + return readLogBufferUtil(); + } + + static facebook::jni::local_ref> + readLogBufferUtil() { +#ifdef __ANDROID__ + + facebook::jni::local_ref> ret; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = facebook::jni::JArrayClass::newArray(size); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL + // MESSAGE". + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + + facebook::jni::local_ref jstr_message = + facebook::jni::make_jstring(ss.str().c_str()); + (*ret)[i] = jstr_message; + } + }); + + return ret; +#else + return facebook::jni::JArrayClass::newArray(0); +#endif + } + + jboolean etdump() { +#ifdef EXECUTORCH_ANDROID_PROFILING + executorch::etdump::ETDumpGen* etdumpgen = + (executorch::etdump::ETDumpGen*)module_->event_tracer(); + auto etdump_data = etdumpgen->get_etdump_data(); + + if (etdump_data.buf != nullptr && etdump_data.size > 0) { + int etdump_file = + open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); + if (etdump_file == -1) { + ET_LOG(Error, "Cannot create result.etdump error: %d", errno); + return false; + } + ssize_t bytes_written = + write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); + if (bytes_written == -1) { + ET_LOG(Error, "Cannot write result.etdump error: %d", errno); + return false; + } else { + ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); + } + close(etdump_file); + free(etdump_data.buf); + return true; + } else { + ET_LOG(Error, "No ETDump data available!"); + } +#endif + return false; + } + + facebook::jni::local_ref> getMethods() { + const auto& names_result = module_->method_names(); + if (!names_result.ok()) { + // Format hex string + std::stringstream ss; + ss << "Cannot get load module [Native Error: 0x" << std::hex + << std::uppercase << static_cast(names_result.error()) + << "]"; + + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str()); + return {}; + } + const auto& methods = names_result.get(); + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray(methods.size()); + int i = 0; + for (auto s : methods) { + facebook::jni::local_ref method_name = + facebook::jni::make_jstring(s.c_str()); + (*ret)[i] = method_name; + i++; + } + return ret; + } + + facebook::jni::local_ref> getUsedBackends( + facebook::jni::alias_ref methodName) { + auto method_name = methodName->toStdString(); + auto methodMetaResult = module_->method_meta(method_name); + if (!methodMetaResult.ok()) { + std::stringstream ss; + ss << "Cannot get method meta for '" << method_name + << "' [Native Error: 0x" << std::hex << std::uppercase + << static_cast(methodMetaResult.error()) << "]"; + jni_helper::throwExecutorchException( + static_cast(methodMetaResult.error()), ss.str()); + return {}; + } + auto methodMeta = methodMetaResult.get(); + std::unordered_set backends; + for (auto i = 0; i < methodMeta.num_backends(); i++) { + auto backend_name_result = methodMeta.get_backend_name(i); + if (backend_name_result.ok()) { + backends.insert(backend_name_result.get()); + } + } + + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray(backends.size()); + int i = 0; + for (auto s : backends) { + facebook::jni::local_ref backend_name = + facebook::jni::make_jstring(s.c_str()); + (*ret)[i] = backend_name; + i++; + } + return ret; + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), + makeNativeMethod("executeNative", ExecuTorchJni::execute), + makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method), + makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), + makeNativeMethod( + "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), + makeNativeMethod("etdump", ExecuTorchJni::etdump), + makeNativeMethod("getMethods", ExecuTorchJni::getMethods), + makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), + makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule), + makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline), + }); + } + + // ======= TTS Pipeline ======= + static Module* cpModule_; + + void setCpModule() { + cpModule_ = module_.get(); + ET_LOG(Info, "TTS CP module set"); + } + + // Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP. + facebook::jni::local_ref + runTtsPipeline( + facebook::jni::alias_ref jPrefill, jint nPrefill, + facebook::jni::alias_ref jTrailing, jint nTrailing, + facebook::jni::alias_ref jCodecEmb, + facebook::jni::alias_ref jCpEmbs, + facebook::jni::alias_ref jCpHeads, + facebook::jni::alias_ref jTCos, + facebook::jni::alias_ref jTSin, + facebook::jni::alias_ref jCCos, + facebook::jni::alias_ref jCSin, + facebook::jni::alias_ref jEos, + facebook::jni::alias_ref jPad, + jint maxTokens); + + static facebook::jni::local_ref + runTtsPipelineImpl( + Module* talker, Module* cp, + facebook::jni::alias_ref jPrefill, jint nPrefill, + facebook::jni::alias_ref jTrailing, jint nTrailing, + facebook::jni::alias_ref jCodecEmb, + facebook::jni::alias_ref jCpEmbs, + facebook::jni::alias_ref jCpHeads, + facebook::jni::alias_ref jTCos, + facebook::jni::alias_ref jTSin, + facebook::jni::alias_ref jCCos, + facebook::jni::alias_ref jCSin, + facebook::jni::alias_ref jEos, + facebook::jni::alias_ref jPad, + jint maxTokens); +}; +} // namespace executorch::extension + +// ======= TTS Pipeline statics and wrapper ======= +namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; } + +namespace executorch::extension { +facebook::jni::local_ref +ExecuTorchJni::runTtsPipeline( + facebook::jni::alias_ref jPrefill, jint nPrefill, + facebook::jni::alias_ref jTrailing, jint nTrailing, + facebook::jni::alias_ref jCodecEmb, + facebook::jni::alias_ref jCpEmbs, + facebook::jni::alias_ref jCpHeads, + facebook::jni::alias_ref jTCos, + facebook::jni::alias_ref jTSin, + facebook::jni::alias_ref jCCos, + facebook::jni::alias_ref jCSin, + facebook::jni::alias_ref jEos, + facebook::jni::alias_ref jPad, + jint maxTokens) { + if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); } + return runTtsPipelineImpl(module_.get(), cpModule_, + jPrefill, nPrefill, jTrailing, nTrailing, + jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens); +} +} // namespace + +#include +#include +#include + +static inline float tts_dot_neon(const float* a, const float* b, int n) { + float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0); + int i=0; + for(;i+15 topk(k,{0,-FLT_MAX}); + for(int i=0;itopk[k-1].v){topk[k-1]={i,logits[i]}; + for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}} + } + float maxv=topk[0].v,sum=0; + for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;} + tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL; + float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum; + float acc=0; + for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;} + return topk[0].i; +} + +namespace executorch::extension { +facebook::jni::local_ref +ExecuTorchJni::runTtsPipelineImpl( + Module* talker, Module* cp, + facebook::jni::alias_ref jPrefill, jint nPrefill, + facebook::jni::alias_ref jTrailing, jint nTrailing, + facebook::jni::alias_ref jCodecEmb, + facebook::jni::alias_ref jCpEmbs, + facebook::jni::alias_ref jCpHeads, + facebook::jni::alias_ref jTCos, + facebook::jni::alias_ref jTSin, + facebook::jni::alias_ref jCCos, + facebook::jni::alias_ref jCSin, + facebook::jni::alias_ref jEos, + facebook::jni::alias_ref jPad, + jint maxTokens) +{ + static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16; + static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100; + static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16; + static const int CODEC_EOS=2150; + + // Copy JNI arrays + auto copyArr = [](facebook::jni::alias_ref j) { + std::vector v(j->size()); + j->getRegion(0, j->size(), v.data()); + return v; + }; + auto prefill=copyArr(jPrefill); + auto trailing=nTrailing>0?copyArr(jTrailing):std::vector(); + auto codecEmb=copyArr(jCodecEmb); + auto cpEmbs=copyArr(jCpEmbs); + auto cpHeads=copyArr(jCpHeads); + auto tCos=copyArr(jTCos),tSin=copyArr(jTSin); + auto cCos=copyArr(jCCos),cSin=copyArr(jCSin); + auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad); + + int tkvElem=T_KV*T_KV_LEN*T_HD; + std::vector tK(T_L*tkvElem,0),tV(T_L*tkvElem,0); + float mask[T_KV_LEN]; for(int i=0;i allCodes,cb0Hist; + int pos=0,curCb0=-1,trIdx=0; + + // Get raw Method pointers for direct execution + auto tMethodRes = talker->method("forward"); + auto cMethodRes = cp->method("forward"); + if (!tMethodRes.ok() || !cMethodRes.ok()) { + ET_LOG(Error, "TTS cannot get Method pointers"); + return facebook::jni::JArrayInt::newArray(0); + } + Method* tMethod = tMethodRes.get(); + Method* cMethod = cMethodRes.get(); + + // Talker step: prepare_input_tensors + memcpy + execute (like cp_et_runner) + auto talkerStep = [&](const float* emb) { + int pi=std::min(pos,249); + int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1); + if(mi>=0) mask[mi]=0.0f; + + auto prep = executorch::extension::prepare_input_tensors(*tMethod); + if(!prep.ok()){ET_LOG(Error,"Talker prep fail");return;} + memcpy(tMethod->mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4); + memcpy(tMethod->mutable_input(1).toTensor().mutable_data_ptr(), mask, T_KV_LEN*4); + memcpy(tMethod->mutable_input(2).toTensor().mutable_data_ptr(), tCos.data()+pi*T_HD, T_HD*4); + memcpy(tMethod->mutable_input(3).toTensor().mutable_data_ptr(), tSin.data()+pi*T_HD, T_HD*4); + for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(), tK.data()+i*tkvElem, tkvElem*4); + memcpy(tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(), tV.data()+i*tkvElem, tkvElem*4); + } + auto status = tMethod->execute(); + if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;} + memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr(), DIM*4); + memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr(), VOCAB*4); + for(int i=0;iget_output(2+i*2).toTensor().const_data_ptr(), tkvElem*4); + memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr(), tkvElem*4); + } + pos++; + }; + + // CP step: 17 autoregressive steps using Method directly + auto cpStep = [&](const float* h, int cb0, int* codes) { + int ckvElem=C_KV*C_KV_LEN*C_HD; + std::vector ckv(C_L*2*ckvElem,0); + for(int step=0;step<17;step++){ + const float*emb; + if(step==0) emb=h; + else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM; + else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM; + + auto prep=executorch::extension::prepare_input_tensors(*cMethod); + if(!prep.ok()) break; + memcpy(cMethod->mutable_input(0).toTensor().mutable_data_ptr(), emb, DIM*4); + float*mp=cMethod->mutable_input(1).toTensor().mutable_data_ptr(); + for(int p=0;p=C_KV_LEN-1-step)?0.0f:-1e9f; + memcpy(cMethod->mutable_input(2).toTensor().mutable_data_ptr(), cCos.data()+step*C_HD, C_HD*4); + memcpy(cMethod->mutable_input(3).toTensor().mutable_data_ptr(), cSin.data()+step*C_HD, C_HD*4); + for(int i=0;imutable_input(4+i*2).toTensor().mutable_data_ptr(), ckv.data()+(i*2)*ckvElem, ckvElem*4); + memcpy(cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr(), ckv.data()+(i*2+1)*ckvElem, ckvElem*4); + } + auto status=cMethod->execute(); + if(status!=Error::Ok) break; + const float*ho=cMethod->get_output(0).toTensor().const_data_ptr(); + if(step>=1&&step-1<15){ + const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM; + int best=0;float bv=-FLT_MAX; + for(int j=0;jbv){bv=d;best=j;}} + codes[step-1]=best; + } + for(int i=0;iget_output(1+i*2).toTensor().const_data_ptr(),ckvElem*4); + memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr(),ckvElem*4); + } + } + }; + + auto t0=std::chrono::high_resolution_clock::now(); + // Prefill + for(int s=0;s(tP-t0).count(),curCb0); + + if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0); + + // Generation + float totalTalker=0,totalCp=0; + for(int g=0;g(tc1-tc0).count(); + for(int i=0;i<15;i++) codes[i+1]=cpCodes[i]; + for(int i=0;i(tt1-tt0).count(); + + for(int j=CB_SIZE;j seen(cb0Hist.begin(),cb0Hist.end()); + for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f; + int next=tts_sample_topk(logits,VOCAB,0.9f,50); + if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;} + if((int)cb0Hist.size()>=9){ + bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;} + if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;} + } + curCb0=next; + } + int nTok=(int)allCodes.size()/NUM_CB; + auto t1=std::chrono::high_resolution_clock::now(); + ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms", + nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1), + std::chrono::duration(t1-t0).count()); + + auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size()); + result->setRegion(0,(int)allCodes.size(),allCodes.data()); + return result; +} +} // namespace executorch::extension + +#ifdef EXECUTORCH_BUILD_LLAMA_JNI +extern void register_natives_for_llm(); +#else +// No op if we don't build LLM +void register_natives_for_llm() {} +#endif +extern void register_natives_for_runtime(); + +#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING +extern void register_natives_for_training(); +#else +// No op if we don't build training JNI +void register_natives_for_training() {} +#endif + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { + return facebook::jni::initialize(vm, [] { + executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_llm(); + register_natives_for_runtime(); + register_natives_for_training(); + }); +} diff --git a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt index 0d6bc7e..315bc82 100644 --- a/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt +++ b/kazeia-android/app/src/main/java/com/kazeia/tts/Qwen3TtsEngine.kt @@ -229,29 +229,11 @@ class Qwen3TtsEngine( // Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths) android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true) - // Try native C++ pipeline first (single QNN instance, no Java overhead) + // Load .pte modules via Java JNI (native pipeline uses same instances) run { val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte") - val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte") - if (etModel.exists() && talkerPte.exists()) { + if (etModel.exists() && cpPteModule == null) { try { - val tn = System.currentTimeMillis() - nativePipelineReady = TtsPipeline.nativeInit( - talkerPte.absolutePath, etModel.absolutePath - ) - nlog("Native C++ pipeline: ${if (nativePipelineReady) "OK" else "FAILED"} (${System.currentTimeMillis() - tn}ms)") - } catch (e: Exception) { - nlog("Native pipeline init failed: ${e.message}") - } - } - } - - // Fallback: Load Java .pte modules (only if native pipeline failed) - run { - val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte") - if (!nativePipelineReady && etModel.exists() && cpPteModule == null) { - try { - nlog("Loading Java .pte modules (native unavailable)...") val t0 = System.currentTimeMillis() cpPteModule = org.pytorch.executorch.Module.load( etModel.absolutePath, @@ -265,6 +247,10 @@ class Qwen3TtsEngine( if (lmResult != 0) { nlog("CP .pte loadMethod failed ($lmResult), disabling JNI") cpPteModule = null + } else { + // Register CP module for native pipeline + cpPteModule!!.nativeSetCpModule() + nlog("CP module registered for native pipeline") } } catch (e: Exception) { nlog("CP .pte JNI failed: ${e.message}") @@ -2298,11 +2284,9 @@ class Qwen3TtsEngine( nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)") val allCodes: Array - // Native C++ disabled: QNN HTP compilation not deterministic between loads - // Two instances of same .pte give slightly different hidden states → trembling - // Keep Java pipeline (same QNN instance, validated quality) - if (false && nativePipelineReady) { - // Native C++ pipeline — zero Java overhead + // Native C++ pipeline using SAME Java Module instances (no quality loss) + if (talkerPteModule != null && cpPteModule != null) { + // C++ loop on Java's Module instances — same QNN compilation, no JNI overhead val prefillFlat = FloatArray(nPrefill * TALKER_DIM) for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM) val nTrailing = nTotal - nPrefill @@ -2336,8 +2320,8 @@ class Qwen3TtsEngine( ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM) } - nlog("Running native C++ pipeline...") - val flat = TtsPipeline.nativeRun( + nlog("Running native C++ pipeline (shared Module)...") + val flat = talkerPteModule!!.nativeRunTtsPipeline( prefillFlat, nPrefill, trailingFlat, nTrailing, codecEmbedding ?: FloatArray(0),