feat: camera ground-truth training pipeline (ADR-079, #362)

Add 4 scripts for camera-supervised WiFlow pose training:

- collect-ground-truth.py: synchronized webcam + CSI capture via
  MediaPipe PoseLandmarker (17 COCO keypoints at 30fps)
- align-ground-truth.js: time-align camera keypoints with CSI windows
  using binary search, confidence-weighted averaging
- train-wiflow-supervised.js: 3-phase supervised training (contrastive
  pretrain → supervised keypoint regression → bone-constrained
  refinement) with curriculum learning and CSI augmentation
- eval-wiflow.js: PCK@10/20/50, MPJPE, per-joint breakdown, baseline
  proxy mode for benchmarking

Baseline benchmark (proxy poses, no camera supervision):
  PCK@10: 11.8% | PCK@20: 35.3% | PCK@50: 94.1% | MPJPE: 0.067

Camera pipeline validated over Tailscale to Mac Mini M4 Pro
(1920x1080, 14/17 keypoints visible, MediaPipe confidence 0.94-1.0).

Target after camera-supervised training: PCK@20 > 50%

Closes #362

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv
2026-04-06 14:07:25 -04:00
parent b5e924cd72
commit e3522ddcda
5 changed files with 3176 additions and 0 deletions
+477
View File
@@ -0,0 +1,477 @@
#!/usr/bin/env node
/**
* Ground-Truth Alignment — Camera Keypoints <-> CSI Recording
*
* Time-aligns camera keypoint data with CSI recording data to produce
* paired training samples for WiFlow supervised training (ADR-079).
*
* Camera keypoints: data/ground-truth/gt-{timestamp}.jsonl
* CSI recordings: data/recordings/*.csi.jsonl
* Paired output: data/paired/*.paired.jsonl
*
* Usage:
* node scripts/align-ground-truth.js \
* --gt data/ground-truth/gt-1775300000.jsonl \
* --csi data/recordings/overnight-1775217646.csi.jsonl \
* --output data/paired/aligned.paired.jsonl
*
* # With clock offset correction (camera ahead by 50ms)
* node scripts/align-ground-truth.js \
* --gt data/ground-truth/gt-1775300000.jsonl \
* --csi data/recordings/overnight-1775217646.csi.jsonl \
* --clock-offset-ms -50
*
* ADR: docs/adr/ADR-079
*/
'use strict';
const fs = require('fs');
const path = require('path');
const { parseArgs } = require('util');
// ---------------------------------------------------------------------------
// CLI argument parsing
// ---------------------------------------------------------------------------
const { values: args } = parseArgs({
options: {
gt: { type: 'string' },
csi: { type: 'string' },
output: { type: 'string', short: 'o' },
'window-ms': { type: 'string', default: '200' },
'window-frames': { type: 'string', default: '20' },
'min-camera-frames': { type: 'string', default: '3' },
'min-confidence': { type: 'string', default: '0.5' },
'clock-offset-ms': { type: 'string', default: '0' },
help: { type: 'boolean', short: 'h', default: false },
},
strict: true,
});
if (args.help || !args.gt || !args.csi) {
console.log(`
Usage: node scripts/align-ground-truth.js --gt <gt.jsonl> --csi <csi.jsonl> [options]
Required:
--gt <path> Camera ground-truth JSONL file
--csi <path> CSI recording JSONL file
Options:
--output, -o <path> Output paired JSONL (default: data/paired/<basename>.paired.jsonl)
--window-ms <ms> CSI window size in ms (default: 200)
--window-frames <n> Frames per CSI window (default: 20)
--min-camera-frames <n> Minimum camera frames per window (default: 3)
--min-confidence <f> Minimum average confidence threshold (default: 0.5)
--clock-offset-ms <ms> Manual clock offset: added to camera timestamps (default: 0)
--help, -h Show this help
`);
process.exit(args.help ? 0 : 1);
}
const WINDOW_FRAMES = parseInt(args['window-frames'], 10);
const WINDOW_MS = parseInt(args['window-ms'], 10);
const MIN_CAMERA_FRAMES = parseInt(args['min-camera-frames'], 10);
const MIN_CONFIDENCE = parseFloat(args['min-confidence']);
const CLOCK_OFFSET_MS = parseFloat(args['clock-offset-ms']);
const NUM_KEYPOINTS = 17; // COCO 17-keypoint format
// ---------------------------------------------------------------------------
// Timestamp conversion
// ---------------------------------------------------------------------------
/**
* Convert camera nanosecond timestamp to milliseconds.
* Applies clock offset correction.
*/
function cameraTsToMs(tsNs) {
return tsNs / 1e6 + CLOCK_OFFSET_MS;
}
/**
* Convert ISO 8601 timestamp string to milliseconds since epoch.
*/
function isoToMs(isoStr) {
return new Date(isoStr).getTime();
}
// ---------------------------------------------------------------------------
// IQ hex parsing (matches train-wiflow.js conventions)
// ---------------------------------------------------------------------------
/**
* Parse IQ hex string into signed byte pairs [I0, Q0, I1, Q1, ...].
*/
function parseIqHex(hexStr) {
const bytes = [];
for (let i = 0; i < hexStr.length; i += 2) {
let val = parseInt(hexStr.substr(i, 2), 16);
if (val > 127) val -= 256; // signed byte
bytes.push(val);
}
return bytes;
}
/**
* Extract amplitude from IQ data for a given number of subcarriers.
* Returns Float32Array of amplitudes [nSubcarriers].
* Skips first I/Q pair (DC offset) per WiFlow paper recommendation.
*/
function extractAmplitude(iqBytes, nSubcarriers) {
const amp = new Float32Array(nSubcarriers);
const start = 2; // skip first IQ pair (DC offset)
for (let sc = 0; sc < nSubcarriers; sc++) {
const idx = start + sc * 2;
if (idx + 1 < iqBytes.length) {
const I = iqBytes[idx];
const Q = iqBytes[idx + 1];
amp[sc] = Math.sqrt(I * I + Q * Q);
}
}
return amp;
}
// ---------------------------------------------------------------------------
// File loading
// ---------------------------------------------------------------------------
/**
* Load and parse a JSONL file, skipping blank/malformed lines.
*/
function loadJsonl(filePath) {
const lines = fs.readFileSync(filePath, 'utf8').split('\n');
const records = [];
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
try {
records.push(JSON.parse(trimmed));
} catch {
// skip malformed lines
}
}
return records;
}
/**
* Load camera ground-truth file.
* Returns array of { tsMs, keypoints, confidence, nVisible, nPersons }.
*/
function loadGroundTruth(filePath) {
const raw = loadJsonl(filePath);
const frames = [];
for (const r of raw) {
if (r.ts_ns == null || !r.keypoints) continue;
frames.push({
tsMs: cameraTsToMs(r.ts_ns),
keypoints: r.keypoints,
confidence: r.confidence ?? 0,
nVisible: r.n_visible ?? 0,
nPersons: r.n_persons ?? 1,
});
}
// Sort by timestamp
frames.sort((a, b) => a.tsMs - b.tsMs);
return frames;
}
/**
* Load CSI recording file.
* Separates raw_csi frames and feature frames.
*/
function loadCsi(filePath) {
const raw = loadJsonl(filePath);
const rawCsi = [];
const features = [];
for (const r of raw) {
if (!r.timestamp) continue;
const tsMs = isoToMs(r.timestamp);
if (isNaN(tsMs)) continue;
if (r.type === 'raw_csi') {
rawCsi.push({
tsMs,
nodeId: r.node_id,
subcarriers: r.subcarriers ?? 128,
iqHex: r.iq_hex,
rssi: r.rssi,
seq: r.seq,
});
} else if (r.type === 'feature') {
features.push({
tsMs,
nodeId: r.node_id,
features: r.features,
rssi: r.rssi,
seq: r.seq,
});
}
}
// Sort by timestamp
rawCsi.sort((a, b) => a.tsMs - b.tsMs);
features.sort((a, b) => a.tsMs - b.tsMs);
return { rawCsi, features };
}
// ---------------------------------------------------------------------------
// Windowing
// ---------------------------------------------------------------------------
/**
* Group frames into non-overlapping windows of `windowSize` consecutive frames.
*/
function groupIntoWindows(frames, windowSize) {
const windows = [];
for (let i = 0; i + windowSize <= frames.length; i += windowSize) {
windows.push(frames.slice(i, i + windowSize));
}
return windows;
}
// ---------------------------------------------------------------------------
// Camera frame matching (binary search)
// ---------------------------------------------------------------------------
/**
* Find all camera frames within [tStart, tEnd] using binary search.
*/
function findCameraFramesInRange(cameraFrames, tStartMs, tEndMs) {
// Binary search for first frame >= tStartMs
let lo = 0;
let hi = cameraFrames.length;
while (lo < hi) {
const mid = (lo + hi) >>> 1;
if (cameraFrames[mid].tsMs < tStartMs) lo = mid + 1;
else hi = mid;
}
const matched = [];
for (let i = lo; i < cameraFrames.length; i++) {
if (cameraFrames[i].tsMs > tEndMs) break;
matched.push(cameraFrames[i]);
}
return matched;
}
// ---------------------------------------------------------------------------
// Keypoint averaging (confidence-weighted)
// ---------------------------------------------------------------------------
/**
* Average keypoints weighted by per-frame confidence.
* Returns { keypoints: [[x,y],...], avgConfidence }.
*/
function averageKeypoints(cameraFrames) {
let totalWeight = 0;
const sumKp = new Array(NUM_KEYPOINTS).fill(null).map(() => [0, 0]);
for (const f of cameraFrames) {
const w = f.confidence || 1e-6;
totalWeight += w;
for (let k = 0; k < NUM_KEYPOINTS && k < f.keypoints.length; k++) {
sumKp[k][0] += f.keypoints[k][0] * w;
sumKp[k][1] += f.keypoints[k][1] * w;
}
}
if (totalWeight === 0) totalWeight = 1;
const keypoints = sumKp.map(([x, y]) => [x / totalWeight, y / totalWeight]);
const avgConfidence = cameraFrames.reduce((s, f) => s + (f.confidence || 0), 0) / cameraFrames.length;
return { keypoints, avgConfidence };
}
// ---------------------------------------------------------------------------
// CSI matrix extraction
// ---------------------------------------------------------------------------
/**
* Extract CSI amplitude matrix from raw_csi window.
* Returns { data: flat Float32Array, shape: [subcarriers, windowFrames] }.
*/
function extractCsiMatrix(window) {
const nFrames = window.length;
const nSc = window[0].subcarriers || 128;
const matrix = new Float32Array(nSc * nFrames);
for (let f = 0; f < nFrames; f++) {
const frame = window[f];
if (frame.iqHex) {
const iq = parseIqHex(frame.iqHex);
const amp = extractAmplitude(iq, nSc);
matrix.set(amp, f * nSc);
}
}
return { data: Array.from(matrix), shape: [nSc, nFrames] };
}
/**
* Extract feature matrix from feature-type window.
* Returns { data: flat array, shape: [featureDim, windowFrames] }.
*/
function extractFeatureMatrix(window) {
const nFrames = window.length;
const dim = window[0].features ? window[0].features.length : 8;
const matrix = new Float32Array(dim * nFrames);
for (let f = 0; f < nFrames; f++) {
const feats = window[f].features || new Array(dim).fill(0);
for (let d = 0; d < dim; d++) {
matrix[f * dim + d] = feats[d] || 0;
}
}
return { data: Array.from(matrix), shape: [dim, nFrames] };
}
// ---------------------------------------------------------------------------
// Main alignment
// ---------------------------------------------------------------------------
function align() {
const gtPath = path.resolve(args.gt);
const csiPath = path.resolve(args.csi);
// Determine output path
let outputPath;
if (args.output) {
outputPath = path.resolve(args.output);
} else {
const baseName = path.basename(csiPath, '.csi.jsonl');
outputPath = path.resolve('data', 'paired', `${baseName}.paired.jsonl`);
}
// Ensure output directory exists
const outputDir = path.dirname(outputPath);
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir, { recursive: true });
}
console.log('=== Ground-Truth Alignment (ADR-079) ===');
console.log(` GT file: ${gtPath}`);
console.log(` CSI file: ${csiPath}`);
console.log(` Output: ${outputPath}`);
console.log(` Window: ${WINDOW_FRAMES} frames / ${WINDOW_MS} ms`);
console.log(` Min camera frames: ${MIN_CAMERA_FRAMES}`);
console.log(` Min confidence: ${MIN_CONFIDENCE}`);
console.log(` Clock offset: ${CLOCK_OFFSET_MS} ms`);
console.log();
// Load data
console.log('Loading ground-truth...');
const cameraFrames = loadGroundTruth(gtPath);
console.log(` ${cameraFrames.length} camera frames loaded`);
if (cameraFrames.length > 0) {
console.log(` Time range: ${new Date(cameraFrames[0].tsMs).toISOString()} -> ${new Date(cameraFrames[cameraFrames.length - 1].tsMs).toISOString()}`);
}
console.log('Loading CSI data...');
const { rawCsi, features } = loadCsi(csiPath);
console.log(` ${rawCsi.length} raw_csi frames, ${features.length} feature frames`);
// Decide which CSI source to use
const useRawCsi = rawCsi.length >= WINDOW_FRAMES;
const csiSource = useRawCsi ? rawCsi : features;
const sourceLabel = useRawCsi ? 'raw_csi' : 'feature';
if (csiSource.length < WINDOW_FRAMES) {
console.error(`ERROR: Not enough CSI frames (${csiSource.length}) for even one window of ${WINDOW_FRAMES} frames.`);
process.exit(1);
}
console.log(` Using ${sourceLabel} frames (${csiSource.length} total)`);
if (csiSource.length > 0) {
console.log(` CSI time range: ${new Date(csiSource[0].tsMs).toISOString()} -> ${new Date(csiSource[csiSource.length - 1].tsMs).toISOString()}`);
}
console.log();
// Group CSI into windows
const windows = groupIntoWindows(csiSource, WINDOW_FRAMES);
console.log(`Grouped into ${windows.length} CSI windows`);
// Align
const paired = [];
let totalConfidence = 0;
for (const window of windows) {
const tStartMs = window[0].tsMs;
const tEndMs = window[window.length - 1].tsMs;
// Expand window if actual time span is smaller than window-ms
const halfWindow = WINDOW_MS / 2;
const midpoint = (tStartMs + tEndMs) / 2;
const searchStart = Math.min(tStartMs, midpoint - halfWindow);
const searchEnd = Math.max(tEndMs, midpoint + halfWindow);
// Find matching camera frames
const matched = findCameraFramesInRange(cameraFrames, searchStart, searchEnd);
if (matched.length < MIN_CAMERA_FRAMES) continue;
// Check average confidence
const avgConf = matched.reduce((s, f) => s + (f.confidence || 0), 0) / matched.length;
if (avgConf < MIN_CONFIDENCE) continue;
// Average keypoints weighted by confidence
const { keypoints, avgConfidence } = averageKeypoints(matched);
// Extract CSI matrix
const csiMatrix = useRawCsi
? extractCsiMatrix(window)
: extractFeatureMatrix(window);
paired.push({
csi: csiMatrix.data,
csi_shape: csiMatrix.shape,
kp: keypoints,
conf: Math.round(avgConfidence * 1000) / 1000,
n_camera_frames: matched.length,
ts_start: new Date(tStartMs).toISOString(),
ts_end: new Date(tEndMs).toISOString(),
});
totalConfidence += avgConfidence;
}
// Write output
const outputLines = paired.map(s => JSON.stringify(s));
fs.writeFileSync(outputPath, outputLines.join('\n') + (outputLines.length > 0 ? '\n' : ''));
// Print summary
const alignmentRate = windows.length > 0 ? (paired.length / windows.length * 100) : 0;
const avgPairedConf = paired.length > 0 ? (totalConfidence / paired.length) : 0;
console.log();
console.log('=== Alignment Summary ===');
console.log(` Total CSI windows: ${windows.length}`);
console.log(` Paired samples: ${paired.length}`);
console.log(` Alignment rate: ${alignmentRate.toFixed(1)}%`);
console.log(` Avg confidence (paired): ${avgPairedConf.toFixed(3)}`);
console.log(` CSI source: ${sourceLabel} (${csiMatrix_shapeLabel(paired, useRawCsi)})`);
if (paired.length > 0) {
console.log(` Time range covered: ${paired[0].ts_start} -> ${paired[paired.length - 1].ts_end}`);
}
console.log(` Output written: ${outputPath}`);
console.log();
if (paired.length === 0) {
console.log('WARNING: No paired samples produced. Check that camera and CSI time ranges overlap.');
console.log(' Hint: Use --clock-offset-ms to correct misaligned clocks.');
}
}
/**
* Format CSI matrix shape label for summary.
*/
function csiMatrix_shapeLabel(paired, useRawCsi) {
if (paired.length === 0) return useRawCsi ? `[128, ${WINDOW_FRAMES}]` : `[8, ${WINDOW_FRAMES}]`;
const shape = paired[0].csi_shape;
return `[${shape[0]}, ${shape[1]}]`;
}
// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------
align();
+341
View File
@@ -0,0 +1,341 @@
#!/usr/bin/env python3
"""Camera ground-truth collection for WiFi pose estimation training (ADR-079).
Captures webcam keypoints via MediaPipe PoseLandmarker (Tasks API) and
synchronizes with ESP32 CSI recording from the sensing server.
Output: JSONL file in data/ground-truth/ with per-frame 17-keypoint COCO poses.
Usage:
python scripts/collect-ground-truth.py --preview --duration 60
python scripts/collect-ground-truth.py --server http://192.168.1.10:3000
"""
from __future__ import annotations
import argparse
import json
import os
import signal
import sys
import time
import urllib.request
import urllib.error
from pathlib import Path
from datetime import datetime
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks.python import BaseOptions
from mediapipe.tasks.python.vision import (
PoseLandmarker,
PoseLandmarkerOptions,
RunningMode,
)
# ---------------------------------------------------------------------------
# MediaPipe 33 landmarks -> 17 COCO keypoints
# ---------------------------------------------------------------------------
# COCO idx : MP idx : joint name
# 0 : 0 : nose
# 1 : 2 : left_eye
# 2 : 5 : right_eye
# 3 : 7 : left_ear
# 4 : 8 : right_ear
# 5 : 11 : left_shoulder
# 6 : 12 : right_shoulder
# 7 : 13 : left_elbow
# 8 : 14 : right_elbow
# 9 : 15 : left_wrist
# 10 : 16 : right_wrist
# 11 : 23 : left_hip
# 12 : 24 : right_hip
# 13 : 25 : left_knee
# 14 : 26 : right_knee
# 15 : 27 : left_ankle
# 16 : 28 : right_ankle
MP_TO_COCO = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
COCO_BONES = [
(5, 7), (7, 9), (6, 8), (8, 10), # arms
(5, 6), # shoulders
(11, 13), (13, 15), (12, 14), (14, 16), # legs
(11, 12), # hips
(5, 11), (6, 12), # torso
(0, 1), (0, 2), (1, 3), (2, 4), # face
]
MODEL_URL = (
"https://storage.googleapis.com/mediapipe-models/"
"pose_landmarker/pose_landmarker_lite/float16/latest/"
"pose_landmarker_lite.task"
)
MODEL_FILENAME = "pose_landmarker_lite.task"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def ensure_model(cache_dir: Path) -> Path:
"""Download the PoseLandmarker model if not already cached."""
model_path = cache_dir / MODEL_FILENAME
if model_path.exists():
return model_path
cache_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloading {MODEL_FILENAME} ...")
try:
urllib.request.urlretrieve(MODEL_URL, str(model_path))
print(f" saved to {model_path}")
except Exception as exc:
print(f"ERROR: Failed to download model: {exc}", file=sys.stderr)
print(
"Download manually from:\n"
f" {MODEL_URL}\n"
f"and place at {model_path}",
file=sys.stderr,
)
sys.exit(1)
return model_path
def post_json(url: str, payload: dict | None = None, timeout: float = 5.0) -> bool:
"""POST JSON to a URL. Returns True on success, False on failure."""
data = json.dumps(payload or {}).encode("utf-8")
req = urllib.request.Request(
url,
data=data,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
return 200 <= resp.status < 300
except Exception as exc:
print(f"WARNING: POST {url} failed: {exc}", file=sys.stderr)
return False
def draw_skeleton(frame: np.ndarray, keypoints: list[list[float]], w: int, h: int):
"""Draw COCO skeleton overlay on a BGR frame."""
pts = []
for x, y in keypoints:
px, py = int(x * w), int(y * h)
pts.append((px, py))
cv2.circle(frame, (px, py), 4, (0, 255, 0), -1)
for i, j in COCO_BONES:
if i < len(pts) and j < len(pts):
cv2.line(frame, pts[i], pts[j], (0, 200, 255), 2)
# ---------------------------------------------------------------------------
# Main collection loop
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Collect camera ground-truth keypoints for WiFi pose training (ADR-079)."
)
parser.add_argument(
"--server",
default="http://localhost:3000",
help="Sensing server URL (default: http://localhost:3000)",
)
parser.add_argument(
"--preview",
action="store_true",
help="Show live skeleton overlay window",
)
parser.add_argument(
"--duration",
type=int,
default=300,
help="Recording duration in seconds (default: 300)",
)
parser.add_argument(
"--camera",
type=int,
default=0,
help="Camera device index (default: 0)",
)
parser.add_argument(
"--output",
default="data/ground-truth",
help="Output directory (default: data/ground-truth)",
)
args = parser.parse_args()
# --- Resolve paths relative to repo root ---
repo_root = Path(__file__).resolve().parent.parent
output_dir = repo_root / args.output
output_dir.mkdir(parents=True, exist_ok=True)
cache_dir = repo_root / "data" / ".cache"
# --- Download / locate model ---
model_path = ensure_model(cache_dir)
# --- Open camera ---
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
print(
f"ERROR: Cannot open camera index {args.camera}. "
"Check that a webcam is connected and not in use by another app.",
file=sys.stderr,
)
sys.exit(1)
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Camera opened: {frame_w}x{frame_h}")
# --- Create PoseLandmarker ---
options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=str(model_path)),
running_mode=RunningMode.IMAGE,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
landmarker = PoseLandmarker.create_from_options(options)
# --- Output file ---
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = output_dir / f"keypoints_{timestamp_str}.jsonl"
out_file = open(out_path, "w", encoding="utf-8")
print(f"Output: {out_path}")
# --- Start CSI recording ---
recording_url_start = f"{args.server}/api/v1/recording/start"
recording_url_stop = f"{args.server}/api/v1/recording/stop"
csi_started = post_json(recording_url_start)
if csi_started:
print("CSI recording started on sensing server.")
else:
print(
"WARNING: Could not start CSI recording. "
"Camera keypoints will still be captured.",
file=sys.stderr,
)
# --- Graceful shutdown ---
shutdown_requested = False
def _handle_signal(signum, frame):
nonlocal shutdown_requested
shutdown_requested = True
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# --- Collection loop ---
start_time = time.monotonic()
frame_count = 0
total_confidence = 0.0
total_visible = 0
print(f"Collecting for {args.duration}s ... (press 'q' in preview to stop)")
try:
while not shutdown_requested:
elapsed = time.monotonic() - start_time
if elapsed >= args.duration:
break
ret, frame = cap.read()
if not ret:
print("WARNING: Failed to read frame, retrying ...", file=sys.stderr)
time.sleep(0.01)
continue
ts_ns = time.time_ns()
# Convert BGR -> RGB for MediaPipe
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
result = landmarker.detect(mp_image)
n_persons = len(result.pose_landmarks)
if n_persons > 0:
landmarks = result.pose_landmarks[0]
keypoints = []
visibilities = []
for coco_idx in range(17):
mp_idx = MP_TO_COCO[coco_idx]
lm = landmarks[mp_idx]
keypoints.append([round(lm.x, 5), round(lm.y, 5)])
visibilities.append(lm.visibility if lm.visibility else 0.0)
confidence = float(np.mean(visibilities))
n_visible = int(sum(1 for v in visibilities if v > 0.5))
else:
keypoints = []
confidence = 0.0
n_visible = 0
record = {
"ts_ns": ts_ns,
"keypoints": keypoints,
"confidence": round(confidence, 4),
"n_visible": n_visible,
"n_persons": n_persons,
}
out_file.write(json.dumps(record) + "\n")
frame_count += 1
total_confidence += confidence
total_visible += n_visible
# Preview overlay
if args.preview and keypoints:
draw_skeleton(frame, keypoints, frame_w, frame_h)
if args.preview:
remaining = max(0, int(args.duration - elapsed))
cv2.putText(
frame,
f"Frames: {frame_count} Visible: {n_visible}/17 Time: {remaining}s",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
)
cv2.imshow("Ground Truth Collection (ADR-079)", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
finally:
# --- Cleanup ---
out_file.close()
cap.release()
if args.preview:
cv2.destroyAllWindows()
landmarker.close()
# Stop CSI recording
if csi_started:
if post_json(recording_url_stop):
print("CSI recording stopped.")
else:
print("WARNING: Failed to stop CSI recording.", file=sys.stderr)
# --- Summary ---
avg_conf = total_confidence / frame_count if frame_count > 0 else 0.0
avg_vis = total_visible / frame_count if frame_count > 0 else 0.0
print()
print("=== Collection Summary ===")
print(f" Total frames: {frame_count}")
print(f" Avg confidence: {avg_conf:.3f}")
print(f" Avg visible joints: {avg_vis:.1f} / 17")
print(f" Output: {out_path}")
if __name__ == "__main__":
main()
+625
View File
@@ -0,0 +1,625 @@
#!/usr/bin/env node
/**
* WiFlow PCK Evaluation Script (ADR-079)
*
* Measures accuracy of WiFi-based pose estimation against ground-truth
* camera keypoints using PCK (Percentage of Correct Keypoints) and MPJPE
* (Mean Per-Joint Position Error) metrics.
*
* Usage:
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl
* node scripts/eval-wiflow.js --baseline --data data/paired/aligned.paired.jsonl
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl --verbose
*
* ADR: docs/adr/ADR-079
*/
'use strict';
const fs = require('fs');
const path = require('path');
const { parseArgs } = require('util');
// ---------------------------------------------------------------------------
// Resolve WiFlow model dependencies
// ---------------------------------------------------------------------------
const {
WiFlowModel,
COCO_KEYPOINTS,
createRng,
} = require(path.join(__dirname, 'wiflow-model.js'));
const RUVLLM_PATH = path.resolve(__dirname, '..', 'vendor', 'ruvector', 'npm', 'packages', 'ruvllm', 'src');
const { SafeTensorsReader } = require(path.join(RUVLLM_PATH, 'export.js'));
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
const NUM_KEYPOINTS = 17;
const DEFAULT_TORSO_LENGTH = 0.3; // normalized coords fallback
// Joint name aliases for display (short form)
const JOINT_NAMES = [
'nose', 'l_eye', 'r_eye', 'l_ear', 'r_ear',
'l_shoulder', 'r_shoulder', 'l_elbow', 'r_elbow',
'l_wrist', 'r_wrist', 'l_hip', 'r_hip',
'l_knee', 'r_knee', 'l_ankle', 'r_ankle',
];
// Shoulder indices: l_shoulder=5, r_shoulder=6
// Hip indices: l_hip=11, r_hip=12
const L_SHOULDER = 5;
const R_SHOULDER = 6;
const L_HIP = 11;
const R_HIP = 12;
// ---------------------------------------------------------------------------
// CLI argument parsing
// ---------------------------------------------------------------------------
const { values: args } = parseArgs({
options: {
model: { type: 'string', short: 'm' },
data: { type: 'string', short: 'd' },
baseline: { type: 'boolean', default: false },
output: { type: 'string', short: 'o' },
verbose: { type: 'boolean', short: 'v', default: false },
},
strict: true,
});
if (!args.data) {
console.error('Usage: node scripts/eval-wiflow.js --data <paired-jsonl> [--model <path>] [--baseline] [--output <path>]');
console.error('');
console.error('Required:');
console.error(' --data, -d <path> Paired CSI + keypoint JSONL (from align-ground-truth.js)');
console.error('');
console.error('Options:');
console.error(' --model, -m <path> Path to trained model directory or JSON');
console.error(' --baseline Evaluate proxy-based baseline (no model)');
console.error(' --output, -o <path> Output eval report JSON');
console.error(' --verbose, -v Verbose output');
process.exit(1);
}
if (!args.model && !args.baseline) {
console.error('Error: Must specify either --model <path> or --baseline');
process.exit(1);
}
// ---------------------------------------------------------------------------
// Data loading
// ---------------------------------------------------------------------------
/**
* Load paired JSONL samples.
* Each line: { csi: [...], csi_shape: [S, T], kp: [[x,y],...], conf: 0.xx, ... }
*/
function loadPairedData(filePath) {
const content = fs.readFileSync(filePath, 'utf-8');
const samples = [];
for (const line of content.split('\n')) {
if (!line.trim()) continue;
try {
const s = JSON.parse(line);
if (!s.kp || !Array.isArray(s.kp)) continue;
if (!s.csi && !s.csi_shape) continue;
samples.push(s);
} catch (e) {
// skip malformed lines
}
}
return samples;
}
// ---------------------------------------------------------------------------
// Model loading
// ---------------------------------------------------------------------------
/**
* Load WiFlow model from a directory or JSON file.
* Tries: model.safetensors, then config.json for architecture config.
* Returns { model, name }.
*/
function loadModel(modelPath) {
const stat = fs.statSync(modelPath);
let modelDir;
if (stat.isDirectory()) {
modelDir = modelPath;
} else {
// Assume JSON file in a model directory
modelDir = path.dirname(modelPath);
}
// Load architecture config if available
let config = {};
const configPath = path.join(modelDir, 'config.json');
if (fs.existsSync(configPath)) {
try {
const raw = JSON.parse(fs.readFileSync(configPath, 'utf-8'));
if (raw.custom) {
config.inputChannels = raw.custom.inputChannels || 128;
config.timeSteps = raw.custom.timeSteps || 20;
config.numKeypoints = raw.custom.numKeypoints || 17;
config.numHeads = raw.custom.numHeads || 8;
config.seed = raw.custom.seed || 42;
}
} catch (e) {
// use defaults
}
}
// Load training-metrics.json for additional config
const metricsPath = path.join(modelDir, 'training-metrics.json');
if (fs.existsSync(metricsPath)) {
try {
const metrics = JSON.parse(fs.readFileSync(metricsPath, 'utf-8'));
if (metrics.model && metrics.model.architecture === 'wiflow') {
// metrics available for report
}
} catch (e) {
// ignore
}
}
// Create model with config
const model = new WiFlowModel(config);
model.setTraining(false); // eval mode
// Load weights from SafeTensors
const safetensorsPath = path.join(modelDir, 'model.safetensors');
if (fs.existsSync(safetensorsPath)) {
const buffer = new Uint8Array(fs.readFileSync(safetensorsPath));
const reader = new SafeTensorsReader(buffer);
const tensorNames = reader.getTensorNames();
// Build tensor map for fromTensorMap
const tensorMap = new Map();
for (const name of tensorNames) {
const tensor = reader.getTensor(name);
if (tensor) {
tensorMap.set(name, tensor.data);
}
}
model.fromTensorMap(tensorMap);
if (args.verbose) {
console.log(`Loaded ${tensorNames.length} tensors from ${safetensorsPath}`);
console.log(`Model params: ${model.numParams().toLocaleString()}`);
}
} else {
console.warn(`WARN: No model.safetensors found in ${modelDir}, using random weights`);
}
// Derive model name
const name = path.basename(modelDir);
return { model, name };
}
// ---------------------------------------------------------------------------
// Baseline proxy pose generation (ADR-072 Phase 2 heuristic)
// ---------------------------------------------------------------------------
/**
* Generate a proxy standing skeleton from CSI features.
* If presence detected (amplitude energy > threshold), place a standing
* person at center with standard COCO proportions, perturbed by motion energy.
*/
function generateBaselinePose(sample) {
const rng = createRng(42);
// Estimate presence from CSI amplitude energy
const csi = sample.csi;
let energy = 0;
if (Array.isArray(csi)) {
for (let i = 0; i < csi.length; i++) {
energy += csi[i] * csi[i];
}
energy = Math.sqrt(energy / csi.length);
}
// Estimate motion energy (variance across subcarriers)
let motionEnergy = 0;
if (Array.isArray(csi) && sample.csi_shape) {
const [S, T] = sample.csi_shape;
if (T > 1) {
for (let s = 0; s < S; s++) {
let sum = 0;
let sumSq = 0;
for (let t = 0; t < T; t++) {
const v = csi[s * T + t] || 0;
sum += v;
sumSq += v * v;
}
const mean = sum / T;
motionEnergy += (sumSq / T) - (mean * mean);
}
motionEnergy = Math.sqrt(Math.max(0, motionEnergy / S));
}
}
// Normalized presence heuristic
const presence = Math.min(1, energy / 10);
if (presence < 0.3) {
// No person detected: return zero pose
return new Float32Array(NUM_KEYPOINTS * 2);
}
// Standing skeleton at center (0.5, 0.5) with standard proportions
// Coordinates are [x, y] in normalized [0, 1] space
// y=0 is top, y=1 is bottom (image convention)
const cx = 0.5;
const headY = 0.2;
const shoulderY = 0.32;
const elbowY = 0.45;
const wristY = 0.55;
const hipY = 0.55;
const kneeY = 0.72;
const ankleY = 0.88;
const shoulderW = 0.08;
const hipW = 0.06;
const armSpread = 0.12;
// Standard standing pose keypoints [x, y]
const skeleton = [
[cx, headY], // 0: nose
[cx - 0.02, headY - 0.02], // 1: l_eye
[cx + 0.02, headY - 0.02], // 2: r_eye
[cx - 0.04, headY], // 3: l_ear
[cx + 0.04, headY], // 4: r_ear
[cx - shoulderW, shoulderY], // 5: l_shoulder
[cx + shoulderW, shoulderY], // 6: r_shoulder
[cx - armSpread, elbowY], // 7: l_elbow
[cx + armSpread, elbowY], // 8: r_elbow
[cx - armSpread - 0.02, wristY], // 9: l_wrist
[cx + armSpread + 0.02, wristY], // 10: r_wrist
[cx - hipW, hipY], // 11: l_hip
[cx + hipW, hipY], // 12: r_hip
[cx - hipW, kneeY], // 13: l_knee
[cx + hipW, kneeY], // 14: r_knee
[cx - hipW, ankleY], // 15: l_ankle
[cx + hipW, ankleY], // 16: r_ankle
];
// Perturb limbs by motion energy
const perturbScale = Math.min(motionEnergy * 0.1, 0.05);
const result = new Float32Array(NUM_KEYPOINTS * 2);
for (let k = 0; k < NUM_KEYPOINTS; k++) {
const px = (rng() - 0.5) * 2 * perturbScale;
const py = (rng() - 0.5) * 2 * perturbScale;
result[k * 2] = Math.max(0, Math.min(1, skeleton[k][0] + px));
result[k * 2 + 1] = Math.max(0, Math.min(1, skeleton[k][1] + py));
}
return result;
}
// ---------------------------------------------------------------------------
// Metric computation
// ---------------------------------------------------------------------------
/** Euclidean distance between two 2D points */
function dist2d(x1, y1, x2, y2) {
const dx = x1 - x2;
const dy = y1 - y2;
return Math.sqrt(dx * dx + dy * dy);
}
/**
* Compute torso length from ground-truth keypoints.
* Torso = distance(mid_shoulder, mid_hip).
* Returns DEFAULT_TORSO_LENGTH if shoulders or hips not visible.
*/
function computeTorsoLength(kp) {
if (!kp || kp.length < 13) return DEFAULT_TORSO_LENGTH;
const lsX = kp[L_SHOULDER][0];
const lsY = kp[L_SHOULDER][1];
const rsX = kp[R_SHOULDER][0];
const rsY = kp[R_SHOULDER][1];
const lhX = kp[L_HIP][0];
const lhY = kp[L_HIP][1];
const rhX = kp[R_HIP][0];
const rhY = kp[R_HIP][1];
// Check if joints are at origin (not visible)
const shoulderVisible = (lsX !== 0 || lsY !== 0) && (rsX !== 0 || rsY !== 0);
const hipVisible = (lhX !== 0 || lhY !== 0) && (rhX !== 0 || rhY !== 0);
if (!shoulderVisible || !hipVisible) return DEFAULT_TORSO_LENGTH;
const midShoulderX = (lsX + rsX) / 2;
const midShoulderY = (lsY + rsY) / 2;
const midHipX = (lhX + rhX) / 2;
const midHipY = (lhY + rhY) / 2;
const torso = dist2d(midShoulderX, midShoulderY, midHipX, midHipY);
return torso > 0.01 ? torso : DEFAULT_TORSO_LENGTH;
}
/**
* Evaluate predictions against ground truth.
*
* @param {Array<{pred: Float32Array, gt: number[][], conf: number}>} results
* @returns {object} Evaluation report
*/
function computeMetrics(results) {
const n = results.length;
if (n === 0) {
return {
n_samples: 0,
pck_10: 0, pck_20: 0, pck_50: 0,
mpjpe: 0,
per_joint_pck20: {},
per_joint_mpjpe: {},
conf_weighted_pck20: 0,
conf_weighted_mpjpe: 0,
};
}
// Accumulators
const pckCounts = { 10: 0, 20: 0, 50: 0 };
let totalJoints = 0;
let totalMPJPE = 0;
const perJointPck20 = new Float64Array(NUM_KEYPOINTS);
const perJointMPJPE = new Float64Array(NUM_KEYPOINTS);
const perJointCount = new Float64Array(NUM_KEYPOINTS);
// Confidence-weighted accumulators
let confWeightedPck20Num = 0;
let confWeightedPck20Den = 0;
let confWeightedMpjpeNum = 0;
let confWeightedMpjpeDen = 0;
for (const { pred, gt, conf } of results) {
const torso = computeTorsoLength(gt);
const w = Math.max(conf, 1e-6);
for (let k = 0; k < NUM_KEYPOINTS; k++) {
if (k >= gt.length) continue;
const gtX = gt[k][0];
const gtY = gt[k][1];
const predX = pred[k * 2];
const predY = pred[k * 2 + 1];
const d = dist2d(predX, predY, gtX, gtY);
totalJoints++;
totalMPJPE += d;
perJointMPJPE[k] += d;
perJointCount[k] += 1;
// PCK at different thresholds
if (d < 0.10 * torso) pckCounts[10]++;
if (d < 0.20 * torso) {
pckCounts[20]++;
perJointPck20[k]++;
confWeightedPck20Num += w;
}
if (d < 0.50 * torso) pckCounts[50]++;
confWeightedPck20Den += w;
confWeightedMpjpeNum += d * w;
confWeightedMpjpeDen += w;
}
}
// Aggregate metrics
const pck10 = totalJoints > 0 ? pckCounts[10] / totalJoints : 0;
const pck20 = totalJoints > 0 ? pckCounts[20] / totalJoints : 0;
const pck50 = totalJoints > 0 ? pckCounts[50] / totalJoints : 0;
const mpjpe = totalJoints > 0 ? totalMPJPE / totalJoints : 0;
// Per-joint breakdown
const perJointPck20Map = {};
const perJointMpjpeMap = {};
for (let k = 0; k < NUM_KEYPOINTS; k++) {
const name = JOINT_NAMES[k];
perJointPck20Map[name] = perJointCount[k] > 0 ? perJointPck20[k] / perJointCount[k] : 0;
perJointMpjpeMap[name] = perJointCount[k] > 0 ? perJointMPJPE[k] / perJointCount[k] : 0;
}
// Confidence-weighted
const confPck20 = confWeightedPck20Den > 0 ? confWeightedPck20Num / confWeightedPck20Den : 0;
const confMpjpe = confWeightedMpjpeDen > 0 ? confWeightedMpjpeNum / confWeightedMpjpeDen : 0;
return {
n_samples: n,
pck_10: pck10,
pck_20: pck20,
pck_50: pck50,
mpjpe,
per_joint_pck20: perJointPck20Map,
per_joint_mpjpe: perJointMpjpeMap,
conf_weighted_pck20: confPck20,
conf_weighted_mpjpe: confMpjpe,
};
}
// ---------------------------------------------------------------------------
// Inference
// ---------------------------------------------------------------------------
/**
* Run model inference on a single paired sample.
* @param {WiFlowModel} model
* @param {object} sample - { csi, csi_shape, kp, conf }
* @returns {Float32Array} - [17*2] predicted keypoints
*/
function runModelInference(model, sample) {
const csi = sample.csi;
const shape = sample.csi_shape;
const S = shape ? shape[0] : 128;
const T = shape ? shape[1] : 20;
// Prepare input as Float32Array [S, T]
let input;
if (csi instanceof Float32Array) {
input = csi;
} else if (Array.isArray(csi)) {
input = new Float32Array(csi);
} else {
input = new Float32Array(S * T);
}
// Ensure correct size (pad or truncate)
const expectedLen = model.inputChannels * model.timeSteps;
if (input.length !== expectedLen) {
const resized = new Float32Array(expectedLen);
const copyLen = Math.min(input.length, expectedLen);
resized.set(input.subarray(0, copyLen));
input = resized;
}
return model.forward(input);
}
// ---------------------------------------------------------------------------
// Formatted output
// ---------------------------------------------------------------------------
function formatPercent(v) {
return (v * 100).toFixed(1) + '%';
}
function formatFloat(v, decimals) {
decimals = decimals || 4;
return v.toFixed(decimals);
}
function printReport(report) {
console.log('');
console.log('WiFlow Evaluation Report (ADR-079)');
console.log('===================================');
console.log(`Model: ${report.model}`);
console.log(`Samples: ${report.n_samples.toLocaleString()}`);
console.log(`PCK@10: ${formatPercent(report.pck_10)}`);
console.log(`PCK@20: ${formatPercent(report.pck_20)}`);
console.log(`PCK@50: ${formatPercent(report.pck_50)}`);
console.log(`MPJPE: ${formatFloat(report.mpjpe)}`);
console.log('');
console.log('Per-Joint PCK@20:');
const maxNameLen = Math.max(...JOINT_NAMES.map(n => n.length));
for (const name of JOINT_NAMES) {
const pck = report.per_joint_pck20[name] || 0;
const pad = ' '.repeat(maxNameLen - name.length + 2);
console.log(` ${name}${pad}${formatPercent(pck)}`);
}
console.log('');
console.log('Per-Joint MPJPE:');
for (const name of JOINT_NAMES) {
const mpjpe = report.per_joint_mpjpe[name] || 0;
const pad = ' '.repeat(maxNameLen - name.length + 2);
console.log(` ${name}${pad}${formatFloat(mpjpe)}`);
}
console.log('');
console.log('Confidence-Weighted:');
console.log(` PCK@20: ${formatPercent(report.conf_weighted_pck20)}`);
console.log(` MPJPE: ${formatFloat(report.conf_weighted_mpjpe)}`);
console.log('');
console.log(`Inference: ${report.inference_latency_ms.toFixed(2)}ms/sample`);
console.log('');
}
// ---------------------------------------------------------------------------
// Main
// ---------------------------------------------------------------------------
function main() {
// Load paired data
if (args.verbose) console.log(`Loading paired data from ${args.data}...`);
const samples = loadPairedData(args.data);
if (samples.length === 0) {
console.error('Error: No valid paired samples found in', args.data);
process.exit(1);
}
if (args.verbose) console.log(`Loaded ${samples.length} paired samples`);
let modelName;
let model = null;
if (args.baseline) {
modelName = 'baseline-proxy';
if (args.verbose) console.log('Running baseline proxy evaluation (ADR-072 Phase 2 heuristic)');
} else {
const loaded = loadModel(args.model);
model = loaded.model;
modelName = loaded.name;
if (args.verbose) console.log(`Running model evaluation: ${modelName}`);
}
// Run inference and collect results
const results = [];
const startTime = process.hrtime.bigint();
for (const sample of samples) {
let pred;
if (args.baseline) {
pred = generateBaselinePose(sample);
} else {
pred = runModelInference(model, sample);
}
results.push({
pred,
gt: sample.kp,
conf: sample.conf || 0,
});
}
const endTime = process.hrtime.bigint();
const totalMs = Number(endTime - startTime) / 1e6;
const latencyMs = totalMs / samples.length;
// Compute metrics
const metrics = computeMetrics(results);
// Build report
const report = {
model: modelName,
n_samples: metrics.n_samples,
pck_10: Math.round(metrics.pck_10 * 10000) / 10000,
pck_20: Math.round(metrics.pck_20 * 10000) / 10000,
pck_50: Math.round(metrics.pck_50 * 10000) / 10000,
mpjpe: Math.round(metrics.mpjpe * 100000) / 100000,
per_joint_pck20: {},
per_joint_mpjpe: {},
conf_weighted_pck20: Math.round(metrics.conf_weighted_pck20 * 10000) / 10000,
conf_weighted_mpjpe: Math.round(metrics.conf_weighted_mpjpe * 100000) / 100000,
inference_latency_ms: Math.round(latencyMs * 100) / 100,
timestamp: new Date().toISOString(),
};
// Round per-joint metrics
for (const name of JOINT_NAMES) {
report.per_joint_pck20[name] = Math.round((metrics.per_joint_pck20[name] || 0) * 10000) / 10000;
report.per_joint_mpjpe[name] = Math.round((metrics.per_joint_mpjpe[name] || 0) * 100000) / 100000;
}
// Print formatted report
printReport(report);
// Write output JSON
const outputPath = args.output ||
(args.model
? path.join(path.dirname(
fs.statSync(args.model).isDirectory() ? path.join(args.model, '.') : args.model
), 'eval-report.json')
: 'models/wiflow-supervised/eval-report.json');
const outputDir = path.dirname(outputPath);
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir, { recursive: true });
}
fs.writeFileSync(outputPath, JSON.stringify(report, null, 2) + '\n');
console.log(`Report saved to ${outputPath}`);
}
main();
File diff suppressed because it is too large Load Diff