kazeia/executorch-patches/qwen3_4b_decoder.patch

73 lines
3.0 KiB
Diff
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py
index 963db6e..9ccfdd0 100644
--- a/examples/qualcomm/oss_scripts/llama/__init__.py
+++ b/examples/qualcomm/oss_scripts/llama/__init__.py
@@ -25,9 +25,14 @@ from executorch.examples.models.granite import (
from executorch.examples.models.internvl3 import (
convert_weights as convert_internvl3_weights,
)
-from executorch.examples.models.phi_4_mini import (
- convert_weights as convert_phi_4_mini_weights,
-)
+try:
+ from executorch.examples.models.phi_4_mini import (
+ convert_weights as convert_phi_4_mini_weights,
+ )
+except ImportError:
+ # phi_4_mini pulls in torchtune which conflicts with our torchao pin.
+ # We don't need phi for Qwen3 export, so tolerate the missing dep.
+ convert_phi_4_mini_weights = None
from executorch.examples.models.qwen2_5 import (
convert_weights as convert_qwen2_5_weights,
)
@@ -479,6 +484,37 @@ class Qwen3_1_7B(LLMModelConfig):
quant_recipe = Qwen3_1_7BQuantRecipe
+@register_llm_model("qwen3-4b")
+@dataclass(init=False, frozen=True)
+class Qwen3_4B(LLMModelConfig):
+ # Local Kazeia addition. Mirrors the Qwen3_1_7B registration; the 4B
+ # variant uses the same convert_weights and 16a4w quant recipe but a
+ # bigger params file. With 4B params at 16a4w the .pte stays under the
+ # 4 GB HTP single-context limit on V79 (empirically ~2.5 GB), so
+ # num_sharding=1 is fine. Compile time on the host is the main cost
+ # (3-4 h on a 16-core x86_64 machine).
+ repo_id: str = "Qwen/Qwen3-4B"
+ params_path: str = os.path.join(
+ BASE_DIR, "../../../models/qwen3/config/4b_config.json"
+ )
+ convert_weights = convert_qwen3_weights
+ transform_weight = False
+ instruct_model = True
+ # num_sharding=1 for hybrid mode: sharding=2 produces a multi-context
+ # .pte (2 graphs × 2 shards = 4 contexts) that the LlmModule load path
+ # can't restore (error 5010 "Context group 1 does not exist"). With
+ # sharding=1 the hybrid export needs ~46 GB RAM peak — the 192 GB swap
+ # on /swapfile handles this; compile takes ~80 min wall but completes
+ # cleanly. Single-context .pte loads fine through the JNI runner.
+ num_sharding = 1
+ masked_softmax = True
+ seq_mse_candidates = 0
+ r1 = False
+ r2 = False
+ r3 = True
+ quant_recipe = Qwen3_1_7BQuantRecipe
+
+
@register_llm_model("smollm2_135m")
@dataclass(init=False, frozen=True)
class Smollm2_135M(LLMModelConfig):
diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py
index 74e3959..995c498 100644
--- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py
+++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py
@@ -55,6 +55,7 @@ DECODER_MODEL_VERSION = {
"qwen2_5-1_5b": "qwen2_5",
"qwen3-0_6b": "qwen3",
"qwen3-1_7b": "qwen3",
+ "qwen3-4b": "qwen3",
"smollm2_135m": "smollm2_135m",
"smollm3-3b": "smollm3",
"glm-1_5b": "glm",