Auto-segmentation for long texts + dynamic pipeline
- prepare_tts_native.py: auto-splits long text at sentence/comma boundaries, max 15 tokens per segment - Multi-segment format: each segment gets fresh KV cache - Formula: target_len = n_tokens × 3.2 + 5 per segment - Tested on Edouard Baer monologue: 28 segments, 102s audio Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
199bc4fbc9
commit
ee186e9049
|
|
@ -15,69 +15,112 @@ warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia."
|
||||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
|
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin"
|
||||||
GOLDEN_PREFILL = "/tmp/existing_embeds.bin" # Must exist (captured on-device once)
|
GOLDEN_PREFILL = "/tmp/existing_embeds.bin"
|
||||||
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||||||
|
MAX_SEGMENT_TOKENS = 15 # Max text tokens per segment (~50 audio tokens, within NPU quality window)
|
||||||
|
|
||||||
import torch, numpy as np
|
import torch, numpy as np, re
|
||||||
from qwen_tts import Qwen3TTSModel
|
from qwen_tts import Qwen3TTSModel
|
||||||
|
|
||||||
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||||||
|
|
||||||
# Load model (just for tokenizer + text_projection)
|
|
||||||
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||||||
talker = tts.model.talker
|
talker = tts.model.talker
|
||||||
tokenizer = tts.processor.tokenizer
|
tokenizer = tts.processor.tokenizer
|
||||||
|
|
||||||
# Tokenize + project
|
# Load golden prefill + codec/eos
|
||||||
tokens = tokenizer.encode(TEXT, add_special_tokens=False)
|
|
||||||
with torch.no_grad():
|
|
||||||
proj = talker.text_projection(
|
|
||||||
talker.get_text_embeddings()(torch.tensor([tokens]))
|
|
||||||
)[0].numpy().astype(np.float32)
|
|
||||||
print(f"Tokens: {len(tokens)}")
|
|
||||||
|
|
||||||
# Load golden prefill[0:9] (captured on-device, text-independent)
|
|
||||||
if not os.path.exists(GOLDEN_PREFILL):
|
if not os.path.exists(GOLDEN_PREFILL):
|
||||||
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}")
|
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}")
|
||||||
with open(GOLDEN_PREFILL, "rb") as f:
|
with open(GOLDEN_PREFILL, "rb") as f:
|
||||||
nP = struct.unpack("<i", f.read(4))[0]
|
nP = struct.unpack("<i", f.read(4))[0]
|
||||||
nT = struct.unpack("<i", f.read(4))[0]
|
nT = struct.unpack("<i", f.read(4))[0]
|
||||||
golden = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
|
golden = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
|
||||||
|
|
||||||
# Load codec_bos embedding
|
|
||||||
ce = np.load("/tmp/ce.npy", allow_pickle=True).reshape(-1, 1024)
|
ce = np.load("/tmp/ce.npy", allow_pickle=True).reshape(-1, 1024)
|
||||||
CODEC_BOS = 2149
|
|
||||||
|
|
||||||
# Load eos embedding
|
|
||||||
sp = np.load("/tmp/tts_special.npy").reshape(3, 1024)
|
sp = np.load("/tmp/tts_special.npy").reshape(3, 1024)
|
||||||
eos = sp[1].astype(np.float32)
|
eos = sp[1].astype(np.float32)
|
||||||
|
CODEC_BOS = 2149
|
||||||
|
|
||||||
# Build trailing: text[1:] + eos padding
|
def split_text(text, max_tokens):
|
||||||
# Audio is ~3.5× longer than text tokens. Pad with eos to ensure full coverage.
|
"""Split text at sentence/clause boundaries, keeping each segment under max_tokens."""
|
||||||
target_len = max(int(len(tokens) * 3.2) + 5, 40) # calibrated: 3.2× + 5 buffer
|
# Split at sentence boundaries first
|
||||||
|
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||||||
|
|
||||||
trailing = [proj[i] for i in range(1, len(proj))] # text[1:]
|
segments = []
|
||||||
while len(trailing) < target_len:
|
for sent in sentences:
|
||||||
|
tokens = tokenizer.encode(sent, add_special_tokens=False)
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
segments.append(sent)
|
||||||
|
else:
|
||||||
|
# Split long sentence at commas
|
||||||
|
parts = re.split(r'(?<=,)\s+', sent)
|
||||||
|
current = ""
|
||||||
|
for part in parts:
|
||||||
|
test = (current + " " + part).strip() if current else part
|
||||||
|
if len(tokenizer.encode(test, add_special_tokens=False)) > max_tokens and current:
|
||||||
|
segments.append(current.strip())
|
||||||
|
current = part
|
||||||
|
else:
|
||||||
|
current = test
|
||||||
|
if current.strip():
|
||||||
|
segments.append(current.strip())
|
||||||
|
return [s for s in segments if s.strip()]
|
||||||
|
|
||||||
|
def make_segment(text_segment):
|
||||||
|
"""Build embeds for one segment."""
|
||||||
|
tokens = tokenizer.encode(text_segment, add_special_tokens=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
proj = talker.text_projection(
|
||||||
|
talker.get_text_embeddings()(torch.tensor([tokens]))
|
||||||
|
)[0].numpy().astype(np.float32)
|
||||||
|
|
||||||
|
target_len = max(int(len(tokens) * 3.2) + 5, 40)
|
||||||
|
trailing = [proj[i] for i in range(1, len(proj))]
|
||||||
|
while len(trailing) < target_len:
|
||||||
trailing.append(eos)
|
trailing.append(eos)
|
||||||
|
|
||||||
# Build file
|
return {
|
||||||
nPrefill = 10
|
'tokens': len(tokens),
|
||||||
nTotal = nPrefill + len(trailing)
|
'proj0': proj[0],
|
||||||
|
'trailing': trailing,
|
||||||
|
}
|
||||||
|
|
||||||
with open(OUTPUT, "wb") as f:
|
# Split text into segments
|
||||||
|
segments = split_text(TEXT, MAX_SEGMENT_TOKENS)
|
||||||
|
print(f"Segments: {len(segments)}")
|
||||||
|
for i, s in enumerate(segments):
|
||||||
|
n = len(tokenizer.encode(s, add_special_tokens=False))
|
||||||
|
print(f" [{i}] ({n} tok) '{s[:60]}{'...' if len(s)>60 else ''}'")
|
||||||
|
|
||||||
|
# Generate embeds per segment
|
||||||
|
seg_data = [make_segment(s) for s in segments]
|
||||||
|
|
||||||
|
if len(seg_data) == 1:
|
||||||
|
# Single segment: legacy format
|
||||||
|
s = seg_data[0]
|
||||||
|
nPrefill = 10
|
||||||
|
nTotal = nPrefill + len(s['trailing'])
|
||||||
|
with open(OUTPUT, "wb") as f:
|
||||||
f.write(struct.pack("<i", nPrefill))
|
f.write(struct.pack("<i", nPrefill))
|
||||||
f.write(struct.pack("<i", nTotal))
|
f.write(struct.pack("<i", nTotal))
|
||||||
# Golden prefill[0:8]
|
for i in range(9): f.write(golden[i].tobytes())
|
||||||
for i in range(9):
|
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
|
||||||
f.write(golden[i].tobytes())
|
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
|
||||||
# Prefill[9] = text[0] + codec_bos
|
print(f"\nSingle segment: {nTotal} embeds")
|
||||||
f.write((proj[0] + ce[CODEC_BOS]).tobytes())
|
else:
|
||||||
# Trailing
|
# Multi-segment format
|
||||||
for e in trailing:
|
with open(OUTPUT, "wb") as f:
|
||||||
f.write(np.array(e, dtype=np.float32).tobytes())
|
f.write(struct.pack("<i", len(seg_data)))
|
||||||
|
for s in seg_data:
|
||||||
|
nPrefill = 10
|
||||||
|
nTotal = nPrefill + len(s['trailing'])
|
||||||
|
f.write(struct.pack("<i", nPrefill))
|
||||||
|
f.write(struct.pack("<i", nTotal))
|
||||||
|
for i in range(9): f.write(golden[i].tobytes())
|
||||||
|
f.write((s['proj0'] + ce[CODEC_BOS]).tobytes())
|
||||||
|
for e in s['trailing']: f.write(np.array(e, dtype=np.float32).tobytes())
|
||||||
|
print(f"\nMulti-segment: {len(seg_data)} segments")
|
||||||
|
|
||||||
audio_est = len(trailing) * 0.08
|
total_trailing = sum(len(s['trailing']) for s in seg_data)
|
||||||
print(f"Trailing: {len(trailing)} ({len(tokens)-1} text + {len(trailing)-len(tokens)+1} eos)")
|
print(f"Total audio: ~{total_trailing * 0.08:.1f}s estimated")
|
||||||
print(f"Audio: ~{audio_est:.1f}s estimated")
|
|
||||||
print(f"Saved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f}KB)")
|
print(f"Saved: {OUTPUT} ({os.path.getsize(OUTPUT)/1024:.0f}KB)")
|
||||||
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue