241 lines
9.3 KiB
Kotlin
241 lines
9.3 KiB
Kotlin
package com.kazeia.service
|
|
|
|
import android.util.Log
|
|
import com.kazeia.core.*
|
|
import kotlinx.coroutines.*
|
|
import kotlinx.coroutines.flow.MutableStateFlow
|
|
import kotlinx.coroutines.flow.StateFlow
|
|
|
|
/**
|
|
* Orchestrates the full pipeline: STT → [Processors chain] → TTS
|
|
* STT and TTS are independent — they only exchange text.
|
|
* Processors are pluggable and executed in order.
|
|
*/
|
|
class KazeiaPipeline {
|
|
|
|
companion object {
|
|
private const val TAG = "Pipeline"
|
|
}
|
|
|
|
private var stt: SttEngine? = null
|
|
private var tts: TtsEngine? = null
|
|
private val processors = mutableListOf<MessageProcessor>()
|
|
private val context = ConversationContext()
|
|
|
|
private val _messages = MutableStateFlow<List<ChatMessage>>(emptyList())
|
|
val messages: StateFlow<List<ChatMessage>> = _messages
|
|
|
|
private val _logs = MutableStateFlow<List<String>>(emptyList())
|
|
val logs: StateFlow<List<String>> = _logs
|
|
|
|
private val _pipelineState = MutableStateFlow<PipelineState>(PipelineState.Idle)
|
|
val pipelineState: StateFlow<PipelineState> = _pipelineState
|
|
|
|
fun setStt(engine: SttEngine) { stt = engine; log("STT set: ${engine::class.simpleName}") }
|
|
fun setTts(engine: TtsEngine) { tts = engine; log("TTS set: ${engine::class.simpleName}") }
|
|
|
|
fun addProcessor(processor: MessageProcessor) {
|
|
processors.add(processor)
|
|
log("Processor added: ${processor.name} (${processors.size} total)")
|
|
}
|
|
|
|
fun removeProcessor(name: String) {
|
|
processors.removeAll { it.name == name }
|
|
log("Processor removed: $name")
|
|
}
|
|
|
|
fun getProcessors(): List<MessageProcessor> = processors.toList()
|
|
|
|
/**
|
|
* Process text input through the pipeline: [Processors] → TTS
|
|
*/
|
|
suspend fun processText(text: String) {
|
|
log("Input: '$text'")
|
|
addMessage(ChatMessage(role = ChatMessage.Role.PATIENT, text = text))
|
|
context.metadata["last_input"] = text
|
|
|
|
val t0 = System.currentTimeMillis()
|
|
val result = runProcessors(text)
|
|
val processingMs = System.currentTimeMillis() - t0
|
|
|
|
if (result.responseText.isNotBlank()) {
|
|
log("Response: '${result.responseText.take(60)}...' (${processingMs}ms)")
|
|
addMessage(ChatMessage(role = ChatMessage.Role.KAZEIA, text = result.responseText))
|
|
|
|
// Log metadata
|
|
result.metadata.forEach { (k, v) -> log(" $k=$v") }
|
|
|
|
// TTS
|
|
if (result.shouldSpeak) {
|
|
speak(result.responseText)
|
|
}
|
|
}
|
|
|
|
// Update context history
|
|
context.history.toMutableList().apply {
|
|
add(ChatMessage(role = ChatMessage.Role.PATIENT, text = text))
|
|
if (result.responseText.isNotBlank()) {
|
|
add(ChatMessage(role = ChatMessage.Role.KAZEIA, text = result.responseText))
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Process audio through: STT → [Processors] → TTS
|
|
*/
|
|
suspend fun processAudio(audioData: ShortArray) {
|
|
val sttEngine = stt ?: return
|
|
|
|
_pipelineState.value = PipelineState.Transcribing
|
|
val t0 = System.currentTimeMillis()
|
|
val transcription = sttEngine.transcribe(audioData, context.language)
|
|
val sttMs = System.currentTimeMillis() - t0
|
|
|
|
if (transcription.text.isBlank()) {
|
|
log("STT: (silence) ${sttMs}ms")
|
|
return
|
|
}
|
|
|
|
log("STT: '${transcription.text}' ${sttMs}ms (RTF=${"%.2f".format(sttMs.toFloat() / (audioData.size * 1000f / 16000))})")
|
|
_pipelineState.value = PipelineState.Transcribed(transcription.text)
|
|
|
|
processText(transcription.text)
|
|
}
|
|
|
|
/**
|
|
* Run text through all processors in chain.
|
|
* First processor that returns shouldContinueChain=false wins.
|
|
*/
|
|
private suspend fun runProcessors(text: String): ProcessorResult {
|
|
_pipelineState.value = PipelineState.Thinking
|
|
|
|
for (processor in processors) {
|
|
if (!processor.isReady()) continue
|
|
try {
|
|
val t0 = System.currentTimeMillis()
|
|
val result = processor.process(text, context)
|
|
val elapsed = System.currentTimeMillis() - t0
|
|
log("[${processor.name}] ${elapsed}ms → ${if (result.shouldContinueChain) "continue" else "done"}")
|
|
|
|
if (!result.shouldContinueChain) {
|
|
return result
|
|
}
|
|
} catch (e: Exception) {
|
|
log("[${processor.name}] ERROR: ${e.message}")
|
|
}
|
|
}
|
|
|
|
// No processor handled it → echo
|
|
return ProcessorResult(responseText = text, metadata = mapOf("mode" to "echo"))
|
|
}
|
|
|
|
private suspend fun speak(text: String) = speakText(text)
|
|
|
|
/**
|
|
* Public entry point for speaking a full (possibly multi-sentence) text.
|
|
* When TTS is Qwen3, text is sentence-split and fed through a streaming
|
|
* session so first audio arrives after the first sentence rather than
|
|
* after the full response is synthesised. Other TTS backends fall back
|
|
* to the legacy one-shot synthesizeAndPlay call.
|
|
*
|
|
* Made public so KazeiaService can route its voice-command replies and
|
|
* the echo-mode playback through the same path — otherwise each TTS
|
|
* site reimplemented the "streaming-or-fallback" dispatch.
|
|
*/
|
|
suspend fun speakText(
|
|
text: String,
|
|
// Fires the instant each synthesized sentence starts playing
|
|
// through the speaker, with the sentence text, audio duration,
|
|
// and a per-ENVELOPE_WINDOW_MS RMS envelope. Used by
|
|
// processLlmResponse to defer the KAZEIA chat bubble appearance
|
|
// until sound is audible, pace word-by-word reveal inside the
|
|
// bubble, and drive the AudioVisualizerView orb.
|
|
onSegmentPlaying: ((
|
|
sentence: String,
|
|
durationMs: Long,
|
|
rmsEnvelope: FloatArray,
|
|
spectrogram: Array<FloatArray>
|
|
) -> Unit)? = null
|
|
) {
|
|
val ttsEngine = tts ?: return
|
|
_pipelineState.value = PipelineState.Speaking
|
|
try {
|
|
val qwen = ttsEngine as? com.kazeia.tts.Qwen3TtsEngine
|
|
if (qwen != null) {
|
|
qwen.onSegmentPlaying = onSegmentPlaying
|
|
qwen.startStreamingSession()
|
|
val streamer = com.kazeia.tts.SentenceStreamer { raw ->
|
|
// Strip emoji / non-speakable pictographs before TTS
|
|
// so a standalone "😊" doesn't become its own noisy
|
|
// segment. The chat bubble keeps the original text —
|
|
// only the audio path sees the cleaned version.
|
|
val spoken = stripNonSpeakable(raw).trim()
|
|
if (spoken.isNotEmpty()) qwen.enqueueSentence(spoken)
|
|
}
|
|
streamer.append(text)
|
|
streamer.flush()
|
|
qwen.endStreamingSession()
|
|
} else {
|
|
ttsEngine.synthesizeAndPlay(text, context.language,
|
|
onComplete = { _pipelineState.value = PipelineState.Idle }
|
|
)
|
|
}
|
|
} catch (e: Exception) {
|
|
log("TTS error: ${e.message}")
|
|
}
|
|
_pipelineState.value = PipelineState.Idle
|
|
}
|
|
|
|
fun addMessage(msg: ChatMessage) {
|
|
_messages.value = _messages.value + msg
|
|
}
|
|
|
|
/**
|
|
* Drop emoji + dingbat + pictographic characters so the TTS engine
|
|
* doesn't try to synthesize them. Covers the main Unicode emoji
|
|
* blocks (Miscellaneous Symbols, Dingbats, Emoticons, Transport,
|
|
* Supplemental Symbols and Pictographs, etc.) plus variation
|
|
* selectors and zero-width joiners that tag emoji sequences.
|
|
* Keeps everything in the Basic Latin / Latin-1 / Latin Extended
|
|
* ranges + common French punctuation untouched.
|
|
*/
|
|
private fun stripNonSpeakable(text: String): String {
|
|
val sb = StringBuilder(text.length)
|
|
var i = 0
|
|
while (i < text.length) {
|
|
val cp = text.codePointAt(i)
|
|
val skip = when {
|
|
cp in 0x2600..0x27BF -> true // misc symbols + dingbats
|
|
cp in 0x1F300..0x1F5FF -> true // pictographs
|
|
cp in 0x1F600..0x1F64F -> true // emoticons
|
|
cp in 0x1F680..0x1F6FF -> true // transport
|
|
cp in 0x1F700..0x1F77F -> true // alchemical
|
|
cp in 0x1F780..0x1F7FF -> true // geometric extended
|
|
cp in 0x1F800..0x1F8FF -> true // supplemental arrows-c
|
|
cp in 0x1F900..0x1F9FF -> true // supplemental pictographs
|
|
cp in 0x1FA00..0x1FAFF -> true // symbols & pictographs extended-A
|
|
cp == 0x200D -> true // zero-width joiner
|
|
cp in 0xFE00..0xFE0F -> true // variation selectors
|
|
cp in 0x1F1E6..0x1F1FF -> true // regional indicators (flags)
|
|
else -> false
|
|
}
|
|
if (!skip) sb.appendCodePoint(cp)
|
|
i += Character.charCount(cp)
|
|
}
|
|
return sb.toString()
|
|
}
|
|
|
|
fun log(msg: String) {
|
|
Log.i(TAG, msg)
|
|
val time = java.text.SimpleDateFormat("HH:mm:ss.SSS", java.util.Locale.FRANCE)
|
|
.format(java.util.Date())
|
|
_logs.value = _logs.value.takeLast(199) + "$time $msg"
|
|
}
|
|
|
|
fun release() {
|
|
stt?.release()
|
|
tts?.release()
|
|
processors.forEach { it.release() }
|
|
}
|
|
}
|