123 lines
3.7 KiB
C++
123 lines
3.7 KiB
C++
/**
|
|
* CP ET Test Client — reads batch input file, sends to cp_et_runner socket,
|
|
* collects output codes. Runs ON DEVICE as root to avoid adb forward issues.
|
|
*
|
|
* Usage: cp_et_test_client --input=/path/input.bin --output=/path/output.bin
|
|
* --sock_path=/data/local/tmp/kazeia/cp_et.sock
|
|
*
|
|
* Input format: int32 n_frames, then per frame: float32[1024] hidden + float32[1024] cb0_emb
|
|
* Output format: int32 n_frames, then per frame: int32[15] codes + float32 timing_ms
|
|
*/
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cstdlib>
|
|
#include <sys/socket.h>
|
|
#include <sys/un.h>
|
|
#include <unistd.h>
|
|
|
|
static bool read_exact(int fd, void* buf, size_t n) {
|
|
size_t d = 0;
|
|
while (d < n) {
|
|
ssize_t r = read(fd, (char*)buf + d, n - d);
|
|
if (r <= 0) return false;
|
|
d += r;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static bool write_exact(int fd, const void* buf, size_t n) {
|
|
size_t d = 0;
|
|
while (d < n) {
|
|
ssize_t r = write(fd, (const char*)buf + d, n - d);
|
|
if (r <= 0) return false;
|
|
d += r;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
int main(int argc, char** argv) {
|
|
const char* input_path = nullptr;
|
|
const char* output_path = nullptr;
|
|
const char* sock_path = "/data/local/tmp/kazeia/cp_et.sock";
|
|
|
|
for (int i = 1; i < argc; i++) {
|
|
if (strncmp(argv[i], "--input=", 8) == 0) input_path = argv[i] + 8;
|
|
else if (strncmp(argv[i], "--output=", 9) == 0) output_path = argv[i] + 9;
|
|
else if (strncmp(argv[i], "--sock_path=", 12) == 0) sock_path = argv[i] + 12;
|
|
}
|
|
|
|
if (!input_path || !output_path) {
|
|
fprintf(stderr, "Usage: %s --input=in.bin --output=out.bin [--sock_path=...]\n", argv[0]);
|
|
return 1;
|
|
}
|
|
|
|
// Read input file
|
|
FILE* fin = fopen(input_path, "rb");
|
|
if (!fin) { fprintf(stderr, "Cannot open %s\n", input_path); return 1; }
|
|
|
|
int32_t n_frames;
|
|
fread(&n_frames, 4, 1, fin);
|
|
fprintf(stderr, "Frames: %d\n", n_frames);
|
|
|
|
const int N_EMBD = 1024;
|
|
float* inputs = (float*)malloc(n_frames * 2 * N_EMBD * sizeof(float));
|
|
fread(inputs, sizeof(float), n_frames * 2 * N_EMBD, fin);
|
|
fclose(fin);
|
|
|
|
// Connect to socket
|
|
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
|
|
if (sock < 0) { perror("socket"); return 1; }
|
|
|
|
struct sockaddr_un addr = {};
|
|
addr.sun_family = AF_UNIX;
|
|
strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
|
|
|
|
if (connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
|
|
perror("connect");
|
|
return 1;
|
|
}
|
|
fprintf(stderr, "Connected to %s\n", sock_path);
|
|
|
|
// Process frames
|
|
FILE* fout = fopen(output_path, "wb");
|
|
fwrite(&n_frames, 4, 1, fout);
|
|
|
|
float total_ms = 0;
|
|
for (int i = 0; i < n_frames; i++) {
|
|
float* frame = inputs + i * 2 * N_EMBD;
|
|
|
|
// Send 8192 bytes
|
|
if (!write_exact(sock, frame, 2 * N_EMBD * sizeof(float))) {
|
|
fprintf(stderr, "Write failed at frame %d\n", i);
|
|
break;
|
|
}
|
|
|
|
// Read 64 bytes: 15 ints + 1 float
|
|
int32_t codes[15];
|
|
float timing;
|
|
if (!read_exact(sock, codes, sizeof(codes))) {
|
|
fprintf(stderr, "Read codes failed at frame %d\n", i);
|
|
break;
|
|
}
|
|
if (!read_exact(sock, &timing, sizeof(timing))) {
|
|
fprintf(stderr, "Read timing failed at frame %d\n", i);
|
|
break;
|
|
}
|
|
|
|
fwrite(codes, sizeof(int32_t), 15, fout);
|
|
fwrite(&timing, sizeof(float), 1, fout);
|
|
total_ms += timing;
|
|
|
|
fprintf(stderr, " Frame %d: %.1fms codes=[%d,%d,%d,...]\n",
|
|
i, timing, codes[0], codes[1], codes[2]);
|
|
}
|
|
|
|
fclose(fout);
|
|
close(sock);
|
|
free(inputs);
|
|
|
|
fprintf(stderr, "Done! Total: %.0fms (%.1fms/frame)\n", total_ms, total_ms / n_frames);
|
|
return 0;
|
|
}
|