kazeia/scripts/prepare_tts_embeds.py

105 lines
4.1 KiB
Python

#!/usr/bin/env python3
"""
Generate pre-computed TTS embeddings for any text.
Run on PC, push result to tablet, then run pipeline.
Usage: python3 prepare_tts_embeds.py "Your text here" [output.bin]
Then: adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
"""
import sys, os, struct, warnings, types
os.chdir("/tmp")
warnings.filterwarnings("ignore")
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_embeds.bin"
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
import torch, numpy as np
from qwen_tts import Qwen3TTSModel
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
print("Loading model...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
# Monkey-patch the talker's inner model __call__ to capture inputs
# (avoids HuggingFace generate() validation issues with forward patch)
captured_inputs = []
original_model_forward = talker.model.forward
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.shape[1] == 1:
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32))
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
talker.model.forward = patched_model_forward
print("Generating voice clone...")
audio_list, sr = tts.generate_voice_clone(
text=TEXT, ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True,
)
audio = audio_list[0]
print(f"Audio: {len(audio)/sr:.2f}s, {len(captured_inputs)} generation steps captured")
if len(captured_inputs) < 2:
print("ERROR: Not enough generation steps captured")
sys.exit(1)
# Build embeds file
# Prefill: first 10 captured are prefill steps, rest are decode
# Actually, captured_inputs only has single-token inputs (generation, not prefill)
# We need the prefill embeddings too. Load from existing structure.
# The first captured input is the FIRST generation step input
# (after prefill is done, the model starts generating codec tokens)
# Prefill inputs are multi-token and not captured
# Load existing prefill from reference file
EXISTING = "/tmp/existing_embeds.bin"
if not os.path.exists(EXISTING):
# Pull from tablet
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
if os.path.exists(EXISTING):
with open(EXISTING, "rb") as f:
nP = struct.unpack("<i", f.read(4))[0]
nT = struct.unpack("<i", f.read(4))[0]
old_embeds = []
for i in range(nT):
old_embeds.append(np.frombuffer(f.read(1024*4), dtype=np.float32).copy())
prefill_embeds = old_embeds[:9] # role+ctrl+spk+bos
else:
print("WARNING: No existing embeds file, prefill will be zeros")
prefill_embeds = [np.zeros(1024, dtype=np.float32)] * 9
# Build output: 10 prefill + N decode
nPrefill = 10 # 9 role/ctrl/spk/bos + first gen embed
nDecode = len(captured_inputs) - 1
nTotal = nPrefill + nDecode
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
# Prefill: 9 from existing + first captured
for emb in prefill_embeds:
f.write(emb.tobytes())
f.write(captured_inputs[0].tobytes())
# Decode: remaining captured inputs (complete embeddings from Python)
for i in range(1, len(captured_inputs)):
f.write(captured_inputs[i].tobytes())
print(f"\nSaved: {OUTPUT}")
print(f" {nPrefill} prefill + {nDecode} decode = {nTotal} total")
print(f" {os.path.getsize(OUTPUT)/1024:.0f} KB")
print(f" Audio: {len(audio)/sr:.2f}s ({len(captured_inputs)} tokens)")
print(f"\nPush to tablet:")
print(f" adb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
# Also save reference audio
import soundfile as sf
ref_path = OUTPUT.replace('.bin', '_ref.wav')
sf.write(ref_path, audio, sr)
print(f" Reference audio: {ref_path}")