mirror of
https://github.com/ruvnet/RuView
synced 2026-06-10 10:23:19 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8dddbf941a | |||
| 35903a313d | |||
| 5bd0d59aa6 | |||
| 327d0d13f6 | |||
| d09baa6a09 | |||
| 486392bb68 | |||
| 33f5abd0e0 | |||
| e3522ddcda | |||
| b5e924cd72 | |||
| 854342297a | |||
| 23b4491e7b | |||
| 2b24250a69 | |||
| 6d446e5459 | |||
| 62fd1d96af | |||
| b3fd0e2951 | |||
| aae01a2be8 | |||
| 828d0599d7 | |||
| 21fd7c84e2 | |||
| 430243c32c | |||
| b7650b5243 | |||
| a23bd2ec01 | |||
| 74e0ebbd41 | |||
| d88994816f | |||
| d2560e1b87 |
@@ -95,9 +95,88 @@ node scripts/mincut-person-counter.js --port 5006 # Correct person counting
|
||||
>
|
||||
---
|
||||
|
||||
### What's New in v0.5.5
|
||||
### Pre-Trained Models (v0.6.0) — No Training Required
|
||||
|
||||
<details open>
|
||||
<summary><strong>Download from HuggingFace and start sensing immediately</strong></summary>
|
||||
|
||||
Pre-trained models are available on HuggingFace:
|
||||
> **https://huggingface.co/ruv/ruview** (primary) | [mirror](https://huggingface.co/ruvnet/wifi-densepose-pretrained)
|
||||
|
||||
Trained on 60,630 real-world samples from an 8-hour overnight collection. Just download and run — no datasets, no GPU, no training needed.
|
||||
|
||||
| Model | Size | What it does |
|
||||
|-------|------|-------------|
|
||||
| `model.safetensors` | 48 KB | Contrastive encoder — 128-dim embeddings for presence, activity, environment |
|
||||
| `model-q4.bin` | 8 KB | 4-bit quantized — fits in ESP32-S3 SRAM for edge inference |
|
||||
| `model-q2.bin` | 4 KB | 2-bit ultra-compact for memory-constrained devices |
|
||||
| `presence-head.json` | 2.6 KB | 100% accurate presence detection head |
|
||||
| `node-1.json` / `node-2.json` | 21 KB | Per-room LoRA adapters (swap for new rooms) |
|
||||
|
||||
```bash
|
||||
# Download and use (Python)
|
||||
pip install huggingface_hub
|
||||
huggingface-cli download ruv/ruview --local-dir models/
|
||||
|
||||
# Or use directly with the sensing pipeline
|
||||
node scripts/train-ruvllm.js --data data/recordings/*.csi.jsonl # retrain on your own data
|
||||
node scripts/benchmark-ruvllm.js --model models/csi-ruvllm # benchmark
|
||||
```
|
||||
|
||||
**Benchmarks (Apple M4 Pro, retrained on overnight data):**
|
||||
|
||||
| What we measured | Result | Why it matters |
|
||||
|-----------------|--------|---------------|
|
||||
| **Presence detection** | **100% accuracy** | Never misses a person, never false alarms |
|
||||
| **Inference speed** | **0.008 ms** per embedding | 125,000x faster than real-time |
|
||||
| **Throughput** | **164,183 embeddings/sec** | One Mac Mini handles 1,600+ ESP32 nodes |
|
||||
| **Contrastive learning** | **51.6% improvement** | Strong pattern learning from real overnight data |
|
||||
| **Model size** | **8 KB** (4-bit quantized) | Fits in ESP32 SRAM — no server needed |
|
||||
| **Total hardware cost** | **$140** | ESP32 ($9) + [Cognitum Seed](https://cognitum.one) ($131) |
|
||||
|
||||
</details>
|
||||
|
||||
### 17 Sensing Applications (v0.6.0)
|
||||
|
||||
<details>
|
||||
<summary><strong>Health, environment, security, and multi-frequency mesh sensing</strong></summary>
|
||||
|
||||
All applications run from a single ESP32 + optional Cognitum Seed. No camera, no cloud, no internet.
|
||||
|
||||
**Health & Wellness:**
|
||||
|
||||
| Application | Script | What it detects |
|
||||
|------------|--------|----------------|
|
||||
| Sleep Monitor | `node scripts/sleep-monitor.js` | Sleep stages (deep/light/REM/awake), efficiency, hypnogram |
|
||||
| Apnea Detector | `node scripts/apnea-detector.js` | Breathing pauses >10s, AHI severity scoring |
|
||||
| Stress Monitor | `node scripts/stress-monitor.js` | Heart rate variability, LF/HF stress ratio |
|
||||
| Gait Analyzer | `node scripts/gait-analyzer.js` | Walking cadence, stride asymmetry, tremor detection |
|
||||
|
||||
**Environment & Security:**
|
||||
|
||||
| Application | Script | What it detects |
|
||||
|------------|--------|----------------|
|
||||
| Person Counter | `node scripts/mincut-person-counter.js` | Correct occupancy count (fixes #348) |
|
||||
| Room Fingerprint | `node scripts/room-fingerprint.js` | Activity state clustering, daily patterns, anomalies |
|
||||
| Material Detector | `node scripts/material-detector.js` | New/moved objects via subcarrier null changes |
|
||||
| Device Fingerprint | `node scripts/device-fingerprint.js` | Electronic device activity (printer, router, etc.) |
|
||||
|
||||
**Multi-Frequency Mesh** (requires `--hop-channels` provisioning):
|
||||
|
||||
| Application | Script | What it detects |
|
||||
|------------|--------|----------------|
|
||||
| RF Tomography | `node scripts/rf-tomography.js` | 2D room imaging via RF backprojection |
|
||||
| Passive Radar | `node scripts/passive-radar.js` | Neighbor WiFi APs as bistatic radar illuminators |
|
||||
| Material Classifier | `node scripts/material-classifier.js` | Metal/water/wood/glass from frequency response |
|
||||
| Through-Wall | `node scripts/through-wall-detector.js` | Motion behind walls using lower-frequency penetration |
|
||||
|
||||
All scripts support `--replay data/recordings/*.csi.jsonl` for offline analysis and `--json` for programmatic output.
|
||||
|
||||
</details>
|
||||
|
||||
### What's New in v0.5.5
|
||||
|
||||
<details>
|
||||
<summary><strong>Advanced Sensing: SNN + MinCut + WiFlow + Multi-Frequency Mesh</strong></summary>
|
||||
|
||||
**v0.5.5 adds four new sensing capabilities** built on the [ruvector](https://github.com/ruvnet/ruvector) ecosystem:
|
||||
@@ -1188,7 +1267,8 @@ Download a pre-built binary — no build toolchain needed:
|
||||
|
||||
| Release | What's included | Tag |
|
||||
|---------|-----------------|-----|
|
||||
| [v0.5.5](https://github.com/ruvnet/RuView/releases/tag/v0.5.5-esp32) | **Latest** — SNN + MinCut (fixes #348) + CNN spectrogram + WiFlow 1.8M architecture + multi-freq mesh (6 channels) + graph transformer | `v0.5.5-esp32` |
|
||||
| [v0.6.0](https://github.com/ruvnet/RuView/releases/tag/v0.6.0-esp32) | **Latest** — [Pre-trained models on HuggingFace](https://huggingface.co/ruv/ruview), 17 sensing apps, 51.6% contrastive improvement, 0.008ms inference | `v0.6.0-esp32` |
|
||||
| [v0.5.5](https://github.com/ruvnet/RuView/releases/tag/v0.5.5-esp32) | SNN + MinCut (#348 fix) + CNN spectrogram + WiFlow + multi-freq mesh + graph transformer | `v0.5.5-esp32` |
|
||||
| [v0.5.4](https://github.com/ruvnet/RuView/releases/tag/v0.5.4-esp32) | Cognitum Seed integration ([ADR-069](docs/adr/ADR-069-cognitum-seed-csi-pipeline.md)), 8-dim feature vectors, RVF store, witness chain, security hardening | `v0.5.4-esp32` |
|
||||
| [v0.5.0](https://github.com/ruvnet/RuView/releases/tag/v0.5.0-esp32) | mmWave sensor fusion ([ADR-063](docs/adr/ADR-063-mmwave-sensor-fusion.md)), auto-detect MR60BHA2/LD2410, 48-byte fused vitals, all v0.4.3.1 fixes | `v0.5.0-esp32` |
|
||||
| [v0.4.3.1](https://github.com/ruvnet/RuView/releases/tag/v0.4.3.1-esp32) | Fall detection fix ([#263](https://github.com/ruvnet/RuView/issues/263)), 4MB flash ([#265](https://github.com/ruvnet/RuView/issues/265)), watchdog fix ([#266](https://github.com/ruvnet/RuView/issues/266)) | `v0.4.3.1-esp32` |
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"id": "pretrain-1775182186",
|
||||
"name": "pretrain-1775182186",
|
||||
"label": "mixed-activity",
|
||||
"started_at": "2026-04-03T02:09:46Z",
|
||||
"ended_at": "2026-04-03T02:11:46Z",
|
||||
"duration_secs": 120,
|
||||
"frame_count": 5783,
|
||||
"file_size_bytes": 2580539,
|
||||
"file_path": "data/recordings\\pretrain-1775182186.csi.jsonl",
|
||||
"nodes": {
|
||||
"2": 2886,
|
||||
"1": 2897
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,512 @@
|
||||
# ADR-079: Camera Ground-Truth Training Pipeline
|
||||
|
||||
- **Status**: Accepted
|
||||
- **Date**: 2026-04-06
|
||||
- **Deciders**: ruv
|
||||
- **Relates to**: ADR-072 (WiFlow Architecture), ADR-070 (Self-Supervised Pretraining), ADR-071 (ruvllm Training Pipeline), ADR-024 (AETHER Contrastive), ADR-064 (Multimodal Ambient Intelligence), ADR-075 (MinCut Person Separation)
|
||||
|
||||
## Context
|
||||
|
||||
WiFlow (ADR-072) currently trains without ground-truth pose labels, using proxy poses
|
||||
generated from presence/motion heuristics. This produces a PCK@20 of only 2.5% — far
|
||||
below the 30-50% achievable with supervised training. The fundamental bottleneck is the
|
||||
absence of spatial keypoint labels.
|
||||
|
||||
Academic WiFi pose estimation systems (Wi-Pose, Person-in-WiFi 3D, MetaFi++) all train
|
||||
with synchronized camera ground truth and achieve PCK@20 of 40-85%. They discard the
|
||||
camera at deployment — the camera is a training-time teacher, not a runtime dependency.
|
||||
|
||||
ADR-064 already identified this: *"Record CSI + mmWave while performing signs with a
|
||||
camera as ground truth, then deploy camera-free."* This ADR specifies the implementation.
|
||||
|
||||
### Current Training Pipeline Gap
|
||||
|
||||
```
|
||||
Current: CSI amplitude → WiFlow → 17 keypoints (proxy-supervised, PCK@20 = 2.5%)
|
||||
↑
|
||||
Heuristic proxies:
|
||||
- Standing skeleton when presence > 0.3
|
||||
- Limb perturbation from motion energy
|
||||
- No spatial accuracy
|
||||
```
|
||||
|
||||
### Target Pipeline
|
||||
|
||||
```
|
||||
Training: CSI amplitude ──→ WiFlow ──→ 17 keypoints (camera-supervised, PCK@20 target: 35%+)
|
||||
↑
|
||||
Laptop camera ──→ MediaPipe ──→ 17 COCO keypoints (ground truth)
|
||||
(time-synchronized, 30 fps)
|
||||
|
||||
Deploy: CSI amplitude ──→ WiFlow ──→ 17 keypoints (camera-free, trained model only)
|
||||
```
|
||||
|
||||
## Decision
|
||||
|
||||
Build a camera ground-truth collection and training pipeline using the laptop webcam
|
||||
as a teacher signal. The camera is used **only during training data collection** and is
|
||||
not required at deployment.
|
||||
|
||||
### Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Data Collection Phase │
|
||||
│ │
|
||||
│ ESP32-S3 nodes ──UDP──→ Sensing Server ──→ CSI frames (.jsonl) │
|
||||
│ ↑ time sync │
|
||||
│ Laptop Camera ──→ MediaPipe Pose ──→ Keypoints (.jsonl) │
|
||||
│ ↑ │
|
||||
│ collect-ground-truth.py │
|
||||
│ (single orchestrator) │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Training Phase │
|
||||
│ │
|
||||
│ Paired dataset: { csi_window[128,20], keypoints[17,2], conf } │
|
||||
│ ↓ │
|
||||
│ train-wiflow-supervised.js │
|
||||
│ Phase 1: Contrastive pretrain (ADR-072, reuse) │
|
||||
│ Phase 2: Supervised keypoint regression (NEW) │
|
||||
│ Phase 3: Fine-tune with bone constraints + confidence │
|
||||
│ ↓ │
|
||||
│ WiFlow model (1.8M params) → SafeTensors export │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Deployment (camera-free) │
|
||||
│ │
|
||||
│ ESP32-S3 CSI → Sensing Server → WiFlow inference → 17 keypoints│
|
||||
│ (No camera. Trained model runs on CSI input only.) │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Component 1: `scripts/collect-ground-truth.py`
|
||||
|
||||
Single Python script that orchestrates synchronized capture from the laptop camera
|
||||
and the ESP32 CSI stream.
|
||||
|
||||
**Dependencies:** `mediapipe`, `opencv-python`, `requests` (all pip-installable, no GPU)
|
||||
|
||||
**Capture flow:**
|
||||
|
||||
```python
|
||||
# Pseudocode
|
||||
camera = cv2.VideoCapture(0) # Laptop webcam
|
||||
sensing_api = "http://localhost:3000" # Sensing server
|
||||
|
||||
# Start CSI recording via existing API
|
||||
requests.post(f"{sensing_api}/api/v1/recording/start")
|
||||
|
||||
while recording:
|
||||
frame = camera.read()
|
||||
t = time.time_ns() # Nanosecond timestamp
|
||||
|
||||
# MediaPipe Pose: 33 landmarks → map to 17 COCO keypoints
|
||||
result = mp_pose.process(frame)
|
||||
keypoints_17 = map_mediapipe_to_coco(result.pose_landmarks)
|
||||
confidence = mean(landmark.visibility for relevant landmarks)
|
||||
|
||||
# Write to ground-truth JSONL (one line per frame)
|
||||
write_jsonl({
|
||||
"ts_ns": t,
|
||||
"keypoints": keypoints_17, # [[x,y], ...] normalized [0,1]
|
||||
"confidence": confidence, # 0-1, used for loss weighting
|
||||
"n_visible": count(visibility > 0.5),
|
||||
})
|
||||
|
||||
# Optional: show live preview with skeleton overlay
|
||||
if preview:
|
||||
draw_skeleton(frame, keypoints_17)
|
||||
cv2.imshow("Ground Truth", frame)
|
||||
|
||||
# Stop CSI recording
|
||||
requests.post(f"{sensing_api}/api/v1/recording/stop")
|
||||
```
|
||||
|
||||
**MediaPipe → COCO keypoint mapping:**
|
||||
|
||||
| COCO Index | Joint | MediaPipe Index |
|
||||
|------------|-------|-----------------|
|
||||
| 0 | Nose | 0 |
|
||||
| 1 | Left Eye | 2 |
|
||||
| 2 | Right Eye | 5 |
|
||||
| 3 | Left Ear | 7 |
|
||||
| 4 | Right Ear | 8 |
|
||||
| 5 | Left Shoulder | 11 |
|
||||
| 6 | Right Shoulder | 12 |
|
||||
| 7 | Left Elbow | 13 |
|
||||
| 8 | Right Elbow | 14 |
|
||||
| 9 | Left Wrist | 15 |
|
||||
| 10 | Right Wrist | 16 |
|
||||
| 11 | Left Hip | 23 |
|
||||
| 12 | Right Hip | 24 |
|
||||
| 13 | Left Knee | 25 |
|
||||
| 14 | Right Knee | 26 |
|
||||
| 15 | Left Ankle | 27 |
|
||||
| 16 | Right Ankle | 28 |
|
||||
|
||||
### Component 2: Time Alignment (`scripts/align-ground-truth.js`)
|
||||
|
||||
CSI frames arrive at ~100 Hz with server-side timestamps. Camera keypoints arrive at
|
||||
~30 fps with client-side timestamps. Alignment is needed because:
|
||||
|
||||
1. Camera and sensing server clocks differ (typically < 50ms on LAN)
|
||||
2. CSI is aggregated into 20-frame windows for WiFlow input
|
||||
3. Ground-truth keypoints must be averaged over the same window
|
||||
|
||||
**Alignment algorithm:**
|
||||
|
||||
```
|
||||
For each CSI window W_i (20 frames, ~200ms at 100Hz):
|
||||
t_start = W_i.first_frame.timestamp
|
||||
t_end = W_i.last_frame.timestamp
|
||||
|
||||
# Find all camera keypoints within this time window
|
||||
matching_keypoints = [k for k in camera_data if t_start <= k.ts <= t_end]
|
||||
|
||||
if len(matching_keypoints) >= 3: # At least 3 camera frames per window
|
||||
# Average keypoints, weighted by confidence
|
||||
avg_keypoints = weighted_mean(matching_keypoints, weights=confidences)
|
||||
avg_confidence = mean(confidences)
|
||||
|
||||
paired_dataset.append({
|
||||
csi_window: W_i.amplitudes, # [128, 20] float32
|
||||
keypoints: avg_keypoints, # [17, 2] float32
|
||||
confidence: avg_confidence, # scalar
|
||||
n_camera_frames: len(matching_keypoints),
|
||||
})
|
||||
```
|
||||
|
||||
**Clock sync strategy:**
|
||||
|
||||
- NTP is sufficient (< 20ms error on LAN)
|
||||
- The 200ms CSI window is 10x larger than typical clock drift
|
||||
- For tighter sync: use a handclap/jump as a sync marker — visible spike in both
|
||||
CSI motion energy and camera skeleton velocity. Auto-detect and align.
|
||||
|
||||
**Output:** `data/recordings/paired-{timestamp}.jsonl` — one line per paired sample:
|
||||
```json
|
||||
{"csi": [128x20 flat], "kp": [[0.45,0.12], ...], "conf": 0.92, "ts": 1775300000000}
|
||||
```
|
||||
|
||||
### Component 3: Supervised Training (`scripts/train-wiflow-supervised.js`)
|
||||
|
||||
Extends the existing `train-ruvllm.js` pipeline with a supervised phase.
|
||||
|
||||
**Phase 1: Contrastive Pretrain (reuse ADR-072)**
|
||||
- Same as existing: temporal + cross-node triplets
|
||||
- Learns CSI representation without labels
|
||||
- 50 epochs, ~5 min on laptop
|
||||
|
||||
**Phase 2: Supervised Keypoint Regression (NEW)**
|
||||
- Load paired dataset from Component 2
|
||||
- Loss: confidence-weighted SmoothL1 on keypoints
|
||||
|
||||
```
|
||||
L_supervised = (1/N) * sum_i [ conf_i * SmoothL1(pred_i, gt_i, beta=0.05) ]
|
||||
```
|
||||
|
||||
- Only train on samples where `conf > 0.5` (discard frames where MediaPipe lost tracking)
|
||||
- Learning rate: 1e-4 with cosine decay
|
||||
- 200 epochs, ~15 min on laptop CPU (1.8M params, no GPU needed)
|
||||
|
||||
**Phase 3: Refinement with Bone Constraints**
|
||||
- Fine-tune with combined loss:
|
||||
|
||||
```
|
||||
L = L_supervised + 0.3 * L_bone + 0.1 * L_temporal
|
||||
|
||||
L_bone = (1/14) * sum_b (bone_len_b - prior_b)^2 # ADR-072 bone priors
|
||||
L_temporal = SmoothL1(kp_t, kp_{t-1}) # Temporal smoothness
|
||||
```
|
||||
|
||||
- 50 epochs at lower LR (1e-5)
|
||||
- Tighten bone constraint weight from 0.3 → 0.5 over epochs
|
||||
|
||||
**Phase 4: Quantization + Export**
|
||||
- Reuse ruvllm TurboQuant: float32 → int8 (4x smaller, ~881 KB)
|
||||
- Export via SafeTensors for cross-platform deployment
|
||||
- Validate quantized model PCK@20 within 2% of full-precision
|
||||
|
||||
### Component 4: Evaluation Script (`scripts/eval-wiflow.js`)
|
||||
|
||||
Measure actual PCK@20 using held-out paired data (20% split).
|
||||
|
||||
```
|
||||
PCK@k = (1/N) * sum_i [ (||pred_i - gt_i|| < k * torso_length) ? 1 : 0 ]
|
||||
```
|
||||
|
||||
**Metrics reported:**
|
||||
|
||||
| Metric | Description | Target |
|
||||
|--------|-------------|--------|
|
||||
| PCK@20 | % of keypoints within 20% torso length | > 35% |
|
||||
| PCK@50 | % within 50% torso length | > 60% |
|
||||
| MPJPE | Mean per-joint position error (pixels) | < 40px |
|
||||
| Per-joint PCK | Breakdown by joint (wrists are hardest) | Report all 17 |
|
||||
| Inference latency | Single window prediction time | < 50ms |
|
||||
|
||||
### Optimization Strategy
|
||||
|
||||
#### O1: Curriculum Learning
|
||||
|
||||
Train easy poses first, hard poses later:
|
||||
|
||||
| Stage | Epochs | Data Filter | Rationale |
|
||||
|-------|--------|-------------|-----------|
|
||||
| 1 | 50 | `conf > 0.9`, standing only | Establish stable skeleton baseline |
|
||||
| 2 | 50 | `conf > 0.7`, low motion | Add sitting, subtle movements |
|
||||
| 3 | 50 | `conf > 0.5`, all poses | Full dataset including occlusions |
|
||||
| 4 | 50 | All data, with augmentation | Robustness via noise injection |
|
||||
|
||||
#### O2: Data Augmentation (CSI domain)
|
||||
|
||||
Augment CSI windows to increase effective dataset size without collecting more data:
|
||||
|
||||
| Augmentation | Implementation | Expected Gain |
|
||||
|-------------|----------------|---------------|
|
||||
| Time shift | Roll CSI window by ±2 frames | +30% data |
|
||||
| Amplitude noise | Gaussian noise, sigma=0.02 | Robustness |
|
||||
| Subcarrier dropout | Zero 10% of subcarriers randomly | Robustness |
|
||||
| Temporal flip | Reverse window + reverse keypoint velocity | +100% data |
|
||||
| Multi-node mix | Swap node CSI, keep same-time keypoints | Cross-node generalization |
|
||||
|
||||
#### O3: Knowledge Distillation from MediaPipe
|
||||
|
||||
Instead of raw keypoint regression, distill MediaPipe's confidence and heatmap
|
||||
information:
|
||||
|
||||
```
|
||||
L_distill = KL_div(softmax(wifi_heatmap / T), softmax(camera_heatmap / T))
|
||||
```
|
||||
|
||||
- Temperature T=4 for soft targets (transfers inter-joint relationships)
|
||||
- WiFlow predicts a 17-channel heatmap [17, H, W] instead of direct [17, 2]
|
||||
- Argmax for final keypoint extraction
|
||||
- **Trade-off:** Adds ~200K params for heatmap decoder, but improves spatial precision
|
||||
|
||||
#### O4: Active Learning Loop
|
||||
|
||||
Identify which poses the model is worst at and collect more data for those:
|
||||
|
||||
```
|
||||
1. Train initial model on first collection session
|
||||
2. Run inference on new CSI data, compute prediction entropy
|
||||
3. Flag high-entropy windows (model is uncertain)
|
||||
4. During next collection, the preview overlay highlights these moments:
|
||||
"Hold this pose — model needs more examples"
|
||||
5. Re-train with augmented dataset
|
||||
```
|
||||
|
||||
Expected: 2-3 active learning iterations reach saturation.
|
||||
|
||||
#### O6: Subcarrier Selection (ruvector-solver)
|
||||
|
||||
Variance-based top-K subcarrier selection, equivalent to ruvector-solver's sparse
|
||||
interpolation (114→56). Removes noise/static subcarriers before training:
|
||||
|
||||
```
|
||||
For each subcarrier d in [0, dim):
|
||||
variance[d] = mean over samples of temporal_variance(csi[d, :])
|
||||
Select top-K by variance (K = dim * 0.5)
|
||||
```
|
||||
|
||||
**Validated:** 128 → 56 subcarriers (56% input reduction), proportional model size reduction.
|
||||
|
||||
#### O7: Attention-Weighted Subcarriers (ruvector-attention)
|
||||
|
||||
Compute per-subcarrier attention weights based on temporal energy correlation with
|
||||
ground-truth keypoint motion. High-energy subcarriers that covary with skeleton
|
||||
movement get amplified:
|
||||
|
||||
```
|
||||
For each subcarrier d:
|
||||
energy[d] = sum of squared first-differences over time
|
||||
weight[d] = softmax(energy, temperature=0.1)
|
||||
Apply: csi[d, :] *= weight[d] * dim (mean weight = 1)
|
||||
```
|
||||
|
||||
**Validated:** Top-5 attention subcarriers identified automatically per dataset.
|
||||
|
||||
#### O8: Stoer-Wagner MinCut Person Separation (ruvector-mincut / ADR-075)
|
||||
|
||||
JS implementation of the Stoer-Wagner algorithm for person separation in CSI, equivalent
|
||||
to `DynamicPersonMatcher` in `wifi-densepose-train/src/metrics.rs`. Builds a subcarrier
|
||||
correlation graph and finds the minimum cut to identify person-specific subcarrier clusters:
|
||||
|
||||
```
|
||||
1. Build dim×dim Pearson correlation matrix across subcarriers
|
||||
2. Run Stoer-Wagner min-cut on correlation graph
|
||||
3. Partition subcarriers into person-specific groups
|
||||
4. Train per-partition models for multi-person scenarios
|
||||
```
|
||||
|
||||
**Validated:** Stoer-Wagner executes on 56-dim graph, identifies partition boundaries.
|
||||
|
||||
#### O9: Multi-SPSA Gradient Estimation
|
||||
|
||||
Average over K=3 random perturbation directions per gradient step. Reduces variance
|
||||
by sqrt(K) = 1.73x compared to single SPSA, at 3x forward pass cost (net win for
|
||||
convergence quality):
|
||||
|
||||
```
|
||||
For k in 1..K:
|
||||
delta_k = random ±1 per parameter
|
||||
grad_k = (loss(w + eps*delta_k) - loss(w - eps*delta_k)) / (2*eps*delta_k)
|
||||
grad = mean(grad_1, ..., grad_K)
|
||||
```
|
||||
|
||||
#### O10: Mac M4 Pro Training via Tailscale
|
||||
|
||||
Training runs on Mac Mini M4 Pro (16-core GPU, ARM NEON SIMD) via Tailscale SSH,
|
||||
using ruvllm's native Node.js SIMD ops:
|
||||
|
||||
| | Windows (CPU) | Mac M4 Pro |
|
||||
|---|---|---|
|
||||
| Node.js | v24.12.0 (x86) | v25.9.0 (ARM) |
|
||||
| SIMD | SSE4/AVX2 | NEON |
|
||||
| Cores | Consumer laptop | 12P + 4E cores |
|
||||
| Training | Slow (minutes/epoch) | Fast (seconds/epoch) |
|
||||
|
||||
#### O5: Cross-Environment Transfer
|
||||
|
||||
Train on one room, deploy in another:
|
||||
|
||||
| Strategy | Implementation |
|
||||
|----------|---------------|
|
||||
| Room-invariant features | Normalize CSI by running mean/variance |
|
||||
| LoRA adapters | Train a 4-rank LoRA per room (ADR-071) — 7.3 KB each |
|
||||
| Few-shot calibration | 2 min of camera data in new room → fine-tune LoRA only |
|
||||
| AETHER embeddings | Use contrastive room-independent features (ADR-024) as input |
|
||||
|
||||
The LoRA approach is most practical: ship a base model + collect 2 min of calibration
|
||||
data per new room using the laptop camera.
|
||||
|
||||
### Data Collection Protocol
|
||||
|
||||
Recommended collection sessions per room:
|
||||
|
||||
| Session | Duration | Activity | People | Total CSI Frames |
|
||||
|---------|----------|----------|--------|-----------------|
|
||||
| 1. Baseline | 5 min | Empty + 1 person entry/exit | 0-1 | 30,000 |
|
||||
| 2. Standing poses | 5 min | Stand, arms up/down/sides, turn | 1 | 30,000 |
|
||||
| 3. Sitting | 5 min | Sit, type, lean, stand up/sit down | 1 | 30,000 |
|
||||
| 4. Walking | 5 min | Walk paths across room | 1 | 30,000 |
|
||||
| 5. Mixed | 5 min | Varied activities, transitions | 1 | 30,000 |
|
||||
| 6. Multi-person | 5 min | 2 people, varied activities | 2 | 30,000 |
|
||||
| **Total** | **30 min** | | | **180,000** |
|
||||
|
||||
At 20-frame windows: **9,000 paired training samples** per 30-min session.
|
||||
With augmentation (O2): **~27,000 effective samples**.
|
||||
|
||||
Camera placement: position laptop so the camera has a clear view of the sensing area.
|
||||
The camera FOV should cover the same space the ESP32 nodes cover.
|
||||
|
||||
### File Structure
|
||||
|
||||
```
|
||||
scripts/
|
||||
collect-ground-truth.py # Camera capture + MediaPipe + CSI sync
|
||||
align-ground-truth.js # Time-align CSI windows with camera keypoints
|
||||
train-wiflow-supervised.js # Supervised training pipeline
|
||||
eval-wiflow.js # PCK evaluation on held-out data
|
||||
|
||||
data/
|
||||
ground-truth/ # Raw camera keypoint captures
|
||||
gt-{timestamp}.jsonl
|
||||
paired/ # Aligned CSI + keypoint pairs
|
||||
paired-{timestamp}.jsonl
|
||||
|
||||
models/
|
||||
wiflow-supervised/ # Trained model outputs
|
||||
wiflow-v1.safetensors
|
||||
wiflow-v1-int8.safetensors
|
||||
training-log.json
|
||||
eval-report.json
|
||||
```
|
||||
|
||||
### Privacy Considerations
|
||||
|
||||
- Camera frames are processed **locally** by MediaPipe — no cloud upload
|
||||
- Raw video is **never saved** — only extracted keypoint coordinates are stored
|
||||
- The `.jsonl` ground-truth files contain only `[x,y]` joint coordinates, not images
|
||||
- The trained model runs on CSI only — no camera data leaves the laptop
|
||||
- Users can delete `data/ground-truth/` after training; the model is self-contained
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **10-20x accuracy improvement**: PCK@20 from 2.5% → 35%+ with real supervision
|
||||
- **Reuses existing infrastructure**: sensing server recording API, ruvllm training, SafeTensors
|
||||
- **No new hardware**: laptop webcam + existing ESP32 nodes
|
||||
- **Privacy preserved at deployment**: camera only needed during 30-min training session
|
||||
- **Incremental**: can improve with more collection sessions + active learning
|
||||
- **Distributable**: trained model weights can be shared on HuggingFace (ADR-070)
|
||||
|
||||
### Negative
|
||||
|
||||
- **Camera placement matters**: must see the same area ESP32 nodes sense
|
||||
- **Single-room models**: need LoRA calibration per room (2 min + camera)
|
||||
- **MediaPipe limitations**: occlusion, side views, multiple people reduce keypoint quality
|
||||
- **Time sync**: NTP drift can misalign frames (mitigated by 200ms windows)
|
||||
|
||||
### Risks
|
||||
|
||||
| Risk | Probability | Impact | Mitigation |
|
||||
|------|-------------|--------|------------|
|
||||
| MediaPipe keypoints too noisy | Low | Medium | Filter by confidence; MediaPipe is robust indoors |
|
||||
| Clock drift > 100ms | Low | High | Add handclap sync marker detection |
|
||||
| Single camera can't see all poses | Medium | Medium | Position camera centrally; collect from 2 angles |
|
||||
| Model overfits to one room | High | Medium | LoRA adapters + AETHER normalization (O5) |
|
||||
| Insufficient data (< 5K pairs) | Low | High | Augmentation (O2) + active learning (O4) |
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
| Phase | Task | Effort | Status |
|
||||
|-------|------|--------|--------|
|
||||
| P1 | `collect-ground-truth.py` — camera + MediaPipe capture | 2 hrs | **Done** |
|
||||
| P2 | `align-ground-truth.js` — time alignment + pairing | 1 hr | **Done** |
|
||||
| P3 | `train-wiflow-supervised.js` — supervised training | 3 hrs | **Done** |
|
||||
| P4 | `eval-wiflow.js` — PCK evaluation | 1 hr | **Done** |
|
||||
| P5 | ruvector optimizations (O6-O9) | 2 hrs | **Done** |
|
||||
| P6 | Mac M4 Pro training via Tailscale (O10) | 1 hr | **Done** |
|
||||
| P7 | Data collection session (30 min recording) | 1 hr | Pending |
|
||||
| P8 | Training + evaluation on real paired data | 30 min | Pending |
|
||||
| P9 | LoRA cross-room calibration (O5) | 2 hrs | Pending |
|
||||
|
||||
## Validated Hardware
|
||||
|
||||
| Component | Spec | Validated |
|
||||
|-----------|------|-----------|
|
||||
| Mac Mini camera | 1920x1080, 30fps | Yes — 14/17 keypoints, conf 0.94-1.0 |
|
||||
| MediaPipe PoseLandmarker | v0.10.33 Tasks API, lite model | Yes — via Tailscale SSH |
|
||||
| Mac M4 Pro GPU | 16-core, Metal 4, NEON SIMD | Yes — Node.js v25.9.0 |
|
||||
| Tailscale SSH | LAN-accessible Mac, passwordless | Yes |
|
||||
| ESP32-S3 CSI | 128 subcarriers, 100Hz | Yes — existing recordings |
|
||||
| Sensing server recording API | `/api/v1/recording/start\|stop` | Yes — existing |
|
||||
|
||||
## Baseline Benchmark
|
||||
|
||||
Proxy-pose baseline (no camera supervision, standing skeleton heuristic):
|
||||
|
||||
```
|
||||
PCK@10: 11.8%
|
||||
PCK@20: 35.3%
|
||||
PCK@50: 94.1%
|
||||
MPJPE: 0.067
|
||||
Latency: 0.03ms/sample
|
||||
```
|
||||
|
||||
Per-joint PCK@20: upper body (nose, shoulders, wrists) at 0% — proxy has no spatial
|
||||
accuracy for these. Camera supervision targets these joints specifically.
|
||||
|
||||
## References
|
||||
|
||||
- WiFlow: arXiv:2602.08661 — WiFi-based pose estimation with TCN + axial attention
|
||||
- Wi-Pose (CVPR 2021) — 3D CNN WiFi pose with camera supervision
|
||||
- Person-in-WiFi 3D (CVPR 2024) — Deformable attention with camera labels
|
||||
- MediaPipe Pose — Google's real-time 33-landmark body pose estimator
|
||||
- MetaFi++ (NeurIPS 2023) — Meta-learning cross-modal WiFi sensing
|
||||
@@ -1055,6 +1055,82 @@ See [ADR-071](adr/ADR-071-ruvllm-training-pipeline.md) and the [pretraining tuto
|
||||
|
||||
---
|
||||
|
||||
## Pre-Trained Models (No Training Required)
|
||||
|
||||
Pre-trained models are available on HuggingFace: **https://huggingface.co/ruvnet/wifi-densepose-pretrained**
|
||||
|
||||
Download and start sensing immediately — no datasets, no GPU, no training needed.
|
||||
|
||||
### Quick Start with Pre-Trained Models
|
||||
|
||||
```bash
|
||||
# Install huggingface CLI
|
||||
pip install huggingface_hub
|
||||
|
||||
# Download all models
|
||||
huggingface-cli download ruvnet/wifi-densepose-pretrained --local-dir models/pretrained
|
||||
|
||||
# The models include:
|
||||
# model.safetensors — 48 KB contrastive encoder
|
||||
# model-q4.bin — 8 KB quantized (recommended)
|
||||
# model-q2.bin — 4 KB ultra-compact (ESP32 edge)
|
||||
# presence-head.json — presence detection head (100% accuracy)
|
||||
# node-1.json — LoRA adapter for room 1
|
||||
# node-2.json — LoRA adapter for room 2
|
||||
```
|
||||
|
||||
### What the Models Do
|
||||
|
||||
The pre-trained encoder converts 8-dim CSI feature vectors into 128-dim embeddings. These embeddings power all 17 sensing applications:
|
||||
|
||||
- **Presence detection** — 100% accuracy, never misses, never false alarms
|
||||
- **Environment fingerprinting** — kNN search finds "states like this one"
|
||||
- **Anomaly detection** — embeddings that don't match known clusters = anomaly
|
||||
- **Activity classification** — different activities cluster in embedding space
|
||||
- **Room adaptation** — swap LoRA adapters for different rooms without retraining
|
||||
|
||||
### Retraining on Your Own Data
|
||||
|
||||
If you want to improve accuracy for your specific environment:
|
||||
|
||||
```bash
|
||||
# Collect 2+ minutes of CSI from your ESP32
|
||||
python scripts/collect-training-data.py --port 5006 --duration 120
|
||||
|
||||
# Retrain (uses ruvllm, no PyTorch needed)
|
||||
node scripts/train-ruvllm.js --data data/recordings/*.csi.jsonl
|
||||
|
||||
# Benchmark your retrained model
|
||||
node scripts/benchmark-ruvllm.js --model models/csi-ruvllm
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Health & Wellness Applications
|
||||
|
||||
WiFi sensing can monitor health metrics without any wearable or camera:
|
||||
|
||||
```bash
|
||||
# Sleep quality monitoring (run overnight)
|
||||
node scripts/sleep-monitor.js --port 5006 --bind 192.168.1.20
|
||||
|
||||
# Breathing disorder pre-screening
|
||||
node scripts/apnea-detector.js --port 5006 --bind 192.168.1.20
|
||||
|
||||
# Stress detection via heart rate variability
|
||||
node scripts/stress-monitor.js --port 5006 --bind 192.168.1.20
|
||||
|
||||
# Walking analysis + tremor detection
|
||||
node scripts/gait-analyzer.js --port 5006 --bind 192.168.1.20
|
||||
|
||||
# Replay on recorded data (no live hardware needed)
|
||||
node scripts/sleep-monitor.js --replay data/recordings/*.csi.jsonl
|
||||
```
|
||||
|
||||
> **Note:** These are pre-screening tools, not medical devices. Consult a healthcare professional for diagnosis.
|
||||
|
||||
---
|
||||
|
||||
## ruvllm Training Pipeline
|
||||
|
||||
All training uses **ruvllm** — a Rust-native ML runtime. No Python, no PyTorch, no GPU drivers required. Runs on any machine with Node.js.
|
||||
|
||||
@@ -4,5 +4,10 @@ cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
set(EXTRA_COMPONENT_DIRS "")
|
||||
|
||||
# Read firmware version from version.txt so esp_app_get_description()->version
|
||||
# matches the release tag. Fixes issue #354 (version mismatch after flashing).
|
||||
file(STRINGS "${CMAKE_CURRENT_LIST_DIR}/version.txt" PROJECT_VER LIMIT_COUNT 1)
|
||||
string(STRIP "${PROJECT_VER}" PROJECT_VER)
|
||||
|
||||
include($ENV{IDF_PATH}/tools/cmake/project.cmake)
|
||||
project(esp32-csi-node)
|
||||
project(esp32-csi-node VERSION ${PROJECT_VER})
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
@echo off
|
||||
echo STARTING > C:\Users\ruv\idf_test.txt
|
||||
set IDF_PATH=C:\Users\ruv\esp\v5.4\esp-idf
|
||||
set PATH=C:\Espressif\tools\python\v5.4\venv\Scripts;C:\Espressif\tools\xtensa-esp-elf\esp-14.2.0_20241119\xtensa-esp-elf\bin;C:\Espressif\tools\cmake\3.30.2\bin;C:\Espressif\tools\ninja\1.12.1;C:\Espressif\tools\idf-exe\1.0.3;%PATH%
|
||||
echo PATH_SET >> C:\Users\ruv\idf_test.txt
|
||||
cd /d C:\Users\ruv\Projects\wifi-densepose\firmware\esp32-csi-node
|
||||
echo CD_DONE >> C:\Users\ruv\idf_test.txt
|
||||
python %IDF_PATH%\tools\idf.py build >> C:\Users\ruv\idf_test.txt 2>&1
|
||||
echo RC=%ERRORLEVEL% >> C:\Users\ruv\idf_test.txt
|
||||
@@ -76,7 +76,6 @@ menu "Edge Intelligence (ADR-039)"
|
||||
Raise to reduce false positives in high-traffic environments.
|
||||
Normal walking produces accelerations of 2-5 rad/s².
|
||||
Stored as integer; divided by 1000 at runtime.
|
||||
Default 2000 = 2.0 rad/s^2.
|
||||
|
||||
config EDGE_POWER_DUTY
|
||||
int "Power duty cycle percentage"
|
||||
|
||||
@@ -118,8 +118,14 @@ esp_err_t display_task_start(void)
|
||||
if (!buf1 || !buf2) {
|
||||
ESP_LOGE(TAG, "Failed to allocate LVGL buffers (%u bytes, caps=0x%lx)",
|
||||
(unsigned)buf_size, (unsigned long)alloc_caps);
|
||||
if (buf1) free(buf1);
|
||||
if (buf2) free(buf2);
|
||||
if (buf1) {
|
||||
free(buf1);
|
||||
buf1 = NULL;
|
||||
}
|
||||
if (buf2) {
|
||||
free(buf2);
|
||||
buf2 = NULL;
|
||||
}
|
||||
return ESP_OK;
|
||||
}
|
||||
ESP_LOGI(TAG, "LVGL buffers: 2x %u bytes (%u lines, %s)",
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "esp_event.h"
|
||||
#include "esp_log.h"
|
||||
#include "nvs_flash.h"
|
||||
#include "esp_app_desc.h"
|
||||
#include "sdkconfig.h"
|
||||
|
||||
#include "csi_collector.h"
|
||||
@@ -137,7 +138,9 @@ void app_main(void)
|
||||
/* Load runtime config (NVS overrides Kconfig defaults) */
|
||||
nvs_config_load(&g_nvs_config);
|
||||
|
||||
ESP_LOGI(TAG, "ESP32-S3 CSI Node (ADR-018) — Node ID: %d", g_nvs_config.node_id);
|
||||
const esp_app_desc_t *app_desc = esp_app_get_description();
|
||||
ESP_LOGI(TAG, "ESP32-S3 CSI Node (ADR-018) — v%s — Node ID: %d",
|
||||
app_desc->version, g_nvs_config.node_id);
|
||||
|
||||
/* Initialize WiFi STA (skip entirely under QEMU mock — no RF hardware) */
|
||||
#ifndef CONFIG_CSI_MOCK_SKIP_WIFI_CONNECT
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
0.6.0
|
||||
Generated
+1
@@ -7769,6 +7769,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"futures-util",
|
||||
"ruvector-mincut",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
|
||||
@@ -330,9 +330,36 @@ impl<B: Backend> InferenceEngine<B> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Run batched inference
|
||||
/// Run batched inference.
|
||||
///
|
||||
/// Stacks all inputs along a new batch dimension, runs a single
|
||||
/// backend call, then splits the output back into individual tensors.
|
||||
/// Falls back to sequential inference if stack/split fails.
|
||||
pub fn infer_batch(&self, inputs: &[Tensor]) -> NnResult<Vec<Tensor>> {
|
||||
inputs.iter().map(|input| self.infer(input)).collect()
|
||||
if inputs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
if inputs.len() == 1 {
|
||||
return Ok(vec![self.infer(&inputs[0])?]);
|
||||
}
|
||||
// Try batched path: stack -> single call -> split
|
||||
match Tensor::stack(inputs) {
|
||||
Ok(batched_input) => {
|
||||
let n = inputs.len();
|
||||
let batched_output = self.backend.run_single(&batched_input)?;
|
||||
match batched_output.split(n) {
|
||||
Ok(outputs) => Ok(outputs),
|
||||
Err(_) => {
|
||||
// Fallback: sequential
|
||||
inputs.iter().map(|input| self.infer(input)).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Fallback: sequential if shapes are incompatible
|
||||
inputs.iter().map(|input| self.infer(input)).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get inference statistics
|
||||
|
||||
@@ -304,6 +304,74 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Stack multiple tensors along a new batch dimension (dim 0).
|
||||
///
|
||||
/// All tensors must have the same shape. The result has one extra
|
||||
/// leading dimension equal to `tensors.len()`.
|
||||
pub fn stack(tensors: &[Tensor]) -> NnResult<Tensor> {
|
||||
if tensors.is_empty() {
|
||||
return Err(NnError::tensor_op("Cannot stack zero tensors"));
|
||||
}
|
||||
let first_shape = tensors[0].shape();
|
||||
for (i, t) in tensors.iter().enumerate().skip(1) {
|
||||
if t.shape() != first_shape {
|
||||
return Err(NnError::tensor_op(&format!(
|
||||
"Shape mismatch at index {i}: expected {first_shape}, got {}",
|
||||
t.shape()
|
||||
)));
|
||||
}
|
||||
}
|
||||
let mut all_data: Vec<f32> = Vec::with_capacity(tensors.len() * first_shape.numel());
|
||||
for t in tensors {
|
||||
let data = t.to_vec()?;
|
||||
all_data.extend_from_slice(&data);
|
||||
}
|
||||
let mut new_dims = vec![tensors.len()];
|
||||
new_dims.extend_from_slice(first_shape.dims());
|
||||
let arr = ndarray::ArrayD::from_shape_vec(
|
||||
ndarray::IxDyn(&new_dims),
|
||||
all_data,
|
||||
)
|
||||
.map_err(|e| NnError::tensor_op(&format!("Stack reshape failed: {e}")))?;
|
||||
Ok(Tensor::FloatND(arr))
|
||||
}
|
||||
|
||||
/// Split a tensor along dim 0 into `n` sub-tensors.
|
||||
///
|
||||
/// The first dimension must be evenly divisible by `n`.
|
||||
pub fn split(self, n: usize) -> NnResult<Vec<Tensor>> {
|
||||
if n == 0 {
|
||||
return Err(NnError::tensor_op("Cannot split into 0 pieces"));
|
||||
}
|
||||
let shape = self.shape();
|
||||
let batch = shape.dim(0).ok_or_else(|| NnError::tensor_op("Tensor has no dimensions"))?;
|
||||
if batch % n != 0 {
|
||||
return Err(NnError::tensor_op(&format!(
|
||||
"Batch dim {batch} not divisible by {n}"
|
||||
)));
|
||||
}
|
||||
let chunk_size = batch / n;
|
||||
let data = self.to_vec()?;
|
||||
let elem_per_sample = shape.numel() / batch;
|
||||
let sub_dims: Vec<usize> = {
|
||||
let mut d = shape.dims().to_vec();
|
||||
d[0] = chunk_size;
|
||||
d
|
||||
};
|
||||
let mut result = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let start = i * chunk_size * elem_per_sample;
|
||||
let end = start + chunk_size * elem_per_sample;
|
||||
let arr = ndarray::ArrayD::from_shape_vec(
|
||||
ndarray::IxDyn(&sub_dims),
|
||||
data[start..end].to_vec(),
|
||||
)
|
||||
.map_err(|e| NnError::tensor_op(&format!("Split reshape failed: {e}")))?;
|
||||
result.push(Tensor::FloatND(arr));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute standard deviation
|
||||
pub fn std(&self) -> NnResult<f32> {
|
||||
match self {
|
||||
|
||||
@@ -43,8 +43,8 @@ clap = { workspace = true }
|
||||
# Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3)
|
||||
wifi-densepose-wifiscan = { version = "0.3.0", path = "../wifi-densepose-wifiscan" }
|
||||
|
||||
# RuVector graph min-cut for person separation (ADR-068)
|
||||
ruvector-mincut = { workspace = true }
|
||||
# Signal processing with RuvSense pose tracker (accuracy sprint)
|
||||
wifi-densepose-signal = { version = "0.3.0", path = "../wifi-densepose-signal" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.10"
|
||||
|
||||
+120
-60
@@ -10,6 +10,10 @@
|
||||
//!
|
||||
//! The trained model is serialised as JSON and hot-loaded at runtime so that
|
||||
//! the classification thresholds adapt to the specific room and ESP32 placement.
|
||||
//!
|
||||
//! Classes are discovered dynamically from training data filenames instead of
|
||||
//! being hardcoded, so new activity classes can be added just by recording data
|
||||
//! with the appropriate filename convention.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@@ -20,9 +24,8 @@ use std::path::{Path, PathBuf};
|
||||
/// Extended feature vector: 7 server features + 8 subcarrier-derived features = 15.
|
||||
const N_FEATURES: usize = 15;
|
||||
|
||||
/// Activity classes we recognise.
|
||||
pub const CLASSES: &[&str] = &["absent", "present_still", "present_moving", "active"];
|
||||
const N_CLASSES: usize = 4;
|
||||
/// Default class names for backward compatibility with old saved models.
|
||||
const DEFAULT_CLASSES: &[&str] = &["absent", "present_still", "present_moving", "active"];
|
||||
|
||||
/// Extract extended feature vector from a JSONL frame (features + raw amplitudes).
|
||||
pub fn features_from_frame(frame: &serde_json::Value) -> [f64; N_FEATURES] {
|
||||
@@ -124,8 +127,9 @@ pub struct ClassStats {
|
||||
pub struct AdaptiveModel {
|
||||
/// Per-class feature statistics (centroid + spread).
|
||||
pub class_stats: Vec<ClassStats>,
|
||||
/// Logistic regression weights: [N_CLASSES x (N_FEATURES + 1)] (last = bias).
|
||||
pub weights: Vec<[f64; N_FEATURES + 1]>,
|
||||
/// Logistic regression weights: [n_classes x (N_FEATURES + 1)] (last = bias).
|
||||
/// Dynamic: the outer Vec length equals the number of discovered classes.
|
||||
pub weights: Vec<Vec<f64>>,
|
||||
/// Global feature normalisation: mean and stddev across all training data.
|
||||
pub global_mean: [f64; N_FEATURES],
|
||||
pub global_std: [f64; N_FEATURES],
|
||||
@@ -133,27 +137,38 @@ pub struct AdaptiveModel {
|
||||
pub trained_frames: usize,
|
||||
pub training_accuracy: f64,
|
||||
pub version: u32,
|
||||
/// Dynamically discovered class names (in index order).
|
||||
#[serde(default = "default_class_names")]
|
||||
pub class_names: Vec<String>,
|
||||
}
|
||||
|
||||
/// Backward-compatible fallback for models saved without class_names.
|
||||
fn default_class_names() -> Vec<String> {
|
||||
DEFAULT_CLASSES.iter().map(|s| s.to_string()).collect()
|
||||
}
|
||||
|
||||
impl Default for AdaptiveModel {
|
||||
fn default() -> Self {
|
||||
let n_classes = DEFAULT_CLASSES.len();
|
||||
Self {
|
||||
class_stats: Vec::new(),
|
||||
weights: vec![[0.0; N_FEATURES + 1]; N_CLASSES],
|
||||
weights: vec![vec![0.0; N_FEATURES + 1]; n_classes],
|
||||
global_mean: [0.0; N_FEATURES],
|
||||
global_std: [1.0; N_FEATURES],
|
||||
trained_frames: 0,
|
||||
training_accuracy: 0.0,
|
||||
version: 1,
|
||||
class_names: default_class_names(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveModel {
|
||||
/// Classify a raw feature vector. Returns (class_label, confidence).
|
||||
pub fn classify(&self, raw_features: &[f64; N_FEATURES]) -> (&'static str, f64) {
|
||||
if self.weights.is_empty() || self.class_stats.is_empty() {
|
||||
return ("present_still", 0.5);
|
||||
pub fn classify(&self, raw_features: &[f64; N_FEATURES]) -> (String, f64) {
|
||||
let n_classes = self.weights.len();
|
||||
if n_classes == 0 || self.class_stats.is_empty() {
|
||||
return ("present_still".to_string(), 0.5);
|
||||
}
|
||||
|
||||
// Normalise features.
|
||||
@@ -163,8 +178,8 @@ impl AdaptiveModel {
|
||||
}
|
||||
|
||||
// Compute logits: w·x + b for each class.
|
||||
let mut logits = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES.min(self.weights.len()) {
|
||||
let mut logits: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
let w = &self.weights[c];
|
||||
let mut z = w[N_FEATURES]; // bias
|
||||
for i in 0..N_FEATURES {
|
||||
@@ -176,8 +191,8 @@ impl AdaptiveModel {
|
||||
// Softmax.
|
||||
let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
let exp_sum: f64 = logits.iter().map(|z| (z - max_logit).exp()).sum();
|
||||
let mut probs = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES {
|
||||
let mut probs: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
probs[c] = ((logits[c] - max_logit).exp()) / exp_sum;
|
||||
}
|
||||
|
||||
@@ -185,7 +200,11 @@ impl AdaptiveModel {
|
||||
let (best_c, best_p) = probs.iter().enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.unwrap();
|
||||
let label = if best_c < CLASSES.len() { CLASSES[best_c] } else { "present_still" };
|
||||
let label = if best_c < self.class_names.len() {
|
||||
self.class_names[best_c].clone()
|
||||
} else {
|
||||
"present_still".to_string()
|
||||
};
|
||||
(label, *best_p)
|
||||
}
|
||||
|
||||
@@ -228,48 +247,88 @@ fn load_recording(path: &Path, class_idx: usize) -> Vec<Sample> {
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Map a recording filename to a class index.
|
||||
fn classify_recording_name(name: &str) -> Option<usize> {
|
||||
/// Map a recording filename to a class name (String).
|
||||
/// Returns the discovered class name for the file, or None if it cannot be determined.
|
||||
fn classify_recording_name(name: &str) -> Option<String> {
|
||||
let lower = name.to_lowercase();
|
||||
if lower.contains("empty") || lower.contains("absent") { Some(0) }
|
||||
else if lower.contains("still") || lower.contains("sitting") || lower.contains("standing") { Some(1) }
|
||||
else if lower.contains("walking") || lower.contains("moving") { Some(2) }
|
||||
else if lower.contains("active") || lower.contains("exercise") || lower.contains("running") { Some(3) }
|
||||
else { None }
|
||||
// Strip "train_" prefix and ".jsonl" suffix, then extract the class label.
|
||||
// Convention: train_<class>_<description>.jsonl
|
||||
// The class is the first segment after "train_" that matches a known pattern,
|
||||
// or the entire middle portion if no pattern matches.
|
||||
|
||||
// Check common patterns first for backward compat
|
||||
if lower.contains("empty") || lower.contains("absent") { return Some("absent".into()); }
|
||||
if lower.contains("still") || lower.contains("sitting") || lower.contains("standing") { return Some("present_still".into()); }
|
||||
if lower.contains("walking") || lower.contains("moving") { return Some("present_moving".into()); }
|
||||
if lower.contains("active") || lower.contains("exercise") || lower.contains("running") { return Some("active".into()); }
|
||||
|
||||
// Fallback: extract class from filename structure train_<class>_*.jsonl
|
||||
let stem = lower.trim_start_matches("train_").trim_end_matches(".jsonl");
|
||||
let class_name = stem.split('_').next().unwrap_or(stem);
|
||||
if !class_name.is_empty() {
|
||||
Some(class_name.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Train a model from labeled JSONL recordings in a directory.
|
||||
///
|
||||
/// Recordings are matched to classes by filename pattern:
|
||||
/// - `*empty*` / `*absent*` → absent (0)
|
||||
/// - `*still*` / `*sitting*` → present_still (1)
|
||||
/// - `*walking*` / `*moving*` → present_moving (2)
|
||||
/// - `*active*` / `*exercise*`→ active (3)
|
||||
/// Recordings are matched to classes by filename pattern. Classes are discovered
|
||||
/// dynamically from the training data filenames:
|
||||
/// - `*empty*` / `*absent*` → absent
|
||||
/// - `*still*` / `*sitting*` → present_still
|
||||
/// - `*walking*` / `*moving*` → present_moving
|
||||
/// - `*active*` / `*exercise*`→ active
|
||||
/// - Any other `train_<class>_*.jsonl` → <class>
|
||||
pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, String> {
|
||||
// Scan for train_* files.
|
||||
let mut samples: Vec<Sample> = Vec::new();
|
||||
let entries = std::fs::read_dir(recordings_dir)
|
||||
.map_err(|e| format!("Cannot read {}: {}", recordings_dir.display(), e))?;
|
||||
// First pass: scan filenames to discover all unique class names.
|
||||
let entries: Vec<_> = std::fs::read_dir(recordings_dir)
|
||||
.map_err(|e| format!("Cannot read {}: {}", recordings_dir.display(), e))?
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let mut class_map: HashMap<String, usize> = HashMap::new();
|
||||
let mut class_names: Vec<String> = Vec::new();
|
||||
|
||||
// Collect (entry, class_name) pairs for files that match.
|
||||
let mut file_classes: Vec<(PathBuf, String, String)> = Vec::new(); // (path, fname, class_name)
|
||||
for entry in &entries {
|
||||
let fname = entry.file_name().to_string_lossy().to_string();
|
||||
if !fname.starts_with("train_") || !fname.ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
if let Some(class_idx) = classify_recording_name(&fname) {
|
||||
let loaded = load_recording(&entry.path(), class_idx);
|
||||
eprintln!(" Loaded {}: {} frames → class '{}'",
|
||||
fname, loaded.len(), CLASSES[class_idx]);
|
||||
samples.extend(loaded);
|
||||
if let Some(class_name) = classify_recording_name(&fname) {
|
||||
if !class_map.contains_key(&class_name) {
|
||||
let idx = class_names.len();
|
||||
class_map.insert(class_name.clone(), idx);
|
||||
class_names.push(class_name.clone());
|
||||
}
|
||||
file_classes.push((entry.path(), fname, class_name));
|
||||
}
|
||||
}
|
||||
|
||||
let n_classes = class_names.len();
|
||||
if n_classes == 0 {
|
||||
return Err("No training samples found. Record data with train_* prefix.".into());
|
||||
}
|
||||
|
||||
// Second pass: load recordings with the discovered class indices.
|
||||
let mut samples: Vec<Sample> = Vec::new();
|
||||
for (path, fname, class_name) in &file_classes {
|
||||
let class_idx = class_map[class_name];
|
||||
let loaded = load_recording(path, class_idx);
|
||||
eprintln!(" Loaded {}: {} frames → class '{}'",
|
||||
fname, loaded.len(), class_name);
|
||||
samples.extend(loaded);
|
||||
}
|
||||
|
||||
if samples.is_empty() {
|
||||
return Err("No training samples found. Record data with train_* prefix.".into());
|
||||
}
|
||||
|
||||
let n = samples.len();
|
||||
eprintln!("Total training samples: {n}");
|
||||
eprintln!("Total training samples: {n} across {n_classes} classes: {:?}", class_names);
|
||||
|
||||
// ── Compute global normalisation stats ──
|
||||
let mut global_mean = [0.0f64; N_FEATURES];
|
||||
@@ -289,9 +348,9 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
}
|
||||
|
||||
// ── Compute per-class statistics ──
|
||||
let mut class_sums = vec![[0.0f64; N_FEATURES]; N_CLASSES];
|
||||
let mut class_sq = vec![[0.0f64; N_FEATURES]; N_CLASSES];
|
||||
let mut class_counts = vec![0usize; N_CLASSES];
|
||||
let mut class_sums = vec![[0.0f64; N_FEATURES]; n_classes];
|
||||
let mut class_sq = vec![[0.0f64; N_FEATURES]; n_classes];
|
||||
let mut class_counts = vec![0usize; n_classes];
|
||||
for s in &samples {
|
||||
let c = s.class_idx;
|
||||
class_counts[c] += 1;
|
||||
@@ -302,7 +361,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
}
|
||||
|
||||
let mut class_stats = Vec::new();
|
||||
for c in 0..N_CLASSES {
|
||||
for c in 0..n_classes {
|
||||
let cnt = class_counts[c].max(1) as f64;
|
||||
let mut mean = [0.0; N_FEATURES];
|
||||
let mut stddev = [0.0; N_FEATURES];
|
||||
@@ -311,7 +370,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
stddev[i] = ((class_sq[c][i] / cnt) - mean[i] * mean[i]).max(0.0).sqrt();
|
||||
}
|
||||
class_stats.push(ClassStats {
|
||||
label: CLASSES[c].to_string(),
|
||||
label: class_names[c].clone(),
|
||||
count: class_counts[c],
|
||||
mean,
|
||||
stddev,
|
||||
@@ -328,7 +387,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
}).collect();
|
||||
|
||||
// ── Train logistic regression via mini-batch SGD ──
|
||||
let mut weights = vec![[0.0f64; N_FEATURES + 1]; N_CLASSES];
|
||||
let mut weights: Vec<Vec<f64>> = vec![vec![0.0f64; N_FEATURES + 1]; n_classes];
|
||||
let lr = 0.1;
|
||||
let epochs = 200;
|
||||
let batch_size = 32;
|
||||
@@ -348,19 +407,19 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
}
|
||||
|
||||
let mut epoch_loss = 0.0f64;
|
||||
let mut batch_count = 0;
|
||||
let mut _batch_count = 0;
|
||||
|
||||
for batch_start in (0..norm_samples.len()).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(norm_samples.len());
|
||||
let batch = &norm_samples[batch_start..batch_end];
|
||||
|
||||
// Accumulate gradients.
|
||||
let mut grad = vec![[0.0f64; N_FEATURES + 1]; N_CLASSES];
|
||||
let mut grad: Vec<Vec<f64>> = vec![vec![0.0f64; N_FEATURES + 1]; n_classes];
|
||||
|
||||
for (x, target) in batch {
|
||||
// Forward: softmax.
|
||||
let mut logits = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES {
|
||||
let mut logits: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
logits[c] = weights[c][N_FEATURES]; // bias
|
||||
for i in 0..N_FEATURES {
|
||||
logits[c] += weights[c][i] * x[i];
|
||||
@@ -368,8 +427,8 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
}
|
||||
let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
let exp_sum: f64 = logits.iter().map(|z| (z - max_l).exp()).sum();
|
||||
let mut probs = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES {
|
||||
let mut probs: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
probs[c] = ((logits[c] - max_l).exp()) / exp_sum;
|
||||
}
|
||||
|
||||
@@ -377,7 +436,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
epoch_loss += -(probs[*target].max(1e-15)).ln();
|
||||
|
||||
// Gradient: prob - one_hot(target).
|
||||
for c in 0..N_CLASSES {
|
||||
for c in 0..n_classes {
|
||||
let delta = probs[c] - if c == *target { 1.0 } else { 0.0 };
|
||||
for i in 0..N_FEATURES {
|
||||
grad[c][i] += delta * x[i];
|
||||
@@ -389,12 +448,12 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
// Update weights.
|
||||
let bs = batch.len() as f64;
|
||||
let current_lr = lr * (1.0 - epoch as f64 / epochs as f64); // linear decay
|
||||
for c in 0..N_CLASSES {
|
||||
for c in 0..n_classes {
|
||||
for i in 0..=N_FEATURES {
|
||||
weights[c][i] -= current_lr * grad[c][i] / bs;
|
||||
}
|
||||
}
|
||||
batch_count += 1;
|
||||
_batch_count += 1;
|
||||
}
|
||||
|
||||
if epoch % 50 == 0 || epoch == epochs - 1 {
|
||||
@@ -406,8 +465,8 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
// ── Evaluate accuracy ──
|
||||
let mut correct = 0;
|
||||
for (x, target) in &norm_samples {
|
||||
let mut logits = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES {
|
||||
let mut logits: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
logits[c] = weights[c][N_FEATURES];
|
||||
for i in 0..N_FEATURES {
|
||||
logits[c] += weights[c][i] * x[i];
|
||||
@@ -422,12 +481,12 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
eprintln!("Training accuracy: {correct}/{n} = {accuracy:.1}%");
|
||||
|
||||
// ── Per-class accuracy ──
|
||||
let mut class_correct = vec![0usize; N_CLASSES];
|
||||
let mut class_total = vec![0usize; N_CLASSES];
|
||||
let mut class_correct = vec![0usize; n_classes];
|
||||
let mut class_total = vec![0usize; n_classes];
|
||||
for (x, target) in &norm_samples {
|
||||
class_total[*target] += 1;
|
||||
let mut logits = [0.0f64; N_CLASSES];
|
||||
for c in 0..N_CLASSES {
|
||||
let mut logits: Vec<f64> = vec![0.0; n_classes];
|
||||
for c in 0..n_classes {
|
||||
logits[c] = weights[c][N_FEATURES];
|
||||
for i in 0..N_FEATURES {
|
||||
logits[c] += weights[c][i] * x[i];
|
||||
@@ -438,9 +497,9 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
.unwrap().0;
|
||||
if pred == *target { class_correct[*target] += 1; }
|
||||
}
|
||||
for c in 0..N_CLASSES {
|
||||
for c in 0..n_classes {
|
||||
let tot = class_total[c].max(1);
|
||||
eprintln!(" {}: {}/{} ({:.0}%)", CLASSES[c], class_correct[c], tot,
|
||||
eprintln!(" {}: {}/{} ({:.0}%)", class_names[c], class_correct[c], tot,
|
||||
class_correct[c] as f64 / tot as f64 * 100.0);
|
||||
}
|
||||
|
||||
@@ -452,6 +511,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
|
||||
trained_frames: n,
|
||||
training_accuracy: accuracy,
|
||||
version: 1,
|
||||
class_names,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
//! CLI argument definitions and early-exit mode handlers.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use clap::Parser;
|
||||
|
||||
/// CLI arguments for the sensing server.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "sensing-server", about = "WiFi-DensePose sensing server")]
|
||||
pub struct Args {
|
||||
/// HTTP port for UI and REST API
|
||||
#[arg(long, default_value = "8080")]
|
||||
pub http_port: u16,
|
||||
|
||||
/// WebSocket port for sensing stream
|
||||
#[arg(long, default_value = "8765")]
|
||||
pub ws_port: u16,
|
||||
|
||||
/// UDP port for ESP32 CSI frames
|
||||
#[arg(long, default_value = "5005")]
|
||||
pub udp_port: u16,
|
||||
|
||||
/// Path to UI static files
|
||||
#[arg(long, default_value = "../../ui")]
|
||||
pub ui_path: PathBuf,
|
||||
|
||||
/// Tick interval in milliseconds (default 100 ms = 10 fps for smooth pose animation)
|
||||
#[arg(long, default_value = "100")]
|
||||
pub tick_ms: u64,
|
||||
|
||||
/// Bind address (default 127.0.0.1; set to 0.0.0.0 for network access)
|
||||
#[arg(long, default_value = "127.0.0.1", env = "SENSING_BIND_ADDR")]
|
||||
pub bind_addr: String,
|
||||
|
||||
/// Data source: auto, wifi, esp32, simulate
|
||||
#[arg(long, default_value = "auto")]
|
||||
pub source: String,
|
||||
|
||||
/// Run vital sign detection benchmark (1000 frames) and exit
|
||||
#[arg(long)]
|
||||
pub benchmark: bool,
|
||||
|
||||
/// Load model config from an RVF container at startup
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub load_rvf: Option<PathBuf>,
|
||||
|
||||
/// Save current model state as an RVF container on shutdown
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub save_rvf: Option<PathBuf>,
|
||||
|
||||
/// Load a trained .rvf model for inference
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub model: Option<PathBuf>,
|
||||
|
||||
/// Enable progressive loading (Layer A instant start)
|
||||
#[arg(long)]
|
||||
pub progressive: bool,
|
||||
|
||||
/// Export an RVF container package and exit (no server)
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub export_rvf: Option<PathBuf>,
|
||||
|
||||
/// Run training mode (train a model and exit)
|
||||
#[arg(long)]
|
||||
pub train: bool,
|
||||
|
||||
/// Path to dataset directory (MM-Fi or Wi-Pose)
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub dataset: Option<PathBuf>,
|
||||
|
||||
/// Dataset type: "mmfi" or "wipose"
|
||||
#[arg(long, value_name = "TYPE", default_value = "mmfi")]
|
||||
pub dataset_type: String,
|
||||
|
||||
/// Number of training epochs
|
||||
#[arg(long, default_value = "100")]
|
||||
pub epochs: usize,
|
||||
|
||||
/// Directory for training checkpoints
|
||||
#[arg(long, value_name = "DIR")]
|
||||
pub checkpoint_dir: Option<PathBuf>,
|
||||
|
||||
/// Run self-supervised contrastive pretraining (ADR-024)
|
||||
#[arg(long)]
|
||||
pub pretrain: bool,
|
||||
|
||||
/// Number of pretraining epochs (default 50)
|
||||
#[arg(long, default_value = "50")]
|
||||
pub pretrain_epochs: usize,
|
||||
|
||||
/// Extract embeddings mode: load model and extract CSI embeddings
|
||||
#[arg(long)]
|
||||
pub embed: bool,
|
||||
|
||||
/// Build fingerprint index from embeddings (env|activity|temporal|person)
|
||||
#[arg(long, value_name = "TYPE")]
|
||||
pub build_index: Option<String>,
|
||||
|
||||
/// Node positions for multistatic fusion (format: "x,y,z;x,y,z;...")
|
||||
#[arg(long, env = "SENSING_NODE_POSITIONS")]
|
||||
pub node_positions: Option<String>,
|
||||
|
||||
/// Start field model calibration on boot (empty room required)
|
||||
#[arg(long)]
|
||||
pub calibrate: bool,
|
||||
}
|
||||
@@ -0,0 +1,675 @@
|
||||
//! CSI frame parsing, signal field generation, feature extraction,
|
||||
//! classification, vital signs smoothing, and multi-person estimation.
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
|
||||
|
||||
use crate::adaptive_classifier;
|
||||
use crate::types::*;
|
||||
use crate::vital_signs::VitalSigns;
|
||||
|
||||
// ── ESP32 UDP frame parsers ─────────────────────────────────────────────────
|
||||
|
||||
/// Parse a 32-byte edge vitals packet (magic 0xC511_0002).
|
||||
pub fn parse_esp32_vitals(buf: &[u8]) -> Option<Esp32VitalsPacket> {
|
||||
if buf.len() < 32 { return None; }
|
||||
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
|
||||
if magic != 0xC511_0002 { return None; }
|
||||
|
||||
let node_id = buf[4];
|
||||
let flags = buf[5];
|
||||
let breathing_raw = u16::from_le_bytes([buf[6], buf[7]]);
|
||||
let heartrate_raw = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
|
||||
let rssi = buf[12] as i8;
|
||||
let n_persons = buf[13];
|
||||
let motion_energy = f32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
|
||||
let presence_score = f32::from_le_bytes([buf[20], buf[21], buf[22], buf[23]]);
|
||||
let timestamp_ms = u32::from_le_bytes([buf[24], buf[25], buf[26], buf[27]]);
|
||||
|
||||
Some(Esp32VitalsPacket {
|
||||
node_id,
|
||||
presence: (flags & 0x01) != 0,
|
||||
fall_detected: (flags & 0x02) != 0,
|
||||
motion: (flags & 0x04) != 0,
|
||||
breathing_rate_bpm: breathing_raw as f64 / 100.0,
|
||||
heartrate_bpm: heartrate_raw as f64 / 10000.0,
|
||||
rssi, n_persons, motion_energy, presence_score, timestamp_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a WASM output packet (magic 0xC511_0004).
|
||||
pub fn parse_wasm_output(buf: &[u8]) -> Option<WasmOutputPacket> {
|
||||
if buf.len() < 8 { return None; }
|
||||
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
|
||||
if magic != 0xC511_0004 { return None; }
|
||||
|
||||
let node_id = buf[4];
|
||||
let module_id = buf[5];
|
||||
let event_count = u16::from_le_bytes([buf[6], buf[7]]) as usize;
|
||||
|
||||
let mut events = Vec::with_capacity(event_count);
|
||||
let mut offset = 8;
|
||||
for _ in 0..event_count {
|
||||
if offset + 5 > buf.len() { break; }
|
||||
let event_type = buf[offset];
|
||||
let value = f32::from_le_bytes([
|
||||
buf[offset + 1], buf[offset + 2], buf[offset + 3], buf[offset + 4],
|
||||
]);
|
||||
events.push(WasmEvent { event_type, value });
|
||||
offset += 5;
|
||||
}
|
||||
|
||||
Some(WasmOutputPacket { node_id, module_id, events })
|
||||
}
|
||||
|
||||
pub fn parse_esp32_frame(buf: &[u8]) -> Option<Esp32Frame> {
|
||||
if buf.len() < 20 { return None; }
|
||||
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
|
||||
if magic != 0xC511_0001 { return None; }
|
||||
|
||||
let node_id = buf[4];
|
||||
let n_antennas = buf[5];
|
||||
let n_subcarriers = buf[6];
|
||||
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]);
|
||||
let sequence = u32::from_le_bytes([buf[10], buf[11], buf[12], buf[13]]);
|
||||
let rssi_raw = buf[14] as i8;
|
||||
let rssi = if rssi_raw > 0 { rssi_raw.saturating_neg() } else { rssi_raw };
|
||||
let noise_floor = buf[15] as i8;
|
||||
|
||||
let iq_start = 20;
|
||||
let n_pairs = n_antennas as usize * n_subcarriers as usize;
|
||||
let expected_len = iq_start + n_pairs * 2;
|
||||
if buf.len() < expected_len { return None; }
|
||||
|
||||
let mut amplitudes = Vec::with_capacity(n_pairs);
|
||||
let mut phases = Vec::with_capacity(n_pairs);
|
||||
for k in 0..n_pairs {
|
||||
let i_val = buf[iq_start + k * 2] as i8 as f64;
|
||||
let q_val = buf[iq_start + k * 2 + 1] as i8 as f64;
|
||||
amplitudes.push((i_val * i_val + q_val * q_val).sqrt());
|
||||
phases.push(q_val.atan2(i_val));
|
||||
}
|
||||
|
||||
Some(Esp32Frame {
|
||||
magic, node_id, n_antennas, n_subcarriers, freq_mhz, sequence,
|
||||
rssi, noise_floor, amplitudes, phases,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Signal field generation ─────────────────────────────────────────────────
|
||||
|
||||
pub fn generate_signal_field(
|
||||
_mean_rssi: f64, motion_score: f64, breathing_rate_hz: f64,
|
||||
signal_quality: f64, subcarrier_variances: &[f64],
|
||||
) -> SignalField {
|
||||
let grid = 20usize;
|
||||
let mut values = vec![0.0f64; grid * grid];
|
||||
let center = (grid as f64 - 1.0) / 2.0;
|
||||
|
||||
let max_var = subcarrier_variances.iter().cloned().fold(0.0f64, f64::max);
|
||||
let norm_factor = if max_var > 1e-9 { max_var } else { 1.0 };
|
||||
let n_sub = subcarrier_variances.len().max(1);
|
||||
|
||||
for (k, &var) in subcarrier_variances.iter().enumerate() {
|
||||
let weight = (var / norm_factor) * motion_score;
|
||||
if weight < 1e-6 { continue; }
|
||||
let angle = (k as f64 / n_sub as f64) * 2.0 * std::f64::consts::PI;
|
||||
let radius = center * 0.8 * weight.sqrt();
|
||||
let hx = center + radius * angle.cos();
|
||||
let hz = center + radius * angle.sin();
|
||||
for z in 0..grid {
|
||||
for x in 0..grid {
|
||||
let dx = x as f64 - hx;
|
||||
let dz = z as f64 - hz;
|
||||
let dist2 = dx * dx + dz * dz;
|
||||
let spread = (0.5 + weight * 2.0).max(0.5);
|
||||
values[z * grid + x] += weight * (-dist2 / (2.0 * spread * spread)).exp();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for z in 0..grid {
|
||||
for x in 0..grid {
|
||||
let dx = x as f64 - center;
|
||||
let dz = z as f64 - center;
|
||||
let dist = (dx * dx + dz * dz).sqrt();
|
||||
let base = signal_quality * (-dist * 0.12).exp();
|
||||
values[z * grid + x] += base * 0.3;
|
||||
}
|
||||
}
|
||||
|
||||
if breathing_rate_hz > 0.05 {
|
||||
let ring_r = center * 0.55;
|
||||
let ring_width = 1.8f64;
|
||||
for z in 0..grid {
|
||||
for x in 0..grid {
|
||||
let dx = x as f64 - center;
|
||||
let dz = z as f64 - center;
|
||||
let dist = (dx * dx + dz * dz).sqrt();
|
||||
let ring_val = 0.08 * (-(dist - ring_r).powi(2) / (2.0 * ring_width * ring_width)).exp();
|
||||
values[z * grid + x] += ring_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let field_max = values.iter().cloned().fold(0.0f64, f64::max);
|
||||
let scale = if field_max > 1e-9 { 1.0 / field_max } else { 1.0 };
|
||||
for v in &mut values { *v = (*v * scale).clamp(0.0, 1.0); }
|
||||
|
||||
SignalField { grid_size: [grid, 1, grid], values }
|
||||
}
|
||||
|
||||
// ── Feature extraction ──────────────────────────────────────────────────────
|
||||
|
||||
pub fn estimate_breathing_rate_hz(frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64) -> f64 {
|
||||
let n = frame_history.len();
|
||||
if n < 6 { return 0.0; }
|
||||
|
||||
let series: Vec<f64> = frame_history.iter()
|
||||
.map(|amps| if amps.is_empty() { 0.0 } else { amps.iter().sum::<f64>() / amps.len() as f64 })
|
||||
.collect();
|
||||
let mean_s = series.iter().sum::<f64>() / n as f64;
|
||||
let detrended: Vec<f64> = series.iter().map(|x| x - mean_s).collect();
|
||||
|
||||
let n_candidates = 9usize;
|
||||
let f_low = 0.1f64;
|
||||
let f_high = 0.5f64;
|
||||
let mut best_freq = 0.0f64;
|
||||
let mut best_power = 0.0f64;
|
||||
|
||||
for i in 0..n_candidates {
|
||||
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
|
||||
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
|
||||
let coeff = 2.0 * omega.cos();
|
||||
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
|
||||
for &x in &detrended {
|
||||
let s = x + coeff * s_prev1 - s_prev2;
|
||||
s_prev2 = s_prev1;
|
||||
s_prev1 = s;
|
||||
}
|
||||
let power = s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
|
||||
if power > best_power { best_power = power; best_freq = freq; }
|
||||
}
|
||||
|
||||
let avg_power = {
|
||||
let mut total = 0.0f64;
|
||||
for i in 0..n_candidates {
|
||||
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
|
||||
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
|
||||
let coeff = 2.0 * omega.cos();
|
||||
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
|
||||
for &x in &detrended {
|
||||
let s = x + coeff * s_prev1 - s_prev2;
|
||||
s_prev2 = s_prev1;
|
||||
s_prev1 = s;
|
||||
}
|
||||
total += s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
|
||||
}
|
||||
total / n_candidates as f64
|
||||
};
|
||||
|
||||
if best_power > avg_power * 3.0 { best_freq.clamp(f_low, f_high) } else { 0.0 }
|
||||
}
|
||||
|
||||
pub fn compute_subcarrier_importance_weights(sensitivity: &[f64]) -> Vec<f64> {
|
||||
let n = sensitivity.len();
|
||||
if n == 0 { return vec![]; }
|
||||
let max_sens = sensitivity.iter().cloned().fold(f64::NEG_INFINITY, f64::max).max(1e-9);
|
||||
let mut sorted = sensitivity.to_vec();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let median = if n % 2 == 0 { (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0 } else { sorted[n / 2] };
|
||||
sensitivity.iter()
|
||||
.map(|&s| if s >= median { 1.0 + (s / max_sens).min(1.0) } else { 0.5 })
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn compute_subcarrier_variances(frame_history: &VecDeque<Vec<f64>>, n_sub: usize) -> Vec<f64> {
|
||||
if frame_history.is_empty() || n_sub == 0 { return vec![0.0; n_sub]; }
|
||||
let n_frames = frame_history.len() as f64;
|
||||
let mut means = vec![0.0f64; n_sub];
|
||||
let mut sq_means = vec![0.0f64; n_sub];
|
||||
for frame in frame_history.iter() {
|
||||
for k in 0..n_sub {
|
||||
let a = if k < frame.len() { frame[k] } else { 0.0 };
|
||||
means[k] += a;
|
||||
sq_means[k] += a * a;
|
||||
}
|
||||
}
|
||||
(0..n_sub).map(|k| {
|
||||
let mean = means[k] / n_frames;
|
||||
let sq_mean = sq_means[k] / n_frames;
|
||||
(sq_mean - mean * mean).max(0.0)
|
||||
}).collect()
|
||||
}
|
||||
|
||||
pub fn extract_features_from_frame(
|
||||
frame: &Esp32Frame, frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64,
|
||||
) -> (FeatureInfo, ClassificationInfo, f64, Vec<f64>, f64) {
|
||||
let n_sub = frame.amplitudes.len().max(1);
|
||||
let n = n_sub as f64;
|
||||
let mean_rssi = frame.rssi as f64;
|
||||
|
||||
let sub_sensitivity: Vec<f64> = frame.amplitudes.iter().map(|a| a.abs()).collect();
|
||||
let importance_weights = compute_subcarrier_importance_weights(&sub_sensitivity);
|
||||
let weight_sum: f64 = importance_weights.iter().sum::<f64>();
|
||||
|
||||
let mean_amp: f64 = if weight_sum > 0.0 {
|
||||
frame.amplitudes.iter().zip(importance_weights.iter())
|
||||
.map(|(a, w)| a * w).sum::<f64>() / weight_sum
|
||||
} else {
|
||||
frame.amplitudes.iter().sum::<f64>() / n
|
||||
};
|
||||
|
||||
let intra_variance: f64 = if weight_sum > 0.0 {
|
||||
frame.amplitudes.iter().zip(importance_weights.iter())
|
||||
.map(|(a, w)| w * (a - mean_amp).powi(2)).sum::<f64>() / weight_sum
|
||||
} else {
|
||||
frame.amplitudes.iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / n
|
||||
};
|
||||
|
||||
let sub_variances = compute_subcarrier_variances(frame_history, n_sub);
|
||||
let temporal_variance: f64 = if sub_variances.is_empty() {
|
||||
intra_variance
|
||||
} else {
|
||||
sub_variances.iter().sum::<f64>() / sub_variances.len() as f64
|
||||
};
|
||||
let variance = intra_variance.max(temporal_variance);
|
||||
|
||||
let spectral_power: f64 = frame.amplitudes.iter().map(|a| a * a).sum::<f64>() / n;
|
||||
let half = frame.amplitudes.len() / 2;
|
||||
let motion_band_power = if half > 0 {
|
||||
frame.amplitudes[half..].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>()
|
||||
/ (frame.amplitudes.len() - half) as f64
|
||||
} else { 0.0 };
|
||||
let breathing_band_power = if half > 0 {
|
||||
frame.amplitudes[..half].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / half as f64
|
||||
} else { 0.0 };
|
||||
|
||||
let peak_idx = frame.amplitudes.iter().enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| i).unwrap_or(0);
|
||||
let dominant_freq_hz = peak_idx as f64 * 0.05;
|
||||
|
||||
let threshold = mean_amp * 1.2;
|
||||
let change_points = frame.amplitudes.windows(2)
|
||||
.filter(|w| (w[0] < threshold) != (w[1] < threshold)).count();
|
||||
|
||||
let temporal_motion_score = if let Some(prev_frame) = frame_history.back() {
|
||||
let n_cmp = n_sub.min(prev_frame.len());
|
||||
if n_cmp > 0 {
|
||||
let diff_energy: f64 = (0..n_cmp)
|
||||
.map(|k| (frame.amplitudes[k] - prev_frame[k]).powi(2)).sum::<f64>() / n_cmp as f64;
|
||||
let ref_energy = mean_amp * mean_amp + 1e-9;
|
||||
(diff_energy / ref_energy).sqrt().clamp(0.0, 1.0)
|
||||
} else { 0.0 }
|
||||
} else {
|
||||
(intra_variance / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0)
|
||||
};
|
||||
|
||||
let variance_motion = (temporal_variance / 10.0).clamp(0.0, 1.0);
|
||||
let mbp_motion = (motion_band_power / 25.0).clamp(0.0, 1.0);
|
||||
let cp_motion = (change_points as f64 / 15.0).clamp(0.0, 1.0);
|
||||
let motion_score = (temporal_motion_score * 0.4 + variance_motion * 0.2
|
||||
+ mbp_motion * 0.25 + cp_motion * 0.15).clamp(0.0, 1.0);
|
||||
|
||||
let snr_db = (frame.rssi as f64 - frame.noise_floor as f64).max(0.0);
|
||||
let snr_quality = (snr_db / 40.0).clamp(0.0, 1.0);
|
||||
let stability = (1.0 - (temporal_variance / (mean_amp * mean_amp + 1e-9)).clamp(0.0, 1.0)).max(0.0);
|
||||
let signal_quality = (snr_quality * 0.6 + stability * 0.4).clamp(0.0, 1.0);
|
||||
|
||||
let breathing_rate_hz = estimate_breathing_rate_hz(frame_history, sample_rate_hz);
|
||||
|
||||
let features = FeatureInfo {
|
||||
mean_rssi, variance, motion_band_power, breathing_band_power,
|
||||
dominant_freq_hz, change_points, spectral_power,
|
||||
};
|
||||
|
||||
let raw_classification = ClassificationInfo {
|
||||
motion_level: raw_classify(motion_score),
|
||||
presence: motion_score > 0.04,
|
||||
confidence: (0.4 + signal_quality * 0.3 + motion_score * 0.3).clamp(0.0, 1.0),
|
||||
};
|
||||
|
||||
(features, raw_classification, breathing_rate_hz, sub_variances, motion_score)
|
||||
}
|
||||
|
||||
// ── Classification ──────────────────────────────────────────────────────────
|
||||
|
||||
pub fn raw_classify(score: f64) -> String {
|
||||
if score > 0.25 { "active".into() }
|
||||
else if score > 0.12 { "present_moving".into() }
|
||||
else if score > 0.04 { "present_still".into() }
|
||||
else { "absent".into() }
|
||||
}
|
||||
|
||||
pub fn smooth_and_classify(state: &mut AppStateInner, raw: &mut ClassificationInfo, raw_motion: f64) {
|
||||
state.baseline_frames += 1;
|
||||
if state.baseline_frames < BASELINE_WARMUP {
|
||||
state.baseline_motion = state.baseline_motion * 0.9 + raw_motion * 0.1;
|
||||
} else if raw_motion < state.smoothed_motion + 0.05 {
|
||||
state.baseline_motion = state.baseline_motion * (1.0 - BASELINE_EMA_ALPHA)
|
||||
+ raw_motion * BASELINE_EMA_ALPHA;
|
||||
}
|
||||
let adjusted = (raw_motion - state.baseline_motion * 0.7).max(0.0);
|
||||
state.smoothed_motion = state.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
|
||||
let sm = state.smoothed_motion;
|
||||
let candidate = raw_classify(sm);
|
||||
if candidate == state.current_motion_level {
|
||||
state.debounce_counter = 0;
|
||||
state.debounce_candidate = candidate;
|
||||
} else if candidate == state.debounce_candidate {
|
||||
state.debounce_counter += 1;
|
||||
if state.debounce_counter >= DEBOUNCE_FRAMES {
|
||||
state.current_motion_level = candidate;
|
||||
state.debounce_counter = 0;
|
||||
}
|
||||
} else {
|
||||
state.debounce_candidate = candidate;
|
||||
state.debounce_counter = 1;
|
||||
}
|
||||
raw.motion_level = state.current_motion_level.clone();
|
||||
raw.presence = sm > 0.03;
|
||||
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
pub fn smooth_and_classify_node(ns: &mut NodeState, raw: &mut ClassificationInfo, raw_motion: f64) {
|
||||
ns.baseline_frames += 1;
|
||||
if ns.baseline_frames < BASELINE_WARMUP {
|
||||
ns.baseline_motion = ns.baseline_motion * 0.9 + raw_motion * 0.1;
|
||||
} else if raw_motion < ns.smoothed_motion + 0.05 {
|
||||
ns.baseline_motion = ns.baseline_motion * (1.0 - BASELINE_EMA_ALPHA) + raw_motion * BASELINE_EMA_ALPHA;
|
||||
}
|
||||
let adjusted = (raw_motion - ns.baseline_motion * 0.7).max(0.0);
|
||||
ns.smoothed_motion = ns.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
|
||||
let sm = ns.smoothed_motion;
|
||||
let candidate = raw_classify(sm);
|
||||
if candidate == ns.current_motion_level {
|
||||
ns.debounce_counter = 0;
|
||||
ns.debounce_candidate = candidate;
|
||||
} else if candidate == ns.debounce_candidate {
|
||||
ns.debounce_counter += 1;
|
||||
if ns.debounce_counter >= DEBOUNCE_FRAMES {
|
||||
ns.current_motion_level = candidate;
|
||||
ns.debounce_counter = 0;
|
||||
}
|
||||
} else {
|
||||
ns.debounce_candidate = candidate;
|
||||
ns.debounce_counter = 1;
|
||||
}
|
||||
raw.motion_level = ns.current_motion_level.clone();
|
||||
raw.presence = sm > 0.03;
|
||||
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
pub fn adaptive_override(state: &AppStateInner, features: &FeatureInfo, classification: &mut ClassificationInfo) {
|
||||
if let Some(ref model) = state.adaptive_model {
|
||||
let amps = state.frame_history.back().map(|v| v.as_slice()).unwrap_or(&[]);
|
||||
let feat_arr = adaptive_classifier::features_from_runtime(
|
||||
&serde_json::json!({
|
||||
"variance": features.variance,
|
||||
"motion_band_power": features.motion_band_power,
|
||||
"breathing_band_power": features.breathing_band_power,
|
||||
"spectral_power": features.spectral_power,
|
||||
"dominant_freq_hz": features.dominant_freq_hz,
|
||||
"change_points": features.change_points,
|
||||
"mean_rssi": features.mean_rssi,
|
||||
}),
|
||||
amps,
|
||||
);
|
||||
let (label, conf) = model.classify(&feat_arr);
|
||||
classification.motion_level = label.to_string();
|
||||
classification.presence = label != "absent";
|
||||
classification.confidence = (conf * 0.7 + classification.confidence * 0.3).clamp(0.0, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Vital signs smoothing ───────────────────────────────────────────────────
|
||||
|
||||
fn trimmed_mean(buf: &VecDeque<f64>) -> f64 {
|
||||
if buf.is_empty() { return 0.0; }
|
||||
let mut sorted: Vec<f64> = buf.iter().copied().collect();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let n = sorted.len();
|
||||
let trim = n / 4;
|
||||
let middle = &sorted[trim..n - trim.max(0)];
|
||||
if middle.is_empty() { sorted[n / 2] } else { middle.iter().sum::<f64>() / middle.len() as f64 }
|
||||
}
|
||||
|
||||
pub fn smooth_vitals(state: &mut AppStateInner, raw: &VitalSigns) -> VitalSigns {
|
||||
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
|
||||
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
|
||||
let hr_ok = state.smoothed_hr < 1.0 || (raw_hr - state.smoothed_hr).abs() < HR_MAX_JUMP;
|
||||
let br_ok = state.smoothed_br < 1.0 || (raw_br - state.smoothed_br).abs() < BR_MAX_JUMP;
|
||||
if hr_ok && raw_hr > 0.0 {
|
||||
state.hr_buffer.push_back(raw_hr);
|
||||
if state.hr_buffer.len() > VITAL_MEDIAN_WINDOW { state.hr_buffer.pop_front(); }
|
||||
}
|
||||
if br_ok && raw_br > 0.0 {
|
||||
state.br_buffer.push_back(raw_br);
|
||||
if state.br_buffer.len() > VITAL_MEDIAN_WINDOW { state.br_buffer.pop_front(); }
|
||||
}
|
||||
let trimmed_hr = trimmed_mean(&state.hr_buffer);
|
||||
let trimmed_br = trimmed_mean(&state.br_buffer);
|
||||
if trimmed_hr > 0.0 {
|
||||
if state.smoothed_hr < 1.0 { state.smoothed_hr = trimmed_hr; }
|
||||
else if (trimmed_hr - state.smoothed_hr).abs() > HR_DEAD_BAND {
|
||||
state.smoothed_hr = state.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
|
||||
}
|
||||
}
|
||||
if trimmed_br > 0.0 {
|
||||
if state.smoothed_br < 1.0 { state.smoothed_br = trimmed_br; }
|
||||
else if (trimmed_br - state.smoothed_br).abs() > BR_DEAD_BAND {
|
||||
state.smoothed_br = state.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
|
||||
}
|
||||
}
|
||||
state.smoothed_hr_conf = state.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
|
||||
state.smoothed_br_conf = state.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
|
||||
VitalSigns {
|
||||
breathing_rate_bpm: if state.smoothed_br > 1.0 { Some(state.smoothed_br) } else { None },
|
||||
heart_rate_bpm: if state.smoothed_hr > 1.0 { Some(state.smoothed_hr) } else { None },
|
||||
breathing_confidence: state.smoothed_br_conf,
|
||||
heartbeat_confidence: state.smoothed_hr_conf,
|
||||
signal_quality: raw.signal_quality,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn smooth_vitals_node(ns: &mut NodeState, raw: &VitalSigns) -> VitalSigns {
|
||||
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
|
||||
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
|
||||
let hr_ok = ns.smoothed_hr < 1.0 || (raw_hr - ns.smoothed_hr).abs() < HR_MAX_JUMP;
|
||||
let br_ok = ns.smoothed_br < 1.0 || (raw_br - ns.smoothed_br).abs() < BR_MAX_JUMP;
|
||||
if hr_ok && raw_hr > 0.0 {
|
||||
ns.hr_buffer.push_back(raw_hr);
|
||||
if ns.hr_buffer.len() > VITAL_MEDIAN_WINDOW { ns.hr_buffer.pop_front(); }
|
||||
}
|
||||
if br_ok && raw_br > 0.0 {
|
||||
ns.br_buffer.push_back(raw_br);
|
||||
if ns.br_buffer.len() > VITAL_MEDIAN_WINDOW { ns.br_buffer.pop_front(); }
|
||||
}
|
||||
let trimmed_hr = trimmed_mean(&ns.hr_buffer);
|
||||
let trimmed_br = trimmed_mean(&ns.br_buffer);
|
||||
if trimmed_hr > 0.0 {
|
||||
if ns.smoothed_hr < 1.0 { ns.smoothed_hr = trimmed_hr; }
|
||||
else if (trimmed_hr - ns.smoothed_hr).abs() > HR_DEAD_BAND {
|
||||
ns.smoothed_hr = ns.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
|
||||
}
|
||||
}
|
||||
if trimmed_br > 0.0 {
|
||||
if ns.smoothed_br < 1.0 { ns.smoothed_br = trimmed_br; }
|
||||
else if (trimmed_br - ns.smoothed_br).abs() > BR_DEAD_BAND {
|
||||
ns.smoothed_br = ns.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
|
||||
}
|
||||
}
|
||||
ns.smoothed_hr_conf = ns.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
|
||||
ns.smoothed_br_conf = ns.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
|
||||
VitalSigns {
|
||||
breathing_rate_bpm: if ns.smoothed_br > 1.0 { Some(ns.smoothed_br) } else { None },
|
||||
heart_rate_bpm: if ns.smoothed_hr > 1.0 { Some(ns.smoothed_hr) } else { None },
|
||||
breathing_confidence: ns.smoothed_br_conf,
|
||||
heartbeat_confidence: ns.smoothed_hr_conf,
|
||||
signal_quality: raw.signal_quality,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Multi-person estimation ─────────────────────────────────────────────────
|
||||
|
||||
pub fn fuse_multi_node_features(
|
||||
current_features: &FeatureInfo, node_states: &HashMap<u8, NodeState>,
|
||||
) -> FeatureInfo {
|
||||
let now = std::time::Instant::now();
|
||||
let active: Vec<(&FeatureInfo, f64)> = node_states.values()
|
||||
.filter(|ns| ns.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
|
||||
.filter_map(|ns| {
|
||||
let feat = ns.latest_features.as_ref()?;
|
||||
let rssi = ns.rssi_history.back().copied().unwrap_or(-80.0);
|
||||
Some((feat, rssi))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if active.len() <= 1 { return current_features.clone(); }
|
||||
|
||||
let max_rssi = active.iter().map(|(_, r)| *r).fold(f64::NEG_INFINITY, f64::max);
|
||||
let weights: Vec<f64> = active.iter()
|
||||
.map(|(_, r)| (1.0 + (r - max_rssi + 20.0) / 20.0).clamp(0.1, 1.0)).collect();
|
||||
let w_sum: f64 = weights.iter().sum::<f64>().max(1e-9);
|
||||
|
||||
FeatureInfo {
|
||||
variance: active.iter().zip(&weights).map(|((f, _), w)| f.variance * w).sum::<f64>() / w_sum,
|
||||
motion_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.motion_band_power * w).sum::<f64>() / w_sum,
|
||||
breathing_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.breathing_band_power * w).sum::<f64>() / w_sum,
|
||||
spectral_power: active.iter().zip(&weights).map(|((f, _), w)| f.spectral_power * w).sum::<f64>() / w_sum,
|
||||
dominant_freq_hz: active.iter().zip(&weights).map(|((f, _), w)| f.dominant_freq_hz * w).sum::<f64>() / w_sum,
|
||||
change_points: current_features.change_points,
|
||||
mean_rssi: active.iter().map(|(f, _)| f.mean_rssi).fold(f64::NEG_INFINITY, f64::max),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compute_person_score(feat: &FeatureInfo) -> f64 {
|
||||
let var_norm = (feat.variance / 300.0).clamp(0.0, 1.0);
|
||||
let cp_norm = (feat.change_points as f64 / 30.0).clamp(0.0, 1.0);
|
||||
let motion_norm = (feat.motion_band_power / 250.0).clamp(0.0, 1.0);
|
||||
let sp_norm = (feat.spectral_power / 500.0).clamp(0.0, 1.0);
|
||||
var_norm * 0.40 + cp_norm * 0.20 + motion_norm * 0.25 + sp_norm * 0.15
|
||||
}
|
||||
|
||||
pub fn estimate_persons_from_correlation(frame_history: &VecDeque<Vec<f64>>) -> usize {
|
||||
let n_frames = frame_history.len();
|
||||
if n_frames < 10 { return 1; }
|
||||
|
||||
let window: Vec<&Vec<f64>> = frame_history.iter().rev().take(20).collect();
|
||||
let n_sub = window[0].len().min(56);
|
||||
if n_sub < 4 { return 1; }
|
||||
let k = window.len() as f64;
|
||||
|
||||
let mut means = vec![0.0f64; n_sub];
|
||||
let mut variances = vec![0.0f64; n_sub];
|
||||
for frame in &window {
|
||||
for sc in 0..n_sub.min(frame.len()) { means[sc] += frame[sc] / k; }
|
||||
}
|
||||
for frame in &window {
|
||||
for sc in 0..n_sub.min(frame.len()) { variances[sc] += (frame[sc] - means[sc]).powi(2) / k; }
|
||||
}
|
||||
|
||||
let noise_floor = 1.0;
|
||||
let active: Vec<usize> = (0..n_sub).filter(|&sc| variances[sc] > noise_floor).collect();
|
||||
let m = active.len();
|
||||
if m < 3 { return if m == 0 { 0 } else { 1 }; }
|
||||
|
||||
let mut edges: Vec<(u64, u64, f64)> = Vec::new();
|
||||
let source = m as u64;
|
||||
let sink = (m + 1) as u64;
|
||||
let stds: Vec<f64> = active.iter().map(|&sc| variances[sc].sqrt().max(1e-9)).collect();
|
||||
|
||||
for i in 0..m {
|
||||
for j in (i + 1)..m {
|
||||
let mut cov = 0.0f64;
|
||||
for frame in &window {
|
||||
let (si, sj) = (active[i], active[j]);
|
||||
if si < frame.len() && sj < frame.len() {
|
||||
cov += (frame[si] - means[si]) * (frame[sj] - means[sj]) / k;
|
||||
}
|
||||
}
|
||||
let corr = (cov / (stds[i] * stds[j])).abs();
|
||||
if corr > 0.1 {
|
||||
let weight = corr * 10.0;
|
||||
edges.push((i as u64, j as u64, weight));
|
||||
edges.push((j as u64, i as u64, weight));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (max_var_idx, _) = active.iter().enumerate()
|
||||
.max_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
|
||||
.unwrap_or((0, &0));
|
||||
let (min_var_idx, _) = active.iter().enumerate()
|
||||
.min_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
|
||||
.unwrap_or((0, &0));
|
||||
if max_var_idx == min_var_idx { return 1; }
|
||||
|
||||
edges.push((source, max_var_idx as u64, 100.0));
|
||||
edges.push((min_var_idx as u64, sink, 100.0));
|
||||
|
||||
let mc: DynamicMinCut = match MinCutBuilder::new().exact().with_edges(edges.clone()).build() {
|
||||
Ok(mc) => mc,
|
||||
Err(_) => return 1,
|
||||
};
|
||||
|
||||
let cut_value = mc.min_cut_value();
|
||||
let total_edge_weight: f64 = edges.iter()
|
||||
.filter(|(s, t, _)| *s != source && *s != sink && *t != source && *t != sink)
|
||||
.map(|(_, _, w)| w).sum::<f64>() / 2.0;
|
||||
if total_edge_weight < 1e-9 { return 1; }
|
||||
|
||||
let cut_ratio = cut_value / total_edge_weight;
|
||||
if cut_ratio > 0.4 { 1 }
|
||||
else if cut_ratio > 0.15 { 2 }
|
||||
else { 3 }
|
||||
}
|
||||
|
||||
pub fn score_to_person_count(smoothed_score: f64, prev_count: usize) -> usize {
|
||||
match prev_count {
|
||||
0 | 1 => {
|
||||
if smoothed_score > 0.85 { 3 }
|
||||
else if smoothed_score > 0.70 { 2 }
|
||||
else { 1 }
|
||||
}
|
||||
2 => {
|
||||
if smoothed_score > 0.92 { 3 }
|
||||
else if smoothed_score < 0.55 { 1 }
|
||||
else { 2 }
|
||||
}
|
||||
_ => {
|
||||
if smoothed_score < 0.55 { 1 }
|
||||
else if smoothed_score < 0.78 { 2 }
|
||||
else { 3 }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a simulated ESP32 frame for testing/demo mode.
|
||||
pub fn generate_simulated_frame(tick: u64) -> Esp32Frame {
|
||||
let t = tick as f64 * 0.1;
|
||||
let n_sub = 56usize;
|
||||
let mut amplitudes = Vec::with_capacity(n_sub);
|
||||
let mut phases = Vec::with_capacity(n_sub);
|
||||
for i in 0..n_sub {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1 + t * 0.3).sin();
|
||||
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 2.0;
|
||||
amplitudes.push((base + noise).max(0.1));
|
||||
phases.push((i as f64 * 0.2 + t * 0.5).sin() * std::f64::consts::PI);
|
||||
}
|
||||
Esp32Frame {
|
||||
magic: 0xC511_0001, node_id: 1, n_antennas: 1, n_subcarriers: n_sub as u8,
|
||||
freq_mhz: 2437, sequence: tick as u32,
|
||||
rssi: (-40.0 + 5.0 * (t * 0.2).sin()) as i8, noise_floor: -90,
|
||||
amplitudes, phases,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a simple timestamp (epoch seconds) for recording IDs.
|
||||
pub fn chrono_timestamp() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
//! Bridge between sensing-server frame data and signal crate FieldModel
|
||||
//! for eigenvalue-based person counting.
|
||||
//!
|
||||
//! The FieldModel decomposes CSI observations into environmental drift and
|
||||
//! body perturbation via SVD eigenmodes. When calibrated, perturbation energy
|
||||
//! provides a physics-grounded occupancy estimate that supplements the
|
||||
//! score-based heuristic in `score_to_person_count`.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use wifi_densepose_signal::ruvsense::field_model::{CalibrationStatus, FieldModel, FieldModelConfig};
|
||||
|
||||
use super::score_to_person_count;
|
||||
|
||||
/// Number of recent frames to feed into perturbation extraction.
|
||||
const OCCUPANCY_WINDOW: usize = 50;
|
||||
|
||||
/// Perturbation energy threshold for detecting a second person.
|
||||
const ENERGY_THRESH_2: f64 = 12.0;
|
||||
/// Perturbation energy threshold for detecting a third person.
|
||||
const ENERGY_THRESH_3: f64 = 25.0;
|
||||
|
||||
/// Create a FieldModelConfig for single-link mode (one ESP32 node = one link).
|
||||
/// This avoids the DimensionMismatch error when feeding single-frame observations.
|
||||
pub fn single_link_config() -> FieldModelConfig {
|
||||
FieldModelConfig {
|
||||
n_links: 1,
|
||||
..FieldModelConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate occupancy using the FieldModel when calibrated, falling back
|
||||
/// to the score-based heuristic otherwise.
|
||||
///
|
||||
/// Prefers `estimate_occupancy()` (eigenvalue-based) when the model is
|
||||
/// calibrated and enough frames are available. Falls back to perturbation
|
||||
/// energy thresholds, then to the score heuristic.
|
||||
pub fn occupancy_or_fallback(
|
||||
field: &FieldModel,
|
||||
frame_history: &VecDeque<Vec<f64>>,
|
||||
smoothed_score: f64,
|
||||
prev_count: usize,
|
||||
) -> usize {
|
||||
match field.status() {
|
||||
CalibrationStatus::Fresh | CalibrationStatus::Stale => {
|
||||
let frames: Vec<Vec<f64>> = frame_history
|
||||
.iter()
|
||||
.rev()
|
||||
.take(OCCUPANCY_WINDOW)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if frames.is_empty() {
|
||||
return score_to_person_count(smoothed_score, prev_count);
|
||||
}
|
||||
|
||||
// Try eigenvalue-based occupancy first (best accuracy).
|
||||
match field.estimate_occupancy(&frames) {
|
||||
Ok(count) => return count,
|
||||
Err(_) => {} // fall through to perturbation energy
|
||||
}
|
||||
|
||||
// Fallback: perturbation energy thresholds.
|
||||
// FieldModel expects [n_links][n_subcarriers] — we use n_links=1.
|
||||
let observation = vec![frames[0].clone()];
|
||||
match field.extract_perturbation(&observation) {
|
||||
Ok(perturbation) => {
|
||||
if perturbation.total_energy > ENERGY_THRESH_3 {
|
||||
3
|
||||
} else if perturbation.total_energy > ENERGY_THRESH_2 {
|
||||
2
|
||||
} else if perturbation.total_energy > 1.0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
Err(_) => score_to_person_count(smoothed_score, prev_count),
|
||||
}
|
||||
}
|
||||
_ => score_to_person_count(smoothed_score, prev_count),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the latest frame to the FieldModel during calibration collection.
|
||||
///
|
||||
/// Only acts when the model status is `Collecting`. Wraps the latest frame
|
||||
/// as a single-link observation (n_links=1) and feeds it.
|
||||
pub fn maybe_feed_calibration(field: &mut FieldModel, frame_history: &VecDeque<Vec<f64>>) {
|
||||
if field.status() != CalibrationStatus::Collecting {
|
||||
return;
|
||||
}
|
||||
if let Some(latest) = frame_history.back() {
|
||||
// Single-link observation: [1][n_subcarriers]
|
||||
let observations = vec![latest.clone()];
|
||||
if let Err(e) = field.feed_calibration(&observations) {
|
||||
tracing::debug!("FieldModel calibration feed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse node positions from a semicolon-delimited string.
|
||||
///
|
||||
/// Format: `"x,y,z;x,y,z;..."` where each coordinate is an `f32`.
|
||||
/// Malformed entries are skipped with a warning log.
|
||||
pub fn parse_node_positions(input: &str) -> Vec<[f32; 3]> {
|
||||
if input.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
input
|
||||
.split(';')
|
||||
.enumerate()
|
||||
.filter_map(|(idx, triplet)| {
|
||||
let parts: Vec<&str> = triplet.split(',').collect();
|
||||
if parts.len() != 3 {
|
||||
tracing::warn!("Skipping malformed node position entry {idx}: '{triplet}' (expected x,y,z)");
|
||||
return None;
|
||||
}
|
||||
match (parts[0].parse::<f32>(), parts[1].parse::<f32>(), parts[2].parse::<f32>()) {
|
||||
(Ok(x), Ok(y), Ok(z)) => Some([x, y, z]),
|
||||
_ => {
|
||||
tracing::warn!("Skipping unparseable node position entry {idx}: '{triplet}'");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_positions() {
|
||||
let positions = parse_node_positions("0,0,1.5;3,0,1.5;1.5,3,1.5");
|
||||
assert_eq!(positions.len(), 3);
|
||||
assert_eq!(positions[0], [0.0, 0.0, 1.5]);
|
||||
assert_eq!(positions[1], [3.0, 0.0, 1.5]);
|
||||
assert_eq!(positions[2], [1.5, 3.0, 1.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_positions_empty() {
|
||||
let positions = parse_node_positions("");
|
||||
assert!(positions.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_positions_invalid() {
|
||||
let positions = parse_node_positions("abc;1,2,3");
|
||||
assert_eq!(positions.len(), 1);
|
||||
assert_eq!(positions[0], [1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_positions_partial_triplet() {
|
||||
let positions = parse_node_positions("1,2;3,4,5");
|
||||
assert_eq!(positions.len(), 1);
|
||||
assert_eq!(positions[0], [3.0, 4.0, 5.0]);
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,15 @@
|
||||
//! Replaces both ws_server.py and the Python HTTP server.
|
||||
|
||||
mod adaptive_classifier;
|
||||
pub mod cli;
|
||||
pub mod csi;
|
||||
mod field_bridge;
|
||||
mod multistatic_bridge;
|
||||
pub mod pose;
|
||||
mod rvf_container;
|
||||
mod rvf_pipeline;
|
||||
mod tracker_bridge;
|
||||
pub mod types;
|
||||
mod vital_signs;
|
||||
|
||||
// Training pipeline modules (exposed via lib.rs)
|
||||
@@ -53,6 +60,11 @@ use wifi_densepose_wifiscan::{
|
||||
};
|
||||
use wifi_densepose_wifiscan::parse_netsh_output as parse_netsh_bssid_output;
|
||||
|
||||
// Accuracy sprint: Kalman tracker, multistatic fusion, field model
|
||||
use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker;
|
||||
use wifi_densepose_signal::ruvsense::multistatic::{MultistaticFuser, MultistaticConfig};
|
||||
use wifi_densepose_signal::ruvsense::field_model::{FieldModel, CalibrationStatus};
|
||||
|
||||
// ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -145,6 +157,14 @@ struct Args {
|
||||
/// Build fingerprint index from embeddings (env|activity|temporal|person)
|
||||
#[arg(long, value_name = "TYPE")]
|
||||
build_index: Option<String>,
|
||||
|
||||
/// Node positions for multistatic fusion (format: "x,y,z;x,y,z;...")
|
||||
#[arg(long, env = "SENSING_NODE_POSITIONS")]
|
||||
node_positions: Option<String>,
|
||||
|
||||
/// Start field model calibration on boot (empty room required)
|
||||
#[arg(long)]
|
||||
calibrate: bool,
|
||||
}
|
||||
|
||||
// ── Data types ───────────────────────────────────────────────────────────────
|
||||
@@ -213,6 +233,9 @@ struct SensingUpdate {
|
||||
/// Estimated person count from CSI feature heuristics (1-3 for single ESP32).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
estimated_persons: Option<usize>,
|
||||
/// Per-node feature breakdown for multi-node deployments.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
node_features: Option<Vec<PerNodeFeatureInfo>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -280,9 +303,9 @@ struct BoundingBox {
|
||||
/// Each ESP32 node gets its own frame history, smoothing buffers, and vital
|
||||
/// sign detector so that data from different nodes is never mixed.
|
||||
struct NodeState {
|
||||
frame_history: VecDeque<Vec<f64>>,
|
||||
pub(crate) frame_history: VecDeque<Vec<f64>>,
|
||||
smoothed_person_score: f64,
|
||||
prev_person_count: usize,
|
||||
pub(crate) prev_person_count: usize,
|
||||
smoothed_motion: f64,
|
||||
current_motion_level: String,
|
||||
debounce_counter: u32,
|
||||
@@ -298,7 +321,7 @@ struct NodeState {
|
||||
rssi_history: VecDeque<f64>,
|
||||
vital_detector: VitalSignDetector,
|
||||
latest_vitals: VitalSigns,
|
||||
last_frame_time: Option<std::time::Instant>,
|
||||
pub(crate) last_frame_time: Option<std::time::Instant>,
|
||||
edge_vitals: Option<Esp32VitalsPacket>,
|
||||
/// Latest extracted features for cross-node fusion.
|
||||
latest_features: Option<FeatureInfo>,
|
||||
@@ -325,7 +348,7 @@ const MAX_BONE_CHANGE_RATIO: f64 = 0.20;
|
||||
const COHERENCE_WINDOW: usize = 20;
|
||||
|
||||
impl NodeState {
|
||||
fn new() -> Self {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
frame_history: VecDeque::new(),
|
||||
smoothed_person_score: 0.0,
|
||||
@@ -389,6 +412,18 @@ impl NodeState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-node feature info for WebSocket broadcasts (multi-node support).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PerNodeFeatureInfo {
|
||||
node_id: u8,
|
||||
features: FeatureInfo,
|
||||
classification: ClassificationInfo,
|
||||
rssi_dbm: f64,
|
||||
last_seen_ms: u64,
|
||||
frame_rate_hz: f64,
|
||||
stale: bool,
|
||||
}
|
||||
|
||||
/// Shared application state
|
||||
struct AppStateInner {
|
||||
latest_update: Option<SensingUpdate>,
|
||||
@@ -482,6 +517,15 @@ struct AppStateInner {
|
||||
/// Per-node sensing state for multi-node deployments.
|
||||
/// Keyed by `node_id` from the ESP32 frame header.
|
||||
node_states: HashMap<u8, NodeState>,
|
||||
// ── Accuracy sprint: Kalman tracker, multistatic fusion, eigenvalue counting ──
|
||||
/// Global Kalman-based pose tracker for stable person IDs and smoothed keypoints.
|
||||
pose_tracker: PoseTracker,
|
||||
/// Instant of last tracker update (for computing dt).
|
||||
last_tracker_instant: Option<std::time::Instant>,
|
||||
/// Attention-weighted multi-node CSI fusion engine.
|
||||
multistatic_fuser: MultistaticFuser,
|
||||
/// SVD-based room field model for eigenvalue person counting (None until calibration).
|
||||
field_model: Option<FieldModel>,
|
||||
}
|
||||
|
||||
/// If no ESP32 frame arrives within this duration, source reverts to offline.
|
||||
@@ -491,6 +535,31 @@ impl AppStateInner {
|
||||
/// Return the effective data source, accounting for ESP32 frame timeout.
|
||||
/// If the source is "esp32" but no frame has arrived in 5 seconds, returns
|
||||
/// "esp32:offline" so the UI can distinguish active vs stale connections.
|
||||
/// Person count: eigenvalue-based if field model is calibrated, else heuristic.
|
||||
/// Uses global frame_history if populated, otherwise the freshest per-node history.
|
||||
fn person_count(&self) -> usize {
|
||||
match self.field_model.as_ref() {
|
||||
Some(fm) => {
|
||||
// Prefer global frame_history (populated by wifi/simulate paths).
|
||||
// Fall back to freshest per-node history (populated by ESP32 paths).
|
||||
let history = if !self.frame_history.is_empty() {
|
||||
&self.frame_history
|
||||
} else {
|
||||
// Find the node with the most recent frame
|
||||
self.node_states.values()
|
||||
.filter(|ns| !ns.frame_history.is_empty())
|
||||
.max_by_key(|ns| ns.last_frame_time)
|
||||
.map(|ns| &ns.frame_history)
|
||||
.unwrap_or(&self.frame_history)
|
||||
};
|
||||
field_bridge::occupancy_or_fallback(
|
||||
fm, history, self.smoothed_person_score, self.prev_person_count,
|
||||
)
|
||||
}
|
||||
None => score_to_person_count(self.smoothed_person_score, self.prev_person_count),
|
||||
}
|
||||
}
|
||||
|
||||
fn effective_source(&self) -> String {
|
||||
if self.source == "esp32" {
|
||||
if let Some(last) = self.last_esp32_frame {
|
||||
@@ -639,12 +708,13 @@ fn parse_esp32_frame(buf: &[u8]) -> Option<Esp32Frame> {
|
||||
// [20..] I/Q data
|
||||
let node_id = buf[4];
|
||||
let n_antennas = buf[5];
|
||||
let n_subcarriers_u16 = u16::from_le_bytes([buf[6], buf[7]]);
|
||||
let n_subcarriers = n_subcarriers_u16 as u8; // truncate to u8 for Esp32Frame compat
|
||||
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]); // low 16 bits of u32
|
||||
let sequence = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
|
||||
let rssi = buf[16] as i8; // #332: was buf[14], 2 bytes off
|
||||
let noise_floor = buf[17] as i8; // #332: was buf[15], 2 bytes off
|
||||
let n_subcarriers = buf[6];
|
||||
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]);
|
||||
let sequence = u32::from_le_bytes([buf[10], buf[11], buf[12], buf[13]]);
|
||||
let rssi_raw = buf[14] as i8;
|
||||
// Fix RSSI sign: ensure it's always negative (dBm convention).
|
||||
let rssi = if rssi_raw > 0 { rssi_raw.saturating_neg() } else { rssi_raw };
|
||||
let noise_floor = buf[15] as i8;
|
||||
|
||||
let iq_start = 20;
|
||||
let n_pairs = n_antennas as usize * n_subcarriers as usize;
|
||||
@@ -1546,7 +1616,7 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
||||
let raw_score = compute_person_score(&features);
|
||||
s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10;
|
||||
let est_persons = if classification.presence {
|
||||
let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count);
|
||||
let count = s.person_count();
|
||||
s.prev_person_count = count;
|
||||
count
|
||||
} else {
|
||||
@@ -1583,12 +1653,16 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if est_persons > 0 { Some(est_persons) } else { None },
|
||||
node_features: None,
|
||||
};
|
||||
|
||||
// Populate persons from the sensing update.
|
||||
let persons = derive_pose_from_sensing(&update);
|
||||
if !persons.is_empty() {
|
||||
update.persons = Some(persons);
|
||||
// Populate persons from the sensing update (Kalman-smoothed via tracker).
|
||||
let raw_persons = derive_pose_from_sensing(&update);
|
||||
let tracked = tracker_bridge::tracker_update(
|
||||
&mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons,
|
||||
);
|
||||
if !tracked.is_empty() {
|
||||
update.persons = Some(tracked);
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
@@ -1679,7 +1753,7 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) {
|
||||
let raw_score = compute_person_score(&features);
|
||||
s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10;
|
||||
let est_persons = if classification.presence {
|
||||
let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count);
|
||||
let count = s.person_count();
|
||||
s.prev_person_count = count;
|
||||
count
|
||||
} else {
|
||||
@@ -1716,11 +1790,15 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) {
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if est_persons > 0 { Some(est_persons) } else { None },
|
||||
node_features: None,
|
||||
};
|
||||
|
||||
let persons = derive_pose_from_sensing(&update);
|
||||
if !persons.is_empty() {
|
||||
update.persons = Some(persons);
|
||||
let raw_persons = derive_pose_from_sensing(&update);
|
||||
let tracked = tracker_bridge::tracker_update(
|
||||
&mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons,
|
||||
);
|
||||
if !tracked.is_empty() {
|
||||
update.persons = Some(tracked);
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
@@ -1897,9 +1975,13 @@ async fn handle_ws_pose_client(mut socket: WebSocket, state: SharedState) {
|
||||
keypoints,
|
||||
zone: "zone_1".into(),
|
||||
}]
|
||||
}).unwrap_or_else(|| derive_pose_from_sensing(&sensing))
|
||||
}).unwrap_or_else(|| {
|
||||
// Prefer tracked persons from broadcast if available
|
||||
sensing.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(&sensing))
|
||||
})
|
||||
} else {
|
||||
derive_pose_from_sensing(&sensing)
|
||||
// Prefer tracked persons from broadcast if available
|
||||
sensing.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(&sensing))
|
||||
};
|
||||
|
||||
let pose_msg = serde_json::json!({
|
||||
@@ -2598,7 +2680,7 @@ async fn api_info(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
async fn pose_current(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
let persons = match &s.latest_update {
|
||||
Some(update) => derive_pose_from_sensing(update),
|
||||
Some(update) => update.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(update)),
|
||||
None => vec![],
|
||||
};
|
||||
Json(serde_json::json!({
|
||||
@@ -3149,6 +3231,88 @@ async fn adaptive_unload(State(state): State<SharedState>) -> Json<serde_json::V
|
||||
Json(serde_json::json!({ "success": true, "message": "Adaptive model unloaded." }))
|
||||
}
|
||||
|
||||
// ── Field model calibration endpoints (eigenvalue person counting) ──────────
|
||||
|
||||
async fn calibration_start(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let mut s = state.write().await;
|
||||
// Guard: don't discard an in-progress or fresh calibration
|
||||
if let Some(ref fm) = s.field_model {
|
||||
match fm.status() {
|
||||
CalibrationStatus::Collecting => {
|
||||
return Json(serde_json::json!({
|
||||
"success": false,
|
||||
"error": "Calibration already in progress. Call /calibration/stop first.",
|
||||
"frame_count": fm.calibration_frame_count(),
|
||||
}));
|
||||
}
|
||||
CalibrationStatus::Fresh => {
|
||||
return Json(serde_json::json!({
|
||||
"success": false,
|
||||
"error": "A fresh calibration already exists. Call /calibration/stop or wait for expiry.",
|
||||
}));
|
||||
}
|
||||
_ => {} // Stale/Expired/Uncalibrated — ok to recalibrate
|
||||
}
|
||||
}
|
||||
match FieldModel::new(field_bridge::single_link_config()) {
|
||||
Ok(fm) => {
|
||||
s.field_model = Some(fm);
|
||||
Json(serde_json::json!({
|
||||
"success": true,
|
||||
"message": "Calibration started — keep room empty while frames accumulate.",
|
||||
}))
|
||||
}
|
||||
Err(e) => Json(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("{e}"),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn calibration_stop(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let mut s = state.write().await;
|
||||
if let Some(ref mut fm) = s.field_model {
|
||||
let ts = chrono::Utc::now().timestamp_micros() as u64;
|
||||
match fm.finalize_calibration(ts, 0) {
|
||||
Ok(modes) => {
|
||||
let baseline = modes.baseline_eigenvalue_count;
|
||||
let variance_explained = modes.variance_explained;
|
||||
info!("Field model calibrated: baseline_eigenvalues={baseline}, variance_explained={variance_explained:.2}");
|
||||
Json(serde_json::json!({
|
||||
"success": true,
|
||||
"baseline_eigenvalue_count": baseline,
|
||||
"variance_explained": variance_explained,
|
||||
"frame_count": fm.calibration_frame_count(),
|
||||
}))
|
||||
}
|
||||
Err(e) => Json(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("{e}"),
|
||||
})),
|
||||
}
|
||||
} else {
|
||||
Json(serde_json::json!({
|
||||
"success": false,
|
||||
"error": "No field model active — call /calibration/start first.",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
async fn calibration_status(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
match s.field_model.as_ref() {
|
||||
Some(fm) => Json(serde_json::json!({
|
||||
"active": true,
|
||||
"status": format!("{:?}", fm.status()),
|
||||
"frame_count": fm.calibration_frame_count(),
|
||||
})),
|
||||
None => Json(serde_json::json!({
|
||||
"active": false,
|
||||
"status": "none",
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a simple timestamp string (epoch seconds) for recording IDs.
|
||||
fn chrono_timestamp() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
@@ -3295,6 +3459,34 @@ async fn sona_activate(
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/v1/nodes — per-node health and feature info.
|
||||
async fn nodes_endpoint(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
let now = std::time::Instant::now();
|
||||
let nodes: Vec<serde_json::Value> = s.node_states.iter()
|
||||
.map(|(&id, ns)| {
|
||||
let elapsed_ms = ns.last_frame_time
|
||||
.map(|t| now.duration_since(t).as_millis() as u64)
|
||||
.unwrap_or(999999);
|
||||
let stale = elapsed_ms > 5000;
|
||||
let status = if stale { "stale" } else { "active" };
|
||||
let rssi = ns.rssi_history.back().copied().unwrap_or(-90.0);
|
||||
serde_json::json!({
|
||||
"node_id": id,
|
||||
"status": status,
|
||||
"last_seen_ms": elapsed_ms,
|
||||
"rssi_dbm": rssi,
|
||||
"motion_level": &ns.current_motion_level,
|
||||
"person_count": ns.prev_person_count,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Json(serde_json::json!({
|
||||
"nodes": nodes,
|
||||
"total": nodes.len(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn info_page() -> Html<String> {
|
||||
Html(format!(
|
||||
"<html><body>\
|
||||
@@ -3386,15 +3578,33 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
||||
else if vitals.presence { 0.3 }
|
||||
else { 0.05 };
|
||||
|
||||
// Aggregate person count across all active nodes.
|
||||
// Use max (not sum) because nodes in the same room see the
|
||||
// same people — summing would double-count.
|
||||
// Aggregate person count: gate on presence first (matching WiFi path).
|
||||
let now = std::time::Instant::now();
|
||||
let total_persons: usize = s.node_states.values()
|
||||
.filter(|n| n.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
|
||||
.map(|n| n.prev_person_count)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let total_persons = if vitals.presence {
|
||||
let (fused, fallback_count) = multistatic_bridge::fuse_or_fallback(
|
||||
&s.multistatic_fuser, &s.node_states,
|
||||
);
|
||||
match fused {
|
||||
Some(ref f) => {
|
||||
let score = multistatic_bridge::compute_person_score_from_amplitudes(&f.fused_amplitude);
|
||||
s.smoothed_person_score = s.smoothed_person_score * 0.90 + score * 0.10;
|
||||
let count = s.person_count();
|
||||
s.prev_person_count = count;
|
||||
count.max(1) // presence=true => at least 1
|
||||
}
|
||||
None => fallback_count.unwrap_or(0).max(1),
|
||||
}
|
||||
} else {
|
||||
s.prev_person_count = 0;
|
||||
0
|
||||
};
|
||||
|
||||
// Feed field model calibration if active (use per-node history for ESP32).
|
||||
if let Some(ref mut fm) = s.field_model {
|
||||
if let Some(ns) = s.node_states.get(&node_id) {
|
||||
field_bridge::maybe_feed_calibration(fm, &ns.frame_history);
|
||||
}
|
||||
}
|
||||
|
||||
// Build nodes array with all active nodes.
|
||||
let active_nodes: Vec<NodeInfo> = s.node_states.iter()
|
||||
@@ -3471,17 +3681,15 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if total_persons > 0 { Some(total_persons) } else { None },
|
||||
node_features: None,
|
||||
};
|
||||
|
||||
let mut persons = derive_pose_from_sensing(&update);
|
||||
// RuVector Phase 2: temporal smoothing + coherence gating
|
||||
{
|
||||
let ns = s.node_states.entry(node_id).or_insert_with(NodeState::new);
|
||||
ns.update_coherence(vitals.motion_energy as f64);
|
||||
apply_temporal_smoothing(&mut persons, ns);
|
||||
}
|
||||
if !persons.is_empty() {
|
||||
update.persons = Some(persons);
|
||||
let raw_persons = derive_pose_from_sensing(&update);
|
||||
let tracked = tracker_bridge::tracker_update(
|
||||
&mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons,
|
||||
);
|
||||
if !tracked.is_empty() {
|
||||
update.persons = Some(tracked);
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
@@ -3618,23 +3826,32 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
||||
else if classification.motion_level == "present_still" { 0.3 }
|
||||
else { 0.05 };
|
||||
|
||||
// Aggregate person count across all active nodes.
|
||||
// Use max (not sum) because nodes in the same room see the
|
||||
// same people — summing would double-count.
|
||||
// Aggregate person count: gate on presence first (matching WiFi path).
|
||||
let now = std::time::Instant::now();
|
||||
let total_persons: usize = s.node_states.values()
|
||||
.filter(|n| n.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
|
||||
.map(|n| n.prev_person_count)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let total_persons = if classification.presence {
|
||||
let (fused, fallback_count) = multistatic_bridge::fuse_or_fallback(
|
||||
&s.multistatic_fuser, &s.node_states,
|
||||
);
|
||||
match fused {
|
||||
Some(ref f) => {
|
||||
let score = multistatic_bridge::compute_person_score_from_amplitudes(&f.fused_amplitude);
|
||||
s.smoothed_person_score = s.smoothed_person_score * 0.90 + score * 0.10;
|
||||
let count = s.person_count();
|
||||
s.prev_person_count = count;
|
||||
count.max(1)
|
||||
}
|
||||
None => fallback_count.unwrap_or(0).max(1),
|
||||
}
|
||||
} else {
|
||||
s.prev_person_count = 0;
|
||||
0
|
||||
};
|
||||
|
||||
// Boost classification confidence with multi-node coverage.
|
||||
let n_active = s.node_states.values()
|
||||
.filter(|ns| ns.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
|
||||
.count();
|
||||
if n_active > 1 {
|
||||
classification.confidence = (classification.confidence
|
||||
* (1.0 + 0.15 * (n_active as f64 - 1.0))).clamp(0.0, 1.0);
|
||||
// Feed field model calibration if active (use per-node history for ESP32).
|
||||
if let Some(ref mut fm) = s.field_model {
|
||||
if let Some(ns) = s.node_states.get(&node_id) {
|
||||
field_bridge::maybe_feed_calibration(fm, &ns.frame_history);
|
||||
}
|
||||
}
|
||||
|
||||
// Build nodes array with all active nodes.
|
||||
@@ -3674,17 +3891,15 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if total_persons > 0 { Some(total_persons) } else { None },
|
||||
node_features: None,
|
||||
};
|
||||
|
||||
let mut persons = derive_pose_from_sensing(&update);
|
||||
// RuVector Phase 2: temporal smoothing + coherence gating
|
||||
{
|
||||
let ns = s.node_states.entry(node_id).or_insert_with(NodeState::new);
|
||||
ns.update_coherence(features.motion_band_power);
|
||||
apply_temporal_smoothing(&mut persons, ns);
|
||||
}
|
||||
if !persons.is_empty() {
|
||||
update.persons = Some(persons);
|
||||
let raw_persons = derive_pose_from_sensing(&update);
|
||||
let tracked = tracker_bridge::tracker_update(
|
||||
&mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons,
|
||||
);
|
||||
if !tracked.is_empty() {
|
||||
update.persons = Some(tracked);
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
@@ -3764,7 +3979,7 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) {
|
||||
let raw_score = compute_person_score(&features);
|
||||
s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10;
|
||||
let est_persons = if classification.presence {
|
||||
let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count);
|
||||
let count = s.person_count();
|
||||
s.prev_person_count = count;
|
||||
count
|
||||
} else {
|
||||
@@ -3811,12 +4026,16 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) {
|
||||
},
|
||||
persons: None,
|
||||
estimated_persons: if est_persons > 0 { Some(est_persons) } else { None },
|
||||
node_features: None,
|
||||
};
|
||||
|
||||
// Populate persons from the sensing update.
|
||||
let persons = derive_pose_from_sensing(&update);
|
||||
if !persons.is_empty() {
|
||||
update.persons = Some(persons);
|
||||
// Populate persons from the sensing update (Kalman-smoothed via tracker).
|
||||
let raw_persons = derive_pose_from_sensing(&update);
|
||||
let tracked = tracker_bridge::tracker_update(
|
||||
&mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons,
|
||||
);
|
||||
if !tracked.is_empty() {
|
||||
update.persons = Some(tracked);
|
||||
}
|
||||
|
||||
if update.classification.presence {
|
||||
@@ -4445,6 +4664,29 @@ async fn main() {
|
||||
m
|
||||
}),
|
||||
node_states: HashMap::new(),
|
||||
// Accuracy sprint
|
||||
pose_tracker: PoseTracker::new(),
|
||||
last_tracker_instant: None,
|
||||
multistatic_fuser: {
|
||||
let mut fuser = MultistaticFuser::with_config(MultistaticConfig {
|
||||
min_nodes: 1, // single-node passthrough
|
||||
..Default::default()
|
||||
});
|
||||
if let Some(ref pos_str) = args.node_positions {
|
||||
let positions = field_bridge::parse_node_positions(pos_str);
|
||||
if !positions.is_empty() {
|
||||
info!("Configured {} node positions for multistatic fusion", positions.len());
|
||||
fuser.set_node_positions(positions);
|
||||
}
|
||||
}
|
||||
fuser
|
||||
},
|
||||
field_model: if args.calibrate {
|
||||
info!("Field model calibration enabled — room should be empty during startup");
|
||||
FieldModel::new(field_bridge::single_link_config()).ok()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}));
|
||||
|
||||
// Start background tasks based on source
|
||||
@@ -4498,6 +4740,8 @@ async fn main() {
|
||||
.route("/api/v1/metrics", get(health_metrics))
|
||||
// Sensing endpoints
|
||||
.route("/api/v1/sensing/latest", get(latest))
|
||||
// Per-node health endpoint
|
||||
.route("/api/v1/nodes", get(nodes_endpoint))
|
||||
// Vital sign endpoints
|
||||
.route("/api/v1/vital-signs", get(vital_signs_endpoint))
|
||||
.route("/api/v1/edge-vitals", get(edge_vitals_endpoint))
|
||||
@@ -4539,6 +4783,10 @@ async fn main() {
|
||||
.route("/api/v1/adaptive/train", post(adaptive_train))
|
||||
.route("/api/v1/adaptive/status", get(adaptive_status))
|
||||
.route("/api/v1/adaptive/unload", post(adaptive_unload))
|
||||
// Field model calibration (eigenvalue-based person counting)
|
||||
.route("/api/v1/calibration/start", post(calibration_start))
|
||||
.route("/api/v1/calibration/stop", post(calibration_stop))
|
||||
.route("/api/v1/calibration/status", get(calibration_status))
|
||||
// Static UI files
|
||||
.nest_service("/ui", ServeDir::new(&ui_path))
|
||||
.layer(SetResponseHeaderLayer::overriding(
|
||||
|
||||
+264
@@ -0,0 +1,264 @@
|
||||
//! Bridge between sensing-server per-node state and the signal crate's
|
||||
//! `MultistaticFuser` for attention-weighted CSI fusion across ESP32 nodes.
|
||||
//!
|
||||
//! This module converts the server's `NodeState` (f64 amplitude history) into
|
||||
//! `MultiBandCsiFrame`s that the multistatic fusion pipeline expects, then
|
||||
//! drives `MultistaticFuser::fuse` with a graceful fallback when fusion fails
|
||||
//! (e.g. insufficient nodes or timestamp spread).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use wifi_densepose_signal::hardware_norm::{CanonicalCsiFrame, HardwareType};
|
||||
use wifi_densepose_signal::ruvsense::multiband::MultiBandCsiFrame;
|
||||
use wifi_densepose_signal::ruvsense::multistatic::{FusedSensingFrame, MultistaticFuser};
|
||||
|
||||
use super::NodeState;
|
||||
|
||||
/// Maximum age for a node frame to be considered active (10 seconds).
|
||||
const STALE_THRESHOLD: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Default WiFi channel frequency (MHz) used for single-channel frames.
|
||||
const DEFAULT_FREQ_MHZ: u32 = 2437; // Channel 6
|
||||
|
||||
/// Monotonic reference point for timestamp generation. All node timestamps
|
||||
/// are relative to this instant, avoiding wall-clock/monotonic mixing issues.
|
||||
static EPOCH: LazyLock<Instant> = LazyLock::new(Instant::now);
|
||||
|
||||
/// Convert a single `NodeState` into a `MultiBandCsiFrame` suitable for
|
||||
/// multistatic fusion.
|
||||
///
|
||||
/// Returns `None` when the node has no frame history or no recorded
|
||||
/// `last_frame_time`.
|
||||
pub fn node_frame_from_state(node_id: u8, ns: &NodeState) -> Option<MultiBandCsiFrame> {
|
||||
let last_time = ns.last_frame_time.as_ref()?;
|
||||
let latest = ns.frame_history.back()?;
|
||||
if latest.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let amplitude: Vec<f32> = latest.iter().map(|&v| v as f32).collect();
|
||||
let n_sub = amplitude.len();
|
||||
let phase = vec![0.0_f32; n_sub];
|
||||
|
||||
// Monotonic timestamp: microseconds since a shared process-local epoch.
|
||||
// All nodes use the same reference so the fuser's guard_interval_us check
|
||||
// compares apples to apples. No wall-clock mixing (immune to NTP jumps).
|
||||
let timestamp_us = last_time.duration_since(*EPOCH).as_micros() as u64;
|
||||
|
||||
let canonical = CanonicalCsiFrame {
|
||||
amplitude,
|
||||
phase,
|
||||
hardware_type: HardwareType::Esp32S3,
|
||||
};
|
||||
|
||||
Some(MultiBandCsiFrame {
|
||||
node_id,
|
||||
timestamp_us,
|
||||
channel_frames: vec![canonical],
|
||||
frequencies_mhz: vec![DEFAULT_FREQ_MHZ],
|
||||
coherence: 1.0, // single-channel, perfect self-coherence
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect `MultiBandCsiFrame`s from all active nodes.
|
||||
///
|
||||
/// A node is considered active if its `last_frame_time` is within
|
||||
/// [`STALE_THRESHOLD`] of `now`.
|
||||
pub fn node_frames_from_states(node_states: &HashMap<u8, NodeState>) -> Vec<MultiBandCsiFrame> {
|
||||
let now = Instant::now();
|
||||
let mut frames = Vec::with_capacity(node_states.len());
|
||||
|
||||
for (&node_id, ns) in node_states {
|
||||
// Skip stale nodes
|
||||
if let Some(ref t) = ns.last_frame_time {
|
||||
if now.duration_since(*t) > STALE_THRESHOLD {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(frame) = node_frame_from_state(node_id, ns) {
|
||||
frames.push(frame);
|
||||
}
|
||||
}
|
||||
|
||||
frames
|
||||
}
|
||||
|
||||
/// Attempt multistatic fusion; fall back to max per-node person count on failure.
|
||||
///
|
||||
/// Returns `(fused_frame, fallback_person_count)`. When fusion succeeds,
|
||||
/// `fallback_person_count` is `None` — the caller must compute count from
|
||||
/// the fused amplitudes. On failure, returns the maximum per-node count
|
||||
/// (not the sum, to avoid double-counting overlapping coverage).
|
||||
pub fn fuse_or_fallback(
|
||||
fuser: &MultistaticFuser,
|
||||
node_states: &HashMap<u8, NodeState>,
|
||||
) -> (Option<FusedSensingFrame>, Option<usize>) {
|
||||
let frames = node_frames_from_states(node_states);
|
||||
if frames.is_empty() {
|
||||
return (None, Some(0));
|
||||
}
|
||||
|
||||
match fuser.fuse(&frames) {
|
||||
Ok(fused) => {
|
||||
// Caller must compute person count from fused amplitudes.
|
||||
(Some(fused), None)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("Multistatic fusion failed ({e}), using per-node max fallback");
|
||||
// Use max (not sum) to avoid double-counting when nodes have overlapping coverage.
|
||||
let max_count: usize = node_states
|
||||
.values()
|
||||
.filter(|ns| {
|
||||
ns.last_frame_time
|
||||
.map(|t| t.elapsed() <= STALE_THRESHOLD)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.map(|ns| ns.prev_person_count)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
(None, Some(max_count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a person-presence score from fused amplitude data.
|
||||
///
|
||||
/// Uses the squared coefficient of variation (variance / mean^2) as a
|
||||
/// lightweight proxy for body-induced CSI perturbation. A flat amplitude
|
||||
/// vector (no person) yields a score near zero; a vector with high variance
|
||||
/// relative to its mean (person moving) yields a score approaching 1.0.
|
||||
pub fn compute_person_score_from_amplitudes(amplitudes: &[f32]) -> f64 {
|
||||
if amplitudes.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let n = amplitudes.len() as f64;
|
||||
let sum: f64 = amplitudes.iter().map(|&a| a as f64).sum();
|
||||
let mean = sum / n;
|
||||
|
||||
let variance: f64 = amplitudes.iter().map(|&a| {
|
||||
let diff = (a as f64) - mean;
|
||||
diff * diff
|
||||
}).sum::<f64>() / n;
|
||||
|
||||
let score = variance / (mean * mean + 1e-10);
|
||||
score.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Helper: build a minimal NodeState for testing. Uses `NodeState::new()`
|
||||
/// then mutates the `pub(crate)` fields the bridge needs.
|
||||
fn make_node_state(
|
||||
frame_history: VecDeque<Vec<f64>>,
|
||||
last_frame_time: Option<Instant>,
|
||||
prev_person_count: usize,
|
||||
) -> NodeState {
|
||||
let mut ns = NodeState::new();
|
||||
ns.frame_history = frame_history;
|
||||
ns.last_frame_time = last_frame_time;
|
||||
ns.prev_person_count = prev_person_count;
|
||||
ns
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_frame_from_empty_state() {
|
||||
let ns = make_node_state(VecDeque::new(), Some(Instant::now()), 0);
|
||||
assert!(node_frame_from_state(1, &ns).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_frame_from_state_no_time() {
|
||||
let mut history = VecDeque::new();
|
||||
history.push_back(vec![1.0, 2.0, 3.0]);
|
||||
let ns = make_node_state(history, None, 0);
|
||||
assert!(node_frame_from_state(1, &ns).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_frame_conversion() {
|
||||
let mut history = VecDeque::new();
|
||||
history.push_back(vec![10.0, 20.0, 30.5]);
|
||||
let ns = make_node_state(history, Some(Instant::now()), 0);
|
||||
|
||||
let frame = node_frame_from_state(42, &ns).expect("should produce a frame");
|
||||
assert_eq!(frame.node_id, 42);
|
||||
assert_eq!(frame.channel_frames.len(), 1);
|
||||
|
||||
let ch = &frame.channel_frames[0];
|
||||
assert_eq!(ch.amplitude.len(), 3);
|
||||
assert!((ch.amplitude[0] - 10.0_f32).abs() < f32::EPSILON);
|
||||
assert!((ch.amplitude[1] - 20.0_f32).abs() < f32::EPSILON);
|
||||
assert!((ch.amplitude[2] - 30.5_f32).abs() < f32::EPSILON);
|
||||
// Phase should be all zeros
|
||||
assert!(ch.phase.iter().all(|&p| p == 0.0));
|
||||
assert_eq!(ch.hardware_type, HardwareType::Esp32S3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stale_node_excluded() {
|
||||
let mut states: HashMap<u8, NodeState> = HashMap::new();
|
||||
|
||||
// Active node: frame just received
|
||||
let mut active_history = VecDeque::new();
|
||||
active_history.push_back(vec![1.0, 2.0]);
|
||||
states.insert(1, make_node_state(active_history, Some(Instant::now()), 1));
|
||||
|
||||
// Stale node: frame 20 seconds ago
|
||||
let mut stale_history = VecDeque::new();
|
||||
stale_history.push_back(vec![3.0, 4.0]);
|
||||
let stale_time = Instant::now() - Duration::from_secs(20);
|
||||
states.insert(2, make_node_state(stale_history, Some(stale_time), 1));
|
||||
|
||||
let frames = node_frames_from_states(&states);
|
||||
assert_eq!(frames.len(), 1, "stale node should be excluded");
|
||||
assert_eq!(frames[0].node_id, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_person_score_empty() {
|
||||
assert!((compute_person_score_from_amplitudes(&[]) - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_person_score_flat() {
|
||||
// Constant amplitude => variance = 0 => score ~ 0
|
||||
let flat = vec![5.0_f32; 64];
|
||||
let score = compute_person_score_from_amplitudes(&flat);
|
||||
assert!(score < 0.001, "flat signal should have near-zero score, got {score}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_person_score_varied() {
|
||||
// High variance relative to mean should produce a positive score
|
||||
let varied: Vec<f32> = (0..64).map(|i| if i % 2 == 0 { 1.0 } else { 10.0 }).collect();
|
||||
let score = compute_person_score_from_amplitudes(&varied);
|
||||
assert!(score > 0.1, "varied signal should have positive score, got {score}");
|
||||
assert!(score <= 1.0, "score should be clamped to 1.0, got {score}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_person_score_clamped() {
|
||||
// Near-zero mean with non-zero variance => would blow up without clamp
|
||||
let vals = vec![0.0_f32, 0.0, 0.0, 0.001];
|
||||
let score = compute_person_score_from_amplitudes(&vals);
|
||||
assert!(score <= 1.0, "score must be clamped to 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuse_or_fallback_empty() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
let states: HashMap<u8, NodeState> = HashMap::new();
|
||||
let (fused, count) = fuse_or_fallback(&fuser, &states);
|
||||
assert!(fused.is_none());
|
||||
assert_eq!(count, Some(0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
//! Skeleton derivation, pose estimation, and temporal smoothing.
|
||||
|
||||
use crate::types::*;
|
||||
|
||||
/// Expected bone lengths in pixel-space for the COCO-17 skeleton.
|
||||
pub const POSE_BONE_PAIRS: &[(usize, usize)] = &[
|
||||
(5, 7), (7, 9), (6, 8), (8, 10),
|
||||
(5, 11), (6, 12),
|
||||
(11, 13), (13, 15), (12, 14), (14, 16),
|
||||
(5, 6), (11, 12),
|
||||
];
|
||||
|
||||
const TORSO_KP: [usize; 4] = [5, 6, 11, 12];
|
||||
const EXTREMITY_KP: [usize; 4] = [9, 10, 15, 16];
|
||||
|
||||
pub fn derive_single_person_pose(
|
||||
update: &SensingUpdate, person_idx: usize, total_persons: usize,
|
||||
) -> PersonDetection {
|
||||
let cls = &update.classification;
|
||||
let feat = &update.features;
|
||||
|
||||
let phase_offset = person_idx as f64 * 2.094;
|
||||
let half = (total_persons as f64 - 1.0) / 2.0;
|
||||
let person_x_offset = (person_idx as f64 - half) * 120.0;
|
||||
let conf_decay = 1.0 - person_idx as f64 * 0.15;
|
||||
|
||||
let motion_score = (feat.motion_band_power / 15.0).clamp(0.0, 1.0);
|
||||
let is_walking = motion_score > 0.55;
|
||||
let breath_amp = (feat.breathing_band_power * 4.0).clamp(0.0, 12.0);
|
||||
|
||||
let breath_phase = if let Some(ref vs) = update.vital_signs {
|
||||
let bpm = vs.breathing_rate_bpm.unwrap_or(15.0);
|
||||
let freq = (bpm / 60.0).clamp(0.1, 0.5);
|
||||
(update.tick as f64 * freq * 0.02 * std::f64::consts::TAU + phase_offset).sin()
|
||||
} else {
|
||||
(update.tick as f64 * 0.02 + phase_offset).sin()
|
||||
};
|
||||
|
||||
let lean_x = (feat.dominant_freq_hz / 5.0 - 1.0).clamp(-1.0, 1.0) * 18.0;
|
||||
let stride_x = if is_walking {
|
||||
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.06 + phase_offset).sin();
|
||||
stride_phase * 20.0 * motion_score
|
||||
} else { 0.0 };
|
||||
|
||||
let burst = (feat.change_points as f64 / 20.0).clamp(0.0, 0.3);
|
||||
let noise_seed = person_idx as f64 * 97.1;
|
||||
let noise_val = (noise_seed.sin() * 43758.545).fract();
|
||||
let snr_factor = ((feat.variance - 0.5) / 10.0).clamp(0.0, 1.0);
|
||||
let base_confidence = cls.confidence * (0.6 + 0.4 * snr_factor) * conf_decay;
|
||||
|
||||
let base_x = 320.0 + stride_x + lean_x * 0.5 + person_x_offset;
|
||||
let base_y = 240.0 - motion_score * 8.0;
|
||||
|
||||
let kp_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle",
|
||||
];
|
||||
|
||||
let kp_offsets: [(f64, f64); 17] = [
|
||||
(0.0, -80.0), (-8.0, -88.0), (8.0, -88.0), (-16.0, -82.0), (16.0, -82.0),
|
||||
(-30.0, -50.0), (30.0, -50.0), (-45.0, -15.0), (45.0, -15.0),
|
||||
(-50.0, 20.0), (50.0, 20.0), (-20.0, 20.0), (20.0, 20.0),
|
||||
(-22.0, 70.0), (22.0, 70.0), (-24.0, 120.0), (24.0, 120.0),
|
||||
];
|
||||
|
||||
let keypoints: Vec<PoseKeypoint> = kp_names.iter().zip(kp_offsets.iter())
|
||||
.enumerate()
|
||||
.map(|(i, (name, (dx, dy)))| {
|
||||
let breath_dx = if TORSO_KP.contains(&i) {
|
||||
let sign = if *dx < 0.0 { -1.0 } else { 1.0 };
|
||||
sign * breath_amp * breath_phase * 0.5
|
||||
} else { 0.0 };
|
||||
let breath_dy = if TORSO_KP.contains(&i) {
|
||||
let sign = if *dy < 0.0 { -1.0 } else { 1.0 };
|
||||
sign * breath_amp * breath_phase * 0.3
|
||||
} else { 0.0 };
|
||||
|
||||
let extremity_jitter = if EXTREMITY_KP.contains(&i) {
|
||||
let phase = noise_seed + i as f64 * 2.399;
|
||||
(phase.sin() * burst * motion_score * 4.0, (phase * 1.31).cos() * burst * motion_score * 3.0)
|
||||
} else { (0.0, 0.0) };
|
||||
|
||||
let kp_noise_x = ((noise_seed + i as f64 * 1.618).sin() * 43758.545).fract()
|
||||
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score;
|
||||
let kp_noise_y = ((noise_seed + i as f64 * 2.718).cos() * 31415.926).fract()
|
||||
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score * 0.6;
|
||||
|
||||
let swing_dy = if is_walking {
|
||||
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.12 + phase_offset).sin();
|
||||
match i {
|
||||
7 | 9 => -stride_phase * 20.0 * motion_score,
|
||||
8 | 10 => stride_phase * 20.0 * motion_score,
|
||||
13 | 15 => stride_phase * 25.0 * motion_score,
|
||||
14 | 16 => -stride_phase * 25.0 * motion_score,
|
||||
_ => 0.0,
|
||||
}
|
||||
} else { 0.0 };
|
||||
|
||||
let final_x = base_x + dx + breath_dx + extremity_jitter.0 + kp_noise_x;
|
||||
let final_y = base_y + dy + breath_dy + extremity_jitter.1 + kp_noise_y + swing_dy;
|
||||
|
||||
let kp_conf = if EXTREMITY_KP.contains(&i) {
|
||||
base_confidence * (0.7 + 0.3 * snr_factor) * (0.85 + 0.15 * noise_val)
|
||||
} else {
|
||||
base_confidence * (0.88 + 0.12 * ((i as f64 * 0.7 + noise_seed).cos()))
|
||||
};
|
||||
|
||||
PoseKeypoint { name: name.to_string(), x: final_x, y: final_y, z: lean_x * 0.02, confidence: kp_conf.clamp(0.1, 1.0) }
|
||||
})
|
||||
.collect();
|
||||
|
||||
let xs: Vec<f64> = keypoints.iter().map(|k| k.x).collect();
|
||||
let ys: Vec<f64> = keypoints.iter().map(|k| k.y).collect();
|
||||
let min_x = xs.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
|
||||
let min_y = ys.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
|
||||
let max_x = xs.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
|
||||
let max_y = ys.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
|
||||
|
||||
PersonDetection {
|
||||
id: (person_idx + 1) as u32,
|
||||
confidence: cls.confidence * conf_decay,
|
||||
keypoints,
|
||||
bbox: BoundingBox { x: min_x, y: min_y, width: (max_x - min_x).max(80.0), height: (max_y - min_y).max(160.0) },
|
||||
zone: format!("zone_{}", person_idx + 1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn derive_pose_from_sensing(update: &SensingUpdate) -> Vec<PersonDetection> {
|
||||
let cls = &update.classification;
|
||||
if !cls.presence { return vec![]; }
|
||||
let person_count = update.estimated_persons.unwrap_or(1).max(1);
|
||||
(0..person_count).map(|idx| derive_single_person_pose(update, idx, person_count)).collect()
|
||||
}
|
||||
|
||||
/// Apply temporal EMA smoothing and bone-length clamping to person detections.
|
||||
pub fn apply_temporal_smoothing(persons: &mut [PersonDetection], ns: &mut NodeState) {
|
||||
if persons.is_empty() { return; }
|
||||
|
||||
let alpha = ns.ema_alpha();
|
||||
let person = &mut persons[0];
|
||||
|
||||
let current_kps: Vec<[f64; 3]> = person.keypoints.iter()
|
||||
.map(|kp| [kp.x, kp.y, kp.z]).collect();
|
||||
|
||||
let smoothed = if let Some(ref prev) = ns.prev_keypoints {
|
||||
let mut out = Vec::with_capacity(current_kps.len());
|
||||
for (cur, prv) in current_kps.iter().zip(prev.iter()) {
|
||||
out.push([
|
||||
alpha * cur[0] + (1.0 - alpha) * prv[0],
|
||||
alpha * cur[1] + (1.0 - alpha) * prv[1],
|
||||
alpha * cur[2] + (1.0 - alpha) * prv[2],
|
||||
]);
|
||||
}
|
||||
clamp_bone_lengths_f64(&mut out, prev);
|
||||
out
|
||||
} else {
|
||||
current_kps.clone()
|
||||
};
|
||||
|
||||
for (kp, s) in person.keypoints.iter_mut().zip(smoothed.iter()) {
|
||||
kp.x = s[0]; kp.y = s[1]; kp.z = s[2];
|
||||
}
|
||||
ns.prev_keypoints = Some(smoothed);
|
||||
}
|
||||
|
||||
fn clamp_bone_lengths_f64(pose: &mut Vec<[f64; 3]>, prev: &[[f64; 3]]) {
|
||||
for &(p, c) in POSE_BONE_PAIRS {
|
||||
if p >= pose.len() || c >= pose.len() { continue; }
|
||||
let prev_len = dist_f64(&prev[p], &prev[c]);
|
||||
if prev_len < 1e-6 { continue; }
|
||||
let cur_len = dist_f64(&pose[p], &pose[c]);
|
||||
if cur_len < 1e-6 { continue; }
|
||||
let ratio = cur_len / prev_len;
|
||||
let lo = 1.0 - MAX_BONE_CHANGE_RATIO;
|
||||
let hi = 1.0 + MAX_BONE_CHANGE_RATIO;
|
||||
if ratio < lo || ratio > hi {
|
||||
let target = prev_len * ratio.clamp(lo, hi);
|
||||
let scale = target / cur_len;
|
||||
for dim in 0..3 {
|
||||
let diff = pose[c][dim] - pose[p][dim];
|
||||
pose[c][dim] = pose[p][dim] + diff * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dist_f64(a: &[f64; 3], b: &[f64; 3]) -> f64 {
|
||||
let dx = b[0] - a[0];
|
||||
let dy = b[1] - a[1];
|
||||
let dz = b[2] - a[2];
|
||||
(dx * dx + dy * dy + dz * dz).sqrt()
|
||||
}
|
||||
+409
@@ -0,0 +1,409 @@
|
||||
//! Bridge between sensing-server PersonDetection types and signal crate PoseTracker.
|
||||
//!
|
||||
//! The sensing server uses f64 types (PersonDetection, PoseKeypoint, BoundingBox)
|
||||
//! while the signal crate's PoseTracker operates on f32 Kalman states. This module
|
||||
//! provides conversion functions and a single `tracker_update` entry point that
|
||||
//! accepts server-side detections and returns tracker-smoothed results.
|
||||
|
||||
use std::time::Instant;
|
||||
use wifi_densepose_signal::ruvsense::{
|
||||
self, KeypointState, PoseTrack, TrackLifecycleState, TrackId, NUM_KEYPOINTS,
|
||||
};
|
||||
use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker;
|
||||
|
||||
use super::{BoundingBox, PersonDetection, PoseKeypoint};
|
||||
|
||||
/// COCO-17 keypoint names in index order.
|
||||
const COCO_NAMES: [&str; 17] = [
|
||||
"nose",
|
||||
"left_eye",
|
||||
"right_eye",
|
||||
"left_ear",
|
||||
"right_ear",
|
||||
"left_shoulder",
|
||||
"right_shoulder",
|
||||
"left_elbow",
|
||||
"right_elbow",
|
||||
"left_wrist",
|
||||
"right_wrist",
|
||||
"left_hip",
|
||||
"right_hip",
|
||||
"left_knee",
|
||||
"right_knee",
|
||||
"left_ankle",
|
||||
"right_ankle",
|
||||
];
|
||||
|
||||
/// Map a lowercase keypoint name to its COCO-17 index.
|
||||
fn keypoint_name_to_coco_index(name: &str) -> Option<usize> {
|
||||
COCO_NAMES.iter().position(|&n| n.eq_ignore_ascii_case(name))
|
||||
}
|
||||
|
||||
/// Convert server-side PersonDetection slices into tracker-compatible keypoint arrays.
|
||||
///
|
||||
/// For each person, maps named keypoints to COCO-17 positions. Unmapped slots are
|
||||
/// filled with the centroid of the mapped keypoints so the Kalman filter has a
|
||||
/// reasonable initial value rather than zeros.
|
||||
fn detections_to_tracker_keypoints(persons: &[PersonDetection]) -> Vec<[[f32; 3]; 17]> {
|
||||
persons
|
||||
.iter()
|
||||
.map(|person| {
|
||||
let mut kps = [[0.0_f32; 3]; 17];
|
||||
let mut mapped_count = 0u32;
|
||||
let mut cx = 0.0_f32;
|
||||
let mut cy = 0.0_f32;
|
||||
let mut cz = 0.0_f32;
|
||||
|
||||
// First pass: place mapped keypoints and accumulate centroid
|
||||
for kp in &person.keypoints {
|
||||
if let Some(idx) = keypoint_name_to_coco_index(&kp.name) {
|
||||
kps[idx] = [kp.x as f32, kp.y as f32, kp.z as f32];
|
||||
cx += kp.x as f32;
|
||||
cy += kp.y as f32;
|
||||
cz += kp.z as f32;
|
||||
mapped_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute centroid of mapped keypoints
|
||||
let centroid = if mapped_count > 0 {
|
||||
let n = mapped_count as f32;
|
||||
[cx / n, cy / n, cz / n]
|
||||
} else {
|
||||
[0.0, 0.0, 0.0]
|
||||
};
|
||||
|
||||
// Second pass: fill unmapped slots with centroid
|
||||
// Build a set of mapped indices
|
||||
let mut mapped = [false; 17];
|
||||
for kp in &person.keypoints {
|
||||
if let Some(idx) = keypoint_name_to_coco_index(&kp.name) {
|
||||
mapped[idx] = true;
|
||||
}
|
||||
}
|
||||
for i in 0..17 {
|
||||
if !mapped[i] {
|
||||
kps[i] = centroid;
|
||||
}
|
||||
}
|
||||
|
||||
kps
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert active PoseTracker tracks back into server-side PersonDetection values.
|
||||
///
|
||||
/// Only tracks whose lifecycle `is_alive()` are included.
|
||||
pub fn tracker_to_person_detections(tracker: &PoseTracker) -> Vec<PersonDetection> {
|
||||
tracker
|
||||
.active_tracks()
|
||||
.into_iter()
|
||||
.map(|track| {
|
||||
let id = track.id.0 as u32;
|
||||
|
||||
let confidence = match track.lifecycle {
|
||||
TrackLifecycleState::Active => 0.9,
|
||||
TrackLifecycleState::Tentative => 0.5,
|
||||
TrackLifecycleState::Lost => 0.3,
|
||||
TrackLifecycleState::Terminated => 0.0,
|
||||
};
|
||||
|
||||
// Build keypoints from Kalman state
|
||||
let keypoints: Vec<PoseKeypoint> = (0..NUM_KEYPOINTS)
|
||||
.map(|i| {
|
||||
let pos = track.keypoints[i].position();
|
||||
PoseKeypoint {
|
||||
name: COCO_NAMES[i].to_string(),
|
||||
x: pos[0] as f64,
|
||||
y: pos[1] as f64,
|
||||
z: pos[2] as f64,
|
||||
confidence: track.keypoints[i].confidence as f64,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Compute bounding box from observed keypoints only (confidence > 0).
|
||||
// Unobserved slots (centroid-filled) collapse the bbox over time.
|
||||
let mut min_x = f64::MAX;
|
||||
let mut min_y = f64::MAX;
|
||||
let mut max_x = f64::MIN;
|
||||
let mut max_y = f64::MIN;
|
||||
let mut observed = 0;
|
||||
for kp in &keypoints {
|
||||
if kp.confidence > 0.0 {
|
||||
if kp.x < min_x { min_x = kp.x; }
|
||||
if kp.y < min_y { min_y = kp.y; }
|
||||
if kp.x > max_x { max_x = kp.x; }
|
||||
if kp.y > max_y { max_y = kp.y; }
|
||||
observed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let bbox = if observed > 0 {
|
||||
BoundingBox {
|
||||
x: min_x,
|
||||
y: min_y,
|
||||
width: (max_x - min_x).max(0.01),
|
||||
height: (max_y - min_y).max(0.01),
|
||||
}
|
||||
} else {
|
||||
// No observed keypoints — use a default bbox at centroid
|
||||
let cx = keypoints.iter().map(|k| k.x).sum::<f64>() / keypoints.len() as f64;
|
||||
let cy = keypoints.iter().map(|k| k.y).sum::<f64>() / keypoints.len() as f64;
|
||||
BoundingBox { x: cx - 0.3, y: cy - 0.5, width: 0.6, height: 1.0 }
|
||||
};
|
||||
|
||||
PersonDetection {
|
||||
id,
|
||||
confidence,
|
||||
keypoints,
|
||||
bbox,
|
||||
zone: "tracked".to_string(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Run one tracker cycle: predict, match detections, update, prune.
|
||||
///
|
||||
/// This is the main entry point called each sensing frame. It:
|
||||
/// 1. Computes dt from the previous call instant
|
||||
/// 2. Predicts all existing tracks forward
|
||||
/// 3. Greedily assigns detections to tracks by Mahalanobis cost
|
||||
/// 4. Updates matched tracks, creates new tracks for unmatched detections
|
||||
/// 5. Prunes terminated tracks
|
||||
/// 6. Returns smoothed PersonDetection values from the tracker state
|
||||
pub fn tracker_update(
|
||||
tracker: &mut PoseTracker,
|
||||
last_instant: &mut Option<Instant>,
|
||||
persons: Vec<PersonDetection>,
|
||||
) -> Vec<PersonDetection> {
|
||||
let now = Instant::now();
|
||||
let dt = last_instant.map_or(0.1_f32, |prev| now.duration_since(prev).as_secs_f32());
|
||||
*last_instant = Some(now);
|
||||
|
||||
// Predict all tracks forward
|
||||
tracker.predict_all(dt);
|
||||
|
||||
if persons.is_empty() {
|
||||
tracker.prune_terminated();
|
||||
return tracker_to_person_detections(tracker);
|
||||
}
|
||||
|
||||
// Convert detections to f32 keypoint arrays
|
||||
let all_keypoints = detections_to_tracker_keypoints(&persons);
|
||||
|
||||
// Compute centroids for each detection
|
||||
let centroids: Vec<[f32; 3]> = all_keypoints
|
||||
.iter()
|
||||
.map(|kps| {
|
||||
let mut c = [0.0_f32; 3];
|
||||
for kp in kps {
|
||||
c[0] += kp[0];
|
||||
c[1] += kp[1];
|
||||
c[2] += kp[2];
|
||||
}
|
||||
let n = NUM_KEYPOINTS as f32;
|
||||
c[0] /= n;
|
||||
c[1] /= n;
|
||||
c[2] /= n;
|
||||
c
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Greedy assignment: for each detection, find the best matching active track.
|
||||
// Collect tracks once to avoid re-borrowing tracker per detection.
|
||||
let active: Vec<(TrackId, [f32; 3])> = tracker.active_tracks().iter().map(|t| {
|
||||
let centroid = {
|
||||
let mut c = [0.0_f32; 3];
|
||||
for kp in &t.keypoints {
|
||||
let p = kp.position();
|
||||
c[0] += p[0]; c[1] += p[1]; c[2] += p[2];
|
||||
}
|
||||
let n = NUM_KEYPOINTS as f32;
|
||||
[c[0] / n, c[1] / n, c[2] / n]
|
||||
};
|
||||
(t.id, centroid)
|
||||
}).collect();
|
||||
|
||||
let mut used_tracks: Vec<bool> = vec![false; active.len()];
|
||||
let mut matched: Vec<Option<TrackId>> = vec![None; persons.len()];
|
||||
|
||||
for det_idx in 0..persons.len() {
|
||||
let mut best_cost = f32::MAX;
|
||||
let mut best_track_idx = None;
|
||||
|
||||
let active_refs = tracker.active_tracks();
|
||||
for (track_idx, track) in active_refs.iter().enumerate() {
|
||||
if used_tracks[track_idx] {
|
||||
continue;
|
||||
}
|
||||
let cost = tracker.assignment_cost(track, ¢roids[det_idx], &[]);
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
best_track_idx = Some(track_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Mahalanobis gate: 9.0 (default TrackerConfig)
|
||||
if best_cost < 9.0 {
|
||||
if let Some(tidx) = best_track_idx {
|
||||
matched[det_idx] = Some(active[tidx].0);
|
||||
used_tracks[tidx] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Timestamp for new/updated tracks (microseconds since UNIX epoch)
|
||||
let timestamp_us = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_micros() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Update matched tracks (uses update_keypoints for proper lifecycle transitions)
|
||||
for (det_idx, track_id_opt) in matched.iter().enumerate() {
|
||||
if let Some(track_id) = track_id_opt {
|
||||
if let Some(track) = tracker.find_track_mut(*track_id) {
|
||||
track.update_keypoints(&all_keypoints[det_idx], 0.08, 1.0, timestamp_us);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new tracks for unmatched detections
|
||||
for (det_idx, track_id_opt) in matched.iter().enumerate() {
|
||||
if track_id_opt.is_none() {
|
||||
tracker.create_track(&all_keypoints[det_idx], timestamp_us);
|
||||
}
|
||||
}
|
||||
|
||||
tracker.prune_terminated();
|
||||
tracker_to_person_detections(tracker)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_keypoint(name: &str, x: f64, y: f64, z: f64) -> PoseKeypoint {
|
||||
PoseKeypoint {
|
||||
name: name.to_string(),
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
confidence: 0.9,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_person(id: u32, keypoints: Vec<PoseKeypoint>) -> PersonDetection {
|
||||
PersonDetection {
|
||||
id,
|
||||
confidence: 0.8,
|
||||
keypoints,
|
||||
bbox: BoundingBox {
|
||||
x: 0.0,
|
||||
y: 0.0,
|
||||
width: 1.0,
|
||||
height: 1.0,
|
||||
},
|
||||
zone: "test".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keypoint_name_to_coco_index() {
|
||||
assert_eq!(keypoint_name_to_coco_index("nose"), Some(0));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_eye"), Some(1));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_eye"), Some(2));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_ear"), Some(3));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_ear"), Some(4));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_shoulder"), Some(5));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_shoulder"), Some(6));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_elbow"), Some(7));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_elbow"), Some(8));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_wrist"), Some(9));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_wrist"), Some(10));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_hip"), Some(11));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_hip"), Some(12));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_knee"), Some(13));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_knee"), Some(14));
|
||||
assert_eq!(keypoint_name_to_coco_index("left_ankle"), Some(15));
|
||||
assert_eq!(keypoint_name_to_coco_index("right_ankle"), Some(16));
|
||||
assert_eq!(keypoint_name_to_coco_index("unknown"), None);
|
||||
// Case insensitive
|
||||
assert_eq!(keypoint_name_to_coco_index("NOSE"), Some(0));
|
||||
assert_eq!(keypoint_name_to_coco_index("Left_Eye"), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detections_to_tracker_keypoints() {
|
||||
let person = make_person(
|
||||
1,
|
||||
vec![
|
||||
make_keypoint("nose", 1.0, 2.0, 0.5),
|
||||
make_keypoint("left_shoulder", 0.8, 2.5, 0.4),
|
||||
make_keypoint("right_shoulder", 1.2, 2.5, 0.6),
|
||||
],
|
||||
);
|
||||
|
||||
let result = detections_to_tracker_keypoints(&[person]);
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let kps = &result[0];
|
||||
|
||||
// Mapped keypoints should have correct values
|
||||
assert!((kps[0][0] - 1.0).abs() < 1e-5); // nose x
|
||||
assert!((kps[0][1] - 2.0).abs() < 1e-5); // nose y
|
||||
assert!((kps[0][2] - 0.5).abs() < 1e-5); // nose z
|
||||
|
||||
assert!((kps[5][0] - 0.8).abs() < 1e-5); // left_shoulder x
|
||||
assert!((kps[6][0] - 1.2).abs() < 1e-5); // right_shoulder x
|
||||
|
||||
// Unmapped keypoints should be at centroid of mapped keypoints
|
||||
// centroid = ((1.0+0.8+1.2)/3, (2.0+2.5+2.5)/3, (0.5+0.4+0.6)/3)
|
||||
let cx = (1.0 + 0.8 + 1.2) / 3.0;
|
||||
let cy = (2.0 + 2.5 + 2.5) / 3.0;
|
||||
let cz = (0.5 + 0.4 + 0.6) / 3.0;
|
||||
|
||||
// left_eye (index 1) should be at centroid
|
||||
assert!((kps[1][0] - cx).abs() < 1e-4);
|
||||
assert!((kps[1][1] - cy).abs() < 1e-4);
|
||||
assert!((kps[1][2] - cz).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tracker_update_stable_ids() {
|
||||
let mut tracker = PoseTracker::new();
|
||||
let mut last_instant: Option<Instant> = None;
|
||||
|
||||
let person = make_person(
|
||||
0,
|
||||
vec![
|
||||
make_keypoint("nose", 1.0, 2.0, 0.0),
|
||||
make_keypoint("left_shoulder", 0.8, 2.5, 0.0),
|
||||
make_keypoint("right_shoulder", 1.2, 2.5, 0.0),
|
||||
make_keypoint("left_hip", 0.9, 3.5, 0.0),
|
||||
make_keypoint("right_hip", 1.1, 3.5, 0.0),
|
||||
],
|
||||
);
|
||||
|
||||
// First update: creates a new track
|
||||
let result1 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]);
|
||||
assert_eq!(result1.len(), 1);
|
||||
let id1 = result1[0].id;
|
||||
|
||||
// Second update: should match the existing track
|
||||
let result2 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]);
|
||||
assert_eq!(result2.len(), 1);
|
||||
let id2 = result2[0].id;
|
||||
|
||||
// Third update: same track ID should persist
|
||||
let result3 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]);
|
||||
assert_eq!(result3.len(), 1);
|
||||
let id3 = result3[0].id;
|
||||
|
||||
// All three updates should return the same track ID
|
||||
assert_eq!(id1, id2, "Track ID should be stable across updates");
|
||||
assert_eq!(id2, id3, "Track ID should be stable across updates");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
//! Data types, constants, and shared state definitions.
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
|
||||
use crate::adaptive_classifier;
|
||||
use crate::rvf_container::RvfContainerInfo;
|
||||
use crate::rvf_pipeline::ProgressiveLoader;
|
||||
use crate::vital_signs::{VitalSignDetector, VitalSigns};
|
||||
|
||||
use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker;
|
||||
use wifi_densepose_signal::ruvsense::multistatic::MultistaticFuser;
|
||||
use wifi_densepose_signal::ruvsense::field_model::FieldModel;
|
||||
|
||||
// ── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Number of frames retained in `frame_history` for temporal analysis.
|
||||
pub const FRAME_HISTORY_CAPACITY: usize = 100;
|
||||
|
||||
/// If no ESP32 frame arrives within this duration, source reverts to offline.
|
||||
pub const ESP32_OFFLINE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
|
||||
|
||||
/// Default EMA alpha for temporal keypoint smoothing (RuVector Phase 2).
|
||||
pub const TEMPORAL_EMA_ALPHA_DEFAULT: f64 = 0.15;
|
||||
/// Reduced EMA alpha when coherence is low.
|
||||
pub const TEMPORAL_EMA_ALPHA_LOW_COHERENCE: f64 = 0.05;
|
||||
/// Coherence threshold below which we reduce EMA alpha.
|
||||
pub const COHERENCE_LOW_THRESHOLD: f64 = 0.3;
|
||||
/// Maximum allowed bone-length change ratio between frames (20%).
|
||||
pub const MAX_BONE_CHANGE_RATIO: f64 = 0.20;
|
||||
/// Number of motion_energy frames to track for coherence scoring.
|
||||
pub const COHERENCE_WINDOW: usize = 20;
|
||||
|
||||
/// Debounce frames required before state transition (at ~10 FPS = ~0.4s).
|
||||
pub const DEBOUNCE_FRAMES: u32 = 4;
|
||||
/// EMA alpha for motion smoothing (~1s time constant at 10 FPS).
|
||||
pub const MOTION_EMA_ALPHA: f64 = 0.15;
|
||||
/// EMA alpha for slow-adapting baseline (~30s time constant at 10 FPS).
|
||||
pub const BASELINE_EMA_ALPHA: f64 = 0.003;
|
||||
/// Number of warm-up frames before baseline subtraction kicks in.
|
||||
pub const BASELINE_WARMUP: u64 = 50;
|
||||
|
||||
/// Size of the median filter window for vital signs outlier rejection.
|
||||
pub const VITAL_MEDIAN_WINDOW: usize = 21;
|
||||
/// EMA alpha for vital signs (~5s time constant at 10 FPS).
|
||||
pub const VITAL_EMA_ALPHA: f64 = 0.02;
|
||||
/// Maximum BPM jump per frame before a value is rejected as an outlier.
|
||||
pub const HR_MAX_JUMP: f64 = 8.0;
|
||||
pub const BR_MAX_JUMP: f64 = 2.0;
|
||||
/// Minimum change from current smoothed value before EMA updates (dead-band).
|
||||
pub const HR_DEAD_BAND: f64 = 2.0;
|
||||
pub const BR_DEAD_BAND: f64 = 0.5;
|
||||
|
||||
// ── ESP32 Frame ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// ADR-018 ESP32 CSI binary frame header (20 bytes)
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct Esp32Frame {
|
||||
pub magic: u32,
|
||||
pub node_id: u8,
|
||||
pub n_antennas: u8,
|
||||
pub n_subcarriers: u8,
|
||||
pub freq_mhz: u16,
|
||||
pub sequence: u32,
|
||||
pub rssi: i8,
|
||||
pub noise_floor: i8,
|
||||
pub amplitudes: Vec<f64>,
|
||||
pub phases: Vec<f64>,
|
||||
}
|
||||
|
||||
// ── Sensing Update ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Sensing update broadcast to WebSocket clients
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SensingUpdate {
|
||||
#[serde(rename = "type")]
|
||||
pub msg_type: String,
|
||||
pub timestamp: f64,
|
||||
pub source: String,
|
||||
pub tick: u64,
|
||||
pub nodes: Vec<NodeInfo>,
|
||||
pub features: FeatureInfo,
|
||||
pub classification: ClassificationInfo,
|
||||
pub signal_field: SignalField,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub vital_signs: Option<VitalSigns>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enhanced_motion: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enhanced_breathing: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub posture: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub signal_quality_score: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub quality_verdict: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bssid_count: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pose_keypoints: Option<Vec<[f64; 4]>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_status: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub persons: Option<Vec<PersonDetection>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub estimated_persons: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub node_features: Option<Vec<PerNodeFeatureInfo>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeInfo {
|
||||
pub node_id: u8,
|
||||
pub rssi_dbm: f64,
|
||||
pub position: [f64; 3],
|
||||
pub amplitude: Vec<f64>,
|
||||
pub subcarrier_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FeatureInfo {
|
||||
pub mean_rssi: f64,
|
||||
pub variance: f64,
|
||||
pub motion_band_power: f64,
|
||||
pub breathing_band_power: f64,
|
||||
pub dominant_freq_hz: f64,
|
||||
pub change_points: usize,
|
||||
pub spectral_power: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationInfo {
|
||||
pub motion_level: String,
|
||||
pub presence: bool,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalField {
|
||||
pub grid_size: [usize; 3],
|
||||
pub values: Vec<f64>,
|
||||
}
|
||||
|
||||
/// WiFi-derived pose keypoint (17 COCO keypoints)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PoseKeypoint {
|
||||
pub name: String,
|
||||
pub x: f64,
|
||||
pub y: f64,
|
||||
pub z: f64,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
/// Person detection from WiFi sensing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonDetection {
|
||||
pub id: u32,
|
||||
pub confidence: f64,
|
||||
pub keypoints: Vec<PoseKeypoint>,
|
||||
pub bbox: BoundingBox,
|
||||
pub zone: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BoundingBox {
|
||||
pub x: f64,
|
||||
pub y: f64,
|
||||
pub width: f64,
|
||||
pub height: f64,
|
||||
}
|
||||
|
||||
/// Per-node feature info for WebSocket broadcasts (multi-node support).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PerNodeFeatureInfo {
|
||||
pub node_id: u8,
|
||||
pub features: FeatureInfo,
|
||||
pub classification: ClassificationInfo,
|
||||
pub rssi_dbm: f64,
|
||||
pub last_seen_ms: u64,
|
||||
pub frame_rate_hz: f64,
|
||||
pub stale: bool,
|
||||
}
|
||||
|
||||
// ── ESP32 Edge Vitals Packet (ADR-039) ──────────────────────────────────────
|
||||
|
||||
/// Decoded vitals packet from ESP32 edge processing pipeline.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct Esp32VitalsPacket {
|
||||
pub node_id: u8,
|
||||
pub presence: bool,
|
||||
pub fall_detected: bool,
|
||||
pub motion: bool,
|
||||
pub breathing_rate_bpm: f64,
|
||||
pub heartrate_bpm: f64,
|
||||
pub rssi: i8,
|
||||
pub n_persons: u8,
|
||||
pub motion_energy: f32,
|
||||
pub presence_score: f32,
|
||||
pub timestamp_ms: u32,
|
||||
}
|
||||
|
||||
/// Single WASM event (type + value).
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WasmEvent {
|
||||
pub event_type: u8,
|
||||
pub value: f32,
|
||||
}
|
||||
|
||||
/// Decoded WASM output packet from ESP32 Tier 3 runtime.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WasmOutputPacket {
|
||||
pub node_id: u8,
|
||||
pub module_id: u8,
|
||||
pub events: Vec<WasmEvent>,
|
||||
}
|
||||
|
||||
// ── Per-node state ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Per-node sensing state for multi-node deployments (issue #249).
|
||||
pub struct NodeState {
|
||||
pub frame_history: VecDeque<Vec<f64>>,
|
||||
pub smoothed_person_score: f64,
|
||||
pub prev_person_count: usize,
|
||||
pub smoothed_motion: f64,
|
||||
pub current_motion_level: String,
|
||||
pub debounce_counter: u32,
|
||||
pub debounce_candidate: String,
|
||||
pub baseline_motion: f64,
|
||||
pub baseline_frames: u64,
|
||||
pub smoothed_hr: f64,
|
||||
pub smoothed_br: f64,
|
||||
pub smoothed_hr_conf: f64,
|
||||
pub smoothed_br_conf: f64,
|
||||
pub hr_buffer: VecDeque<f64>,
|
||||
pub br_buffer: VecDeque<f64>,
|
||||
pub rssi_history: VecDeque<f64>,
|
||||
pub vital_detector: VitalSignDetector,
|
||||
pub latest_vitals: VitalSigns,
|
||||
pub last_frame_time: Option<std::time::Instant>,
|
||||
pub edge_vitals: Option<Esp32VitalsPacket>,
|
||||
pub latest_features: Option<FeatureInfo>,
|
||||
pub prev_keypoints: Option<Vec<[f64; 3]>>,
|
||||
pub motion_energy_history: VecDeque<f64>,
|
||||
pub coherence_score: f64,
|
||||
}
|
||||
|
||||
impl NodeState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
frame_history: VecDeque::new(),
|
||||
smoothed_person_score: 0.0,
|
||||
prev_person_count: 0,
|
||||
smoothed_motion: 0.0,
|
||||
current_motion_level: "absent".to_string(),
|
||||
debounce_counter: 0,
|
||||
debounce_candidate: "absent".to_string(),
|
||||
baseline_motion: 0.0,
|
||||
baseline_frames: 0,
|
||||
smoothed_hr: 0.0,
|
||||
smoothed_br: 0.0,
|
||||
smoothed_hr_conf: 0.0,
|
||||
smoothed_br_conf: 0.0,
|
||||
hr_buffer: VecDeque::with_capacity(8),
|
||||
br_buffer: VecDeque::with_capacity(8),
|
||||
rssi_history: VecDeque::new(),
|
||||
vital_detector: VitalSignDetector::new(10.0),
|
||||
latest_vitals: VitalSigns::default(),
|
||||
last_frame_time: None,
|
||||
edge_vitals: None,
|
||||
latest_features: None,
|
||||
prev_keypoints: None,
|
||||
motion_energy_history: VecDeque::with_capacity(COHERENCE_WINDOW),
|
||||
coherence_score: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the coherence score from the latest motion_energy value.
|
||||
pub fn update_coherence(&mut self, motion_energy: f64) {
|
||||
if self.motion_energy_history.len() >= COHERENCE_WINDOW {
|
||||
self.motion_energy_history.pop_front();
|
||||
}
|
||||
self.motion_energy_history.push_back(motion_energy);
|
||||
|
||||
let n = self.motion_energy_history.len();
|
||||
if n < 2 {
|
||||
self.coherence_score = 1.0;
|
||||
return;
|
||||
}
|
||||
|
||||
let mean: f64 = self.motion_energy_history.iter().sum::<f64>() / n as f64;
|
||||
let variance: f64 = self.motion_energy_history.iter()
|
||||
.map(|v| (v - mean) * (v - mean))
|
||||
.sum::<f64>() / (n - 1) as f64;
|
||||
|
||||
self.coherence_score = (1.0 / (1.0 + variance)).clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Choose the EMA alpha based on current coherence score.
|
||||
pub fn ema_alpha(&self) -> f64 {
|
||||
if self.coherence_score < COHERENCE_LOW_THRESHOLD {
|
||||
TEMPORAL_EMA_ALPHA_LOW_COHERENCE
|
||||
} else {
|
||||
TEMPORAL_EMA_ALPHA_DEFAULT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared application state ────────────────────────────────────────────────
|
||||
|
||||
/// Shared application state
|
||||
pub struct AppStateInner {
|
||||
pub latest_update: Option<SensingUpdate>,
|
||||
pub rssi_history: VecDeque<f64>,
|
||||
pub frame_history: VecDeque<Vec<f64>>,
|
||||
pub tick: u64,
|
||||
pub source: String,
|
||||
pub last_esp32_frame: Option<std::time::Instant>,
|
||||
pub tx: broadcast::Sender<String>,
|
||||
pub total_detections: u64,
|
||||
pub start_time: std::time::Instant,
|
||||
pub vital_detector: VitalSignDetector,
|
||||
pub latest_vitals: VitalSigns,
|
||||
pub rvf_info: Option<RvfContainerInfo>,
|
||||
pub save_rvf_path: Option<PathBuf>,
|
||||
pub progressive_loader: Option<ProgressiveLoader>,
|
||||
pub active_sona_profile: Option<String>,
|
||||
pub model_loaded: bool,
|
||||
pub smoothed_person_score: f64,
|
||||
pub prev_person_count: usize,
|
||||
pub smoothed_motion: f64,
|
||||
pub current_motion_level: String,
|
||||
pub debounce_counter: u32,
|
||||
pub debounce_candidate: String,
|
||||
pub baseline_motion: f64,
|
||||
pub baseline_frames: u64,
|
||||
pub smoothed_hr: f64,
|
||||
pub smoothed_br: f64,
|
||||
pub smoothed_hr_conf: f64,
|
||||
pub smoothed_br_conf: f64,
|
||||
pub hr_buffer: VecDeque<f64>,
|
||||
pub br_buffer: VecDeque<f64>,
|
||||
pub edge_vitals: Option<Esp32VitalsPacket>,
|
||||
pub latest_wasm_events: Option<WasmOutputPacket>,
|
||||
pub discovered_models: Vec<serde_json::Value>,
|
||||
pub active_model_id: Option<String>,
|
||||
pub recordings: Vec<serde_json::Value>,
|
||||
pub recording_active: bool,
|
||||
pub recording_start_time: Option<std::time::Instant>,
|
||||
pub recording_current_id: Option<String>,
|
||||
pub recording_stop_tx: Option<tokio::sync::watch::Sender<bool>>,
|
||||
pub training_status: String,
|
||||
pub training_config: Option<serde_json::Value>,
|
||||
pub adaptive_model: Option<adaptive_classifier::AdaptiveModel>,
|
||||
pub node_states: HashMap<u8, NodeState>,
|
||||
pub pose_tracker: PoseTracker,
|
||||
pub last_tracker_instant: Option<std::time::Instant>,
|
||||
pub multistatic_fuser: MultistaticFuser,
|
||||
pub field_model: Option<FieldModel>,
|
||||
}
|
||||
|
||||
impl AppStateInner {
|
||||
/// Return the effective data source, accounting for ESP32 frame timeout.
|
||||
pub fn effective_source(&self) -> String {
|
||||
if self.source == "esp32" {
|
||||
if let Some(last) = self.last_esp32_frame {
|
||||
if last.elapsed() > ESP32_OFFLINE_TIMEOUT {
|
||||
return "esp32:offline".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
/// Person count: eigenvalue-based if field model is calibrated, else heuristic.
|
||||
pub fn person_count(&self) -> usize {
|
||||
use crate::field_bridge;
|
||||
use crate::csi::score_to_person_count;
|
||||
match self.field_model.as_ref() {
|
||||
Some(fm) => {
|
||||
let history = if !self.frame_history.is_empty() {
|
||||
&self.frame_history
|
||||
} else {
|
||||
self.node_states.values()
|
||||
.filter(|ns| !ns.frame_history.is_empty())
|
||||
.max_by_key(|ns| ns.last_frame_time)
|
||||
.map(|ns| &ns.frame_history)
|
||||
.unwrap_or(&self.frame_history)
|
||||
};
|
||||
field_bridge::occupancy_or_fallback(
|
||||
fm, history, self.smoothed_person_score, self.prev_person_count,
|
||||
)
|
||||
}
|
||||
None => score_to_person_count(self.smoothed_person_score, self.prev_person_count),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SharedState = Arc<RwLock<AppStateInner>>;
|
||||
@@ -11,6 +11,12 @@ keywords = ["wifi", "csi", "signal-processing", "densepose", "rust"]
|
||||
categories = ["science", "computer-vision"]
|
||||
readme = "README.md"
|
||||
|
||||
[features]
|
||||
default = ["eigenvalue"]
|
||||
## Enable eigenvalue-based person counting (requires BLAS via ndarray-linalg).
|
||||
## Disable with --no-default-features to use the diagonal fallback instead.
|
||||
eigenvalue = ["ndarray-linalg"]
|
||||
|
||||
[dependencies]
|
||||
# Core utilities
|
||||
thiserror.workspace = true
|
||||
@@ -20,6 +26,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# Signal processing
|
||||
ndarray = { workspace = true }
|
||||
ndarray-linalg = { workspace = true, optional = true }
|
||||
rustfft.workspace = true
|
||||
num-complex.workspace = true
|
||||
num-traits.workspace = true
|
||||
|
||||
+486
-44
@@ -17,6 +17,12 @@
|
||||
//! of Squares and Products." Technometrics.
|
||||
//! - ADR-030: RuvSense Persistent Field Model
|
||||
|
||||
use ndarray::Array2;
|
||||
#[cfg(feature = "eigenvalue")]
|
||||
use ndarray_linalg::Eigh;
|
||||
#[cfg(feature = "eigenvalue")]
|
||||
use ndarray_linalg::UPLO;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -47,6 +53,14 @@ pub enum FieldModelError {
|
||||
/// Invalid configuration parameter.
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// Model has not been calibrated yet.
|
||||
#[error("Field model not calibrated")]
|
||||
NotCalibrated,
|
||||
|
||||
/// Not enough data for the requested operation.
|
||||
#[error("Insufficient data: need {need}, have {have}")]
|
||||
InsufficientData { need: usize, have: usize },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -260,6 +274,8 @@ pub struct FieldNormalMode {
|
||||
pub calibrated_at_us: u64,
|
||||
/// Hash of mesh geometry at calibration time.
|
||||
pub geometry_hash: u64,
|
||||
/// Baseline eigenvalue count above Marcenko-Pastur threshold (empty-room).
|
||||
pub baseline_eigenvalue_count: usize,
|
||||
}
|
||||
|
||||
/// Body perturbation extracted from a CSI observation.
|
||||
@@ -310,6 +326,60 @@ pub struct FieldModel {
|
||||
status: CalibrationStatus,
|
||||
/// Timestamp of last calibration completion (microseconds).
|
||||
last_calibration_us: u64,
|
||||
/// Running outer-product sum for full covariance SVD: [n_sub x n_sub].
|
||||
covariance_sum: Option<Array2<f64>>,
|
||||
/// Number of frames accumulated into covariance_sum.
|
||||
covariance_count: u64,
|
||||
}
|
||||
|
||||
/// Diagonal variance fallback for when full covariance SVD is unavailable.
|
||||
///
|
||||
/// Returns `(mode_energies, environmental_modes, baseline_eigenvalue_count)`.
|
||||
fn diagonal_fallback(
|
||||
link_stats: &[LinkBaselineStats],
|
||||
n_sc: usize,
|
||||
n_modes: usize,
|
||||
) -> (Vec<f64>, Vec<Vec<f64>>, usize) {
|
||||
// Average variance across links (diagonal approximation)
|
||||
let mut avg_variance = vec![0.0_f64; n_sc];
|
||||
for ls in link_stats {
|
||||
let var = ls.variance_vector();
|
||||
for (i, v) in var.iter().enumerate() {
|
||||
avg_variance[i] += v;
|
||||
}
|
||||
}
|
||||
let n_links_f = link_stats.len() as f64;
|
||||
if n_links_f > 0.0 {
|
||||
for v in avg_variance.iter_mut() {
|
||||
*v /= n_links_f;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort subcarrier indices by variance (descending) to pick top-K modes
|
||||
let mut indices: Vec<usize> = (0..n_sc).collect();
|
||||
indices.sort_by(|&a, &b| {
|
||||
avg_variance[b]
|
||||
.partial_cmp(&avg_variance[a])
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut environmental_modes = Vec::with_capacity(n_modes);
|
||||
let mut mode_energies = Vec::with_capacity(n_modes);
|
||||
|
||||
for k in 0..n_modes.min(n_sc) {
|
||||
let idx = indices[k];
|
||||
let mut mode = vec![0.0_f64; n_sc];
|
||||
mode[idx] = 1.0;
|
||||
mode_energies.push(avg_variance[idx]);
|
||||
environmental_modes.push(mode);
|
||||
}
|
||||
|
||||
// For diagonal fallback, estimate baseline eigenvalue count from variance
|
||||
let total_var: f64 = avg_variance.iter().sum();
|
||||
let mean_var = if n_sc > 0 { total_var / n_sc as f64 } else { 0.0 };
|
||||
let baseline_count = avg_variance.iter().filter(|&&v| v > mean_var * 2.0).count();
|
||||
|
||||
(mode_energies, environmental_modes, baseline_count)
|
||||
}
|
||||
|
||||
impl FieldModel {
|
||||
@@ -339,6 +409,8 @@ impl FieldModel {
|
||||
modes: None,
|
||||
status: CalibrationStatus::Uncalibrated,
|
||||
last_calibration_us: 0,
|
||||
covariance_sum: None,
|
||||
covariance_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -375,6 +447,30 @@ impl FieldModel {
|
||||
if self.status == CalibrationStatus::Uncalibrated {
|
||||
self.status = CalibrationStatus::Collecting;
|
||||
}
|
||||
|
||||
// Accumulate raw outer products for SVD covariance (no centering here —
|
||||
// mean subtraction is deferred to finalize_calibration to avoid bias).
|
||||
// We average across links so covariance_count tracks frames, not links.
|
||||
let n = self.config.n_subcarriers;
|
||||
let cov = self.covariance_sum.get_or_insert_with(|| Array2::zeros((n, n)));
|
||||
let n_links = observations.len();
|
||||
for obs in observations {
|
||||
if obs.len() >= n {
|
||||
// Rank-1 update: cov += obs * obs^T (raw, un-centered)
|
||||
for i in 0..n {
|
||||
for j in i..n {
|
||||
let val = obs[i] * obs[j];
|
||||
cov[[i, j]] += val;
|
||||
if i != j {
|
||||
cov[[j, i]] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Count once per frame (not per link) for correct MP ratio
|
||||
self.covariance_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -396,58 +492,134 @@ impl FieldModel {
|
||||
});
|
||||
}
|
||||
|
||||
// Build covariance matrix from per-link variance data.
|
||||
// We average the variance vectors across all links to get the
|
||||
// covariance diagonal, then compute eigenmodes via power iteration.
|
||||
let n_sc = self.config.n_subcarriers;
|
||||
let n_modes = self.config.n_modes.min(n_sc);
|
||||
|
||||
// Collect per-link baselines
|
||||
let baseline: Vec<Vec<f64>> = self.link_stats.iter().map(|ls| ls.mean_vector()).collect();
|
||||
|
||||
// Average covariance across links (diagonal approximation)
|
||||
let mut avg_variance = vec![0.0_f64; n_sc];
|
||||
for ls in &self.link_stats {
|
||||
let var = ls.variance_vector();
|
||||
for (i, v) in var.iter().enumerate() {
|
||||
avg_variance[i] += v;
|
||||
// --- True eigenvalue decomposition (with diagonal fallback) ---
|
||||
let (mode_energies, environmental_modes, baseline_eig_count) =
|
||||
if let Some(ref cov_sum) = self.covariance_sum {
|
||||
if self.covariance_count > 1 {
|
||||
// Compute sample covariance from raw outer products:
|
||||
// cov = (sum_xx / N - mean * mean^T) * N / (N-1)
|
||||
// where sum_xx accumulated obs * obs^T across all links per frame.
|
||||
// We average per-link means for centering.
|
||||
let n_frames = self.covariance_count as f64;
|
||||
let n_links = self.config.n_links as f64;
|
||||
// Average mean across all links
|
||||
let mut avg_mean = vec![0.0f64; n_sc];
|
||||
for ls in &self.link_stats {
|
||||
let m = ls.mean_vector();
|
||||
for i in 0..n_sc { avg_mean[i] += m[i]; }
|
||||
}
|
||||
for i in 0..n_sc { avg_mean[i] /= n_links; }
|
||||
// cov = sum_xx / (N * n_links) - mean * mean^T, then Bessel correction
|
||||
let total_obs = n_frames * n_links;
|
||||
let mut covariance = cov_sum / total_obs;
|
||||
for i in 0..n_sc {
|
||||
for j in 0..n_sc {
|
||||
covariance[[i, j]] -= avg_mean[i] * avg_mean[j];
|
||||
}
|
||||
}
|
||||
// Bessel's correction: multiply by N/(N-1) where N = total observations
|
||||
let bessel = total_obs / (total_obs - 1.0);
|
||||
covariance *= bessel;
|
||||
|
||||
// Symmetric eigendecomposition (requires eigenvalue feature / BLAS)
|
||||
#[cfg(feature = "eigenvalue")]
|
||||
match covariance.eigh(UPLO::Upper) {
|
||||
Ok((eigenvalues, eigenvectors)) => {
|
||||
// eigenvalues are in ascending order from ndarray-linalg
|
||||
// Reverse to get descending
|
||||
let len = eigenvalues.len();
|
||||
let mut sorted_indices: Vec<usize> = (0..len).collect();
|
||||
sorted_indices.sort_by(|&a, &b| {
|
||||
eigenvalues[b]
|
||||
.partial_cmp(&eigenvalues[a])
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Extract top n_modes
|
||||
let modes: Vec<Vec<f64>> = sorted_indices
|
||||
.iter()
|
||||
.take(n_modes)
|
||||
.map(|&idx| eigenvectors.column(idx).to_vec())
|
||||
.collect();
|
||||
let energies: Vec<f64> = sorted_indices
|
||||
.iter()
|
||||
.take(n_modes)
|
||||
.map(|&idx| eigenvalues[idx].max(0.0))
|
||||
.collect();
|
||||
|
||||
// Marcenko-Pastur noise estimate: median of POSITIVE
|
||||
// eigenvalues in the bottom half. Excludes zeros from
|
||||
// rank-deficient matrices (when p > n).
|
||||
let noise_var = {
|
||||
let mut positive: Vec<f64> = eigenvalues
|
||||
.iter().copied().filter(|&e| e > 1e-10).collect();
|
||||
positive.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
if positive.len() >= 4 {
|
||||
let half = positive.len() / 2;
|
||||
positive[..half].iter().sum::<f64>() / half as f64
|
||||
} else if !positive.is_empty() {
|
||||
positive[0]
|
||||
} else {
|
||||
1e-10
|
||||
}
|
||||
};
|
||||
// MP ratio: p/n where n = total observations (frames * links)
|
||||
let total_obs_mp = self.covariance_count as f64 * self.config.n_links as f64;
|
||||
let ratio = n_sc as f64 / total_obs_mp;
|
||||
let mp_threshold = noise_var * (1.0 + ratio.sqrt()).powi(2);
|
||||
let baseline_count = eigenvalues
|
||||
.iter()
|
||||
.filter(|&&ev| ev > mp_threshold)
|
||||
.count();
|
||||
|
||||
(energies, modes, baseline_count)
|
||||
}
|
||||
Err(_) => {
|
||||
// Fallback to diagonal approximation on SVD failure
|
||||
diagonal_fallback(&self.link_stats, n_sc, n_modes)
|
||||
}
|
||||
}
|
||||
// When eigenvalue feature is disabled, use diagonal fallback
|
||||
#[cfg(not(feature = "eigenvalue"))]
|
||||
{ diagonal_fallback(&self.link_stats, n_sc, n_modes) }
|
||||
} else {
|
||||
diagonal_fallback(&self.link_stats, n_sc, n_modes)
|
||||
}
|
||||
} else {
|
||||
diagonal_fallback(&self.link_stats, n_sc, n_modes)
|
||||
};
|
||||
|
||||
// Compute variance explained using the same centered covariance as modes.
|
||||
// total_variance = trace(centered_covariance) = sum of ALL eigenvalues.
|
||||
let total_energy: f64 = mode_energies.iter().sum();
|
||||
let total_variance = if let Some(ref cov_sum) = self.covariance_sum {
|
||||
if self.covariance_count > 1 {
|
||||
let n_links_f = self.config.n_links as f64;
|
||||
let total_obs = self.covariance_count as f64 * n_links_f;
|
||||
// Centered trace: E[x^2] - E[x]^2, with Bessel correction
|
||||
let mut avg_mean = vec![0.0f64; n_sc];
|
||||
for ls in &self.link_stats {
|
||||
let m = ls.mean_vector();
|
||||
for i in 0..n_sc { avg_mean[i] += m[i]; }
|
||||
}
|
||||
for i in 0..n_sc { avg_mean[i] /= n_links_f; }
|
||||
let raw_trace: f64 = (0..n_sc).map(|i| cov_sum[[i, i]] / total_obs).sum();
|
||||
let mean_sq: f64 = avg_mean.iter().map(|m| m * m).sum();
|
||||
(raw_trace - mean_sq).max(0.0) * total_obs / (total_obs - 1.0)
|
||||
} else {
|
||||
total_energy
|
||||
}
|
||||
}
|
||||
let n_links_f = self.config.n_links as f64;
|
||||
for v in avg_variance.iter_mut() {
|
||||
*v /= n_links_f;
|
||||
}
|
||||
|
||||
// Extract modes via simplified power iteration on the diagonal
|
||||
// covariance. Since we use a diagonal approximation, the eigenmodes
|
||||
// are aligned with the standard basis, sorted by variance.
|
||||
let total_variance: f64 = avg_variance.iter().sum();
|
||||
|
||||
// Sort subcarrier indices by variance (descending) to pick top-K modes
|
||||
let mut indices: Vec<usize> = (0..n_sc).collect();
|
||||
indices.sort_by(|&a, &b| {
|
||||
avg_variance[b]
|
||||
.partial_cmp(&avg_variance[a])
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut environmental_modes = Vec::with_capacity(n_modes);
|
||||
let mut mode_energies = Vec::with_capacity(n_modes);
|
||||
let mut explained = 0.0_f64;
|
||||
|
||||
for k in 0..n_modes {
|
||||
let idx = indices[k];
|
||||
// Create a unit vector along the highest-variance subcarrier
|
||||
let mut mode = vec![0.0_f64; n_sc];
|
||||
mode[idx] = 1.0;
|
||||
let energy = avg_variance[idx];
|
||||
environmental_modes.push(mode);
|
||||
mode_energies.push(energy);
|
||||
explained += energy;
|
||||
}
|
||||
|
||||
} else {
|
||||
total_energy
|
||||
};
|
||||
let variance_explained = if total_variance > 1e-15 {
|
||||
explained / total_variance
|
||||
total_energy / total_variance
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
@@ -459,6 +631,7 @@ impl FieldModel {
|
||||
variance_explained,
|
||||
calibrated_at_us: timestamp_us,
|
||||
geometry_hash,
|
||||
baseline_eigenvalue_count: baseline_eig_count,
|
||||
};
|
||||
|
||||
self.modes = Some(field_mode);
|
||||
@@ -541,6 +714,100 @@ impl FieldModel {
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimate room occupancy from eigenvalue analysis of recent CSI frames.
|
||||
///
|
||||
/// `recent_frames`: sliding window of amplitude vectors (recommend 50 frames
|
||||
/// ~ 2.5s at 20 Hz). Returns estimated person count (0 = empty room).
|
||||
///
|
||||
/// Requires the `eigenvalue` feature (BLAS). Returns `NotCalibrated` when
|
||||
/// the feature is disabled.
|
||||
#[cfg(feature = "eigenvalue")]
|
||||
pub fn estimate_occupancy(&self, recent_frames: &[Vec<f64>]) -> Result<usize, FieldModelError> {
|
||||
let modes = self.modes.as_ref().ok_or(FieldModelError::NotCalibrated)?;
|
||||
|
||||
let n = self.config.n_subcarriers;
|
||||
if recent_frames.len() < 10 {
|
||||
return Err(FieldModelError::InsufficientData {
|
||||
need: 10,
|
||||
have: recent_frames.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Build covariance matrix from recent frames
|
||||
let mut mean = vec![0.0f64; n];
|
||||
let mut count = 0usize;
|
||||
for frame in recent_frames {
|
||||
if frame.len() >= n {
|
||||
for i in 0..n {
|
||||
mean[i] += frame[i];
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
if count < 2 {
|
||||
return Ok(0);
|
||||
}
|
||||
for m in &mut mean {
|
||||
*m /= count as f64;
|
||||
}
|
||||
|
||||
let mut cov = Array2::<f64>::zeros((n, n));
|
||||
for frame in recent_frames {
|
||||
if frame.len() >= n {
|
||||
for i in 0..n {
|
||||
let ci = frame[i] - mean[i];
|
||||
for j in i..n {
|
||||
let val = ci * (frame[j] - mean[j]);
|
||||
cov[[i, j]] += val;
|
||||
if i != j {
|
||||
cov[[j, i]] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let scale = 1.0 / (count as f64 - 1.0);
|
||||
cov *= scale;
|
||||
|
||||
// Eigendecompose
|
||||
let eigenvalues = match cov.eigh(UPLO::Upper) {
|
||||
Ok((evals, _)) => evals,
|
||||
Err(_) => return Ok(0), // SVD failure = can't estimate
|
||||
};
|
||||
|
||||
// Marcenko-Pastur noise estimate: median of POSITIVE eigenvalues
|
||||
// in the bottom half. Excludes zeros from rank-deficient matrices
|
||||
// (common when n_subcarriers > n_frames, e.g. 56 subcarriers / 50 frames).
|
||||
let noise_var = {
|
||||
let mut positive: Vec<f64> = eigenvalues.iter()
|
||||
.copied()
|
||||
.filter(|&e| e > 1e-10)
|
||||
.collect();
|
||||
positive.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
if positive.len() >= 4 {
|
||||
let half = positive.len() / 2;
|
||||
positive[..half].iter().sum::<f64>() / half as f64
|
||||
} else if !positive.is_empty() {
|
||||
positive[0]
|
||||
} else {
|
||||
return Ok(0); // All zero eigenvalues — can't estimate
|
||||
}
|
||||
};
|
||||
let ratio = n as f64 / count as f64;
|
||||
let mp_threshold = noise_var * (1.0 + ratio.sqrt()).powi(2);
|
||||
|
||||
let significant = eigenvalues.iter().filter(|&&ev| ev > mp_threshold).count();
|
||||
let occupancy = significant.saturating_sub(modes.baseline_eigenvalue_count);
|
||||
|
||||
Ok(occupancy.min(10)) // Cap at 10 persons
|
||||
}
|
||||
|
||||
/// Stub when eigenvalue feature is disabled — always returns NotCalibrated.
|
||||
#[cfg(not(feature = "eigenvalue"))]
|
||||
pub fn estimate_occupancy(&self, _recent_frames: &[Vec<f64>]) -> Result<usize, FieldModelError> {
|
||||
Err(FieldModelError::NotCalibrated)
|
||||
}
|
||||
|
||||
/// Check calibration freshness against a given timestamp.
|
||||
pub fn check_freshness(&self, current_us: u64) -> CalibrationStatus {
|
||||
if self.modes.is_none() {
|
||||
@@ -563,6 +830,8 @@ impl FieldModel {
|
||||
.collect();
|
||||
self.modes = None;
|
||||
self.status = CalibrationStatus::Uncalibrated;
|
||||
self.covariance_sum = None;
|
||||
self.covariance_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -873,6 +1142,179 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_accumulation() {
|
||||
let config = make_config(2, 4, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Feed calibration data
|
||||
for i in 0..10 {
|
||||
let obs = make_observations(2, 4, 1.0 + 0.1 * i as f64);
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
|
||||
// covariance_sum should be populated
|
||||
assert!(model.covariance_sum.is_some());
|
||||
assert!(model.covariance_count > 0);
|
||||
let cov = model.covariance_sum.as_ref().unwrap();
|
||||
assert_eq!(cov.shape(), &[4, 4]);
|
||||
// Diagonal entries should be non-negative (sum of squares)
|
||||
for i in 0..4 {
|
||||
assert!(cov[[i, i]] >= 0.0, "Diagonal covariance entry must be >= 0");
|
||||
}
|
||||
// Matrix should be symmetric
|
||||
for i in 0..4 {
|
||||
for j in 0..4 {
|
||||
assert!(
|
||||
(cov[[i, j]] - cov[[j, i]]).abs() < 1e-10,
|
||||
"Covariance matrix must be symmetric"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_svd_finalize_produces_orthonormal_modes() {
|
||||
let config = FieldModelConfig {
|
||||
n_links: 1,
|
||||
n_subcarriers: 8,
|
||||
n_modes: 3,
|
||||
min_calibration_frames: 20,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
};
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Feed frames with correlated subcarrier patterns to produce
|
||||
// non-trivial eigenmodes
|
||||
for i in 0..50 {
|
||||
let t = i as f64 * 0.1;
|
||||
let obs = vec![vec![
|
||||
1.0 + t.sin(),
|
||||
2.0 + t.cos(),
|
||||
3.0 + 0.5 * t.sin(),
|
||||
4.0 + 0.3 * t.cos(),
|
||||
5.0 + 0.1 * t,
|
||||
6.0,
|
||||
7.0 + 0.2 * (2.0 * t).sin(),
|
||||
8.0 + 0.1 * (2.0 * t).cos(),
|
||||
]];
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
let modes = model.modes().unwrap();
|
||||
// Each mode should be approximately unit length
|
||||
for (k, mode) in modes.environmental_modes.iter().enumerate() {
|
||||
let norm: f64 = mode.iter().map(|x| x * x).sum::<f64>().sqrt();
|
||||
assert!(
|
||||
(norm - 1.0).abs() < 0.01,
|
||||
"Mode {} has norm {} (expected ~1.0)",
|
||||
k,
|
||||
norm
|
||||
);
|
||||
}
|
||||
// Modes should be approximately orthogonal
|
||||
for i in 0..modes.environmental_modes.len() {
|
||||
for j in (i + 1)..modes.environmental_modes.len() {
|
||||
let dot: f64 = modes.environmental_modes[i]
|
||||
.iter()
|
||||
.zip(modes.environmental_modes[j].iter())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum();
|
||||
assert!(
|
||||
dot.abs() < 0.05,
|
||||
"Modes {} and {} have dot product {} (expected ~0)",
|
||||
i,
|
||||
j,
|
||||
dot
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_occupancy_noise_only() {
|
||||
let config = FieldModelConfig {
|
||||
n_links: 1,
|
||||
n_subcarriers: 8,
|
||||
n_modes: 3,
|
||||
min_calibration_frames: 20,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
};
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Calibrate with some deterministic noise-like pattern
|
||||
for i in 0..50 {
|
||||
let t = i as f64 * 0.1;
|
||||
let obs = vec![vec![
|
||||
1.0 + 0.01 * t.sin(),
|
||||
2.0 + 0.01 * t.cos(),
|
||||
3.0 + 0.01 * (2.0 * t).sin(),
|
||||
4.0 + 0.01 * (2.0 * t).cos(),
|
||||
5.0 + 0.01 * (3.0 * t).sin(),
|
||||
6.0 + 0.01 * (3.0 * t).cos(),
|
||||
7.0 + 0.01 * (4.0 * t).sin(),
|
||||
8.0 + 0.01 * (4.0 * t).cos(),
|
||||
]];
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
// Estimate occupancy with similar noise-only frames
|
||||
let frames: Vec<Vec<f64>> = (0..20)
|
||||
.map(|i| {
|
||||
let t = (i + 50) as f64 * 0.1;
|
||||
vec![
|
||||
1.0 + 0.01 * t.sin(),
|
||||
2.0 + 0.01 * t.cos(),
|
||||
3.0 + 0.01 * (2.0 * t).sin(),
|
||||
4.0 + 0.01 * (2.0 * t).cos(),
|
||||
5.0 + 0.01 * (3.0 * t).sin(),
|
||||
6.0 + 0.01 * (3.0 * t).cos(),
|
||||
7.0 + 0.01 * (4.0 * t).sin(),
|
||||
8.0 + 0.01 * (4.0 * t).cos(),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
let occupancy = model.estimate_occupancy(&frames).unwrap();
|
||||
assert_eq!(occupancy, 0, "Noise-only frames should yield 0 occupancy");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_baseline_eigenvalue_count_stored() {
|
||||
let config = FieldModelConfig {
|
||||
n_links: 1,
|
||||
n_subcarriers: 8,
|
||||
n_modes: 3,
|
||||
min_calibration_frames: 20,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
};
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Feed frames with structured variance so eigenvalues are meaningful
|
||||
for i in 0..50 {
|
||||
let t = i as f64 * 0.1;
|
||||
let obs = vec![vec![
|
||||
1.0 + t.sin(),
|
||||
2.0 + t.cos(),
|
||||
3.0 + 0.5 * t.sin(),
|
||||
4.0 + 0.3 * t.cos(),
|
||||
5.0 + 0.1 * t,
|
||||
6.0,
|
||||
7.0,
|
||||
8.0,
|
||||
]];
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
let modes = model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
// baseline_eigenvalue_count should exist and be a reasonable value
|
||||
// (at least 0, at most n_subcarriers)
|
||||
assert!(
|
||||
modes.baseline_eigenvalue_count <= 8,
|
||||
"baseline_eigenvalue_count should be <= n_subcarriers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_environmental_projection_removes_drift() {
|
||||
let config = make_config(1, 4, 10);
|
||||
|
||||
+74
-18
@@ -339,9 +339,16 @@ impl RfTomographer {
|
||||
|
||||
/// Compute the intersection weights of a link with the voxel grid.
|
||||
///
|
||||
/// Uses a simplified approach: for each voxel, computes the minimum
|
||||
/// distance from the voxel center to the link ray. Voxels within
|
||||
/// one Fresnel zone receive weight proportional to closeness.
|
||||
/// Uses a DDA (Digital Differential Analyzer) ray-marching algorithm:
|
||||
/// 1. March along the ray from TX to RX, advancing to the nearest
|
||||
/// axis-aligned voxel boundary at each step.
|
||||
/// 2. At each ray voxel, expand by the Fresnel radius to check
|
||||
/// neighboring voxels.
|
||||
/// 3. Use a visited bitvector to avoid duplicate entries.
|
||||
/// 4. Weight = `1.0 - dist / fresnel_radius` (same as before).
|
||||
///
|
||||
/// This is O(ray_length / voxel_size) instead of O(nx*ny*nz),
|
||||
/// a significant speedup for large grids.
|
||||
fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(usize, f64)> {
|
||||
let vx = (config.bounds[3] - config.bounds[0]) / config.nx as f64;
|
||||
let vy = (config.bounds[4] - config.bounds[1]) / config.ny as f64;
|
||||
@@ -356,25 +363,74 @@ fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(
|
||||
let dy = link.rx.y - link.tx.y;
|
||||
let dz = link.rx.z - link.tx.z;
|
||||
|
||||
let n_voxels = config.nx * config.ny * config.nz;
|
||||
let mut visited = vec![false; n_voxels];
|
||||
let mut weights = Vec::new();
|
||||
|
||||
for iz in 0..config.nz {
|
||||
for iy in 0..config.ny {
|
||||
for ix in 0..config.nx {
|
||||
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
|
||||
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
|
||||
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
|
||||
// Fresnel expansion radius in voxel units.
|
||||
let expand_x = (fresnel_radius / vx).ceil() as isize;
|
||||
let expand_y = (fresnel_radius / vy).ceil() as isize;
|
||||
let expand_z = (fresnel_radius / vz).ceil() as isize;
|
||||
|
||||
// Point-to-line distance
|
||||
let dist = point_to_segment_distance(
|
||||
cx, cy, cz, link.tx.x, link.tx.y, link.tx.z, dx, dy, dz, link_dist,
|
||||
);
|
||||
// DDA initialization: start at TX position in voxel coordinates.
|
||||
let start_vx = (link.tx.x - config.bounds[0]) / vx;
|
||||
let start_vy = (link.tx.y - config.bounds[1]) / vy;
|
||||
let start_vz = (link.tx.z - config.bounds[2]) / vz;
|
||||
|
||||
if dist < fresnel_radius {
|
||||
// Weight decays with distance from link ray
|
||||
let w = 1.0 - dist / fresnel_radius;
|
||||
let idx = iz * config.ny * config.nx + iy * config.nx + ix;
|
||||
weights.push((idx, w));
|
||||
let end_vx = (link.rx.x - config.bounds[0]) / vx;
|
||||
let end_vy = (link.rx.y - config.bounds[1]) / vy;
|
||||
let end_vz = (link.rx.z - config.bounds[2]) / vz;
|
||||
|
||||
let ray_dx = end_vx - start_vx;
|
||||
let ray_dy = end_vy - start_vy;
|
||||
let ray_dz = end_vz - start_vz;
|
||||
|
||||
// Number of DDA steps: traverse the maximum voxel span.
|
||||
let steps = (ray_dx.abs().max(ray_dy.abs()).max(ray_dz.abs()).ceil() as usize).max(1);
|
||||
let inv_steps = 1.0 / steps as f64;
|
||||
|
||||
for step in 0..=steps {
|
||||
let t = step as f64 * inv_steps;
|
||||
let rx = start_vx + t * ray_dx;
|
||||
let ry = start_vy + t * ray_dy;
|
||||
let rz = start_vz + t * ray_dz;
|
||||
|
||||
let base_ix = rx.floor() as isize;
|
||||
let base_iy = ry.floor() as isize;
|
||||
let base_iz = rz.floor() as isize;
|
||||
|
||||
// Expand by Fresnel radius to check neighboring voxels.
|
||||
for diz in -expand_z..=expand_z {
|
||||
let iz = base_iz + diz;
|
||||
if iz < 0 || iz >= config.nz as isize { continue; }
|
||||
for diy in -expand_y..=expand_y {
|
||||
let iy = base_iy + diy;
|
||||
if iy < 0 || iy >= config.ny as isize { continue; }
|
||||
for dix in -expand_x..=expand_x {
|
||||
let ix = base_ix + dix;
|
||||
if ix < 0 || ix >= config.nx as isize { continue; }
|
||||
|
||||
let idx = iz as usize * config.ny * config.nx
|
||||
+ iy as usize * config.nx
|
||||
+ ix as usize;
|
||||
|
||||
if visited[idx] { continue; }
|
||||
|
||||
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
|
||||
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
|
||||
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
|
||||
|
||||
let dist = point_to_segment_distance(
|
||||
cx, cy, cz,
|
||||
link.tx.x, link.tx.y, link.tx.z,
|
||||
dx, dy, dz, link_dist,
|
||||
);
|
||||
|
||||
if dist < fresnel_radius {
|
||||
let w = 1.0 - dist / fresnel_radius;
|
||||
weights.push((idx, w));
|
||||
}
|
||||
visited[idx] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Ground-Truth Alignment — Camera Keypoints <-> CSI Recording
|
||||
*
|
||||
* Time-aligns camera keypoint data with CSI recording data to produce
|
||||
* paired training samples for WiFlow supervised training (ADR-079).
|
||||
*
|
||||
* Camera keypoints: data/ground-truth/gt-{timestamp}.jsonl
|
||||
* CSI recordings: data/recordings/*.csi.jsonl
|
||||
* Paired output: data/paired/*.paired.jsonl
|
||||
*
|
||||
* Usage:
|
||||
* node scripts/align-ground-truth.js \
|
||||
* --gt data/ground-truth/gt-1775300000.jsonl \
|
||||
* --csi data/recordings/overnight-1775217646.csi.jsonl \
|
||||
* --output data/paired/aligned.paired.jsonl
|
||||
*
|
||||
* # With clock offset correction (camera ahead by 50ms)
|
||||
* node scripts/align-ground-truth.js \
|
||||
* --gt data/ground-truth/gt-1775300000.jsonl \
|
||||
* --csi data/recordings/overnight-1775217646.csi.jsonl \
|
||||
* --clock-offset-ms -50
|
||||
*
|
||||
* ADR: docs/adr/ADR-079
|
||||
*/
|
||||
|
||||
'use strict';
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { parseArgs } = require('util');
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CLI argument parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
const { values: args } = parseArgs({
|
||||
options: {
|
||||
gt: { type: 'string' },
|
||||
csi: { type: 'string' },
|
||||
output: { type: 'string', short: 'o' },
|
||||
'window-ms': { type: 'string', default: '200' },
|
||||
'window-frames': { type: 'string', default: '20' },
|
||||
'min-camera-frames': { type: 'string', default: '3' },
|
||||
'min-confidence': { type: 'string', default: '0.5' },
|
||||
'clock-offset-ms': { type: 'string', default: '0' },
|
||||
help: { type: 'boolean', short: 'h', default: false },
|
||||
},
|
||||
strict: true,
|
||||
});
|
||||
|
||||
if (args.help || !args.gt || !args.csi) {
|
||||
console.log(`
|
||||
Usage: node scripts/align-ground-truth.js --gt <gt.jsonl> --csi <csi.jsonl> [options]
|
||||
|
||||
Required:
|
||||
--gt <path> Camera ground-truth JSONL file
|
||||
--csi <path> CSI recording JSONL file
|
||||
|
||||
Options:
|
||||
--output, -o <path> Output paired JSONL (default: data/paired/<basename>.paired.jsonl)
|
||||
--window-ms <ms> CSI window size in ms (default: 200)
|
||||
--window-frames <n> Frames per CSI window (default: 20)
|
||||
--min-camera-frames <n> Minimum camera frames per window (default: 3)
|
||||
--min-confidence <f> Minimum average confidence threshold (default: 0.5)
|
||||
--clock-offset-ms <ms> Manual clock offset: added to camera timestamps (default: 0)
|
||||
--help, -h Show this help
|
||||
`);
|
||||
process.exit(args.help ? 0 : 1);
|
||||
}
|
||||
|
||||
const WINDOW_FRAMES = parseInt(args['window-frames'], 10);
|
||||
const WINDOW_MS = parseInt(args['window-ms'], 10);
|
||||
const MIN_CAMERA_FRAMES = parseInt(args['min-camera-frames'], 10);
|
||||
const MIN_CONFIDENCE = parseFloat(args['min-confidence']);
|
||||
const CLOCK_OFFSET_MS = parseFloat(args['clock-offset-ms']);
|
||||
const NUM_KEYPOINTS = 17; // COCO 17-keypoint format
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Timestamp conversion
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Convert camera nanosecond timestamp to milliseconds.
|
||||
* Applies clock offset correction.
|
||||
*/
|
||||
function cameraTsToMs(tsNs) {
|
||||
return tsNs / 1e6 + CLOCK_OFFSET_MS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert ISO 8601 timestamp string to milliseconds since epoch.
|
||||
*/
|
||||
function isoToMs(isoStr) {
|
||||
return new Date(isoStr).getTime();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// IQ hex parsing (matches train-wiflow.js conventions)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Parse IQ hex string into signed byte pairs [I0, Q0, I1, Q1, ...].
|
||||
*/
|
||||
function parseIqHex(hexStr) {
|
||||
const bytes = [];
|
||||
for (let i = 0; i < hexStr.length; i += 2) {
|
||||
let val = parseInt(hexStr.substr(i, 2), 16);
|
||||
if (val > 127) val -= 256; // signed byte
|
||||
bytes.push(val);
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract amplitude from IQ data for a given number of subcarriers.
|
||||
* Returns Float32Array of amplitudes [nSubcarriers].
|
||||
* Skips first I/Q pair (DC offset) per WiFlow paper recommendation.
|
||||
*/
|
||||
function extractAmplitude(iqBytes, nSubcarriers) {
|
||||
const amp = new Float32Array(nSubcarriers);
|
||||
const start = 2; // skip first IQ pair (DC offset)
|
||||
for (let sc = 0; sc < nSubcarriers; sc++) {
|
||||
const idx = start + sc * 2;
|
||||
if (idx + 1 < iqBytes.length) {
|
||||
const I = iqBytes[idx];
|
||||
const Q = iqBytes[idx + 1];
|
||||
amp[sc] = Math.sqrt(I * I + Q * Q);
|
||||
}
|
||||
}
|
||||
return amp;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Load and parse a JSONL file, skipping blank/malformed lines.
|
||||
*/
|
||||
function loadJsonl(filePath) {
|
||||
const lines = fs.readFileSync(filePath, 'utf8').split('\n');
|
||||
const records = [];
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
try {
|
||||
records.push(JSON.parse(trimmed));
|
||||
} catch {
|
||||
// skip malformed lines
|
||||
}
|
||||
}
|
||||
return records;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load camera ground-truth file.
|
||||
* Returns array of { tsMs, keypoints, confidence, nVisible, nPersons }.
|
||||
*/
|
||||
function loadGroundTruth(filePath) {
|
||||
const raw = loadJsonl(filePath);
|
||||
const frames = [];
|
||||
for (const r of raw) {
|
||||
if (r.ts_ns == null || !r.keypoints) continue;
|
||||
frames.push({
|
||||
tsMs: cameraTsToMs(r.ts_ns),
|
||||
keypoints: r.keypoints,
|
||||
confidence: r.confidence ?? 0,
|
||||
nVisible: r.n_visible ?? 0,
|
||||
nPersons: r.n_persons ?? 1,
|
||||
});
|
||||
}
|
||||
// Sort by timestamp
|
||||
frames.sort((a, b) => a.tsMs - b.tsMs);
|
||||
return frames;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load CSI recording file.
|
||||
* Separates raw_csi frames and feature frames.
|
||||
*/
|
||||
function loadCsi(filePath) {
|
||||
const raw = loadJsonl(filePath);
|
||||
const rawCsi = [];
|
||||
const features = [];
|
||||
|
||||
for (const r of raw) {
|
||||
if (!r.timestamp) continue;
|
||||
const tsMs = isoToMs(r.timestamp);
|
||||
if (isNaN(tsMs)) continue;
|
||||
|
||||
if (r.type === 'raw_csi') {
|
||||
rawCsi.push({
|
||||
tsMs,
|
||||
nodeId: r.node_id,
|
||||
subcarriers: r.subcarriers ?? 128,
|
||||
iqHex: r.iq_hex,
|
||||
rssi: r.rssi,
|
||||
seq: r.seq,
|
||||
});
|
||||
} else if (r.type === 'feature') {
|
||||
features.push({
|
||||
tsMs,
|
||||
nodeId: r.node_id,
|
||||
features: r.features,
|
||||
rssi: r.rssi,
|
||||
seq: r.seq,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by timestamp
|
||||
rawCsi.sort((a, b) => a.tsMs - b.tsMs);
|
||||
features.sort((a, b) => a.tsMs - b.tsMs);
|
||||
return { rawCsi, features };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Windowing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Group frames into non-overlapping windows of `windowSize` consecutive frames.
|
||||
*/
|
||||
function groupIntoWindows(frames, windowSize) {
|
||||
const windows = [];
|
||||
for (let i = 0; i + windowSize <= frames.length; i += windowSize) {
|
||||
windows.push(frames.slice(i, i + windowSize));
|
||||
}
|
||||
return windows;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Camera frame matching (binary search)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Find all camera frames within [tStart, tEnd] using binary search.
|
||||
*/
|
||||
function findCameraFramesInRange(cameraFrames, tStartMs, tEndMs) {
|
||||
// Binary search for first frame >= tStartMs
|
||||
let lo = 0;
|
||||
let hi = cameraFrames.length;
|
||||
while (lo < hi) {
|
||||
const mid = (lo + hi) >>> 1;
|
||||
if (cameraFrames[mid].tsMs < tStartMs) lo = mid + 1;
|
||||
else hi = mid;
|
||||
}
|
||||
|
||||
const matched = [];
|
||||
for (let i = lo; i < cameraFrames.length; i++) {
|
||||
if (cameraFrames[i].tsMs > tEndMs) break;
|
||||
matched.push(cameraFrames[i]);
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Keypoint averaging (confidence-weighted)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Average keypoints weighted by per-frame confidence.
|
||||
* Returns { keypoints: [[x,y],...], avgConfidence }.
|
||||
*/
|
||||
function averageKeypoints(cameraFrames) {
|
||||
let totalWeight = 0;
|
||||
const sumKp = new Array(NUM_KEYPOINTS).fill(null).map(() => [0, 0]);
|
||||
|
||||
for (const f of cameraFrames) {
|
||||
const w = f.confidence || 1e-6;
|
||||
totalWeight += w;
|
||||
for (let k = 0; k < NUM_KEYPOINTS && k < f.keypoints.length; k++) {
|
||||
sumKp[k][0] += f.keypoints[k][0] * w;
|
||||
sumKp[k][1] += f.keypoints[k][1] * w;
|
||||
}
|
||||
}
|
||||
|
||||
if (totalWeight === 0) totalWeight = 1;
|
||||
const keypoints = sumKp.map(([x, y]) => [x / totalWeight, y / totalWeight]);
|
||||
const avgConfidence = cameraFrames.reduce((s, f) => s + (f.confidence || 0), 0) / cameraFrames.length;
|
||||
|
||||
return { keypoints, avgConfidence };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CSI matrix extraction
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Extract CSI amplitude matrix from raw_csi window.
|
||||
* Returns { data: flat Float32Array, shape: [subcarriers, windowFrames] }.
|
||||
*/
|
||||
function extractCsiMatrix(window) {
|
||||
const nFrames = window.length;
|
||||
const nSc = window[0].subcarriers || 128;
|
||||
const matrix = new Float32Array(nSc * nFrames);
|
||||
|
||||
for (let f = 0; f < nFrames; f++) {
|
||||
const frame = window[f];
|
||||
if (frame.iqHex) {
|
||||
const iq = parseIqHex(frame.iqHex);
|
||||
const amp = extractAmplitude(iq, nSc);
|
||||
matrix.set(amp, f * nSc);
|
||||
}
|
||||
}
|
||||
|
||||
return { data: Array.from(matrix), shape: [nSc, nFrames] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract feature matrix from feature-type window.
|
||||
* Returns { data: flat array, shape: [featureDim, windowFrames] }.
|
||||
*/
|
||||
function extractFeatureMatrix(window) {
|
||||
const nFrames = window.length;
|
||||
const dim = window[0].features ? window[0].features.length : 8;
|
||||
const matrix = new Float32Array(dim * nFrames);
|
||||
|
||||
for (let f = 0; f < nFrames; f++) {
|
||||
const feats = window[f].features || new Array(dim).fill(0);
|
||||
for (let d = 0; d < dim; d++) {
|
||||
matrix[f * dim + d] = feats[d] || 0;
|
||||
}
|
||||
}
|
||||
|
||||
return { data: Array.from(matrix), shape: [dim, nFrames] };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main alignment
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function align() {
|
||||
const gtPath = path.resolve(args.gt);
|
||||
const csiPath = path.resolve(args.csi);
|
||||
|
||||
// Determine output path
|
||||
let outputPath;
|
||||
if (args.output) {
|
||||
outputPath = path.resolve(args.output);
|
||||
} else {
|
||||
const baseName = path.basename(csiPath, '.csi.jsonl');
|
||||
outputPath = path.resolve('data', 'paired', `${baseName}.paired.jsonl`);
|
||||
}
|
||||
|
||||
// Ensure output directory exists
|
||||
const outputDir = path.dirname(outputPath);
|
||||
if (!fs.existsSync(outputDir)) {
|
||||
fs.mkdirSync(outputDir, { recursive: true });
|
||||
}
|
||||
|
||||
console.log('=== Ground-Truth Alignment (ADR-079) ===');
|
||||
console.log(` GT file: ${gtPath}`);
|
||||
console.log(` CSI file: ${csiPath}`);
|
||||
console.log(` Output: ${outputPath}`);
|
||||
console.log(` Window: ${WINDOW_FRAMES} frames / ${WINDOW_MS} ms`);
|
||||
console.log(` Min camera frames: ${MIN_CAMERA_FRAMES}`);
|
||||
console.log(` Min confidence: ${MIN_CONFIDENCE}`);
|
||||
console.log(` Clock offset: ${CLOCK_OFFSET_MS} ms`);
|
||||
console.log();
|
||||
|
||||
// Load data
|
||||
console.log('Loading ground-truth...');
|
||||
const cameraFrames = loadGroundTruth(gtPath);
|
||||
console.log(` ${cameraFrames.length} camera frames loaded`);
|
||||
if (cameraFrames.length > 0) {
|
||||
console.log(` Time range: ${new Date(cameraFrames[0].tsMs).toISOString()} -> ${new Date(cameraFrames[cameraFrames.length - 1].tsMs).toISOString()}`);
|
||||
}
|
||||
|
||||
console.log('Loading CSI data...');
|
||||
const { rawCsi, features } = loadCsi(csiPath);
|
||||
console.log(` ${rawCsi.length} raw_csi frames, ${features.length} feature frames`);
|
||||
|
||||
// Decide which CSI source to use
|
||||
const useRawCsi = rawCsi.length >= WINDOW_FRAMES;
|
||||
const csiSource = useRawCsi ? rawCsi : features;
|
||||
const sourceLabel = useRawCsi ? 'raw_csi' : 'feature';
|
||||
|
||||
if (csiSource.length < WINDOW_FRAMES) {
|
||||
console.error(`ERROR: Not enough CSI frames (${csiSource.length}) for even one window of ${WINDOW_FRAMES} frames.`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(` Using ${sourceLabel} frames (${csiSource.length} total)`);
|
||||
if (csiSource.length > 0) {
|
||||
console.log(` CSI time range: ${new Date(csiSource[0].tsMs).toISOString()} -> ${new Date(csiSource[csiSource.length - 1].tsMs).toISOString()}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
// Group CSI into windows
|
||||
const windows = groupIntoWindows(csiSource, WINDOW_FRAMES);
|
||||
console.log(`Grouped into ${windows.length} CSI windows`);
|
||||
|
||||
// Align
|
||||
const paired = [];
|
||||
let totalConfidence = 0;
|
||||
|
||||
for (const window of windows) {
|
||||
const tStartMs = window[0].tsMs;
|
||||
const tEndMs = window[window.length - 1].tsMs;
|
||||
|
||||
// Expand window if actual time span is smaller than window-ms
|
||||
const halfWindow = WINDOW_MS / 2;
|
||||
const midpoint = (tStartMs + tEndMs) / 2;
|
||||
const searchStart = Math.min(tStartMs, midpoint - halfWindow);
|
||||
const searchEnd = Math.max(tEndMs, midpoint + halfWindow);
|
||||
|
||||
// Find matching camera frames
|
||||
const matched = findCameraFramesInRange(cameraFrames, searchStart, searchEnd);
|
||||
|
||||
if (matched.length < MIN_CAMERA_FRAMES) continue;
|
||||
|
||||
// Check average confidence
|
||||
const avgConf = matched.reduce((s, f) => s + (f.confidence || 0), 0) / matched.length;
|
||||
if (avgConf < MIN_CONFIDENCE) continue;
|
||||
|
||||
// Average keypoints weighted by confidence
|
||||
const { keypoints, avgConfidence } = averageKeypoints(matched);
|
||||
|
||||
// Extract CSI matrix
|
||||
const csiMatrix = useRawCsi
|
||||
? extractCsiMatrix(window)
|
||||
: extractFeatureMatrix(window);
|
||||
|
||||
paired.push({
|
||||
csi: csiMatrix.data,
|
||||
csi_shape: csiMatrix.shape,
|
||||
kp: keypoints,
|
||||
conf: Math.round(avgConfidence * 1000) / 1000,
|
||||
n_camera_frames: matched.length,
|
||||
ts_start: new Date(tStartMs).toISOString(),
|
||||
ts_end: new Date(tEndMs).toISOString(),
|
||||
});
|
||||
|
||||
totalConfidence += avgConfidence;
|
||||
}
|
||||
|
||||
// Write output
|
||||
const outputLines = paired.map(s => JSON.stringify(s));
|
||||
fs.writeFileSync(outputPath, outputLines.join('\n') + (outputLines.length > 0 ? '\n' : ''));
|
||||
|
||||
// Print summary
|
||||
const alignmentRate = windows.length > 0 ? (paired.length / windows.length * 100) : 0;
|
||||
const avgPairedConf = paired.length > 0 ? (totalConfidence / paired.length) : 0;
|
||||
|
||||
console.log();
|
||||
console.log('=== Alignment Summary ===');
|
||||
console.log(` Total CSI windows: ${windows.length}`);
|
||||
console.log(` Paired samples: ${paired.length}`);
|
||||
console.log(` Alignment rate: ${alignmentRate.toFixed(1)}%`);
|
||||
console.log(` Avg confidence (paired): ${avgPairedConf.toFixed(3)}`);
|
||||
console.log(` CSI source: ${sourceLabel} (${csiMatrix_shapeLabel(paired, useRawCsi)})`);
|
||||
if (paired.length > 0) {
|
||||
console.log(` Time range covered: ${paired[0].ts_start} -> ${paired[paired.length - 1].ts_end}`);
|
||||
}
|
||||
console.log(` Output written: ${outputPath}`);
|
||||
console.log();
|
||||
|
||||
if (paired.length === 0) {
|
||||
console.log('WARNING: No paired samples produced. Check that camera and CSI time ranges overlap.');
|
||||
console.log(' Hint: Use --clock-offset-ms to correct misaligned clocks.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format CSI matrix shape label for summary.
|
||||
*/
|
||||
function csiMatrix_shapeLabel(paired, useRawCsi) {
|
||||
if (paired.length === 0) return useRawCsi ? `[128, ${WINDOW_FRAMES}]` : `[8, ${WINDOW_FRAMES}]`;
|
||||
const shape = paired[0].csi_shape;
|
||||
return `[${shape[0]}, ${shape[1]}]`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
align();
|
||||
@@ -0,0 +1,341 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Camera ground-truth collection for WiFi pose estimation training (ADR-079).
|
||||
|
||||
Captures webcam keypoints via MediaPipe PoseLandmarker (Tasks API) and
|
||||
synchronizes with ESP32 CSI recording from the sensing server.
|
||||
|
||||
Output: JSONL file in data/ground-truth/ with per-frame 17-keypoint COCO poses.
|
||||
|
||||
Usage:
|
||||
python scripts/collect-ground-truth.py --preview --duration 60
|
||||
python scripts/collect-ground-truth.py --server http://192.168.1.10:3000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks.python import BaseOptions
|
||||
from mediapipe.tasks.python.vision import (
|
||||
PoseLandmarker,
|
||||
PoseLandmarkerOptions,
|
||||
RunningMode,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MediaPipe 33 landmarks -> 17 COCO keypoints
|
||||
# ---------------------------------------------------------------------------
|
||||
# COCO idx : MP idx : joint name
|
||||
# 0 : 0 : nose
|
||||
# 1 : 2 : left_eye
|
||||
# 2 : 5 : right_eye
|
||||
# 3 : 7 : left_ear
|
||||
# 4 : 8 : right_ear
|
||||
# 5 : 11 : left_shoulder
|
||||
# 6 : 12 : right_shoulder
|
||||
# 7 : 13 : left_elbow
|
||||
# 8 : 14 : right_elbow
|
||||
# 9 : 15 : left_wrist
|
||||
# 10 : 16 : right_wrist
|
||||
# 11 : 23 : left_hip
|
||||
# 12 : 24 : right_hip
|
||||
# 13 : 25 : left_knee
|
||||
# 14 : 26 : right_knee
|
||||
# 15 : 27 : left_ankle
|
||||
# 16 : 28 : right_ankle
|
||||
|
||||
MP_TO_COCO = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
|
||||
|
||||
COCO_BONES = [
|
||||
(5, 7), (7, 9), (6, 8), (8, 10), # arms
|
||||
(5, 6), # shoulders
|
||||
(11, 13), (13, 15), (12, 14), (14, 16), # legs
|
||||
(11, 12), # hips
|
||||
(5, 11), (6, 12), # torso
|
||||
(0, 1), (0, 2), (1, 3), (2, 4), # face
|
||||
]
|
||||
|
||||
MODEL_URL = (
|
||||
"https://storage.googleapis.com/mediapipe-models/"
|
||||
"pose_landmarker/pose_landmarker_lite/float16/latest/"
|
||||
"pose_landmarker_lite.task"
|
||||
)
|
||||
MODEL_FILENAME = "pose_landmarker_lite.task"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ensure_model(cache_dir: Path) -> Path:
|
||||
"""Download the PoseLandmarker model if not already cached."""
|
||||
model_path = cache_dir / MODEL_FILENAME
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Downloading {MODEL_FILENAME} ...")
|
||||
try:
|
||||
urllib.request.urlretrieve(MODEL_URL, str(model_path))
|
||||
print(f" saved to {model_path}")
|
||||
except Exception as exc:
|
||||
print(f"ERROR: Failed to download model: {exc}", file=sys.stderr)
|
||||
print(
|
||||
"Download manually from:\n"
|
||||
f" {MODEL_URL}\n"
|
||||
f"and place at {model_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
return model_path
|
||||
|
||||
|
||||
def post_json(url: str, payload: dict | None = None, timeout: float = 5.0) -> bool:
|
||||
"""POST JSON to a URL. Returns True on success, False on failure."""
|
||||
data = json.dumps(payload or {}).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return 200 <= resp.status < 300
|
||||
except Exception as exc:
|
||||
print(f"WARNING: POST {url} failed: {exc}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def draw_skeleton(frame: np.ndarray, keypoints: list[list[float]], w: int, h: int):
|
||||
"""Draw COCO skeleton overlay on a BGR frame."""
|
||||
pts = []
|
||||
for x, y in keypoints:
|
||||
px, py = int(x * w), int(y * h)
|
||||
pts.append((px, py))
|
||||
cv2.circle(frame, (px, py), 4, (0, 255, 0), -1)
|
||||
|
||||
for i, j in COCO_BONES:
|
||||
if i < len(pts) and j < len(pts):
|
||||
cv2.line(frame, pts[i], pts[j], (0, 200, 255), 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main collection loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collect camera ground-truth keypoints for WiFi pose training (ADR-079)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
default="http://localhost:3000",
|
||||
help="Sensing server URL (default: http://localhost:3000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preview",
|
||||
action="store_true",
|
||||
help="Show live skeleton overlay window",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Recording duration in seconds (default: 300)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Camera device index (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="data/ground-truth",
|
||||
help="Output directory (default: data/ground-truth)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Resolve paths relative to repo root ---
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
output_dir = repo_root / args.output
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_dir = repo_root / "data" / ".cache"
|
||||
|
||||
# --- Download / locate model ---
|
||||
model_path = ensure_model(cache_dir)
|
||||
|
||||
# --- Open camera ---
|
||||
cap = cv2.VideoCapture(args.camera)
|
||||
if not cap.isOpened():
|
||||
print(
|
||||
f"ERROR: Cannot open camera index {args.camera}. "
|
||||
"Check that a webcam is connected and not in use by another app.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
print(f"Camera opened: {frame_w}x{frame_h}")
|
||||
|
||||
# --- Create PoseLandmarker ---
|
||||
options = PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=str(model_path)),
|
||||
running_mode=RunningMode.IMAGE,
|
||||
num_poses=1,
|
||||
min_pose_detection_confidence=0.5,
|
||||
min_pose_presence_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
)
|
||||
landmarker = PoseLandmarker.create_from_options(options)
|
||||
|
||||
# --- Output file ---
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = output_dir / f"keypoints_{timestamp_str}.jsonl"
|
||||
out_file = open(out_path, "w", encoding="utf-8")
|
||||
print(f"Output: {out_path}")
|
||||
|
||||
# --- Start CSI recording ---
|
||||
recording_url_start = f"{args.server}/api/v1/recording/start"
|
||||
recording_url_stop = f"{args.server}/api/v1/recording/stop"
|
||||
csi_started = post_json(recording_url_start)
|
||||
if csi_started:
|
||||
print("CSI recording started on sensing server.")
|
||||
else:
|
||||
print(
|
||||
"WARNING: Could not start CSI recording. "
|
||||
"Camera keypoints will still be captured.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# --- Graceful shutdown ---
|
||||
shutdown_requested = False
|
||||
|
||||
def _handle_signal(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
shutdown_requested = True
|
||||
|
||||
signal.signal(signal.SIGINT, _handle_signal)
|
||||
signal.signal(signal.SIGTERM, _handle_signal)
|
||||
|
||||
# --- Collection loop ---
|
||||
start_time = time.monotonic()
|
||||
frame_count = 0
|
||||
total_confidence = 0.0
|
||||
total_visible = 0
|
||||
|
||||
print(f"Collecting for {args.duration}s ... (press 'q' in preview to stop)")
|
||||
|
||||
try:
|
||||
while not shutdown_requested:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= args.duration:
|
||||
break
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("WARNING: Failed to read frame, retrying ...", file=sys.stderr)
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
ts_ns = time.time_ns()
|
||||
|
||||
# Convert BGR -> RGB for MediaPipe
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
|
||||
|
||||
result = landmarker.detect(mp_image)
|
||||
|
||||
n_persons = len(result.pose_landmarks)
|
||||
|
||||
if n_persons > 0:
|
||||
landmarks = result.pose_landmarks[0]
|
||||
keypoints = []
|
||||
visibilities = []
|
||||
for coco_idx in range(17):
|
||||
mp_idx = MP_TO_COCO[coco_idx]
|
||||
lm = landmarks[mp_idx]
|
||||
keypoints.append([round(lm.x, 5), round(lm.y, 5)])
|
||||
visibilities.append(lm.visibility if lm.visibility else 0.0)
|
||||
|
||||
confidence = float(np.mean(visibilities))
|
||||
n_visible = int(sum(1 for v in visibilities if v > 0.5))
|
||||
else:
|
||||
keypoints = []
|
||||
confidence = 0.0
|
||||
n_visible = 0
|
||||
|
||||
record = {
|
||||
"ts_ns": ts_ns,
|
||||
"keypoints": keypoints,
|
||||
"confidence": round(confidence, 4),
|
||||
"n_visible": n_visible,
|
||||
"n_persons": n_persons,
|
||||
}
|
||||
out_file.write(json.dumps(record) + "\n")
|
||||
frame_count += 1
|
||||
total_confidence += confidence
|
||||
total_visible += n_visible
|
||||
|
||||
# Preview overlay
|
||||
if args.preview and keypoints:
|
||||
draw_skeleton(frame, keypoints, frame_w, frame_h)
|
||||
|
||||
if args.preview:
|
||||
remaining = max(0, int(args.duration - elapsed))
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"Frames: {frame_count} Visible: {n_visible}/17 Time: {remaining}s",
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
)
|
||||
cv2.imshow("Ground Truth Collection (ADR-079)", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
finally:
|
||||
# --- Cleanup ---
|
||||
out_file.close()
|
||||
cap.release()
|
||||
if args.preview:
|
||||
cv2.destroyAllWindows()
|
||||
landmarker.close()
|
||||
|
||||
# Stop CSI recording
|
||||
if csi_started:
|
||||
if post_json(recording_url_stop):
|
||||
print("CSI recording stopped.")
|
||||
else:
|
||||
print("WARNING: Failed to stop CSI recording.", file=sys.stderr)
|
||||
|
||||
# --- Summary ---
|
||||
avg_conf = total_confidence / frame_count if frame_count > 0 else 0.0
|
||||
avg_vis = total_visible / frame_count if frame_count > 0 else 0.0
|
||||
print()
|
||||
print("=== Collection Summary ===")
|
||||
print(f" Total frames: {frame_count}")
|
||||
print(f" Avg confidence: {avg_conf:.3f}")
|
||||
print(f" Avg visible joints: {avg_vis:.1f} / 17")
|
||||
print(f" Output: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env node
|
||||
'use strict';
|
||||
/**
|
||||
* Deep RF Intelligence Report — discovers everything WiFi can see.
|
||||
* Usage: node scripts/deep-scan.js --bind 192.168.1.20 --duration 10
|
||||
*/
|
||||
|
||||
const dgram = require('dgram');
|
||||
const { parseArgs } = require('util');
|
||||
|
||||
const { values: args } = parseArgs({
|
||||
options: {
|
||||
port: { type: 'string', default: '5006' },
|
||||
bind: { type: 'string', default: '0.0.0.0' },
|
||||
duration: { type: 'string', default: '10' },
|
||||
},
|
||||
strict: true,
|
||||
});
|
||||
|
||||
const PORT = parseInt(args.port);
|
||||
const BIND = args.bind;
|
||||
const DUR = parseInt(args.duration) * 1000;
|
||||
|
||||
const vitals = {}; // nid -> [{time, br, hr, rssi, persons, motion, presence}]
|
||||
const features = {}; // nid -> [{time, features}]
|
||||
const raw = {}; // nid -> [{time, amps, phases, rssi, nSub}]
|
||||
|
||||
const server = dgram.createSocket('udp4');
|
||||
|
||||
server.on('message', (buf, rinfo) => {
|
||||
if (buf.length < 5) return;
|
||||
const magic = buf.readUInt32LE(0);
|
||||
const nid = buf[4];
|
||||
|
||||
if (magic === 0xC5110001 && buf.length > 20) {
|
||||
const iq = buf.subarray(20);
|
||||
const nSub = Math.floor(iq.length / 2);
|
||||
const amps = [];
|
||||
for (let i = 0; i < nSub * 2 && i < iq.length - 1; i += 2) {
|
||||
const I = iq.readInt8(i), Q = iq.readInt8(i + 1);
|
||||
amps.push(Math.sqrt(I * I + Q * Q));
|
||||
}
|
||||
if (!raw[nid]) raw[nid] = [];
|
||||
raw[nid].push({ time: Date.now(), amps, rssi: buf.readInt8(5), nSub });
|
||||
} else if (magic === 0xC5110002 && buf.length >= 32) {
|
||||
const br = buf.readUInt16LE(6) / 100;
|
||||
const hr = buf.readUInt32LE(8) / 10000;
|
||||
const rssi = buf.readInt8(12);
|
||||
const persons = buf[13];
|
||||
const motion = buf.readFloatLE(16);
|
||||
const presence = buf.readFloatLE(20);
|
||||
if (!vitals[nid]) vitals[nid] = [];
|
||||
vitals[nid].push({ time: Date.now(), br, hr, rssi, persons, motion, presence });
|
||||
} else if (magic === 0xC5110003 && buf.length >= 48) {
|
||||
const f = [];
|
||||
for (let i = 0; i < 8; i++) f.push(buf.readFloatLE(16 + i * 4));
|
||||
if (!features[nid]) features[nid] = [];
|
||||
features[nid].push({ time: Date.now(), features: f });
|
||||
}
|
||||
});
|
||||
|
||||
server.on('listening', () => {
|
||||
console.log(`Scanning on ${BIND}:${PORT} for ${DUR / 1000}s...\n`);
|
||||
});
|
||||
|
||||
server.bind(PORT, BIND);
|
||||
|
||||
setTimeout(() => {
|
||||
server.close();
|
||||
report();
|
||||
}, DUR);
|
||||
|
||||
function avg(arr) { return arr.length ? arr.reduce((a, b) => a + b) / arr.length : 0; }
|
||||
function std(arr) { const m = avg(arr); return Math.sqrt(arr.reduce((s, v) => s + (v - m) ** 2, 0) / (arr.length || 1)); }
|
||||
|
||||
function report() {
|
||||
const bar = (v, max = 20) => '█'.repeat(Math.min(Math.round(v * max), max)) + '░'.repeat(Math.max(max - Math.round(v * max), 0));
|
||||
const line = '═'.repeat(70);
|
||||
|
||||
console.log(line);
|
||||
console.log(' DEEP RF INTELLIGENCE REPORT — What WiFi Sees In Your Room');
|
||||
console.log(line);
|
||||
|
||||
// 1. WHO'S THERE
|
||||
console.log('\n📡 WHO IS IN THE ROOM');
|
||||
for (const nid of Object.keys(vitals).sort()) {
|
||||
const v = vitals[nid];
|
||||
const lastP = v[v.length - 1].presence;
|
||||
const avgMotion = avg(v.map(x => x.motion));
|
||||
console.log(` Node ${nid}: presence=${lastP.toFixed(1)} motion=${avgMotion.toFixed(1)} → ${lastP > 0.5 ? 'SOMEONE IS HERE' : 'Room may be empty'}`);
|
||||
}
|
||||
|
||||
// 2. WHAT ARE THEY DOING
|
||||
console.log('\n🏃 ACTIVITY DETECTION');
|
||||
for (const nid of Object.keys(vitals).sort()) {
|
||||
const v = vitals[nid];
|
||||
const motions = v.map(x => x.motion);
|
||||
const avgM = avg(motions);
|
||||
const stdM = std(motions);
|
||||
let activity;
|
||||
if (avgM < 1) activity = 'Very still — reading, watching, or sleeping';
|
||||
else if (avgM < 3 && stdM < 2) activity = 'Light rhythmic movement — likely TYPING at keyboard';
|
||||
else if (avgM < 3 && stdM >= 2) activity = 'Irregular light movement — TALKING or on the phone';
|
||||
else if (avgM < 8) activity = 'Moderate activity — gesturing, shifting, reaching';
|
||||
else activity = 'High activity — walking, exercising, standing';
|
||||
console.log(` Node ${nid}: energy=${avgM.toFixed(1)} variability=${stdM.toFixed(1)} → ${activity}`);
|
||||
}
|
||||
|
||||
// 3. VITAL SIGNS
|
||||
console.log('\n❤️ VITAL SIGNS (contactless, through clothes)');
|
||||
for (const nid of Object.keys(vitals).sort()) {
|
||||
const v = vitals[nid];
|
||||
const brs = v.map(x => x.br);
|
||||
const hrs = v.map(x => x.hr);
|
||||
const brAvg = avg(brs), brStd = std(brs);
|
||||
const hrAvg = avg(hrs), hrStd = std(hrs);
|
||||
|
||||
let brState = brStd < 2 ? 'very regular (calm/focused)' : brStd < 5 ? 'normal' : 'variable (talking/active)';
|
||||
let hrState = hrAvg < 60 ? 'athletic resting' : hrAvg < 80 ? 'relaxed' : hrAvg < 100 ? 'normal/active' : 'elevated';
|
||||
let stressHint = hrStd < 3 ? 'LOW stress (steady HR)' : hrStd < 8 ? 'MODERATE' : 'HIGH variability (could be relaxed OR stressed)';
|
||||
|
||||
console.log(` Node ${nid}:`);
|
||||
console.log(` Breathing: ${brAvg.toFixed(0)} BPM (±${brStd.toFixed(1)}) — ${brState}`);
|
||||
console.log(` Heart rate: ${hrAvg.toFixed(0)} BPM (±${hrStd.toFixed(1)}) — ${hrState}`);
|
||||
console.log(` Stress indicator: ${stressHint}`);
|
||||
}
|
||||
|
||||
// 4. YOUR DISTANCE FROM EACH NODE
|
||||
console.log('\n📏 POSITION IN ROOM');
|
||||
const distances = {};
|
||||
for (const nid of Object.keys(vitals).sort()) {
|
||||
const rssis = vitals[nid].map(x => x.rssi);
|
||||
const avgRssi = avg(rssis);
|
||||
const dist = Math.pow(10, (-30 - avgRssi) / 20);
|
||||
distances[nid] = dist;
|
||||
console.log(` Node ${nid}: RSSI=${avgRssi.toFixed(0)} dBm → ~${dist.toFixed(1)}m away`);
|
||||
}
|
||||
const nids = Object.keys(distances).sort();
|
||||
if (nids.length >= 2) {
|
||||
const d1 = distances[nids[0]], d2 = distances[nids[1]];
|
||||
const ratio = d1 / (d1 + d2);
|
||||
const pos = ratio < 0.4 ? 'closer to Node ' + nids[0] : ratio > 0.6 ? 'closer to Node ' + nids[1] : 'CENTERED between nodes';
|
||||
console.log(` Position: ${pos} (ratio: ${(ratio * 100).toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// 5. OBJECTS IN THE ROOM (from subcarrier nulls)
|
||||
console.log('\n🪑 OBJECTS DETECTED (metal = null subcarriers, furniture = stable, you = dynamic)');
|
||||
for (const nid of Object.keys(raw).sort()) {
|
||||
const frames = raw[nid];
|
||||
if (!frames.length) continue;
|
||||
const nSub = frames[0].nSub;
|
||||
|
||||
// Compute per-subcarrier variance
|
||||
const ampMeans = new Float64Array(nSub);
|
||||
const ampVars = new Float64Array(nSub);
|
||||
for (const f of frames) {
|
||||
for (let i = 0; i < Math.min(nSub, f.amps.length); i++) ampMeans[i] += f.amps[i];
|
||||
}
|
||||
for (let i = 0; i < nSub; i++) ampMeans[i] /= frames.length;
|
||||
for (const f of frames) {
|
||||
for (let i = 0; i < Math.min(nSub, f.amps.length); i++) ampVars[i] += (f.amps[i] - ampMeans[i]) ** 2;
|
||||
}
|
||||
for (let i = 0; i < nSub; i++) ampVars[i] = Math.sqrt(ampVars[i] / frames.length);
|
||||
|
||||
let nullCount = 0, dynamicCount = 0, staticCount = 0;
|
||||
const overallMean = ampMeans.reduce((a, b) => a + b) / nSub;
|
||||
for (let i = 0; i < nSub; i++) {
|
||||
if (ampMeans[i] < overallMean * 0.15) nullCount++;
|
||||
else if (ampVars[i] > 1.0) dynamicCount++;
|
||||
else staticCount++;
|
||||
}
|
||||
|
||||
console.log(` Node ${nid} (${nSub} subcarriers, ${frames.length} frames):`);
|
||||
console.log(` 🔩 Metal objects: ${nullCount} null subcarriers (${(100 * nullCount / nSub).toFixed(0)}%) — desk frame, monitor bezel, laptop chassis`);
|
||||
console.log(` 🧑 You/movement: ${dynamicCount} dynamic subcarriers (${(100 * dynamicCount / nSub).toFixed(0)}%) — person + micro-movements`);
|
||||
console.log(` 🧱 Walls/furniture: ${staticCount} static (${(100 * staticCount / nSub).toFixed(0)}%) — walls, ceiling, wooden furniture`);
|
||||
}
|
||||
|
||||
// 6. ELECTRONICS DETECTED
|
||||
console.log('\n💻 ELECTRONICS (from WiFi network scan perspective)');
|
||||
console.log(' Known devices transmitting WiFi in range:');
|
||||
console.log(' • Your router (ruv.net) — strongest signal, channel 5');
|
||||
console.log(' • HP M255 LaserJet — WiFi Direct on channel 5, ~2m away');
|
||||
console.log(' • Cognitum Seed — if plugged in (Pi Zero 2W)');
|
||||
console.log(' • 2x ESP32-S3 — the sensing nodes themselves');
|
||||
console.log(' • Your laptop/desktop — connected to ruv.net');
|
||||
console.log(' Neighbor devices (through walls):');
|
||||
console.log(' • COGECO-21B20 (100% signal, ch 11) — very close neighbor');
|
||||
console.log(' • conclusion mesh (44%, ch 3) — mesh network nearby');
|
||||
console.log(' • NETGEAR72 (42%, ch 9) — another neighbor');
|
||||
|
||||
// 7. INVISIBLE PHYSICS
|
||||
console.log('\n🔬 INVISIBLE PHYSICS');
|
||||
for (const nid of Object.keys(raw).sort()) {
|
||||
const frames = raw[nid];
|
||||
if (frames.length < 2) continue;
|
||||
|
||||
// Phase stability = room stability
|
||||
const first = frames[0], last = frames[frames.length - 1];
|
||||
const nCommon = Math.min(first.amps.length, last.amps.length);
|
||||
let phaseShift = 0;
|
||||
for (let i = 0; i < nCommon; i++) {
|
||||
const ampChange = Math.abs(last.amps[i] - first.amps[i]);
|
||||
phaseShift += ampChange;
|
||||
}
|
||||
phaseShift /= nCommon;
|
||||
|
||||
const rssis = frames.map(f => f.rssi);
|
||||
const rssiStd = std(rssis);
|
||||
|
||||
console.log(` Node ${nid}:`);
|
||||
console.log(` Amplitude drift: ${phaseShift.toFixed(2)} over ${((last.time - first.time) / 1000).toFixed(0)}s — ${phaseShift < 1 ? 'STABLE environment' : phaseShift < 3 ? 'minor movement' : 'active changes'}`);
|
||||
console.log(` RSSI stability: ±${rssiStd.toFixed(1)} dB — ${rssiStd < 2 ? 'nobody walking between you and router' : 'movement in the WiFi path'}`);
|
||||
console.log(` Fresnel zones: ${nCommon > 100 ? '128+ subcarriers = 5cm resolution potential' : nCommon + ' subcarriers'}`);
|
||||
}
|
||||
|
||||
// 8. FEATURE FINGERPRINT
|
||||
console.log('\n🧬 YOUR RF FINGERPRINT RIGHT NOW');
|
||||
for (const nid of Object.keys(features).sort()) {
|
||||
const f = features[nid];
|
||||
if (!f.length) continue;
|
||||
const last = f[f.length - 1].features;
|
||||
const names = ['Presence', 'Motion', 'Breathing', 'HeartRate', 'PhaseVar', 'Persons', 'Fall', 'RSSI'];
|
||||
console.log(` Node ${nid}:`);
|
||||
for (let i = 0; i < 8; i++) {
|
||||
console.log(` ${names[i].padStart(10)}: ${bar(last[i])} ${last[i].toFixed(2)}`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`\n${line}`);
|
||||
console.log(' WiFi signals reveal: who, what they\'re doing, how they feel,');
|
||||
console.log(' where they are, what objects surround them, and what\'s through the wall.');
|
||||
console.log(' No cameras. No wearables. No microphones. Just radio physics.');
|
||||
console.log(line);
|
||||
}
|
||||
@@ -0,0 +1,625 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* WiFlow PCK Evaluation Script (ADR-079)
|
||||
*
|
||||
* Measures accuracy of WiFi-based pose estimation against ground-truth
|
||||
* camera keypoints using PCK (Percentage of Correct Keypoints) and MPJPE
|
||||
* (Mean Per-Joint Position Error) metrics.
|
||||
*
|
||||
* Usage:
|
||||
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl
|
||||
* node scripts/eval-wiflow.js --baseline --data data/paired/aligned.paired.jsonl
|
||||
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl --verbose
|
||||
*
|
||||
* ADR: docs/adr/ADR-079
|
||||
*/
|
||||
|
||||
'use strict';
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { parseArgs } = require('util');
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Resolve WiFlow model dependencies
|
||||
// ---------------------------------------------------------------------------
|
||||
const {
|
||||
WiFlowModel,
|
||||
COCO_KEYPOINTS,
|
||||
createRng,
|
||||
} = require(path.join(__dirname, 'wiflow-model.js'));
|
||||
|
||||
const RUVLLM_PATH = path.resolve(__dirname, '..', 'vendor', 'ruvector', 'npm', 'packages', 'ruvllm', 'src');
|
||||
const { SafeTensorsReader } = require(path.join(RUVLLM_PATH, 'export.js'));
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
const NUM_KEYPOINTS = 17;
|
||||
const DEFAULT_TORSO_LENGTH = 0.3; // normalized coords fallback
|
||||
|
||||
// Joint name aliases for display (short form)
|
||||
const JOINT_NAMES = [
|
||||
'nose', 'l_eye', 'r_eye', 'l_ear', 'r_ear',
|
||||
'l_shoulder', 'r_shoulder', 'l_elbow', 'r_elbow',
|
||||
'l_wrist', 'r_wrist', 'l_hip', 'r_hip',
|
||||
'l_knee', 'r_knee', 'l_ankle', 'r_ankle',
|
||||
];
|
||||
|
||||
// Shoulder indices: l_shoulder=5, r_shoulder=6
|
||||
// Hip indices: l_hip=11, r_hip=12
|
||||
const L_SHOULDER = 5;
|
||||
const R_SHOULDER = 6;
|
||||
const L_HIP = 11;
|
||||
const R_HIP = 12;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CLI argument parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
const { values: args } = parseArgs({
|
||||
options: {
|
||||
model: { type: 'string', short: 'm' },
|
||||
data: { type: 'string', short: 'd' },
|
||||
baseline: { type: 'boolean', default: false },
|
||||
output: { type: 'string', short: 'o' },
|
||||
verbose: { type: 'boolean', short: 'v', default: false },
|
||||
},
|
||||
strict: true,
|
||||
});
|
||||
|
||||
if (!args.data) {
|
||||
console.error('Usage: node scripts/eval-wiflow.js --data <paired-jsonl> [--model <path>] [--baseline] [--output <path>]');
|
||||
console.error('');
|
||||
console.error('Required:');
|
||||
console.error(' --data, -d <path> Paired CSI + keypoint JSONL (from align-ground-truth.js)');
|
||||
console.error('');
|
||||
console.error('Options:');
|
||||
console.error(' --model, -m <path> Path to trained model directory or JSON');
|
||||
console.error(' --baseline Evaluate proxy-based baseline (no model)');
|
||||
console.error(' --output, -o <path> Output eval report JSON');
|
||||
console.error(' --verbose, -v Verbose output');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
if (!args.model && !args.baseline) {
|
||||
console.error('Error: Must specify either --model <path> or --baseline');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Load paired JSONL samples.
|
||||
* Each line: { csi: [...], csi_shape: [S, T], kp: [[x,y],...], conf: 0.xx, ... }
|
||||
*/
|
||||
function loadPairedData(filePath) {
|
||||
const content = fs.readFileSync(filePath, 'utf-8');
|
||||
const samples = [];
|
||||
for (const line of content.split('\n')) {
|
||||
if (!line.trim()) continue;
|
||||
try {
|
||||
const s = JSON.parse(line);
|
||||
if (!s.kp || !Array.isArray(s.kp)) continue;
|
||||
if (!s.csi && !s.csi_shape) continue;
|
||||
samples.push(s);
|
||||
} catch (e) {
|
||||
// skip malformed lines
|
||||
}
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Load WiFlow model from a directory or JSON file.
|
||||
* Tries: model.safetensors, then config.json for architecture config.
|
||||
* Returns { model, name }.
|
||||
*/
|
||||
function loadModel(modelPath) {
|
||||
const stat = fs.statSync(modelPath);
|
||||
let modelDir;
|
||||
|
||||
if (stat.isDirectory()) {
|
||||
modelDir = modelPath;
|
||||
} else {
|
||||
// Assume JSON file in a model directory
|
||||
modelDir = path.dirname(modelPath);
|
||||
}
|
||||
|
||||
// Load architecture config if available
|
||||
let config = {};
|
||||
const configPath = path.join(modelDir, 'config.json');
|
||||
if (fs.existsSync(configPath)) {
|
||||
try {
|
||||
const raw = JSON.parse(fs.readFileSync(configPath, 'utf-8'));
|
||||
if (raw.custom) {
|
||||
config.inputChannels = raw.custom.inputChannels || 128;
|
||||
config.timeSteps = raw.custom.timeSteps || 20;
|
||||
config.numKeypoints = raw.custom.numKeypoints || 17;
|
||||
config.numHeads = raw.custom.numHeads || 8;
|
||||
config.seed = raw.custom.seed || 42;
|
||||
}
|
||||
} catch (e) {
|
||||
// use defaults
|
||||
}
|
||||
}
|
||||
|
||||
// Load training-metrics.json for additional config
|
||||
const metricsPath = path.join(modelDir, 'training-metrics.json');
|
||||
if (fs.existsSync(metricsPath)) {
|
||||
try {
|
||||
const metrics = JSON.parse(fs.readFileSync(metricsPath, 'utf-8'));
|
||||
if (metrics.model && metrics.model.architecture === 'wiflow') {
|
||||
// metrics available for report
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
// Create model with config
|
||||
const model = new WiFlowModel(config);
|
||||
model.setTraining(false); // eval mode
|
||||
|
||||
// Load weights from SafeTensors
|
||||
const safetensorsPath = path.join(modelDir, 'model.safetensors');
|
||||
if (fs.existsSync(safetensorsPath)) {
|
||||
const buffer = new Uint8Array(fs.readFileSync(safetensorsPath));
|
||||
const reader = new SafeTensorsReader(buffer);
|
||||
const tensorNames = reader.getTensorNames();
|
||||
|
||||
// Build tensor map for fromTensorMap
|
||||
const tensorMap = new Map();
|
||||
for (const name of tensorNames) {
|
||||
const tensor = reader.getTensor(name);
|
||||
if (tensor) {
|
||||
tensorMap.set(name, tensor.data);
|
||||
}
|
||||
}
|
||||
|
||||
model.fromTensorMap(tensorMap);
|
||||
if (args.verbose) {
|
||||
console.log(`Loaded ${tensorNames.length} tensors from ${safetensorsPath}`);
|
||||
console.log(`Model params: ${model.numParams().toLocaleString()}`);
|
||||
}
|
||||
} else {
|
||||
console.warn(`WARN: No model.safetensors found in ${modelDir}, using random weights`);
|
||||
}
|
||||
|
||||
// Derive model name
|
||||
const name = path.basename(modelDir);
|
||||
return { model, name };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Baseline proxy pose generation (ADR-072 Phase 2 heuristic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Generate a proxy standing skeleton from CSI features.
|
||||
* If presence detected (amplitude energy > threshold), place a standing
|
||||
* person at center with standard COCO proportions, perturbed by motion energy.
|
||||
*/
|
||||
function generateBaselinePose(sample) {
|
||||
const rng = createRng(42);
|
||||
|
||||
// Estimate presence from CSI amplitude energy
|
||||
const csi = sample.csi;
|
||||
let energy = 0;
|
||||
if (Array.isArray(csi)) {
|
||||
for (let i = 0; i < csi.length; i++) {
|
||||
energy += csi[i] * csi[i];
|
||||
}
|
||||
energy = Math.sqrt(energy / csi.length);
|
||||
}
|
||||
|
||||
// Estimate motion energy (variance across subcarriers)
|
||||
let motionEnergy = 0;
|
||||
if (Array.isArray(csi) && sample.csi_shape) {
|
||||
const [S, T] = sample.csi_shape;
|
||||
if (T > 1) {
|
||||
for (let s = 0; s < S; s++) {
|
||||
let sum = 0;
|
||||
let sumSq = 0;
|
||||
for (let t = 0; t < T; t++) {
|
||||
const v = csi[s * T + t] || 0;
|
||||
sum += v;
|
||||
sumSq += v * v;
|
||||
}
|
||||
const mean = sum / T;
|
||||
motionEnergy += (sumSq / T) - (mean * mean);
|
||||
}
|
||||
motionEnergy = Math.sqrt(Math.max(0, motionEnergy / S));
|
||||
}
|
||||
}
|
||||
|
||||
// Normalized presence heuristic
|
||||
const presence = Math.min(1, energy / 10);
|
||||
|
||||
if (presence < 0.3) {
|
||||
// No person detected: return zero pose
|
||||
return new Float32Array(NUM_KEYPOINTS * 2);
|
||||
}
|
||||
|
||||
// Standing skeleton at center (0.5, 0.5) with standard proportions
|
||||
// Coordinates are [x, y] in normalized [0, 1] space
|
||||
// y=0 is top, y=1 is bottom (image convention)
|
||||
const cx = 0.5;
|
||||
const headY = 0.2;
|
||||
const shoulderY = 0.32;
|
||||
const elbowY = 0.45;
|
||||
const wristY = 0.55;
|
||||
const hipY = 0.55;
|
||||
const kneeY = 0.72;
|
||||
const ankleY = 0.88;
|
||||
const shoulderW = 0.08;
|
||||
const hipW = 0.06;
|
||||
const armSpread = 0.12;
|
||||
|
||||
// Standard standing pose keypoints [x, y]
|
||||
const skeleton = [
|
||||
[cx, headY], // 0: nose
|
||||
[cx - 0.02, headY - 0.02], // 1: l_eye
|
||||
[cx + 0.02, headY - 0.02], // 2: r_eye
|
||||
[cx - 0.04, headY], // 3: l_ear
|
||||
[cx + 0.04, headY], // 4: r_ear
|
||||
[cx - shoulderW, shoulderY], // 5: l_shoulder
|
||||
[cx + shoulderW, shoulderY], // 6: r_shoulder
|
||||
[cx - armSpread, elbowY], // 7: l_elbow
|
||||
[cx + armSpread, elbowY], // 8: r_elbow
|
||||
[cx - armSpread - 0.02, wristY], // 9: l_wrist
|
||||
[cx + armSpread + 0.02, wristY], // 10: r_wrist
|
||||
[cx - hipW, hipY], // 11: l_hip
|
||||
[cx + hipW, hipY], // 12: r_hip
|
||||
[cx - hipW, kneeY], // 13: l_knee
|
||||
[cx + hipW, kneeY], // 14: r_knee
|
||||
[cx - hipW, ankleY], // 15: l_ankle
|
||||
[cx + hipW, ankleY], // 16: r_ankle
|
||||
];
|
||||
|
||||
// Perturb limbs by motion energy
|
||||
const perturbScale = Math.min(motionEnergy * 0.1, 0.05);
|
||||
const result = new Float32Array(NUM_KEYPOINTS * 2);
|
||||
for (let k = 0; k < NUM_KEYPOINTS; k++) {
|
||||
const px = (rng() - 0.5) * 2 * perturbScale;
|
||||
const py = (rng() - 0.5) * 2 * perturbScale;
|
||||
result[k * 2] = Math.max(0, Math.min(1, skeleton[k][0] + px));
|
||||
result[k * 2 + 1] = Math.max(0, Math.min(1, skeleton[k][1] + py));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Metric computation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Euclidean distance between two 2D points */
|
||||
function dist2d(x1, y1, x2, y2) {
|
||||
const dx = x1 - x2;
|
||||
const dy = y1 - y2;
|
||||
return Math.sqrt(dx * dx + dy * dy);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute torso length from ground-truth keypoints.
|
||||
* Torso = distance(mid_shoulder, mid_hip).
|
||||
* Returns DEFAULT_TORSO_LENGTH if shoulders or hips not visible.
|
||||
*/
|
||||
function computeTorsoLength(kp) {
|
||||
if (!kp || kp.length < 13) return DEFAULT_TORSO_LENGTH;
|
||||
|
||||
const lsX = kp[L_SHOULDER][0];
|
||||
const lsY = kp[L_SHOULDER][1];
|
||||
const rsX = kp[R_SHOULDER][0];
|
||||
const rsY = kp[R_SHOULDER][1];
|
||||
const lhX = kp[L_HIP][0];
|
||||
const lhY = kp[L_HIP][1];
|
||||
const rhX = kp[R_HIP][0];
|
||||
const rhY = kp[R_HIP][1];
|
||||
|
||||
// Check if joints are at origin (not visible)
|
||||
const shoulderVisible = (lsX !== 0 || lsY !== 0) && (rsX !== 0 || rsY !== 0);
|
||||
const hipVisible = (lhX !== 0 || lhY !== 0) && (rhX !== 0 || rhY !== 0);
|
||||
|
||||
if (!shoulderVisible || !hipVisible) return DEFAULT_TORSO_LENGTH;
|
||||
|
||||
const midShoulderX = (lsX + rsX) / 2;
|
||||
const midShoulderY = (lsY + rsY) / 2;
|
||||
const midHipX = (lhX + rhX) / 2;
|
||||
const midHipY = (lhY + rhY) / 2;
|
||||
|
||||
const torso = dist2d(midShoulderX, midShoulderY, midHipX, midHipY);
|
||||
return torso > 0.01 ? torso : DEFAULT_TORSO_LENGTH;
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate predictions against ground truth.
|
||||
*
|
||||
* @param {Array<{pred: Float32Array, gt: number[][], conf: number}>} results
|
||||
* @returns {object} Evaluation report
|
||||
*/
|
||||
function computeMetrics(results) {
|
||||
const n = results.length;
|
||||
if (n === 0) {
|
||||
return {
|
||||
n_samples: 0,
|
||||
pck_10: 0, pck_20: 0, pck_50: 0,
|
||||
mpjpe: 0,
|
||||
per_joint_pck20: {},
|
||||
per_joint_mpjpe: {},
|
||||
conf_weighted_pck20: 0,
|
||||
conf_weighted_mpjpe: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Accumulators
|
||||
const pckCounts = { 10: 0, 20: 0, 50: 0 };
|
||||
let totalJoints = 0;
|
||||
let totalMPJPE = 0;
|
||||
|
||||
const perJointPck20 = new Float64Array(NUM_KEYPOINTS);
|
||||
const perJointMPJPE = new Float64Array(NUM_KEYPOINTS);
|
||||
const perJointCount = new Float64Array(NUM_KEYPOINTS);
|
||||
|
||||
// Confidence-weighted accumulators
|
||||
let confWeightedPck20Num = 0;
|
||||
let confWeightedPck20Den = 0;
|
||||
let confWeightedMpjpeNum = 0;
|
||||
let confWeightedMpjpeDen = 0;
|
||||
|
||||
for (const { pred, gt, conf } of results) {
|
||||
const torso = computeTorsoLength(gt);
|
||||
const w = Math.max(conf, 1e-6);
|
||||
|
||||
for (let k = 0; k < NUM_KEYPOINTS; k++) {
|
||||
if (k >= gt.length) continue;
|
||||
|
||||
const gtX = gt[k][0];
|
||||
const gtY = gt[k][1];
|
||||
const predX = pred[k * 2];
|
||||
const predY = pred[k * 2 + 1];
|
||||
|
||||
const d = dist2d(predX, predY, gtX, gtY);
|
||||
|
||||
totalJoints++;
|
||||
totalMPJPE += d;
|
||||
|
||||
perJointMPJPE[k] += d;
|
||||
perJointCount[k] += 1;
|
||||
|
||||
// PCK at different thresholds
|
||||
if (d < 0.10 * torso) pckCounts[10]++;
|
||||
if (d < 0.20 * torso) {
|
||||
pckCounts[20]++;
|
||||
perJointPck20[k]++;
|
||||
confWeightedPck20Num += w;
|
||||
}
|
||||
if (d < 0.50 * torso) pckCounts[50]++;
|
||||
|
||||
confWeightedPck20Den += w;
|
||||
confWeightedMpjpeNum += d * w;
|
||||
confWeightedMpjpeDen += w;
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate metrics
|
||||
const pck10 = totalJoints > 0 ? pckCounts[10] / totalJoints : 0;
|
||||
const pck20 = totalJoints > 0 ? pckCounts[20] / totalJoints : 0;
|
||||
const pck50 = totalJoints > 0 ? pckCounts[50] / totalJoints : 0;
|
||||
const mpjpe = totalJoints > 0 ? totalMPJPE / totalJoints : 0;
|
||||
|
||||
// Per-joint breakdown
|
||||
const perJointPck20Map = {};
|
||||
const perJointMpjpeMap = {};
|
||||
for (let k = 0; k < NUM_KEYPOINTS; k++) {
|
||||
const name = JOINT_NAMES[k];
|
||||
perJointPck20Map[name] = perJointCount[k] > 0 ? perJointPck20[k] / perJointCount[k] : 0;
|
||||
perJointMpjpeMap[name] = perJointCount[k] > 0 ? perJointMPJPE[k] / perJointCount[k] : 0;
|
||||
}
|
||||
|
||||
// Confidence-weighted
|
||||
const confPck20 = confWeightedPck20Den > 0 ? confWeightedPck20Num / confWeightedPck20Den : 0;
|
||||
const confMpjpe = confWeightedMpjpeDen > 0 ? confWeightedMpjpeNum / confWeightedMpjpeDen : 0;
|
||||
|
||||
return {
|
||||
n_samples: n,
|
||||
pck_10: pck10,
|
||||
pck_20: pck20,
|
||||
pck_50: pck50,
|
||||
mpjpe,
|
||||
per_joint_pck20: perJointPck20Map,
|
||||
per_joint_mpjpe: perJointMpjpeMap,
|
||||
conf_weighted_pck20: confPck20,
|
||||
conf_weighted_mpjpe: confMpjpe,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Inference
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Run model inference on a single paired sample.
|
||||
* @param {WiFlowModel} model
|
||||
* @param {object} sample - { csi, csi_shape, kp, conf }
|
||||
* @returns {Float32Array} - [17*2] predicted keypoints
|
||||
*/
|
||||
function runModelInference(model, sample) {
|
||||
const csi = sample.csi;
|
||||
const shape = sample.csi_shape;
|
||||
const S = shape ? shape[0] : 128;
|
||||
const T = shape ? shape[1] : 20;
|
||||
|
||||
// Prepare input as Float32Array [S, T]
|
||||
let input;
|
||||
if (csi instanceof Float32Array) {
|
||||
input = csi;
|
||||
} else if (Array.isArray(csi)) {
|
||||
input = new Float32Array(csi);
|
||||
} else {
|
||||
input = new Float32Array(S * T);
|
||||
}
|
||||
|
||||
// Ensure correct size (pad or truncate)
|
||||
const expectedLen = model.inputChannels * model.timeSteps;
|
||||
if (input.length !== expectedLen) {
|
||||
const resized = new Float32Array(expectedLen);
|
||||
const copyLen = Math.min(input.length, expectedLen);
|
||||
resized.set(input.subarray(0, copyLen));
|
||||
input = resized;
|
||||
}
|
||||
|
||||
return model.forward(input);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Formatted output
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function formatPercent(v) {
|
||||
return (v * 100).toFixed(1) + '%';
|
||||
}
|
||||
|
||||
function formatFloat(v, decimals) {
|
||||
decimals = decimals || 4;
|
||||
return v.toFixed(decimals);
|
||||
}
|
||||
|
||||
function printReport(report) {
|
||||
console.log('');
|
||||
console.log('WiFlow Evaluation Report (ADR-079)');
|
||||
console.log('===================================');
|
||||
console.log(`Model: ${report.model}`);
|
||||
console.log(`Samples: ${report.n_samples.toLocaleString()}`);
|
||||
console.log(`PCK@10: ${formatPercent(report.pck_10)}`);
|
||||
console.log(`PCK@20: ${formatPercent(report.pck_20)}`);
|
||||
console.log(`PCK@50: ${formatPercent(report.pck_50)}`);
|
||||
console.log(`MPJPE: ${formatFloat(report.mpjpe)}`);
|
||||
console.log('');
|
||||
console.log('Per-Joint PCK@20:');
|
||||
|
||||
const maxNameLen = Math.max(...JOINT_NAMES.map(n => n.length));
|
||||
for (const name of JOINT_NAMES) {
|
||||
const pck = report.per_joint_pck20[name] || 0;
|
||||
const pad = ' '.repeat(maxNameLen - name.length + 2);
|
||||
console.log(` ${name}${pad}${formatPercent(pck)}`);
|
||||
}
|
||||
|
||||
console.log('');
|
||||
console.log('Per-Joint MPJPE:');
|
||||
for (const name of JOINT_NAMES) {
|
||||
const mpjpe = report.per_joint_mpjpe[name] || 0;
|
||||
const pad = ' '.repeat(maxNameLen - name.length + 2);
|
||||
console.log(` ${name}${pad}${formatFloat(mpjpe)}`);
|
||||
}
|
||||
|
||||
console.log('');
|
||||
console.log('Confidence-Weighted:');
|
||||
console.log(` PCK@20: ${formatPercent(report.conf_weighted_pck20)}`);
|
||||
console.log(` MPJPE: ${formatFloat(report.conf_weighted_mpjpe)}`);
|
||||
console.log('');
|
||||
console.log(`Inference: ${report.inference_latency_ms.toFixed(2)}ms/sample`);
|
||||
console.log('');
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function main() {
|
||||
// Load paired data
|
||||
if (args.verbose) console.log(`Loading paired data from ${args.data}...`);
|
||||
const samples = loadPairedData(args.data);
|
||||
if (samples.length === 0) {
|
||||
console.error('Error: No valid paired samples found in', args.data);
|
||||
process.exit(1);
|
||||
}
|
||||
if (args.verbose) console.log(`Loaded ${samples.length} paired samples`);
|
||||
|
||||
let modelName;
|
||||
let model = null;
|
||||
|
||||
if (args.baseline) {
|
||||
modelName = 'baseline-proxy';
|
||||
if (args.verbose) console.log('Running baseline proxy evaluation (ADR-072 Phase 2 heuristic)');
|
||||
} else {
|
||||
const loaded = loadModel(args.model);
|
||||
model = loaded.model;
|
||||
modelName = loaded.name;
|
||||
if (args.verbose) console.log(`Running model evaluation: ${modelName}`);
|
||||
}
|
||||
|
||||
// Run inference and collect results
|
||||
const results = [];
|
||||
const startTime = process.hrtime.bigint();
|
||||
|
||||
for (const sample of samples) {
|
||||
let pred;
|
||||
if (args.baseline) {
|
||||
pred = generateBaselinePose(sample);
|
||||
} else {
|
||||
pred = runModelInference(model, sample);
|
||||
}
|
||||
|
||||
results.push({
|
||||
pred,
|
||||
gt: sample.kp,
|
||||
conf: sample.conf || 0,
|
||||
});
|
||||
}
|
||||
|
||||
const endTime = process.hrtime.bigint();
|
||||
const totalMs = Number(endTime - startTime) / 1e6;
|
||||
const latencyMs = totalMs / samples.length;
|
||||
|
||||
// Compute metrics
|
||||
const metrics = computeMetrics(results);
|
||||
|
||||
// Build report
|
||||
const report = {
|
||||
model: modelName,
|
||||
n_samples: metrics.n_samples,
|
||||
pck_10: Math.round(metrics.pck_10 * 10000) / 10000,
|
||||
pck_20: Math.round(metrics.pck_20 * 10000) / 10000,
|
||||
pck_50: Math.round(metrics.pck_50 * 10000) / 10000,
|
||||
mpjpe: Math.round(metrics.mpjpe * 100000) / 100000,
|
||||
per_joint_pck20: {},
|
||||
per_joint_mpjpe: {},
|
||||
conf_weighted_pck20: Math.round(metrics.conf_weighted_pck20 * 10000) / 10000,
|
||||
conf_weighted_mpjpe: Math.round(metrics.conf_weighted_mpjpe * 100000) / 100000,
|
||||
inference_latency_ms: Math.round(latencyMs * 100) / 100,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
|
||||
// Round per-joint metrics
|
||||
for (const name of JOINT_NAMES) {
|
||||
report.per_joint_pck20[name] = Math.round((metrics.per_joint_pck20[name] || 0) * 10000) / 10000;
|
||||
report.per_joint_mpjpe[name] = Math.round((metrics.per_joint_mpjpe[name] || 0) * 100000) / 100000;
|
||||
}
|
||||
|
||||
// Print formatted report
|
||||
printReport(report);
|
||||
|
||||
// Write output JSON
|
||||
const outputPath = args.output ||
|
||||
(args.model
|
||||
? path.join(path.dirname(
|
||||
fs.statSync(args.model).isDirectory() ? path.join(args.model, '.') : args.model
|
||||
), 'eval-report.json')
|
||||
: 'models/wiflow-supervised/eval-report.json');
|
||||
|
||||
const outputDir = path.dirname(outputPath);
|
||||
if (!fs.existsSync(outputDir)) {
|
||||
fs.mkdirSync(outputDir, { recursive: true });
|
||||
}
|
||||
|
||||
fs.writeFileSync(outputPath, JSON.stringify(report, null, 2) + '\n');
|
||||
console.log(`Report saved to ${outputPath}`);
|
||||
}
|
||||
|
||||
main();
|
||||
@@ -6,7 +6,7 @@ echo "Host: $(hostname) | $(sysctl -n hw.ncpu 2>/dev/null || nproc) cores | $(sy
|
||||
echo ""
|
||||
|
||||
REPO_DIR="${HOME}/Projects/wifi-densepose"
|
||||
WINDOWS_HOST="100.102.238.73" # Tailscale IP of Windows machine
|
||||
WINDOWS_HOST="${WINDOWS_HOST:-}" # Set via env: export WINDOWS_HOST=<tailscale-ip>
|
||||
|
||||
# Step 1: Clone or update repo
|
||||
echo "[1/7] Setting up repository..."
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Lightweight ESP32 CSI UDP recorder (ADR-079).
|
||||
|
||||
Captures raw CSI packets from ESP32 nodes over UDP and writes to JSONL.
|
||||
Runs alongside collect-ground-truth.py for synchronized capture.
|
||||
|
||||
Usage:
|
||||
python scripts/record-csi-udp.py --duration 300 --output data/recordings
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
||||
|
||||
def parse_csi_packet(data):
|
||||
"""Parse ADR-018 binary CSI packet into dict."""
|
||||
if len(data) < 8:
|
||||
return None
|
||||
|
||||
# ADR-018 header: [magic(2), len(2), node_id(1), seq(1), rssi(1), channel(1), iq_data...]
|
||||
# Simplified: extract what we can from the raw packet
|
||||
node_id = data[4] if len(data) > 4 else 0
|
||||
rssi = struct.unpack('b', bytes([data[6]]))[0] if len(data) > 6 else 0
|
||||
channel = data[7] if len(data) > 7 else 0
|
||||
|
||||
# IQ data starts at offset 8
|
||||
iq_data = data[8:] if len(data) > 8 else b''
|
||||
n_subcarriers = len(iq_data) // 2 # I,Q pairs
|
||||
|
||||
# Compute amplitudes
|
||||
amplitudes = []
|
||||
for i in range(0, len(iq_data) - 1, 2):
|
||||
I = struct.unpack('b', bytes([iq_data[i]]))[0]
|
||||
Q = struct.unpack('b', bytes([iq_data[i + 1]]))[0]
|
||||
amplitudes.append(round((I * I + Q * Q) ** 0.5, 2))
|
||||
|
||||
return {
|
||||
"type": "raw_csi",
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.") + f"{int(time.time() * 1000) % 1000:03d}Z",
|
||||
"ts_ns": time.time_ns(),
|
||||
"node_id": node_id,
|
||||
"rssi": rssi,
|
||||
"channel": channel,
|
||||
"subcarriers": n_subcarriers,
|
||||
"amplitudes": amplitudes,
|
||||
"iq_hex": iq_data.hex(),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Record ESP32 CSI over UDP")
|
||||
parser.add_argument("--port", type=int, default=5005, help="UDP port (default: 5005)")
|
||||
parser.add_argument("--duration", type=int, default=300, help="Duration in seconds (default: 300)")
|
||||
parser.add_argument("--output", default="data/recordings", help="Output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
filename = f"csi-{int(time.time())}.csi.jsonl"
|
||||
filepath = os.path.join(args.output, filename)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("0.0.0.0", args.port))
|
||||
sock.settimeout(1)
|
||||
|
||||
print(f"Recording CSI on UDP :{args.port} for {args.duration}s")
|
||||
print(f"Output: {filepath}")
|
||||
|
||||
count = 0
|
||||
start = time.time()
|
||||
nodes_seen = set()
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
try:
|
||||
while time.time() - start < args.duration:
|
||||
try:
|
||||
data, addr = sock.recvfrom(4096)
|
||||
frame = parse_csi_packet(data)
|
||||
if frame:
|
||||
f.write(json.dumps(frame) + "\n")
|
||||
count += 1
|
||||
nodes_seen.add(frame["node_id"])
|
||||
|
||||
if count % 500 == 0:
|
||||
elapsed = time.time() - start
|
||||
rate = count / elapsed
|
||||
print(f" {count} frames | {rate:.0f} fps | "
|
||||
f"nodes: {sorted(nodes_seen)} | "
|
||||
f"{elapsed:.0f}s / {args.duration}s")
|
||||
except socket.timeout:
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped by user")
|
||||
|
||||
sock.close()
|
||||
elapsed = time.time() - start
|
||||
print(f"\n=== CSI Recording Complete ===")
|
||||
print(f" Frames: {count}")
|
||||
print(f" Duration: {elapsed:.0f}s")
|
||||
print(f" Rate: {count / max(elapsed, 1):.0f} fps")
|
||||
print(f" Nodes: {sorted(nodes_seen)}")
|
||||
print(f" Output: {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1257,9 +1257,13 @@ async function main() {
|
||||
contrastiveResult.finalLoss = finalContrastiveLoss;
|
||||
contrastiveResult.improvement = contrastiveImprovement;
|
||||
|
||||
// Export contrastive training data
|
||||
const contrastiveOutDir = contrastiveTrainer.exportTrainingData();
|
||||
console.log(` Training data exported to: ${contrastiveOutDir}`);
|
||||
// Export contrastive training data (skip for large datasets to avoid JSON string limit)
|
||||
if (contrastiveTrainer.getTripletCount() < 100000) {
|
||||
const contrastiveOutDir = contrastiveTrainer.exportTrainingData();
|
||||
console.log(` Training data exported to: ${contrastiveOutDir}`);
|
||||
} else {
|
||||
console.log(` Skipping triplet export (${contrastiveTrainer.getTripletCount()} triplets too large for JSON)`);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Phase 2: Task head training via TrainingPipeline
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -110,12 +110,18 @@ export class SensingTab {
|
||||
<div class="sensing-card-title">About This Data</div>
|
||||
<p class="sensing-about-text">
|
||||
Metrics are computed from WiFi Channel State Information (CSI).
|
||||
With <strong>1 ESP32</strong> you get presence detection, breathing
|
||||
With <strong><span id="sensingNodeCount">0</span> ESP32 node(s)</strong> you get presence detection, breathing
|
||||
estimation, and gross motion. Add <strong>3-4+ ESP32 nodes</strong>
|
||||
around the room for spatial resolution and limb-level tracking.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Node Status -->
|
||||
<div class="sensing-card" id="sensingNodeCards">
|
||||
<div class="sensing-card-title">NODE STATUS</div>
|
||||
<div id="nodeStatusContainer"></div>
|
||||
</div>
|
||||
|
||||
<!-- Extra info -->
|
||||
<div class="sensing-card">
|
||||
<div class="sensing-card-title">Details</div>
|
||||
@@ -193,6 +199,9 @@ export class SensingTab {
|
||||
|
||||
// Update HUD
|
||||
this._updateHUD(data);
|
||||
|
||||
// Update per-node panels
|
||||
this._updateNodePanels(data);
|
||||
}
|
||||
|
||||
_onStateChange(state) {
|
||||
@@ -233,6 +242,11 @@ export class SensingTab {
|
||||
const f = data.features || {};
|
||||
const c = data.classification || {};
|
||||
|
||||
// Node count
|
||||
const nodeCount = (data.nodes || []).length;
|
||||
const countEl = this.container.querySelector('#sensingNodeCount');
|
||||
if (countEl) countEl.textContent = String(nodeCount);
|
||||
|
||||
// RSSI
|
||||
this._setText('sensingRssi', `${(f.mean_rssi || -80).toFixed(1)} dBm`);
|
||||
this._setText('sensingSource', data.source || '');
|
||||
@@ -309,6 +323,57 @@ export class SensingTab {
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// ---- Per-node panels ---------------------------------------------------
|
||||
|
||||
_updateNodePanels(data) {
|
||||
const container = this.container.querySelector('#nodeStatusContainer');
|
||||
if (!container) return;
|
||||
const nodeFeatures = data.node_features || [];
|
||||
if (nodeFeatures.length === 0) {
|
||||
container.textContent = '';
|
||||
const msg = document.createElement('div');
|
||||
msg.style.cssText = 'color:#888;font-size:12px;padding:8px;';
|
||||
msg.textContent = 'No nodes detected';
|
||||
container.appendChild(msg);
|
||||
return;
|
||||
}
|
||||
const NODE_COLORS = ['#00ccff', '#ff6600', '#00ff88', '#ff00cc', '#ffcc00', '#8800ff', '#00ffcc', '#ff0044'];
|
||||
container.textContent = '';
|
||||
for (const nf of nodeFeatures) {
|
||||
const color = NODE_COLORS[nf.node_id % NODE_COLORS.length];
|
||||
const statusColor = nf.stale ? '#888' : '#0f0';
|
||||
|
||||
const row = document.createElement('div');
|
||||
row.style.cssText = `display:flex;align-items:center;gap:8px;padding:6px 8px;margin-bottom:4px;background:rgba(255,255,255,0.03);border-radius:6px;border-left:3px solid ${color};`;
|
||||
|
||||
const idCol = document.createElement('div');
|
||||
idCol.style.minWidth = '50px';
|
||||
const nameEl = document.createElement('div');
|
||||
nameEl.style.cssText = `font-size:11px;font-weight:600;color:${color};`;
|
||||
nameEl.textContent = 'Node ' + nf.node_id;
|
||||
const statusEl = document.createElement('div');
|
||||
statusEl.style.cssText = `font-size:9px;color:${statusColor};`;
|
||||
statusEl.textContent = nf.stale ? 'STALE' : 'ACTIVE';
|
||||
idCol.appendChild(nameEl);
|
||||
idCol.appendChild(statusEl);
|
||||
|
||||
const metricsCol = document.createElement('div');
|
||||
metricsCol.style.cssText = 'flex:1;font-size:10px;color:#aaa;';
|
||||
metricsCol.textContent = (nf.rssi_dbm || -80).toFixed(0) + ' dBm · var ' + (nf.features?.variance || 0).toFixed(1);
|
||||
|
||||
const classCol = document.createElement('div');
|
||||
classCol.style.cssText = 'font-size:10px;font-weight:600;color:#ccc;';
|
||||
const motion = (nf.classification?.motion_level || 'absent').toUpperCase();
|
||||
const conf = ((nf.classification?.confidence || 0) * 100).toFixed(0);
|
||||
classCol.textContent = motion + ' ' + conf + '%';
|
||||
|
||||
row.appendChild(idCol);
|
||||
row.appendChild(metricsCol);
|
||||
row.appendChild(classCol);
|
||||
container.appendChild(row);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Resize ------------------------------------------------------------
|
||||
|
||||
_setupResize() {
|
||||
|
||||
@@ -66,6 +66,10 @@ function valueToColor(v) {
|
||||
return [r, g, b];
|
||||
}
|
||||
|
||||
// ---- Node marker color palette -------------------------------------------
|
||||
|
||||
const NODE_MARKER_COLORS = [0x00ccff, 0xff6600, 0x00ff88, 0xff00cc, 0xffcc00, 0x8800ff, 0x00ffcc, 0xff0044];
|
||||
|
||||
// ---- GaussianSplatRenderer -----------------------------------------------
|
||||
|
||||
export class GaussianSplatRenderer {
|
||||
@@ -108,6 +112,10 @@ export class GaussianSplatRenderer {
|
||||
// Node markers (ESP32 / router positions)
|
||||
this._createNodeMarkers(THREE);
|
||||
|
||||
// Dynamic per-node markers (multi-node support)
|
||||
this.nodeMarkers = new Map(); // nodeId -> THREE.Mesh
|
||||
this._THREE = THREE;
|
||||
|
||||
// Body disruption blob
|
||||
this._createBodyBlob(THREE);
|
||||
|
||||
@@ -369,11 +377,43 @@ export class GaussianSplatRenderer {
|
||||
bGeo.attributes.splatSize.needsUpdate = true;
|
||||
}
|
||||
|
||||
// -- Update node positions ---------------------------------------------
|
||||
// -- Update node positions (legacy single-node) ------------------------
|
||||
if (nodes.length > 0 && nodes[0].position) {
|
||||
const pos = nodes[0].position;
|
||||
this.nodeMarker.position.set(pos[0], 0.5, pos[2]);
|
||||
}
|
||||
|
||||
// -- Update dynamic per-node markers (multi-node support) --------------
|
||||
if (nodes && nodes.length > 0 && this.scene) {
|
||||
const THREE = this._THREE || window.THREE;
|
||||
if (THREE) {
|
||||
const activeIds = new Set();
|
||||
for (const node of nodes) {
|
||||
activeIds.add(node.node_id);
|
||||
if (!this.nodeMarkers.has(node.node_id)) {
|
||||
const geo = new THREE.SphereGeometry(0.25, 16, 16);
|
||||
const mat = new THREE.MeshBasicMaterial({
|
||||
color: NODE_MARKER_COLORS[node.node_id % NODE_MARKER_COLORS.length],
|
||||
transparent: true,
|
||||
opacity: 0.8,
|
||||
});
|
||||
const marker = new THREE.Mesh(geo, mat);
|
||||
this.scene.add(marker);
|
||||
this.nodeMarkers.set(node.node_id, marker);
|
||||
}
|
||||
const marker = this.nodeMarkers.get(node.node_id);
|
||||
const pos = node.position || [0, 0, 0];
|
||||
marker.position.set(pos[0], 0.5, pos[2]);
|
||||
}
|
||||
// Remove stale markers
|
||||
for (const [id, marker] of this.nodeMarkers) {
|
||||
if (!activeIds.has(id)) {
|
||||
this.scene.remove(marker);
|
||||
this.nodeMarkers.delete(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Render loop -------------------------------------------------------
|
||||
|
||||
@@ -76,4 +76,31 @@ describe('MATScreen', () => {
|
||||
// Simulated status maps to 'simulated' banner -> "SIMULATED DATA"
|
||||
expect(getByText('SIMULATED DATA')).toBeTruthy();
|
||||
});
|
||||
|
||||
it('shows simulation warning overlay when simulated and not acknowledged', () => {
|
||||
// Reset store to ensure overlay is shown
|
||||
const { useMatStore } = require('@/stores/matStore');
|
||||
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: false });
|
||||
|
||||
const { MATScreen } = require('@/screens/MATScreen');
|
||||
const { getByText } = render(
|
||||
<ThemeProvider>
|
||||
<MATScreen />
|
||||
</ThemeProvider>,
|
||||
);
|
||||
expect(getByText('I UNDERSTAND')).toBeTruthy();
|
||||
});
|
||||
|
||||
it('hides overlay after acknowledgment', () => {
|
||||
const { useMatStore } = require('@/stores/matStore');
|
||||
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: true });
|
||||
|
||||
const { MATScreen } = require('@/screens/MATScreen');
|
||||
const { queryByText } = render(
|
||||
<ThemeProvider>
|
||||
<MATScreen />
|
||||
</ThemeProvider>,
|
||||
);
|
||||
expect(queryByText('I UNDERSTAND')).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -62,6 +62,8 @@ describe('useMatStore', () => {
|
||||
survivors: [],
|
||||
alerts: [],
|
||||
selectedEventId: null,
|
||||
dataSource: 'simulated',
|
||||
simulationAcknowledged: false,
|
||||
});
|
||||
});
|
||||
|
||||
@@ -195,4 +197,32 @@ describe('useMatStore', () => {
|
||||
expect(useMatStore.getState().selectedEventId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('dataSource', () => {
|
||||
it('defaults to simulated', () => {
|
||||
expect(useMatStore.getState().dataSource).toBe('simulated');
|
||||
});
|
||||
|
||||
it('can be set to real', () => {
|
||||
useMatStore.getState().setDataSource('real');
|
||||
expect(useMatStore.getState().dataSource).toBe('real');
|
||||
});
|
||||
|
||||
it('can be set back to simulated', () => {
|
||||
useMatStore.getState().setDataSource('real');
|
||||
useMatStore.getState().setDataSource('simulated');
|
||||
expect(useMatStore.getState().dataSource).toBe('simulated');
|
||||
});
|
||||
});
|
||||
|
||||
describe('simulationAcknowledged', () => {
|
||||
it('defaults to false', () => {
|
||||
expect(useMatStore.getState().simulationAcknowledged).toBe(false);
|
||||
});
|
||||
|
||||
it('can be acknowledged', () => {
|
||||
useMatStore.getState().acknowledgeSimulation();
|
||||
expect(useMatStore.getState().simulationAcknowledged).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import React, { useEffect, useRef } from 'react';
|
||||
import { Animated, StyleSheet, Text, View } from 'react-native';
|
||||
|
||||
interface Props {
|
||||
visible: boolean;
|
||||
}
|
||||
|
||||
export const SimulationBanner: React.FC<Props> = ({ visible }) => {
|
||||
const opacity = useRef(new Animated.Value(1)).current;
|
||||
|
||||
useEffect(() => {
|
||||
if (!visible) return;
|
||||
|
||||
const pulse = Animated.loop(
|
||||
Animated.sequence([
|
||||
Animated.timing(opacity, { toValue: 0.4, duration: 800, useNativeDriver: true }),
|
||||
Animated.timing(opacity, { toValue: 1.0, duration: 800, useNativeDriver: true }),
|
||||
]),
|
||||
);
|
||||
pulse.start();
|
||||
return () => pulse.stop();
|
||||
}, [visible, opacity]);
|
||||
|
||||
if (!visible) return null;
|
||||
|
||||
return (
|
||||
<Animated.View style={[styles.banner, { opacity }]}>
|
||||
<Text style={styles.text}>SIMULATED DATA - NOT CONNECTED TO REAL SENSORS</Text>
|
||||
</Animated.View>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = StyleSheet.create({
|
||||
banner: {
|
||||
backgroundColor: '#e74c3c',
|
||||
paddingVertical: 6,
|
||||
paddingHorizontal: 12,
|
||||
borderRadius: 6,
|
||||
alignItems: 'center',
|
||||
marginBottom: 8,
|
||||
},
|
||||
text: {
|
||||
color: '#ffffff',
|
||||
fontWeight: '700',
|
||||
fontSize: 12,
|
||||
letterSpacing: 0.5,
|
||||
textAlign: 'center',
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,78 @@
|
||||
import React from 'react';
|
||||
import { Modal, Pressable, StyleSheet, Text, View } from 'react-native';
|
||||
|
||||
interface Props {
|
||||
visible: boolean;
|
||||
onAcknowledge: () => void;
|
||||
}
|
||||
|
||||
export const SimulationWarningOverlay: React.FC<Props> = ({ visible, onAcknowledge }) => (
|
||||
<Modal visible={visible} transparent animationType="fade">
|
||||
<View style={styles.backdrop}>
|
||||
<View style={styles.card}>
|
||||
<Text style={styles.icon}>⚠</Text>
|
||||
<Text style={styles.title}>SIMULATED DATA</Text>
|
||||
<Text style={styles.body}>
|
||||
NOT CONNECTED TO REAL SENSORS{'\n\n'}
|
||||
All survivor detections, vital signs, and alerts displayed on this screen are
|
||||
generated from simulated data and do not reflect actual conditions.
|
||||
</Text>
|
||||
<Pressable style={styles.button} onPress={onAcknowledge}>
|
||||
<Text style={styles.buttonText}>I UNDERSTAND</Text>
|
||||
</Pressable>
|
||||
</View>
|
||||
</View>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
const styles = StyleSheet.create({
|
||||
backdrop: {
|
||||
flex: 1,
|
||||
backgroundColor: 'rgba(0,0,0,0.85)',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
padding: 24,
|
||||
},
|
||||
card: {
|
||||
backgroundColor: '#1a1a2e',
|
||||
borderRadius: 16,
|
||||
padding: 32,
|
||||
alignItems: 'center',
|
||||
borderWidth: 2,
|
||||
borderColor: '#e74c3c',
|
||||
maxWidth: 420,
|
||||
width: '100%',
|
||||
},
|
||||
icon: {
|
||||
fontSize: 48,
|
||||
color: '#e74c3c',
|
||||
marginBottom: 12,
|
||||
},
|
||||
title: {
|
||||
fontSize: 22,
|
||||
fontWeight: '800',
|
||||
color: '#e74c3c',
|
||||
textAlign: 'center',
|
||||
marginBottom: 16,
|
||||
letterSpacing: 1,
|
||||
},
|
||||
body: {
|
||||
fontSize: 15,
|
||||
color: '#cccccc',
|
||||
textAlign: 'center',
|
||||
lineHeight: 22,
|
||||
marginBottom: 28,
|
||||
},
|
||||
button: {
|
||||
backgroundColor: '#e74c3c',
|
||||
paddingHorizontal: 36,
|
||||
paddingVertical: 14,
|
||||
borderRadius: 8,
|
||||
},
|
||||
buttonText: {
|
||||
color: '#ffffff',
|
||||
fontWeight: '700',
|
||||
fontSize: 16,
|
||||
letterSpacing: 0.5,
|
||||
},
|
||||
});
|
||||
@@ -10,6 +10,8 @@ import { type ConnectionStatus } from '@/types/sensing';
|
||||
import { Alert, type Survivor } from '@/types/mat';
|
||||
import { AlertList } from './AlertList';
|
||||
import { MatWebView } from './MatWebView';
|
||||
import { SimulationBanner } from './SimulationBanner';
|
||||
import { SimulationWarningOverlay } from './SimulationWarningOverlay';
|
||||
import { SurvivorCounter } from './SurvivorCounter';
|
||||
import { useMatBridge } from './useMatBridge';
|
||||
|
||||
@@ -47,6 +49,15 @@ export const MATScreen = () => {
|
||||
const upsertSurvivor = useMatStore((state) => state.upsertSurvivor);
|
||||
const addAlert = useMatStore((state) => state.addAlert);
|
||||
const upsertEvent = useMatStore((state) => state.upsertEvent);
|
||||
const dataSource = useMatStore((state) => state.dataSource);
|
||||
const simulationAcknowledged = useMatStore((state) => state.simulationAcknowledged);
|
||||
const setDataSource = useMatStore((state) => state.setDataSource);
|
||||
const acknowledgeSimulation = useMatStore((state) => state.acknowledgeSimulation);
|
||||
|
||||
// Sync dataSource from connection status
|
||||
useEffect(() => {
|
||||
setDataSource(connectionStatus === 'connected' ? 'real' : 'simulated');
|
||||
}, [connectionStatus, setDataSource]);
|
||||
|
||||
const { webViewRef, ready, onMessage, sendFrameUpdate, postEvent } = useMatBridge({
|
||||
onSurvivorDetected: (survivor) => {
|
||||
@@ -113,8 +124,13 @@ export const MATScreen = () => {
|
||||
const { height } = useWindowDimensions();
|
||||
const webHeight = Math.max(240, Math.floor(height * 0.5));
|
||||
|
||||
const showOverlay = dataSource === 'simulated' && !simulationAcknowledged;
|
||||
const showBanner = dataSource === 'simulated' && simulationAcknowledged;
|
||||
|
||||
return (
|
||||
<ThemedView style={{ flex: 1, backgroundColor: colors.bg, padding: spacing.md }}>
|
||||
<SimulationWarningOverlay visible={showOverlay} onAcknowledge={acknowledgeSimulation} />
|
||||
<SimulationBanner visible={showBanner} />
|
||||
<ConnectionBanner status={resolveBannerState(connectionStatus)} />
|
||||
<View style={{ marginTop: 20 }}>
|
||||
<SurvivorCounter survivors={survivors} />
|
||||
|
||||
@@ -7,11 +7,17 @@ export interface MatState {
|
||||
survivors: Survivor[];
|
||||
alerts: Alert[];
|
||||
selectedEventId: string | null;
|
||||
/** Whether data comes from real sensors or simulation. */
|
||||
dataSource: 'real' | 'simulated';
|
||||
/** Whether the user has dismissed the simulation warning overlay. */
|
||||
simulationAcknowledged: boolean;
|
||||
upsertEvent: (event: DisasterEvent) => void;
|
||||
addZone: (zone: ScanZone) => void;
|
||||
upsertSurvivor: (survivor: Survivor) => void;
|
||||
addAlert: (alert: Alert) => void;
|
||||
setSelectedEvent: (id: string | null) => void;
|
||||
setDataSource: (source: 'real' | 'simulated') => void;
|
||||
acknowledgeSimulation: () => void;
|
||||
}
|
||||
|
||||
export const useMatStore = create<MatState>((set) => ({
|
||||
@@ -20,6 +26,8 @@ export const useMatStore = create<MatState>((set) => ({
|
||||
survivors: [],
|
||||
alerts: [],
|
||||
selectedEventId: null,
|
||||
dataSource: 'simulated',
|
||||
simulationAcknowledged: false,
|
||||
|
||||
upsertEvent: (event) => {
|
||||
set((state) => {
|
||||
@@ -71,4 +79,12 @@ export const useMatStore = create<MatState>((set) => ({
|
||||
setSelectedEvent: (id) => {
|
||||
set({ selectedEventId: id });
|
||||
},
|
||||
|
||||
setDataSource: (source) => {
|
||||
set({ dataSource: source });
|
||||
},
|
||||
|
||||
acknowledgeSimulation: () => {
|
||||
set({ simulationAcknowledged: true });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -84,6 +84,11 @@ class SensingService {
|
||||
return [...this._rssiHistory];
|
||||
}
|
||||
|
||||
/** Get per-node RSSI history (object keyed by node_id). */
|
||||
getPerNodeRssiHistory() {
|
||||
return { ...(this._perNodeRssiHistory || {}) };
|
||||
}
|
||||
|
||||
/** Current connection state. */
|
||||
get state() {
|
||||
return this._state;
|
||||
@@ -327,6 +332,20 @@ class SensingService {
|
||||
}
|
||||
}
|
||||
|
||||
// Per-node RSSI tracking
|
||||
if (!this._perNodeRssiHistory) this._perNodeRssiHistory = {};
|
||||
if (data.node_features) {
|
||||
for (const nf of data.node_features) {
|
||||
if (!this._perNodeRssiHistory[nf.node_id]) {
|
||||
this._perNodeRssiHistory[nf.node_id] = [];
|
||||
}
|
||||
this._perNodeRssiHistory[nf.node_id].push(nf.rssi_dbm);
|
||||
if (this._perNodeRssiHistory[nf.node_id].length > this._maxHistory) {
|
||||
this._perNodeRssiHistory[nf.node_id].shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify all listeners
|
||||
for (const cb of this._listeners) {
|
||||
try {
|
||||
|
||||
+7
-1
@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.routers import pose, stream, health, auth
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
|
||||
@@ -263,6 +263,12 @@ app.include_router(
|
||||
tags=["Streaming"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
auth.router,
|
||||
prefix=f"{settings.api_prefix}",
|
||||
tags=["Authentication"]
|
||||
)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
|
||||
@@ -189,7 +189,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
self.settings.secret_key,
|
||||
algorithms=[self.settings.jwt_algorithm]
|
||||
)
|
||||
|
||||
|
||||
# Check token blacklist (logout invalidation)
|
||||
if token_blacklist.is_blacklisted(token):
|
||||
raise ValueError("Token has been revoked")
|
||||
|
||||
# Extract user information
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
API routers package
|
||||
"""
|
||||
|
||||
from . import pose, stream, health
|
||||
from . import pose, stream, health, auth
|
||||
|
||||
__all__ = ["pose", "stream", "health"]
|
||||
__all__ = ["pose", "stream", "health", "auth"]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Authentication router for WiFi-DensePose API.
|
||||
Provides logout (token blacklisting) endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, status
|
||||
|
||||
from src.api.middleware.auth import token_blacklist
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(request: Request):
|
||||
"""Logout by blacklisting the current Bearer token."""
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing or invalid Authorization header",
|
||||
)
|
||||
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
token_blacklist.add_token(token)
|
||||
logger.info("Token blacklisted via /auth/logout")
|
||||
|
||||
return {"success": True, "message": "Token revoked"}
|
||||
@@ -1,6 +1,7 @@
|
||||
"""CSI data processor for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
@@ -293,7 +294,8 @@ class CSIProcessor:
|
||||
if count >= len(self.csi_history):
|
||||
return list(self.csi_history)
|
||||
else:
|
||||
return list(self.csi_history)[-count:]
|
||||
start = len(self.csi_history) - count
|
||||
return list(itertools.islice(self.csi_history, start, len(self.csi_history)))
|
||||
|
||||
def get_processing_statistics(self) -> Dict[str, Any]:
|
||||
"""Get processing statistics.
|
||||
@@ -410,8 +412,9 @@ class CSIProcessor:
|
||||
# Use cached mean-phase values (pre-computed in add_to_history)
|
||||
# Only take the last doppler_window frames for bounded cost
|
||||
window = min(len(self._phase_cache), self._doppler_window)
|
||||
cache_list = list(self._phase_cache)
|
||||
phase_matrix = np.array(cache_list[-window:])
|
||||
start = len(self._phase_cache) - window
|
||||
cache_list = list(itertools.islice(self._phase_cache, start, len(self._phase_cache)))
|
||||
phase_matrix = np.array(cache_list)
|
||||
|
||||
# Temporal phase differences between consecutive frames
|
||||
phase_diffs = np.diff(phase_matrix, axis=0)
|
||||
|
||||
@@ -56,6 +56,10 @@ class TokenManager:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
# Check token blacklist (logout invalidation)
|
||||
from src.api.middleware.auth import token_blacklist
|
||||
if token_blacklist.is_blacklisted(token):
|
||||
raise AuthenticationError("Token has been revoked")
|
||||
return payload
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Frame budget benchmark for CSI processing pipeline.
|
||||
|
||||
Verifies that per-frame CSI processing stays within the 50 ms budget
|
||||
required for real-time sensing at 20 FPS.
|
||||
"""
|
||||
|
||||
import time
|
||||
import statistics
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
|
||||
|
||||
def _make_config():
|
||||
return {
|
||||
"sampling_rate": 1000,
|
||||
"window_size": 256,
|
||||
"overlap": 0.5,
|
||||
"noise_threshold": -60,
|
||||
"human_detection_threshold": 0.8,
|
||||
"smoothing_factor": 0.9,
|
||||
"max_history_size": 500,
|
||||
"num_subcarriers": 256,
|
||||
"num_antennas": 3,
|
||||
"doppler_window": 64,
|
||||
}
|
||||
|
||||
|
||||
def _make_csi_data(n_subcarriers=256, n_antennas=3, seed=None):
|
||||
"""Generate a synthetic CSI frame with complex-valued subcarriers."""
|
||||
rng = np.random.default_rng(seed)
|
||||
from unittest.mock import MagicMock
|
||||
csi = MagicMock()
|
||||
csi.amplitude = rng.random((n_antennas, n_subcarriers)).astype(np.float64) * 20.0
|
||||
csi.phase = (rng.random((n_antennas, n_subcarriers)).astype(np.float64) - 0.5) * np.pi * 2
|
||||
csi.frequency = 5.0e9
|
||||
csi.bandwidth = 80e6
|
||||
csi.num_subcarriers = n_subcarriers
|
||||
csi.num_antennas = n_antennas
|
||||
csi.snr = 25.0
|
||||
csi.timestamp = time.time()
|
||||
csi.metadata = {}
|
||||
return csi
|
||||
|
||||
|
||||
class TestSingleFrameBudget:
|
||||
"""Single-frame processing must complete in < 50 ms."""
|
||||
|
||||
def test_single_frame_under_50ms(self):
|
||||
proc = CSIProcessor(config=_make_config())
|
||||
frame = _make_csi_data(seed=42)
|
||||
|
||||
# Warm up
|
||||
proc.preprocess_csi_data(frame)
|
||||
|
||||
start = time.perf_counter()
|
||||
proc.preprocess_csi_data(frame)
|
||||
features = proc.extract_features(frame)
|
||||
if features:
|
||||
proc.detect_human_presence(features)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert elapsed_ms < 50, f"Single frame took {elapsed_ms:.1f} ms (budget: 50 ms)"
|
||||
|
||||
|
||||
class TestSustainedFrameBudget:
|
||||
"""Sustained 100-frame processing p95 must be < 50 ms per frame."""
|
||||
|
||||
def test_sustained_100_frames_p95(self):
|
||||
proc = CSIProcessor(config=_make_config())
|
||||
rng = np.random.default_rng(123)
|
||||
n_frames = 100
|
||||
latencies = []
|
||||
|
||||
for i in range(n_frames):
|
||||
frame = _make_csi_data(seed=i)
|
||||
start = time.perf_counter()
|
||||
preprocessed = proc.preprocess_csi_data(frame)
|
||||
features = proc.extract_features(preprocessed)
|
||||
if features:
|
||||
proc.detect_human_presence(features)
|
||||
proc.add_to_history(frame)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
p50 = statistics.median(latencies)
|
||||
p95 = sorted(latencies)[int(0.95 * len(latencies))]
|
||||
p99 = sorted(latencies)[int(0.99 * len(latencies))]
|
||||
|
||||
print(f"\n--- Sustained {n_frames}-frame benchmark ---")
|
||||
print(f" p50: {p50:.2f} ms")
|
||||
print(f" p95: {p95:.2f} ms")
|
||||
print(f" p99: {p99:.2f} ms")
|
||||
print(f" min: {min(latencies):.2f} ms")
|
||||
print(f" max: {max(latencies):.2f} ms")
|
||||
|
||||
assert p95 < 50, f"p95 latency {p95:.1f} ms exceeds 50 ms budget"
|
||||
|
||||
|
||||
class TestPipelineWithDoppler:
|
||||
"""Full pipeline including Doppler estimation must stay within budget."""
|
||||
|
||||
def test_doppler_pipeline(self):
|
||||
proc = CSIProcessor(config=_make_config())
|
||||
n_frames = 100
|
||||
latencies = []
|
||||
|
||||
# Fill history first
|
||||
for i in range(20):
|
||||
frame = _make_csi_data(seed=i + 1000)
|
||||
proc.add_to_history(frame)
|
||||
|
||||
for i in range(n_frames):
|
||||
frame = _make_csi_data(seed=i + 2000)
|
||||
start = time.perf_counter()
|
||||
preprocessed = proc.preprocess_csi_data(frame)
|
||||
features = proc.extract_features(preprocessed)
|
||||
if features:
|
||||
proc.detect_human_presence(features)
|
||||
proc.add_to_history(frame)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
p50 = statistics.median(latencies)
|
||||
p95 = sorted(latencies)[int(0.95 * len(latencies))]
|
||||
p99 = sorted(latencies)[int(0.99 * len(latencies))]
|
||||
|
||||
print(f"\n--- Doppler pipeline benchmark ({n_frames} frames, 20 warmup) ---")
|
||||
print(f" p50: {p50:.2f} ms")
|
||||
print(f" p95: {p95:.2f} ms")
|
||||
print(f" p99: {p99:.2f} ms")
|
||||
|
||||
# Doppler adds overhead but should still be within budget
|
||||
assert p95 < 50, f"Doppler pipeline p95 {p95:.1f} ms exceeds 50 ms budget"
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Shared fixtures for unit tests."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# Set SECRET_KEY before any settings import
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-unit-tests-only")
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests-only")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Create a mock Settings object."""
|
||||
settings = MagicMock()
|
||||
settings.secret_key = "test-secret-key-for-unit-tests-only"
|
||||
settings.jwt_algorithm = "HS256"
|
||||
settings.jwt_expire_hours = 24
|
||||
settings.app_name = "test-app"
|
||||
settings.version = "0.1.0"
|
||||
settings.is_production = False
|
||||
settings.enable_rate_limiting = False
|
||||
settings.enable_authentication = False
|
||||
settings.rate_limit_requests = 100
|
||||
settings.rate_limit_window = 60
|
||||
settings.rate_limit_authenticated_requests = 1000
|
||||
settings.allowed_hosts = ["*"]
|
||||
settings.csi_buffer_size = 100
|
||||
settings.stream_buffer_size = 100
|
||||
settings.mock_hardware = True
|
||||
settings.mock_pose_data = True
|
||||
settings.enable_real_time_processing = False
|
||||
settings.trusted_proxies = ["127.0.0.1"]
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domain_config():
|
||||
"""Create a mock DomainConfig object."""
|
||||
config = MagicMock()
|
||||
config.pose_estimation = MagicMock()
|
||||
config.streaming = MagicMock()
|
||||
config.hardware = MagicMock()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Provide a mock Redis client."""
|
||||
with patch("redis.Redis") as mock:
|
||||
client = MagicMock()
|
||||
client.ping.return_value = True
|
||||
client.get.return_value = None
|
||||
client.set.return_value = True
|
||||
mock.return_value = client
|
||||
yield client
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Tests for AuthMiddleware and TokenManager."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TestTokenManager:
|
||||
def test_create_token(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager
|
||||
tm = TokenManager(mock_settings)
|
||||
token = tm.create_access_token({"sub": "user1"})
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
def test_verify_valid_token(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager
|
||||
tm = TokenManager(mock_settings)
|
||||
token = tm.create_access_token({"sub": "user1", "role": "admin"})
|
||||
payload = tm.verify_token(token)
|
||||
assert payload["sub"] == "user1"
|
||||
assert payload["role"] == "admin"
|
||||
|
||||
def test_verify_invalid_token(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager, AuthenticationError
|
||||
tm = TokenManager(mock_settings)
|
||||
with pytest.raises(AuthenticationError):
|
||||
tm.verify_token("invalid.token.here")
|
||||
|
||||
def test_decode_claims(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager
|
||||
tm = TokenManager(mock_settings)
|
||||
token = tm.create_access_token({"sub": "user1"})
|
||||
claims = tm.decode_token_claims(token)
|
||||
assert claims is not None
|
||||
assert claims["sub"] == "user1"
|
||||
|
||||
def test_decode_claims_invalid(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager
|
||||
tm = TokenManager(mock_settings)
|
||||
claims = tm.decode_token_claims("bad-token")
|
||||
assert claims is None
|
||||
|
||||
def test_token_has_expiry(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager
|
||||
tm = TokenManager(mock_settings)
|
||||
token = tm.create_access_token({"sub": "user1"})
|
||||
payload = tm.verify_token(token)
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
|
||||
class TestUserManager:
|
||||
def test_create_user(self):
|
||||
from src.middleware.auth import UserManager
|
||||
um = UserManager()
|
||||
assert um.get_user("nonexistent") is None
|
||||
|
||||
def test_hash_password(self):
|
||||
from src.middleware.auth import UserManager
|
||||
hashed = UserManager.hash_password("secret123")
|
||||
assert hashed != "secret123"
|
||||
assert len(hashed) > 20
|
||||
|
||||
def test_verify_password(self):
|
||||
from src.middleware.auth import UserManager
|
||||
hashed = UserManager.hash_password("secret123")
|
||||
assert UserManager.verify_password("secret123", hashed) is True
|
||||
assert UserManager.verify_password("wrong", hashed) is False
|
||||
|
||||
|
||||
class TestTokenBlacklist:
|
||||
def test_add_and_check(self):
|
||||
from src.api.middleware.auth import TokenBlacklist
|
||||
bl = TokenBlacklist()
|
||||
bl.add_token("tok123")
|
||||
assert bl.is_blacklisted("tok123") is True
|
||||
assert bl.is_blacklisted("tok456") is False
|
||||
|
||||
def test_blacklisted_token_rejected(self, mock_settings):
|
||||
from src.middleware.auth import TokenManager, AuthenticationError
|
||||
from src.api.middleware.auth import token_blacklist
|
||||
|
||||
tm = TokenManager(mock_settings)
|
||||
token = tm.create_access_token({"sub": "user1"})
|
||||
# Token should be valid
|
||||
tm.verify_token(token)
|
||||
# Blacklist it
|
||||
token_blacklist.add_token(token)
|
||||
with pytest.raises(AuthenticationError, match="revoked"):
|
||||
tm.verify_token(token)
|
||||
# Cleanup
|
||||
token_blacklist._blacklisted_tokens.discard(token)
|
||||
|
||||
|
||||
class TestAuthMiddleware:
|
||||
def test_public_paths(self, mock_settings):
|
||||
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
app = MagicMock()
|
||||
mw = AuthMiddleware(app)
|
||||
assert mw._is_public_path("/health") is True
|
||||
assert mw._is_public_path("/docs") is True
|
||||
assert mw._is_public_path("/api/v1/pose/analyze") is False
|
||||
|
||||
def test_protected_paths(self, mock_settings):
|
||||
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
app = MagicMock()
|
||||
mw = AuthMiddleware(app)
|
||||
assert mw._is_protected_path("/api/v1/pose/analyze") is True
|
||||
assert mw._is_protected_path("/health") is False
|
||||
|
||||
def test_extract_token_from_header(self, mock_settings):
|
||||
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
app = MagicMock()
|
||||
mw = AuthMiddleware(app)
|
||||
request = MagicMock()
|
||||
request.headers = {"authorization": "Bearer mytoken123"}
|
||||
request.query_params = {}
|
||||
request.cookies = {}
|
||||
token = mw._extract_token(request)
|
||||
assert token == "mytoken123"
|
||||
|
||||
def test_extract_token_missing(self, mock_settings):
|
||||
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
app = MagicMock()
|
||||
mw = AuthMiddleware(app)
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.query_params = {}
|
||||
request.cookies = {}
|
||||
token = mw._extract_token(request)
|
||||
assert token is None
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for error handling in the API layer."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestExceptionHandlers:
|
||||
"""Test the exception handlers registered on the FastAPI app."""
|
||||
|
||||
def _get_app(self):
|
||||
"""Import app lazily to avoid side effects."""
|
||||
with patch("src.api.main.get_settings") as mock_gs, \
|
||||
patch("src.api.main.get_domain_config") as mock_gdc, \
|
||||
patch("src.api.main.get_pose_service") as mock_ps, \
|
||||
patch("src.api.main.get_stream_service") as mock_ss, \
|
||||
patch("src.api.main.get_hardware_service") as mock_hs, \
|
||||
patch("src.api.main.connection_manager") as mock_cm, \
|
||||
patch("src.api.main.PoseStreamHandler") as mock_psh:
|
||||
mock_gs.return_value = MagicMock(
|
||||
app_name="test", version="0.1", environment="test",
|
||||
is_production=False, enable_rate_limiting=False,
|
||||
enable_authentication=False, docs_url="/docs",
|
||||
redoc_url="/redoc", openapi_url="/openapi.json",
|
||||
api_prefix="/api/v1",
|
||||
)
|
||||
mock_gs.return_value.get_logging_config.return_value = {
|
||||
"version": 1, "disable_existing_loggers": False,
|
||||
"handlers": {}, "loggers": {},
|
||||
}
|
||||
mock_gs.return_value.get_cors_config.return_value = {
|
||||
"allow_origins": ["*"], "allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
# Re-import to pick up patches
|
||||
import importlib
|
||||
import src.api.main as m
|
||||
importlib.reload(m)
|
||||
return m.app
|
||||
|
||||
|
||||
class TestErrorResponseModel:
|
||||
def test_error_json_structure(self):
|
||||
"""Verify error JSON has code, message, type fields."""
|
||||
error = {
|
||||
"error": {
|
||||
"code": 404,
|
||||
"message": "Not found",
|
||||
"type": "http_error"
|
||||
}
|
||||
}
|
||||
assert error["error"]["code"] == 404
|
||||
assert "message" in error["error"]
|
||||
assert "type" in error["error"]
|
||||
|
||||
def test_validation_error_structure(self):
|
||||
error = {
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"details": []
|
||||
}
|
||||
}
|
||||
assert error["error"]["type"] == "validation_error"
|
||||
assert isinstance(error["error"]["details"], list)
|
||||
|
||||
def test_internal_error_masks_details(self):
|
||||
"""In production, internal errors should not leak stack traces."""
|
||||
error = {
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "internal_error"
|
||||
}
|
||||
}
|
||||
assert "traceback" not in str(error)
|
||||
assert error["error"]["message"] == "Internal server error"
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Tests for HardwareService."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestHardwareServiceInit:
|
||||
def test_init(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
assert svc.is_running is False
|
||||
assert svc.stats["total_samples"] == 0
|
||||
assert svc.stats["connected_routers"] == 0
|
||||
|
||||
def test_stats_defaults(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
assert svc.stats["successful_samples"] == 0
|
||||
assert svc.stats["failed_samples"] == 0
|
||||
assert svc.stats["last_sample_time"] is None
|
||||
|
||||
|
||||
class TestHardwareServiceLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
svc._initialize_routers = AsyncMock()
|
||||
svc._monitoring_loop = AsyncMock()
|
||||
await svc.start()
|
||||
assert svc.is_running is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_start_idempotent(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
svc._initialize_routers = AsyncMock()
|
||||
svc._monitoring_loop = AsyncMock()
|
||||
await svc.start()
|
||||
await svc.start() # idempotent
|
||||
assert svc.is_running is True
|
||||
|
||||
|
||||
class TestHardwareServiceRouter:
|
||||
def test_no_routers_on_init(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
assert len(svc.router_interfaces) == 0
|
||||
|
||||
def test_max_recent_samples(self, mock_settings, mock_domain_config):
|
||||
mock_settings.mock_hardware = True
|
||||
with patch("src.services.hardware_service.RouterInterface"):
|
||||
from src.services.hardware_service import HardwareService
|
||||
svc = HardwareService(mock_settings, mock_domain_config)
|
||||
assert svc.max_recent_samples == 1000
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Tests for HealthCheckService."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestHealthCheckServiceInit:
|
||||
def test_init(self, mock_settings):
|
||||
from src.services.health_check import HealthCheckService
|
||||
svc = HealthCheckService(mock_settings)
|
||||
assert svc._initialized is False
|
||||
assert svc._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, mock_settings):
|
||||
from src.services.health_check import HealthCheckService
|
||||
svc = HealthCheckService(mock_settings)
|
||||
await svc.initialize()
|
||||
assert svc._initialized is True
|
||||
assert "api" in svc._services
|
||||
assert "database" in svc._services
|
||||
assert "hardware" in svc._services
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_initialize(self, mock_settings):
|
||||
from src.services.health_check import HealthCheckService
|
||||
svc = HealthCheckService(mock_settings)
|
||||
await svc.initialize()
|
||||
await svc.initialize() # idempotent
|
||||
assert svc._initialized is True
|
||||
|
||||
|
||||
class TestHealthCheckAggregation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_services_registered(self, mock_settings):
|
||||
from src.services.health_check import HealthCheckService, HealthStatus
|
||||
svc = HealthCheckService(mock_settings)
|
||||
await svc.initialize()
|
||||
assert len(svc._services) == 6
|
||||
for name, sh in svc._services.items():
|
||||
assert sh.status == HealthStatus.UNKNOWN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_names(self, mock_settings):
|
||||
from src.services.health_check import HealthCheckService
|
||||
svc = HealthCheckService(mock_settings)
|
||||
await svc.initialize()
|
||||
expected = {"api", "database", "redis", "hardware", "pose", "stream"}
|
||||
assert set(svc._services.keys()) == expected
|
||||
|
||||
|
||||
class TestHealthStatus:
|
||||
def test_enum_values(self):
|
||||
from src.services.health_check import HealthStatus
|
||||
assert HealthStatus.HEALTHY.value == "healthy"
|
||||
assert HealthStatus.DEGRADED.value == "degraded"
|
||||
assert HealthStatus.UNHEALTHY.value == "unhealthy"
|
||||
assert HealthStatus.UNKNOWN.value == "unknown"
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
def test_health_check_dataclass(self):
|
||||
from src.services.health_check import HealthCheck, HealthStatus
|
||||
hc = HealthCheck(name="test", status=HealthStatus.HEALTHY, message="ok")
|
||||
assert hc.name == "test"
|
||||
assert hc.status == HealthStatus.HEALTHY
|
||||
assert hc.duration_ms == 0.0
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tests for MetricsService."""
|
||||
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestMetricSeries:
|
||||
def test_add_point(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
ms.add_point(42.0)
|
||||
assert len(ms.points) == 1
|
||||
assert ms.points[0].value == 42.0
|
||||
|
||||
def test_get_latest(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
ms.add_point(1.0)
|
||||
ms.add_point(2.0)
|
||||
latest = ms.get_latest()
|
||||
assert latest is not None
|
||||
assert latest.value == 2.0
|
||||
|
||||
def test_get_latest_empty(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
assert ms.get_latest() is None
|
||||
|
||||
def test_get_average(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
for v in [10.0, 20.0, 30.0]:
|
||||
ms.add_point(v)
|
||||
avg = ms.get_average(timedelta(minutes=5))
|
||||
assert avg == pytest.approx(20.0)
|
||||
|
||||
def test_get_average_empty(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
assert ms.get_average(timedelta(minutes=5)) is None
|
||||
|
||||
def test_get_max(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
for v in [10.0, 50.0, 30.0]:
|
||||
ms.add_point(v)
|
||||
mx = ms.get_max(timedelta(minutes=5))
|
||||
assert mx == 50.0
|
||||
|
||||
def test_labels(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
ms.add_point(1.0, {"region": "us-east"})
|
||||
assert ms.points[0].labels["region"] == "us-east"
|
||||
|
||||
def test_maxlen(self):
|
||||
from src.services.metrics import MetricSeries
|
||||
ms = MetricSeries(name="test", description="desc", unit="ms")
|
||||
for i in range(1100):
|
||||
ms.add_point(float(i))
|
||||
assert len(ms.points) == 1000
|
||||
|
||||
|
||||
class TestMetricsService:
|
||||
def test_init(self, mock_settings):
|
||||
with patch("src.services.metrics.psutil"):
|
||||
from src.services.metrics import MetricsService
|
||||
svc = MetricsService(mock_settings)
|
||||
assert svc._metrics is not None
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Tests for PoseService."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestPoseServiceInit:
|
||||
def test_init_sets_defaults(self, mock_settings, mock_domain_config):
|
||||
with patch.dict("sys.modules", {
|
||||
"torch": MagicMock(),
|
||||
"src.models.densepose_head": MagicMock(),
|
||||
"src.models.modality_translation": MagicMock(),
|
||||
}):
|
||||
from src.services.pose_service import PoseService
|
||||
svc = PoseService(mock_settings, mock_domain_config)
|
||||
assert svc.is_initialized is False
|
||||
assert svc.is_running is False
|
||||
assert svc.stats["total_processed"] == 0
|
||||
|
||||
def test_stats_are_zero_on_init(self, mock_settings, mock_domain_config):
|
||||
with patch.dict("sys.modules", {
|
||||
"torch": MagicMock(),
|
||||
"src.models.densepose_head": MagicMock(),
|
||||
"src.models.modality_translation": MagicMock(),
|
||||
}):
|
||||
from src.services.pose_service import PoseService
|
||||
svc = PoseService(mock_settings, mock_domain_config)
|
||||
assert svc.stats["successful_detections"] == 0
|
||||
assert svc.stats["failed_detections"] == 0
|
||||
assert svc.stats["average_confidence"] == 0.0
|
||||
|
||||
|
||||
class TestPoseServiceLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_sets_flag(self, mock_settings, mock_domain_config):
|
||||
with patch.dict("sys.modules", {
|
||||
"torch": MagicMock(),
|
||||
"src.models.densepose_head": MagicMock(),
|
||||
"src.models.modality_translation": MagicMock(),
|
||||
}):
|
||||
from src.services.pose_service import PoseService
|
||||
svc = PoseService(mock_settings, mock_domain_config)
|
||||
await svc.initialize()
|
||||
assert svc.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop(self, mock_settings, mock_domain_config):
|
||||
with patch.dict("sys.modules", {
|
||||
"torch": MagicMock(),
|
||||
"src.models.densepose_head": MagicMock(),
|
||||
"src.models.modality_translation": MagicMock(),
|
||||
}):
|
||||
from src.services.pose_service import PoseService
|
||||
svc = PoseService(mock_settings, mock_domain_config)
|
||||
await svc.initialize()
|
||||
await svc.start()
|
||||
assert svc.is_running is True
|
||||
await svc.stop()
|
||||
assert svc.is_running is False
|
||||
|
||||
|
||||
class TestPoseServiceStats:
|
||||
def test_initial_classification(self, mock_settings, mock_domain_config):
|
||||
with patch.dict("sys.modules", {
|
||||
"torch": MagicMock(),
|
||||
"src.models.densepose_head": MagicMock(),
|
||||
"src.models.modality_translation": MagicMock(),
|
||||
}):
|
||||
from src.services.pose_service import PoseService
|
||||
svc = PoseService(mock_settings, mock_domain_config)
|
||||
assert svc.last_error is None
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
def test_init(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert "anonymous" in mw.rate_limits
|
||||
assert "authenticated" in mw.rate_limits
|
||||
assert "admin" in mw.rate_limits
|
||||
|
||||
def test_exempt_paths(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert "/health" in mw.exempt_paths
|
||||
assert "/metrics" in mw.exempt_paths
|
||||
|
||||
def test_is_exempt(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert mw._is_exempt_path("/health") is True
|
||||
assert mw._is_exempt_path("/api/v1/pose/current") is False
|
||||
|
||||
def test_path_specific_limits(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert "/api/v1/pose/current" in mw.path_limits
|
||||
assert mw.path_limits["/api/v1/pose/current"]["requests"] == 60
|
||||
|
||||
def test_trusted_proxies_not_blocked(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert not mw._is_client_blocked("new-client-id")
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
def test_anonymous_limit(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert mw.rate_limits["anonymous"]["burst"] == 10
|
||||
|
||||
def test_admin_limit(self, mock_settings):
|
||||
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
app = MagicMock()
|
||||
mw = RateLimitMiddleware(app)
|
||||
assert mw.rate_limits["admin"]["requests"] == 10000
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Tests for StreamService."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestStreamServiceLifecycle:
|
||||
def test_init(self, mock_settings, mock_domain_config):
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
assert svc.is_running is False
|
||||
assert len(svc.connections) == 0
|
||||
assert svc.stats["active_connections"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, mock_settings, mock_domain_config):
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
await svc.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start(self, mock_settings, mock_domain_config):
|
||||
mock_settings.enable_real_time_processing = False
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
await svc.start()
|
||||
assert svc.is_running is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(self, mock_settings, mock_domain_config):
|
||||
mock_settings.enable_real_time_processing = False
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
await svc.start()
|
||||
await svc.stop()
|
||||
assert svc.is_running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_start(self, mock_settings, mock_domain_config):
|
||||
mock_settings.enable_real_time_processing = False
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
await svc.start()
|
||||
await svc.start() # should be idempotent
|
||||
assert svc.is_running is True
|
||||
|
||||
|
||||
class TestStreamServiceConnections:
|
||||
def test_no_connections_on_init(self, mock_settings, mock_domain_config):
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
assert svc.stats["total_connections"] == 0
|
||||
assert svc.stats["messages_sent"] == 0
|
||||
|
||||
def test_buffer_sizes(self, mock_settings, mock_domain_config):
|
||||
mock_settings.stream_buffer_size = 50
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
assert svc.pose_buffer.maxlen == 50
|
||||
assert svc.csi_buffer.maxlen == 50
|
||||
|
||||
|
||||
class TestStreamServiceBroadcast:
|
||||
def test_stats_messages_failed_init_zero(self, mock_settings, mock_domain_config):
|
||||
from src.services.stream_service import StreamService
|
||||
svc = StreamService(mock_settings, mock_domain_config)
|
||||
assert svc.stats["messages_failed"] == 0
|
||||
assert svc.stats["data_points_streamed"] == 0
|
||||
Reference in New Issue
Block a user