137 lines
5.4 KiB
Python
137 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Generate TTS embeddings for long text, split into sentence segments.
|
|
Each segment is generated independently by Python for maximum quality.
|
|
|
|
Usage: python3 prepare_tts_segments.py "Long text..." [output.bin]
|
|
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
|
|
|
|
Output format:
|
|
int32 n_segments
|
|
for each segment:
|
|
int32 n_prefill
|
|
int32 n_total
|
|
float32[n_total * 1024] embeddings
|
|
"""
|
|
import sys, os, struct, re, types, warnings
|
|
os.chdir("/tmp")
|
|
warnings.filterwarnings("ignore")
|
|
|
|
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour. Je suis Kazeia."
|
|
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_segments.bin"
|
|
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
|
|
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
|
|
|
import torch, numpy as np
|
|
from qwen_tts import Qwen3TTSModel
|
|
|
|
def split_sentences(text, max_chars=120):
|
|
"""Split text into SHORT segments (~40-50 tokens max). Each sentence separate."""
|
|
# Split at every sentence boundary
|
|
parts = re.split(r'(?<=[.!?;:])\s+', text.strip())
|
|
|
|
# Further split long sentences at commas
|
|
final = []
|
|
for part in parts:
|
|
if len(part) > max_chars:
|
|
subs = re.split(r'(?<=,)\s+', part)
|
|
current = ""
|
|
for s in subs:
|
|
if current and len(current) + len(s) > max_chars:
|
|
final.append(current.strip())
|
|
current = s
|
|
else:
|
|
current = (current + " " + s).strip() if current else s
|
|
if current.strip():
|
|
final.append(current.strip())
|
|
else:
|
|
final.append(part)
|
|
|
|
return [s for s in final if s.strip()] if final else [text]
|
|
|
|
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
|
segments = split_sentences(TEXT)
|
|
print(f"Split into {len(segments)} segments:")
|
|
for i, s in enumerate(segments):
|
|
print(f" [{i}] '{s[:60]}{'...' if len(s)>60 else ''}'")
|
|
|
|
print("\nLoading model...")
|
|
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
|
talker = tts.model.talker
|
|
|
|
# Capture generation inputs. The previous version of this script only captured
|
|
# 1-token decode calls and reassembled them with a generic 9-embed "prefill_base"
|
|
# pulled from another file — dropping the per-segment prefill that contains the
|
|
# xvector conditioning AND the text-encoded embeddings. With that generic prefix
|
|
# the talker had no idea which sentence to produce → Hexagon output was garbled.
|
|
# Fix: capture the MULTI-token prefill call too (first call has shape[1] > 1),
|
|
# exactly like prepare_tts_voiceclone.py does. Each segment becomes self-contained.
|
|
captured_embeds = [] # 1024-dim vectors in order
|
|
call_shapes = [] # length of each talker.model.forward call
|
|
original_model_forward = talker.model.forward
|
|
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
|
if inputs_embeds is not None and inputs_embeds.dim() == 3:
|
|
t = inputs_embeds.shape[1]
|
|
call_shapes.append(t)
|
|
for j in range(t):
|
|
captured_embeds.append(inputs_embeds[0, j, :].detach().cpu().numpy().astype(np.float32))
|
|
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
|
talker.model.forward = patched_model_forward
|
|
|
|
# Generate embeds for each segment
|
|
all_segments_data = []
|
|
total_audio_duration = 0
|
|
|
|
for i, seg_text in enumerate(segments):
|
|
print(f"\n--- Segment {i+1}/{len(segments)} ---")
|
|
print(f" '{seg_text[:60]}...'")
|
|
|
|
captured_embeds.clear()
|
|
call_shapes.clear()
|
|
audio_list, sr = tts.generate_voice_clone(
|
|
text=seg_text, ref_audio=VOICE, language="french",
|
|
x_vector_only_mode=True, non_streaming_mode=True,
|
|
)
|
|
audio_dur = len(audio_list[0]) / sr
|
|
total_audio_duration += audio_dur
|
|
|
|
if not call_shapes:
|
|
print(f" WARNING: no calls captured, skipping")
|
|
continue
|
|
|
|
# First call = multi-token prefill (xvector + text encoding + role tokens).
|
|
# Subsequent calls = single-token decode (one per generated codec frame).
|
|
nPrefill = call_shapes[0]
|
|
nTotal = len(captured_embeds)
|
|
if nTotal <= nPrefill:
|
|
print(f" WARNING: no decode steps captured, skipping")
|
|
continue
|
|
|
|
seg_data = {
|
|
'nPrefill': nPrefill,
|
|
'nTotal': nTotal,
|
|
'embeds': captured_embeds.copy(),
|
|
'audio_dur': audio_dur,
|
|
}
|
|
all_segments_data.append(seg_data)
|
|
print(f" {nTotal} embeds ({nPrefill} prefill + {nTotal - nPrefill} decode), {audio_dur:.2f}s audio")
|
|
|
|
# Write multi-segment file. Per-segment layout matches single-segment format so
|
|
# the tablet can read either shape with the same parser.
|
|
with open(OUTPUT, "wb") as f:
|
|
f.write(struct.pack("<i", len(all_segments_data)))
|
|
for seg in all_segments_data:
|
|
f.write(struct.pack("<i", seg['nPrefill']))
|
|
f.write(struct.pack("<i", seg['nTotal']))
|
|
for emb in seg['embeds']:
|
|
f.write(emb.tobytes())
|
|
|
|
sz = os.path.getsize(OUTPUT)
|
|
total_tokens = sum(s['nTotal'] - s['nPrefill'] for s in all_segments_data)
|
|
print(f"\n=== RESULT ===")
|
|
print(f"Segments: {len(all_segments_data)}")
|
|
print(f"Total tokens: {total_tokens}")
|
|
print(f"Total audio: {total_audio_duration:.2f}s")
|
|
print(f"File: {OUTPUT} ({sz/1024:.0f}KB)")
|
|
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|