127 lines
5.0 KiB
Python
127 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Generate text-only TTS embeddings for FULL C++ native pipeline.
|
||
No Python model generation needed — just tokenize + text_projection.
|
||
|
||
Usage: python3 prepare_tts_native.py "Your text here" [output.bin]
|
||
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
|
||
|
||
Formula: trailing = text_proj[1:] + eos_padding(n_tokens × 4 total)
|
||
maxTokens = trailing_count (cut after trailing exhausted)
|
||
"""
|
||
import sys, os, struct, warnings
|
||
os.chdir("/tmp")
|
||
warnings.filterwarnings("ignore")
|
||
|
||
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
|
||
GOLDEN_PREFILL = "/tmp/existing_embeds.bin"
|
||
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||
MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window)
|
||
|
||
import torch, numpy as np, re
|
||
from qwen_tts import Qwen3TTSModel
|
||
|
||
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||
|
||
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||
talker = tts.model.talker
|
||
tokenizer = tts.processor.tokenizer
|
||
|
||
# Load golden prefill + codec/eos
|
||
if not os.path.exists(GOLDEN_PREFILL):
|
||
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}")
|
||
with open(GOLDEN_PREFILL, "rb") as f:
|
||
nP = struct.unpack("<i", f.read(4))[0]
|
||
nT = struct.unpack("<i", f.read(4))[0]
|
||
golden = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
|
||
ce = np.load("/tmp/ce.npy", allow_pickle=True).reshape(-1, 1024)
|
||
sp = np.load("/tmp/tts_special.npy").reshape(3, 1024)
|
||
eos = sp[1].astype(np.float32)
|
||
CODEC_BOS = 2149
|
||
|
||
def split_text(text, max_tokens):
|
||
"""Split text at sentence/clause boundaries, keeping each segment under max_tokens."""
|
||
# Split at sentence boundaries first
|
||
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||
|
||
segments = []
|
||
for sent in sentences:
|
||
tokens = tokenizer.encode(sent, add_special_tokens=False)
|
||
if len(tokens) <= max_tokens:
|
||
segments.append(sent)
|
||
else:
|
||
# Split long sentence at commas
|
||
parts = re.split(r'(?<=,)\s+', sent)
|
||
current = ""
|
||
for part in parts:
|
||
test = (current + " " + part).strip() if current else part
|
||
if len(tokenizer.encode(test, add_special_tokens=False)) > max_tokens and current:
|
||
segments.append(current.strip())
|
||
current = part
|
||
else:
|
||
current = test
|
||
if current.strip():
|
||
segments.append(current.strip())
|
||
return [s for s in segments if s.strip()]
|
||
|
||
def make_segment(text_segment):
|
||
"""Build embeds for one segment."""
|
||
tokens = tokenizer.encode(text_segment, add_special_tokens=False)
|
||
with torch.no_grad():
|
||
proj = talker.text_projection(
|
||
talker.get_text_embeddings()(torch.tensor([tokens]))
|
||
)[0].numpy().astype(np.float32)
|
||
|
||
target_len = max(int(len(tokens) * 3.2) + 5, 40)
|
||
trailing = [proj[i] for i in range(1, len(proj))]
|
||
while len(trailing) < target_len:
|
||
trailing.append(eos)
|
||
|
||
return {
|
||
'tokens': len(tokens),
|
||
'proj0': proj[0],
|
||
'trailing': trailing,
|
||
}
|
||
|
||
# Split text into segments
|
||
segments = split_text(TEXT, MAX_SEGMENT_TOKENS)
|
||
print(f"Segments: {len(segments)}")
|
||
for i, s in enumerate(segments):
|
||
n = len(tokenizer.encode(s, add_special_tokens=False))
|
||
print(f" [{i}] ({n} tok) '{s[:60]}{'...' if len(s)>60 else ''}'")
|
||
|
||
# Generate embeds per segment
|
||
seg_data = [make_segment(s) for s in segments]
|
||
|
||
if len(seg_data) == 1:
|
||
# Single segment: legacy format
|
||
s = seg_data[0]
|
||
nPrefill = 10
|
||
nTotal = nPrefill + len(s['trailing'])
|
||
with open(OUTPUT, "wb") as f:
|
||
f.write(struct.pack("<i", nPrefill))
|
||
f.write(struct.pack("<i", nTotal))
|
||
for i in range(9): f.write(golden[i].tobytes())
|
||
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
|
||
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
|
||
print(f"\nSingle segment: {nTotal} embeds")
|
||
else:
|
||
# Multi-segment format
|
||
with open(OUTPUT, "wb") as f:
|
||
f.write(struct.pack("<i", len(seg_data)))
|
||
for s in seg_data:
|
||
nPrefill = 10
|
||
nTotal = nPrefill + len(s['trailing'])
|
||
f.write(struct.pack("<i", nPrefill))
|
||
f.write(struct.pack("<i", nTotal))
|
||
for i in range(9): f.write(golden[i].tobytes())
|
||
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
|
||
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
|
||
print(f"\nMulti-segment: {len(seg_data)} segments")
|
||
|
||
total_trailing = sum(len(s['trailing']) for s in seg_data)
|
||
print(f"Total audio: ~{total_trailing * 0.08:.1f}s estimated")
|
||
print(f"Saved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f}KB)")
|
||
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|