Fix missing eos/pad embeddings in native C++ pipeline

The native pipeline was adding zeros after trailing text tokens
instead of tts_eos_embed then tts_pad_embed. This caused the model
to mispronounce final words (e.g. "développement" → "devopment").

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 10:35:05 +02:00
parent 393ce79eb5
commit 3b01302cfb
3 changed files with 13 additions and 2 deletions

View File

@ -222,6 +222,7 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
jfloatArray jTrailing,jint nTrailing,
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
jfloatArray jEosEmbed, jfloatArray jPadEmbed,
jint maxTokens)
{
if(!gState||!gState->loaded) return nullptr;
@ -257,6 +258,10 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
std::vector<float> eosEmbed(DIM,0), padEmbed(DIM,0);
if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data());
if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data());
// Pipeline state
int tkvElem=T_KV*T_KV_LEN*T_HD;
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
@ -317,13 +322,17 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
}
// Add trailing text or pad
// Add trailing text, then eos, then pad (matches Python/Kotlin pipeline)
if(trailingIdx<nTrailing){
const float*te=trailingData.data()+trailingIdx*DIM;
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
trailingIdx++;
} else if(trailingIdx==nTrailing){
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmbed[k];
trailingIdx++;
} else {
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmbed[k];
}
// (pad embedding = zeros, already added implicitly)
// Talker step
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);

View File

@ -2315,6 +2315,7 @@ class Qwen3TtsEngine(
cpAllHeads ?: FloatArray(0),
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
nTotal - nPrefill
)
if (flat == null || flat.isEmpty()) return ShortArray(0)

View File

@ -22,6 +22,7 @@ object TtsPipeline {
cpHeads: FloatArray,
talkerCos: FloatArray, talkerSin: FloatArray,
cpCos: FloatArray, cpSin: FloatArray,
eosEmbed: FloatArray, padEmbed: FloatArray,
maxTokens: Int
): IntArray?
}