Fix missing eos/pad embeddings in native C++ pipeline
The native pipeline was adding zeros after trailing text tokens instead of tts_eos_embed then tts_pad_embed. This caused the model to mispronounce final words (e.g. "développement" → "devopment"). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
393ce79eb5
commit
3b01302cfb
|
|
@ -222,6 +222,7 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
||||||
jfloatArray jTrailing,jint nTrailing,
|
jfloatArray jTrailing,jint nTrailing,
|
||||||
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
|
jfloatArray jCodecEmb, jfloatArray jCpEmbs, jfloatArray jCpHeads,
|
||||||
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
|
jfloatArray jTCos,jfloatArray jTSin, jfloatArray jCCos,jfloatArray jCSin,
|
||||||
|
jfloatArray jEosEmbed, jfloatArray jPadEmbed,
|
||||||
jint maxTokens)
|
jint maxTokens)
|
||||||
{
|
{
|
||||||
if(!gState||!gState->loaded) return nullptr;
|
if(!gState||!gState->loaded) return nullptr;
|
||||||
|
|
@ -257,6 +258,10 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
||||||
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
|
env->GetFloatArrayRegion(jCCos,0,ccSize,cCos.data());
|
||||||
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
|
env->GetFloatArrayRegion(jCSin,0,ccSize,cSin.data());
|
||||||
|
|
||||||
|
std::vector<float> eosEmbed(DIM,0), padEmbed(DIM,0);
|
||||||
|
if(jEosEmbed) env->GetFloatArrayRegion(jEosEmbed,0,DIM,eosEmbed.data());
|
||||||
|
if(jPadEmbed) env->GetFloatArrayRegion(jPadEmbed,0,DIM,padEmbed.data());
|
||||||
|
|
||||||
// Pipeline state
|
// Pipeline state
|
||||||
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
int tkvElem=T_KV*T_KV_LEN*T_HD;
|
||||||
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
|
std::vector<float> tK(T_L*tkvElem,0), tV(T_L*tkvElem,0);
|
||||||
|
|
@ -317,13 +322,17 @@ Java_com_kazeia_tts_TtsPipeline_nativeRun(
|
||||||
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
const float*ec=cpEmbs.data()+((long)cb*CB_SIZE+std::min(std::max(codes[cb+1],0),CB_SIZE-1))*DIM;
|
||||||
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
for(int k=0;k<DIM;k++) nextEmb[k]+=ec[k];
|
||||||
}
|
}
|
||||||
// Add trailing text or pad
|
// Add trailing text, then eos, then pad (matches Python/Kotlin pipeline)
|
||||||
if(trailingIdx<nTrailing){
|
if(trailingIdx<nTrailing){
|
||||||
const float*te=trailingData.data()+trailingIdx*DIM;
|
const float*te=trailingData.data()+trailingIdx*DIM;
|
||||||
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
|
for(int k=0;k<DIM;k++) nextEmb[k]+=te[k];
|
||||||
trailingIdx++;
|
trailingIdx++;
|
||||||
|
} else if(trailingIdx==nTrailing){
|
||||||
|
for(int k=0;k<DIM;k++) nextEmb[k]+=eosEmbed[k];
|
||||||
|
trailingIdx++;
|
||||||
|
} else {
|
||||||
|
for(int k=0;k<DIM;k++) nextEmb[k]+=padEmbed[k];
|
||||||
}
|
}
|
||||||
// (pad embedding = zeros, already added implicitly)
|
|
||||||
|
|
||||||
// Talker step
|
// Talker step
|
||||||
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
int mi=T_KV_LEN-1-std::min(pos,T_KV_LEN-1);
|
||||||
|
|
|
||||||
|
|
@ -2315,6 +2315,7 @@ class Qwen3TtsEngine(
|
||||||
cpAllHeads ?: FloatArray(0),
|
cpAllHeads ?: FloatArray(0),
|
||||||
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0),
|
||||||
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0),
|
||||||
|
ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM),
|
||||||
nTotal - nPrefill
|
nTotal - nPrefill
|
||||||
)
|
)
|
||||||
if (flat == null || flat.isEmpty()) return ShortArray(0)
|
if (flat == null || flat.isEmpty()) return ShortArray(0)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ object TtsPipeline {
|
||||||
cpHeads: FloatArray,
|
cpHeads: FloatArray,
|
||||||
talkerCos: FloatArray, talkerSin: FloatArray,
|
talkerCos: FloatArray, talkerSin: FloatArray,
|
||||||
cpCos: FloatArray, cpSin: FloatArray,
|
cpCos: FloatArray, cpSin: FloatArray,
|
||||||
|
eosEmbed: FloatArray, padEmbed: FloatArray,
|
||||||
maxTokens: Int
|
maxTokens: Int
|
||||||
): IntArray?
|
): IntArray?
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue