Restore KV=100 + fix as-is embeds + multi-segment support

- KV_LEN restored to 100 (KV=64 caused quality loss from evicted role tokens)
- C++ uses pre-computed embeds as-is (no double codec_sum)
- Multi-segment format support in Kotlin (detects n_segments header)
- prepare_tts_segments.py: splits text + generates per-segment embeds
- Quality issue: Python-captured embeds differ from original working file
  (original was likely captured on-device, not from Python model.forward)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 22:26:20 +02:00
parent 10a3904d7d
commit 3dcf73aa38
4 changed files with 23 additions and 34 deletions

View File

@ -692,7 +692,7 @@ ExecuTorchJni::runTtsPipelineImpl(
jint maxTokens) jint maxTokens)
{ {
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16; static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=64; static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16; static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
static const int CODEC_EOS=2150; static const int CODEC_EOS=2150;
@ -839,13 +839,13 @@ ExecuTorchJni::runTtsPipelineImpl(
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]); for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
cb0Hist.push_back(curCb0); cb0Hist.push_back(curCb0);
// Next embed: pre-computed from Python (complete: codec_sum+text) // Next embed: pre-computed from Python (already contains codec_sum+text)
// After exhausted: codec_sum(our codes) + pad
float nextEmb[DIM]={}; float nextEmb[DIM]={};
if(trIdx<nTrailing){ if(trIdx<nTrailing){
memcpy(nextEmb,trailing.data()+trIdx*DIM,DIM*4); memcpy(nextEmb,trailing.data()+trIdx*DIM,DIM*4);
trIdx++; trIdx++;
} else { } else {
// After embeds exhausted: our codec_sum + pad
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM; const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k]; for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
for(int cb=0;cb<15;cb++){ for(int cb=0;cb<15;cb++){

View File

@ -68,7 +68,7 @@ class Qwen3TtsEngine(
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current) private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
// Talker .pte constants // Talker .pte constants
private const val TALKER_PTE_KV_LEN = 64 // .pte talker KV window size (reduced from 100) private const val TALKER_PTE_KV_LEN = 100 // must match .pte export (KV=64 caused quality loss)
// Codec special token IDs (in talker's 3072 vocab space) // Codec special token IDs (in talker's 3072 vocab space)
private const val CODEC_EOS = 2150 private const val CODEC_EOS = 2150

View File

@ -11,7 +11,7 @@ warnings.filterwarnings('ignore')
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2 N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
VOCAB = 3072; FFN = 3072 VOCAB = 3072; FFN = 3072
KV_LEN = 64 # Reduced from 100: saves 36% memcpy, sufficient for ~70 token generation KV_LEN = 100 # Must be >= prefill+maxGen. KV=64 caused quality loss (role tokens evicted)
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth", state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
map_location="cpu", weights_only=False) map_location="cpu", weights_only=False)

View File

@ -25,40 +25,29 @@ MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base
import torch, numpy as np import torch, numpy as np
from qwen_tts import Qwen3TTSModel from qwen_tts import Qwen3TTSModel
def split_sentences(text, max_tokens=60): def split_sentences(text, max_chars=120):
"""Split text at sentence boundaries, keeping segments short.""" """Split text into SHORT segments (~40-50 tokens max). Each sentence separate."""
# Split at . ! ? ; and keep the punctuation # Split at every sentence boundary
parts = re.split(r'(?<=[.!?;])\s+', text.strip()) parts = re.split(r'(?<=[.!?;:])\s+', text.strip())
segments = [] # Further split long sentences at commas
current = ""
for part in parts:
if current and len(current) + len(part) > 200: # rough char limit
segments.append(current.strip())
current = part
else:
current = (current + " " + part).strip() if current else part
if current.strip():
segments.append(current.strip())
# If any segment is still too long, split at commas
final = [] final = []
for seg in segments: for part in parts:
if len(seg) > 250: if len(part) > max_chars:
parts = re.split(r'(?<=,)\s+', seg) subs = re.split(r'(?<=,)\s+', part)
sub = "" current = ""
for p in parts: for s in subs:
if sub and len(sub) + len(p) > 200: if current and len(current) + len(s) > max_chars:
final.append(sub.strip()) final.append(current.strip())
sub = p current = s
else: else:
sub = (sub + " " + p).strip() if sub else p current = (current + " " + s).strip() if current else s
if sub.strip(): if current.strip():
final.append(sub.strip()) final.append(current.strip())
else: else:
final.append(seg) final.append(part)
return final if final else [text] return [s for s in final if s.strip()] if final else [text]
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'") print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
segments = split_sentences(TEXT) segments = split_sentences(TEXT)