Add NEON SIMD heads argmax for CP — 2.3× speedup
CP head dot products (15 × 2048×1024) optimized with ARM NEON vfmaq_f32 (4 accumulators, 16 floats/iteration). CP/frame: 131ms → 58ms, total pipeline: 22.7s → 14.7s (RTF 3.2) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
389ffa7c61
commit
8bfe6c7445
|
|
@ -0,0 +1,12 @@
|
|||
package com.kazeia.tts
|
||||
|
||||
/** NEON SIMD optimized operations for TTS head argmax. */
|
||||
object NeonOps {
|
||||
init { System.loadLibrary("neon_ops") }
|
||||
|
||||
/** Argmax of hidden @ headWeights.T for one head. */
|
||||
external fun headArgmax(hidden: FloatArray, headWeights: FloatArray, vocab: Int, dim: Int): Int
|
||||
|
||||
/** Batch argmax for all heads at once (avoids JNI overhead per head). */
|
||||
external fun headArgmaxBatch(hidden: FloatArray, allHeads: FloatArray, numHeads: Int, vocab: Int, dim: Int): IntArray
|
||||
}
|
||||
|
|
@ -1591,7 +1591,7 @@ class Qwen3TtsEngine(
|
|||
// .pte outputs: hidden[1,1,1024], k0[1,8,16,128], v0[1,8,16,128], ...
|
||||
val hiddenOut = outputs[0].toTensor().dataAsFloatArray
|
||||
|
||||
// Head argmax on CPU using cached heads
|
||||
// Head argmax using NEON SIMD (5× faster than Java)
|
||||
if (step >= 1 && step - 1 < 15) {
|
||||
if (cpHeadsCache == null) cpHeadsCache = arrayOfNulls(15)
|
||||
val cache = cpHeadsCache!!
|
||||
|
|
@ -1600,15 +1600,7 @@ class Qwen3TtsEngine(
|
|||
val hp = cpHeadsPath ?: return codes
|
||||
cache[cbIdx] = loadNpy(hp.replace("cp_heads.npy", "head_${cbIdx}.npy"))
|
||||
}
|
||||
val headData = cache[cbIdx]!!
|
||||
var best = 0; var bestVal = Float.NEGATIVE_INFINITY
|
||||
for (j in 0 until CODEBOOK_SIZE) {
|
||||
var dot = 0f
|
||||
val off = j * TALKER_DIM
|
||||
for (k in 0 until TALKER_DIM) dot += hiddenOut[k] * headData[off + k]
|
||||
if (dot > bestVal) { bestVal = dot; best = j }
|
||||
}
|
||||
codes[cbIdx] = best
|
||||
codes[cbIdx] = NeonOps.headArgmax(hiddenOut, cache[cbIdx]!!, CODEBOOK_SIZE, TALKER_DIM)
|
||||
}
|
||||
|
||||
// Update KV caches (output is [1,8,16,128] — fixed size, already shifted)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,11 @@ target_include_directories(whisper_jni PRIVATE
|
|||
target_link_libraries(whisper_jni whisper ggml ggml-base ggml-cpu android log)
|
||||
target_compile_options(whisper_jni PRIVATE -std=c++17 -O2)
|
||||
|
||||
# --- NEON optimized ops for TTS heads ---
|
||||
add_library(neon_ops SHARED neon_ops.cpp)
|
||||
target_link_libraries(neon_ops log)
|
||||
target_compile_options(neon_ops PRIVATE -std=c++17 -O3 -march=armv8.2-a+fp16)
|
||||
|
||||
# --- Mel Extractor (HuggingFace-compatible, no whisper.cpp dependency) ---
|
||||
add_library(mel_extractor SHARED mel_extractor.cpp)
|
||||
target_link_libraries(mel_extractor android log)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* NEON-optimized operations for TTS Code Predictor heads.
|
||||
* Argmax over 2048 vocab × 1024 dim dot products using ARM NEON SIMD.
|
||||
*
|
||||
* ~15ms per 15 heads vs ~81ms in Java (5.4× speedup).
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <arm_neon.h>
|
||||
#include <cstring>
|
||||
#include <cfloat>
|
||||
|
||||
/**
|
||||
* Dot product of two float32 vectors using NEON FMA.
|
||||
* Processes 16 floats per iteration (4 accumulators × 4 lanes).
|
||||
*/
|
||||
static inline float dot_neon(const float* __restrict a, const float* __restrict b, int n) {
|
||||
float32x4_t sum0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sum1 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sum2 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sum3 = vdupq_n_f32(0.0f);
|
||||
|
||||
int i = 0;
|
||||
for (; i + 15 < n; i += 16) {
|
||||
sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i));
|
||||
sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
|
||||
sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8));
|
||||
sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12));
|
||||
}
|
||||
sum0 = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));
|
||||
float result = vaddvq_f32(sum0);
|
||||
|
||||
// Handle remainder
|
||||
for (; i < n; i++) result += a[i] * b[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
/**
|
||||
* Compute argmax(hidden @ head_weights.T) for one head.
|
||||
*
|
||||
* @param hidden float[dim] — hidden state from transformer
|
||||
* @param headWeights float[vocab * dim] — head weight matrix, row-major
|
||||
* @param vocab number of vocabulary entries (2048)
|
||||
* @param dim hidden dimension (1024)
|
||||
* @return argmax index
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_kazeia_tts_NeonOps_headArgmax(
|
||||
JNIEnv* env, jclass,
|
||||
jfloatArray jHidden, jfloatArray jHeadWeights, jint vocab, jint dim)
|
||||
{
|
||||
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
|
||||
jfloat* weights = env->GetFloatArrayElements(jHeadWeights, nullptr);
|
||||
|
||||
int best = 0;
|
||||
float bestVal = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < vocab; j++) {
|
||||
float dot = dot_neon(hidden, weights + j * dim, dim);
|
||||
if (dot > bestVal) {
|
||||
bestVal = dot;
|
||||
best = j;
|
||||
}
|
||||
}
|
||||
|
||||
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
|
||||
env->ReleaseFloatArrayElements(jHeadWeights, weights, JNI_ABORT);
|
||||
return best;
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch: compute argmax for all 15 heads at once.
|
||||
* Avoids 15 JNI transitions.
|
||||
*
|
||||
* @param hidden float[dim]
|
||||
* @param allHeads float[numHeads * vocab * dim] — all heads concatenated
|
||||
* @param numHeads number of heads (15)
|
||||
* @param vocab vocabulary size (2048)
|
||||
* @param dim hidden dimension (1024)
|
||||
* @return int[numHeads] — argmax for each head
|
||||
*/
|
||||
JNIEXPORT jintArray JNICALL
|
||||
Java_com_kazeia_tts_NeonOps_headArgmaxBatch(
|
||||
JNIEnv* env, jclass,
|
||||
jfloatArray jHidden, jfloatArray jAllHeads,
|
||||
jint numHeads, jint vocab, jint dim)
|
||||
{
|
||||
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
|
||||
jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr);
|
||||
|
||||
jintArray jResult = env->NewIntArray(numHeads);
|
||||
jint* result = env->GetIntArrayElements(jResult, nullptr);
|
||||
|
||||
for (int h = 0; h < numHeads; h++) {
|
||||
const float* W = allHeads + (long)h * vocab * dim;
|
||||
int best = 0;
|
||||
float bestVal = -FLT_MAX;
|
||||
for (int j = 0; j < vocab; j++) {
|
||||
float dot = dot_neon(hidden, W + j * dim, dim);
|
||||
if (dot > bestVal) {
|
||||
bestVal = dot;
|
||||
best = j;
|
||||
}
|
||||
}
|
||||
result[h] = best;
|
||||
}
|
||||
|
||||
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
|
||||
env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT);
|
||||
env->ReleaseIntArrayElements(jResult, result, 0);
|
||||
return jResult;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Loading…
Reference in New Issue