LLM JNI: auto-detect eval_mode from .pte methods (kv-only vs hybrid)

Replace the hardcoded eval_mode=0 in the QNN_LLAMA branch with a runtime
check on the loaded module's method names: if the .pte exposes a
prefill_forward graph, switch to EvalMode::kHybrid (1) — the runner can
then batch the entire prompt through prefill_forward in one parallel pass
instead of running 52 ms/token sequentially through kv_forward. Falls
back to kKVCached (0) when only kv_forward exists, matching the current
.pte behaviour exactly so this is a safe in-place upgrade ahead of the
hybrid re-export.

Sanity-tested with the kv-only Qwen3-4B .pte already on the tablet:
  Prompt 'Bonjour, ça va ?' → "Bonjour ! Ça va, merci de me demander ça.
  Tu as une question ?", TTFT 2728 ms, total 4158 ms — no change vs the
  hardcoded eval_mode=0 build.

Once the hybrid Qwen3-4B export finishes (~50 min compile, both
prefill_forward + kv_forward graphs), the same JNI binary will pick up
the new .pte and TTFT should drop to <1 s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-14 12:45:10 +02:00
parent 3d435f9cdd
commit f4b15a72a7
1 changed files with 12 additions and 3 deletions

View File

@ -14,10 +14,10 @@ index e93731e..4951e1d 100644
${EXECUTORCH_SOURCE_DIR}/third-party/pybind11
${CMAKE_CURRENT_BINARY_DIR}/pybind11
diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp
index 45f2414..7c4e1aa 100644
index 45f2414..ae3d79f 100644
--- a/extension/android/jni/jni_layer_llama.cpp
+++ b/extension/android/jni/jni_layer_llama.cpp
@@ -171,14 +171,35 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
@@ -171,14 +171,44 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
model_path->toStdString().c_str(),
data_files_vector,
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
@ -40,6 +40,15 @@ index 45f2414..7c4e1aa 100644
+ kv_bitwidth = static_cast<example::KvBitWidth>(
+ module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
+ }
+ // Auto-detect eval_mode: kv-only (0) if the .pte only carries
+ // kv_forward, hybrid (1) if it also has prefill_forward (which lets the
+ // runner batch the prompt prefill — TTFT drops from ~52 ms/token to
+ // sub-ms after the one-shot prefill graph). Same JNI binary works with
+ // both export modes, no code change needed when the .pte is upgraded.
+ int eval_mode = 0;
+ if (module->method_names()->count("prefill_forward") > 0) {
+ eval_mode = 1; // EvalMode::kHybrid
+ }
+ auto make_runner = [&](auto sample) -> std::unique_ptr<llm::IRunner> {
+ using T = decltype(sample);
+ return std::make_unique<example::Runner<T>>(
@ -50,7 +59,7 @@ index 45f2414..7c4e1aa 100644
+ /* performance_output_path */ "",
+ /* dump_logits_path */ "",
+ /* temperature */ 0.0f, // greedy
+ /* eval_mode */ 0, // EvalMode::kKVCached
+ eval_mode,
+ /* shared_buffer */ true);
+ };
+ if (kv_bitwidth == example::KvBitWidth::kWidth16) {