Add NEON SIMD heads argmax for CP — 2.3× speedup

CP head dot products (15 × 2048×1024) optimized with ARM NEON
vfmaq_f32 (4 accumulators, 16 floats/iteration).

CP/frame: 131ms → 58ms, total pipeline: 22.7s → 14.7s (RTF 3.2)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kazeia Team 2026-04-09 08:55:20 +02:00
parent 389ffa7c61
commit 8bfe6c7445
4 changed files with 134 additions and 10 deletions

View File

@ -0,0 +1,12 @@
package com.kazeia.tts
/** NEON SIMD optimized operations for TTS head argmax. */
object NeonOps {
init { System.loadLibrary("neon_ops") }
/** Argmax of hidden @ headWeights.T for one head. */
external fun headArgmax(hidden: FloatArray, headWeights: FloatArray, vocab: Int, dim: Int): Int
/** Batch argmax for all heads at once (avoids JNI overhead per head). */
external fun headArgmaxBatch(hidden: FloatArray, allHeads: FloatArray, numHeads: Int, vocab: Int, dim: Int): IntArray
}

View File

@ -1591,7 +1591,7 @@ class Qwen3TtsEngine(
// .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ...
val hiddenOut = outputs[0].toTensor().dataAsFloatArray
// Head argmax on CPU using cached heads
// Head argmax using NEON SIMD (5× faster than Java)
if (step >= 1 && step - 1 < 15) {
if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15)
val cache = cpHeadsCache!!
@ -1600,15 +1600,7 @@ class Qwen3TtsEngine(
val hp = cpHeadsPath ?: return codes
cache[cbIdx] = loadNpy(hp.replace("cp_heads.npy", "head_${cbIdx}.npy"))
}
val headData = cache[cbIdx]!!
var best = 0; var bestVal = Float.NEGATIVE_INFINITY
for (j in 0 until CODEBOOK_SIZE) {
var dot = 0f
val off = j * TALKER_DIM
for (k in 0 until TALKER_DIM) dot += hiddenOut[k] * headData[off + k]
if (dot > bestVal) { bestVal = dot; best = j }
}
codes[cbIdx] = best
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, cache[cbIdx]!!, CODEBOOK_SIZE, TALKER_DIM)
}
// Update KV caches (output is [1,8,16,128] — fixed size, already shifted)

View File

@ -37,6 +37,11 @@ target_include_directories(whisper_jni PRIVATE
target_link_libraries(whisper_jni whisper ggml ggml-base ggml-cpu android log)
target_compile_options(whisper_jni PRIVATE -std=c++17 -O2)
# --- NEON optimized ops for TTS heads ---
add_library(neon_ops SHARED neon_ops.cpp)
target_link_libraries(neon_ops log)
target_compile_options(neon_ops PRIVATE -std=c++17 -O3 -march=armv8.2-a+fp16)
# --- Mel Extractor (HuggingFace-compatible, no whisper.cpp dependency) ---
add_library(mel_extractor SHARED mel_extractor.cpp)
target_link_libraries(mel_extractor android log)

View File

@ -0,0 +1,115 @@
/**
* NEON-optimized operations for TTS Code Predictor heads.
* Argmax over 2048 vocab × 1024 dim dot products using ARM NEON SIMD.
*
* ~15ms per 15 heads vs ~81ms in Java (5.4× speedup).
*/
#include <jni.h>
#include <arm_neon.h>
#include <cstring>
#include <cfloat>
/**
* Dot product of two float32 vectors using NEON FMA.
* Processes 16 floats per iteration (4 accumulators × 4 lanes).
*/
static inline float dot_neon(const float* __restrict a, const float* __restrict b, int n) {
float32x4_t sum0 = vdupq_n_f32(0.0f);
float32x4_t sum1 = vdupq_n_f32(0.0f);
float32x4_t sum2 = vdupq_n_f32(0.0f);
float32x4_t sum3 = vdupq_n_f32(0.0f);
int i = 0;
for (; i + 15 < n; i += 16) {
sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i));
sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8));
sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12));
}
sum0 = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));
float result = vaddvq_f32(sum0);
// Handle remainder
for (; i < n; i++) result += a[i] * b[i];
return result;
}
extern "C" {
/**
* Compute argmax(hidden @ head_weights.T) for one head.
*
* @param hidden float[dim] hidden state from transformer
* @param headWeights float[vocab * dim] head weight matrix, row-major
* @param vocab number of vocabulary entries (2048)
* @param dim hidden dimension (1024)
* @return argmax index
*/
JNIEXPORT jint JNICALL
Java_com_kazeia_tts_NeonOps_headArgmax(
JNIEnv* env, jclass,
jfloatArray jHidden, jfloatArray jHeadWeights, jint vocab, jint dim)
{
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
jfloat* weights = env->GetFloatArrayElements(jHeadWeights, nullptr);
int best = 0;
float bestVal = -FLT_MAX;
for (int j = 0; j < vocab; j++) {
float dot = dot_neon(hidden, weights + j * dim, dim);
if (dot > bestVal) {
bestVal = dot;
best = j;
}
}
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
env->ReleaseFloatArrayElements(jHeadWeights, weights, JNI_ABORT);
return best;
}
/**
* Batch: compute argmax for all 15 heads at once.
* Avoids 15 JNI transitions.
*
* @param hidden float[dim]
* @param allHeads float[numHeads * vocab * dim] all heads concatenated
* @param numHeads number of heads (15)
* @param vocab vocabulary size (2048)
* @param dim hidden dimension (1024)
* @return int[numHeads] argmax for each head
*/
JNIEXPORT jintArray JNICALL
Java_com_kazeia_tts_NeonOps_headArgmaxBatch(
JNIEnv* env, jclass,
jfloatArray jHidden, jfloatArray jAllHeads,
jint numHeads, jint vocab, jint dim)
{
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr);
jintArray jResult = env->NewIntArray(numHeads);
jint* result = env->GetIntArrayElements(jResult, nullptr);
for (int h = 0; h < numHeads; h++) {
const float* W = allHeads + (long)h * vocab * dim;
int best = 0;
float bestVal = -FLT_MAX;
for (int j = 0; j < vocab; j++) {
float dot = dot_neon(hidden, W + j * dim, dim);
if (dot > bestVal) {
bestVal = dot;
best = j;
}
}
result[h] = best;
}
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT);
env->ReleaseIntArrayElements(jResult, result, 0);
return jResult;
}
} // extern "C"