kazeia/executorch-patches/llm_in_process_jni.patch

64 lines
2.9 KiB
Diff

diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt
index e93731e..4951e1d 100644
--- a/backends/qualcomm/CMakeLists.txt
+++ b/backends/qualcomm/CMakeLists.txt
@@ -308,8 +308,8 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES Hexagon)
)
endif()
-# QNN pybind
-if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
+# QNN pybind — host Python bindings, not for Android cross-compile
+if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64" AND NOT ANDROID)
add_subdirectory(
${EXECUTORCH_SOURCE_DIR}/third-party/pybind11
${CMAKE_CURRENT_BINARY_DIR}/pybind11
diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp
index 45f2414..7c4e1aa 100644
--- a/extension/android/jni/jni_layer_llama.cpp
+++ b/extension/android/jni/jni_layer_llama.cpp
@@ -171,14 +171,35 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
model_path->toStdString().c_str(),
data_files_vector,
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
- std::string decoder_model = "llama3"; // use llama3 for now
- runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
- std::move(module),
- decoder_model.c_str(),
- model_path->toStdString().c_str(),
- tokenizer_path->toStdString().c_str(),
- "",
- "");
+ std::string decoder_model = "qwen3"; // Kazeia: our .pte was exported with --decoder_model qwen3-4b
+
+ // Mirror qnn_llama_runner.cpp main(): pick the Runner<T> template based
+ // on the model's get_kv_io_bit_width metadata. The 16-bit KV I/O models
+ // were introduced after the 8-bit ones, and using the wrong T treats
+ // KV-cache bytes as the wrong width → garbage logits → gibberish output.
+ example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
+ if (module->method_names()->count("get_kv_io_bit_width") > 0) {
+ kv_bitwidth = static_cast<example::KvBitWidth>(
+ module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
+ }
+ auto make_runner = [&](auto sample) -> std::unique_ptr<llm::IRunner> {
+ using T = decltype(sample);
+ return std::make_unique<example::Runner<T>>(
+ std::move(module),
+ decoder_model.c_str(),
+ model_path->toStdString().c_str(),
+ tokenizer_path->toStdString().c_str(),
+ /* performance_output_path */ "",
+ /* dump_logits_path */ "",
+ /* temperature */ 0.0f, // greedy
+ /* eval_mode */ 0, // EvalMode::kKVCached
+ /* shared_buffer */ true);
+ };
+ if (kv_bitwidth == example::KvBitWidth::kWidth16) {
+ runner_ = make_runner(uint16_t{0});
+ } else {
+ runner_ = make_runner(uint8_t{0});
+ }
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
#endif
#if defined(EXECUTORCH_BUILD_MEDIATEK)