116 lines
3.7 KiB
C++
116 lines
3.7 KiB
C++
/**
|
||
* NEON-optimized operations for TTS Code Predictor heads.
|
||
* Argmax over 2048 vocab × 1024 dim dot products using ARM NEON SIMD.
|
||
*
|
||
* ~15ms per 15 heads vs ~81ms in Java (5.4× speedup).
|
||
*/
|
||
#include <jni.h>
|
||
#include <arm_neon.h>
|
||
#include <cstring>
|
||
#include <cfloat>
|
||
|
||
/**
|
||
* Dot product of two float32 vectors using NEON FMA.
|
||
* Processes 16 floats per iteration (4 accumulators × 4 lanes).
|
||
*/
|
||
static inline float dot_neon(const float* __restrict a, const float* __restrict b, int n) {
|
||
float32x4_t sum0 = vdupq_n_f32(0.0f);
|
||
float32x4_t sum1 = vdupq_n_f32(0.0f);
|
||
float32x4_t sum2 = vdupq_n_f32(0.0f);
|
||
float32x4_t sum3 = vdupq_n_f32(0.0f);
|
||
|
||
int i = 0;
|
||
for (; i + 15 < n; i += 16) {
|
||
sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i));
|
||
sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
|
||
sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8));
|
||
sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12));
|
||
}
|
||
sum0 = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));
|
||
float result = vaddvq_f32(sum0);
|
||
|
||
// Handle remainder
|
||
for (; i < n; i++) result += a[i] * b[i];
|
||
return result;
|
||
}
|
||
|
||
extern "C" {
|
||
|
||
/**
|
||
* Compute argmax(hidden @ head_weights.T) for one head.
|
||
*
|
||
* @param hidden float[dim] — hidden state from transformer
|
||
* @param headWeights float[vocab * dim] — head weight matrix, row-major
|
||
* @param vocab number of vocabulary entries (2048)
|
||
* @param dim hidden dimension (1024)
|
||
* @return argmax index
|
||
*/
|
||
JNIEXPORT jint JNICALL
|
||
Java_com_kazeia_tts_NeonOps_headArgmax(
|
||
JNIEnv* env, jclass,
|
||
jfloatArray jHidden, jfloatArray jHeadWeights, jint vocab, jint dim)
|
||
{
|
||
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
|
||
jfloat* weights = env->GetFloatArrayElements(jHeadWeights, nullptr);
|
||
|
||
int best = 0;
|
||
float bestVal = -FLT_MAX;
|
||
|
||
for (int j = 0; j < vocab; j++) {
|
||
float dot = dot_neon(hidden, weights + j * dim, dim);
|
||
if (dot > bestVal) {
|
||
bestVal = dot;
|
||
best = j;
|
||
}
|
||
}
|
||
|
||
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
|
||
env->ReleaseFloatArrayElements(jHeadWeights, weights, JNI_ABORT);
|
||
return best;
|
||
}
|
||
|
||
/**
|
||
* Batch: compute argmax for all 15 heads at once.
|
||
* Avoids 15 JNI transitions.
|
||
*
|
||
* @param hidden float[dim]
|
||
* @param allHeads float[numHeads * vocab * dim] — all heads concatenated
|
||
* @param numHeads number of heads (15)
|
||
* @param vocab vocabulary size (2048)
|
||
* @param dim hidden dimension (1024)
|
||
* @return int[numHeads] — argmax for each head
|
||
*/
|
||
JNIEXPORT jintArray JNICALL
|
||
Java_com_kazeia_tts_NeonOps_headArgmaxBatch(
|
||
JNIEnv* env, jclass,
|
||
jfloatArray jHidden, jfloatArray jAllHeads,
|
||
jint numHeads, jint vocab, jint dim)
|
||
{
|
||
jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr);
|
||
jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr);
|
||
|
||
jintArray jResult = env->NewIntArray(numHeads);
|
||
jint* result = env->GetIntArrayElements(jResult, nullptr);
|
||
|
||
for (int h = 0; h < numHeads; h++) {
|
||
const float* W = allHeads + (long)h * vocab * dim;
|
||
int best = 0;
|
||
float bestVal = -FLT_MAX;
|
||
for (int j = 0; j < vocab; j++) {
|
||
float dot = dot_neon(hidden, W + j * dim, dim);
|
||
if (dot > bestVal) {
|
||
bestVal = dot;
|
||
best = j;
|
||
}
|
||
}
|
||
result[h] = best;
|
||
}
|
||
|
||
env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT);
|
||
env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT);
|
||
env->ReleaseIntArrayElements(jResult, result, 0);
|
||
return jResult;
|
||
}
|
||
|
||
} // extern "C"
|