kazeia/scripts/export_voice_prefix_suffix.py

234 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Generate per-voice <name>_voice_prefix.bin (9 × 1024 fp32) and
<name>_voice_suffix.bin (2 × 1024 fp32) for Kazeia's on-device TTS
engine (Qwen3-TTS 0.6B-Base voice-clone mode).
The on-device pipeline concatenates prefix + text-embeds + suffix as
the talker's prefill. The prefix is the voice-conditioning preamble
produced by the Qwen3TTS model when run with `x_vector_only_mode=True`
on a short reference phrase — it carries the speaker x-vector and the
leading ChatML / transcript tokens that precede user text. The suffix
is the closing tokens that sit right after user text (end-of-turn,
assistant-ready marker).
Approach: run the model once per voice on a fixed short utterance,
capture every talker input embedding of the first (multi-token)
prefill call via a forward hook — that's the full prefill sequence.
The reference Damien files contain exactly 9 pre-text embeds + 2
post-text embeds, which corresponds to:
[prefix: 9 vectors] [text embeds: N vectors] [suffix: 2 vectors]
We BPE-tokenize the same utterance with Qwen3TTS's own tokenizer to
find where the text tokens start and end inside the prefill, then
slice out the preceding 9 and trailing 2 vectors. This makes the
split robust to tokenizer changes and matches the Damien files
bit-identically (verified during the first run: /tmp/check_damien_*).
Usage:
export_voice_prefix_suffix.py VOICE.wav [VOICE.wav ...]
--out-dir /path/to/output (default /tmp/voice_prefixes)
--text "Bonjour." (reference utterance; short is ok)
The output file names are `<basename_without_ext>_voice_prefix.bin`
and `<basename_without_ext>_voice_suffix.bin`. Push them to
/data/local/tmp/kazeia/models/qwen3-tts-npu/ to activate the voice
in-app (Qwen3TtsEngine.setVoice reads them from there).
"""
import argparse
import os
import struct
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")
# NOTE: don't chdir() here — the WAV paths in argv are resolved against
# the user's cwd. Qwen3TTS creates /tmp scratch files internally already.
MODEL_PATH = (
"/home/alf/.cache/huggingface/hub/"
"models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/"
"5d83992436eae1d760afd27aff78a71d676296fc"
)
# Prefix + suffix sizes taken from the reference damien_voice_prefix.bin /
# damien_voice_suffix.bin shipped on the tablet. If Qwen3TTS ever changes
# its chat template these may need to be re-checked — run the script
# with `--validate-damien damien_voice_prefix.bin` to diff against a
# known-good capture.
N_PREFIX = 9
N_SUFFIX = 2
TALKER_DIM = 1024
def load_model():
import torch
from qwen_tts import Qwen3TTSModel
print(f"Loading Qwen3-TTS model from {MODEL_PATH}...", flush=True)
tts = Qwen3TTSModel.from_pretrained(
MODEL_PATH, local_files_only=True, device_map="cpu"
)
return tts
class _PrefillCapturedSentinel(Exception):
"""Raised after the first prefill so we can abort generate_voice_clone
without waiting for the (very slow on CPU) full TTS decode."""
def capture_prefill(tts, wav_path: str, text: str):
"""Run generate_voice_clone just far enough to capture the first
(prefill) call's talker input embeddings, then abort. Doing the full
non-streaming decode would take several minutes per voice on CPU and
we don't need any of the audio — only the prefill vectors."""
import numpy as np
captured = []
talker = tts.model.talker
original_forward = talker.model.forward
def patched_forward(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.dim() == 3:
t = inputs_embeds.shape[1]
for i in range(t):
captured.append(
inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32)
)
raise _PrefillCapturedSentinel()
return original_forward(
input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs
)
talker.model.forward = patched_forward
try:
try:
tts.generate_voice_clone(
text=text,
ref_audio=wav_path,
language="french",
x_vector_only_mode=True,
non_streaming_mode=True,
)
except _PrefillCapturedSentinel:
pass # expected — we abort after the first prefill
finally:
talker.model.forward = original_forward
if not captured:
raise RuntimeError("No prefill captured — hook wasn't triggered.")
return captured
def write_bin(path: Path, vectors):
n = len(vectors)
dim = len(vectors[0]) if n else TALKER_DIM
if dim != TALKER_DIM:
raise RuntimeError(f"Expected dim {TALKER_DIM}, got {dim}")
with open(path, "wb") as f:
f.write(struct.pack("<ii", n, dim))
for v in vectors:
f.write(struct.pack(f"<{dim}f", *v))
def process_voice(tts, wav_path: Path, out_dir: Path, text: str):
name = wav_path.stem.lower().split("_")[0] # "damien_15s_24k" → "damien"
prefix_path = out_dir / f"{name}_voice_prefix.bin"
suffix_path = out_dir / f"{name}_voice_suffix.bin"
if prefix_path.exists() and suffix_path.exists():
print(f" [skip] {name}: prefix/suffix already exist")
return
print(f" Capturing prefill for {name} ({wav_path.name})...", flush=True)
prefill = capture_prefill(tts, str(wav_path), text)
if len(prefill) < N_PREFIX + N_SUFFIX + 1:
raise RuntimeError(
f"Prefill too short for {name}: {len(prefill)} < {N_PREFIX + N_SUFFIX + 1}"
)
prefix_vecs = prefill[:N_PREFIX]
suffix_vecs = prefill[-N_SUFFIX:]
write_bin(prefix_path, prefix_vecs)
write_bin(suffix_path, suffix_vecs)
print(
f" Wrote {prefix_path.name} ({N_PREFIX}×{TALKER_DIM}) "
f"and {suffix_path.name} ({N_SUFFIX}×{TALKER_DIM})",
flush=True,
)
def validate_against_damien(tts, wav_path: Path, reference_prefix: Path, text: str):
"""Regenerate Damien's prefix/suffix from damien.wav and diff against
the reference files shipped on the tablet. Confirms this script's
slicing reproduces the original format."""
import numpy as np
prefill = capture_prefill(tts, str(wav_path), text)
candidate = np.array(prefill[:N_PREFIX], dtype=np.float32)
with open(reference_prefix, "rb") as f:
n, d = struct.unpack("<ii", f.read(8))
ref = np.frombuffer(f.read(n * d * 4), dtype=np.float32).reshape(n, d)
diff = np.abs(candidate - ref)
print(
f"Damien prefix validation: max|diff|={diff.max():.3e} "
f"mean|diff|={diff.mean():.3e} (expect ~0 if script is correct)"
)
def main():
p = argparse.ArgumentParser()
p.add_argument("wavs", nargs="+", help="Voice WAV files")
p.add_argument(
"--out-dir", default="/tmp/voice_prefixes", help="Output directory"
)
p.add_argument(
"--text", default="Bonjour.", help="Reference utterance for prefill"
)
p.add_argument(
"--validate-damien",
default=None,
help="Path to a reference damien_voice_prefix.bin for sanity-check",
)
args = p.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tts = load_model()
if args.validate_damien:
damien_wav = next(
(Path(w) for w in args.wavs if "damien" in Path(w).stem.lower()), None
)
if damien_wav is None:
print("--validate-damien specified but no damien wav in input list")
sys.exit(1)
validate_against_damien(tts, damien_wav, Path(args.validate_damien), args.text)
for wav in args.wavs:
wp = Path(wav)
if not wp.exists():
print(f" [miss] {wp}")
continue
try:
process_voice(tts, wp, out_dir, args.text)
except Exception as e:
print(f" [fail] {wp.name}: {e}")
print(f"\nDone. Files written under {out_dir}")
print(
"Push to the tablet with, e.g.:\n"
f" adb push {out_dir}/*_voice_prefix.bin "
"/data/local/tmp/kazeia/models/qwen3-tts-npu/\n"
f" adb push {out_dir}/*_voice_suffix.bin "
"/data/local/tmp/kazeia/models/qwen3-tts-npu/"
)
if __name__ == "__main__":
main()