diff --git a/executorch-custom/jni_layer_tts.cpp b/executorch-custom/jni_layer_tts.cpp index a52db0e..bd4f46f 100644 --- a/executorch-custom/jni_layer_tts.cpp +++ b/executorch-custom/jni_layer_tts.cpp @@ -839,19 +839,25 @@ ExecuTorchJni::runTtsPipelineImpl( for(int i=0;i= 0) { - var energy = 0.0 - for (i in pos until minOf(pos + windowSize, audio.size)) { - energy += audio[i].toDouble() * audio[i] - } - val rms = kotlin.math.sqrt(energy / windowSize) - if (rms > threshold) { - lastSpeechEnd = pos + windowSize - break - } - pos -= windowSize - } - - val trimEnd = minOf(lastSpeechEnd + marginSamples, audio.size) - val result = audio.copyOf(trimEnd) - - // Apply fade-out - val fadeStart = maxOf(0, result.size - fadeSamples) - for (i in fadeStart until result.size) { - val alpha = 1f - (i - fadeStart).toFloat() / (result.size - fadeStart) - result[i] = (result[i] * alpha).toInt().toShort() - } - return result - } /** Sample from logits with temperature scaling and top-K filtering */ private fun sampleTopK(logits: FloatArray, temperature: Float = 0.9f, topK: Int = 50): Int { @@ -2372,7 +2333,7 @@ class Qwen3TtsEngine( talkerPteRotaryCos ?: FloatArray(0), talkerPteRotarySin ?: FloatArray(0), cpRotaryCos ?: FloatArray(0), cpRotarySin ?: FloatArray(0), ttsEosEmbed ?: FloatArray(TALKER_DIM), ttsPadEmbed ?: FloatArray(TALKER_DIM), - maxOf(200, (nTotal - nPrefill) * 4) // maxTokens: audio is ~3-4× longer than text + nTotal - nPrefill // maxTokens = trailing count (no pad generation) ) if (flat == null || flat.isEmpty()) return ShortArray(0) val nTokens = flat.size / NUM_CODEBOOKS @@ -2393,9 +2354,13 @@ class Qwen3TtsEngine( } val t3 = System.currentTimeMillis() - val audio = decodeChunked(allCodebooks, numRealTokens) + val rawAudio = decodeChunked(allCodebooks, numRealTokens) nlog("Decode: ${System.currentTimeMillis() - t3}ms") + // Trim trailing noise/silence: scan from end, find last loud frame + val audio = trimTrailingSilence(rawAudio) + nlog("Trimmed: ${rawAudio.size} → ${audio.size} samples (${(rawAudio.size-audio.size)/SR.toFloat()}s removed)") + val totalMs = System.currentTimeMillis() - t0 val audioDur = audio.size.toFloat() / SR nlog("Total: ${totalMs}ms for ${audioDur}s") @@ -2641,6 +2606,42 @@ class Qwen3TtsEngine( return result } + /** Trim trailing garbage from audio by detecting RMS drop. + * Scans forward, finds where RMS drops significantly → end of speech. */ + private fun trimTrailingSilence(audio: ShortArray): ShortArray { + val windowSamples = SR / 10 // 100ms windows + if (audio.size < windowSamples * 4) return audio + + // Compute RMS per window + val nWindows = audio.size / windowSamples + val rmsValues = FloatArray(nWindows) + for (w in 0 until nWindows) { + var sum = 0.0 + for (i in 0 until windowSamples) { + val s = audio[w * windowSamples + i].toFloat() + sum += s * s + } + rmsValues[w] = Math.sqrt(sum / windowSamples).toFloat() + } + + // Find peak RMS in first half (speech region) + val peakRms = rmsValues.take(nWindows / 2).maxOrNull() ?: return audio + + // Scan from 60% onwards, find first window where RMS drops below 15% of peak + // (speech ended, garbage/silence started) + val threshold = peakRms * 0.15f + var cutWindow = nWindows + for (w in (nWindows * 3 / 5) until nWindows) { + if (rmsValues[w] < threshold) { + cutWindow = w + 1 // keep one more window for tail + break + } + } + + val trimPoint = minOf(cutWindow * windowSamples, audio.size) + return if (trimPoint < audio.size) audio.copyOf(trimPoint) else audio + } + /** Full pipeline using Hexagon talker + Hexagon CP from pre-computed embeddings. */ private fun generateFromEmbedsHexagon(embedsPath: String): ShortArray { nlog("Full pipeline (Hexagon) from: $embedsPath") diff --git a/scripts/prepare_tts_native.py b/scripts/prepare_tts_native.py new file mode 100644 index 0000000..e126ebd --- /dev/null +++ b/scripts/prepare_tts_native.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Generate text-only TTS embeddings for FULL C++ native pipeline. +No Python model generation needed — just tokenize + text_projection. + +Usage: python3 prepare_tts_native.py "Your text here" [output.bin] + adb push output.bin /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin + +Formula: trailing = text_proj[1:] + eos_padding(n_tokens × 4 total) + maxTokens = trailing_count (cut after trailing exhausted) +""" +import sys, os, struct, warnings +os.chdir("/tmp") +warnings.filterwarnings("ignore") + +TEXT = sys.argv[1] if len(sys.argv) > 1 else "Bonjour, je m'appelle Kazeia." +OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/tts_native.bin" +GOLDEN_PREFILL = "/tmp/existing_embeds.bin" # Must exist (captured on-device once) +MODEL = "/home/alf/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-12Hz-0.6B-Base/snapshots/5d83992436eae1d760afd27aff78a71d676296fc" + +import torch, numpy as np +from qwen_tts import Qwen3TTSModel + +print(f"Text: '{TEXT[:80]}{'...' if len(TEXT)>80 else ''}'") + +# Load model (just for tokenizer + text_projection) +tts = Qwen3TTSModel.from_pretrained(MODEL, local_files_only=True, device_map="cpu") +talker = tts.model.talker +tokenizer = tts.processor.tokenizer + +# Tokenize + project +tokens = tokenizer.encode(TEXT, add_special_tokens=False) +with torch.no_grad(): + proj = talker.text_projection( + talker.get_text_embeddings()(torch.tensor([tokens])) + )[0].numpy().astype(np.float32) +print(f"Tokens: {len(tokens)}") + +# Load golden prefill[0:9] (captured on-device, text-independent) +if not os.path.exists(GOLDEN_PREFILL): + os.system(f"adb pull /data/local/tmp/kazeia/models/qwen3-tts-npu/full_pipeline_embeds.bin {GOLDEN_PREFILL}") +with open(GOLDEN_PREFILL, "rb") as f: + nP = struct.unpack("