Fix missing eos/pad embeddings in native C++ pipeline
The native pipeline was adding zeros after trailing text tokens instead of tts_eos_embed then tts_pad_embed. This caused the model to mispronounce final words (e.g. "développement" → "devopment"). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
393ce79eb5
commit
3b01302cfb
|
|
@ -222,6 +222,7 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
|||
jfloatArray jTrailing,jint nTrailing,
|
||||
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
|
||||
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
|
||||
jfloatArray jEosEmbed, jfloatArray jPadEmbed,
|
||||
jint maxTokens)
|
||||
{
|
||||
if(!gState||!gState->loaded) return nullptr;
|
||||
|
|
@ -257,6 +258,10 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
|||
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
|
||||
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
|
||||
|
||||
std::vector<float> eosEmbed(DIM,0), padEmbed(DIM,0);
|
||||
if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data());
|
||||
if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data());
|
||||
|
||||
// Pipeline state
|
||||
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
||||
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
|
||||
|
|
@ -317,13 +322,17 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
|||
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
||||
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
||||
}
|
||||
// Add trailing text or pad
|
||||
// Add trailing text, then eos, then pad (matches Python/Kotlin pipeline)
|
||||
if(trailingIdx<nTrailing){
|
||||
const float*te=trailingData.data()+trailingIdx*DIM;
|
||||
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
|
||||
trailingIdx++;
|
||||
} else if(trailingIdx==nTrailing){
|
||||
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmbed[k];
|
||||
trailingIdx++;
|
||||
} else {
|
||||
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmbed[k];
|
||||
}
|
||||
// (pad embedding = zeros, already added implicitly)
|
||||
|
||||
// Talker step
|
||||
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
||||
|
|
|
|||
|
|
@ -2315,6 +2315,7 @@ class Qwen3TtsEngine(
|
|||
cpAllHeads ?: FloatArray(0),
|
||||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||||
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||||
nTotal - nPrefill
|
||||
)
|
||||
if (flat == null || flat.isEmpty()) return ShortArray(0)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ object TtsPipeline {
|
|||
cpHeads: FloatArray,
|
||||
talkerCos: FloatArray, talkerSin: FloatArray,
|
||||
cpCos: FloatArray, cpSin: FloatArray,
|
||||
eosEmbed: FloatArray, padEmbed: FloatArray,
|
||||
maxTokens: Int
|
||||
): IntArray?
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue