kazeia/scripts/export_talker_pte.py

124 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Export Qwen3-TTS talker transformer to ExecuTorch .pte on QNN HTP fp16.
28 layers, 1024 dim, GQA 16/8, M-RoPE, codec_head.
Fixed KV cache with shift (like CP export).
"""
import os, sys, warnings, torch, torch.nn as nn, torch.nn.functional as F
sys.path = [p for p in sys.path if 'Kazeia/executorch' not in p and p != '.']
os.environ['QNN_SDK_ROOT'] = os.environ.get('QNN_SDK_ROOT', '')
warnings.filterwarnings('ignore')
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
VOCAB = 3072; FFN = 3072
KV_LEN = 64 # Reduced from 100: saves 36% memcpy, sufficient for ~70 token generation
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
map_location="cpu", weights_only=False)
def rotate_half(x):
x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(x, n):
B, H, T, D = x.shape
return x[:, :, None, :, :].expand(B, H, n, T, D).reshape(B, H*n, T, D)
class RMSNorm(nn.Module):
def __init__(s, d):
super().__init__(); s.weight = nn.Parameter(torch.ones(d))
def forward(s, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * s.weight
class TalkerTransformer(nn.Module):
"""Talker transformer without codec_head (head done on CPU)."""
def __init__(s, st):
super().__init__()
s.na = nn.ModuleList(); s.nf = nn.ModuleList()
s.qp = nn.ModuleList(); s.kp = nn.ModuleList()
s.vp = nn.ModuleList(); s.op = nn.ModuleList()
s.qn = nn.ModuleList(); s.kn = nn.ModuleList()
s.ga = nn.ModuleList(); s.dn = nn.ModuleList(); s.up = nn.ModuleList()
for i in range(N_L):
p = f"layers.{i}."
a = RMSNorm(DIM); a.weight.data = st[p+"attention_norm.weight"]; s.na.append(a)
f = RMSNorm(DIM); f.weight.data = st[p+"ffn_norm.weight"]; s.nf.append(f)
s.qp.append(nn.Linear(DIM, N_H*HD, bias=False)); s.qp[-1].weight.data = st[p+"attention.wq.weight"]
s.kp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.kp[-1].weight.data = st[p+"attention.wk.weight"]
s.vp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.vp[-1].weight.data = st[p+"attention.wv.weight"]
s.op.append(nn.Linear(N_H*HD, DIM, bias=False)); s.op[-1].weight.data = st[p+"attention.wo.weight"]
q = RMSNorm(HD); q.weight.data = st[p+"attention.q_norm_fn.weight"]; s.qn.append(q)
k = RMSNorm(HD); k.weight.data = st[p+"attention.k_norm_fn.weight"]; s.kn.append(k)
s.ga.append(nn.Linear(DIM, FFN, bias=False)); s.ga[-1].weight.data = st[p+"feed_forward.w1.weight"]
s.dn.append(nn.Linear(FFN, DIM, bias=False)); s.dn[-1].weight.data = st[p+"feed_forward.w2.weight"]
s.up.append(nn.Linear(DIM, FFN, bias=False)); s.up[-1].weight.data = st[p+"feed_forward.w3.weight"]
s.fn = RMSNorm(DIM); s.fn.weight.data = st["norm.weight"]
# Include codec_head for CB0 prediction
s.head = nn.Linear(DIM, VOCAB, bias=False); s.head.weight.data = st["output.weight"]
def forward(s, emb, mask, cos, sin, *kv_args):
"""
emb: [1,1,1024]
mask: [1,1,1,KV_LEN]
cos: [1,1,128]
sin: [1,1,128]
kv: 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128])
Returns: hidden[1,1,1024], logits[1,1,3072], 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128])
"""
h = emb; c = cos.unsqueeze(1); sn = sin.unsqueeze(1)
nk = []
for i in range(N_L):
kc = kv_args[i*2]; vc = kv_args[i*2+1]
r = h; hn = s.na[i](h)
q = s.qp[i](hn).view(1,1,N_H,HD).transpose(1,2)
k = s.kp[i](hn).view(1,1,N_KV,HD).transpose(1,2)
v = s.vp[i](hn).view(1,1,N_KV,HD).transpose(1,2)
q = s.qn[i](q); k = s.kn[i](k)
q = q*c + rotate_half(q)*sn; k = k*c + rotate_half(k)*sn
# Shift KV: drop oldest, append new
kf = torch.cat([kc[:,:,1:,:], k], dim=2)
vf = torch.cat([vc[:,:,1:,:], v], dim=2)
ke = repeat_kv(kf, N_REP); ve = repeat_kv(vf, N_REP)
sc = torch.matmul(q, ke.transpose(-2,-1)) * (1.0/(HD**0.5)) + mask
ao = torch.matmul(F.softmax(sc, dim=-1), ve).transpose(1,2).contiguous().view(1,1,-1)
h = r + s.op[i](ao)
r = h; fn = s.nf[i](h)
h = r + s.dn[i](F.silu(s.ga[i](fn)) * s.up[i](fn))
nk.extend([kf, vf])
h = s.fn(h)
logits = s.head(h)
return (h, logits, *nk)
print("Building talker transformer...")
w = TalkerTransformer(state).eval()
n_params = sum(p.numel() for p in w.parameters())
print(f"Params: {n_params/1e6:.1f}M ({n_params*2/1024/1024:.0f}MB fp16)")
# Test
e = torch.randn(1,1,DIM)
m = torch.full((1,1,1,KV_LEN), -1e9); m[0,0,0,-1] = 0
inv = 1.0/(1e6**(torch.arange(0, HD, 2, dtype=torch.float32)/HD))
c0 = torch.cos(0*inv).repeat(2).unsqueeze(0).unsqueeze(0)
s0 = torch.sin(0*inv).repeat(2).unsqueeze(0).unsqueeze(0)
kvs = [torch.zeros(1, N_KV, KV_LEN, HD) for _ in range(N_L*2)]
with torch.no_grad():
out = w(e, m, c0, s0, *kvs)
print(f"Test: hidden={out[0].shape}, logits={out[1].shape}, kv0={out[2].shape}")
# ExecuTorch export
from executorch.backends.qualcomm.utils.utils import *
htp = generate_htp_compiler_spec(use_fp16=True)
bo = QnnExecuTorchBackendOptions(backend_type=QnnExecuTorchBackendType.kHtpBackend, htp_options=htp)
specs = generate_qnn_executorch_compiler_spec(soc_model=QcomChipset.SM8750, backend_options=bo)
print(f"Lowering talker transformer ({N_L} layers, KV={KV_LEN}) to QNN...")
edge = to_edge_transform_and_lower_to_qnn(w, (e, m, c0, s0, *kvs), compiler_specs=specs)
print("LOWERED!")
pte = edge.to_executorch()
OUT = "/opt/Kazeia/models_qnn/talker_transformer_fp16.pte"
with open(OUT, "wb") as f:
pte.write_to_file(f)
print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")