kazeia/executorch-custom/cp_et_test_client.cpp

123 lines
3.7 KiB
C++

/**
* CP ET Test Client — reads batch input file, sends to cp_et_runner socket,
* collects output codes. Runs ON DEVICE as root to avoid adb forward issues.
*
* Usage: cp_et_test_client --input=/path/input.bin --output=/path/output.bin
* --sock_path=/data/local/tmp/kazeia/cp_et.sock
*
* Input format: int32 n_frames, then per frame: float32[1024] hidden + float32[1024] cb0_emb
* Output format: int32 n_frames, then per frame: int32[15] codes + float32 timing_ms
*/
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
static bool read_exact(int fd, void* buf, size_t n) {
size_t d = 0;
while (d < n) {
ssize_t r = read(fd, (char*)buf + d, n - d);
if (r <= 0) return false;
d += r;
}
return true;
}
static bool write_exact(int fd, const void* buf, size_t n) {
size_t d = 0;
while (d < n) {
ssize_t r = write(fd, (const char*)buf + d, n - d);
if (r <= 0) return false;
d += r;
}
return true;
}
int main(int argc, char** argv) {
const char* input_path = nullptr;
const char* output_path = nullptr;
const char* sock_path = "/data/local/tmp/kazeia/cp_et.sock";
for (int i = 1; i < argc; i++) {
if (strncmp(argv[i], "--input=", 8) == 0) input_path = argv[i] + 8;
else if (strncmp(argv[i], "--output=", 9) == 0) output_path = argv[i] + 9;
else if (strncmp(argv[i], "--sock_path=", 12) == 0) sock_path = argv[i] + 12;
}
if (!input_path || !output_path) {
fprintf(stderr, "Usage: %s --input=in.bin --output=out.bin [--sock_path=...]\n", argv[0]);
return 1;
}
// Read input file
FILE* fin = fopen(input_path, "rb");
if (!fin) { fprintf(stderr, "Cannot open %s\n", input_path); return 1; }
int32_t n_frames;
fread(&n_frames, 4, 1, fin);
fprintf(stderr, "Frames: %d\n", n_frames);
const int N_EMBD = 1024;
float* inputs = (float*)malloc(n_frames * 2 * N_EMBD * sizeof(float));
fread(inputs, sizeof(float), n_frames * 2 * N_EMBD, fin);
fclose(fin);
// Connect to socket
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock < 0) { perror("socket"); return 1; }
struct sockaddr_un addr = {};
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
if (connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
perror("connect");
return 1;
}
fprintf(stderr, "Connected to %s\n", sock_path);
// Process frames
FILE* fout = fopen(output_path, "wb");
fwrite(&n_frames, 4, 1, fout);
float total_ms = 0;
for (int i = 0; i < n_frames; i++) {
float* frame = inputs + i * 2 * N_EMBD;
// Send 8192 bytes
if (!write_exact(sock, frame, 2 * N_EMBD * sizeof(float))) {
fprintf(stderr, "Write failed at frame %d\n", i);
break;
}
// Read 64 bytes: 15 ints + 1 float
int32_t codes[15];
float timing;
if (!read_exact(sock, codes, sizeof(codes))) {
fprintf(stderr, "Read codes failed at frame %d\n", i);
break;
}
if (!read_exact(sock, &timing, sizeof(timing))) {
fprintf(stderr, "Read timing failed at frame %d\n", i);
break;
}
fwrite(codes, sizeof(int32_t), 15, fout);
fwrite(&timing, sizeof(float), 1, fout);
total_ms += timing;
fprintf(stderr, " Frame %d: %.1fms codes=[%d,%d,%d,...]\n",
i, timing, codes[0], codes[1], codes[2]);
}
fclose(fout);
close(sock);
free(inputs);
fprintf(stderr, "Done! Total: %.0fms (%.1fms/frame)\n", total_ms, total_ms / n_frames);
return 0;
}