263 lines
8.6 KiB
Java
263 lines
8.6 KiB
Java
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
package org.pytorch.executorch;
|
|
|
|
import android.util.Log;
|
|
import com.facebook.jni.HybridData;
|
|
import com.facebook.jni.annotations.DoNotStrip;
|
|
import com.facebook.soloader.nativeloader.NativeLoader;
|
|
import com.facebook.soloader.nativeloader.SystemDelegate;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.concurrent.locks.Lock;
|
|
import java.util.concurrent.locks.ReentrantLock;
|
|
import org.pytorch.executorch.annotations.Experimental;
|
|
|
|
/**
|
|
* Java wrapper for ExecuTorch Module.
|
|
*
|
|
* <p>Warning: These APIs are experimental and subject to change without notice
|
|
*/
|
|
@Experimental
|
|
public class Module {
|
|
|
|
static {
|
|
if (!NativeLoader.isInitialized()) {
|
|
NativeLoader.init(new SystemDelegate());
|
|
}
|
|
// Loads libexecutorch.so from jniLibs
|
|
NativeLoader.loadLibrary("executorch");
|
|
}
|
|
|
|
/** Load mode for the module. Load the whole file as a buffer. */
|
|
public static final int LOAD_MODE_FILE = 0;
|
|
|
|
/** Load mode for the module. Use mmap to load pages into memory. */
|
|
public static final int LOAD_MODE_MMAP = 1;
|
|
|
|
/** Load mode for the module. Use memory locking and handle errors. */
|
|
public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
|
|
|
|
/** Load mode for the module. Use memory locking and ignore errors. */
|
|
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
|
|
|
|
private final HybridData mHybridData;
|
|
|
|
private final Map<String, MethodMetadata> mMethodMetadata;
|
|
|
|
@DoNotStrip
|
|
private static native HybridData initHybrid(
|
|
String moduleAbsolutePath, int loadMode, int numThreads);
|
|
|
|
private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
|
|
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
|
|
|
|
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
|
|
|
|
mMethodMetadata = populateMethodMeta();
|
|
}
|
|
|
|
private Map<String, MethodMetadata> populateMethodMeta() {
|
|
String[] methods = getMethods();
|
|
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
|
|
for (String name : methods) {
|
|
metadata.put(name, new MethodMetadata(name, getUsedBackends(name)));
|
|
}
|
|
return metadata;
|
|
}
|
|
|
|
/** Lock protecting the non-thread safe methods in mHybridData. */
|
|
private Lock mLock = new ReentrantLock();
|
|
|
|
/**
|
|
* Loads a serialized ExecuTorch module from the specified path on the disk.
|
|
*
|
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
|
* @param loadMode load mode for the module. See constants in {@link Module}.
|
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
|
*/
|
|
public static Module load(final String modelPath, int loadMode) {
|
|
return load(modelPath, loadMode, 0);
|
|
}
|
|
|
|
/**
|
|
* Loads a serialized ExecuTorch module from the specified path on the disk.
|
|
*
|
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
|
* @param loadMode load mode for the module. See constants in {@link Module}.
|
|
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
|
|
* hardware-specific default.
|
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
|
*/
|
|
public static Module load(final String modelPath, int loadMode, int numThreads) {
|
|
ExecuTorchRuntime.validateFilePath(modelPath, "model path");
|
|
return new Module(modelPath, loadMode, numThreads);
|
|
}
|
|
|
|
/**
|
|
* Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU.
|
|
*
|
|
* @param modelPath path to file that contains the serialized ExecuTorch module.
|
|
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
|
|
*/
|
|
public static Module load(final String modelPath) {
|
|
return load(modelPath, LOAD_MODE_FILE);
|
|
}
|
|
|
|
/**
|
|
* Runs the 'forward' method of this module with the specified arguments.
|
|
*
|
|
* @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
|
|
* requires inputs but no inputs are given, the function will not error out, but run 'forward'
|
|
* with sample inputs.
|
|
* @return return value from the 'forward' method.
|
|
*/
|
|
public EValue[] forward(EValue... inputs) {
|
|
return execute("forward", inputs);
|
|
}
|
|
|
|
/**
|
|
* Runs the specified method of this module with the specified arguments.
|
|
*
|
|
* @param methodName name of the ExecuTorch method to run.
|
|
* @param inputs arguments that will be passed to ExecuTorch method.
|
|
* @return return value from the method.
|
|
*/
|
|
public EValue[] execute(String methodName, EValue... inputs) {
|
|
try {
|
|
mLock.lock();
|
|
if (!mHybridData.isValid()) {
|
|
Log.e("ExecuTorch", "Attempt to use a destroyed module");
|
|
return new EValue[0];
|
|
}
|
|
return executeNative(methodName, inputs);
|
|
} finally {
|
|
mLock.unlock();
|
|
}
|
|
}
|
|
|
|
@DoNotStrip
|
|
private native EValue[] executeNative(String methodName, EValue... inputs);
|
|
|
|
/**
|
|
* Load a method on this module. This might help with the first time inference performance,
|
|
* because otherwise the method is loaded lazily when it's execute. Note: this function is
|
|
* synchronous, and will block until the method is loaded. Therefore, it is recommended to call
|
|
* this on a background thread. However, users need to make sure that they don't execute before
|
|
* this function returns.
|
|
*
|
|
* @return the Error code if there was an error loading the method
|
|
*/
|
|
public int loadMethod(String methodName) {
|
|
try {
|
|
mLock.lock();
|
|
if (!mHybridData.isValid()) {
|
|
Log.e("ExecuTorch", "Attempt to use a destroyed module");
|
|
return 0x2; // InvalidState
|
|
}
|
|
return loadMethodNative(methodName);
|
|
} finally {
|
|
mLock.unlock();
|
|
}
|
|
}
|
|
|
|
@DoNotStrip
|
|
private native int loadMethodNative(String methodName);
|
|
|
|
/**
|
|
* Returns the names of the backends in a certain method.
|
|
*
|
|
* @param methodName method name to query
|
|
* @return an array of backend name
|
|
*/
|
|
@DoNotStrip
|
|
private native String[] getUsedBackends(String methodName);
|
|
|
|
/**
|
|
* Returns the names of methods.
|
|
*
|
|
* @return name of methods in this Module
|
|
*/
|
|
@DoNotStrip
|
|
public native String[] getMethods();
|
|
|
|
/**
|
|
* Get the corresponding @MethodMetadata for a method
|
|
*
|
|
* @param name method name
|
|
* @return @MethodMetadata for this method
|
|
*/
|
|
public MethodMetadata getMethodMetadata(String name) {
|
|
MethodMetadata methodMetadata = mMethodMetadata.get(name);
|
|
if (methodMetadata == null) {
|
|
throw new IllegalArgumentException("method " + name + " does not exist for this module");
|
|
}
|
|
return methodMetadata;
|
|
}
|
|
|
|
@DoNotStrip
|
|
private static native String[] readLogBufferStaticNative();
|
|
|
|
public static String[] readLogBufferStatic() {
|
|
return readLogBufferStaticNative();
|
|
}
|
|
|
|
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
|
|
public String[] readLogBuffer() {
|
|
return readLogBufferNative();
|
|
}
|
|
|
|
@DoNotStrip
|
|
private native String[] readLogBufferNative();
|
|
|
|
/**
|
|
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
|
|
*
|
|
* <p>Currently for internal (minibench) use only.
|
|
*
|
|
* @return true if the etdump was successfully written, false otherwise.
|
|
*/
|
|
@Experimental
|
|
@DoNotStrip
|
|
public native boolean etdump();
|
|
|
|
/** TTS Pipeline: set this module as the CP (code predictor) for the pipeline. */
|
|
@DoNotStrip
|
|
public native void nativeSetCpModule();
|
|
|
|
/** TTS Pipeline: run full talker+CP loop in C++. 'this' is the talker module. */
|
|
@DoNotStrip
|
|
public native int[] nativeRunTtsPipeline(
|
|
float[] prefill, int nPrefill, float[] trailing, int nTrailing,
|
|
float[] codecEmb, float[] cpEmbs, float[] cpHeads,
|
|
float[] talkerCos, float[] talkerSin, float[] cpCos, float[] cpSin,
|
|
float[] eosEmbed, float[] padEmbed, int maxTokens);
|
|
|
|
/**
|
|
* Explicitly destroys the native Module object. Calling this method is not required, as the
|
|
* native object will be destroyed when this object is garbage-collected. However, the timing of
|
|
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
|
|
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
|
*/
|
|
public void destroy() {
|
|
if (mLock.tryLock()) {
|
|
try {
|
|
mHybridData.resetNative();
|
|
} finally {
|
|
mLock.unlock();
|
|
}
|
|
} else {
|
|
Log.w(
|
|
"ExecuTorch",
|
|
"Destroy was called while the module was in use. Resources will not be immediately"
|
|
+ " released.");
|
|
}
|
|
}
|
|
}
|