Multi-segment TTS for long text: split → generate → concatenate
- prepare_tts_segments.py: splits text at sentence boundaries, generates Python pre-computed embeds per segment - Kotlin: detects multi-segment file format, processes each segment independently (fresh KV cache), concatenates audio - Long text tested: 3 segments, 335 tokens, 26.8s audio, RTF 1.67 File format: n_segments, then per segment: nPrefill, nTotal, embeds[] Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
24157c0a68
commit
10a3904d7d
|
|
@ -2282,7 +2282,19 @@ class Qwen3TtsEngine(
|
||||||
val t0 = System.currentTimeMillis()
|
val t0 = System.currentTimeMillis()
|
||||||
val bytes = File(embedsPath).readBytes()
|
val bytes = File(embedsPath).readBytes()
|
||||||
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
val nPrefill = bb.int; val nTotal = bb.int
|
|
||||||
|
// Detect format: multi-segment (first int = n_segments, small) or legacy (first int = nPrefill ~10)
|
||||||
|
val firstInt = bb.int
|
||||||
|
val isMultiSegment = firstInt > 0 && firstInt < 100 && firstInt != 10 && firstInt != 9
|
||||||
|
// Heuristic: legacy format has nPrefill=9 or 10. Multi-segment has n_segments=2..50
|
||||||
|
|
||||||
|
if (isMultiSegment && talkerPteModule != null && cpPteModule != null) {
|
||||||
|
return generateMultiSegment(bb, firstInt, t0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy single-segment format
|
||||||
|
val nPrefill = if (isMultiSegment) { bb.position(0); bb.int; bb.int } else firstInt
|
||||||
|
val nTotal = bb.int
|
||||||
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||||||
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
nlog("Loaded $nTotal embeds ($nPrefill prefill + ${nTotal - nPrefill} decode)")
|
||||||
|
|
||||||
|
|
@ -2498,6 +2510,87 @@ class Qwen3TtsEngine(
|
||||||
return allCodes.toTypedArray()
|
return allCodes.toTypedArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Multi-segment pipeline: process each segment independently, concatenate audio. */
|
||||||
|
private fun generateMultiSegment(bb: ByteBuffer, nSegments: Int, t0: Long): ShortArray {
|
||||||
|
nlog("Multi-segment: $nSegments segments")
|
||||||
|
val allAudio = mutableListOf<ShortArray>()
|
||||||
|
|
||||||
|
// Ensure data arrays loaded
|
||||||
|
val mpath = "/data/local/tmp/kazeia/models/qwen3-tts-npu"
|
||||||
|
if (cpAllHeads == null) {
|
||||||
|
val hf = java.io.File("/data/local/tmp/kazeia/models/cp_heads.bin")
|
||||||
|
if (hf.exists()) { val hb = hf.readBytes(); cpAllHeads = FloatArray(hb.size/4); ByteBuffer.wrap(hb).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(cpAllHeads!!) }
|
||||||
|
}
|
||||||
|
if (codecEmbedding == null) codecEmbedding = loadNpy("$mpath/codec_embedding.npy")
|
||||||
|
if (cpEmbeddings == null) cpEmbeddings = loadNpy("$mpath/code_predictor_embeddings.npy")
|
||||||
|
if (cpRotaryCos == null) cpRotaryCos = loadNpy("$mpath/cp_kv_v2/cp_rotary_cos.npy")
|
||||||
|
if (cpRotarySin == null) cpRotarySin = loadNpy("$mpath/cp_kv_v2/cp_rotary_sin.npy")
|
||||||
|
if (talkerPteRotaryCos == null) talkerPteRotaryCos = loadNpy("$mpath/talker_pte_rotary_cos.npy")
|
||||||
|
if (talkerPteRotarySin == null) talkerPteRotarySin = loadNpy("$mpath/talker_pte_rotary_sin.npy")
|
||||||
|
if (ttsEosEmbed == null) { val sp = loadNpy("$mpath/tts_special_embeds.npy"); ttsBosEmbed=sp.sliceArray(0 until TALKER_DIM); ttsEosEmbed=sp.sliceArray(TALKER_DIM until 2*TALKER_DIM); ttsPadEmbed=sp.sliceArray(2*TALKER_DIM until 3*TALKER_DIM) }
|
||||||
|
|
||||||
|
for (seg in 0 until nSegments) {
|
||||||
|
val nPrefill = bb.int; val nTotal = bb.int
|
||||||
|
val embeds = Array(nTotal) { FloatArray(TALKER_DIM).also { arr -> for (j in 0 until TALKER_DIM) arr[j] = bb.float } }
|
||||||
|
nlog("Segment ${seg+1}/$nSegments: $nPrefill prefill + ${nTotal-nPrefill} decode")
|
||||||
|
|
||||||
|
val prefillFlat = FloatArray(nPrefill * TALKER_DIM)
|
||||||
|
for (i in 0 until nPrefill) System.arraycopy(embeds[i], 0, prefillFlat, i * TALKER_DIM, TALKER_DIM)
|
||||||
|
val nTrailing = nTotal - nPrefill
|
||||||
|
val trailingFlat = if (nTrailing > 0) FloatArray(nTrailing * TALKER_DIM).also { arr ->
|
||||||
|
for (i in 0 until nTrailing) System.arraycopy(embeds[nPrefill+i], 0, arr, i*TALKER_DIM, TALKER_DIM)
|
||||||
|
} else null
|
||||||
|
|
||||||
|
val flat = talkerPteModule!!.nativeRunTtsPipeline(
|
||||||
|
prefillFlat, nPrefill, trailingFlat, nTrailing,
|
||||||
|
codecEmbedding ?: FloatArray(0), cpEmbeddings ?: FloatArray(0), cpAllHeads ?: FloatArray(0),
|
||||||
|
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||||||
|
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||||||
|
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||||||
|
nTrailing
|
||||||
|
)
|
||||||
|
if (flat == null || flat.isEmpty()) continue
|
||||||
|
val nTokens = flat.size / NUM_CODEBOOKS
|
||||||
|
val segCodes = Array(nTokens) { t -> IntArray(NUM_CODEBOOKS) { cb -> flat[t * NUM_CODEBOOKS + cb] } }
|
||||||
|
nlog(" → $nTokens tokens generated")
|
||||||
|
|
||||||
|
val padLen = maxOf(nTokens, SEQ_LEN)
|
||||||
|
val codebooks = Array(NUM_CODEBOOKS) { cb -> IntArray(padLen) { t -> if (t < nTokens) segCodes[t][cb] else 0 } }
|
||||||
|
val segAudio = decodeChunked(codebooks, nTokens)
|
||||||
|
allAudio.add(segAudio)
|
||||||
|
nlog(" → ${segAudio.size/SR.toFloat()}s audio decoded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate all segments
|
||||||
|
val totalSamples = allAudio.sumOf { it.size }
|
||||||
|
val result = ShortArray(totalSamples)
|
||||||
|
var offset = 0
|
||||||
|
for (seg in allAudio) { System.arraycopy(seg, 0, result, offset, seg.size); offset += seg.size }
|
||||||
|
|
||||||
|
val totalMs = System.currentTimeMillis() - t0
|
||||||
|
nlog("Total: ${totalMs}ms for ${totalSamples/SR.toFloat()}s ($nSegments segments)")
|
||||||
|
|
||||||
|
// Save WAV
|
||||||
|
try {
|
||||||
|
val wavPath = "/data/local/tmp/kazeia/kazeia_PTE_NPU.wav"
|
||||||
|
val fos = java.io.FileOutputStream(wavPath)
|
||||||
|
val dataLen = result.size * 2
|
||||||
|
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
header.put("RIFF".toByteArray()); header.putInt(36 + dataLen)
|
||||||
|
header.put("WAVE".toByteArray()); header.put("fmt ".toByteArray())
|
||||||
|
header.putInt(16); header.putShort(1); header.putShort(1)
|
||||||
|
header.putInt(SR); header.putInt(SR * 2); header.putShort(2); header.putShort(16)
|
||||||
|
header.put("data".toByteArray()); header.putInt(dataLen)
|
||||||
|
fos.write(header.array())
|
||||||
|
val buf = ByteBuffer.allocate(dataLen).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
for (s in result) buf.putShort(s)
|
||||||
|
fos.write(buf.array()); fos.close()
|
||||||
|
nlog("WAV saved: $wavPath ($totalSamples samples)")
|
||||||
|
} catch (e: Exception) { nlog("WAV save failed: ${e.message}") }
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
/** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */
|
/** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */
|
||||||
private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray {
|
private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray {
|
||||||
nlog("Full pipeline (Hexagon) from: $embedsPath")
|
nlog("Full pipeline (Hexagon) from: $embedsPath")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Generate TTS embeddings for long text, split into sentence segments.
|
||||||
|
Each segment is generated independently by Python for maximum quality.
|
||||||
|
|
||||||
|
Usage: python3 prepare_tts_segments.py "Long text..." [output.bin]
|
||||||
|
adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin
|
||||||
|
|
||||||
|
Output format:
|
||||||
|
int32 n_segments
|
||||||
|
for each segment:
|
||||||
|
int32 n_prefill
|
||||||
|
int32 n_total
|
||||||
|
float32[n_total * 1024] embeddings
|
||||||
|
"""
|
||||||
|
import sys, os, struct, re, types, warnings
|
||||||
|
os.chdir("/tmp")
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour. Je suis Kazeia."
|
||||||
|
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_segments.bin"
|
||||||
|
VOICE = "/opt/Kazeia/voix/damien_15s_24k.wav"
|
||||||
|
MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc"
|
||||||
|
|
||||||
|
import torch, numpy as np
|
||||||
|
from qwen_tts import Qwen3TTSModel
|
||||||
|
|
||||||
|
def split_sentences(text, max_tokens=60):
|
||||||
|
"""Split text at sentence boundaries, keeping segments short."""
|
||||||
|
# Split at . ! ? ; and keep the punctuation
|
||||||
|
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
current = ""
|
||||||
|
for part in parts:
|
||||||
|
if current and len(current) + len(part) > 200: # rough char limit
|
||||||
|
segments.append(current.strip())
|
||||||
|
current = part
|
||||||
|
else:
|
||||||
|
current = (current + " " + part).strip() if current else part
|
||||||
|
if current.strip():
|
||||||
|
segments.append(current.strip())
|
||||||
|
|
||||||
|
# If any segment is still too long, split at commas
|
||||||
|
final = []
|
||||||
|
for seg in segments:
|
||||||
|
if len(seg) > 250:
|
||||||
|
parts = re.split(r'(?<=,)\s+', seg)
|
||||||
|
sub = ""
|
||||||
|
for p in parts:
|
||||||
|
if sub and len(sub) + len(p) > 200:
|
||||||
|
final.append(sub.strip())
|
||||||
|
sub = p
|
||||||
|
else:
|
||||||
|
sub = (sub + " " + p).strip() if sub else p
|
||||||
|
if sub.strip():
|
||||||
|
final.append(sub.strip())
|
||||||
|
else:
|
||||||
|
final.append(seg)
|
||||||
|
|
||||||
|
return final if final else [text]
|
||||||
|
|
||||||
|
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||||||
|
segments = split_sentences(TEXT)
|
||||||
|
print(f"Split into {len(segments)} segments:")
|
||||||
|
for i, s in enumerate(segments):
|
||||||
|
print(f" [{i}] '{s[:60]}{'...' if len(s)>60 else ''}'")
|
||||||
|
|
||||||
|
print("\nLoading model...")
|
||||||
|
tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu")
|
||||||
|
talker = tts.model.talker
|
||||||
|
|
||||||
|
# Capture generation inputs via monkey-patch on inner model
|
||||||
|
captured_inputs = []
|
||||||
|
original_model_forward = talker.model.forward
|
||||||
|
def patched_model_forward(input_ids=None, inputs_embeds=None, **kwargs):
|
||||||
|
if inputs_embeds is not None and inputs_embeds.shape[1] == 1:
|
||||||
|
captured_inputs.append(inputs_embeds[0, 0, :].detach().cpu().numpy().astype(np.float32))
|
||||||
|
return original_model_forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
||||||
|
talker.model.forward = patched_model_forward
|
||||||
|
|
||||||
|
# Load prefill structure
|
||||||
|
EXISTING = "/tmp/existing_embeds.bin"
|
||||||
|
if not os.path.exists(EXISTING):
|
||||||
|
os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {EXISTING}")
|
||||||
|
with open(EXISTING, "rb") as f:
|
||||||
|
nP = struct.unpack("<i", f.read(4))[0]
|
||||||
|
nT = struct.unpack("<i", f.read(4))[0]
|
||||||
|
old_embeds = [np.frombuffer(f.read(1024*4), dtype=np.float32).copy() for _ in range(nT)]
|
||||||
|
prefill_base = old_embeds[:9] # role+ctrl+spk+bos
|
||||||
|
|
||||||
|
# Generate embeds for each segment
|
||||||
|
all_segments_data = []
|
||||||
|
total_audio_duration = 0
|
||||||
|
|
||||||
|
for i, seg_text in enumerate(segments):
|
||||||
|
print(f"\n--- Segment {i+1}/{len(segments)} ---")
|
||||||
|
print(f" '{seg_text[:60]}...'")
|
||||||
|
|
||||||
|
captured_inputs.clear()
|
||||||
|
audio_list, sr = tts.generate_voice_clone(
|
||||||
|
text=seg_text, ref_audio=VOICE, language="french",
|
||||||
|
x_vector_only_mode=True, non_streaming_mode=True,
|
||||||
|
)
|
||||||
|
audio_dur = len(audio_list[0]) / sr
|
||||||
|
total_audio_duration += audio_dur
|
||||||
|
|
||||||
|
if len(captured_inputs) < 2:
|
||||||
|
print(f" WARNING: Only {len(captured_inputs)} steps, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
nPrefill = 10 # 9 base + first gen input
|
||||||
|
nDecode = len(captured_inputs) - 1
|
||||||
|
nTotal = nPrefill + nDecode
|
||||||
|
|
||||||
|
seg_data = {
|
||||||
|
'nPrefill': nPrefill,
|
||||||
|
'nTotal': nTotal,
|
||||||
|
'prefill': prefill_base.copy(),
|
||||||
|
'first_gen': captured_inputs[0],
|
||||||
|
'decode': captured_inputs[1:],
|
||||||
|
'audio_dur': audio_dur,
|
||||||
|
'n_tokens': len(captured_inputs),
|
||||||
|
}
|
||||||
|
all_segments_data.append(seg_data)
|
||||||
|
print(f" {nTotal} embeds ({len(captured_inputs)} tokens), {audio_dur:.2f}s audio")
|
||||||
|
|
||||||
|
# Write multi-segment file
|
||||||
|
with open(OUTPUT, "wb") as f:
|
||||||
|
f.write(struct.pack("<i", len(all_segments_data)))
|
||||||
|
for seg in all_segments_data:
|
||||||
|
f.write(struct.pack("<i", seg['nPrefill']))
|
||||||
|
f.write(struct.pack("<i", seg['nTotal']))
|
||||||
|
for emb in seg['prefill']:
|
||||||
|
f.write(emb.tobytes())
|
||||||
|
f.write(seg['first_gen'].tobytes())
|
||||||
|
for emb in seg['decode']:
|
||||||
|
f.write(emb.tobytes())
|
||||||
|
|
||||||
|
sz = os.path.getsize(OUTPUT)
|
||||||
|
total_tokens = sum(s['n_tokens'] for s in all_segments_data)
|
||||||
|
print(f"\n=== RESULT ===")
|
||||||
|
print(f"Segments: {len(all_segments_data)}")
|
||||||
|
print(f"Total tokens: {total_tokens}")
|
||||||
|
print(f"Total audio: {total_audio_duration:.2f}s")
|
||||||
|
print(f"File: {OUTPUT} ({sz/1024:.0f}KB)")
|
||||||
|
print(f"\nadb push {OUTPUT} /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin")
|
||||||
Loading…
Reference in New Issue