kazeia/kazeia-android/app/src/main/jni/neon_ops.cpp

116 lines
3.7 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* NEON-optimized operations for TTS Code Predictor heads.
* Argmax over 2048 vocab × 1024 dim dot products using ARM NEON SIMD.
*
* ~15ms per 15 heads vs ~81ms in Java (5.4× speedup).
*/
#include <jni.h>
#include <arm_neon.h>
#include <cstring>
#include <cfloat>
/**
* Dot product of two float32 vectors using NEON FMA.
* Processes 16 floats per iteration (4 accumulators × 4 lanes).
*/
static inline float dot_neon(const float* __restrict a, const float* __restrict b, int n) {
float32x4_t sum0 = vdupq_n_f32(0.0f);
float32x4_t sum1 = vdupq_n_f32(0.0f);
float32x4_t sum2 = vdupq_n_f32(0.0f);
float32x4_t sum3 = vdupq_n_f32(0.0f);
int i = 0;
for (; i + 15 < n; i += 16) {
sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i));
sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8));
sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12));
}
sum0 = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));
float result = vaddvq_f32(sum0);
// Handle remainder
for (; i < n; i++) result += a[i] * b[i];
return result;
}
extern "C" {
/**
* Compute argmax(hidden @ head_weights.T) for one head.
*
* @param hidden float[dim] — hidden state from transformer
* @param headWeights float[vocab * dim] — head weight matrix, row-major
* @param vocab number of vocabulary entries (2048)
* @param dim hidden dimension (1024)
* @return argmax index
*/
JNIEXPORT jint JNICALL
Java_com_kazeia_tts_NeonOps_headArgmax(
JNIEnv* env, jclass,
jfloatArray jHidden, jfloatArray jHeadWeights, jint vocab, jint dim)
{
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
jfloat* weights = env->GetFloatArrayElements(jHeadWeights, nullptr);
int best = 0;
float bestVal = -FLT_MAX;
for (int j = 0; j < vocab; j++) {
float dot = dot_neon(hidden, weights + j * dim, dim);
if (dot > bestVal) {
bestVal = dot;
best = j;
}
}
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
env->ReleaseFloatArrayElements(jHeadWeights, weights, JNI_ABORT);
return best;
}
/**
* Batch: compute argmax for all 15 heads at once.
* Avoids 15 JNI transitions.
*
* @param hidden float[dim]
* @param allHeads float[numHeads * vocab * dim] — all heads concatenated
* @param numHeads number of heads (15)
* @param vocab vocabulary size (2048)
* @param dim hidden dimension (1024)
* @return int[numHeads] — argmax for each head
*/
JNIEXPORT jintArray JNICALL
Java_com_kazeia_tts_NeonOps_headArgmaxBatch(
JNIEnv* env, jclass,
jfloatArray jHidden, jfloatArray jAllHeads,
jint numHeads, jint vocab, jint dim)
{
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr);
jintArray jResult = env->NewIntArray(numHeads);
jint* result = env->GetIntArrayElements(jResult, nullptr);
for (int h = 0; h < numHeads; h++) {
const float* W = allHeads + (long)h * vocab * dim;
int best = 0;
float bestVal = -FLT_MAX;
for (int j = 0; j < vocab; j++) {
float dot = dot_neon(hidden, W + j * dim, dim);
if (dot > bestVal) {
bestVal = dot;
best = j;
}
}
result[h] = best;
}
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT);
env->ReleaseIntArrayElements(jResult, result, 0);
return jResult;
}
} // extern "C"