diff --git a/scripts/export_voice_prefix_suffix.py b/scripts/export_voice_prefix_suffix.py new file mode 100644 index 0000000..4e5c875 --- /dev/null +++ b/scripts/export_voice_prefix_suffix.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +""" +Generate per-voice _voice_prefix.bin (9 × 1024 fp32) and +_voice_suffix.bin (2 × 1024 fp32) for Kazeia's on-device TTS +engine (Qwen3-TTS 0.6B-Base voice-clone mode). + +The on-device pipeline concatenates prefix + text-embeds + suffix as +the talker's prefill. The prefix is the voice-conditioning preamble +produced by the Qwen3TTS model when run with `x_vector_only_mode=True` +on a short reference phrase — it carries the speaker x-vector and the +leading ChatML / transcript tokens that precede user text. The suffix +is the closing tokens that sit right after user text (end-of-turn, +assistant-ready marker). + +Approach: run the model once per voice on a fixed short utterance, +capture every talker input embedding of the first (multi-token) +prefill call via a forward hook — that's the full prefill sequence. +The reference Damien files contain exactly 9 pre-text embeds + 2 +post-text embeds, which corresponds to: + + [prefix: 9 vectors] [text embeds: N vectors] [suffix: 2 vectors] + +We BPE-tokenize the same utterance with Qwen3TTS's own tokenizer to +find where the text tokens start and end inside the prefill, then +slice out the preceding 9 and trailing 2 vectors. This makes the +split robust to tokenizer changes and matches the Damien files +bit-identically (verified during the first run: /tmp/check_damien_*). + +Usage: + export_voice_prefix_suffix.py VOICE.wav [VOICE.wav ...] + --out-dir /path/to/output (default /tmp/voice_prefixes) + --text "Bonjour." (reference utterance; short is ok) + +The output file names are `_voice_prefix.bin` +and `_voice_suffix.bin`. Push them to +/data/local/tmp/kazeia/models/qwen3-tts-npu/ to activate the voice +in-app (Qwen3TtsEngine.setVoice reads them from there). +""" + +import argparse +import os +import struct +import sys +import warnings +from pathlib import Path + +warnings.filterwarnings("ignore") +# NOTE: don't chdir() here — the WAV paths in argv are resolved against +# the user's cwd. Qwen3TTS creates /tmp scratch files internally already. + +MODEL_PATH = ( + "/home/alf/.cache/huggingface/hub/" + "models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/" + "5d83992436eae1d760afd27aff78a71d676296fc" +) + +# Prefix + suffix sizes taken from the reference damien_voice_prefix.bin / +# damien_voice_suffix.bin shipped on the tablet. If Qwen3TTS ever changes +# its chat template these may need to be re-checked — run the script +# with `--validate-damien damien_voice_prefix.bin` to diff against a +# known-good capture. +N_PREFIX = 9 +N_SUFFIX = 2 +TALKER_DIM = 1024 + + +def load_model(): + import torch + from qwen_tts import Qwen3TTSModel + + print(f"Loading Qwen3-TTS model from {MODEL_PATH}...", flush=True) + tts = Qwen3TTSModel.from_pretrained( + MODEL_PATH, local_files_only=True, device_map="cpu" + ) + return tts + + +class _PrefillCapturedSentinel(Exception): + """Raised after the first prefill so we can abort generate_voice_clone + without waiting for the (very slow on CPU) full TTS decode.""" + + +def capture_prefill(tts, wav_path: str, text: str): + """Run generate_voice_clone just far enough to capture the first + (prefill) call's talker input embeddings, then abort. Doing the full + non-streaming decode would take several minutes per voice on CPU and + we don't need any of the audio — only the prefill vectors.""" + import numpy as np + + captured = [] + talker = tts.model.talker + original_forward = talker.model.forward + + def patched_forward(input_ids=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None and inputs_embeds.dim() == 3: + t = inputs_embeds.shape[1] + for i in range(t): + captured.append( + inputs_embeds[0, i, :].detach().cpu().numpy().astype(np.float32) + ) + raise _PrefillCapturedSentinel() + return original_forward( + input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs + ) + + talker.model.forward = patched_forward + try: + try: + tts.generate_voice_clone( + text=text, + ref_audio=wav_path, + language="french", + x_vector_only_mode=True, + non_streaming_mode=True, + ) + except _PrefillCapturedSentinel: + pass # expected — we abort after the first prefill + finally: + talker.model.forward = original_forward + + if not captured: + raise RuntimeError("No prefill captured — hook wasn't triggered.") + return captured + + +def write_bin(path: Path, vectors): + n = len(vectors) + dim = len(vectors[0]) if n else TALKER_DIM + if dim != TALKER_DIM: + raise RuntimeError(f"Expected dim {TALKER_DIM}, got {dim}") + with open(path, "wb") as f: + f.write(struct.pack("