Reduce talker KV_LEN 100→64: saves 148ms (RTF 1.31)
KV window of 64 sufficient for ~70 token generation (10 prefill + 58 gen). 36% less KV memcpy per talker step (28L × 2 × 64×8×128 vs 100×8×128). Generation: 3795ms → 3647ms, total: 6438ms → 6093ms Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4dcc4bb8b3
commit
a688edc9ec
|
|
@ -692,7 +692,7 @@ ExecuTorchJni::runTtsPipelineImpl(
|
||||||
jint maxTokens)
|
jint maxTokens)
|
||||||
{
|
{
|
||||||
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
|
static const int DIM=1024,VOCAB=3072,CB_SIZE=2048,NUM_CB=16;
|
||||||
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=100;
|
static const int T_L=28,T_KV=8,T_HD=128,T_KV_LEN=64;
|
||||||
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
|
static const int C_L=5,C_KV=8,C_HD=128,C_KV_LEN=16;
|
||||||
static const int CODEC_EOS=2150;
|
static const int CODEC_EOS=2150;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class Qwen3TtsEngine(
|
||||||
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
|
private const val CP_KV_LEN = 16 // max 16 past positions (17 total with current)
|
||||||
|
|
||||||
// Talker .pte constants
|
// Talker .pte constants
|
||||||
private const val TALKER_PTE_KV_LEN = 100 // .pte talker KV window size
|
private const val TALKER_PTE_KV_LEN = 64 // .pte talker KV window size (reduced from 100)
|
||||||
|
|
||||||
// Codec special token IDs (in talker's 3072 vocab space)
|
// Codec special token IDs (in talker's 3072 vocab space)
|
||||||
private const val CODEC_EOS = 2150
|
private const val CODEC_EOS = 2150
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
|
N_L = 28; N_H = 16; N_KV = 8; HD = 128; DIM = 1024; N_REP = 2
|
||||||
VOCAB = 3072; FFN = 3072
|
VOCAB = 3072; FFN = 3072
|
||||||
KV_LEN = 16 # Small KV for testing HTP viability
|
KV_LEN = 64 # Reduced from 100: saves 36% memcpy, sufficient for ~70 token generation
|
||||||
|
|
||||||
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
|
state = torch.load("/opt/Kazeia/models_qnn/qwen3-tts-export/qwen3_tts_talker.pth",
|
||||||
map_location="cpu", weights_only=False)
|
map_location="cpu", weights_only=False)
|
||||||
|
|
@ -117,7 +117,7 @@ edge = to_edge_transform_and_lower_to_qnn(w, (e, m, c0, s0, *kvs), compiler_spec
|
||||||
print("LOWERED!")
|
print("LOWERED!")
|
||||||
|
|
||||||
pte = edge.to_executorch()
|
pte = edge.to_executorch()
|
||||||
OUT = "/opt/Kazeia/models_qnn/talker_transformer_fp16_kv16.pte"
|
OUT = "/opt/Kazeia/models_qnn/talker_transformer_fp16.pte"
|
||||||
with open(OUT, "wb") as f:
|
with open(OUT, "wb") as f:
|
||||||
pte.write_to_file(f)
|
pte.write_to_file(f)
|
||||||
print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")
|
print(f"SAVED: {OUT} ({os.path.getsize(OUT)/1024/1024:.0f} MB)")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue