110 lines
4.7 KiB
Python
110 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Generate voice-cloned TTS embeds: capture the COMPLETE talker input sequence
|
||
from a Python voice-cloning run (prefill + every generation step).
|
||
|
||
Unlike prepare_tts_embeds.py, this version captures multi-token prefill too,
|
||
so the NPU has the correct KV-cache context and there is no "tacs"/clicks.
|
||
|
||
Usage: python3 prepare_tts_voiceclone.py "Your text here" [output.bin] [voice.wav]
|
||
"""
|
||
import sys, os, struct, warnings
|
||
os.chdir("/tmp")
|
||
warnings.filterwarnings("ignore")
|
||
|
||
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_vc.bin"
|
||
VOICE = sys.argv[3] if len(sys.argv) > 3 else "/opt/Kazeia/voix/damien_15s_24k.wav"
|
||
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||
|
||
import torch, numpy as np
|
||
from qwen_tts import Qwen3TTSModel
|
||
|
||
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||
print(f"Voice: {VOICE}")
|
||
print("Loading model...")
|
||
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||
talker = tts.model.talker
|
||
|
||
# Capture EVERY talker input, keeping track of per-call shape so we can split
|
||
# the first call (prefill, multi-token) from subsequent calls (decode, 1 token).
|
||
captured = [] # list of 1024-dim vectors, in order
|
||
call_shapes = [] # length of each call
|
||
codec_ids_per_step = [] # list of [16] int arrays — Python's predicted codes per decode step
|
||
original_forward = talker.model.forward
|
||
original_talker_forward = talker.forward
|
||
|
||
def patched_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
||
if inputs_embeds is not None and inputs_embeds.dim() == 3:
|
||
t = inputs_embeds.shape[1]
|
||
call_shapes.append(t)
|
||
for i in range(t):
|
||
captured.append(inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32))
|
||
return original_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
||
|
||
def talker_output_hook(module, input, output):
|
||
"""Captures codec_ids from each talker.forward call output (nn.Module forward hook).
|
||
Preserves the method signature so HF's kwarg validation still works."""
|
||
hs = getattr(output, 'hidden_states', None)
|
||
if isinstance(hs, tuple) and len(hs) >= 2 and hs[1] is not None:
|
||
cids = hs[1]
|
||
if hasattr(cids, 'detach'):
|
||
codec_ids_per_step.append(cids.detach().cpu().numpy().astype(np.int32).reshape(-1))
|
||
|
||
talker.model.forward = patched_forward
|
||
talker.register_forward_hook(talker_output_hook)
|
||
|
||
print("Running voice-clone generation (captures prefill + decode inputs)...")
|
||
audio_list, sr = tts.generate_voice_clone(
|
||
text=TEXT, ref_audio=VOICE, language="french",
|
||
x_vector_only_mode=True, non_streaming_mode=True,
|
||
)
|
||
audio = audio_list[0]
|
||
|
||
if not call_shapes:
|
||
print("ERROR: captured nothing")
|
||
sys.exit(1)
|
||
|
||
# First call is prefill (multi-token). Every subsequent call is a single-token
|
||
# decode step. Decode length = total gen frames Python produced.
|
||
nPrefill = call_shapes[0]
|
||
nDecode = len(captured) - nPrefill
|
||
nTotal = len(captured)
|
||
|
||
print(f"Audio: {len(audio)/sr:.2f}s")
|
||
print(f"Captured {nTotal} embeds: {nPrefill} prefill + {nDecode} decode")
|
||
print(f"Captured {len(codec_ids_per_step)} code vectors × 16 per step")
|
||
print(f"Call shapes: first={call_shapes[0]}, rest={call_shapes[1:4]}... ({len(call_shapes)} calls total)")
|
||
|
||
# Binary format: <i nPrefill> <i nTotal> <f32 × 1024 × nTotal>
|
||
with open(OUTPUT, "wb") as f:
|
||
f.write(struct.pack("<i", nPrefill))
|
||
f.write(struct.pack("<i", nTotal))
|
||
for emb in captured:
|
||
f.write(emb.tobytes())
|
||
print(f"\nSaved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f} KB)")
|
||
|
||
# Also save Python's sampled codes (diagnostic path: decode these directly via
|
||
# tablet BigVGAN to isolate whether the tremor comes from code divergence vs
|
||
# from our BigVGAN implementation).
|
||
codes_path = OUTPUT.replace('.bin', '_codes.bin')
|
||
with open(codes_path, "wb") as f:
|
||
n_steps = len(codec_ids_per_step)
|
||
f.write(struct.pack("<i", n_steps))
|
||
f.write(struct.pack("<i", 16)) # num codebooks
|
||
for c in codec_ids_per_step:
|
||
# Write 16 int32 codes per step (pad with 0 if shorter)
|
||
arr = np.zeros(16, dtype=np.int32)
|
||
arr[:min(16, len(c))] = c[:min(16, len(c))]
|
||
f.write(arr.tobytes())
|
||
print(f"Saved Python codes: {codes_path} ({os.path.getsize(codes_path)} bytes, {n_steps} steps)")
|
||
|
||
print(f"\nPush to tablet:")
|
||
print(f" adb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
||
print(f" adb push {codes_path} /data/local/tmp/kazeia/models/qwen3-tts-npu/python_codes.bin")
|
||
|
||
import soundfile as sf
|
||
ref_path = OUTPUT.replace('.bin', '_ref.wav')
|
||
sf.write(ref_path, audio, sr)
|
||
print(f" Python ref audio: {ref_path}")
|