kazeia/scripts/prepare_tts_native.py

128 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Generate text-only TTS embeddings for FULL C++ native pipeline.
No Python model generation needed — just tokenize + text_projection.
Usage: python3 prepare_tts_native.py "Your text here" [output.bin]
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
Mirrors Python qwen_tts protocol exactly:
trailing = text_proj[1:] (no eos padding — C++ adds 1×eos then pad_embed itself)
Stop = natural codec_eos_token_id (handled in C++)
"""
import sys, os, struct, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
GOLDEN_PREFILL = "/tmp/existing_embeds.bin"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
MAX_SEGMENT_TOKENS = 20 # ~70 audio steps + prefill = ~80, fits KV_LEN=100 with margin
import torch, numpy as np, re
from qwen_tts import Qwen3TTSModel
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
tokenizer = tts.processor.tokenizer
# Load golden prefill + codec/eos
if not os.path.exists(GOLDEN_PREFILL):
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}")
with open(GOLDEN_PREFILL, "rb") as f:
nP = struct.unpack("<i", f.read(4))[0]
nT = struct.unpack("<i", f.read(4))[0]
golden = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
ce = np.load("/tmp/ce.npy", allow_pickle=True).reshape(-1, 1024)
sp = np.load("/tmp/tts_special.npy").reshape(3, 1024)
eos = sp[1].astype(np.float32)
CODEC_BOS = 2149
def split_text(text, max_tokens):
"""Split text at sentence/clause boundaries, keeping each segment under max_tokens."""
# Split at sentence boundaries first
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
segments = []
for sent in sentences:
tokens = tokenizer.encode(sent, add_special_tokens=False)
if len(tokens) <= max_tokens:
segments.append(sent)
else:
# Split long sentence at commas
parts = re.split(r'(?<=,)\s+', sent)
current = ""
for part in parts:
test = (current + " " + part).strip() if current else part
if len(tokenizer.encode(test, add_special_tokens=False)) > max_tokens and current:
segments.append(current.strip())
current = part
else:
current = test
if current.strip():
segments.append(current.strip())
return [s for s in segments if s.strip()]
def make_segment(text_segment):
"""Build embeds for one segment.
Mirrors Python qwen_tts: trailing = text_proj[1:] (no padding).
C++ then adds 1×eos after exhausting trailing, then pad_embed, and stops on natural EOS.
"""
tokens = tokenizer.encode(text_segment, add_special_tokens=False)
with torch.no_grad():
proj = talker.text_projection(
talker.get_text_embeddings()(torch.tensor([tokens]))
)[0].numpy().astype(np.float32)
trailing = [proj[i] for i in range(1, len(proj))] # text[1:], no eos here
return {
'tokens': len(tokens),
'proj0': proj[0],
'trailing': trailing,
}
# Split text into segments
segments = split_text(TEXT, MAX_SEGMENT_TOKENS)
print(f"Segments: {len(segments)}")
for i, s in enumerate(segments):
n = len(tokenizer.encode(s, add_special_tokens=False))
print(f" [{i}] ({n} tok) '{s[:60]}{'...' if len(s)>60 else ''}'")
# Generate embeds per segment
seg_data = [make_segment(s) for s in segments]
if len(seg_data) == 1:
# Single segment: legacy format
s = seg_data[0]
nPrefill = 10
nTotal = nPrefill + len(s['trailing'])
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
for i in range(9): f.write(golden[i].tobytes())
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
print(f"\nSingle segment: {nTotal} embeds")
else:
# Multi-segment format
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", len(seg_data)))
for s in seg_data:
nPrefill = 10
nTotal = nPrefill + len(s['trailing'])
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
for i in range(9): f.write(golden[i].tobytes())
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
print(f"\nMulti-segment: {len(seg_data)} segments")
total_trailing = sum(len(s['trailing']) for s in seg_data)
print(f"Total audio: ~{total_trailing * 0.08:.1f}s estimated")
print(f"Saved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f}KB)")
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")