diff --git a/scripts/prepare_tts_native.py b/scripts/prepare_tts_native.py index d772c76..0c57a4b 100644 --- a/scripts/prepare_tts_native.py +++ b/scripts/prepare_tts_native.py @@ -15,69 +15,112 @@ warnings.filterwarnings("ignore") TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia." OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin" -GOLDEN_PREFILL = "/tmp/existing_embeds.bin" # Must exist (captured on-device once) +GOLDEN_PREFILL = "/tmp/existing_embeds.bin" MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" +MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window) -import torch, numpy as np +import torch, numpy as np, re from qwen_tts import Qwen3TTSModel print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'") -# Load model (just for tokenizer + text_projection) tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") talker = tts.model.talker tokenizer = tts.processor.tokenizer -# Tokenize + project -tokens = tokenizer.encode(TEXT, add_special_tokens=False) -with torch.no_grad(): - proj = talker.text_projection( - talker.get_text_embeddings()(torch.tensor([tokens])) - )[0].numpy().astype(np.float32) -print(f"Tokens: {len(tokens)}") - -# Load golden prefill[0:9] (captured on-device, text-independent) +# Load golden prefill + codec/eos if not os.path.exists(GOLDEN_PREFILL): os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}") with open(GOLDEN_PREFILL, "rb") as f: nP = struct.unpack(" max_tokens and current: + segments.append(current.strip()) + current = part + else: + current = test + if current.strip(): + segments.append(current.strip()) + return [s for s in segments if s.strip()] -# Build file -nPrefill = 10 -nTotal = nPrefill + len(trailing) +def make_segment(text_segment): + """Build embeds for one segment.""" + tokens = tokenizer.encode(text_segment, add_special_tokens=False) + with torch.no_grad(): + proj = talker.text_projection( + talker.get_text_embeddings()(torch.tensor([tokens])) + )[0].numpy().astype(np.float32) -with open(OUTPUT, "wb") as f: - f.write(struct.pack("60 else ''}'") + +# Generate embeds per segment +seg_data = [make_segment(s) for s in segments] + +if len(seg_data) == 1: + # Single segment: legacy format + s = seg_data[0] + nPrefill = 10 + nTotal = nPrefill + len(s['trailing']) + with open(OUTPUT, "wb") as f: + f.write(struct.pack("