kazeia/scripts/export_tts_text_embeddings.py

157 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Export everything the tablet needs to build TTS prefill embeds for arbitrary
LLM text, offline, without talking to a PC.
Outputs (pushed to /data/local/tmp/kazeia/models/qwen3-tts-npu/):
- text_embeds_full_fp16.bin : 151936 × 1024 fp16 = 311 MB
Pre-projected text embeddings for the full Qwen3 vocab. Per-token
lookup on-device replaces a lookup + FC1 + SiLU + FC2 + bias. Same
numbers PyTorch produces for text_projection(text_embedding(id)).
- damien_voice_prefix.bin : 9 × 1024 fp32 = 36 KB
The fixed voice-cloning prefix (positions 0..8) for speaker Damien,
captured from a real voice-clone run. Positions 0..6 = role/control
tokens, position 7 = xvector (L2 norm ~10), position 8 = trailing
voice-marker. Same for every phrase uttered by this speaker, so we
capture once here and reuse indefinitely on-device.
- damien_voice_suffix.bin : 2 × 1024 fp32 = 8 KB
The fixed voice-cloning SUFFIX (last 2 positions of the prefill)
that Python emits AFTER the text tokens. Verified bit-identical
across segments of different texts → invariant closure marker
for the voice-clone conditioning. Without it the talker misreads
the end of the text and produces garbled output.
- qwen3_tokenizer/ : tokenizer files copied from HF snapshot
tokenizer.json, vocab.json, merges.txt, special_tokens_map.json.
Kotlin BPE implementation reads vocab + merges at init.
The combination lets the tablet build, for any text, the exact same
prefill tensor PyTorch would build, bit-for-bit at fp16 — which is
what our Hexagon talker consumes anyway.
Usage:
python3 export_tts_text_embeddings.py [output_dir]
"""
import sys, os, struct, shutil, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
OUTPUT_DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/kazeia_tts_export"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/qwen3_tokenizer", exist_ok=True)
import torch, numpy as np
from qwen_tts import Qwen3TTSModel
print("Loading Qwen3-TTS model (~30s, CPU)...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
# ---- 1. Full projected text embeddings ----
# Evaluate text_projection(text_embedding.weight) for EVERY vocab entry.
# Batching keeps peak memory bounded; fp32 matmul then fp16 store preserves
# precision up to the final quantization step.
print("\n[1/3] Precomputing projected embeddings for full vocab...")
vocab_size = talker.model.text_embedding.weight.shape[0]
print(f" Vocab size: {vocab_size}")
BATCH = 4096
out_path = f"{OUTPUT_DIR}/text_embeds_full_fp16.bin"
with torch.no_grad():
W_emb = talker.model.text_embedding.weight # [vocab, 2048]
fc1_w = talker.text_projection.linear_fc1.weight # [2048, 2048]
fc1_b = talker.text_projection.linear_fc1.bias # [2048]
fc2_w = talker.text_projection.linear_fc2.weight # [1024, 2048]
fc2_b = talker.text_projection.linear_fc2.bias # [1024]
with open(out_path, "wb") as f:
for start in range(0, vocab_size, BATCH):
end = min(start + BATCH, vocab_size)
x = W_emb[start:end].float() # [b, 2048]
h = torch.nn.functional.linear(x, fc1_w, fc1_b) # [b, 2048]
h = torch.nn.functional.silu(h) # [b, 2048]
y = torch.nn.functional.linear(h, fc2_w, fc2_b) # [b, 1024]
f.write(y.to(torch.float16).numpy().tobytes())
if start % (BATCH * 4) == 0:
print(f" {end}/{vocab_size} ({end*100//vocab_size}%)", flush=True)
sz_mb = os.path.getsize(out_path) / (1024*1024)
print(f" -> {out_path} ({sz_mb:.1f} MB)")
# Sanity check: re-read a couple of tokens, project live, compare.
print("\n Sanity check (token 1043 = 'Bonjour'):")
with torch.no_grad():
live = talker.text_projection(talker.model.text_embedding(torch.tensor([1043])))[0].float().numpy()
with open(out_path, "rb") as f:
f.seek(1043 * 1024 * 2)
stored = np.frombuffer(f.read(1024 * 2), dtype=np.float16).astype(np.float32)
diff = float(np.abs(live - stored).max())
print(f" max abs diff live vs stored fp16: {diff:.2e} (expect < 1e-3)")
# ---- 2. Damien voice prefix (positions 0..8) ----
# Run a voice-clone and capture the multi-token prefill call, then keep the
# first 9 rows. Those are fixed per speaker — same for every phrase — so
# one capture suffices for the app's lifetime.
print(f"\n[2/3] Capturing Damien voice prefix from {VOICE}...")
captured = []
call_shapes = []
original_forward = talker.model.forward
def patched(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.dim() == 3:
call_shapes.append(inputs_embeds.shape[1])
for i in range(inputs_embeds.shape[1]):
captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32))
return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
talker.model.forward = patched
# Any short sentence works — we only keep positions 0..8 which are text-
# invariant.
_ = tts.generate_voice_clone(
text="Bonjour, je suis Kazeia.", ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True,
)
talker.model.forward = original_forward
nP = call_shapes[0]
print(f" Prefill size: {nP} tokens")
prefix_9 = np.stack(captured[:9]) # [9, 1024]
suffix_2 = np.stack(captured[nP-2:nP]) # [2, 1024]
prefix_path = f"{OUTPUT_DIR}/damien_voice_prefix.bin"
with open(prefix_path, "wb") as f:
f.write(struct.pack("<i", 9))
f.write(struct.pack("<i", 1024))
f.write(prefix_9.astype(np.float32).tobytes())
print(f" prefix -> {prefix_path} ({os.path.getsize(prefix_path)} bytes)")
suffix_path = f"{OUTPUT_DIR}/damien_voice_suffix.bin"
with open(suffix_path, "wb") as f:
f.write(struct.pack("<i", 2))
f.write(struct.pack("<i", 1024))
f.write(suffix_2.astype(np.float32).tobytes())
print(f" suffix -> {suffix_path} ({os.path.getsize(suffix_path)} bytes)")
norms_pref = [float(np.linalg.norm(prefix_9[i])) for i in range(9)]
norms_suff = [float(np.linalg.norm(suffix_2[i])) for i in range(2)]
print(f" Prefix norms: {[f'{n:.2f}' for n in norms_pref]} (pos 7 = xvector ~10, others ~1.6-1.8)")
print(f" Suffix norms: {[f'{n:.2f}' for n in norms_suff]}")
# ---- 3. Tokenizer files ----
# Copy the HF tokenizer artefacts so a Kotlin BPE can reproduce Python
# encode() bit-for-bit.
print(f"\n[3/3] Copying tokenizer to {OUTPUT_DIR}/qwen3_tokenizer/...")
for name in ("tokenizer.json", "vocab.json", "merges.txt", "tokenizer_config.json", "special_tokens_map.json"):
src = os.path.join(MODEL, name)
if os.path.exists(src):
shutil.copy(src, f"{OUTPUT_DIR}/qwen3_tokenizer/{name}")
print(f" {name} ({os.path.getsize(src)} bytes)")
else:
print(f" (skipped, not present: {name})")
print(f"\n=== DONE ===")
print(f"Files ready in {OUTPUT_DIR}/")
print(f"\nPush to tablet:")
print(f" adb push {OUTPUT_DIR}/text_embeds_full_fp16.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/")
print(f" adb push {OUTPUT_DIR}/damien_voice_prefix.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/")
print(f" adb push {OUTPUT_DIR}/qwen3_tokenizer /data/local/tmp/kazeia/models/qwen3-tts-npu/")