Restore KV=100 + fix as-is embeds + multi-segment support
- KV_LEN restored to 100 (KV=64 caused quality loss from evicted role tokens) - C++ uses pre-computed embeds as-is (no double codec_sum) - Multi-segment format support in Kotlin (detects n_segments header) - prepare_tts_segments.py: splits text + generates per-segment embeds - Quality issue: Python-captured embeds differ from original working file (original was likely captured on-device, not from Python model.forward) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
10a3904d7d
commit
3dcf73aa38
|
|
@ -692,7 +692,7 @@ ExecuTorchJni::runTtsPipelineImpl(
|
|||
jint maxTokens)
|
||||
{
|
||||
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
|
||||
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=64;
|
||||
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
|
||||
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
|
||||
static const int CODEC_EOS=2150;
|
||||
|
||||
|
|
@ -839,13 +839,13 @@ ExecuTorchJni::runTtsPipelineImpl(
|
|||
for(int i=0;i<NUM_CB;i++) allCodes.push_back(codes[i]);
|
||||
cb0Hist.push_back(curCb0);
|
||||
|
||||
// Next embed: pre-computed from Python (complete: codec_sum+text)
|
||||
// After exhausted: codec_sum(our codes) + pad
|
||||
// Next embed: pre-computed from Python (already contains codec_sum+text)
|
||||
float nextEmb[DIM]={};
|
||||
if(trIdx<nTrailing){
|
||||
memcpy(nextEmb,trailing.data()+trIdx*DIM,DIM*4);
|
||||
trIdx++;
|
||||
} else {
|
||||
// After embeds exhausted: our codec_sum + pad
|
||||
const float*e0=codecEmb.data()+std::min(std::max(codes[0],0),VOCAB-1)*DIM;
|
||||
for(int k=0;k<DIM;k++) nextEmb[k]+=e0[k];
|
||||
for(int cb=0;cb<15;cb++){
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class Qwen3TtsEngine(
|
|||
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
|
||||
|
||||
// Talker .pte constants
|
||||
private const val TALKER_PTE_KV_LEN = 64 // .pte talker KV window size (reduced from 100)
|
||||
private const val TALKER_PTE_KV_LEN = 100 // must match .pte export (KV=64 caused quality loss)
|
||||
|
||||
// Codec special token IDs (in talker's 3072 vocab space)
|
||||
private const val CODEC_EOS = 2150
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ warnings.filterwarnings('ignore')
|
|||
|
||||
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
|
||||
VOCAB = 3072; FFN = 3072
|
||||
KV_LEN = 64 # Reduced from 100: saves 36% memcpy, sufficient for ~70 token generation
|
||||
KV_LEN = 100 # Must be >= prefill+maxGen. KV=64 caused quality loss (role tokens evicted)
|
||||
|
||||
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
|
||||
map_location="cpu", weights_only=False)
|
||||
|
|
|
|||
|
|
@ -25,40 +25,29 @@ MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base
|
|||
import torch, numpy as np
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
def split_sentences(text, max_tokens=60):
|
||||
"""Split text at sentence boundaries, keeping segments short."""
|
||||
# Split at . ! ? ; and keep the punctuation
|
||||
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
|
||||
def split_sentences(text, max_chars=120):
|
||||
"""Split text into SHORT segments (~40-50 tokens max). Each sentence separate."""
|
||||
# Split at every sentence boundary
|
||||
parts = re.split(r'(?<=[.!?;:])\s+', text.strip())
|
||||
|
||||
segments = []
|
||||
current = ""
|
||||
for part in parts:
|
||||
if current and len(current) + len(part) > 200: # rough char limit
|
||||
segments.append(current.strip())
|
||||
current = part
|
||||
else:
|
||||
current = (current + " " + part).strip() if current else part
|
||||
if current.strip():
|
||||
segments.append(current.strip())
|
||||
|
||||
# If any segment is still too long, split at commas
|
||||
# Further split long sentences at commas
|
||||
final = []
|
||||
for seg in segments:
|
||||
if len(seg) > 250:
|
||||
parts = re.split(r'(?<=,)\s+', seg)
|
||||
sub = ""
|
||||
for p in parts:
|
||||
if sub and len(sub) + len(p) > 200:
|
||||
final.append(sub.strip())
|
||||
sub = p
|
||||
for part in parts:
|
||||
if len(part) > max_chars:
|
||||
subs = re.split(r'(?<=,)\s+', part)
|
||||
current = ""
|
||||
for s in subs:
|
||||
if current and len(current) + len(s) > max_chars:
|
||||
final.append(current.strip())
|
||||
current = s
|
||||
else:
|
||||
sub = (sub + " " + p).strip() if sub else p
|
||||
if sub.strip():
|
||||
final.append(sub.strip())
|
||||
current = (current + " " + s).strip() if current else s
|
||||
if current.strip():
|
||||
final.append(current.strip())
|
||||
else:
|
||||
final.append(seg)
|
||||
final.append(part)
|
||||
|
||||
return final if final else [text]
|
||||
return [s for s in final if s.strip()] if final else [text]
|
||||
|
||||
print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'")
|
||||
segments = split_sentences(TEXT)
|
||||
|
|
|
|||
Loading…
Reference in New Issue