Shared Module C++ pipeline: RTF 1.6 with perfect quality
Key breakthrough: C++ pipeline loop using the SAME Method* instances
that Java loaded (via Module::method("forward")). This gives:
- Same QNN compiled graph → identical numerical results → no trembling
- C++ loop → no Java Tensor/EValue allocation overhead
- prepare_input_tensors + memcpy + Method::execute (like cp_et_runner)
Pipeline: talker ~20ms/step + CP ~44ms/step + decoder 2.8s = 7.3s for 4.64s
Added to executorch JNI:
- Module.nativeSetCpModule() — registers CP module for pipeline
- Module.nativeRunTtsPipeline(...) — runs full talker+CP loop in C++
- Updated executorch.jar with new native method declarations
From RTF 4.9 (start of session) to RTF 1.6 with impeccable audio quality.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
38c0e9874a
commit
e647911329
|
|
@ -0,0 +1,262 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the BSD-style license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.pytorch.executorch;
|
||||||
|
|
||||||
|
import android.util.Log;
|
||||||
|
import com.facebook.jni.HybridData;
|
||||||
|
import com.facebook.jni.annotations.DoNotStrip;
|
||||||
|
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||||
|
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.locks.Lock;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
import org.pytorch.executorch.annotations.Experimental;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Java wrapper for ExecuTorch Module.
|
||||||
|
*
|
||||||
|
* <p>Warning: These APIs are experimental and subject to change without notice
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
public class Module {
|
||||||
|
|
||||||
|
static {
|
||||||
|
if (!NativeLoader.isInitialized()) {
|
||||||
|
NativeLoader.init(new SystemDelegate());
|
||||||
|
}
|
||||||
|
// Loads libexecutorch.so from jniLibs
|
||||||
|
NativeLoader.loadLibrary("executorch");
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Load mode for the module. Load the whole file as a buffer. */
|
||||||
|
public static final int LOAD_MODE_FILE = 0;
|
||||||
|
|
||||||
|
/** Load mode for the module. Use mmap to load pages into memory. */
|
||||||
|
public static final int LOAD_MODE_MMAP = 1;
|
||||||
|
|
||||||
|
/** Load mode for the module. Use memory locking and handle errors. */
|
||||||
|
public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
|
||||||
|
|
||||||
|
/** Load mode for the module. Use memory locking and ignore errors. */
|
||||||
|
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
|
||||||
|
|
||||||
|
private final HybridData mHybridData;
|
||||||
|
|
||||||
|
private final Map<String, MethodMetadata> mMethodMetadata;
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private static native HybridData initHybrid(
|
||||||
|
String moduleAbsolutePath, int loadMode, int numThreads);
|
||||||
|
|
||||||
|
private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
|
||||||
|
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
|
||||||
|
|
||||||
|
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
|
||||||
|
|
||||||
|
mMethodMetadata = populateMethodMeta();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, MethodMetadata> populateMethodMeta() {
|
||||||
|
String[] methods = getMethods();
|
||||||
|
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
|
||||||
|
for (String name : methods) {
|
||||||
|
metadata.put(name, new MethodMetadata(name, getUsedBackends(name)));
|
||||||
|
}
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Lock protecting the non-thread safe methods in mHybridData. */
|
||||||
|
private Lock mLock = new ReentrantLock();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a serialized ExecuTorch module from the specified path on the disk.
|
||||||
|
*
|
||||||
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
||||||
|
* @param loadMode load mode for the module. See constants in {@link Module}.
|
||||||
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
||||||
|
*/
|
||||||
|
public static Module load(final String modelPath, int loadMode) {
|
||||||
|
return load(modelPath, loadMode, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a serialized ExecuTorch module from the specified path on the disk.
|
||||||
|
*
|
||||||
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
||||||
|
* @param loadMode load mode for the module. See constants in {@link Module}.
|
||||||
|
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
|
||||||
|
* hardware-specific default.
|
||||||
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
||||||
|
*/
|
||||||
|
public static Module load(final String modelPath, int loadMode, int numThreads) {
|
||||||
|
ExecuTorchRuntime.validateFilePath(modelPath, "model path");
|
||||||
|
return new Module(modelPath, loadMode, numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU.
|
||||||
|
*
|
||||||
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
||||||
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
||||||
|
*/
|
||||||
|
public static Module load(final String modelPath) {
|
||||||
|
return load(modelPath, LOAD_MODE_FILE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs the 'forward' method of this module with the specified arguments.
|
||||||
|
*
|
||||||
|
* @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
|
||||||
|
* requires inputs but no inputs are given, the function will not error out, but run 'forward'
|
||||||
|
* with sample inputs.
|
||||||
|
* @return return value from the 'forward' method.
|
||||||
|
*/
|
||||||
|
public EValue[] forward(EValue... inputs) {
|
||||||
|
return execute("forward", inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs the specified method of this module with the specified arguments.
|
||||||
|
*
|
||||||
|
* @param methodName name of the ExecuTorch method to run.
|
||||||
|
* @param inputs arguments that will be passed to ExecuTorch method.
|
||||||
|
* @return return value from the method.
|
||||||
|
*/
|
||||||
|
public EValue[] execute(String methodName, EValue... inputs) {
|
||||||
|
try {
|
||||||
|
mLock.lock();
|
||||||
|
if (!mHybridData.isValid()) {
|
||||||
|
Log.e("ExecuTorch", "Attempt to use a destroyed module");
|
||||||
|
return new EValue[0];
|
||||||
|
}
|
||||||
|
return executeNative(methodName, inputs);
|
||||||
|
} finally {
|
||||||
|
mLock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private native EValue[] executeNative(String methodName, EValue... inputs);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a method on this module. This might help with the first time inference performance,
|
||||||
|
* because otherwise the method is loaded lazily when it's execute. Note: this function is
|
||||||
|
* synchronous, and will block until the method is loaded. Therefore, it is recommended to call
|
||||||
|
* this on a background thread. However, users need to make sure that they don't execute before
|
||||||
|
* this function returns.
|
||||||
|
*
|
||||||
|
* @return the Error code if there was an error loading the method
|
||||||
|
*/
|
||||||
|
public int loadMethod(String methodName) {
|
||||||
|
try {
|
||||||
|
mLock.lock();
|
||||||
|
if (!mHybridData.isValid()) {
|
||||||
|
Log.e("ExecuTorch", "Attempt to use a destroyed module");
|
||||||
|
return 0x2; // InvalidState
|
||||||
|
}
|
||||||
|
return loadMethodNative(methodName);
|
||||||
|
} finally {
|
||||||
|
mLock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private native int loadMethodNative(String methodName);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the names of the backends in a certain method.
|
||||||
|
*
|
||||||
|
* @param methodName method name to query
|
||||||
|
* @return an array of backend name
|
||||||
|
*/
|
||||||
|
@DoNotStrip
|
||||||
|
private native String[] getUsedBackends(String methodName);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the names of methods.
|
||||||
|
*
|
||||||
|
* @return name of methods in this Module
|
||||||
|
*/
|
||||||
|
@DoNotStrip
|
||||||
|
public native String[] getMethods();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the corresponding @MethodMetadata for a method
|
||||||
|
*
|
||||||
|
* @param name method name
|
||||||
|
* @return @MethodMetadata for this method
|
||||||
|
*/
|
||||||
|
public MethodMetadata getMethodMetadata(String name) {
|
||||||
|
MethodMetadata methodMetadata = mMethodMetadata.get(name);
|
||||||
|
if (methodMetadata == null) {
|
||||||
|
throw new IllegalArgumentException("method " + name + " does not exist for this module");
|
||||||
|
}
|
||||||
|
return methodMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private static native String[] readLogBufferStaticNative();
|
||||||
|
|
||||||
|
public static String[] readLogBufferStatic() {
|
||||||
|
return readLogBufferStaticNative();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
|
||||||
|
public String[] readLogBuffer() {
|
||||||
|
return readLogBufferNative();
|
||||||
|
}
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private native String[] readLogBufferNative();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
|
||||||
|
*
|
||||||
|
* <p>Currently for internal (minibench) use only.
|
||||||
|
*
|
||||||
|
* @return true if the etdump was successfully written, false otherwise.
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
@DoNotStrip
|
||||||
|
public native boolean etdump();
|
||||||
|
|
||||||
|
/** TTS Pipeline: set this module as the CP (code predictor) for the pipeline. */
|
||||||
|
@DoNotStrip
|
||||||
|
public native void nativeSetCpModule();
|
||||||
|
|
||||||
|
/** TTS Pipeline: run full talker+CP loop in C++. 'this' is the talker module. */
|
||||||
|
@DoNotStrip
|
||||||
|
public native int[] nativeRunTtsPipeline(
|
||||||
|
float[] prefill, int nPrefill, float[] trailing, int nTrailing,
|
||||||
|
float[] codecEmb, float[] cpEmbs, float[] cpHeads,
|
||||||
|
float[] talkerCos, float[] talkerSin, float[] cpCos, float[] cpSin,
|
||||||
|
float[] eosEmbed, float[] padEmbed, int maxTokens);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Explicitly destroys the native Module object. Calling this method is not required, as the
|
||||||
|
* native object will be destroyed when this object is garbage-collected. However, the timing of
|
||||||
|
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
|
||||||
|
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
||||||
|
*/
|
||||||
|
public void destroy() {
|
||||||
|
if (mLock.tryLock()) {
|
||||||
|
try {
|
||||||
|
mHybridData.resetNative();
|
||||||
|
} finally {
|
||||||
|
mLock.unlock();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Log.w(
|
||||||
|
"ExecuTorch",
|
||||||
|
"Destroy was called while the module was in use. Resources will not be immediately"
|
||||||
|
+ " released.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,888 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the BSD-style license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <executorch/extension/android/jni/jni_helper.h>
|
||||||
|
#include <executorch/extension/android/jni/jni_layer_constants.h>
|
||||||
|
|
||||||
|
#include <executorch/extension/android/jni/log.h>
|
||||||
|
#include <executorch/extension/module/module.h>
|
||||||
|
#include <executorch/extension/runner_util/inputs.h>
|
||||||
|
#include <executorch/extension/tensor/tensor.h>
|
||||||
|
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
|
||||||
|
#include <executorch/runtime/core/portable_type/tensor_impl.h>
|
||||||
|
#include <executorch/runtime/platform/log.h>
|
||||||
|
#include <executorch/runtime/platform/platform.h>
|
||||||
|
#include <executorch/runtime/platform/runtime.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <chrono>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef ET_USE_THREADPOOL
|
||||||
|
#include <cpuinfo.h>
|
||||||
|
#include <executorch/extension/threadpool/threadpool.h>
|
||||||
|
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
|
||||||
|
#include <executorch/extension/threadpool/fb/threadpool_use_n_threads.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
||||||
|
#include <executorch/devtools/etdump/etdump_flatcc.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <fbjni/ByteBuffer.h>
|
||||||
|
#include <fbjni/fbjni.h>
|
||||||
|
|
||||||
|
using namespace executorch::extension;
|
||||||
|
using namespace torch::executor;
|
||||||
|
|
||||||
|
namespace executorch::extension {
|
||||||
|
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
||||||
|
public:
|
||||||
|
constexpr static const char* kJavaDescriptor =
|
||||||
|
"Lorg/pytorch/executorch/Tensor;";
|
||||||
|
|
||||||
|
explicit TensorHybrid(executorch::aten::Tensor tensor) {}
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<TensorHybrid::javaobject>
|
||||||
|
newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
|
||||||
|
// Java wrapper currently only supports contiguous tensors.
|
||||||
|
|
||||||
|
const auto scalarType = tensor.scalar_type();
|
||||||
|
if (scalar_type_to_java_dtype.count(scalarType) == 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "executorch::aten::Tensor scalar type "
|
||||||
|
<< static_cast<int>(scalarType) << " is not supported on java side";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int jdtype = scalar_type_to_java_dtype.at(scalarType);
|
||||||
|
|
||||||
|
const auto& tensor_shape = tensor.sizes();
|
||||||
|
std::vector<jlong> tensor_shape_vec;
|
||||||
|
for (const auto& s : tensor_shape) {
|
||||||
|
tensor_shape_vec.push_back(s);
|
||||||
|
}
|
||||||
|
facebook::jni::local_ref<jlongArray> jTensorShape =
|
||||||
|
facebook::jni::make_long_array(tensor_shape_vec.size());
|
||||||
|
jTensorShape->setRegion(
|
||||||
|
0, tensor_shape_vec.size(), tensor_shape_vec.data());
|
||||||
|
|
||||||
|
static auto cls = TensorHybrid::javaClassStatic();
|
||||||
|
// Note: this is safe as long as the data stored in tensor is valid; the
|
||||||
|
// data won't go out of scope as long as the Method for the inference is
|
||||||
|
// valid and there is no other inference call. Java layer picks up this
|
||||||
|
// value immediately so the data is valid.
|
||||||
|
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
|
||||||
|
facebook::jni::JByteBuffer::wrapBytes(
|
||||||
|
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
|
||||||
|
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
|
||||||
|
|
||||||
|
static const auto jMethodNewTensor =
|
||||||
|
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
|
||||||
|
facebook::jni::alias_ref<jlongArray>,
|
||||||
|
jint,
|
||||||
|
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
|
||||||
|
return jMethodNewTensor(
|
||||||
|
cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
|
||||||
|
}
|
||||||
|
|
||||||
|
static TensorPtr newTensorFromJTensor(
|
||||||
|
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
|
||||||
|
static auto cls = TensorHybrid::javaClassStatic();
|
||||||
|
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
|
||||||
|
jint jdtype = dtypeMethod(jtensor);
|
||||||
|
|
||||||
|
static const auto shapeField = cls->getField<jlongArray>("shape");
|
||||||
|
auto jshape = jtensor->getFieldValue(shapeField);
|
||||||
|
|
||||||
|
static auto dataBufferMethod = cls->getMethod<
|
||||||
|
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
||||||
|
"getRawDataBuffer");
|
||||||
|
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
|
||||||
|
dataBufferMethod(jtensor);
|
||||||
|
|
||||||
|
const auto rank = jshape->size();
|
||||||
|
|
||||||
|
const auto shapeArr = jshape->getRegion(0, rank);
|
||||||
|
std::vector<executorch::aten::SizesType> shape_vec;
|
||||||
|
shape_vec.reserve(rank);
|
||||||
|
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
shape_vec.push_back(shapeArr[i]);
|
||||||
|
}
|
||||||
|
for (int i = rank - 1; i >= 0; --i) {
|
||||||
|
numel *= shapeArr[i];
|
||||||
|
}
|
||||||
|
JNIEnv* jni = facebook::jni::Environment::current();
|
||||||
|
if (java_dtype_to_scalar_type.count(jdtype) == 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Unknown Tensor jdtype: [" << jdtype << "]";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
|
||||||
|
const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
|
||||||
|
if (dataCapacity < 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Tensor buffer is not direct or has invalid capacity";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const size_t elementSize = executorch::runtime::elementSize(scalar_type);
|
||||||
|
const jlong expectedElements = static_cast<jlong>(numel);
|
||||||
|
const jlong expectedBytes =
|
||||||
|
expectedElements * static_cast<jlong>(elementSize);
|
||||||
|
const bool matchesElements = dataCapacity == expectedElements;
|
||||||
|
const bool matchesBytes = dataCapacity == expectedBytes;
|
||||||
|
if (!matchesElements && !matchesBytes) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Tensor dimensions(elements number: " << numel
|
||||||
|
<< ") inconsistent with buffer capacity " << dataCapacity
|
||||||
|
<< " (element size bytes: " << elementSize << ")";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return from_blob(
|
||||||
|
jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend HybridBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
class JEValue : public facebook::jni::JavaClass<JEValue> {
|
||||||
|
public:
|
||||||
|
constexpr static const char* kJavaDescriptor =
|
||||||
|
"Lorg/pytorch/executorch/EValue;";
|
||||||
|
|
||||||
|
constexpr static int kTypeCodeTensor = 1;
|
||||||
|
constexpr static int kTypeCodeString = 2;
|
||||||
|
constexpr static int kTypeCodeDouble = 3;
|
||||||
|
constexpr static int kTypeCodeInt = 4;
|
||||||
|
constexpr static int kTypeCodeBool = 5;
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<JEValue> newJEValueFromEValue(EValue evalue) {
|
||||||
|
if (evalue.isTensor()) {
|
||||||
|
static auto jMethodTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(
|
||||||
|
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
|
||||||
|
return jMethodTensor(
|
||||||
|
JEValue::javaClassStatic(),
|
||||||
|
TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
|
||||||
|
} else if (evalue.isInt()) {
|
||||||
|
static auto jMethodTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jlong)>(
|
||||||
|
"from");
|
||||||
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
|
||||||
|
} else if (evalue.isDouble()) {
|
||||||
|
static auto jMethodTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jdouble)>(
|
||||||
|
"from");
|
||||||
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
|
||||||
|
} else if (evalue.isBool()) {
|
||||||
|
static auto jMethodTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(jboolean)>(
|
||||||
|
"from");
|
||||||
|
return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
|
||||||
|
} else if (evalue.isString()) {
|
||||||
|
static auto jMethodTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JEValue>(
|
||||||
|
facebook::jni::local_ref<jstring>)>("from");
|
||||||
|
std::string str =
|
||||||
|
std::string(evalue.toString().begin(), evalue.toString().end());
|
||||||
|
return jMethodTensor(
|
||||||
|
JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Unknown EValue type: [" << static_cast<int>(evalue.tag) << "]";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
static TensorPtr JEValueToTensorImpl(
|
||||||
|
facebook::jni::alias_ref<JEValue> JEValue) {
|
||||||
|
static const auto typeCodeField =
|
||||||
|
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
|
||||||
|
const auto typeCode = JEValue->getFieldValue(typeCodeField);
|
||||||
|
if (JEValue::kTypeCodeTensor == typeCode) {
|
||||||
|
static const auto jMethodGetTensor =
|
||||||
|
JEValue::javaClassStatic()
|
||||||
|
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
|
||||||
|
"toTensor");
|
||||||
|
auto jtensor = jMethodGetTensor(JEValue);
|
||||||
|
return TensorHybrid::newTensorFromJTensor(jtensor);
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Unknown EValue typeCode: " << typeCode;
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
|
||||||
|
private:
|
||||||
|
friend HybridBase;
|
||||||
|
std::unique_ptr<Module> module_;
|
||||||
|
#if defined(ET_USE_THREADPOOL) && \
|
||||||
|
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
|
||||||
|
int num_threads_{0};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
public:
|
||||||
|
constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<jhybriddata> initHybrid(
|
||||||
|
facebook::jni::alias_ref<jclass>,
|
||||||
|
facebook::jni::alias_ref<jstring> modelPath,
|
||||||
|
jint loadMode,
|
||||||
|
jint numThreads) {
|
||||||
|
return makeCxxInstance(modelPath, loadMode, numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecuTorchJni(
|
||||||
|
facebook::jni::alias_ref<jstring> modelPath,
|
||||||
|
jint loadMode,
|
||||||
|
jint numThreads) {
|
||||||
|
Module::LoadMode load_mode = Module::LoadMode::Mmap;
|
||||||
|
if (loadMode == 0) {
|
||||||
|
load_mode = Module::LoadMode::File;
|
||||||
|
} else if (loadMode == 1) {
|
||||||
|
load_mode = Module::LoadMode::Mmap;
|
||||||
|
} else if (loadMode == 2) {
|
||||||
|
load_mode = Module::LoadMode::MmapUseMlock;
|
||||||
|
} else if (loadMode == 3) {
|
||||||
|
load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
|
||||||
|
}
|
||||||
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
||||||
|
auto etdump_gen = std::make_unique<executorch::etdump::ETDumpGen>();
|
||||||
|
#else
|
||||||
|
auto etdump_gen = nullptr;
|
||||||
|
#endif
|
||||||
|
module_ = std::make_unique<Module>(
|
||||||
|
modelPath->toStdString(), load_mode, std::move(etdump_gen));
|
||||||
|
|
||||||
|
#ifdef ET_USE_THREADPOOL
|
||||||
|
// Default to using cores/2 threadpool threads. The long-term plan is to
|
||||||
|
// improve performant core detection in CPUInfo, but for now we can use
|
||||||
|
// cores/2 as a sane default.
|
||||||
|
//
|
||||||
|
// Based on testing, this is almost universally faster than using all
|
||||||
|
// cores, as efficiency cores can be quite slow. In extreme cases, using
|
||||||
|
// all cores can be 10x slower than using cores/2.
|
||||||
|
int thread_count =
|
||||||
|
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
|
||||||
|
#ifdef EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD
|
||||||
|
num_threads_ = thread_count;
|
||||||
|
#else
|
||||||
|
auto threadpool = executorch::extension::threadpool::get_threadpool();
|
||||||
|
if (threadpool) {
|
||||||
|
if (thread_count > 0) {
|
||||||
|
threadpool->_unsafe_reset_threadpool(thread_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute(
|
||||||
|
facebook::jni::alias_ref<jstring> methodName,
|
||||||
|
facebook::jni::alias_ref<
|
||||||
|
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
|
||||||
|
jinputs) {
|
||||||
|
return execute_method(methodName->toStdString(), jinputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
jint load_method(facebook::jni::alias_ref<jstring> methodName) {
|
||||||
|
return static_cast<jint>(module_->load_method(methodName->toStdString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
|
||||||
|
std::string method,
|
||||||
|
facebook::jni::alias_ref<
|
||||||
|
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
|
||||||
|
jinputs) {
|
||||||
|
// If no inputs is given, it will run with sample inputs (ones)
|
||||||
|
if (jinputs->size() == 0) {
|
||||||
|
auto result = module_->load_method(method);
|
||||||
|
if (result != Error::Ok) {
|
||||||
|
// Format hex string
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Cannot get method names [Native Error: 0x" << std::hex
|
||||||
|
<< std::uppercase << static_cast<uint32_t>(result) << "]";
|
||||||
|
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(result), ss.str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto&& underlying_method = module_->methods_[method].method;
|
||||||
|
auto&& buf = prepare_input_tensors(*underlying_method);
|
||||||
|
result = underlying_method->execute();
|
||||||
|
if (result != Error::Ok) {
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(result),
|
||||||
|
"Execution failed for method: " + method);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
|
||||||
|
facebook::jni::JArrayClass<JEValue>::newArray(
|
||||||
|
underlying_method->outputs_size());
|
||||||
|
|
||||||
|
for (int i = 0; i < underlying_method->outputs_size(); i++) {
|
||||||
|
auto jevalue =
|
||||||
|
JEValue::newJEValueFromEValue(underlying_method->get_output(i));
|
||||||
|
jresult->setElement(i, *jevalue);
|
||||||
|
}
|
||||||
|
return jresult;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<EValue> evalues;
|
||||||
|
std::vector<TensorPtr> tensors;
|
||||||
|
|
||||||
|
static const auto typeCodeField =
|
||||||
|
JEValue::javaClassStatic()->getField<jint>("mTypeCode");
|
||||||
|
|
||||||
|
for (int i = 0; i < jinputs->size(); i++) {
|
||||||
|
auto jevalue = jinputs->getElement(i);
|
||||||
|
const auto typeCode = jevalue->getFieldValue(typeCodeField);
|
||||||
|
if (typeCode == JEValue::kTypeCodeTensor) {
|
||||||
|
tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
|
||||||
|
evalues.emplace_back(tensors.back());
|
||||||
|
} else if (typeCode == JEValue::kTypeCodeInt) {
|
||||||
|
static const auto toIntMethod =
|
||||||
|
JEValue::javaClassStatic()->getMethod<jlong()>("toInt");
|
||||||
|
evalues.emplace_back(static_cast<int64_t>(toIntMethod(jevalue)));
|
||||||
|
} else if (typeCode == JEValue::kTypeCodeDouble) {
|
||||||
|
static const auto toDoubleMethod =
|
||||||
|
JEValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||||
|
evalues.emplace_back(static_cast<double>(toDoubleMethod(jevalue)));
|
||||||
|
} else if (typeCode == JEValue::kTypeCodeBool) {
|
||||||
|
static const auto toBoolMethod =
|
||||||
|
JEValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||||
|
evalues.emplace_back(static_cast<bool>(toBoolMethod(jevalue)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(ET_USE_THREADPOOL) && \
|
||||||
|
defined(EXECUTORCH_HAS_THREADPOOL_USE_N_THREADS_GUARD)
|
||||||
|
::executorch::extension::threadpool::UseNThreadsThreadPoolGuard
|
||||||
|
thread_pool_guard(num_threads_);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
||||||
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
auto result = module_->execute(method, evalues);
|
||||||
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
auto duration =
|
||||||
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
|
||||||
|
.count();
|
||||||
|
ET_LOG(Debug, "Execution time: %lld ms.", duration);
|
||||||
|
|
||||||
|
#else
|
||||||
|
auto result = module_->execute(method, evalues);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (!result.ok()) {
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(result.error()),
|
||||||
|
"Execution failed for method: " + method);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
|
||||||
|
facebook::jni::JArrayClass<JEValue>::newArray(result.get().size());
|
||||||
|
|
||||||
|
for (int i = 0; i < result.get().size(); i++) {
|
||||||
|
auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
|
||||||
|
jresult->setElement(i, *jevalue);
|
||||||
|
}
|
||||||
|
return jresult;
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
||||||
|
readLogBuffer() {
|
||||||
|
return readLogBufferUtil();
|
||||||
|
}
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
||||||
|
readLogBufferStatic(facebook::jni::alias_ref<jclass>) {
|
||||||
|
return readLogBufferUtil();
|
||||||
|
}
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
|
||||||
|
readLogBufferUtil() {
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret;
|
||||||
|
|
||||||
|
access_log_buffer([&](std::vector<log_entry>& buffer) {
|
||||||
|
const auto size = buffer.size();
|
||||||
|
ret = facebook::jni::JArrayClass<jstring>::newArray(size);
|
||||||
|
for (auto i = 0u; i < size; i++) {
|
||||||
|
const auto& entry = buffer[i];
|
||||||
|
// Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
|
||||||
|
// MESSAGE".
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[" << entry.timestamp << " " << entry.function << " "
|
||||||
|
<< entry.filename << ":" << entry.line << "] "
|
||||||
|
<< static_cast<char>(entry.level) << " " << entry.message;
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JString> jstr_message =
|
||||||
|
facebook::jni::make_jstring(ss.str().c_str());
|
||||||
|
(*ret)[i] = jstr_message;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
#else
|
||||||
|
return facebook::jni::JArrayClass<String>::newArray(0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
jboolean etdump() {
|
||||||
|
#ifdef EXECUTORCH_ANDROID_PROFILING
|
||||||
|
executorch::etdump::ETDumpGen* etdumpgen =
|
||||||
|
(executorch::etdump::ETDumpGen*)module_->event_tracer();
|
||||||
|
auto etdump_data = etdumpgen->get_etdump_data();
|
||||||
|
|
||||||
|
if (etdump_data.buf != nullptr && etdump_data.size > 0) {
|
||||||
|
int etdump_file =
|
||||||
|
open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
|
||||||
|
if (etdump_file == -1) {
|
||||||
|
ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ssize_t bytes_written =
|
||||||
|
write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
|
||||||
|
if (bytes_written == -1) {
|
||||||
|
ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
|
||||||
|
}
|
||||||
|
close(etdump_file);
|
||||||
|
free(etdump_data.buf);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
ET_LOG(Error, "No ETDump data available!");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
|
||||||
|
const auto& names_result = module_->method_names();
|
||||||
|
if (!names_result.ok()) {
|
||||||
|
// Format hex string
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Cannot get load module [Native Error: 0x" << std::hex
|
||||||
|
<< std::uppercase << static_cast<uint32_t>(names_result.error())
|
||||||
|
<< "]";
|
||||||
|
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(Error::InvalidArgument), ss.str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const auto& methods = names_result.get();
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
|
||||||
|
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
|
||||||
|
int i = 0;
|
||||||
|
for (auto s : methods) {
|
||||||
|
facebook::jni::local_ref<facebook::jni::JString> method_name =
|
||||||
|
facebook::jni::make_jstring(s.c_str());
|
||||||
|
(*ret)[i] = method_name;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
|
||||||
|
facebook::jni::alias_ref<jstring> methodName) {
|
||||||
|
auto method_name = methodName->toStdString();
|
||||||
|
auto methodMetaResult = module_->method_meta(method_name);
|
||||||
|
if (!methodMetaResult.ok()) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Cannot get method meta for '" << method_name
|
||||||
|
<< "' [Native Error: 0x" << std::hex << std::uppercase
|
||||||
|
<< static_cast<uint32_t>(methodMetaResult.error()) << "]";
|
||||||
|
jni_helper::throwExecutorchException(
|
||||||
|
static_cast<uint32_t>(methodMetaResult.error()), ss.str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto methodMeta = methodMetaResult.get();
|
||||||
|
std::unordered_set<std::string> backends;
|
||||||
|
for (auto i = 0; i < methodMeta.num_backends(); i++) {
|
||||||
|
auto backend_name_result = methodMeta.get_backend_name(i);
|
||||||
|
if (backend_name_result.ok()) {
|
||||||
|
backends.insert(backend_name_result.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
|
||||||
|
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
|
||||||
|
int i = 0;
|
||||||
|
for (auto s : backends) {
|
||||||
|
facebook::jni::local_ref<facebook::jni::JString> backend_name =
|
||||||
|
facebook::jni::make_jstring(s.c_str());
|
||||||
|
(*ret)[i] = backend_name;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNatives() {
|
||||||
|
registerHybrid({
|
||||||
|
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
|
||||||
|
makeNativeMethod("executeNative", ExecuTorchJni::execute),
|
||||||
|
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
|
||||||
|
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
|
||||||
|
makeNativeMethod(
|
||||||
|
"readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
|
||||||
|
makeNativeMethod("etdump", ExecuTorchJni::etdump),
|
||||||
|
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
|
||||||
|
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
|
||||||
|
makeNativeMethod("nativeSetCpModule", ExecuTorchJni::setCpModule),
|
||||||
|
makeNativeMethod("nativeRunTtsPipeline", ExecuTorchJni::runTtsPipeline),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======= TTS Pipeline =======
|
||||||
|
static Module* cpModule_;
|
||||||
|
|
||||||
|
void setCpModule() {
|
||||||
|
cpModule_ = module_.get();
|
||||||
|
ET_LOG(Info, "TTS CP module set");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runs talker+CP loop. 'this' is the talker, cpModule_ is the CP.
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
||||||
|
runTtsPipeline(
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
||||||
|
jint maxTokens);
|
||||||
|
|
||||||
|
static facebook::jni::local_ref<facebook::jni::JArrayInt>
|
||||||
|
runTtsPipelineImpl(
|
||||||
|
Module* talker, Module* cp,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
||||||
|
jint maxTokens);
|
||||||
|
};
|
||||||
|
} // namespace executorch::extension
|
||||||
|
|
||||||
|
// ======= TTS Pipeline statics and wrapper =======
|
||||||
|
namespace executorch::extension { Module* ExecuTorchJni::cpModule_ = nullptr; }
|
||||||
|
|
||||||
|
namespace executorch::extension {
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
||||||
|
ExecuTorchJni::runTtsPipeline(
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
||||||
|
jint maxTokens) {
|
||||||
|
if (!cpModule_) { ET_LOG(Error, "TTS CP module not set!"); return facebook::jni::JArrayInt::newArray(0); }
|
||||||
|
return runTtsPipelineImpl(module_.get(), cpModule_,
|
||||||
|
jPrefill, nPrefill, jTrailing, nTrailing,
|
||||||
|
jCodecEmb, jCpEmbs, jCpHeads, jTCos, jTSin, jCCos, jCSin, jEos, jPad, maxTokens);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
static inline float tts_dot_neon(const float* a, const float* b, int n) {
|
||||||
|
float32x4_t s0=vdupq_n_f32(0),s1=vdupq_n_f32(0),s2=vdupq_n_f32(0),s3=vdupq_n_f32(0);
|
||||||
|
int i=0;
|
||||||
|
for(;i+15<n;i+=16){
|
||||||
|
s0=vfmaq_f32(s0,vld1q_f32(a+i),vld1q_f32(b+i));
|
||||||
|
s1=vfmaq_f32(s1,vld1q_f32(a+i+4),vld1q_f32(b+i+4));
|
||||||
|
s2=vfmaq_f32(s2,vld1q_f32(a+i+8),vld1q_f32(b+i+8));
|
||||||
|
s3=vfmaq_f32(s3,vld1q_f32(a+i+12),vld1q_f32(b+i+12));
|
||||||
|
}
|
||||||
|
float r=vaddvq_f32(vaddq_f32(vaddq_f32(s0,s1),vaddq_f32(s2,s3)));
|
||||||
|
for(;i<n;i++) r+=a[i]*b[i];
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t tts_rng = 0x12345678ABCDEF01ULL;
|
||||||
|
static int tts_sample_topk(const float* logits, int vocab, float temp, int k) {
|
||||||
|
struct IV{int i;float v;};
|
||||||
|
std::vector<IV> topk(k,{0,-FLT_MAX});
|
||||||
|
for(int i=0;i<vocab;i++){
|
||||||
|
if(logits[i]>topk[k-1].v){topk[k-1]={i,logits[i]};
|
||||||
|
for(int j=k-2;j>=0;j--){if(topk[j+1].v>topk[j].v)std::swap(topk[j],topk[j+1]);else break;}}
|
||||||
|
}
|
||||||
|
float maxv=topk[0].v,sum=0;
|
||||||
|
for(auto&t:topk){t.v=expf((t.v-maxv)/temp);sum+=t.v;}
|
||||||
|
tts_rng=tts_rng*6364136223846793005ULL+1442695040888963407ULL;
|
||||||
|
float r=(float)((tts_rng>>33)&0x7FFFFFFF)/(float)0x7FFFFFFF*sum;
|
||||||
|
float acc=0;
|
||||||
|
for(auto&t:topk){acc+=t.v;if(acc>=r)return t.i;}
|
||||||
|
return topk[0].i;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace executorch::extension {
|
||||||
|
facebook::jni::local_ref<facebook::jni::JArrayInt>
|
||||||
|
ExecuTorchJni::runTtsPipelineImpl(
|
||||||
|
Module* talker, Module* cp,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPrefill, jint nPrefill,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTrailing, jint nTrailing,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCodecEmb,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpEmbs,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCpHeads,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jTSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCCos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jCSin,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jEos,
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JArrayFloat> jPad,
|
||||||
|
jint maxTokens)
|
||||||
|
{
|
||||||
|
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
|
||||||
|
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
|
||||||
|
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
|
||||||
|
static const int CODEC_EOS=2150;
|
||||||
|
|
||||||
|
// Copy JNI arrays
|
||||||
|
auto copyArr = [](facebook::jni::alias_ref<facebook::jni::JArrayFloat> j) {
|
||||||
|
std::vector<float> v(j->size());
|
||||||
|
j->getRegion(0, j->size(), v.data());
|
||||||
|
return v;
|
||||||
|
};
|
||||||
|
auto prefill=copyArr(jPrefill);
|
||||||
|
auto trailing=nTrailing>0?copyArr(jTrailing):std::vector<float>();
|
||||||
|
auto codecEmb=copyArr(jCodecEmb);
|
||||||
|
auto cpEmbs=copyArr(jCpEmbs);
|
||||||
|
auto cpHeads=copyArr(jCpHeads);
|
||||||
|
auto tCos=copyArr(jTCos),tSin=copyArr(jTSin);
|
||||||
|
auto cCos=copyArr(jCCos),cSin=copyArr(jCSin);
|
||||||
|
auto eosEmb=copyArr(jEos),padEmb=copyArr(jPad);
|
||||||
|
|
||||||
|
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
||||||
|
std::vector<float> tK(T_L*tkvElem,0),tV(T_L*tkvElem,0);
|
||||||
|
float mask[T_KV_LEN]; for(int i=0;i<T_KV_LEN;i++) mask[i]=-1e9f;
|
||||||
|
float hidden[DIM]={},logits[VOCAB]={};
|
||||||
|
std::vector<int> allCodes,cb0Hist;
|
||||||
|
int pos=0,curCb0=-1,trIdx=0;
|
||||||
|
|
||||||
|
// Get raw Method pointers for direct execution
|
||||||
|
auto tMethodRes = talker->method("forward");
|
||||||
|
auto cMethodRes = cp->method("forward");
|
||||||
|
if (!tMethodRes.ok() || !cMethodRes.ok()) {
|
||||||
|
ET_LOG(Error, "TTS cannot get Method pointers");
|
||||||
|
return facebook::jni::JArrayInt::newArray(0);
|
||||||
|
}
|
||||||
|
Method* tMethod = tMethodRes.get();
|
||||||
|
Method* cMethod = cMethodRes.get();
|
||||||
|
|
||||||
|
// Talker step: prepare_input_tensors + memcpy + execute (like cp_et_runner)
|
||||||
|
auto talkerStep = [&](const float* emb) {
|
||||||
|
int pi=std::min(pos,249);
|
||||||
|
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
||||||
|
if(mi>=0) mask[mi]=0.0f;
|
||||||
|
|
||||||
|
auto prep = executorch::extension::prepare_input_tensors(*tMethod);
|
||||||
|
if(!prep.ok()){ET_LOG(Error,"Talker prep fail");return;}
|
||||||
|
memcpy(tMethod->mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
|
||||||
|
memcpy(tMethod->mutable_input(1).toTensor().mutable_data_ptr<float>(), mask, T_KV_LEN*4);
|
||||||
|
memcpy(tMethod->mutable_input(2).toTensor().mutable_data_ptr<float>(), tCos.data()+pi*T_HD, T_HD*4);
|
||||||
|
memcpy(tMethod->mutable_input(3).toTensor().mutable_data_ptr<float>(), tSin.data()+pi*T_HD, T_HD*4);
|
||||||
|
for(int i=0;i<T_L;i++){
|
||||||
|
memcpy(tMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), tK.data()+i*tkvElem, tkvElem*4);
|
||||||
|
memcpy(tMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), tV.data()+i*tkvElem, tkvElem*4);
|
||||||
|
}
|
||||||
|
auto status = tMethod->execute();
|
||||||
|
if(status!=Error::Ok){ET_LOG(Error,"Talker exec fail: %d",(int)status);return;}
|
||||||
|
memcpy(hidden, tMethod->get_output(0).toTensor().const_data_ptr<float>(), DIM*4);
|
||||||
|
memcpy(logits, tMethod->get_output(1).toTensor().const_data_ptr<float>(), VOCAB*4);
|
||||||
|
for(int i=0;i<T_L;i++){
|
||||||
|
memcpy(tK.data()+i*tkvElem, tMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
|
||||||
|
memcpy(tV.data()+i*tkvElem, tMethod->get_output(3+i*2).toTensor().const_data_ptr<float>(), tkvElem*4);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
};
|
||||||
|
|
||||||
|
// CP step: 17 autoregressive steps using Method directly
|
||||||
|
auto cpStep = [&](const float* h, int cb0, int* codes) {
|
||||||
|
int ckvElem=C_KV*C_KV_LEN*C_HD;
|
||||||
|
std::vector<float> ckv(C_L*2*ckvElem,0);
|
||||||
|
for(int step=0;step<17;step++){
|
||||||
|
const float*emb;
|
||||||
|
if(step==0) emb=h;
|
||||||
|
else if(step==1) emb=codecEmb.data()+std::min(std::max(cb0,0),VOCAB-1)*DIM;
|
||||||
|
else emb=cpEmbs.data()+((long)(step-2)*CB_SIZE+std::min(std::max(codes[step-2],0),CB_SIZE-1))*DIM;
|
||||||
|
|
||||||
|
auto prep=executorch::extension::prepare_input_tensors(*cMethod);
|
||||||
|
if(!prep.ok()) break;
|
||||||
|
memcpy(cMethod->mutable_input(0).toTensor().mutable_data_ptr<float>(), emb, DIM*4);
|
||||||
|
float*mp=cMethod->mutable_input(1).toTensor().mutable_data_ptr<float>();
|
||||||
|
for(int p=0;p<C_KV_LEN;p++) mp[p]=(p>=C_KV_LEN-1-step)?0.0f:-1e9f;
|
||||||
|
memcpy(cMethod->mutable_input(2).toTensor().mutable_data_ptr<float>(), cCos.data()+step*C_HD, C_HD*4);
|
||||||
|
memcpy(cMethod->mutable_input(3).toTensor().mutable_data_ptr<float>(), cSin.data()+step*C_HD, C_HD*4);
|
||||||
|
for(int i=0;i<C_L;i++){
|
||||||
|
memcpy(cMethod->mutable_input(4+i*2).toTensor().mutable_data_ptr<float>(), ckv.data()+(i*2)*ckvElem, ckvElem*4);
|
||||||
|
memcpy(cMethod->mutable_input(5+i*2).toTensor().mutable_data_ptr<float>(), ckv.data()+(i*2+1)*ckvElem, ckvElem*4);
|
||||||
|
}
|
||||||
|
auto status=cMethod->execute();
|
||||||
|
if(status!=Error::Ok) break;
|
||||||
|
const float*ho=cMethod->get_output(0).toTensor().const_data_ptr<float>();
|
||||||
|
if(step>=1&&step-1<15){
|
||||||
|
const float*W=cpHeads.data()+(long)(step-1)*CB_SIZE*DIM;
|
||||||
|
int best=0;float bv=-FLT_MAX;
|
||||||
|
for(int j=0;j<CB_SIZE;j++){float d=tts_dot_neon(ho,W+j*DIM,DIM);if(d>bv){bv=d;best=j;}}
|
||||||
|
codes[step-1]=best;
|
||||||
|
}
|
||||||
|
for(int i=0;i<C_L;i++){
|
||||||
|
memcpy(ckv.data()+(i*2)*ckvElem,cMethod->get_output(1+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
|
||||||
|
memcpy(ckv.data()+(i*2+1)*ckvElem,cMethod->get_output(2+i*2).toTensor().const_data_ptr<float>(),ckvElem*4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto t0=std::chrono::high_resolution_clock::now();
|
||||||
|
// Prefill
|
||||||
|
for(int s=0;s<nPrefill;s++){
|
||||||
|
talkerStep(prefill.data()+s*DIM);
|
||||||
|
if(s==nPrefill-1){
|
||||||
|
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
||||||
|
curCb0=tts_sample_topk(logits,VOCAB,0.9f,50);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto tP=std::chrono::high_resolution_clock::now();
|
||||||
|
ET_LOG(Info,"TTS Prefill: %.0fms, cb0=%d",std::chrono::duration<float,std::milli>(tP-t0).count(),curCb0);
|
||||||
|
|
||||||
|
if(curCb0<0||curCb0==CODEC_EOS) return facebook::jni::JArrayInt::newArray(0);
|
||||||
|
|
||||||
|
// Generation
|
||||||
|
float totalTalker=0,totalCp=0;
|
||||||
|
for(int g=0;g<maxTokens;g++){
|
||||||
|
int codes[NUM_CB]={}; codes[0]=curCb0;
|
||||||
|
auto tc0=std::chrono::high_resolution_clock::now();
|
||||||
|
int cpCodes[15]={};
|
||||||
|
cpStep(hidden,curCb0,cpCodes);
|
||||||
|
auto tc1=std::chrono::high_resolution_clock::now();
|
||||||
|
totalCp+=std::chrono::duration<float,std::milli>(tc1-tc0).count();
|
||||||
|
for(int i=0;i<15;i++) codes[i+1]=cpCodes[i];
|
||||||
|
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
|
||||||
|
cb0Hist.push_back(curCb0);
|
||||||
|
|
||||||
|
// Next embed: pre-computed trailing OR codec_sum + eos/pad
|
||||||
|
float nextEmb[DIM]={};
|
||||||
|
if(trIdx<nTrailing){
|
||||||
|
memcpy(nextEmb,trailing.data()+trIdx*DIM,DIM*4);
|
||||||
|
trIdx++;
|
||||||
|
} else {
|
||||||
|
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
|
||||||
|
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
|
||||||
|
for(int cb=0;cb<15;cb++){
|
||||||
|
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
||||||
|
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
||||||
|
}
|
||||||
|
if(trIdx==nTrailing){for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmb[k];trIdx++;}
|
||||||
|
else {for(int k=0;k<DIM;k++) nextEmb[k]+=padEmb[k];}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tt0=std::chrono::high_resolution_clock::now();
|
||||||
|
talkerStep(nextEmb);
|
||||||
|
auto tt1=std::chrono::high_resolution_clock::now();
|
||||||
|
totalTalker+=std::chrono::duration<float,std::milli>(tt1-tt0).count();
|
||||||
|
|
||||||
|
for(int j=CB_SIZE;j<VOCAB;j++) if(j!=CODEC_EOS) logits[j]=-FLT_MAX;
|
||||||
|
std::unordered_set<int> seen(cb0Hist.begin(),cb0Hist.end());
|
||||||
|
for(int tok:seen) logits[tok]=(logits[tok]>0)?logits[tok]/1.05f:logits[tok]*1.05f;
|
||||||
|
int next=tts_sample_topk(logits,VOCAB,0.9f,50);
|
||||||
|
if(next==CODEC_EOS){ET_LOG(Info,"TTS EOS at %d",g+2);break;}
|
||||||
|
if((int)cb0Hist.size()>=9){
|
||||||
|
bool deg=true;for(int i=(int)cb0Hist.size()-9;i<(int)cb0Hist.size();i++)if(cb0Hist[i]!=next){deg=false;break;}
|
||||||
|
if(deg){ET_LOG(Info,"TTS Degen at %d",g+2);break;}
|
||||||
|
}
|
||||||
|
curCb0=next;
|
||||||
|
}
|
||||||
|
int nTok=(int)allCodes.size()/NUM_CB;
|
||||||
|
auto t1=std::chrono::high_resolution_clock::now();
|
||||||
|
ET_LOG(Info,"TTS Generated %d | Talker %.0fms (%.0f/step) | CP %.0fms (%.0f/step) | Total %.0fms",
|
||||||
|
nTok,totalTalker,totalTalker/std::max(nTok,1),totalCp,totalCp/std::max(nTok,1),
|
||||||
|
std::chrono::duration<float,std::milli>(t1-t0).count());
|
||||||
|
|
||||||
|
auto result=facebook::jni::JArrayInt::newArray((int)allCodes.size());
|
||||||
|
result->setRegion(0,(int)allCodes.size(),allCodes.data());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace executorch::extension
|
||||||
|
|
||||||
|
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
|
||||||
|
extern void register_natives_for_llm();
|
||||||
|
#else
|
||||||
|
// No op if we don't build LLM
|
||||||
|
void register_natives_for_llm() {}
|
||||||
|
#endif
|
||||||
|
extern void register_natives_for_runtime();
|
||||||
|
|
||||||
|
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
|
||||||
|
extern void register_natives_for_training();
|
||||||
|
#else
|
||||||
|
// No op if we don't build training JNI
|
||||||
|
void register_natives_for_training() {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||||
|
return facebook::jni::initialize(vm, [] {
|
||||||
|
executorch::extension::ExecuTorchJni::registerNatives();
|
||||||
|
register_natives_for_llm();
|
||||||
|
register_natives_for_runtime();
|
||||||
|
register_natives_for_training();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
@ -229,29 +229,11 @@ class Qwen3TtsEngine(
|
||||||
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
|
// Set ADSP library path for QNN HTP skel libs (needed by both Java and C++ paths)
|
||||||
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
|
android.system.Os.setenv("ADSP_LIBRARY_PATH", "$nativeLibDir;/data/local/tmp/kazeia/qnn_libs;/vendor/dsp/cdsp;/vendor/dsp", true)
|
||||||
|
|
||||||
// Try native C++ pipeline first (single QNN instance, no Java overhead)
|
// Load .pte modules via Java JNI (native pipeline uses same instances)
|
||||||
run {
|
run {
|
||||||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
||||||
val talkerPte = File("/data/local/tmp/kazeia/models/talker_transformer_fp16.pte")
|
if (etModel.exists() && cpPteModule == null) {
|
||||||
if (etModel.exists() && talkerPte.exists()) {
|
|
||||||
try {
|
try {
|
||||||
val tn = System.currentTimeMillis()
|
|
||||||
nativePipelineReady = TtsPipeline.nativeInit(
|
|
||||||
talkerPte.absolutePath, etModel.absolutePath
|
|
||||||
)
|
|
||||||
nlog("Native C++ pipeline: ${if (nativePipelineReady) "OK" else "FAILED"} (${System.currentTimeMillis() - tn}ms)")
|
|
||||||
} catch (e: Exception) {
|
|
||||||
nlog("Native pipeline init failed: ${e.message}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: Load Java .pte modules (only if native pipeline failed)
|
|
||||||
run {
|
|
||||||
val etModel = File("/data/local/tmp/kazeia/models/cp_transformer_fp16.pte")
|
|
||||||
if (!nativePipelineReady && etModel.exists() && cpPteModule == null) {
|
|
||||||
try {
|
|
||||||
nlog("Loading Java .pte modules (native unavailable)...")
|
|
||||||
val t0 = System.currentTimeMillis()
|
val t0 = System.currentTimeMillis()
|
||||||
cpPteModule = org.pytorch.executorch.Module.load(
|
cpPteModule = org.pytorch.executorch.Module.load(
|
||||||
etModel.absolutePath,
|
etModel.absolutePath,
|
||||||
|
|
@ -265,6 +247,10 @@ class Qwen3TtsEngine(
|
||||||
if (lmResult != 0) {
|
if (lmResult != 0) {
|
||||||
nlog("CP .pte loadMethod failed ($lmResult), disabling JNI")
|
nlog("CP .pte loadMethod failed ($lmResult), disabling JNI")
|
||||||
cpPteModule = null
|
cpPteModule = null
|
||||||
|
} else {
|
||||||
|
// Register CP module for native pipeline
|
||||||
|
cpPteModule!!.nativeSetCpModule()
|
||||||
|
nlog("CP module registered for native pipeline")
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
nlog("CP .pte JNI failed: ${e.message}")
|
nlog("CP .pte JNI failed: ${e.message}")
|
||||||
|
|
@ -2298,11 +2284,9 @@ class Qwen3TtsEngine(
|
||||||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||||||
|
|
||||||
val allCodes: Array<IntArray>
|
val allCodes: Array<IntArray>
|
||||||
// Native C++ disabled: QNN HTP compilation not deterministic between loads
|
// Native C++ pipeline using SAME Java Module instances (no quality loss)
|
||||||
// Two instances of same .pte give slightly different hidden states → trembling
|
if (talkerPteModule != null && cpPteModule != null) {
|
||||||
// Keep Java pipeline (same QNN instance, validated quality)
|
// C++ loop on Java's Module instances — same QNN compilation, no JNI overhead
|
||||||
if (false && nativePipelineReady) {
|
|
||||||
// Native C++ pipeline — zero Java overhead
|
|
||||||
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
||||||
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
||||||
val nTrailing = nTotal - nPrefill
|
val nTrailing = nTotal - nPrefill
|
||||||
|
|
@ -2336,8 +2320,8 @@ class Qwen3TtsEngine(
|
||||||
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
|
ttsPadEmbed = sp.sliceArray(2 * TALKER_DIM until 3 * TALKER_DIM)
|
||||||
}
|
}
|
||||||
|
|
||||||
nlog("Running native C++ pipeline...")
|
nlog("Running native C++ pipeline (shared Module)...")
|
||||||
val flat = TtsPipeline.nativeRun(
|
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||||||
prefillFlat, nPrefill,
|
prefillFlat, nPrefill,
|
||||||
trailingFlat, nTrailing,
|
trailingFlat, nTrailing,
|
||||||
codecEmbedding ?: FloatArray(0),
|
codecEmbedding ?: FloatArray(0),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue