#!/usr/bin/env python3 """ Export Qwen3-TTS talker transformer to ExecuTorch .pte on QNN HTP fp16. 28 layers, 1024 dim, GQA 16/8, M-RoPE, codec_head. Fixed KV cache with shift (like CP export). """ import os, sys, warnings, torch, torch.nn as nn, torch.nn.functional as F sys.path = [p for p in sys.path if 'Kazeia/executorch' not in p and p != '.'] os.environ['QNN_SDK_ROOT'] = os.environ.get('QNN_SDK_ROOT', '') warnings.filterwarnings('ignore') N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2 VOCAB = 3072; FFN = 3072 KV_LEN = 100 # Must be >= prefill+maxGen. KV=64 caused quality loss (role tokens evicted) state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth", map_location="cpu", weights_only=False) def rotate_half(x): x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) def repeat_kv(x, n): B, H, T, D = x.shape return x[:, :, None, :, :].expand(B, H, n, T, D).reshape(B, H*n, T, D) class RMSNorm(nn.Module): def __init__(s, d): super().__init__(); s.weight = nn.Parameter(torch.ones(d)) def forward(s, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * s.weight class TalkerTransformer(nn.Module): """Talker transformer without codec_head (head done on CPU).""" def __init__(s, st): super().__init__() s.na = nn.ModuleList(); s.nf = nn.ModuleList() s.qp = nn.ModuleList(); s.kp = nn.ModuleList() s.vp = nn.ModuleList(); s.op = nn.ModuleList() s.qn = nn.ModuleList(); s.kn = nn.ModuleList() s.ga = nn.ModuleList(); s.dn = nn.ModuleList(); s.up = nn.ModuleList() for i in range(N_L): p = f"layers.{i}." a = RMSNorm(DIM); a.weight.data = st[p+"attention_norm.weight"]; s.na.append(a) f = RMSNorm(DIM); f.weight.data = st[p+"ffn_norm.weight"]; s.nf.append(f) s.qp.append(nn.Linear(DIM, N_H*HD, bias=False)); s.qp[-1].weight.data = st[p+"attention.wq.weight"] s.kp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.kp[-1].weight.data = st[p+"attention.wk.weight"] s.vp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.vp[-1].weight.data = st[p+"attention.wv.weight"] s.op.append(nn.Linear(N_H*HD, DIM, bias=False)); s.op[-1].weight.data = st[p+"attention.wo.weight"] q = RMSNorm(HD); q.weight.data = st[p+"attention.q_norm_fn.weight"]; s.qn.append(q) k = RMSNorm(HD); k.weight.data = st[p+"attention.k_norm_fn.weight"]; s.kn.append(k) s.ga.append(nn.Linear(DIM, FFN, bias=False)); s.ga[-1].weight.data = st[p+"feed_forward.w1.weight"] s.dn.append(nn.Linear(FFN, DIM, bias=False)); s.dn[-1].weight.data = st[p+"feed_forward.w2.weight"] s.up.append(nn.Linear(DIM, FFN, bias=False)); s.up[-1].weight.data = st[p+"feed_forward.w3.weight"] s.fn = RMSNorm(DIM); s.fn.weight.data = st["norm.weight"] # Include codec_head for CB0 prediction s.head = nn.Linear(DIM, VOCAB, bias=False); s.head.weight.data = st["output.weight"] def forward(s, emb, mask, cos, sin, *kv_args): """ emb: [1,1,1024] mask: [1,1,1,KV_LEN] cos: [1,1,128] sin: [1,1,128] kv: 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128]) Returns: hidden[1,1,1024], logits[1,1,3072], 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128]) """ h = emb; c = cos.unsqueeze(1); sn = sin.unsqueeze(1) nk = [] for i in range(N_L): kc = kv_args[i*2]; vc = kv_args[i*2+1] r = h; hn = s.na[i](h) q = s.qp[i](hn).view(1,1,N_H,HD).transpose(1,2) k = s.kp[i](hn).view(1,1,N_KV,HD).transpose(1,2) v = s.vp[i](hn).view(1,1,N_KV,HD).transpose(1,2) q = s.qn[i](q); k = s.kn[i](k) q = q*c + rotate_half(q)*sn; k = k*c + rotate_half(k)*sn # Shift KV: drop oldest, append new kf = torch.cat([kc[:,:,1:,:], k], dim=2) vf = torch.cat([vc[:,:,1:,:], v], dim=2) ke = repeat_kv(kf, N_REP); ve = repeat_kv(vf, N_REP) sc = torch.matmul(q, ke.transpose(-2,-1)) * (1.0/(HD**0.5)) + mask ao = torch.matmul(F.softmax(sc, dim=-1), ve).transpose(1,2).contiguous().view(1,1,-1) h = r + s.op[i](ao) r = h; fn = s.nf[i](h) h = r + s.dn[i](F.silu(s.ga[i](fn)) * s.up[i](fn)) nk.extend([kf, vf]) h = s.fn(h) logits = s.head(h) return (h, logits, *nk) print("Building talker transformer...") w = TalkerTransformer(state).eval() n_params = sum(p.numel() for p in w.parameters()) print(f"Params: {n_params/1e6:.1f}M ({n_params*2/1024/1024:.0f}MB fp16)") # Test e = torch.randn(1,1,DIM) m = torch.full((1,1,1,KV_LEN), -1e9); m[0,0,0,-1] = 0 inv = 1.0/(1e6**(torch.arange(0, HD, 2, dtype=torch.float32)/HD)) c0 = torch.cos(0*inv).repeat(2).unsqueeze(0).unsqueeze(0) s0 = torch.sin(0*inv).repeat(2).unsqueeze(0).unsqueeze(0) kvs = [torch.zeros(1, N_KV, KV_LEN, HD) for _ in range(N_L*2)] with torch.no_grad(): out = w(e, m, c0, s0, *kvs) print(f"Test: hidden={out[0].shape}, logits={out[1].shape}, kv0={out[2].shape}") # ExecuTorch export from executorch.backends.qualcomm.utils.utils import * htp = generate_htp_compiler_spec(use_fp16=True) bo = QnnExecuTorchBackendOptions(backend_type=QnnExecuTorchBackendType.kHtpBackend, htp_options=htp) specs = generate_qnn_executorch_compiler_spec(soc_model=QcomChipset.SM8750, backend_options=bo) print(f"Lowering talker transformer ({N_L} layers, KV={KV_LEN}) to QNN...") edge = to_edge_transform_and_lower_to_qnn(w, (e, m, c0, s0, *kvs), compiler_specs=specs) print("LOWERED!") pte = edge.to_executorch() OUT = "/opt/Kazeia/models_qnn/talker_transformer_fp16.pte" with open(OUT, "wb") as f: pte.write_to_file(f) print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")