/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ package org.pytorch.executorch; import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; /** * Java wrapper for ExecuTorch Module. * *

Warning: These APIs are experimental and subject to change without notice */ @Experimental public class Module { static { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } // Loads libexecutorch.so from jniLibs NativeLoader.loadLibrary("executorch"); } /** Load mode for the module. Load the whole file as a buffer. */ public static final int LOAD_MODE_FILE = 0; /** Load mode for the module. Use mmap to load pages into memory. */ public static final int LOAD_MODE_MMAP = 1; /** Load mode for the module. Use memory locking and handle errors. */ public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; /** Load mode for the module. Use memory locking and ignore errors. */ public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; private final HybridData mHybridData; private final Map mMethodMetadata; @DoNotStrip private static native HybridData initHybrid( String moduleAbsolutePath, int loadMode, int numThreads); private Module(String moduleAbsolutePath, int loadMode, int numThreads) { ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); mMethodMetadata = populateMethodMeta(); } private Map populateMethodMeta() { String[] methods = getMethods(); Map metadata = new HashMap(); for (String name : methods) { metadata.put(name, new MethodMetadata(name, getUsedBackends(name))); } return metadata; } /** Lock protecting the non-thread safe methods in mHybridData. */ private Lock mLock = new ReentrantLock(); /** * Loads a serialized ExecuTorch module from the specified path on the disk. * * @param modelPath path to file that contains the serialized ExecuTorch module. * @param loadMode load mode for the module. See constants in {@link Module}. * @return new {@link org.pytorch.executorch.Module} object which owns the model module. */ public static Module load(final String modelPath, int loadMode) { return load(modelPath, loadMode, 0); } /** * Loads a serialized ExecuTorch module from the specified path on the disk. * * @param modelPath path to file that contains the serialized ExecuTorch module. * @param loadMode load mode for the module. See constants in {@link Module}. * @param numThreads the number of threads to use for inference. A value of 0 defaults to a * hardware-specific default. * @return new {@link org.pytorch.executorch.Module} object which owns the model module. */ public static Module load(final String modelPath, int loadMode, int numThreads) { ExecuTorchRuntime.validateFilePath(modelPath, "model path"); return new Module(modelPath, loadMode, numThreads); } /** * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. * * @param modelPath path to file that contains the serialized ExecuTorch module. * @return new {@link org.pytorch.executorch.Module} object which owns the model module. */ public static Module load(final String modelPath) { return load(modelPath, LOAD_MODE_FILE); } /** * Runs the 'forward' method of this module with the specified arguments. * * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' * requires inputs but no inputs are given, the function will not error out, but run 'forward' * with sample inputs. * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { return execute("forward", inputs); } /** * Runs the specified method of this module with the specified arguments. * * @param methodName name of the ExecuTorch method to run. * @param inputs arguments that will be passed to ExecuTorch method. * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { try { mLock.lock(); if (!mHybridData.isValid()) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new EValue[0]; } return executeNative(methodName, inputs); } finally { mLock.unlock(); } } @DoNotStrip private native EValue[] executeNative(String methodName, EValue... inputs); /** * Load a method on this module. This might help with the first time inference performance, * because otherwise the method is loaded lazily when it's execute. Note: this function is * synchronous, and will block until the method is loaded. Therefore, it is recommended to call * this on a background thread. However, users need to make sure that they don't execute before * this function returns. * * @return the Error code if there was an error loading the method */ public int loadMethod(String methodName) { try { mLock.lock(); if (!mHybridData.isValid()) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return 0x2; // InvalidState } return loadMethodNative(methodName); } finally { mLock.unlock(); } } @DoNotStrip private native int loadMethodNative(String methodName); /** * Returns the names of the backends in a certain method. * * @param methodName method name to query * @return an array of backend name */ @DoNotStrip private native String[] getUsedBackends(String methodName); /** * Returns the names of methods. * * @return name of methods in this Module */ @DoNotStrip public native String[] getMethods(); /** * Get the corresponding @MethodMetadata for a method * * @param name method name * @return @MethodMetadata for this method */ public MethodMetadata getMethodMetadata(String name) { MethodMetadata methodMetadata = mMethodMetadata.get(name); if (methodMetadata == null) { throw new IllegalArgumentException("method " + name + " does not exist for this module"); } return methodMetadata; } @DoNotStrip private static native String[] readLogBufferStaticNative(); public static String[] readLogBufferStatic() { return readLogBufferStaticNative(); } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { return readLogBufferNative(); } @DoNotStrip private native String[] readLogBufferNative(); /** * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. * *

Currently for internal (minibench) use only. * * @return true if the etdump was successfully written, false otherwise. */ @Experimental @DoNotStrip public native boolean etdump(); /** TTS Pipeline: set this module as the CP (code predictor) for the pipeline. */ @DoNotStrip public native void nativeSetCpModule(); /** TTS Pipeline: run full talker+CP loop in C++. 'this' is the talker module. */ @DoNotStrip public native int[] nativeRunTtsPipeline( float[] prefill, int nPrefill, float[] trailing, int nTrailing, float[] codecEmb, float[] cpEmbs, float[] cpHeads, float[] talkerCos, float[] talkerSin, float[] cpCos, float[] cpSin, float[] eosEmbed, float[] padEmbed, int maxTokens); /** * Explicitly destroys the native Module object. Calling this method is not required, as the * native object will be destroyed when this object is garbage-collected. However, the timing of * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. */ public void destroy() { if (mLock.tryLock()) { try { mHybridData.resetNative(); } finally { mLock.unlock(); } } else { Log.w( "ExecuTorch", "Destroy was called while the module was in use. Resources will not be immediately" + " released."); } } }