124 lines
5.8 KiB
Python
124 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Export Qwen3-TTS talker transformer to ExecuTorch .pte on QNN HTP fp16.
|
||
28 layers, 1024 dim, GQA 16/8, M-RoPE, codec_head.
|
||
Fixed KV cache with shift (like CP export).
|
||
"""
|
||
import os, sys, warnings, torch, torch.nn as nn, torch.nn.functional as F
|
||
sys.path = [p for p in sys.path if 'Kazeia/executorch' not in p and p != '.']
|
||
os.environ['QNN_SDK_ROOT'] = os.environ.get('QNN_SDK_ROOT', '')
|
||
warnings.filterwarnings('ignore')
|
||
|
||
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
|
||
VOCAB = 3072; FFN = 3072
|
||
KV_LEN = 100 # Must be >= prefill+maxGen. KV=64 caused quality loss (role tokens evicted)
|
||
|
||
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
|
||
map_location="cpu", weights_only=False)
|
||
|
||
def rotate_half(x):
|
||
x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:]
|
||
return torch.cat((-x2, x1), dim=-1)
|
||
|
||
def repeat_kv(x, n):
|
||
B, H, T, D = x.shape
|
||
return x[:, :, None, :, :].expand(B, H, n, T, D).reshape(B, H*n, T, D)
|
||
|
||
class RMSNorm(nn.Module):
|
||
def __init__(s, d):
|
||
super().__init__(); s.weight = nn.Parameter(torch.ones(d))
|
||
def forward(s, x):
|
||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * s.weight
|
||
|
||
class TalkerTransformer(nn.Module):
|
||
"""Talker transformer without codec_head (head done on CPU)."""
|
||
def __init__(s, st):
|
||
super().__init__()
|
||
s.na = nn.ModuleList(); s.nf = nn.ModuleList()
|
||
s.qp = nn.ModuleList(); s.kp = nn.ModuleList()
|
||
s.vp = nn.ModuleList(); s.op = nn.ModuleList()
|
||
s.qn = nn.ModuleList(); s.kn = nn.ModuleList()
|
||
s.ga = nn.ModuleList(); s.dn = nn.ModuleList(); s.up = nn.ModuleList()
|
||
for i in range(N_L):
|
||
p = f"layers.{i}."
|
||
a = RMSNorm(DIM); a.weight.data = st[p+"attention_norm.weight"]; s.na.append(a)
|
||
f = RMSNorm(DIM); f.weight.data = st[p+"ffn_norm.weight"]; s.nf.append(f)
|
||
s.qp.append(nn.Linear(DIM, N_H*HD, bias=False)); s.qp[-1].weight.data = st[p+"attention.wq.weight"]
|
||
s.kp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.kp[-1].weight.data = st[p+"attention.wk.weight"]
|
||
s.vp.append(nn.Linear(DIM, N_KV*HD, bias=False)); s.vp[-1].weight.data = st[p+"attention.wv.weight"]
|
||
s.op.append(nn.Linear(N_H*HD, DIM, bias=False)); s.op[-1].weight.data = st[p+"attention.wo.weight"]
|
||
q = RMSNorm(HD); q.weight.data = st[p+"attention.q_norm_fn.weight"]; s.qn.append(q)
|
||
k = RMSNorm(HD); k.weight.data = st[p+"attention.k_norm_fn.weight"]; s.kn.append(k)
|
||
s.ga.append(nn.Linear(DIM, FFN, bias=False)); s.ga[-1].weight.data = st[p+"feed_forward.w1.weight"]
|
||
s.dn.append(nn.Linear(FFN, DIM, bias=False)); s.dn[-1].weight.data = st[p+"feed_forward.w2.weight"]
|
||
s.up.append(nn.Linear(DIM, FFN, bias=False)); s.up[-1].weight.data = st[p+"feed_forward.w3.weight"]
|
||
s.fn = RMSNorm(DIM); s.fn.weight.data = st["norm.weight"]
|
||
# Include codec_head for CB0 prediction
|
||
s.head = nn.Linear(DIM, VOCAB, bias=False); s.head.weight.data = st["output.weight"]
|
||
|
||
def forward(s, emb, mask, cos, sin, *kv_args):
|
||
"""
|
||
emb: [1,1,1024]
|
||
mask: [1,1,1,KV_LEN]
|
||
cos: [1,1,128]
|
||
sin: [1,1,128]
|
||
kv: 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128])
|
||
Returns: hidden[1,1,1024], logits[1,1,3072], 28 × (k[1,8,KV_LEN,128], v[1,8,KV_LEN,128])
|
||
"""
|
||
h = emb; c = cos.unsqueeze(1); sn = sin.unsqueeze(1)
|
||
nk = []
|
||
for i in range(N_L):
|
||
kc = kv_args[i*2]; vc = kv_args[i*2+1]
|
||
r = h; hn = s.na[i](h)
|
||
q = s.qp[i](hn).view(1,1,N_H,HD).transpose(1,2)
|
||
k = s.kp[i](hn).view(1,1,N_KV,HD).transpose(1,2)
|
||
v = s.vp[i](hn).view(1,1,N_KV,HD).transpose(1,2)
|
||
q = s.qn[i](q); k = s.kn[i](k)
|
||
q = q*c + rotate_half(q)*sn; k = k*c + rotate_half(k)*sn
|
||
# Shift KV: drop oldest, append new
|
||
kf = torch.cat([kc[:,:,1:,:], k], dim=2)
|
||
vf = torch.cat([vc[:,:,1:,:], v], dim=2)
|
||
ke = repeat_kv(kf, N_REP); ve = repeat_kv(vf, N_REP)
|
||
sc = torch.matmul(q, ke.transpose(-2,-1)) * (1.0/(HD**0.5)) + mask
|
||
ao = torch.matmul(F.softmax(sc, dim=-1), ve).transpose(1,2).contiguous().view(1,1,-1)
|
||
h = r + s.op[i](ao)
|
||
r = h; fn = s.nf[i](h)
|
||
h = r + s.dn[i](F.silu(s.ga[i](fn)) * s.up[i](fn))
|
||
nk.extend([kf, vf])
|
||
h = s.fn(h)
|
||
logits = s.head(h)
|
||
return (h, logits, *nk)
|
||
|
||
print("Building talker transformer...")
|
||
w = TalkerTransformer(state).eval()
|
||
n_params = sum(p.numel() for p in w.parameters())
|
||
print(f"Params: {n_params/1e6:.1f}M ({n_params*2/1024/1024:.0f}MB fp16)")
|
||
|
||
# Test
|
||
e = torch.randn(1,1,DIM)
|
||
m = torch.full((1,1,1,KV_LEN), -1e9); m[0,0,0,-1] = 0
|
||
inv = 1.0/(1e6**(torch.arange(0, HD, 2, dtype=torch.float32)/HD))
|
||
c0 = torch.cos(0*inv).repeat(2).unsqueeze(0).unsqueeze(0)
|
||
s0 = torch.sin(0*inv).repeat(2).unsqueeze(0).unsqueeze(0)
|
||
kvs = [torch.zeros(1, N_KV, KV_LEN, HD) for _ in range(N_L*2)]
|
||
|
||
with torch.no_grad():
|
||
out = w(e, m, c0, s0, *kvs)
|
||
print(f"Test: hidden={out[0].shape}, logits={out[1].shape}, kv0={out[2].shape}")
|
||
|
||
# ExecuTorch export
|
||
from executorch.backends.qualcomm.utils.utils import *
|
||
htp = generate_htp_compiler_spec(use_fp16=True)
|
||
bo = QnnExecuTorchBackendOptions(backend_type=QnnExecuTorchBackendType.kHtpBackend, htp_options=htp)
|
||
specs = generate_qnn_executorch_compiler_spec(soc_model=QcomChipset.SM8750, backend_options=bo)
|
||
|
||
print(f"Lowering talker transformer ({N_L} layers, KV={KV_LEN}) to QNN...")
|
||
edge = to_edge_transform_and_lower_to_qnn(w, (e, m, c0, s0, *kvs), compiler_specs=specs)
|
||
print("LOWERED!")
|
||
|
||
pte = edge.to_executorch()
|
||
OUT = "/opt/Kazeia/models_qnn/talker_transformer_fp16.pte"
|
||
with open(OUT, "wb") as f:
|
||
pte.write_to_file(f)
|
||
print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")
|