import os, sys, warnings, torch, torch.nn as nn, torch.nn.functional as F sys.path = [p for p in sys.path if 'Kazeia/executorch' not in p and p != '.'] os.environ['QNN_SDK_ROOT'] = os.environ.get('QNN_SDK_ROOT', '') warnings.filterwarnings('ignore') N_L=5;N_H=16;N_KV=8;HD=128;DIM=1024;N_REP=2;CP_KV=16 state=torch.load("/opt/Kazeia/models_qnn/qwen3-tts-native/code_predictor_weights.pt",map_location="cpu",weights_only=False) def rotate_half(x):x1=x[...,:x.shape[-1]//2];x2=x[...,x.shape[-1]//2:];return torch.cat((-x2,x1),dim=-1) def repeat_kv(x,n):B,H,T,D=x.shape;return x[:,:,None,:,:].expand(B,H,n,T,D).reshape(B,H*n,T,D) class RMSNorm(nn.Module): def __init__(s,d):super().__init__();s.weight=nn.Parameter(torch.ones(d)) def forward(s,x):return x*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+1e-6)*s.weight # Just the transformer, NO heads class CPTransformer(nn.Module): def __init__(s,st): super().__init__();s.na=nn.ModuleList();s.nf=nn.ModuleList();s.qp=nn.ModuleList();s.kp=nn.ModuleList() s.vp=nn.ModuleList();s.op=nn.ModuleList();s.qn=nn.ModuleList();s.kn=nn.ModuleList() s.ga=nn.ModuleList();s.dn2=nn.ModuleList();s.up=nn.ModuleList() for i in range(N_L): p=f"model.layers.{i}.";a=RMSNorm(DIM);a.weight.data=st[p+"input_layernorm.weight"];s.na.append(a) f=RMSNorm(DIM);f.weight.data=st[p+"post_attention_layernorm.weight"];s.nf.append(f) s.qp.append(nn.Linear(DIM,N_H*HD,bias=False));s.qp[-1].weight.data=st[p+"self_attn.q_proj.weight"] s.kp.append(nn.Linear(DIM,N_KV*HD,bias=False));s.kp[-1].weight.data=st[p+"self_attn.k_proj.weight"] s.vp.append(nn.Linear(DIM,N_KV*HD,bias=False));s.vp[-1].weight.data=st[p+"self_attn.v_proj.weight"] s.op.append(nn.Linear(N_H*HD,DIM,bias=False));s.op[-1].weight.data=st[p+"self_attn.o_proj.weight"] q=RMSNorm(HD);q.weight.data=st[p+"self_attn.q_norm.weight"];s.qn.append(q) k=RMSNorm(HD);k.weight.data=st[p+"self_attn.k_norm.weight"];s.kn.append(k) s.ga.append(nn.Linear(DIM,3072,bias=False));s.ga[-1].weight.data=st[p+"mlp.gate_proj.weight"] s.dn2.append(nn.Linear(3072,DIM,bias=False));s.dn2[-1].weight.data=st[p+"mlp.down_proj.weight"] s.up.append(nn.Linear(DIM,3072,bias=False));s.up[-1].weight.data=st[p+"mlp.up_proj.weight"] s.fn=RMSNorm(DIM);s.fn.weight.data=st["model.norm.weight"] def forward(s,emb,mask,cos,sin,k0,v0,k1,v1,k2,v2,k3,v3,k4,v4): h=emb;c=cos.unsqueeze(1);sn=sin.unsqueeze(1);kv=[k0,v0,k1,v1,k2,v2,k3,v3,k4,v4];nk=[] for i in range(N_L): kc=kv[i*2];vc=kv[i*2+1];r=h;hn=s.na[i](h) q=s.qp[i](hn).view(1,1,N_H,HD).transpose(1,2);k=s.kp[i](hn).view(1,1,N_KV,HD).transpose(1,2);v=s.vp[i](hn).view(1,1,N_KV,HD).transpose(1,2) q=s.qn[i](q);k=s.kn[i](k);q=q*c+rotate_half(q)*sn;k=k*c+rotate_half(k)*sn kf=torch.cat([kc[:,:,1:,:],k],dim=2);vf=torch.cat([vc[:,:,1:,:],v],dim=2) ke=repeat_kv(kf,N_REP);ve=repeat_kv(vf,N_REP) sc=torch.matmul(q,ke.transpose(-2,-1))*(1.0/(HD**0.5))+mask ao=torch.matmul(F.softmax(sc,dim=-1),ve).transpose(1,2).contiguous().view(1,1,-1) h=r+s.op[i](ao);r=h;fn=s.nf[i](h);h=r+s.dn2[i](F.silu(s.ga[i](fn))*s.up[i](fn));nk.extend([kf,vf]) return (s.fn(h),nk[0],nk[1],nk[2],nk[3],nk[4],nk[5],nk[6],nk[7],nk[8],nk[9]) print("Building (no heads)...") w=CPTransformer(state).eval() print(f"Params: {sum(p.numel() for p in w.parameters())/1e6:.1f}M") e=torch.randn(1,1,DIM);m=torch.full((1,1,1,CP_KV),-1e9);m[0,0,0,-1]=0 inv=1.0/(1e6**(torch.arange(0,HD,2,dtype=torch.float32)/HD)) c0=torch.cos(0*inv).repeat(2).unsqueeze(0).unsqueeze(0);s0=torch.sin(0*inv).repeat(2).unsqueeze(0).unsqueeze(0) kvs=[torch.zeros(1,N_KV,CP_KV,HD) for _ in range(10)] with torch.no_grad():out=w(e,m,c0,s0,*kvs) print(f"Test: h={out[0].shape}") from executorch.backends.qualcomm.utils.utils import * htp=generate_htp_compiler_spec(use_fp16=True) bo=QnnExecuTorchBackendOptions(backend_type=QnnExecuTorchBackendType.kHtpBackend,htp_options=htp) specs=generate_qnn_executorch_compiler_spec(soc_model=QcomChipset.SM8750,backend_options=bo) print("Lowering CP transformer (no heads) to QNN...") edge=to_edge_transform_and_lower_to_qnn(w,(e,m,c0,s0,*kvs),compiler_specs=specs) print("LOWERED!") pte=edge.to_executorch() OUT="/opt/Kazeia/models_qnn/cp_transformer_fp16.pte" with open(OUT,"wb") as f:pte.write_to_file(f) print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")