/** * NEON-optimized operations for TTS Code Predictor heads. * Argmax over 2048 vocab × 1024 dim dot products using ARM NEON SIMD. * * ~15ms per 15 heads vs ~81ms in Java (5.4× speedup). */ #include #include #include #include /** * Dot product of two float32 vectors using NEON FMA. * Processes 16 floats per iteration (4 accumulators × 4 lanes). */ static inline float dot_neon(const float* __restrict a, const float* __restrict b, int n) { float32x4_t sum0 = vdupq_n_f32(0.0f); float32x4_t sum1 = vdupq_n_f32(0.0f); float32x4_t sum2 = vdupq_n_f32(0.0f); float32x4_t sum3 = vdupq_n_f32(0.0f); int i = 0; for (; i + 15 < n; i += 16) { sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i)); sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4)); sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8)); sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12)); } sum0 = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); float result = vaddvq_f32(sum0); // Handle remainder for (; i < n; i++) result += a[i] * b[i]; return result; } extern "C" { /** * Compute argmax(hidden @ head_weights.T) for one head. * * @param hidden float[dim] — hidden state from transformer * @param headWeights float[vocab * dim] — head weight matrix, row-major * @param vocab number of vocabulary entries (2048) * @param dim hidden dimension (1024) * @return argmax index */ JNIEXPORT jint JNICALL Java_com_kazeia_tts_NeonOps_headArgmax( JNIEnv* env, jclass, jfloatArray jHidden, jfloatArray jHeadWeights, jint vocab, jint dim) { jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr); jfloat* weights = env->GetFloatArrayElements(jHeadWeights, nullptr); int best = 0; float bestVal = -FLT_MAX; for (int j = 0; j < vocab; j++) { float dot = dot_neon(hidden, weights + j * dim, dim); if (dot > bestVal) { bestVal = dot; best = j; } } env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT); env->ReleaseFloatArrayElements(jHeadWeights, weights, JNI_ABORT); return best; } /** * Batch: compute argmax for all 15 heads at once. * Avoids 15 JNI transitions. * * @param hidden float[dim] * @param allHeads float[numHeads * vocab * dim] — all heads concatenated * @param numHeads number of heads (15) * @param vocab vocabulary size (2048) * @param dim hidden dimension (1024) * @return int[numHeads] — argmax for each head */ JNIEXPORT jintArray JNICALL Java_com_kazeia_tts_NeonOps_headArgmaxBatch( JNIEnv* env, jclass, jfloatArray jHidden, jfloatArray jAllHeads, jint numHeads, jint vocab, jint dim) { jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr); jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr); jintArray jResult = env->NewIntArray(numHeads); jint* result = env->GetIntArrayElements(jResult, nullptr); for (int h = 0; h < numHeads; h++) { const float* W = allHeads + (long)h * vocab * dim; int best = 0; float bestVal = -FLT_MAX; for (int j = 0; j < vocab; j++) { float dot = dot_neon(hidden, W + j * dim, dim); if (dot > bestVal) { bestVal = dot; best = j; } } result[h] = best; } env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT); env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT); env->ReleaseIntArrayElements(jResult, result, 0); return jResult; } /** * Argmax with offset into weight buffer (avoids Java array copy). */ JNIEXPORT jint JNICALL Java_com_kazeia_tts_NeonOps_headArgmaxOffset( JNIEnv* env, jclass, jfloatArray jHidden, jfloatArray jAllHeads, jint offset, jint vocab, jint dim) { jfloat* hidden = env->GetFloatArrayElements(jHidden, nullptr); jfloat* allHeads = env->GetFloatArrayElements(jAllHeads, nullptr); const float* W = allHeads + offset; int best = 0; float bestVal = -FLT_MAX; for (int j = 0; j < vocab; j++) { float dot = dot_neon(hidden, W + j * dim, dim); if (dot > bestVal) { bestVal = dot; best = j; } } env->ReleaseFloatArrayElements(jHidden, hidden, JNI_ABORT); env->ReleaseFloatArrayElements(jAllHeads, allHeads, JNI_ABORT); return best; } } // extern "C"