234 lines
8.2 KiB
Python
234 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Generate per-voice <name>_voice_prefix.bin (9 × 1024 fp32) and
|
||
<name>_voice_suffix.bin (2 × 1024 fp32) for Kazeia's on-device TTS
|
||
engine (Qwen3-TTS 0.6B-Base voice-clone mode).
|
||
|
||
The on-device pipeline concatenates prefix + text-embeds + suffix as
|
||
the talker's prefill. The prefix is the voice-conditioning preamble
|
||
produced by the Qwen3TTS model when run with `x_vector_only_mode=True`
|
||
on a short reference phrase — it carries the speaker x-vector and the
|
||
leading ChatML / transcript tokens that precede user text. The suffix
|
||
is the closing tokens that sit right after user text (end-of-turn,
|
||
assistant-ready marker).
|
||
|
||
Approach: run the model once per voice on a fixed short utterance,
|
||
capture every talker input embedding of the first (multi-token)
|
||
prefill call via a forward hook — that's the full prefill sequence.
|
||
The reference Damien files contain exactly 9 pre-text embeds + 2
|
||
post-text embeds, which corresponds to:
|
||
|
||
[prefix: 9 vectors] [text embeds: N vectors] [suffix: 2 vectors]
|
||
|
||
We BPE-tokenize the same utterance with Qwen3TTS's own tokenizer to
|
||
find where the text tokens start and end inside the prefill, then
|
||
slice out the preceding 9 and trailing 2 vectors. This makes the
|
||
split robust to tokenizer changes and matches the Damien files
|
||
bit-identically (verified during the first run: /tmp/check_damien_*).
|
||
|
||
Usage:
|
||
export_voice_prefix_suffix.py VOICE.wav [VOICE.wav ...]
|
||
--out-dir /path/to/output (default /tmp/voice_prefixes)
|
||
--text "Bonjour." (reference utterance; short is ok)
|
||
|
||
The output file names are `<basename_without_ext>_voice_prefix.bin`
|
||
and `<basename_without_ext>_voice_suffix.bin`. Push them to
|
||
/data/local/tmp/kazeia/models/qwen3-tts-npu/ to activate the voice
|
||
in-app (Qwen3TtsEngine.setVoice reads them from there).
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import struct
|
||
import sys
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
warnings.filterwarnings("ignore")
|
||
# NOTE: don't chdir() here — the WAV paths in argv are resolved against
|
||
# the user's cwd. Qwen3TTS creates /tmp scratch files internally already.
|
||
|
||
MODEL_PATH = (
|
||
"/home/alf/.cache/huggingface/hub/"
|
||
"models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/"
|
||
"5d83992436eae1d760afd27aff78a71d676296fc"
|
||
)
|
||
|
||
# Prefix + suffix sizes taken from the reference damien_voice_prefix.bin /
|
||
# damien_voice_suffix.bin shipped on the tablet. If Qwen3TTS ever changes
|
||
# its chat template these may need to be re-checked — run the script
|
||
# with `--validate-damien damien_voice_prefix.bin` to diff against a
|
||
# known-good capture.
|
||
N_PREFIX = 9
|
||
N_SUFFIX = 2
|
||
TALKER_DIM = 1024
|
||
|
||
|
||
def load_model():
|
||
import torch
|
||
from qwen_tts import Qwen3TTSModel
|
||
|
||
print(f"Loading Qwen3-TTS model from {MODEL_PATH}...", flush=True)
|
||
tts = Qwen3TTSModel.from_pretrained(
|
||
MODEL_PATH, local_files_only=True, device_map="cpu"
|
||
)
|
||
return tts
|
||
|
||
|
||
class _PrefillCapturedSentinel(Exception):
|
||
"""Raised after the first prefill so we can abort generate_voice_clone
|
||
without waiting for the (very slow on CPU) full TTS decode."""
|
||
|
||
|
||
def capture_prefill(tts, wav_path: str, text: str):
|
||
"""Run generate_voice_clone just far enough to capture the first
|
||
(prefill) call's talker input embeddings, then abort. Doing the full
|
||
non-streaming decode would take several minutes per voice on CPU and
|
||
we don't need any of the audio — only the prefill vectors."""
|
||
import numpy as np
|
||
|
||
captured = []
|
||
talker = tts.model.talker
|
||
original_forward = talker.model.forward
|
||
|
||
def patched_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
||
if inputs_embeds is not None and inputs_embeds.dim() == 3:
|
||
t = inputs_embeds.shape[1]
|
||
for i in range(t):
|
||
captured.append(
|
||
inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32)
|
||
)
|
||
raise _PrefillCapturedSentinel()
|
||
return original_forward(
|
||
input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs
|
||
)
|
||
|
||
talker.model.forward = patched_forward
|
||
try:
|
||
try:
|
||
tts.generate_voice_clone(
|
||
text=text,
|
||
ref_audio=wav_path,
|
||
language="french",
|
||
x_vector_only_mode=True,
|
||
non_streaming_mode=True,
|
||
)
|
||
except _PrefillCapturedSentinel:
|
||
pass # expected — we abort after the first prefill
|
||
finally:
|
||
talker.model.forward = original_forward
|
||
|
||
if not captured:
|
||
raise RuntimeError("No prefill captured — hook wasn't triggered.")
|
||
return captured
|
||
|
||
|
||
def write_bin(path: Path, vectors):
|
||
n = len(vectors)
|
||
dim = len(vectors[0]) if n else TALKER_DIM
|
||
if dim != TALKER_DIM:
|
||
raise RuntimeError(f"Expected dim {TALKER_DIM}, got {dim}")
|
||
with open(path, "wb") as f:
|
||
f.write(struct.pack("<ii", n, dim))
|
||
for v in vectors:
|
||
f.write(struct.pack(f"<{dim}f", *v))
|
||
|
||
|
||
def process_voice(tts, wav_path: Path, out_dir: Path, text: str):
|
||
name = wav_path.stem.lower().split("_")[0] # "damien_15s_24k" → "damien"
|
||
prefix_path = out_dir / f"{name}_voice_prefix.bin"
|
||
suffix_path = out_dir / f"{name}_voice_suffix.bin"
|
||
if prefix_path.exists() and suffix_path.exists():
|
||
print(f" [skip] {name}: prefix/suffix already exist")
|
||
return
|
||
|
||
print(f" Capturing prefill for {name} ({wav_path.name})...", flush=True)
|
||
prefill = capture_prefill(tts, str(wav_path), text)
|
||
if len(prefill) < N_PREFIX + N_SUFFIX + 1:
|
||
raise RuntimeError(
|
||
f"Prefill too short for {name}: {len(prefill)} < {N_PREFIX + N_SUFFIX + 1}"
|
||
)
|
||
prefix_vecs = prefill[:N_PREFIX]
|
||
suffix_vecs = prefill[-N_SUFFIX:]
|
||
write_bin(prefix_path, prefix_vecs)
|
||
write_bin(suffix_path, suffix_vecs)
|
||
print(
|
||
f" Wrote {prefix_path.name} ({N_PREFIX}×{TALKER_DIM}) "
|
||
f"and {suffix_path.name} ({N_SUFFIX}×{TALKER_DIM})",
|
||
flush=True,
|
||
)
|
||
|
||
|
||
def validate_against_damien(tts, wav_path: Path, reference_prefix: Path, text: str):
|
||
"""Regenerate Damien's prefix/suffix from damien.wav and diff against
|
||
the reference files shipped on the tablet. Confirms this script's
|
||
slicing reproduces the original format."""
|
||
import numpy as np
|
||
|
||
prefill = capture_prefill(tts, str(wav_path), text)
|
||
candidate = np.array(prefill[:N_PREFIX], dtype=np.float32)
|
||
|
||
with open(reference_prefix, "rb") as f:
|
||
n, d = struct.unpack("<ii", f.read(8))
|
||
ref = np.frombuffer(f.read(n * d * 4), dtype=np.float32).reshape(n, d)
|
||
|
||
diff = np.abs(candidate - ref)
|
||
print(
|
||
f"Damien prefix validation: max|diff|={diff.max():.3e} "
|
||
f"mean|diff|={diff.mean():.3e} (expect ~0 if script is correct)"
|
||
)
|
||
|
||
|
||
def main():
|
||
p = argparse.ArgumentParser()
|
||
p.add_argument("wavs", nargs="+", help="Voice WAV files")
|
||
p.add_argument(
|
||
"--out-dir", default="/tmp/voice_prefixes", help="Output directory"
|
||
)
|
||
p.add_argument(
|
||
"--text", default="Bonjour.", help="Reference utterance for prefill"
|
||
)
|
||
p.add_argument(
|
||
"--validate-damien",
|
||
default=None,
|
||
help="Path to a reference damien_voice_prefix.bin for sanity-check",
|
||
)
|
||
args = p.parse_args()
|
||
|
||
out_dir = Path(args.out_dir)
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
tts = load_model()
|
||
|
||
if args.validate_damien:
|
||
damien_wav = next(
|
||
(Path(w) for w in args.wavs if "damien" in Path(w).stem.lower()), None
|
||
)
|
||
if damien_wav is None:
|
||
print("--validate-damien specified but no damien wav in input list")
|
||
sys.exit(1)
|
||
validate_against_damien(tts, damien_wav, Path(args.validate_damien), args.text)
|
||
|
||
for wav in args.wavs:
|
||
wp = Path(wav)
|
||
if not wp.exists():
|
||
print(f" [miss] {wp}")
|
||
continue
|
||
try:
|
||
process_voice(tts, wp, out_dir, args.text)
|
||
except Exception as e:
|
||
print(f" [fail] {wp.name}: {e}")
|
||
|
||
print(f"\nDone. Files written under {out_dir}")
|
||
print(
|
||
"Push to the tablet with, e.g.:\n"
|
||
f" adb push {out_dir}/*_voice_prefix.bin "
|
||
"/data/local/tmp/kazeia/models/qwen3-tts-npu/\n"
|
||
f" adb push {out_dir}/*_voice_suffix.bin "
|
||
"/data/local/tmp/kazeia/models/qwen3-tts-npu/"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|