diff --git a/executorch-patches/llm_in_process_jni.patch b/executorch-patches/llm_in_process_jni.patch index 357c7cd..05fd12c 100644 --- a/executorch-patches/llm_in_process_jni.patch +++ b/executorch-patches/llm_in_process_jni.patch @@ -14,10 +14,10 @@ index e93731e..4951e1d 100644 ${EXECUTORCH_SOURCE_DIR}/third-party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11 diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp -index 45f2414..7c4e1aa 100644 +index 45f2414..ae3d79f 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp -@@ -171,14 +171,35 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { +@@ -171,14 +171,44 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { model_path->toStdString().c_str(), data_files_vector, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); @@ -40,6 +40,15 @@ index 45f2414..7c4e1aa 100644 + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } ++ // Auto-detect eval_mode: kv-only (0) if the .pte only carries ++ // kv_forward, hybrid (1) if it also has prefill_forward (which lets the ++ // runner batch the prompt prefill — TTFT drops from ~52 ms/token to ++ // sub-ms after the one-shot prefill graph). Same JNI binary works with ++ // both export modes, no code change needed when the .pte is upgraded. ++ int eval_mode = 0; ++ if (module->method_names()->count("prefill_forward") > 0) { ++ eval_mode = 1; // EvalMode::kHybrid ++ } + auto make_runner = [&](auto sample) -> std::unique_ptr { + using T = decltype(sample); + return std::make_unique>( @@ -50,7 +59,7 @@ index 45f2414..7c4e1aa 100644 + /* performance_output_path */ "", + /* dump_logits_path */ "", + /* temperature */ 0.0f, // greedy -+ /* eval_mode */ 0, // EvalMode::kKVCached ++ eval_mode, + /* shared_buffer */ true); + }; + if (kv_bitwidth == example::KvBitWidth::kWidth16) {