kazeia/scripts/prepare_tts_segments.py

137 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""
Generate TTS embeddings for long text, split into sentence segments.
Each segment is generated independently by Python for maximum quality.
Usage: python3 prepare_tts_segments.py "Long text..." [output.bin]
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
Output format:
int32 n_segments
for each segment:
int32 n_prefill
int32 n_total
float32[n_total * 1024] embeddings
"""
import sys, os, struct, re, types, warnings
os.chdir("/tmp")
warnings.filterwarnings("ignore")
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour. Je suis Kazeia."
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_segments.bin"
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
import torch, numpy as np
from qwen_tts import Qwen3TTSModel
def split_sentences(text, max_chars=120):
"""Split text into SHORT segments (~40-50 tokens max). Each sentence separate."""
# Split at every sentence boundary
parts = re.split(r'(?<=[.!?;:])\s+', text.strip())
# Further split long sentences at commas
final = []
for part in parts:
if len(part) > max_chars:
subs = re.split(r'(?<=,)\s+', part)
current = ""
for s in subs:
if current and len(current) + len(s) > max_chars:
final.append(current.strip())
current = s
else:
current = (current + " " + s).strip() if current else s
if current.strip():
final.append(current.strip())
else:
final.append(part)
return [s for s in final if s.strip()] if final else [text]
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
segments = split_sentences(TEXT)
print(f"Split into {len(segments)} segments:")
for i, s in enumerate(segments):
print(f" [{i}] '{s[:60]}{'...' if len(s)>60 else ''}'")
print("\nLoading model...")
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
talker = tts.model.talker
# Capture generation inputs via monkey-patch on inner model
captured_inputs = []
original_model_forward = talker.model.forward
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and inputs_embeds.shape[1] == 1:
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32))
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
talker.model.forward = patched_model_forward
# Load prefill structure
EXISTING = "/tmp/existing_embeds.bin"
if not os.path.exists(EXISTING):
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
with open(EXISTING, "rb") as f:
nP = struct.unpack("<i", f.read(4))[0]
nT = struct.unpack("<i", f.read(4))[0]
old_embeds = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
prefill_base = old_embeds[:9] # role+ctrl+spk+bos
# Generate embeds for each segment
all_segments_data = []
total_audio_duration = 0
for i, seg_text in enumerate(segments):
print(f"\n--- Segment {i+1}/{len(segments)} ---")
print(f" '{seg_text[:60]}...'")
captured_inputs.clear()
audio_list, sr = tts.generate_voice_clone(
text=seg_text, ref_audio=VOICE, language="french",
x_vector_only_mode=True, non_streaming_mode=True,
)
audio_dur = len(audio_list[0]) / sr
total_audio_duration += audio_dur
if len(captured_inputs) < 2:
print(f" WARNING: Only {len(captured_inputs)} steps, skipping")
continue
nPrefill = 10 # 9 base + first gen input
nDecode = len(captured_inputs) - 1
nTotal = nPrefill + nDecode
seg_data = {
'nPrefill': nPrefill,
'nTotal': nTotal,
'prefill': prefill_base.copy(),
'first_gen': captured_inputs[0],
'decode': captured_inputs[1:],
'audio_dur': audio_dur,
'n_tokens': len(captured_inputs),
}
all_segments_data.append(seg_data)
print(f" {nTotal} embeds ({len(captured_inputs)} tokens), {audio_dur:.2f}s audio")
# Write multi-segment file
with open(OUTPUT, "wb") as f:
f.write(struct.pack("<i", len(all_segments_data)))
for seg in all_segments_data:
f.write(struct.pack("<i", seg['nPrefill']))
f.write(struct.pack("<i", seg['nTotal']))
for emb in seg['prefill']:
f.write(emb.tobytes())
f.write(seg['first_gen'].tobytes())
for emb in seg['decode']:
f.write(emb.tobytes())
sz = os.path.getsize(OUTPUT)
total_tokens = sum(s['n_tokens'] for s in all_segments_data)
print(f"\n=== RESULT ===")
print(f"Segments: {len(all_segments_data)}")
print(f"Total tokens: {total_tokens}")
print(f"Total audio: {total_audio_duration:.2f}s")
print(f"File: {OUTPUT} ({sz/1024:.0f}KB)")
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")