kazeia/scripts/prepare_tts_native.py

127 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Generate text-only TTS embeddings for FULL C++ native pipeline.
No Python model generation needed — just tokenize + text_projection.
Usage: python3 prepare_tts_native.py "Your text here" [output.bin]
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
Formula: trailing = text_proj[1:] + eos_padding(n_tokens × 4 total)
maxTokens = trailing_count (cut after trailing exhausted)
"""
import sys, os, struct, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
GOLDEN_PREFILL = "/tmp/existing_embeds.bin"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window)
import torch, numpy as np, re
from qwen_tts import Qwen3TTSModel
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
tokenizer = tts.processor.tokenizer
# Load golden prefill + codec/eos
if not os.path.exists(GOLDEN_PREFILL):
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}")
with open(GOLDEN_PREFILL, "rb") as f:
nP = struct.unpack("<i", f.read(4))[0]
nT = struct.unpack("<i", f.read(4))[0]
golden = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
ce = np.load("/tmp/ce.npy", allow_pickle=True).reshape(-1, 1024)
sp = np.load("/tmp/tts_special.npy").reshape(3, 1024)
eos = sp[1].astype(np.float32)
CODEC_BOS = 2149
def split_text(text, max_tokens):
"""Split text at sentence/clause boundaries, keeping each segment under max_tokens."""
# Split at sentence boundaries first
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
segments = []
for sent in sentences:
tokens = tokenizer.encode(sent, add_special_tokens=False)
if len(tokens) <= max_tokens:
segments.append(sent)
else:
# Split long sentence at commas
parts = re.split(r'(?<=,)\s+', sent)
current = ""
for part in parts:
test = (current + " " + part).strip() if current else part
if len(tokenizer.encode(test, add_special_tokens=False)) > max_tokens and current:
segments.append(current.strip())
current = part
else:
current = test
if current.strip():
segments.append(current.strip())
return [s for s in segments if s.strip()]
def make_segment(text_segment):
"""Build embeds for one segment."""
tokens = tokenizer.encode(text_segment, add_special_tokens=False)
with torch.no_grad():
proj = talker.text_projection(
talker.get_text_embeddings()(torch.tensor([tokens]))
)[0].numpy().astype(np.float32)
target_len = max(int(len(tokens) * 3.2) + 5, 40)
trailing = [proj[i] for i in range(1, len(proj))]
while len(trailing) < target_len:
trailing.append(eos)
return {
'tokens': len(tokens),
'proj0': proj[0],
'trailing': trailing,
}
# Split text into segments
segments = split_text(TEXT, MAX_SEGMENT_TOKENS)
print(f"Segments: {len(segments)}")
for i, s in enumerate(segments):
n = len(tokenizer.encode(s, add_special_tokens=False))
print(f" [{i}] ({n} tok) '{s[:60]}{'...' if len(s)>60 else ''}'")
# Generate embeds per segment
seg_data = [make_segment(s) for s in segments]
if len(seg_data) == 1:
# Single segment: legacy format
s = seg_data[0]
nPrefill = 10
nTotal = nPrefill + len(s['trailing'])
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
for i in range(9): f.write(golden[i].tobytes())
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
print(f"\nSingle segment: {nTotal} embeds")
else:
# Multi-segment format
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", len(seg_data)))
for s in seg_data:
nPrefill = 10
nTotal = nPrefill + len(s['trailing'])
f.write(struct.pack("<i", nPrefill))
f.write(struct.pack("<i", nTotal))
for i in range(9): f.write(golden[i].tobytes())
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
print(f"\nMulti-segment: {len(seg_data)} segments")
total_trailing = sum(len(s['trailing']) for s in seg_data)
print(f"Total audio: ~{total_trailing * 0.08:.1f}s estimated")
print(f"Saved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f}KB)")
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")