kazeia/scripts/prepare_tts_voiceclone.py

110 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Generate voice-cloned TTS embeds: capture the COMPLETE talker input sequence
from a Python voice-cloning run (prefill + every generation step).
Unlike prepare_tts_embeds.py, this version captures multi-token prefill too,
so the NPU has the correct KV-cache context and there is no "tacs"/clicks.
Usage: python3 prepare_tts_voiceclone.py "Your text here" [output.bin] [voice.wav]
"""
import sys, os, struct, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_vc.bin"
VOICE = sys.argv[3] if len(sys.argv) > 3 else "/opt/Kazeia/voix/damien_15s_24k.wav"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
import torch, numpy as np
from qwen_tts import Qwen3TTSModel
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
print(f"Voice: {VOICE}")
print("Loading model...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
# Capture EVERY talker input, keeping track of per-call shape so we can split
# the first call (prefill, multi-token) from subsequent calls (decode, 1 token).
captured = [] # list of 1024-dim vectors, in order
call_shapes = [] # length of each call
codec_ids_per_step = [] # list of [16] int arrays — Python's predicted codes per decode step
original_forward = talker.model.forward
original_talker_forward = talker.forward
def patched_forward(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.dim() == 3:
t = inputs_embeds.shape[1]
call_shapes.append(t)
for i in range(t):
captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32))
return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
def talker_output_hook(module, input, output):
"""Captures codec_ids from each talker.forward call output (nn.Module forward hook).
Preserves the method signature so HF's kwarg validation still works."""
hs = getattr(output, 'hidden_states', None)
if isinstance(hs, tuple) and len(hs) >= 2 and hs[1] is not None:
cids = hs[1]
if hasattr(cids, 'detach'):
codec_ids_per_step.append(cids.detach().cpu().numpy().astype(np.int32).reshape(-1))
talker.model.forward = patched_forward
talker.register_forward_hook(talker_output_hook)
print("Running voice-clone generation (captures prefill + decode inputs)...")
audio_list, sr = tts.generate_voice_clone(
text=TEXT, ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True,
)
audio = audio_list[0]
if not call_shapes:
print("ERROR: captured nothing")
sys.exit(1)
# First call is prefill (multi-token). Every subsequent call is a single-token
# decode step. Decode length = total gen frames Python produced.
nPrefill = call_shapes[0]
nDecode = len(captured) - nPrefill
nTotal = len(captured)
print(f"Audio: {len(audio)/sr:.2f}s")
print(f"Captured {nTotal} embeds: {nPrefill} prefill + {nDecode} decode")
print(f"Captured {len(codec_ids_per_step)} code vectors × 16 per step")
print(f"Call shapes: first={call_shapes[0]}, rest={call_shapes[1:4]}... ({len(call_shapes)} calls total)")
# Binary format: <i nPrefill> <i nTotal> <f32 × 1024 × nTotal>
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
for emb in captured:
f.write(emb.tobytes())
print(f"\nSaved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f} KB)")
# Also save Python's sampled codes (diagnostic path: decode these directly via
# tablet BigVGAN to isolate whether the tremor comes from code divergence vs
# from our BigVGAN implementation).
codes_path = OUTPUT.replace('.bin', '_codes.bin')
with open(codes_path, "wb") as f:
n_steps = len(codec_ids_per_step)
f.write(struct.pack("<i", n_steps))
f.write(struct.pack("<i", 16)) # num codebooks
for c in codec_ids_per_step:
# Write 16 int32 codes per step (pad with 0 if shorter)
arr = np.zeros(16, dtype=np.int32)
arr[:min(16, len(c))] = c[:min(16, len(c))]
f.write(arr.tobytes())
print(f"Saved Python codes: {codes_path} ({os.path.getsize(codes_path)} bytes, {n_steps} steps)")
print(f"\nPush to tablet:")
print(f" adb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
print(f" adb push {codes_path} /data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
import soundfile as sf
ref_path = OUTPUT.replace('.bin', '_ref.wav')
sf.write(ref_path, audio, sr)
print(f" Python ref audio: {ref_path}")