105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Generate pre-computed TTS embeddings for any text.
|
|
Run on PC, push result to tablet, then run pipeline.
|
|
|
|
Usage: python3 prepare_tts_embeds.py "Your text here" [output.bin]
|
|
Then: adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
|
|
"""
|
|
import sys, os, struct, warnings, types
|
|
os.chdir("/tmp")
|
|
warnings.filterwarnings("ignore")
|
|
|
|
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
|
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_embeds.bin"
|
|
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
|
|
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
|
|
|
import torch, numpy as np
|
|
from qwen_tts import Qwen3TTSModel
|
|
|
|
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
|
print("Loading model...")
|
|
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
|
talker = tts.model.talker
|
|
|
|
# Monkey-patch the talker's inner model __call__ to capture inputs
|
|
# (avoids HuggingFace generate() validation issues with forward patch)
|
|
captured_inputs = []
|
|
original_model_forward = talker.model.forward
|
|
|
|
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
|
if inputs_embeds is not None and inputs_embeds.shape[1] == 1:
|
|
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32))
|
|
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
|
|
|
talker.model.forward = patched_model_forward
|
|
|
|
print("Generating voice clone...")
|
|
audio_list, sr = tts.generate_voice_clone(
|
|
text=TEXT, ref_audio=VOICE, language="french",
|
|
x_vector_only_mode=True, non_streaming_mode=True,
|
|
)
|
|
audio = audio_list[0]
|
|
print(f"Audio: {len(audio)/sr:.2f}s, {len(captured_inputs)} generation steps captured")
|
|
|
|
if len(captured_inputs) < 2:
|
|
print("ERROR: Not enough generation steps captured")
|
|
sys.exit(1)
|
|
|
|
# Build embeds file
|
|
# Prefill: first 10 captured are prefill steps, rest are decode
|
|
# Actually, captured_inputs only has single-token inputs (generation, not prefill)
|
|
# We need the prefill embeddings too. Load from existing structure.
|
|
|
|
# The first captured input is the FIRST generation step input
|
|
# (after prefill is done, the model starts generating codec tokens)
|
|
# Prefill inputs are multi-token and not captured
|
|
|
|
# Load existing prefill from reference file
|
|
EXISTING = "/tmp/existing_embeds.bin"
|
|
if not os.path.exists(EXISTING):
|
|
# Pull from tablet
|
|
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
|
|
|
|
if os.path.exists(EXISTING):
|
|
with open(EXISTING, "rb") as f:
|
|
nP = struct.unpack("<i", f.read(4))[0]
|
|
nT = struct.unpack("<i", f.read(4))[0]
|
|
old_embeds = []
|
|
for i in range(nT):
|
|
old_embeds.append(np.frombuffer(f.read(1024*4), dtype=np.float32).copy())
|
|
prefill_embeds = old_embeds[:9] # role+ctrl+spk+bos
|
|
else:
|
|
print("WARNING: No existing embeds file, prefill will be zeros")
|
|
prefill_embeds = [np.zeros(1024, dtype=np.float32)] * 9
|
|
|
|
# Build output: 10 prefill + N decode
|
|
nPrefill = 10 # 9 role/ctrl/spk/bos + first gen embed
|
|
nDecode = len(captured_inputs) - 1
|
|
nTotal = nPrefill + nDecode
|
|
|
|
with open(OUTPUT, "wb") as f:
|
|
f.write(struct.pack("<i", nPrefill))
|
|
f.write(struct.pack("<i", nTotal))
|
|
# Prefill: 9 from existing + first captured
|
|
for emb in prefill_embeds:
|
|
f.write(emb.tobytes())
|
|
f.write(captured_inputs[0].tobytes())
|
|
# Decode: remaining captured inputs (complete embeddings from Python)
|
|
for i in range(1, len(captured_inputs)):
|
|
f.write(captured_inputs[i].tobytes())
|
|
|
|
print(f"\nSaved: {OUTPUT}")
|
|
print(f" {nPrefill} prefill + {nDecode} decode = {nTotal} total")
|
|
print(f" {os.path.getsize(OUTPUT)/1024:.0f} KB")
|
|
print(f" Audio: {len(audio)/sr:.2f}s ({len(captured_inputs)} tokens)")
|
|
print(f"\nPush to tablet:")
|
|
print(f" adb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
|
|
|
# Also save reference audio
|
|
import soundfile as sf
|
|
ref_path = OUTPUT.replace('.bin', '_ref.wav')
|
|
sf.write(ref_path, audio, sr)
|
|
print(f" Reference audio: {ref_path}")
|