mirror of
https://github.com/ruvnet/RuView
synced 2026-06-12 10:43:19 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d22616c488 | |||
| 17471e93ff |
@@ -121,12 +121,23 @@ jobs:
|
||||
with:
|
||||
workspaces: v2
|
||||
|
||||
# The 38-crate workspace debug build exhausts the runner's disk when built
|
||||
# with full debuginfo (observed: "final link failed: No space left on
|
||||
# device" once the engine/benchmark crates landed; the same tree's local
|
||||
# debug target measured 151 GB). Debuginfo is useless in CI — tests either
|
||||
# pass or print their failure — so build without it; target shrinks ~5-10x.
|
||||
- name: Run Rust tests
|
||||
working-directory: v2
|
||||
env:
|
||||
CARGO_PROFILE_DEV_DEBUG: "0"
|
||||
CARGO_PROFILE_TEST_DEBUG: "0"
|
||||
run: cargo test --workspace --no-default-features
|
||||
|
||||
- name: Run ADR-147 worldmodel tests
|
||||
working-directory: v2
|
||||
env:
|
||||
CARGO_PROFILE_DEV_DEBUG: "0"
|
||||
CARGO_PROFILE_TEST_DEBUG: "0"
|
||||
run: cargo test -p wifi-densepose-worldmodel --no-default-features
|
||||
|
||||
# ADR-134 CIR tests are behind the `cir` feature so the bench dependency
|
||||
|
||||
@@ -11,6 +11,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- **Mesh partition risk now demotes the privacy class and is witnessed (ADR-032).** The dynamic min-cut guard's `at_risk` signal was advisory-only (it fed the recalibration advisor). It now also contributes to the ADR-141 privacy demotion alongside fusion- and array-level contradictions: a mesh close to partitioning makes the fused belief less trustworthy, so the cycle emits at a more restricted class (monotonic — information only removed). Because `effective_class` feeds the BLAKE3 witness, a fragmenting array now shifts the witness — partition risk is auditable, not just logged. The mesh computation moved ahead of the demotion step in `process_cycle`; new `mesh_guard_mut()` exposes risk-threshold tuning. Test proves a forced-risk 3-node cycle demotes PrivateHome Anonymous→Restricted and shifts the witness vs a clean *same-topology* baseline (the only delta between the two cycles is the forced risk).
|
||||
|
||||
### Added
|
||||
- **ADR-152 WiFi-Pose SOTA 2026 intake — verified external benchmark + four Rust integrations.** A 22-source adversarially-verified survey of the 2025–2026 WiFi-sensing SOTA, with every adopted number reproduced or graded before integration:
|
||||
- **WiFlow-STD (DY2434) reproduction (`benchmarks/wiflow-std/`)** — the external "97.25% PCK@20, 2.23M params" claim audited end-to-end: the **shipped checkpoint is REFUTED** (0.08% PCK@20 — wrong keypoint normalization, predates the published code), the released code does not run as published (6 documented defects, incl. an import that fails and an unreachable test phase), and the released dataset's final 13 files are corrupted (9,072 windows of NaN + float32-max garbage that NaN-poisons fp16 BatchNorm training). After repairing both, retraining with upstream defaults on an RTX 5080 reproduced **96.09% PCK@20 (full test) / 96.61% (corruption-free)** — claims graded MEASURED-EQUIVALENT; params (2,225,042) and FLOPs (~0.055 G) verified exactly. Full forensics in `benchmarks/wiflow-std/RESULTS.md`.
|
||||
- **`GeometryEmbedding` (ADR-152 §2.1.2, `wifi-densepose-calibration`)** — 32-slot permutation-invariant, NaN-proof featurization of the §2.1.1 `NodeGeometry` records (centroid/spread, measured-first pairwise distances, circular azimuth stats, covariance-eigenvalue geometric diversity, per-node flags), schema-versioned for the ADR-151 P6 LoRA heads; derived `SpecialistBank::geometry_embedding()` accessor. The PerceptAlign "coordinate overfitting" defense, transplanted to per-room banks.
|
||||
- **MAE pretraining recipe (ADR-152 §2.3, `wifi-densepose-train/src/mae.rs`)** — `MaePretrainConfig` pinning the UNSW-measured recipe (80% masking, (30,3) patches) with pure-Rust patchify/random-mask (exact counts, seed-deterministic, error-not-truncate divisibility, NaN rejection), property-tested; the consumption seam for the future ADR-150 ViT-Small encoder.
|
||||
- **`WiFlowStdModel` Rust port (`wifi-densepose-train/src/wiflow_std/`)** — tch-gated idiomatic port of the verified spatio-temporal-decoupled architecture (grouped causal TCN → asymmetric conv stack → dual axial attention); ungated param formula asserted equal to the reference 2,225,042; 15/17-keypoint variants share weights (enables the ADR-152 §2.2(b) ESP32 fine-tune).
|
||||
- **RuVector vendor sync + §2.6 opportunity survey** — vendor at `a083bd77f`; graded ADOPT/EVALUATE/WATCH table; crates.io bumps applied (mincut/solver 2.0.6, attention 2.1.0, gnn 2.2.0; RUSTSEC #504 audit: no pinned crate affected); top WATCH: unpublished `ruvector-graph-condense` differentiable min-cut for trainable subcarrier grouping.
|
||||
- **ADR-153 IEEE 802.11bf-2025 forward-compatibility protocol model (`wifi-densepose-hardware/src/ieee80211bf/`)** — typed WLAN-sensing procedures (measurement setup/instance/report, SBP, termination) with `SpecProfile` version gates, `SensingCapabilities` negotiation, and **required** `ConsentMode` governance metadata on every setup; deterministic session FSM with rejection/timeout paths; `SensingTransport` seam with `SimTransport` and an `OpportunisticCsiBridge` mapping live ESP32 CSI batches into standardized report shape (a future chipset adapter replaces the bridge without touching RuvSense consumers). Not a certified implementation — simulation-tested protocol surface; OTA binding lands when silicon does. 19 acceptance tests.
|
||||
- **Dynamic min-cut mesh partition guard in the streaming engine (`mesh_guard`).** Maintains a `ruvector-mincut` exact min-cut over the live mesh coupling graph (nodes = sensing nodes, coupling = product of fusion attention weights), surfacing per cycle: the global **cut value** (how close the array is to splitting — a structural measure per-node heuristics miss), the **weak side** (which specific nodes would partition: failure/jamming triage feeding ADR-032 posture), and an **at-risk flag** that counts as a structural event for the drift→recalibration advisor. Surfaced as `TrustedOutput::mesh`. **Measured cost policy** (criterion, 12-node mesh): weights are quantized (1/64; a *nonzero* coupling below one quantum saturates to quantum 1 so quantization never erases a live coupling — without the floor, balanced meshes of ≥ 65 nodes had every ~1/n coupling erased and sat permanently "at risk") and updates change-gated, so the steady-state cycle does zero graph work (~7.3 µs, ~23× cheaper than building); on any real change a full exact rebuild (~171 µs) is used because one `DynamicMinCut` delete+insert measured ~240 µs — the incremental machinery's overhead targets much larger graphs, so rebuild-on-change is the measured optimum at mesh scale (one-edge case −28% after the policy switch). Degenerate cases fail toward risk: a node with zero coupling is reported as already partitioned (cut 0). 9 mesh-guard tests + an engine-level wiring test; full `process_cycle` with the guard: ~33 µs for 4 nodes (50 ms budget).
|
||||
- **Opt-in FFT operator for the CIR ISTA solver (8–14× measured).** Φ is a sub-DFT, so each ISTA mat-vec can run as one length-G FFT (O(G log G)) instead of a dense O(K·G) product. New `CirConfig::fft_operator` (default **false** — the dense path stays the bit-exact witness default; the FFT evaluates the same sums in a different order, so enabling it shifts float results and requires regenerating any pinned witness). `FftOperator` (rustfft, planned once at construction, scratch reused across the ISTA loop) dispatches inside `ista_solve`; warm-start/Lipschitz stay dense at construction. Measured (criterion, same run): ht20 2.22 ms → 265 µs (**8.4×**), ht40 10.26 ms → 717 µs (**14.3×**); the real HE40 grid (K=484, G=1452) scales further. 3 new tests: FFT↔dense matvec equivalence to float tolerance (ht20 + he40 grids), end-to-end dominant-tap agreement on a single-path frame, and all default configs keep FFT off. New `cir_estimate_fft` bench group.
|
||||
- **Per-room adapter provenance + drift→recalibration advisor in the streaming engine.** Closes the trust-chain gap where an ~11 KB per-room LoRA adapter (ADR-150 §3.4) could silently change inference without the witness noticing. `StreamingEngine::set_room_adapter(AdapterInfo)` pins the adapter's content-derived id into provenance `model_version` (`rfenc-v1+adapter:<id>`) — and therefore into the BLAKE3 witness — so swapping or clearing adapter weights always shifts the witness (engine test proves base → adapter → other-adapter → cleared all witness differently, and cleared == base). New `RecalibrationAdvisor` recommends re-running the ADR-135 baseline / refitting the adapter on sustained low fusion coherence (streak threshold, default 60 cycles ≈ 3 s at 20 Hz) or an ADR-142 change-point; surfaced as `TrustedOutput::recalibration_recommended` and recorded on the sensing-server's `EngineBridge` alongside the witness. Bridge plumbing: `EngineBridge::{set_room_adapter, clear_room_adapter}` + live-path test that the adapter id flows into the live witness. *Scope note: this is the deployable provenance/trigger half of the "retrained model" roadmap item — fitting the adapter itself runs in the existing external calibration service (`aether-arena/calibration/`), and a trained RF-encoder checkpoint still does not exist in-tree.*
|
||||
|
||||
@@ -10,9 +10,9 @@ Dual codebase: Python v1 (`v1/`) and Rust port (`v2/`).
|
||||
| `wifi-densepose-core` | Core types, traits, error types, CSI frame primitives |
|
||||
| `wifi-densepose-signal` | SOTA signal processing + RuvSense multistatic sensing (16 modules) |
|
||||
| `wifi-densepose-nn` | Neural network inference (ONNX, PyTorch, Candle backends) |
|
||||
| `wifi-densepose-train` | Training pipeline with ruvector integration + ruview_metrics |
|
||||
| `wifi-densepose-train` | Training pipeline with ruvector integration + ruview_metrics; MAE pretraining recipe (`mae.rs`, ADR-152 §2.3) + WiFlow-STD port (`wiflow_std/`, tch-gated) |
|
||||
| `wifi-densepose-mat` | Mass Casualty Assessment Tool — disaster survivor detection |
|
||||
| `wifi-densepose-hardware` | ESP32 aggregator, TDM protocol, channel hopping firmware |
|
||||
| `wifi-densepose-hardware` | ESP32 aggregator, TDM protocol, channel hopping firmware; `ieee80211bf/` 802.11bf forward-compat protocol model (ADR-153) |
|
||||
| `wifi-densepose-ruvector` | RuVector v2.0.4 integration + cross-viewpoint fusion (5 modules) |
|
||||
| `wifi-densepose-wasm` | WebAssembly bindings for browser deployment |
|
||||
| `wifi-densepose-cli` | CLI tool (`wifi-densepose` binary) — `calibrate`/`calibrate-serve`/`enroll`/`train-room`/`room-watch` + MAT (MAT gated behind the `mat` feature; build `--no-default-features` for the aarch64/appliance calibration binary) |
|
||||
@@ -73,6 +73,8 @@ All 5 ruvector crates integrated in workspace:
|
||||
- ADR-031: RuView sensing-first RF mode (Proposed)
|
||||
- ADR-032: Multistatic mesh security hardening (Proposed)
|
||||
- ADR-148: Drone swarm control system / `ruview-swarm` (In Progress)
|
||||
- ADR-152: WiFi-Pose SOTA 2026 intake — geometry conditioning, WiFlow-STD benchmark (measurement (a) complete: claims MEASURED-EQUIVALENT at ~96% PCK@20), MAE recipe (Proposed; §2.1–2.3, 2.6 implemented)
|
||||
- ADR-153: IEEE 802.11bf-2025 forward-compatibility protocol model (Accepted — amends ADR-152 §2.4)
|
||||
|
||||
### Supported Hardware
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# Upstream clone (WiFlow-STD, DY2434) -- never commit third-party code/weights
|
||||
upstream/
|
||||
|
||||
# Local python env
|
||||
.venv/
|
||||
|
||||
# Downloaded data / artifacts
|
||||
data/
|
||||
downloads/
|
||||
*.pth
|
||||
*.pt
|
||||
*.npy
|
||||
*.npz
|
||||
*.zip
|
||||
*.mat
|
||||
*.safetensors
|
||||
results/parity_fixture.json
|
||||
__pycache__/
|
||||
*.onnx
|
||||
|
||||
# Committed ground truth: corruption masks for the pristine Kaggle download.
|
||||
# remote/clean_v2.py zeroes the corrupted source windows IN PLACE, so these
|
||||
# masks CANNOT be regenerated from a cleaned copy (generate_corruption_masks.py
|
||||
# documents the criteria and reproduces them only from a fresh download).
|
||||
!results/nan_windows_mask.npy
|
||||
!results/big_windows_mask.npy
|
||||
@@ -0,0 +1,486 @@
|
||||
# WiFlow-STD (DY2434) Benchmark Results — ADR-152 §2.2
|
||||
|
||||
Upstream: <https://github.com/DY2434/WiFlow-WiFi-Pose-Estimation-with-Spatio-Temporal-Decoupling>
|
||||
pinned at `06899d29` (2026-04-05), Apache-2.0. Dataset: Kaggle `kaka2434/wiflow-dataset`
|
||||
(12.8 GB archive → 15.5 GB extracted; 360,000 windows of 540×20 CSI + 15-keypoint 2D labels).
|
||||
|
||||
Published claims (README "Setting 1"): PCK@20 97.25%, PCK@30 98.63%, PCK@40 99.16%,
|
||||
PCK@50 99.48%, MPJPE 0.007 m, 2.23M params, 0.07 GFLOPs.
|
||||
|
||||
## Measurement (a): their model on their data
|
||||
|
||||
### Artifact verification (MEASURED, 2026-06-10, this repo `eval_repro.py`)
|
||||
|
||||
| Check | Result |
|
||||
|---|---|
|
||||
| Parameter count | **2,225,042 (2.23M) — matches claim** |
|
||||
| FLOPs (torch profiler, batch 1) | ~0.055 GFLOPs — consistent with 0.07B claim |
|
||||
| CPU latency (Windows box, torch 2.12 CPU) | 13.2 ms/window @ batch 1 (76/s); 2.48 ms/sample @ batch 64 (403/s) |
|
||||
| Checkpoint load | `weights_only=True` (no pickle code execution) |
|
||||
|
||||
### Released checkpoint does NOT reproduce the claims — REFUTED as shipped
|
||||
|
||||
Running the released `best_pose_model.pth` through the released code on the released
|
||||
dataset with the released split procedure (seed-42 file-level 70/15/15; 54,000 test
|
||||
samples) yields:
|
||||
|
||||
| Metric | Published | Measured (shipped checkpoint) |
|
||||
|---|---|---|
|
||||
| PCK@20 | 97.25% | **0.08%** |
|
||||
| PCK@30 | 98.63% | 0.78% |
|
||||
| PCK@40 | 99.16% | 5.53% |
|
||||
| PCK@50 | 99.48% | 15.42% |
|
||||
| MPJPE | 0.007 | **NaN** (dataset contains NaN CSI windows) |
|
||||
|
||||
Raw output: `results/repro_a.json`.
|
||||
|
||||
Diagnostics (on 2,000 NaN-free windows from the first files of the dataset, i.e.
|
||||
mostly would-be *training* data — so this is not a split mismatch):
|
||||
|
||||
- Predictions correlate with targets (Pearson r ≈ 0.76) — the checkpoint is a trained
|
||||
model, but in a **different keypoint normalization/order** than the released data.
|
||||
- Best-case post-hoc global per-axis affine correction: PCK@20 ≈ 20%.
|
||||
- Best-case per-keypoint affine correction (15×2 fitted transforms — generous
|
||||
cheating): PCK@20 ≈ 72%, still far below 97.25%.
|
||||
- Pred↔target keypoint correspondence matrix is degenerate (multiple predicted
|
||||
keypoints best-match the same target joint) — keypoint convention mismatch.
|
||||
|
||||
### Reproducibility defects in the released artifacts
|
||||
|
||||
1. `models/__init__.py` imports `TemporalConvNet`, which `models/tcn.py` does not
|
||||
define — **the published code does not import/run as-is**.
|
||||
2. The released root checkpoint uses pre-rename module names (`att.*`, `final_conv.*`)
|
||||
vs the published code (`attention.*`, `decoder.*`) — same shapes/param count, but
|
||||
confirms the checkpoint predates the published code.
|
||||
3. The second shipped checkpoint (`cross_dataset_test/WiFlow/best_pose_model.pth`) is
|
||||
a **different architecture** (342-channel input = MM-Fi layout, 3 TCN layers,
|
||||
3-channel/3D decoder) — not usable on their own dataset.
|
||||
4. `run.py` ignores `--data_dir` and hardcodes `../preprocessed_csi_data`.
|
||||
5. The released dataset's final 13 files (indices 487–499; 9,072 windows, 2.52%)
|
||||
are corrupted: NaN values plus garbage amplitudes up to 3.4e38 (float32 max) in
|
||||
data that is otherwise [0,1]-normalized. Upstream code has no NaN/inf handling;
|
||||
training as published on this download diverges — the first corrupted batch
|
||||
overflows fp16 autocast and permanently poisons BatchNorm running statistics
|
||||
(GradScaler step-skipping does not protect BN). The authors' training curves
|
||||
show normal convergence, so their local data evidently differed from the
|
||||
Kaggle upload. Window masks: `results/nan_windows_mask.npy`,
|
||||
`results/big_windows_mask.npy`.
|
||||
|
||||
### Reproducing the corruption masks
|
||||
|
||||
The two mask files (9,070 NaN/Inf windows, 9,072 with |amplitude| > 1.5;
|
||||
union 9,072, all in dataset files 487–499) are **committed ground truth**
|
||||
(gitignore-negated, ~352 KB each). They can only be regenerated from a
|
||||
**pristine** Kaggle download: `remote/clean_v2.py` repairs the dataset by
|
||||
zeroing the corrupted windows in place, after which the corruption evidence
|
||||
is gone and a rescan returns all-False. `generate_corruption_masks.py`
|
||||
re-derives them (chunked scan, criteria: any non-finite value OR
|
||||
max |finite| > 1.5 per 540×20 window) and refuses to write all-False masks,
|
||||
which indicate a cleaned copy. Verified 2026-06-11: a regeneration from the
|
||||
local pristine download is bit-identical to the committed masks.
|
||||
|
||||
### Retraining result (MEASURED, 2026-06-10): claims APPROXIMATELY REPRODUCED
|
||||
|
||||
Since the shipped checkpoint is unusable, measurement (a) fell back to retraining
|
||||
with upstream code + defaults (seed 42, batch 64, early-stopped at epoch 41 of 50,
|
||||
best epoch 36, ~75 s/epoch) on ruvultra (RTX 5080). Deviations, all forced and
|
||||
documented: one-line fix for defect (1); torch 2.x+cu128 instead of pinned 2.3.1
|
||||
(Blackwell sm_120 unsupported); the 9,072 corrupted windows (defect 5) zeroed
|
||||
entirely — without this the published pipeline produces NaN from epoch 1 (observed).
|
||||
Scripts mirrored in `remote/`; raw metrics in `results/eval_retrained.json`.
|
||||
|
||||
| Metric | Published | Retrained (full test, 54,000) | Retrained (corruption-free, 52,560) |
|
||||
|---|---|---|---|
|
||||
| PCK@20 | 97.25% | **96.09%** | **96.61%** |
|
||||
| PCK@30 | 98.63% | 97.89% | 98.23% |
|
||||
| PCK@40 | 99.16% | 98.58% | 98.79% |
|
||||
| PCK@50 | 99.48% | 98.99% | 99.11% |
|
||||
| MPJPE | 0.007 | 0.0098 | 0.0094 |
|
||||
|
||||
Within ~0.6–1.2 PCK points of every published figure (single run, corrupted train
|
||||
windows zeroed, different torch/GPU). **Verdict: the accuracy claims are credible
|
||||
and approximately reproducible — but only after repairing the released dataset and
|
||||
code.** Val best: PCK@20 96.99%, MPJPE 0.0086 (epoch 36).
|
||||
|
||||
One more defect found during the run:
|
||||
|
||||
6. `train.py` calls `plot_training_history`, which is not defined anywhere — the
|
||||
built-in post-training test evaluation is unreachable as published (crashes
|
||||
with NameError after training completes).
|
||||
|
||||
## ADR-152 §2.2 citation rule
|
||||
|
||||
Evidence grade for the WiFlow-STD accuracy claims after measurement (a):
|
||||
**MEASURED-EQUIVALENT (96.1–96.6% PCK@20 reproduced by retraining; shipped
|
||||
checkpoint REFUTED; dataset/code require repairs)**. RuView docs may cite
|
||||
"~96% PCK@20 (our reproduction)" — still **not comparable** to our 17-keypoint
|
||||
ESP32 numbers (different hardware, 5 subjects, in-domain random split,
|
||||
15 keypoints).
|
||||
|
||||
## Edge optimization (measured)
|
||||
|
||||
ADR-152 "optimize beyond SOTA" track, 2026-06-10, this Windows box (Windows 11,
|
||||
16 torch threads, torch 2.12.0+cpu, onnxruntime 1.26.0). Subject: the retrained
|
||||
checkpoint `results/retrained_best_pose_model.pth` (2,225,042 fp32 params).
|
||||
Scripts: `quantize_bench.py`, `onnx_bench.py`, `eval_ort_accuracy.py`.
|
||||
Raw numbers: `results/edge_optimization.json`.
|
||||
|
||||
Accuracy is on a **10,000-window seed-42 random subset** of the corruption-free
|
||||
test split (same seed-42 file-level 70/15/15 split as `eval_repro.py`; 54,000
|
||||
test windows, 1,440 corrupted excluded via `results/nan_windows_mask.npy` |
|
||||
`results/big_windows_mask.npy`, leaving 52,560; subset drawn with
|
||||
`np.random.default_rng(42)`). The fp32 subset PCK@20 (96.68%) matches the full
|
||||
clean-test figure (96.61%), so the subset is representative.
|
||||
|
||||
Latency is CPU ms/window, median of repeated runs, 3 interleaved repetitions
|
||||
per variant (medians below; run-to-run spread on this box is large, roughly
|
||||
±20-40% at batch 1 — reps are in the JSON).
|
||||
|
||||
| Variant | Disk size | Batch 1 (ms/win) | Batch 64 (ms/win) | PCK@20 | PCK@50 | MPJPE |
|
||||
|---|---|---|---|---|---|---|
|
||||
| torch fp32 (baseline) | 9.07 MB | 11.0 | 2.27 | 96.68% | 99.15% | 0.00936 |
|
||||
| torch fp16 (`.half()`) | **4.58 MB** | 24.3 | 2.42 | 96.68% | 99.15% | 0.00946 |
|
||||
| torch int8 dynamic | 9.07 MB (unchanged) | 15.6 | 2.06 | 96.68% (identical) | 99.15% | 0.00936 |
|
||||
| ONNX fp32 (onnxruntime) | 8.97 MB | **3.2** | **2.0** | 96.68% | 99.15% | 0.00936 |
|
||||
| ONNX int8 (ORT dynamic, supplementary) | **2.44 MB** | 6.5 | 5.8 | 96.52% | 99.15% | 0.01108 |
|
||||
|
||||
Findings:
|
||||
|
||||
- **torch dynamic INT8 quantizes nothing on this model.** The architecture has
|
||||
**zero `nn.Linear` layers** — it is entirely Conv1d (21) + Conv2d (22) +
|
||||
BatchNorm. `torch.ao.quantization.quantize_dynamic` (requested over
|
||||
`{Linear, Conv1d, Conv2d}`) converted **0 modules / 0.0% of params**: dynamic
|
||||
quantization only has kernels for Linear/RNN-family modules and silently
|
||||
skips convolutions. The "int8" model is bit-identical to fp32 (same outputs,
|
||||
same 9.07 MB). Conv quantization would require static (PTQ) quantization
|
||||
with calibration — out of scope here; the ORT dynamic path below is the
|
||||
honest int8 datapoint.
|
||||
- **fp16 halves size for free accuracy-wise** (PCK@20 −0.005 pt, MPJPE
|
||||
+0.0001) but is *slower* on CPU at batch 1 (~2.2×) — torch CPU fp16 conv
|
||||
kernels are emulated. fp16 is a storage/transport format here, not a CPU
|
||||
runtime win.
|
||||
- **ONNX Runtime is the real batch-1 latency win: ~3.4× faster than torch**
|
||||
(3.2 vs 11.0 ms/window) at identical accuracy (parity 2.4e-7).
|
||||
|
||||
### Verdict on the paper's "~2.2 MB int8" claim
|
||||
|
||||
**Plausible but not free, and unreachable by the obvious PyTorch route.**
|
||||
2,225,042 params × 1 byte ≈ 2.2 MB assumes *every* parameter quantizes.
|
||||
PyTorch dynamic quantization — the one-liner most readers would reach for —
|
||||
yields **9.07 MB (0% quantized)** because the model has no Linear layers.
|
||||
ONNX Runtime dynamic quantization, which does have int8 conv weight support,
|
||||
gets **2.44 MB** (close to the claim; the overhead is BatchNorm params/buffers
|
||||
and quantization scales kept in fp32) at a measurable accuracy cost:
|
||||
PCK@20 96.68 → 96.52% (−0.16 pt) and MPJPE 0.00936 → 0.01108 (+18%), and
|
||||
~2× slower inference than ONNX fp32 (ConvInteger kernels). The paper does not
|
||||
state a method or an int8 accuracy; treat "2.2 MB" as a weight-arithmetic
|
||||
estimate, achievable in practice only via conv-capable quantization toolchains
|
||||
and with a small accuracy penalty.
|
||||
|
||||
### ONNX export status
|
||||
|
||||
**Works.** Exported via the TorchScript exporter (`dynamo=False`), opset 17,
|
||||
with a dynamic batch axis — `results/retrained_fp32_dynamic.onnx` (8.97 MB),
|
||||
verified to run at batch 1/2/64. The axial attention's
|
||||
`view(N*W, C, H)` reshape traced correctly (sizes recorded as graph ops, not
|
||||
baked constants). The dynamo exporter also captures the graph but crashed on
|
||||
this box writing a ✅ to a cp1252 console (cosmetic Windows encoding issue, not
|
||||
a model blocker). Parity vs torch on the stored fixture
|
||||
(`results/parity_fixture.npz`, batch 2, seed 42): **max abs diff 2.4e-7 —
|
||||
PASS** (< 1e-4). ORT-quantized int8 model: `results/retrained_int8_ort_dynamic.onnx`.
|
||||
|
||||
### Static PTQ (calibrated) — follow-up
|
||||
|
||||
Follow-up to the dynamic-int8 row above (2026-06-10, same box, onnxruntime
|
||||
1.26.0): ONNX Runtime **static** post-training quantization
|
||||
(`quantize_static`, QDQ format, per-channel int8 weights + int8 activations)
|
||||
of the same fp32 export, calibrated on **corruption-free TRAINING-split
|
||||
windows only** (seed-42 file-level split, same masks; 1,000 windows for
|
||||
MinMax, 512 for the histogram calibrators; never test windows). Scopes:
|
||||
"conv-only" (`op_types_to_quantize=["Conv"]` — the attention path exports as
|
||||
Einsum/Softmax, which ORT never quantizes anyway, so "all-ops" additionally
|
||||
quantizes the elementwise Mul/Sigmoid/Add/AveragePool glue). Accuracy on the
|
||||
identical 10k-window seed-42 corruption-free test subset; latency median of
|
||||
3 interleaved reps (fp32/dynamic re-benched in-session as references).
|
||||
Script: `static_ptq_bench.py`; raw: `results/edge_optimization.json`
|
||||
(`onnx_static_ptq`).
|
||||
|
||||
| Variant | Disk size | Batch 1 (ms/win) | Batch 64 (ms/win) | PCK@20 | PCK@50 | MPJPE |
|
||||
|---|---|---|---|---|---|---|
|
||||
| ONNX fp32 (reference) | 8.97 MB | 2.5 | 1.9 | 96.68% | 99.15% | 0.00936 |
|
||||
| ORT dynamic int8 (baseline) | **2.44 MB** | 5.7 | 4.6 | 96.52% | 99.15% | 0.01108 |
|
||||
| static QDQ **Percentile(99.99) conv-only** | 2.53 MB | 5.3 | 4.7 | 96.61% | 99.16% | **0.01031** |
|
||||
| static QDQ MinMax conv-only | 2.53 MB | 5.2 | 3.3 | **96.63%** | 99.19% | 0.01084 |
|
||||
| static QDQ Entropy conv-only | 2.53 MB | 5.2 | 3.1 | 96.60% | 99.19% | 0.01078 |
|
||||
| static QDQ MinMax all-ops | 2.60 MB | 6.5 | 3.9 | 95.45% | 99.14% | 0.01486 |
|
||||
| static QDQ Entropy all-ops | 2.60 MB | 5.7 | 4.1 | 95.30% | 99.13% | 0.01510 |
|
||||
| static QDQ Percentile all-ops | 2.60 MB | 5.3 | 4.3 | 96.39% | 99.17% | 0.01218 |
|
||||
|
||||
**Verdict: static PTQ (conv-only) is the new best int8 point on accuracy —
|
||||
but only modestly, and it does not fix int8's latency penalty.**
|
||||
|
||||
- **Accuracy: beats dynamic.** All three conv-only calibrations land at
|
||||
PCK@20 96.60–96.63% (vs dynamic 96.52%, fp32 96.68% — recovers ~⅔ of the
|
||||
dynamic gap) and MPJPE 0.0103–0.0108 (vs dynamic 0.01108). Best MPJPE:
|
||||
Percentile conv-only, +10% over fp32 instead of dynamic's +18%.
|
||||
- **Size: slightly worse.** 2.53 MB vs 2.44 MB (+3.6%) — QDQ nodes and
|
||||
per-channel scales cost a little; BatchNorm stays fp32 in both (the 12 BNs
|
||||
follow Slice/Einsum/Reshape, never Conv, so they cannot be folded).
|
||||
- **Latency: a wash vs dynamic, still ~2× slower than ONNX fp32 at batch 1.**
|
||||
Batch-1 medians 5.2–5.3 vs dynamic 5.7 ms/win in-session — within this
|
||||
box's ±20–40% noise. Batch 64 leans static (3.1–3.3 for MinMax/Entropy
|
||||
conv-only vs 4.6), same caveat.
|
||||
- **All-ops QDQ is strictly worse**: up to −1.4 pt PCK@20 and +60% MPJPE for
|
||||
zero size/latency benefit — int8 activations through the elementwise glue
|
||||
around the attention blocks is where the damage is. Conv-only is the right
|
||||
scope.
|
||||
- Negative result worth recording: **Entropy calibration is a no-op here** —
|
||||
on an identical calibration set it selects full-range thresholds
|
||||
bit-identical to MinMax (all 247 scales equal; verified on a 64-window
|
||||
smoke set). Also, ORT 1.26's `CalibMaxIntermediateOutputs` raises a
|
||||
spurious "No data is collected" when the batch count divides the chunk
|
||||
size (worked around in the script).
|
||||
|
||||
Deployment guidance: need speed → ONNX fp32 (3.2 ms b1). Need int8 weights
|
||||
for size → static QDQ conv-only (Percentile or MinMax,
|
||||
`results/retrained_int8_static_percentile_conv.onnx`), which strictly
|
||||
dominates dynamic int8 on accuracy at ~equal latency and +0.09 MB.
|
||||
|
||||
## Efficiency sweep (MEASURED, overnight 2026-06-10/11)
|
||||
|
||||
ADR-152 beyond-SOTA track: compact purpose-built variants of the WiFlow-STD
|
||||
architecture, trained from scratch on the same cleaned dataset, identical
|
||||
seed-42 file-level split, loss and protocol as the measurement-(a) reference
|
||||
(fp32, batch 64, ≤50 epochs, patience 5; RTX 5080, ~22–29 min/variant).
|
||||
Variant transforms are pure channel/group/stride scalings of an
|
||||
architecture-exact parameterized model (validated: reproduces 2,225,042 params
|
||||
at the reference config). Scripts: `remote/sweep/`; raw:
|
||||
`results/efficiency_sweep.jsonl`; checkpoints `results/{half,quarter,tiny}_best.pth`
|
||||
(gitignored).
|
||||
|
||||
| Variant | Params | vs 2.23M | Clean-test PCK@20 | PCK@50 | MPJPE | Best epoch |
|
||||
|---|---|---|---|---|---|---|
|
||||
| full (reference, meas. a) | 2,225,042 | 1× | 96.61% | 99.11% | 0.0094 | 36 |
|
||||
| **half** | **843,834** | **0.38×** | **96.62%** | **99.47%** | **0.00898** | 23 |
|
||||
| quarter | 338,600 | 0.15× | 96.05% | 99.43% | 0.00928 | 50 |
|
||||
| tiny | 56,290 | 0.025× | 94.11% | 99.36% | 0.0125 | 47 |
|
||||
|
||||
Findings:
|
||||
|
||||
- **The half model (843k params) strictly dominates the full reference** on
|
||||
this dataset — equal PCK@20, better PCK@50 and MPJPE, converges in fewer
|
||||
epochs. The published 2.23M architecture is over-parameterized for its own
|
||||
benchmark.
|
||||
- **tiny (56k params, 1/39.5) holds 94.11% PCK@20** — a ~220 KB fp32 /
|
||||
~60 KB int8-class model in reach of severely constrained edge targets,
|
||||
at −2.5 pt from the full reference.
|
||||
- Caveats: in-domain (5-subject random-file split) like every number on this
|
||||
dataset; single run per variant; corruption-free test subset (52,560).
|
||||
Cross-domain behavior of compact variants is untested — ADR-150's evidence
|
||||
says capacity *hurts* cross-subject, so the compact end may generalize no
|
||||
worse, but that is a hypothesis, not a measurement.
|
||||
|
||||
### Compact-variant edge artifacts (MEASURED, 2026-06-11)
|
||||
|
||||
Edge pipeline for the **tiny** checkpoint (56,290 params), same machinery and
|
||||
protocol as the full-model edge rows above (this Windows box, torch
|
||||
2.12.0+cpu, onnxruntime 1.26.0; dynamic-batch opset-17 TorchScript export;
|
||||
static QDQ **Percentile(99.99) conv-only** int8 calibrated on **512**
|
||||
corruption-free TRAIN-split windows; accuracy on the identical 10k-window
|
||||
seed-42 clean test subset; latency = median ms/window over 3 interleaved
|
||||
reps, with the full-model fp32/int8 sessions interleaved as same-session
|
||||
references). Script: `tiny_edge_bench.py`; raw:
|
||||
`results/edge_optimization.json` (`tiny_variant`). Torch-vs-ORT parity on the
|
||||
stored fixture input: **max abs diff 1.5e-7 — PASS** (< 1e-4). The tiny fp32
|
||||
subset PCK@20 (94.11%) matches the full clean-test sweep figure (94.11%)
|
||||
exactly, so the subset remains representative.
|
||||
|
||||
Two forced deviations, both recorded in the JSON:
|
||||
|
||||
1. **Adaptive-pool export rewrite.** tiny's derived stride schedule
|
||||
`[2,1,1,1]` leaves feature width 16, and the TorchScript exporter rejects
|
||||
`AdaptiveAvgPool2d((15,1))` when 15 is not a factor of the input height
|
||||
(the full model never hit this — its width was exactly 15). Since the
|
||||
pool over a fixed-size map is a fixed linear operator, the export wrapper
|
||||
replaces it with `mean(-1)` (W axis, a factor) + a constant averaging
|
||||
matmul using PyTorch's exact bin rule; the parity check (vs the original
|
||||
torch model with the real pool) proves exactness.
|
||||
2. **Calibration count 512, not "~500"**: ORT 1.26's histogram collector
|
||||
`np.asarray()`'s the per-batch maxima, so the calibration count must be a
|
||||
multiple of the 64-window calibration batch or the ragged last batch
|
||||
crashes it (the earlier static-PTQ run dodged this by using exactly 512).
|
||||
|
||||
| Variant | Disk size | Batch 1 (ms/win) | Batch 64 (ms/win) | PCK@20 | PCK@50 | MPJPE |
|
||||
|---|---|---|---|---|---|---|
|
||||
| full ONNX fp32 (same-session ref) | 8.97 MB | 2.27 | 1.42 | 96.68% | 99.15% | 0.00936 |
|
||||
| full static QDQ Percentile conv-only (same-session ref) | 2.53 MB | 5.53 | 3.82 | 96.61% | 99.16% | 0.01031 |
|
||||
| **tiny ONNX fp32** | **0.295 MB** | **0.66** | **0.24** | **94.11%** | 99.37% | 0.01253 |
|
||||
| tiny static QDQ Percentile conv-only | 0.248 MB | 0.85 | 1.03 | 92.68% | 99.33% | 0.01491 |
|
||||
|
||||
(tiny torch `.pth` checkpoint for reference: 0.34 MB on disk; 56,290 fp32
|
||||
params ≈ 225 KB of weights.)
|
||||
|
||||
Findings:
|
||||
|
||||
- **The smallest deployable WiFlow-class model is the tiny ONNX fp32
|
||||
artifact: ~295 KB on disk, 0.66 ms/window batch-1 CPU (~1,500 windows/s),
|
||||
94.1% PCK@20** — 30× smaller and ~3.4× faster (in-session) than the full
|
||||
ONNX fp32 model for −2.6 pt PCK@20.
|
||||
- **int8 is a bad trade at this scale.** Static QDQ conv-only — the recipe
|
||||
that cost the full model only 0.07 pt — costs tiny **−1.43 pt** PCK@20
|
||||
(94.11 → 92.68%) and +19% MPJPE, saves only 47 KB (−16%; QDQ scales and
|
||||
the fp32 BN/attention glue are proportionally larger in a small graph),
|
||||
and is *slower* than tiny fp32 (0.85 vs 0.66 ms b1; 1.03 vs 0.24 ms b64 —
|
||||
QDQ kernel overhead dominates when the convs are this small). A 56k-param
|
||||
model has little redundancy left to absorb weight+activation rounding.
|
||||
- Deployment guidance, compact edition: ship tiny as **ONNX fp32** — at
|
||||
295 KB the int8 size saving solves no real constraint and costs accuracy
|
||||
and speed. If ~250 KB vs ~295 KB ever matters, weight-only quantization
|
||||
would be the thing to try next, not QDQ.
|
||||
|
||||
## Measurement (b): BLOCKED-ON-DATA (attempted 2026-06-10)
|
||||
|
||||
The fine-tune-on-ESP32 measurement stopped at dataset characterization, per the
|
||||
pre-registered stop rule (<2,000 paired windows). Findings (MEASURED):
|
||||
|
||||
- **Only one trainable paired dataset exists**: `ruvultra:~/work/cog-pose-train/paired.jsonl`
|
||||
— 1,077 windows (one subject, one room, one 29.9-min session, single node;
|
||||
CSI [56, 20]; 17 COCO keypoints, MediaPipe confidence mean 0.44 — only 264
|
||||
windows pass ADR-079's own conf>0.5 training filter). Prior measured attempts
|
||||
on this exact set: 0–3% torso-PCK@20 (temporal splits, three independent
|
||||
pipelines). Fine-tuning a 2.23M-param model on ~860 train windows would
|
||||
measure memorization, not transfer.
|
||||
- **The April session behind the old "92.9% PCK@20" claim is lost** (345
|
||||
samples, 35 subcarriers; raw CSI gone from ruvzen/ruvultra/cognitum-v0; only
|
||||
a 69-sample predictions+GT holdout survives at `models/wiflow-real/eval-holdout.jsonl`).
|
||||
- **Forensic recheck of that holdout RETRACTS the 92.9% figure**: the trainer's
|
||||
`pck()` used an absolute 0.2 image-unit threshold (not torso-normalized) and
|
||||
the model output a **constant pose** (pred std 0.0000 across 69 near-static
|
||||
frames; a mean predictor scores 100% under the same protocol). The
|
||||
torso-normalized PCK@20 on the same holdout is 19.1%. This corroborates the
|
||||
2026-05-11 audit retraction (CHANGELOG, PR #535); stale doc citations were
|
||||
removed 2026-06-10 (user-guide, readme-details, ADR-152 §2.1.3). The §2.2
|
||||
no-citation rule now applies to ADR-079 accuracy claims.
|
||||
|
||||
Unblock criteria: a paired collection session of ≥2k windows (≈35+ min at the
|
||||
observed stride; multi-pose, conf>0.5, ideally with the §2.1.3 two-checkerboard
|
||||
calibration), plus a re-baselined our-pipeline number under torso-PCK@20 on the
|
||||
same split. WiFlow-STD assets stand ready on ruvultra (`~/wiflow-std-bench/`).
|
||||
Also worth investigating: ADR-079's protocol predicts ~9k windows per 30 min;
|
||||
the May session under-delivered ~8× (aligner drop rate?).
|
||||
|
||||
## Measurement (b) (MEASURED 2026-06-10/11)
|
||||
|
||||
The data baseline unblocked: the 2026-06-10 22:10–22:40 collection session produced
|
||||
**2,046 paired windows** (`ruvultra:~/wiflow-std-bench/paired-20260610.jsonl`; ONE
|
||||
subject, ONE room, ONE ESP32 node, varied poses: walk/raise/squat/kick/wave/turn/
|
||||
jump/sit; aligner `scripts/align-ground-truth.js`, non-overlapping 20-frame windows
|
||||
~0.42 s; 17 COCO keypoints in normalized [0,1] camera coords; MediaPipe confidence
|
||||
mean 0.802, min 0.692 — all windows pass the conf>0.5 filter). The −4 h timestamp
|
||||
bug and the empty-frame confidence-dilution aligner findings are recorded
|
||||
separately; results only here. Trained on ruvultra (RTX 5080, torch 2.11+cu128,
|
||||
fp32, batch 32, GPU shared with the efficiency sweep). Scripts mirrored in
|
||||
`remote/measb/`; raw metrics + full training curves in `results/measurement_b.json`.
|
||||
|
||||
### Two new aligner/dataset findings (forced deviations, MEASURED)
|
||||
|
||||
1. **`csi_shape` is heterogeneous, not [70, 20]**: 1,347× [70,20], 284× [134,20],
|
||||
243× [26,20], 130× [12,20], 42× [20,20]. The ESP32 stream emits mixed frame
|
||||
types and `extractCsiMatrix` stamps each window's subcarrier count from
|
||||
`window[0].subcarriers`, zero-padding/truncating the other frames — even
|
||||
native-70 windows contain ~20.4% internally zero-padded short frames
|
||||
(subcarriers 40–69 all-zero). Handling: the primary suite ("all 2,046")
|
||||
linearly resamples every frame's subcarrier axis to 70 bins (identity for
|
||||
native-70 frames) so the pre-registered n and split sizes hold; a secondary
|
||||
suite restricts to the 1,347 native [70,20] windows as a homogeneity check.
|
||||
2. **Aligner layout bug**: `extractCsiMatrix` fills `matrix[f * nSc + s]`
|
||||
(frame-major) but declares `shape: [nSc, nFrames]` — the stored shape label is
|
||||
transposed relative to the data. Confirmed by coherent per-frame zero-tails;
|
||||
corrected on load (`reshape(nFrames, nSc).T`).
|
||||
|
||||
### Protocol (pre-registered, followed)
|
||||
|
||||
Temporal split, no shuffling across time: first 70% train (1,432), next 15% val
|
||||
(307), last 15% test (307); seed 42 elsewhere. Model: learned 1×1 Conv1d 70→540
|
||||
adapter prepended to the upstream WiFlow-STD trunk; K=17 via the parameter-free
|
||||
adaptive pool (`AdaptiveAvgPool2d((17,1))` — pretrained weights load strict for
|
||||
any K). CSI normalized by the TRAIN-split p99 amplitude (129.7 all / 130.9
|
||||
native-70), clipped to [0,1]. Three runs, ≤60 epochs, early-stop patience 8 on
|
||||
val MPJPE, AdamW (adapter lr 1e-4; pretrained trunk lr 1e-5, 10× lower; scratch
|
||||
all 1e-4), fp32. Pretrained init = the measurement-(a) **retrained** checkpoint
|
||||
(`upstream/test/best_pose_model.pth`, ~96% PCK@20 on WiFlow data; the
|
||||
`att.`/`final_conv.` key remap from `eval_repro.py` applied defensively — a no-op,
|
||||
that checkpoint already uses post-rename keys). Frozen-trunk run: trunk
|
||||
`requires_grad=False` **and** held in `.eval()` so BatchNorm running stats cannot
|
||||
drift — a pure transfer probe; only the 70→540 adapter (38,340 params) trains.
|
||||
|
||||
PCK is torso-normalized with **torso = ‖l_shoulder(5) − l_hip(11)‖** (upstream
|
||||
`calculate_pck` math — per-frame norm clamped at 0.01, mean over keypoints ×
|
||||
frames — but upstream's `NECK_IDX/PELVIS_IDX = 2, 12` is a 15-keypoint
|
||||
convention; on 17-kp COCO those indices are right_eye/right_hip, so the indices
|
||||
were replaced, not the math). MPJPE is in normalized image units (not meters).
|
||||
|
||||
### Results — primary suite, all 2,046 windows (test = last 307)
|
||||
|
||||
| Run | PCK@10 | PCK@20 | PCK@30 | PCK@40 | PCK@50 | MPJPE | pred std | best ep |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| **mean-pose baseline** (honesty bar) | **73.1%** | **95.9%** | **98.7%** | 99.3% | 99.3% | **0.0148** | 0 (by constr.) | — |
|
||||
| (i) pretrained-init, full fine-tune | 26.0% | 65.0% | 88.0% | 96.4% | 98.9% | 0.0313 | 0.0113 | 58/60 |
|
||||
| (ii) scratch | 0.0% | 0.0% | 0.0% | 0.0% | 0.0% | 0.2554 | 0.0002 | 4 (stop @13) |
|
||||
| (iii) frozen-trunk (adapter only) | 0.0% | 0.0% | 0.2% | 3.2% | 14.4% | 0.1260 | 0.0073 | 59/60 |
|
||||
|
||||
Secondary suite (native [70,20] windows only, n=1,347, test=202) reproduces the
|
||||
same ordering: mean-baseline 96.0% / pretrained 67.1% / scratch 0.0% /
|
||||
frozen-trunk 0.0% PCK@20 (MPJPE 0.0153 / 0.0318 / 0.2236 / 0.1343) — the
|
||||
subcarrier-resampling choice does not change any conclusion.
|
||||
|
||||
### Interpretation
|
||||
|
||||
- **Did pretraining-transfer happen? Partially — as optimization transfer, not
|
||||
feature transfer, and not past the honesty bar.**
|
||||
- *Pretrained vs scratch*: dramatic (65.0% vs 0.0% PCK@20). The pretrained init
|
||||
is the only configuration that trains at all under the pre-registered budget.
|
||||
- *Frozen-trunk*: near-zero (0.0% PCK@20, 14.4% @50). WiFlow-STD's frozen
|
||||
features do **not** transfer to our ESP32 domain through a linear subcarrier
|
||||
adapter — the pretrained benefit is a well-conditioned initialization (incl.
|
||||
calibrated BN/output scales), not reusable CSI→pose features.
|
||||
- *Everything vs mean-pose baseline*: **no run beats it.** A constant
|
||||
train-mean pose scores 95.9% torso-PCK@20 / 0.0148 MPJPE on this test split,
|
||||
because a single subject in one camera frame barely moves in normalized
|
||||
coordinates. The fine-tuned model is a real, non-constant model
|
||||
(pred std 0.0113 > 0 — passes the constant-pose detector that retracted the
|
||||
old 92.9% figure) but its deviations from the mean hurt: it fits train-period
|
||||
temporal dynamics that do not generalize across the temporal split.
|
||||
- **Verdict for ADR-152 §2.2(b): fine-tuning WiFlow-STD on this dataset does not
|
||||
demonstrate CSI→pose signal beyond the mean pose.** Until a model beats the
|
||||
mean-pose baseline on a temporal split, no PCK number from this line may be
|
||||
cited as pose-estimation capability.
|
||||
|
||||
### Caveats (honest, pre-registered)
|
||||
|
||||
- Single subject, single room, single session (30 min), single ESP32 node —
|
||||
in-domain temporal split only; nothing here speaks to cross-room or
|
||||
cross-subject generalization.
|
||||
- 2k windows vs the 360k-window WiFlow-STD corpus — **NOT comparable** to the
|
||||
~96% in-domain measurement-(a) number, and the published 97.25% even less so.
|
||||
- The scratch run's total collapse (it cannot even reach the mean pose; its
|
||||
output BatchNorm/SiLU head must learn output scale from random init at lr 1e-4)
|
||||
is an optimization outcome under the fixed budget, not proof the architecture
|
||||
cannot learn from scratch — the pretrained-vs-scratch gap partially reflects
|
||||
this conditioning advantage.
|
||||
- Mixed-subcarrier frames (finding 1) mean even the "clean" windows carry ~20%
|
||||
zero-padded frames; collection-side frame-type filtering should precede the
|
||||
next session.
|
||||
- Mean-baseline PCK is inflated by low pose variance relative to torso size
|
||||
(~0.2–0.3 image units); PCK@10 (73.1%) shows the same ceiling effect at a
|
||||
stricter threshold — the bar is the bar, but a livelier dataset would lower it.
|
||||
|
||||
## Pending
|
||||
|
||||
- (b) fine-tune on our ESP32 17-keypoint eval set — **MEASURED 2026-06-10/11**,
|
||||
see above: no run beats the mean-pose baseline; pretraining transfers as
|
||||
optimization aid only.
|
||||
- (c) our internal WiFlow on their dataset (15-keypoint subset mapping) — also
|
||||
affected: there is currently no validated internal pose model to compare
|
||||
(the 92.9% artifact is retracted; the MM-Fi SOTA models in ADR-150 §3 are a
|
||||
different input domain).
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Shared infrastructure for the LOCAL wiflow-std benchmark scripts (ADR-152).
|
||||
|
||||
This module is the single canonical implementation of the helpers that were
|
||||
previously copy-pasted across eval_repro.py / quantize_bench.py /
|
||||
onnx_bench.py / eval_ort_accuracy.py / export_to_safetensors.py:
|
||||
|
||||
- ``import_upstream()`` -- sys.path setup + the models-package stub that
|
||||
works around the upstream import bug, plus the >1GB np.load mmap patch
|
||||
- ``install_np_load_mmap_patch()`` -- the mmap patch on its own
|
||||
- ``remap_legacy_keys()`` / ``load_remapped_state()`` -- checkpoint
|
||||
key remap for the pre-rename released checkpoint
|
||||
- ``load_wiflow_model()`` -- WiFlowPoseModel from a checkpoint, eval mode
|
||||
- ``set_seed()`` -- mirrors upstream run.py seeding exactly
|
||||
- ``evaluate()`` -- THE canonical batch-weighted PCK/MPJPE evaluation loop
|
||||
(thresholds 0.1-0.5, upstream utils/metrics.py math); accepts either a
|
||||
torch nn.Module or an onnxruntime InferenceSession
|
||||
|
||||
The scripts under remote/ deploy to ruvultra as standalone single files and
|
||||
therefore intentionally inline private copies of these helpers; when editing
|
||||
them, treat this module as the reference implementation and keep the copies
|
||||
in sync.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
UPSTREAM = os.path.join(HERE, "upstream")
|
||||
RESULTS = os.path.join(HERE, "results")
|
||||
|
||||
DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# >1GB np.load mmap patch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# csi_windows.npy is ~13 GB; mmap large arrays instead of loading into RAM
|
||||
# (loading it eagerly needs ~15 GB).
|
||||
_np_load = np.load
|
||||
|
||||
|
||||
def _np_load_mmap(path, *a, **kw):
|
||||
if (isinstance(path, str) and path.endswith(".npy")
|
||||
and os.path.getsize(path) > 1 << 30 and "mmap_mode" not in kw):
|
||||
kw["mmap_mode"] = "r"
|
||||
return _np_load(path, *a, **kw)
|
||||
|
||||
|
||||
def install_np_load_mmap_patch():
|
||||
"""Globally patch np.load so .npy files >1GB are mmap'd read-only.
|
||||
|
||||
Idempotent. Patching the numpy module attribute is equivalent to the
|
||||
historical ``upstream_dataset.np.load = _np_load_mmap`` (dataset.np IS
|
||||
the numpy module), but works regardless of import order.
|
||||
"""
|
||||
np.load = _np_load_mmap
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# upstream import shim
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def import_upstream(mmap_patch=True):
|
||||
"""Make the upstream WiFlow-STD clone importable; returns its path.
|
||||
|
||||
Upstream bug: models/__init__.py imports TemporalConvNet, which
|
||||
models/tcn.py does not define -- the package fails to import as
|
||||
published. Register a stub package so the broken __init__ never
|
||||
executes; submodules (models.pose_model etc.) still resolve via
|
||||
__path__. Idempotent.
|
||||
"""
|
||||
if UPSTREAM not in sys.path:
|
||||
sys.path.insert(0, UPSTREAM)
|
||||
if "models" not in sys.modules:
|
||||
_models_pkg = types.ModuleType("models")
|
||||
_models_pkg.__path__ = [os.path.join(UPSTREAM, "models")]
|
||||
sys.modules["models"] = _models_pkg
|
||||
if mmap_patch:
|
||||
install_np_load_mmap_patch()
|
||||
return UPSTREAM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# checkpoint loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The released checkpoint predates the published code: modules were renamed
|
||||
# att -> attention, final_conv -> decoder (param count identical, 2.23M).
|
||||
LEGACY_RENAMES = {"att.": "attention.", "final_conv.": "decoder."}
|
||||
|
||||
|
||||
def remap_legacy_keys(state):
|
||||
"""Remap pre-rename state_dict keys; no-op for already-new-style keys."""
|
||||
return {next((new + k[len(old):] for old, new in LEGACY_RENAMES.items()
|
||||
if k.startswith(old)), k): v
|
||||
for k, v in state.items()}
|
||||
|
||||
|
||||
def load_remapped_state(path, map_location="cpu"):
|
||||
"""torch.load (weights_only) + legacy key remap."""
|
||||
state = torch.load(path, map_location=map_location, weights_only=True)
|
||||
return remap_legacy_keys(state)
|
||||
|
||||
|
||||
def load_wiflow_model(checkpoint, map_location="cpu", dropout=0.5):
|
||||
"""Full-size WiFlowPoseModel from a checkpoint, strict load, eval mode."""
|
||||
import_upstream()
|
||||
from models.pose_model import WiFlowPoseModel
|
||||
model = WiFlowPoseModel(dropout=dropout)
|
||||
model.load_state_dict(load_remapped_state(checkpoint, map_location),
|
||||
strict=True)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# seeding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def set_seed(seed=42):
|
||||
# mirror upstream run.py exactly
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# THE canonical evaluation loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def evaluate(model, loader, device=None, dtype=None, label="",
|
||||
thresholds=DEFAULT_THRESHOLDS, progress_every=50):
|
||||
"""Batch-weighted PCK/MPJPE over a DataLoader (upstream metrics math).
|
||||
|
||||
``model`` may be a torch nn.Module (optionally evaluated on ``device``
|
||||
with inputs cast to ``dtype``) or an onnxruntime InferenceSession.
|
||||
Per-threshold PCK values are independent in upstream calculate_pck, so
|
||||
evaluating a superset of thresholds never changes any individual value.
|
||||
|
||||
Returns {"samples", "mpjpe", "pck@10".."pck@50", "wall_seconds"}.
|
||||
"""
|
||||
import_upstream()
|
||||
from utils.metrics import calculate_mpjpe, calculate_pck
|
||||
|
||||
is_ort = hasattr(model, "get_inputs") # onnxruntime InferenceSession
|
||||
if is_ort:
|
||||
inp = model.get_inputs()[0].name
|
||||
|
||||
def forward(bx):
|
||||
return torch.from_numpy(model.run(None, {inp: bx.numpy()})[0])
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
def forward(bx):
|
||||
if device is not None:
|
||||
bx = bx.to(device)
|
||||
if dtype is not None:
|
||||
bx = bx.to(dtype)
|
||||
return model(bx).float()
|
||||
|
||||
thresholds = list(thresholds)
|
||||
totals = {t: 0.0 for t in thresholds}
|
||||
total_mpe, n = 0.0, 0
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (bx, by) in enumerate(loader):
|
||||
out = forward(bx)
|
||||
if device is not None and not is_ort:
|
||||
by = by.to(device)
|
||||
mpe = calculate_mpjpe(out, by)
|
||||
pck = calculate_pck(out, by, thresholds=thresholds)
|
||||
bs = by.size(0)
|
||||
total_mpe += mpe * bs
|
||||
for t in totals:
|
||||
totals[t] += pck[t] * bs
|
||||
n += bs
|
||||
if batch_idx % progress_every == 0:
|
||||
tag = f"[{label}] " if label else ""
|
||||
pck20 = totals.get(0.2)
|
||||
pck20_str = f"pck20={pck20 / n:.4f} " if pck20 is not None else ""
|
||||
print(f" {tag}batch {batch_idx}: n={n} {pck20_str}"
|
||||
f"mpjpe={total_mpe / n:.4f} ({time.time() - t0:.0f}s)",
|
||||
flush=True)
|
||||
return {
|
||||
"samples": n,
|
||||
"mpjpe": total_mpe / n,
|
||||
**{f"pck@{int(t * 100)}": totals[t] / n for t in thresholds},
|
||||
"wall_seconds": time.time() - t0,
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
"""ADR-152 edge optimization: accuracy of the ONNX fp32 and ORT-dynamic-int8
|
||||
models on the same corruption-free 10k test subset used by quantize_bench.py.
|
||||
|
||||
The torch dynamic-int8 path quantizes nothing (no nn.Linear in the model), so
|
||||
the only real int8 datapoint for the paper's "~2.2 MB int8" claim is the
|
||||
onnxruntime dynamically quantized model -- this script measures what that
|
||||
quantization costs in PCK/MPJPE.
|
||||
|
||||
Usage:
|
||||
.venv/Scripts/python.exe eval_ort_accuracy.py \
|
||||
--data-dir <preprocessed_csi_data> [--subset 10000]
|
||||
|
||||
Writes/merges into results/edge_optimization.json under key "onnx_accuracy".
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, HERE)
|
||||
|
||||
from _bench_common import RESULTS, evaluate # noqa: E402
|
||||
from quantize_bench import build_test_subset # noqa: E402 (sets up upstream imports)
|
||||
|
||||
|
||||
def evaluate_ort(sess, loader, label):
|
||||
"""ORT-session evaluation via the canonical _bench_common.evaluate loop."""
|
||||
return evaluate(sess, loader, label=label)
|
||||
|
||||
|
||||
def main():
|
||||
import onnxruntime as ort
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default=os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
||||
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
||||
parser.add_argument("--subset", type=int, default=10000)
|
||||
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
loader, _n_clean = build_test_subset(args.data_dir, args.subset)
|
||||
results = {}
|
||||
for label, fname in (("onnx_fp32", "retrained_fp32_dynamic.onnx"),
|
||||
("onnx_int8_ort_dynamic", "retrained_int8_ort_dynamic.onnx")):
|
||||
path = os.path.join(RESULTS, fname)
|
||||
if not os.path.exists(path):
|
||||
results[label] = {"error": f"{fname} not found; run onnx_bench.py first"}
|
||||
continue
|
||||
sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
|
||||
print(f"=== accuracy: {label} ({fname}) ===")
|
||||
results[label] = evaluate_ort(sess, loader, label)
|
||||
print(json.dumps(results[label], indent=2))
|
||||
|
||||
merged = {}
|
||||
if os.path.exists(args.out):
|
||||
with open(args.out) as f:
|
||||
merged = json.load(f)
|
||||
merged["onnx_accuracy"] = results
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(merged, f, indent=2)
|
||||
print(f"wrote {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,102 @@
|
||||
"""ADR-152 §2.2 measurement (a): reproduce WiFlow-STD (DY2434) published test metrics.
|
||||
|
||||
Runs the released pretrained checkpoint (upstream/best_pose_model.pth) against the
|
||||
released Kaggle dataset (kaka2434/wiflow-dataset) using the upstream code path:
|
||||
identical dataset class, identical file-level 70/15/15 split at seed 42, identical
|
||||
PCK/MPJPE implementations (utils/metrics.py).
|
||||
|
||||
Published claims (README, "Setting 1 random split"):
|
||||
PCK@20 97.25% | PCK@30 98.63% | PCK@40 99.16% | PCK@50 99.48% | MPJPE 0.007 m
|
||||
|
||||
Usage:
|
||||
.venv/Scripts/python.exe eval_repro.py --data-dir <dir containing csi_windows.npy>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from _bench_common import (UPSTREAM, evaluate, import_upstream,
|
||||
load_remapped_state, set_seed)
|
||||
|
||||
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
|
||||
|
||||
from dataset import PreprocessedCSIKeypointsDataset, create_preprocessed_train_val_test_loaders # noqa: E402
|
||||
from models.pose_model import WiFlowPoseModel # noqa: E402
|
||||
|
||||
|
||||
def find_data_dir(root):
|
||||
for dirpath, _dirnames, filenames in os.walk(root):
|
||||
if "csi_windows.npy" in filenames:
|
||||
return dirpath
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", required=True,
|
||||
help="Directory containing csi_windows.npy (searched recursively)")
|
||||
parser.add_argument("--checkpoint", default=os.path.join(UPSTREAM, "best_pose_model.pth"))
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--out", default=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"results", "repro_a.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = args.data_dir
|
||||
if not os.path.exists(os.path.join(data_dir, "csi_windows.npy")):
|
||||
located = find_data_dir(data_dir)
|
||||
if located is None:
|
||||
sys.exit(f"csi_windows.npy not found under {data_dir}")
|
||||
data_dir = located
|
||||
print(f"data dir: {data_dir}")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"device: {device}, torch {torch.__version__}")
|
||||
|
||||
set_seed(42)
|
||||
|
||||
dataset = PreprocessedCSIKeypointsDataset(
|
||||
data_dir=data_dir, keypoint_scale=1000.0, enable_temporal_clean=True)
|
||||
|
||||
# split must match upstream: file-level shuffle at random_seed=42, 70/15/15
|
||||
_train_loader, _val_loader, test_loader = create_preprocessed_train_val_test_loaders(
|
||||
dataset=dataset, batch_size=args.batch_size, num_workers=0, random_seed=42)
|
||||
|
||||
model = WiFlowPoseModel(dropout=0.5).to(device)
|
||||
# released checkpoint predates the published code: modules were renamed
|
||||
# att -> attention, final_conv -> decoder (param count identical, 2.23M)
|
||||
state = load_remapped_state(args.checkpoint, map_location=device)
|
||||
model.load_state_dict(state, strict=True)
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"checkpoint: {args.checkpoint} ({n_params/1e6:.2f}M params)")
|
||||
|
||||
# upstream also evaluates with drop_last=True; we report the full test set
|
||||
# (drop_last=False) and the drop_last variant for exact comparability
|
||||
results = {"published": {"pck@20": 0.9725, "pck@30": 0.9863, "pck@40": 0.9916,
|
||||
"pck@50": 0.9948, "mpjpe": 0.007},
|
||||
"params_millions": n_params / 1e6,
|
||||
"data_dir": data_dir,
|
||||
"device": str(device)}
|
||||
|
||||
print("=== test set (full, drop_last=False) ===")
|
||||
results["test_full"] = evaluate(model, test_loader, device=device)
|
||||
print(json.dumps(results["test_full"], indent=2))
|
||||
|
||||
test_loader_dl = DataLoader(test_loader.dataset, batch_size=args.batch_size,
|
||||
shuffle=False, drop_last=True)
|
||||
print("=== test set (drop_last=True, as upstream train.py) ===")
|
||||
results["test_drop_last"] = evaluate(model, test_loader_dl, device=device)
|
||||
print(json.dumps(results["test_drop_last"], indent=2))
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"wrote {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,174 @@
|
||||
"""ADR-152 §2.2: export the retrained WiFlow-STD PyTorch checkpoint to
|
||||
safetensors with tch-rs (VarStore) variable names, plus a numerical-parity
|
||||
fixture for the Rust port.
|
||||
|
||||
Outputs (all under results/, gitignored):
|
||||
retrained_wiflow_std.safetensors -- 248 f32 tensors named exactly as the
|
||||
Rust WiFlowStdModel VarStore expects
|
||||
(see wiflow_std/model.rs
|
||||
`dump_variable_names` for the
|
||||
authoritative name dump)
|
||||
parity_fixture.npz -- deterministic input (seed 42,
|
||||
shape (2, 540, 20), uniform [0,1]) and
|
||||
the Python model's eval-mode output
|
||||
parity_fixture.json -- same data as flattened f32 lists, for
|
||||
the dependency-free Rust test
|
||||
(tests/test_wiflow_std_parity.rs)
|
||||
|
||||
PyTorch -> tch key mapping (derived from the VarStore dump, not guessed):
|
||||
|
||||
tcn.network.{i}.conv1_group.weight -> tcn{i}.conv1_group.weight
|
||||
tcn.network.{i}.bn*_{group,pw}.<leaf> -> tcn{i}.bn*_{group,pw}.<leaf>
|
||||
tcn.network.{i}.downsample.0.weight -> tcn{i}.ds_conv.weight
|
||||
tcn.network.{i}.downsample.1.<leaf> -> tcn{i}.ds_bn.<leaf>
|
||||
up.block.{0,1,4,5,8,9}.<leaf> -> conv_in.{conv1,bn1,conv2,bn2,conv3,bn3}.<leaf>
|
||||
up.downsample.{0,1}.<leaf> -> conv_in.{ds_conv,ds_bn}.<leaf>
|
||||
residual_blocks.{i}.block.{...}.<leaf> -> conv{i}.{conv1..bn3}.<leaf>
|
||||
residual_blocks.{i}.downsample.{0,1} -> conv{i}.{ds_conv,ds_bn}
|
||||
attention.{width,height}_axis.qkv_transform.weight
|
||||
-> attention.{width,height}.qkv.weight
|
||||
attention.{width,height}_axis.bn_* -> attention.{width,height}.bn_*
|
||||
decoder.{0,1,3,4}.<leaf> -> {dec_conv1,dec_bn1,dec_conv2,dec_bn2}.<leaf>
|
||||
*.num_batches_tracked -> dropped (tch BatchNorm has no such buffer)
|
||||
|
||||
Legacy upstream names (att. -> attention., final_conv. -> decoder.) are
|
||||
remapped first, exactly as eval_repro.py does for the released checkpoint.
|
||||
|
||||
Usage:
|
||||
.venv/Scripts/python.exe export_to_safetensors.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from _bench_common import RESULTS, import_upstream, remap_legacy_keys
|
||||
|
||||
import_upstream() # sys.path + models stub
|
||||
|
||||
from models.pose_model import WiFlowPoseModel # noqa: E402
|
||||
|
||||
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
|
||||
|
||||
# Sequential index -> tch sub-name inside one ConvBlock1/AsymmetricConvBlock:
|
||||
# [Conv2d(0), BN(1), SiLU(2), Dropout2d(3), Conv2d(4), BN(5), SiLU(6),
|
||||
# Dropout2d(7), Conv2d(8), BN(9)]
|
||||
_BLOCK_IDX = {"0": "conv1", "1": "bn1", "4": "conv2", "5": "bn2",
|
||||
"8": "conv3", "9": "bn3"}
|
||||
_DS_IDX = {"0": "ds_conv", "1": "ds_bn"}
|
||||
_DECODER_IDX = {"0": "dec_conv1", "1": "dec_bn1", "3": "dec_conv2",
|
||||
"4": "dec_bn2"}
|
||||
|
||||
|
||||
def _conv_block(new_prefix: str, rest: str) -> str:
|
||||
m = re.fullmatch(r"block\.(\d+)\.(.+)", rest)
|
||||
if m:
|
||||
return f"{new_prefix}.{_BLOCK_IDX[m.group(1)]}.{m.group(2)}"
|
||||
m = re.fullmatch(r"downsample\.(\d+)\.(.+)", rest)
|
||||
if m:
|
||||
return f"{new_prefix}.{_DS_IDX[m.group(1)]}.{m.group(2)}"
|
||||
raise KeyError(f"unmapped conv-block key: {new_prefix} / {rest}")
|
||||
|
||||
|
||||
def map_key(key: str) -> str:
|
||||
"""Map one PyTorch state_dict key to the tch VarStore name."""
|
||||
m = re.fullmatch(r"tcn\.network\.(\d+)\.(.+)", key)
|
||||
if m:
|
||||
i, rest = m.groups()
|
||||
rest = (rest.replace("downsample.0.", "ds_conv.")
|
||||
.replace("downsample.1.", "ds_bn."))
|
||||
return f"tcn{i}.{rest}"
|
||||
|
||||
m = re.fullmatch(r"up\.(.+)", key)
|
||||
if m:
|
||||
return _conv_block("conv_in", m.group(1))
|
||||
|
||||
m = re.fullmatch(r"residual_blocks\.(\d+)\.(.+)", key)
|
||||
if m:
|
||||
return _conv_block(f"conv{m.group(1)}", m.group(2))
|
||||
|
||||
m = re.fullmatch(r"attention\.(width|height)_axis\.(.+)", key)
|
||||
if m:
|
||||
axis, rest = m.groups()
|
||||
rest = rest.replace("qkv_transform.", "qkv.")
|
||||
return f"attention.{axis}.{rest}"
|
||||
|
||||
m = re.fullmatch(r"decoder\.(\d+)\.(.+)", key)
|
||||
if m:
|
||||
return f"{_DECODER_IDX[m.group(1)]}.{m.group(2)}"
|
||||
|
||||
raise KeyError(f"unmapped checkpoint key: {key}")
|
||||
|
||||
|
||||
def main():
|
||||
state = torch.load(CHECKPOINT, map_location="cpu", weights_only=True)
|
||||
if not isinstance(state, dict) or "tcn.network.0.conv1_group.weight" not in {
|
||||
k for k in state
|
||||
} | {k.replace("att.", "attention.") for k in state}:
|
||||
# tolerate trainer wrappers like {"model_state_dict": ...}
|
||||
for wrapper in ("model_state_dict", "state_dict", "model"):
|
||||
if isinstance(state, dict) and wrapper in state:
|
||||
state = state[wrapper]
|
||||
break
|
||||
|
||||
# Legacy upstream names predate the published code (_bench_common).
|
||||
state = remap_legacy_keys(state)
|
||||
|
||||
mapped = {}
|
||||
dropped = 0
|
||||
for k, v in state.items():
|
||||
if k.endswith("num_batches_tracked"):
|
||||
dropped += 1
|
||||
continue
|
||||
tch_key = map_key(k)
|
||||
if tch_key in mapped:
|
||||
raise KeyError(f"duplicate mapped key: {k} -> {tch_key}")
|
||||
mapped[tch_key] = v.detach().to(torch.float32).contiguous()
|
||||
|
||||
n_params = sum(v.numel() for k, v in mapped.items()
|
||||
if "running_" not in k)
|
||||
print(f"checkpoint tensors: {len(state)} "
|
||||
f"(dropped {dropped} num_batches_tracked)")
|
||||
print(f"mapped tensors: {len(mapped)}, "
|
||||
f"non-buffer params: {n_params/1e6:.6f}M")
|
||||
assert len(mapped) == 248, f"expected 248 tch variables, got {len(mapped)}"
|
||||
assert n_params == 2_225_042, f"param count mismatch: {n_params}"
|
||||
|
||||
st_path = os.path.join(RESULTS, "retrained_wiflow_std.safetensors")
|
||||
save_file(mapped, st_path)
|
||||
print(f"wrote {st_path}")
|
||||
|
||||
# ---- parity fixture --------------------------------------------------
|
||||
model = WiFlowPoseModel(dropout=0.5)
|
||||
model.load_state_dict(state, strict=True)
|
||||
model.eval()
|
||||
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
x = torch.rand(2, 540, 20, generator=gen, dtype=torch.float32)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print(f"fixture input {tuple(x.shape)} -> output {tuple(y.shape)}, "
|
||||
f"output range [{y.min().item():.6f}, {y.max().item():.6f}]")
|
||||
|
||||
np.savez(os.path.join(RESULTS, "parity_fixture.npz"),
|
||||
input=x.numpy(), output=y.numpy())
|
||||
fixture = {
|
||||
"seed": 42,
|
||||
"input_shape": list(x.shape),
|
||||
"input": x.flatten().tolist(),
|
||||
"output_shape": list(y.shape),
|
||||
"output": y.flatten().tolist(),
|
||||
}
|
||||
json_path = os.path.join(RESULTS, "parity_fixture.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(fixture, f)
|
||||
print(f"wrote {os.path.join(RESULTS, 'parity_fixture.npz')}")
|
||||
print(f"wrote {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Regenerate results/nan_windows_mask.npy + results/big_windows_mask.npy by
|
||||
scanning a PRISTINE kagglehub download of the WiFlow-STD dataset
|
||||
(kaka2434/wiflow-dataset v1, csi_windows.npy, 360,000 windows of 540x20).
|
||||
|
||||
============================ READ THIS FIRST ===============================
|
||||
This script MUST be run against an UNCLEANED copy of the dataset.
|
||||
|
||||
remote/clean_v2.py (and its predecessor clean_nan.py) repair the dataset by
|
||||
zeroing the corrupted windows IN PLACE, with no backup. A cleaned copy
|
||||
contains no non-finite values and no out-of-range amplitudes, so on a cleaned
|
||||
copy this scan produces ALL-FALSE masks -- silently wrong ground truth. The
|
||||
script errors out loudly in that case (see the sanity check in main()).
|
||||
|
||||
That irreversibility is exactly why the two committed mask files under
|
||||
results/ (gitignore-negated) are the canonical ground truth: once a download
|
||||
has been cleaned, the masks can NEVER be regenerated from it. Only run this
|
||||
on a fresh `kagglehub.dataset_download("kaka2434/wiflow-dataset")`.
|
||||
============================================================================
|
||||
|
||||
Criteria (per window; mirrors the original 2026-06-10 scan and the
|
||||
remote/clean_v2.py repair criteria):
|
||||
|
||||
nan mask: any non-finite value (NaN/Inf) anywhere in the 540x20 window
|
||||
big mask: max |finite value| > 1.5 (the data is otherwise [0,1]-normalized;
|
||||
the corrupted files contain garbage up to 3.4e38, float32 max)
|
||||
|
||||
Expected result on the pristine Kaggle download (RESULTS.md defect 5):
|
||||
nan: 9,070 True | big: 9,072 True | union: 9,072 -- all windows in dataset
|
||||
files 487-499 (the final 13 files), window indices 350,922-359,999.
|
||||
|
||||
Usage:
|
||||
PYTHONUTF8=1 .venv/Scripts/python.exe generate_corruption_masks.py \
|
||||
[--data-dir <dir containing csi_windows.npy>] [--out-dir results]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
RESULTS = os.path.join(HERE, "results")
|
||||
|
||||
EXPECTED = {"nan": 9070, "big": 9072, "union": 9072,
|
||||
"files": (487, 499), "windows": (350922, 359999)}
|
||||
|
||||
|
||||
def scan(csi_path, chunk=4000):
|
||||
"""Chunked scan of the (mmap'd) windows array; returns (nan_mask, big_mask)."""
|
||||
csi = np.load(csi_path, mmap_mode="r")
|
||||
n = len(csi)
|
||||
nan_mask = np.zeros(n, dtype=bool)
|
||||
big_mask = np.zeros(n, dtype=bool)
|
||||
for i in range(0, n, chunk):
|
||||
block = np.asarray(csi[i:i + chunk])
|
||||
finite = np.isfinite(block)
|
||||
nan_mask[i:i + chunk] = (~finite).any(axis=(1, 2))
|
||||
big_mask[i:i + chunk] = (
|
||||
np.abs(np.where(finite, block, 0)).max(axis=(1, 2)) > 1.5)
|
||||
if (i // chunk) % 10 == 0:
|
||||
print(f" scanned {min(i + chunk, n):,}/{n:,} windows "
|
||||
f"(nan={int(nan_mask.sum()):,} big={int(big_mask.sum()):,})",
|
||||
flush=True)
|
||||
return nan_mask, big_mask
|
||||
|
||||
|
||||
def describe_files(data_dir, mask):
|
||||
"""Map marked windows to dataset file indices via window_info.npz."""
|
||||
info = os.path.join(data_dir, "window_info.npz")
|
||||
if not os.path.exists(info):
|
||||
return None
|
||||
w2f = np.load(info)["window_to_file"]
|
||||
return np.unique(w2f[mask])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Regenerate the corruption masks from a PRISTINE "
|
||||
"(uncleaned) kagglehub download. See module docstring.")
|
||||
parser.add_argument("--data-dir", default=os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
||||
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"),
|
||||
help="Directory containing csi_windows.npy (PRISTINE copy)")
|
||||
parser.add_argument("--out-dir", default=RESULTS,
|
||||
help="Where to write the two .npy masks")
|
||||
parser.add_argument("--chunk", type=int, default=4000,
|
||||
help="Windows per scan chunk (memory/speed tradeoff)")
|
||||
args = parser.parse_args()
|
||||
|
||||
csi_path = os.path.join(args.data_dir, "csi_windows.npy")
|
||||
if not os.path.exists(csi_path):
|
||||
sys.exit(f"csi_windows.npy not found in {args.data_dir}")
|
||||
|
||||
print(f"scanning {csi_path} (chunk={args.chunk}) ...")
|
||||
nan_mask, big_mask = scan(csi_path, args.chunk)
|
||||
union = nan_mask | big_mask
|
||||
print(f"nan: {int(nan_mask.sum()):,} | big: {int(big_mask.sum()):,} | "
|
||||
f"union: {int(union.sum()):,} of {len(union):,} windows")
|
||||
|
||||
# ---- sanity check: an all-False result means a CLEANED copy ------------
|
||||
if not union.any():
|
||||
sys.exit(
|
||||
"ERROR: scan found ZERO corrupted windows.\n"
|
||||
"\n"
|
||||
"The pristine Kaggle download (kaka2434/wiflow-dataset v1) is "
|
||||
"known to contain\n"
|
||||
"9,072 corrupted windows (NaN/Inf + amplitudes up to 3.4e38) in "
|
||||
"dataset files\n"
|
||||
"487-499 (RESULTS.md, reproducibility defect 5). Finding none "
|
||||
"means this copy\n"
|
||||
"has almost certainly already been repaired by remote/clean_v2.py "
|
||||
"(or clean_nan.py),\n"
|
||||
"which zeroes the corrupted windows IN PLACE -- after that the "
|
||||
"corruption evidence\n"
|
||||
"is gone and the masks CANNOT be regenerated from this copy.\n"
|
||||
"\n"
|
||||
"Refusing to overwrite the committed ground-truth masks with "
|
||||
"all-False ones.\n"
|
||||
"Re-download the dataset (kagglehub.dataset_download("
|
||||
"'kaka2434/wiflow-dataset'))\n"
|
||||
"and point --data-dir at the fresh, uncleaned copy.")
|
||||
|
||||
files = describe_files(args.data_dir, union)
|
||||
if files is not None:
|
||||
print(f"marked windows span dataset files {files.min()}-{files.max()}: "
|
||||
f"{files.tolist()}")
|
||||
lo, hi = EXPECTED["files"]
|
||||
if files.min() != lo or files.max() != hi:
|
||||
print(f"WARNING: expected marked files exactly {lo}-{hi} "
|
||||
f"(the pristine v1 download); got {files.min()}-{files.max()}. "
|
||||
f"Different dataset version, or a partially cleaned copy?")
|
||||
for name, mask, exp in (("nan", nan_mask, EXPECTED["nan"]),
|
||||
("big", big_mask, EXPECTED["big"])):
|
||||
if int(mask.sum()) != exp:
|
||||
print(f"WARNING: {name} mask has {int(mask.sum()):,} True windows; "
|
||||
f"the pristine v1 download yields {exp:,}.")
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
for name, mask in (("nan_windows_mask.npy", nan_mask),
|
||||
("big_windows_mask.npy", big_mask)):
|
||||
out = os.path.join(args.out_dir, name)
|
||||
np.save(out, mask)
|
||||
print(f"wrote {out} ({int(mask.sum()):,} True)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,220 @@
|
||||
"""ADR-152 edge optimization: ONNX export + onnxruntime CPU benchmark for the
|
||||
retrained WiFlow-STD checkpoint.
|
||||
|
||||
- Exports fp32 to ONNX. The axial attention reshapes with python ints taken
|
||||
from tensor.size() (view(N*W, C, H)), so a traced graph bakes the batch
|
||||
size; we first try a dynamic-batch export and verify it actually works at
|
||||
batch sizes 1/2/64 -- if not, we fall back to fixed-batch exports.
|
||||
- Verifies output parity vs torch on the stored fixture
|
||||
(results/parity_fixture.npz, batch 2, seed 42): max abs diff < 1e-4.
|
||||
- Measures onnxruntime CPU latency at batch 1 and 64 (median of N runs).
|
||||
- Supplementary: onnxruntime dynamic int8 quantization of the exported model
|
||||
(weight size datapoint for the paper's "~2.2 MB int8" claim).
|
||||
|
||||
Usage:
|
||||
.venv/Scripts/python.exe onnx_bench.py
|
||||
|
||||
Writes/merges into results/edge_optimization.json under key "onnx".
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import statistics
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from _bench_common import RESULTS, import_upstream, load_wiflow_model
|
||||
|
||||
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
|
||||
|
||||
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
|
||||
OUT_JSON = os.path.join(RESULTS, "edge_optimization.json")
|
||||
|
||||
|
||||
def load_fp32_model():
|
||||
return load_wiflow_model(CHECKPOINT)
|
||||
|
||||
|
||||
def try_export(model, path, batch, dynamic, opset=17):
|
||||
"""Returns (ok, exporter_used, error)."""
|
||||
x = torch.rand(batch, 540, 20)
|
||||
attempts = []
|
||||
if dynamic:
|
||||
attempts.append(("dynamo", dict(dynamo=True,
|
||||
dynamic_shapes={"x": {0: "batch"}})))
|
||||
attempts.append(("torchscript", dict(dynamo=False,
|
||||
dynamic_axes={"input": {0: "batch"},
|
||||
"output": {0: "batch"}})))
|
||||
else:
|
||||
attempts.append(("torchscript", dict(dynamo=False)))
|
||||
attempts.append(("dynamo", dict(dynamo=True)))
|
||||
last_err = None
|
||||
for name, kw in attempts:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(model, (x,), path, opset_version=opset,
|
||||
input_names=["input"], output_names=["output"],
|
||||
**kw)
|
||||
return True, name, None
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = f"{name}: {type(e).__name__}: {e}"
|
||||
traceback.print_exc()
|
||||
return False, None, last_err
|
||||
|
||||
|
||||
def ort_session(path):
|
||||
import onnxruntime as ort
|
||||
return ort.InferenceSession(path, providers=["CPUExecutionProvider"])
|
||||
|
||||
|
||||
def ort_run(sess, x):
|
||||
inp = sess.get_inputs()[0].name
|
||||
return sess.run(None, {inp: x})[0]
|
||||
|
||||
|
||||
def bench_ort(sess, batch, n_runs):
|
||||
rng = np.random.default_rng(123)
|
||||
x = rng.random((batch, 540, 20), dtype=np.float32)
|
||||
for _ in range(max(5, n_runs // 10)):
|
||||
ort_run(sess, x)
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
t0 = time.perf_counter()
|
||||
ort_run(sess, x)
|
||||
times.append(time.perf_counter() - t0)
|
||||
med = statistics.median(times)
|
||||
return {
|
||||
"batch_size": batch,
|
||||
"runs": n_runs,
|
||||
"median_ms_per_batch": med * 1e3,
|
||||
"median_ms_per_window": med * 1e3 / batch,
|
||||
"windows_per_second": batch / med,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ONNX export + onnxruntime CPU benchmark for the "
|
||||
"retrained WiFlow-STD checkpoint (no options; see "
|
||||
"module docstring). NB: the published "
|
||||
"retrained_fp32_dynamic.onnx came from the TorchScript "
|
||||
"exporter; on newer torch the dynamo attempt may succeed "
|
||||
"first and produce a different (external-data) artifact.")
|
||||
parser.parse_args()
|
||||
|
||||
import onnxruntime
|
||||
model = load_fp32_model()
|
||||
results = {
|
||||
"env": {
|
||||
"torch": torch.__version__,
|
||||
"onnxruntime": onnxruntime.__version__,
|
||||
"platform": platform.platform(),
|
||||
},
|
||||
}
|
||||
|
||||
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
|
||||
fx, fy = fixture["input"], fixture["output"] # (2,540,20) -> (2,15,2)
|
||||
|
||||
# ---- export: dynamic batch first, fall back to fixed --------------------
|
||||
dyn_path = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
|
||||
ok, exporter, err = try_export(model, dyn_path, batch=2, dynamic=True)
|
||||
dynamic_works = False
|
||||
if ok:
|
||||
# verify the dynamic graph really runs at other batch sizes
|
||||
try:
|
||||
sess = ort_session(dyn_path)
|
||||
for b in (1, 2, 64):
|
||||
y = ort_run(sess, np.zeros((b, 540, 20), dtype=np.float32))
|
||||
assert y.shape == (b, 15, 2), y.shape
|
||||
dynamic_works = True
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"dynamic-batch model does not generalize: {e}")
|
||||
|
||||
sessions = {}
|
||||
if dynamic_works:
|
||||
results["export"] = {"mode": "dynamic-batch", "exporter": exporter,
|
||||
"file": os.path.basename(dyn_path),
|
||||
"size_mb": os.path.getsize(dyn_path) / 1e6}
|
||||
sess = ort_session(dyn_path)
|
||||
sessions = {1: sess, 2: sess, 64: sess}
|
||||
print(f"dynamic-batch export OK via {exporter}")
|
||||
else:
|
||||
results["export"] = {"mode": "fixed-batch", "fallback_reason": err,
|
||||
"files": {}}
|
||||
for b in (1, 2, 64):
|
||||
p = os.path.join(RESULTS, f"retrained_fp32_b{b}.onnx")
|
||||
ok, exporter, err = try_export(model, p, batch=b, dynamic=False)
|
||||
if not ok:
|
||||
results["export"]["files"][str(b)] = {"error": err}
|
||||
print(f"EXPORT FAILED at batch {b}: {err}")
|
||||
continue
|
||||
results["export"]["files"][str(b)] = {
|
||||
"exporter": exporter, "file": os.path.basename(p),
|
||||
"size_mb": os.path.getsize(p) / 1e6}
|
||||
sessions[b] = ort_session(p)
|
||||
print(f"fixed-batch {b} export OK via {exporter}")
|
||||
|
||||
# ---- parity vs torch on the fixture -------------------------------------
|
||||
if 2 in sessions:
|
||||
y_ort = ort_run(sessions[2], fx)
|
||||
with torch.no_grad():
|
||||
y_torch = model(torch.from_numpy(fx)).numpy()
|
||||
results["parity"] = {
|
||||
"fixture": "results/parity_fixture.npz (batch 2, seed 42)",
|
||||
"max_abs_diff_vs_stored_fixture": float(np.abs(y_ort - fy).max()),
|
||||
"max_abs_diff_vs_torch_now": float(np.abs(y_ort - y_torch).max()),
|
||||
"pass_lt_1e-4": bool(np.abs(y_ort - y_torch).max() < 1e-4),
|
||||
}
|
||||
print("parity:", json.dumps(results["parity"], indent=2))
|
||||
|
||||
# ---- latency -------------------------------------------------------------
|
||||
results["latency"] = {}
|
||||
if 1 in sessions:
|
||||
results["latency"]["batch1"] = bench_ort(sessions[1], 1, 100)
|
||||
print(f"ORT batch 1: {results['latency']['batch1']['median_ms_per_window']:.2f} ms/window")
|
||||
if 64 in sessions:
|
||||
results["latency"]["batch64"] = bench_ort(sessions[64], 64, 30)
|
||||
print(f"ORT batch 64: {results['latency']['batch64']['median_ms_per_window']:.3f} ms/window")
|
||||
|
||||
# ---- supplementary: ORT dynamic int8 (size datapoint for the 2.2MB claim)
|
||||
src = (dyn_path if dynamic_works
|
||||
else os.path.join(RESULTS, "retrained_fp32_b1.onnx"))
|
||||
if os.path.exists(src):
|
||||
try:
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
q_path = os.path.join(RESULTS, "retrained_int8_ort_dynamic.onnx")
|
||||
quantize_dynamic(src, q_path, weight_type=QuantType.QInt8)
|
||||
entry = {"file": os.path.basename(q_path),
|
||||
"size_mb": os.path.getsize(q_path) / 1e6}
|
||||
try:
|
||||
qs = ort_session(q_path)
|
||||
yq = ort_run(qs, fx[:1] if not dynamic_works else fx)
|
||||
ref = fy[:1] if not dynamic_works else fy
|
||||
entry["runs"] = True
|
||||
entry["max_abs_diff_vs_fp32_fixture"] = float(np.abs(yq - ref).max())
|
||||
except Exception as e: # noqa: BLE001
|
||||
entry["runs"] = False
|
||||
entry["run_error"] = f"{type(e).__name__}: {e}"
|
||||
results["ort_int8_dynamic_supplementary"] = entry
|
||||
print("ORT int8:", json.dumps(entry, indent=2))
|
||||
except Exception as e: # noqa: BLE001
|
||||
results["ort_int8_dynamic_supplementary"] = {
|
||||
"error": f"{type(e).__name__}: {e}"}
|
||||
|
||||
merged = {}
|
||||
if os.path.exists(OUT_JSON):
|
||||
with open(OUT_JSON) as f:
|
||||
merged = json.load(f)
|
||||
merged["onnx"] = results
|
||||
with open(OUT_JSON, "w") as f:
|
||||
json.dump(merged, f, indent=2)
|
||||
print(f"wrote {OUT_JSON}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,228 @@
|
||||
"""ADR-152 "optimize beyond SOTA": edge-optimization benchmark for the
|
||||
retrained WiFlow-STD checkpoint (results/retrained_best_pose_model.pth,
|
||||
~96% PCK@20, fp32 params 2,225,042).
|
||||
|
||||
Measures, for fp32 / fp16 / dynamic-int8 torch variants:
|
||||
(a) serialized state_dict size on disk,
|
||||
(b) CPU inference latency per window at batch 1 and batch 64
|
||||
(median of repeated runs, this Windows box),
|
||||
(c) accuracy (PCK@20/50 + MPJPE, upstream metrics) on a corruption-free
|
||||
random subset of the seed-42 file-level 70/15/15 test split
|
||||
(same split as eval_repro.py; corrupted windows 487-499 excluded via
|
||||
results/nan_windows_mask.npy | results/big_windows_mask.npy).
|
||||
|
||||
Also verifies the paper's "~2.2 MB int8" size claim: reports which layer
|
||||
types torch dynamic quantization actually converts (the model contains NO
|
||||
nn.Linear -- it is Conv1d/Conv2d/BatchNorm only) and the real on-disk size.
|
||||
|
||||
Usage:
|
||||
.venv/Scripts/python.exe quantize_bench.py \
|
||||
--data-dir C:/Users/ruv/.cache/kagglehub/datasets/kaka2434/wiflow-dataset/versions/1/preprocessed_csi_data \
|
||||
[--subset 10000] [--skip-accuracy]
|
||||
|
||||
Writes/merges into results/edge_optimization.json under key "torch".
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from _bench_common import HERE, RESULTS, evaluate, import_upstream, load_wiflow_model
|
||||
|
||||
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
|
||||
|
||||
from dataset import ( # noqa: E402
|
||||
PreprocessedCSIKeypointsDataset,
|
||||
create_preprocessed_train_val_test_loaders,
|
||||
)
|
||||
|
||||
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
|
||||
|
||||
|
||||
def load_fp32_model():
|
||||
# legacy upstream key remap inside is a harmless no-op on this checkpoint
|
||||
return load_wiflow_model(CHECKPOINT)
|
||||
|
||||
|
||||
def state_dict_size_bytes(model, path):
|
||||
torch.save(model.state_dict(), path)
|
||||
return os.path.getsize(path)
|
||||
|
||||
|
||||
def bench_latency(model, batch_size, n_runs, dtype=torch.float32):
|
||||
gen = torch.Generator().manual_seed(123)
|
||||
x = torch.rand(batch_size, 540, 20, generator=gen).to(dtype)
|
||||
with torch.no_grad():
|
||||
for _ in range(max(5, n_runs // 10)): # warmup
|
||||
model(x)
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
t0 = time.perf_counter()
|
||||
model(x)
|
||||
times.append(time.perf_counter() - t0)
|
||||
med = statistics.median(times)
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"runs": n_runs,
|
||||
"median_ms_per_batch": med * 1e3,
|
||||
"median_ms_per_window": med * 1e3 / batch_size,
|
||||
"windows_per_second": batch_size / med,
|
||||
}
|
||||
|
||||
|
||||
def build_test_subset(data_dir, subset_size, batch_size=64):
|
||||
"""Seed-42 file-level 70/15/15 test split (exactly as eval_repro.py),
|
||||
minus corrupted windows, then a seed-42 random subset."""
|
||||
dataset = PreprocessedCSIKeypointsDataset(
|
||||
data_dir=data_dir, keypoint_scale=1000.0, enable_temporal_clean=True)
|
||||
_tr, _va, test_loader = create_preprocessed_train_val_test_loaders(
|
||||
dataset=dataset, batch_size=batch_size, num_workers=0, random_seed=42)
|
||||
test_indices = np.asarray(test_loader.dataset.indices)
|
||||
|
||||
corrupted = (np.load(os.path.join(RESULTS, "nan_windows_mask.npy"))
|
||||
| np.load(os.path.join(RESULTS, "big_windows_mask.npy")))
|
||||
clean = test_indices[~corrupted[test_indices]]
|
||||
print(f"test split: {len(test_indices)} windows, "
|
||||
f"{len(test_indices) - len(clean)} corrupted excluded, "
|
||||
f"{len(clean)} clean")
|
||||
|
||||
if subset_size and subset_size < len(clean):
|
||||
rng = np.random.default_rng(42)
|
||||
clean = np.sort(rng.choice(clean, size=subset_size, replace=False))
|
||||
subset = torch.utils.data.Subset(dataset, clean.tolist())
|
||||
loader = DataLoader(subset, batch_size=batch_size, shuffle=False,
|
||||
num_workers=0)
|
||||
return loader, len(clean)
|
||||
|
||||
|
||||
def quantize_int8_dynamic(fp32_model):
|
||||
"""torch.ao.quantization.quantize_dynamic on Linear/Conv where supported.
|
||||
Returns (model, report) where report documents what actually quantized."""
|
||||
qmodel = torch.ao.quantization.quantize_dynamic(
|
||||
fp32_model, {nn.Linear, nn.Conv1d, nn.Conv2d}, dtype=torch.qint8)
|
||||
|
||||
quantized, total_params, quant_params = [], 0, 0
|
||||
for name, mod in qmodel.named_modules():
|
||||
cls = type(mod).__module__ + "." + type(mod).__name__
|
||||
if "quantized" in cls:
|
||||
w = mod.weight() if callable(getattr(mod, "weight", None)) else None
|
||||
numel = w.numel() if w is not None else 0
|
||||
quant_params += numel
|
||||
quantized.append({"module": name, "class": cls, "params": numel})
|
||||
for p in fp32_model.parameters():
|
||||
total_params += p.numel()
|
||||
|
||||
n_linear = sum(isinstance(m, nn.Linear) for m in fp32_model.modules())
|
||||
n_conv1d = sum(isinstance(m, nn.Conv1d) for m in fp32_model.modules())
|
||||
n_conv2d = sum(isinstance(m, nn.Conv2d) for m in fp32_model.modules())
|
||||
report = {
|
||||
"eligible_module_counts": {
|
||||
"nn.Linear": n_linear, "nn.Conv1d": n_conv1d, "nn.Conv2d": n_conv2d},
|
||||
"modules_actually_quantized": quantized,
|
||||
"n_modules_quantized": len(quantized),
|
||||
"params_total": total_params,
|
||||
"params_quantized": quant_params,
|
||||
"params_quantized_fraction": quant_params / total_params,
|
||||
}
|
||||
return qmodel, report
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default=os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
||||
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
||||
parser.add_argument("--subset", type=int, default=10000)
|
||||
parser.add_argument("--runs-b1", type=int, default=100)
|
||||
parser.add_argument("--runs-b64", type=int, default=30)
|
||||
parser.add_argument("--skip-accuracy", action="store_true")
|
||||
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(42)
|
||||
results = {
|
||||
"env": {
|
||||
"torch": torch.__version__,
|
||||
"platform": platform.platform(),
|
||||
"processor": platform.processor(),
|
||||
"num_threads": torch.get_num_threads(),
|
||||
"checkpoint": os.path.relpath(CHECKPOINT, HERE),
|
||||
},
|
||||
"variants": {},
|
||||
}
|
||||
|
||||
# ---- build variants ---------------------------------------------------
|
||||
fp32 = load_fp32_model()
|
||||
n_params = sum(p.numel() for p in fp32.parameters())
|
||||
results["env"]["params"] = n_params
|
||||
print(f"fp32 model: {n_params:,} params")
|
||||
|
||||
fp16 = load_fp32_model().half()
|
||||
|
||||
int8, q_report = quantize_int8_dynamic(load_fp32_model())
|
||||
results["int8_dynamic_quant_report"] = q_report
|
||||
print(f"int8 dynamic: {q_report['n_modules_quantized']} modules quantized, "
|
||||
f"{q_report['params_quantized_fraction']*100:.1f}% of params")
|
||||
|
||||
variants = {
|
||||
"fp32": (fp32, torch.float32, "retrained_fp32_resaved.pth"),
|
||||
"fp16": (fp16, torch.float16, "retrained_fp16.pth"),
|
||||
"int8_dynamic": (int8, torch.float32, "retrained_int8_dynamic.pth"),
|
||||
}
|
||||
|
||||
# ---- (a) size + (b) latency -------------------------------------------
|
||||
for name, (model, dtype, fname) in variants.items():
|
||||
path = os.path.join(RESULTS, fname)
|
||||
size = state_dict_size_bytes(model, path)
|
||||
print(f"\n=== {name}: {size/1e6:.3f} MB on disk ({fname}) ===")
|
||||
lat1 = bench_latency(model, 1, args.runs_b1, dtype)
|
||||
lat64 = bench_latency(model, 64, args.runs_b64, dtype)
|
||||
print(f" batch 1: {lat1['median_ms_per_window']:.2f} ms/window "
|
||||
f"({lat1['windows_per_second']:.0f}/s)")
|
||||
print(f" batch 64: {lat64['median_ms_per_window']:.3f} ms/window "
|
||||
f"({lat64['windows_per_second']:.0f}/s)")
|
||||
results["variants"][name] = {
|
||||
"file": fname,
|
||||
"size_bytes": size,
|
||||
"size_mb": size / 1e6,
|
||||
"latency_batch1": lat1,
|
||||
"latency_batch64": lat64,
|
||||
}
|
||||
|
||||
# ---- (c) accuracy ------------------------------------------------------
|
||||
if not args.skip_accuracy:
|
||||
loader, n_clean = build_test_subset(args.data_dir, args.subset)
|
||||
results["accuracy_subset"] = {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted "
|
||||
"windows (files 487-499) excluded, seed-42 random "
|
||||
"subset",
|
||||
"subset_size": min(args.subset, n_clean) if args.subset else n_clean,
|
||||
"clean_test_total": n_clean,
|
||||
}
|
||||
for name, (model, dtype, _f) in variants.items():
|
||||
print(f"\n=== accuracy: {name} ===")
|
||||
results["variants"][name]["accuracy"] = evaluate(
|
||||
model, loader, dtype=dtype, label=name)
|
||||
print(json.dumps(results["variants"][name]["accuracy"], indent=2))
|
||||
|
||||
# ---- merge into edge_optimization.json ---------------------------------
|
||||
merged = {}
|
||||
if os.path.exists(args.out):
|
||||
with open(args.out) as f:
|
||||
merged = json.load(f)
|
||||
merged["torch"] = results
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(merged, f, indent=2)
|
||||
print(f"\nwrote {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,14 @@
|
||||
import numpy as np, os
|
||||
d = os.path.expanduser('~/wiflow-std-bench/preprocessed_csi_data')
|
||||
csi = np.load(os.path.join(d, 'csi_windows.npy'), mmap_mode='r+')
|
||||
zeroed = 0
|
||||
chunk = 4000
|
||||
for i in range(0, len(csi), chunk):
|
||||
block = csi[i:i+chunk]
|
||||
finite = np.isfinite(block)
|
||||
bad = (~finite).any(axis=(1, 2)) | (np.abs(np.where(finite, block, 0)).max(axis=(1, 2)) > 1.5)
|
||||
if bad.any():
|
||||
block[bad] = 0.0
|
||||
zeroed += int(bad.sum())
|
||||
csi.flush()
|
||||
print(f'zeroed {zeroed} corrupted windows entirely')
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Evaluate the retrained WiFlow-STD checkpoint (ADR-152 §2.2a fallback).
|
||||
|
||||
Scores the model produced by run.py (train_output/best_pose_model.pth or similar)
|
||||
on the seed-42 test split: full test set AND NaN-free subset (excluding windows
|
||||
that were zero-filled by clean_nan.py — file indices 487-499).
|
||||
|
||||
NOTE: deployed to ruvultra (~/wiflow-std-bench) as a standalone single file,
|
||||
so it deliberately inlines its helpers. The reference implementations (upstream
|
||||
import shim, >1GB np.load mmap patch, key-remap loader, canonical evaluate
|
||||
loop) live in benchmarks/wiflow-std/_bench_common.py — keep copies in sync.
|
||||
"""
|
||||
import json, os, random, sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
# csi_windows.npy is ~13 GB; mmap large arrays instead of eagerly loading
|
||||
# ~15 GB into RAM (same patch as _bench_common._np_load_mmap).
|
||||
_np_load = np.load
|
||||
|
||||
|
||||
def _np_load_mmap(path, *a, **kw):
|
||||
if (isinstance(path, str) and path.endswith('.npy')
|
||||
and os.path.getsize(path) > 1 << 30 and 'mmap_mode' not in kw):
|
||||
kw['mmap_mode'] = 'r'
|
||||
return _np_load(path, *a, **kw)
|
||||
|
||||
|
||||
np.load = _np_load_mmap
|
||||
|
||||
sys.path.insert(0, os.path.expanduser('~/wiflow-std-bench/upstream'))
|
||||
from dataset import PreprocessedCSIKeypointsDataset, create_preprocessed_train_val_test_loaders
|
||||
from models.pose_model import WiFlowPoseModel
|
||||
from utils.metrics import calculate_pck, calculate_mpjpe
|
||||
|
||||
|
||||
def find_checkpoint():
|
||||
cands = []
|
||||
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/train_output')):
|
||||
for f in files:
|
||||
if f.endswith('.pth'):
|
||||
cands.append(os.path.join(root, f))
|
||||
# also upstream/test default output dir
|
||||
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/upstream')):
|
||||
for f in files:
|
||||
if f.endswith('.pth') and 'best' in f and 'cross_dataset' not in root:
|
||||
p = os.path.join(root, f)
|
||||
if os.path.getmtime(p) > os.path.getmtime(os.path.expanduser('~/wiflow-std-bench/train.log')) - 86400 * 2:
|
||||
cands.append(p)
|
||||
cands = [c for c in cands if not c.endswith('upstream/best_pose_model.pth')]
|
||||
if not cands:
|
||||
sys.exit('no retrained checkpoint found')
|
||||
return max(cands, key=os.path.getmtime)
|
||||
|
||||
|
||||
def evaluate(model, loader, device):
|
||||
model.eval()
|
||||
totals = {t: 0.0 for t in (0.1, 0.2, 0.3, 0.4, 0.5)}
|
||||
total_mpe, n = 0.0, 0
|
||||
with torch.no_grad():
|
||||
for bx, by in loader:
|
||||
bx, by = bx.to(device), by.to(device)
|
||||
out = model(bx)
|
||||
bs = by.size(0)
|
||||
total_mpe += calculate_mpjpe(out, by) * bs
|
||||
pck = calculate_pck(out, by, thresholds=list(totals))
|
||||
for t in totals:
|
||||
totals[t] += pck[t] * bs
|
||||
n += bs
|
||||
return {'samples': n, 'mpjpe': total_mpe / n,
|
||||
**{f'pck@{int(t*100)}': totals[t] / n for t in totals}}
|
||||
|
||||
|
||||
random.seed(42); np.random.seed(42); torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
d = os.path.expanduser('~/wiflow-std-bench/preprocessed_csi_data')
|
||||
dataset = PreprocessedCSIKeypointsDataset(data_dir=d, keypoint_scale=1000.0,
|
||||
enable_temporal_clean=True)
|
||||
_, _, test_loader = create_preprocessed_train_val_test_loaders(
|
||||
dataset=dataset, batch_size=256, num_workers=2, random_seed=42)
|
||||
|
||||
device = torch.device('cuda')
|
||||
ckpt = find_checkpoint()
|
||||
print('checkpoint:', ckpt)
|
||||
model = WiFlowPoseModel(dropout=0.5).to(device)
|
||||
state = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
renames = {'att.': 'attention.', 'final_conv.': 'decoder.'}
|
||||
state = {next((new + k[len(old):] for old, new in renames.items()
|
||||
if k.startswith(old)), k): v for k, v in state.items()}
|
||||
model.load_state_dict(state, strict=True)
|
||||
|
||||
results = {'checkpoint': ckpt}
|
||||
print('=== full test set ===')
|
||||
results['test_full'] = evaluate(model, test_loader, device)
|
||||
print(json.dumps(results['test_full'], indent=2))
|
||||
|
||||
# NaN-free subset: exclude windows from corrupted files 487-499
|
||||
test_subset = test_loader.dataset # Subset(dataset, test_indices)
|
||||
w2f = dataset.window_to_file
|
||||
clean_idx = [i for i in test_subset.indices if w2f[i] < 487]
|
||||
print(f'=== NaN-free test subset ({len(clean_idx)} of {len(test_subset.indices)}) ===')
|
||||
clean_loader = DataLoader(Subset(dataset, clean_idx), batch_size=256, shuffle=False)
|
||||
results['test_clean'] = evaluate(model, clean_loader, device)
|
||||
print(json.dumps(results['test_clean'], indent=2))
|
||||
|
||||
out = os.path.expanduser('~/wiflow-std-bench/eval_retrained.json')
|
||||
with open(out, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print('wrote', out)
|
||||
@@ -0,0 +1,374 @@
|
||||
"""ADR-152 SS2.2 measurement (b): WiFlow-STD fine-tuned on our fresh ESP32 paired dataset.
|
||||
|
||||
Dataset: ~/wiflow-std-bench/paired-20260610.jsonl -- 2,046 paired windows collected
|
||||
2026-06-10 22:10-22:40 (ONE subject, ONE room, ONE ESP32 node, varied poses).
|
||||
Per record: csi = flat float32 list, csi_shape, kp = 17 COCO [x, y] normalized [0,1]
|
||||
camera coords, conf (MediaPipe mean confidence, all > 0.5 in this set), ts_start/ts_end.
|
||||
Aligner: scripts/align-ground-truth.js, non-overlapping 20-frame windows (~0.42 s each).
|
||||
|
||||
Dataset findings (MEASURED on this file, 2026-06-10):
|
||||
- csi_shape is HETEROGENEOUS, not uniformly [70, 20]: 1,347x [70,20], 284x [134,20],
|
||||
243x [26,20], 130x [12,20], 42x [20,20]. The ESP32 stream emits mixed frame types
|
||||
and the aligner stamps each window's subcarrier count from frame[0]
|
||||
(extractCsiMatrix: nSc = window[0].subcarriers), zero-padding/truncating the rest.
|
||||
Even native-70 windows contain ~20.4% internally zero-padded short frames
|
||||
(subcarriers 40..69 all-zero for those frames).
|
||||
- LAYOUT BUG: the aligner fills matrix[f * nSc + s] (frame-major) but declares
|
||||
shape [nSc, nFrames]. The true layout is (frame, subcarrier); we reshape
|
||||
(nFrames, nSc) and transpose. Confirmed by coherent per-frame zero-tails.
|
||||
- Handling here (primary suite, "all2046"): every frame's subcarrier axis is
|
||||
linearly resampled to 70 bins (np.interp over a normalized index domain;
|
||||
identity for native-70 frames) so the pre-registered n=2,046 and split sizes
|
||||
hold. Secondary suite ("native70") restricts to the 1,347 native [70,20]
|
||||
windows (temporal 70/15/15 of those) as a homogeneity robustness check.
|
||||
|
||||
Pre-registered protocol (followed exactly):
|
||||
1. TEMPORAL split (records are time-sorted; asserted): first 70% train (1,432),
|
||||
next 15% val (307), last 15% test (307). No shuffling across time. Seed 42
|
||||
for everything else.
|
||||
2. Model: upstream WiFlow-STD trunk (WiFlowPoseModel) with a learned 1x1 Conv1d
|
||||
projection 70->540 prepended, and K=17 via the parameter-free adaptive pool
|
||||
(AdaptiveAvgPool2d((17, 1)) instead of (15, 1)) -- pretrained weights load
|
||||
for any K. CSI normalization: divide by the TRAIN-split 99th-percentile
|
||||
amplitude, clip to [0, 1] (documented in output JSON).
|
||||
3. Three runs, <=60 epochs, early-stop patience 8 on val MPJPE, batch 32,
|
||||
AdamW, fp32 (no autocast):
|
||||
(i) pretrained-init: trunk init from upstream/test/best_pose_model.pth
|
||||
(the measurement-(a) retrained checkpoint, ~96% PCK@20 on WiFlow data;
|
||||
key remap att.->attention. / final_conv.->decoder. applied defensively
|
||||
as in eval_repro.py -- a no-op for this checkpoint, which already uses
|
||||
the new names). Discriminative lr: adapter 1e-4, trunk 1e-5.
|
||||
(ii) scratch: same architecture, random init, all params lr 1e-4.
|
||||
(iii) frozen-trunk: pretrained trunk frozen (requires_grad=False AND held in
|
||||
.eval() so BatchNorm running stats cannot drift -- pure transfer probe);
|
||||
only the 70->540 adapter trains, lr 1e-4.
|
||||
4. Metrics on the temporal TEST split: torso-normalized PCK@10/20/30/40/50 and
|
||||
MPJPE. Upstream utils/metrics.py calculate_pck(use_torso_norm=True) hardcodes
|
||||
NECK_IDX/PELVIS_IDX = 2, 12 -- a 15-keypoint convention that is WRONG for our
|
||||
17 COCO keypoints (2 = right_eye, 12 = right_hip). We therefore reimplement the
|
||||
identical math (per-frame norm distance, clamp min 0.01, mean over all
|
||||
keypoints x frames) with torso = ||l_shoulder(5) - l_hip(11)||.
|
||||
Also reported: prediction std across test frames (constant-pose detector;
|
||||
must be > 0) and the mean-pose-predictor baseline (train-split mean pose
|
||||
evaluated on test -- the honesty bar).
|
||||
|
||||
Usage (on ruvultra):
|
||||
nice -n 10 nohup ~/wiflow-std-bench/venv/bin/python train_measb.py > train_measb.log 2>&1 &
|
||||
|
||||
NOTE: deployed to ruvultra as a standalone single file, so it deliberately
|
||||
inlines its helpers. The reference implementations (upstream import shim,
|
||||
np.load mmap patch, key-remap loader, canonical evaluate loop) live in
|
||||
benchmarks/wiflow-std/_bench_common.py — keep copies in sync.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
BENCH = os.path.expanduser("~/wiflow-std-bench")
|
||||
UPSTREAM = os.path.join(BENCH, "upstream")
|
||||
MEASB = os.path.join(BENCH, "measb")
|
||||
DATA = os.path.join(BENCH, "paired-20260610.jsonl")
|
||||
CHECKPOINT = os.path.join(UPSTREAM, "test", "best_pose_model.pth")
|
||||
|
||||
sys.path.insert(0, UPSTREAM)
|
||||
|
||||
# Upstream defect (1): models/__init__.py imports a name tcn.py does not define.
|
||||
# Register a stub package so the broken __init__ never executes (as eval_repro.py).
|
||||
import types # noqa: E402
|
||||
|
||||
_models_pkg = types.ModuleType("models")
|
||||
_models_pkg.__path__ = [os.path.join(UPSTREAM, "models")]
|
||||
sys.modules["models"] = _models_pkg
|
||||
|
||||
from models.pose_model import WiFlowPoseModel # noqa: E402
|
||||
|
||||
SEED = 42
|
||||
K = 17
|
||||
N_SUBC = 70
|
||||
TRUNK_IN = 540
|
||||
BATCH = 32 # <= 64 per protocol (GPU shared with the efficiency sweep)
|
||||
MAX_EPOCHS = 60
|
||||
PATIENCE = 8
|
||||
LR_ADAPTER = 1e-4
|
||||
LR_TRUNK_FT = 1e-5 # 10x lower for the pretrained trunk vs the fresh adapter
|
||||
L_SHOULDER, L_HIP = 5, 11
|
||||
THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5)
|
||||
|
||||
|
||||
def set_seed(seed=SEED):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def resample_subcarriers(frame_major, n_out=N_SUBC):
|
||||
"""(nFrames, nSc) -> (nFrames, n_out) by per-frame linear interpolation.
|
||||
|
||||
Identity for nSc == n_out. Normalized index domain [0, 1] on both sides.
|
||||
"""
|
||||
nf, nsc = frame_major.shape
|
||||
if nsc == n_out:
|
||||
return frame_major
|
||||
xi = np.linspace(0.0, 1.0, nsc)
|
||||
xo = np.linspace(0.0, 1.0, n_out)
|
||||
return np.stack([np.interp(xo, xi, frame_major[f]) for f in range(nf)]).astype(np.float32)
|
||||
|
||||
|
||||
def load_dataset():
|
||||
csi, kps, confs, ts, native70 = [], [], [], [], []
|
||||
shape_counts = {}
|
||||
with open(DATA) as f:
|
||||
for line in f:
|
||||
r = json.loads(line)
|
||||
nsc, nf = r["csi_shape"]
|
||||
shape_counts[f"{nsc}x{nf}"] = shape_counts.get(f"{nsc}x{nf}", 0) + 1
|
||||
assert nf == 20, r["csi_shape"]
|
||||
# Aligner layout bug: data is frame-major despite the declared
|
||||
# [nSc, nFrames] shape -- reshape (nFrames, nSc), then resample the
|
||||
# subcarrier axis to 70 and transpose to (70 subcarriers, 20 frames).
|
||||
fm = np.asarray(r["csi"], dtype=np.float32).reshape(nf, nsc)
|
||||
csi.append(resample_subcarriers(fm).T)
|
||||
kp = np.asarray(r["kp"], dtype=np.float32)
|
||||
assert kp.shape == (K, 2), kp.shape
|
||||
kps.append(kp)
|
||||
confs.append(r["conf"])
|
||||
ts.append(r["ts_start"])
|
||||
native70.append(nsc == N_SUBC)
|
||||
assert all(ts[i] <= ts[i + 1] for i in range(len(ts) - 1)), "records not time-sorted"
|
||||
return (np.stack(csi), np.stack(kps), np.asarray(confs, dtype=np.float32),
|
||||
np.asarray(native70), shape_counts, ts[0], ts[-1])
|
||||
|
||||
|
||||
def temporal_split(n):
|
||||
n_train = int(round(n * 0.70))
|
||||
n_val = int(round(n * 0.15))
|
||||
return slice(0, n_train), slice(n_train, n_train + n_val), slice(n_train + n_val, n)
|
||||
|
||||
|
||||
class AdaptedWiFlow(nn.Module):
|
||||
"""1x1 Conv1d adapter 70->540 + upstream WiFlow-STD trunk with K=17 pool head."""
|
||||
|
||||
def __init__(self, k=K, dropout=0.5):
|
||||
super().__init__()
|
||||
self.adapter = nn.Conv1d(N_SUBC, TRUNK_IN, kernel_size=1)
|
||||
nn.init.kaiming_normal_(self.adapter.weight, mode="fan_out", nonlinearity="relu")
|
||||
nn.init.constant_(self.adapter.bias, 0)
|
||||
self.trunk = WiFlowPoseModel(dropout=dropout)
|
||||
# K=17 via the parameter-free adaptive pool: decoder emits [B, 2, 15, 20]
|
||||
# spatial maps; pooling H->17 instead of 15 yields [B, 17, 2] with no new
|
||||
# parameters, so the pretrained state_dict loads strict=True for any K.
|
||||
self.trunk.avg_pool = nn.AdaptiveAvgPool2d((k, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.trunk(self.adapter(x))
|
||||
|
||||
|
||||
def load_pretrained_trunk(trunk, path):
|
||||
state = torch.load(path, map_location="cpu", weights_only=True)
|
||||
# Defensive remap as in eval_repro.py (no-op for the retrained checkpoint).
|
||||
renames = {"att.": "attention.", "final_conv.": "decoder."}
|
||||
state = {next((new + k[len(old):] for old, new in renames.items()
|
||||
if k.startswith(old)), k): v
|
||||
for k, v in state.items()}
|
||||
trunk.load_state_dict(state, strict=True)
|
||||
|
||||
|
||||
def pck_torso(pred, target, thresholds=THRESHOLDS):
|
||||
"""Upstream calculate_pck math, torso = l_shoulder(5)<->l_hip(11) for 17-kp COCO."""
|
||||
norm = torch.sqrt(((target[:, L_SHOULDER] - target[:, L_HIP]) ** 2).sum(dim=1))
|
||||
norm = torch.clamp(norm, min=0.01)
|
||||
dist = torch.sqrt(((pred - target) ** 2).sum(dim=2)) / norm.unsqueeze(1)
|
||||
return {f"pck@{int(t * 100)}": (dist <= t).float().mean().item() for t in thresholds}
|
||||
|
||||
|
||||
def mpjpe(pred, target):
|
||||
return torch.sqrt(((pred - target) ** 2).sum(dim=2)).mean().item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(model, x, batch=256):
|
||||
model.eval()
|
||||
return torch.cat([model(x[i:i + batch]) for i in range(0, len(x), batch)])
|
||||
|
||||
|
||||
def eval_preds(pred, target):
|
||||
out = pck_torso(pred, target)
|
||||
out["mpjpe"] = mpjpe(pred, target)
|
||||
# Constant-pose detector: std across test frames per coordinate, mean over
|
||||
# the 17x2 coordinates. 0.0 == degenerate constant predictor.
|
||||
out["pred_std"] = pred.std(dim=0).mean().item()
|
||||
return out
|
||||
|
||||
|
||||
def train_run(name, x_tr, y_tr, x_va, y_va, device, pretrained, freeze_trunk,
|
||||
lr_trunk):
|
||||
set_seed(SEED)
|
||||
model = AdaptedWiFlow().to(device)
|
||||
if pretrained:
|
||||
load_pretrained_trunk(model.trunk, CHECKPOINT)
|
||||
if freeze_trunk:
|
||||
for p in model.trunk.parameters():
|
||||
p.requires_grad = False
|
||||
groups = [{"params": model.adapter.parameters(), "lr": LR_ADAPTER}]
|
||||
else:
|
||||
groups = [{"params": model.adapter.parameters(), "lr": LR_ADAPTER},
|
||||
{"params": model.trunk.parameters(), "lr": lr_trunk}]
|
||||
opt = torch.optim.AdamW(groups)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
n = len(x_tr)
|
||||
best_val, best_state, best_epoch, bad = float("inf"), None, -1, 0
|
||||
history = []
|
||||
t0 = time.time()
|
||||
for epoch in range(MAX_EPOCHS):
|
||||
model.train()
|
||||
if freeze_trunk:
|
||||
model.trunk.eval() # keep BatchNorm running stats fixed: pure transfer
|
||||
perm = torch.randperm(n, device=device)
|
||||
ep_loss = 0.0
|
||||
for i in range(0, n, BATCH):
|
||||
idx = perm[i:i + BATCH]
|
||||
opt.zero_grad()
|
||||
loss = loss_fn(model(x_tr[idx]), y_tr[idx])
|
||||
loss.backward()
|
||||
opt.step()
|
||||
ep_loss += loss.item() * len(idx)
|
||||
val_mpjpe = mpjpe(predict(model, x_va), y_va)
|
||||
history.append({"epoch": epoch, "train_mse": ep_loss / n, "val_mpjpe": val_mpjpe})
|
||||
marker = ""
|
||||
if val_mpjpe < best_val:
|
||||
best_val, best_epoch, bad = val_mpjpe, epoch, 0
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
||||
marker = " *"
|
||||
else:
|
||||
bad += 1
|
||||
print(f"[{name}] epoch {epoch:02d} train_mse {ep_loss / n:.6f} "
|
||||
f"val_mpjpe {val_mpjpe:.5f}{marker}", flush=True)
|
||||
if bad >= PATIENCE:
|
||||
print(f"[{name}] early stop at epoch {epoch} (best {best_epoch})", flush=True)
|
||||
break
|
||||
model.load_state_dict(best_state)
|
||||
torch.save(best_state, os.path.join(MEASB, f"{name}_best.pth"))
|
||||
return model, {"best_epoch": best_epoch, "best_val_mpjpe": best_val,
|
||||
"epochs_run": len(history), "wall_seconds": round(time.time() - t0, 1),
|
||||
"history": history}
|
||||
|
||||
|
||||
def run_suite(tag, csi, kps, device):
|
||||
"""Temporal 70/15/15 split, mean-pose baseline, three training runs."""
|
||||
n = len(csi)
|
||||
tr, va, te = temporal_split(n)
|
||||
print(f"=== suite {tag}: n={n} train={tr.stop} val={va.stop - va.start} "
|
||||
f"test={te.stop - te.start} ===", flush=True)
|
||||
|
||||
# CSI normalization constant from TRAIN split only.
|
||||
train_p99 = float(np.percentile(csi[tr], 99))
|
||||
train_max = float(csi[tr].max())
|
||||
print(f"[{tag}] train p99={train_p99:.3f} max={train_max:.3f} -> /p99, clip [0,1]",
|
||||
flush=True)
|
||||
csi_n = np.clip(csi / train_p99, 0.0, 1.0).astype(np.float32)
|
||||
|
||||
x = torch.from_numpy(csi_n).to(device)
|
||||
y = torch.from_numpy(kps).to(device)
|
||||
x_tr, y_tr = x[tr], y[tr]
|
||||
x_va, y_va = x[va], y[va]
|
||||
x_te, y_te = x[te], y[te]
|
||||
|
||||
suite = {
|
||||
"n_windows": n,
|
||||
"split": {"n_train": int(tr.stop), "n_val": int(va.stop - va.start),
|
||||
"n_test": int(te.stop - te.start)},
|
||||
"csi_norm": {"method": "divide by train-split p99 amplitude, clip [0,1]",
|
||||
"train_p99": train_p99, "train_max": train_max},
|
||||
"runs": {},
|
||||
}
|
||||
|
||||
# Honesty bar: mean-pose predictor fit on TRAIN, evaluated on TEST.
|
||||
mean_pose = y_tr.mean(dim=0, keepdim=True).expand(len(y_te), -1, -1)
|
||||
suite["mean_pose_baseline"] = eval_preds(mean_pose, y_te)
|
||||
suite["mean_pose_baseline"]["note"] = "train-split mean pose; pred_std 0 by construction"
|
||||
print(f"[{tag}] mean-pose baseline:", json.dumps(suite["mean_pose_baseline"]),
|
||||
flush=True)
|
||||
|
||||
configs = [
|
||||
("pretrained", dict(pretrained=True, freeze_trunk=False, lr_trunk=LR_TRUNK_FT)),
|
||||
("scratch", dict(pretrained=False, freeze_trunk=False, lr_trunk=LR_ADAPTER)),
|
||||
("frozen_trunk", dict(pretrained=True, freeze_trunk=True, lr_trunk=0.0)),
|
||||
]
|
||||
for name, cfg in configs:
|
||||
print(f"=== run: {tag}/{name} {cfg} ===", flush=True)
|
||||
model, train_info = train_run(f"{tag}_{name}", x_tr, y_tr, x_va, y_va,
|
||||
device, **cfg)
|
||||
test_metrics = eval_preds(predict(model, x_te), y_te)
|
||||
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
suite["runs"][name] = {"config": cfg, "trainable_params": n_trainable,
|
||||
"train": {k: v for k, v in train_info.items()
|
||||
if k != "history"},
|
||||
"history": train_info["history"],
|
||||
"test": test_metrics}
|
||||
print(f"[{tag}/{name}] TEST:", json.dumps(test_metrics), flush=True)
|
||||
return suite
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"device {device}, torch {torch.__version__}", flush=True)
|
||||
set_seed(SEED)
|
||||
|
||||
csi, kps, confs, native70, shape_counts, ts_first, ts_last = load_dataset()
|
||||
print(f"shape distribution: {shape_counts}", flush=True)
|
||||
|
||||
results = {
|
||||
"protocol": {
|
||||
"dataset": DATA, "n_windows": len(csi),
|
||||
"ts_first": ts_first, "ts_last": ts_last,
|
||||
"conf_mean": float(confs.mean()), "conf_min": float(confs.min()),
|
||||
"csi_shape_distribution": shape_counts,
|
||||
"csi_layout_note": "aligner stores frame-major data under a transposed "
|
||||
"[nSc, nFrames] shape label; corrected on load",
|
||||
"csi_resample": "per-frame linear interp of subcarrier axis to 70 bins "
|
||||
"(identity for native-70 frames); native-70 windows still "
|
||||
"contain ~20.4% internally zero-padded short frames",
|
||||
"split": "temporal 70/15/15 (no shuffle across time)",
|
||||
"model": "1x1 Conv1d 70->540 adapter + WiFlowPoseModel trunk, "
|
||||
"AdaptiveAvgPool2d((17,1)) head (parameter-free K=17)",
|
||||
"checkpoint": CHECKPOINT,
|
||||
"checkpoint_note": "measurement-(a) retrained checkpoint (~96% PCK@20 on "
|
||||
"WiFlow data); att./final_conv. remap applied "
|
||||
"defensively (no-op, already new-style keys)",
|
||||
"optimizer": f"AdamW, adapter lr {LR_ADAPTER}, fine-tuned trunk lr "
|
||||
f"{LR_TRUNK_FT} (10x lower), scratch all {LR_ADAPTER}",
|
||||
"batch": BATCH, "max_epochs": MAX_EPOCHS, "patience": PATIENCE,
|
||||
"precision": "fp32", "seed": SEED,
|
||||
"pck": "torso-normalized, torso = ||l_shoulder(5) - l_hip(11)||, "
|
||||
"clamp min 0.01, mean over keypoints x frames "
|
||||
"(upstream math; upstream 2/12 indices are a 15-kp convention)",
|
||||
},
|
||||
# Primary: all 2,046 windows (pre-registered n), subcarrier axis resampled.
|
||||
"all2046": None,
|
||||
# Secondary robustness check: the 1,347 native [70,20] windows only.
|
||||
"native70": None,
|
||||
}
|
||||
|
||||
results["all2046"] = run_suite("all2046", csi, kps, device)
|
||||
results["native70"] = run_suite("native70", csi[native70], kps[native70], device)
|
||||
|
||||
out = os.path.join(MEASB, "measurement_b.json")
|
||||
with open(out, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"wrote {out}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
cd ~/wiflow-std-bench
|
||||
|
||||
# 1. clone upstream at the pinned commit
|
||||
if [ ! -d upstream ]; then
|
||||
git clone https://github.com/DY2434/WiFlow-WiFi-Pose-Estimation-with-Spatio-Temporal-Decoupling upstream
|
||||
fi
|
||||
cd upstream && git checkout 06899d294a0f44709d601a53e91dbf24759daefb && cd ..
|
||||
|
||||
# 2. documented deviation: fix upstream import bug (TemporalConvNet does not exist)
|
||||
sed -i 's/from .tcn import TemporalConvNet/from .tcn import TemporalBlock/; s/'"'"'TemporalConvNet'"'"'/'"'"'TemporalBlock'"'"'/' upstream/models/__init__.py
|
||||
|
||||
# 3. venv: torch cu128 (RTX 5080 = sm_120 needs >=2.7; their pin 2.3.1 predates Blackwell)
|
||||
if [ ! -d venv ]; then
|
||||
python3 -m venv venv
|
||||
./venv/bin/pip install -q --upgrade pip
|
||||
./venv/bin/pip install -q torch --index-url https://download.pytorch.org/whl/cu128
|
||||
./venv/bin/pip install -q numpy pandas matplotlib seaborn scikit-learn opencv-python-headless scipy tqdm psutil kagglehub
|
||||
fi
|
||||
./venv/bin/python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_device_name(0))"
|
||||
|
||||
# 4. dataset via kagglehub (anonymous, public dataset)
|
||||
DS=$(./venv/bin/python -c "import kagglehub; print(kagglehub.dataset_download('kaka2434/wiflow-dataset'))")
|
||||
echo "dataset at: $DS"
|
||||
|
||||
# 5. run.py hardcodes ../preprocessed_csi_data relative to upstream/
|
||||
ln -sfn "$DS/preprocessed_csi_data" ~/wiflow-std-bench/preprocessed_csi_data
|
||||
|
||||
# 6. train with upstream defaults (seed 42 set inside run.py)
|
||||
../venv/bin/python ../clean_nan.py 2>/dev/null || venv/bin/python clean_nan.py
|
||||
cd upstream
|
||||
../venv/bin/python run.py --gpu 0 --batch_size 64 --epochs 50 --output_dir ../train_output
|
||||
@@ -0,0 +1,332 @@
|
||||
"""Configurable compact variants of the WiFlow-STD pose model (ADR-152 efficiency sweep).
|
||||
|
||||
This is a parameterized copy of upstream models/{pose_model,tcn,convnet,attention}.py
|
||||
(DY2434/WiFlow @ 06899d29, Apache-2.0). upstream/ is NOT modified. Deviations from
|
||||
upstream, all forced by shrinking channels and documented per variant in run_sweep.py:
|
||||
|
||||
1. TCN grouped-conv groups: upstream hardcodes groups=20, which does not divide
|
||||
the compact channel counts (e.g. 270, 135, 85). Rule here:
|
||||
- groups_mode='gcd20': per-conv groups = gcd(channels, 20) (== 20 wherever
|
||||
upstream's choice is valid, incl. the 540-ch input conv; falls back to the
|
||||
largest common divisor with 20 otherwise).
|
||||
- groups_mode='depthwise': groups = channels (tiny variant only).
|
||||
2. Conv2d downsampling strides: upstream uses 4 stride-(1,2) blocks because
|
||||
240/2^4 = 15 == n_keypoints. With smaller TCN output widths that would leave
|
||||
<15 rows and AdaptiveAvgPool2d((15,1)) would duplicate rows across keypoints.
|
||||
Rule: halve the width only while the result stays >= 15 (stride-2 blocks
|
||||
first, stride-1 after). Full model: 240 -> 4 halvings = upstream exactly.
|
||||
3. input_pw_groups (tiny only): the dense 540->c pointwise + residual downsample
|
||||
in TCN block 1 cost 2*540*c params (a ~117k floor that alone exceeds the
|
||||
tiny <100k budget). tiny groups these two convs (groups=4; 4 | gcd(540, 68)).
|
||||
4. Decoder mid-channels: upstream 64->32; here c_last -> max(c_last // 2, 4).
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def tcn_groups(channels: int, mode: str) -> int:
|
||||
if mode == 'depthwise':
|
||||
return channels
|
||||
if mode == 'gcd20':
|
||||
return math.gcd(channels, 20)
|
||||
raise ValueError(mode)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- TCN (copy of tcn.py)
|
||||
class Chomp1d(nn.Module):
|
||||
def __init__(self, chomp_size):
|
||||
super().__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, :, :-self.chomp_size].contiguous()
|
||||
|
||||
|
||||
class CompactGroupedTemporalBlock(nn.Module):
|
||||
"""Upstream InnerGroupedTemporalBlock with parameterized groups."""
|
||||
|
||||
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding,
|
||||
dropout=0.2, groups_mode='gcd20', pw_groups=1):
|
||||
super().__init__()
|
||||
g_in = tcn_groups(n_inputs, groups_mode)
|
||||
g_out = tcn_groups(n_outputs, groups_mode)
|
||||
self.groups = (g_in, g_out)
|
||||
self.pw_groups = pw_groups
|
||||
|
||||
self.conv1_group = nn.Conv1d(n_inputs, n_inputs, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
groups=g_in, bias=False)
|
||||
self.chomp1 = Chomp1d(padding) if padding > 0 else nn.Identity()
|
||||
self.bn1_group = nn.BatchNorm1d(n_inputs)
|
||||
self.relu1_group = nn.SiLU(inplace=True)
|
||||
|
||||
self.conv1_pw = nn.Conv1d(n_inputs, n_outputs, 1, groups=pw_groups, bias=False)
|
||||
self.bn1_pw = nn.BatchNorm1d(n_outputs)
|
||||
self.relu1_pw = nn.SiLU(inplace=True)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
|
||||
self.conv2_group = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation,
|
||||
groups=g_out, bias=False)
|
||||
self.chomp2 = Chomp1d(padding) if padding > 0 else nn.Identity()
|
||||
self.bn2_group = nn.BatchNorm1d(n_outputs)
|
||||
self.relu2_group = nn.SiLU(inplace=True)
|
||||
|
||||
self.conv2_pw = nn.Conv1d(n_outputs, n_outputs, 1, bias=False)
|
||||
self.bn2_pw = nn.BatchNorm1d(n_outputs)
|
||||
self.relu2_pw = nn.SiLU(inplace=True)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv1d(n_inputs, n_outputs, 1, groups=pw_groups, bias=False),
|
||||
nn.BatchNorm1d(n_outputs)
|
||||
) if n_inputs != n_outputs else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
res = self.downsample(x)
|
||||
out = self.conv1_group(x)
|
||||
out = self.chomp1(out)
|
||||
out = self.bn1_group(out)
|
||||
out = self.relu1_group(out)
|
||||
out = self.conv1_pw(out)
|
||||
out = self.bn1_pw(out)
|
||||
out = self.relu1_pw(out)
|
||||
out = self.dropout1(out)
|
||||
out = self.conv2_group(out)
|
||||
out = self.chomp2(out)
|
||||
out = self.bn2_group(out)
|
||||
out = self.relu2_group(out)
|
||||
out = self.conv2_pw(out)
|
||||
out = self.bn2_pw(out)
|
||||
out = self.relu2_pw(out)
|
||||
out = self.dropout2(out)
|
||||
return F.silu(out + res)
|
||||
|
||||
|
||||
class CompactTemporalBlock(nn.Module):
|
||||
def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.2,
|
||||
groups_mode='gcd20', input_pw_groups=1):
|
||||
super().__init__()
|
||||
layers = []
|
||||
for i, out_channels in enumerate(num_channels):
|
||||
dilation_size = 2 ** i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
layers.append(CompactGroupedTemporalBlock(
|
||||
in_channels, out_channels, kernel_size, stride=1,
|
||||
dilation=dilation_size, padding=(kernel_size - 1) * dilation_size,
|
||||
dropout=dropout, groups_mode=groups_mode,
|
||||
pw_groups=input_pw_groups if i == 0 else 1))
|
||||
self.network = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.network(x)
|
||||
|
||||
|
||||
# ------------------------------------------------------- Conv2d path (copy of convnet.py)
|
||||
class AsymmetricConvBlock(nn.Module):
|
||||
"""Upstream block with parameterized width stride (upstream: always (1,2))."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, dropout=0.3, stride_w=2):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3),
|
||||
stride=(1, stride_w), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Dropout2d(dropout),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Dropout2d(dropout),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
||||
stride=(1, stride_w), bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
self.activation = nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.block(x) + self.downsample(x))
|
||||
|
||||
|
||||
class ConvBlock1(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dropout=0.3):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Dropout2d(dropout),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Dropout2d(dropout),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
self.activation = nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.block(x) + self.downsample(x))
|
||||
|
||||
|
||||
# ----------------------------------------------------- attention (verbatim attention.py)
|
||||
class AxialAttention(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, groups=8, stride=1, bias=False, width=False):
|
||||
assert (in_planes % groups == 0) and (out_planes % groups == 0)
|
||||
super().__init__()
|
||||
self.in_planes = in_planes
|
||||
self.out_planes = out_planes
|
||||
self.groups = groups
|
||||
self.group_planes = out_planes // groups
|
||||
self.stride = stride
|
||||
self.bias = bias
|
||||
self.width = width
|
||||
self.qkv_transform = nn.Conv1d(in_planes, out_planes * 3, kernel_size=1,
|
||||
stride=1, padding=0, bias=False)
|
||||
self.bn_qkv = nn.BatchNorm1d(out_planes * 3)
|
||||
self.bn_similarity = nn.BatchNorm2d(groups)
|
||||
self.bn_output = nn.BatchNorm1d(out_planes)
|
||||
if stride > 1:
|
||||
self.pooling = nn.AvgPool2d(stride, stride=stride)
|
||||
nn.init.normal_(self.qkv_transform.weight.data, 0, math.sqrt(1. / self.in_planes))
|
||||
|
||||
def forward(self, x):
|
||||
if self.width:
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
else:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
N, W, C, H = x.shape
|
||||
x = x.contiguous().view(N * W, C, H)
|
||||
qkv = self.bn_qkv(self.qkv_transform(x))
|
||||
qkv = qkv.reshape(N * W, 3, self.out_planes, H).permute(1, 0, 2, 3)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q.reshape(N * W, self.groups, self.group_planes, H)
|
||||
k = k.reshape(N * W, self.groups, self.group_planes, H)
|
||||
v = v.reshape(N * W, self.groups, self.group_planes, H)
|
||||
qk = torch.einsum('bgci, bgcj->bgij', q, k)
|
||||
qk = self.bn_similarity(qk)
|
||||
similarity = F.softmax(qk, dim=-1)
|
||||
sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
|
||||
sv = sv.reshape(N * W, self.out_planes, H)
|
||||
out = self.bn_output(sv)
|
||||
out = out.view(N, W, self.out_planes, H)
|
||||
if self.width:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
if self.stride > 1:
|
||||
out = self.pooling(out)
|
||||
return out
|
||||
|
||||
|
||||
class DualAxialAttention(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, groups=8, stride=1, bias=False):
|
||||
super().__init__()
|
||||
self.width_axis = AxialAttention(in_planes, out_planes, groups, stride, bias, width=True)
|
||||
self.height_axis = AxialAttention(out_planes, out_planes, groups, stride, bias, width=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.height_axis(self.width_axis(x))
|
||||
|
||||
|
||||
# --------------------------------------------------------------- full model
|
||||
def compute_strides(width: int, n_blocks: int, target: int = 15):
|
||||
"""Halve width while result stays >= target (upstream: 240 -> 4 halvings -> 15)."""
|
||||
strides = []
|
||||
for _ in range(n_blocks):
|
||||
nxt = (width + 1) // 2 # conv k=3 s=2 p=1: out = ceil(in/2)
|
||||
if nxt >= target:
|
||||
strides.append(2)
|
||||
width = nxt
|
||||
else:
|
||||
strides.append(1)
|
||||
return strides, width
|
||||
|
||||
|
||||
class CompactWiFlowPoseModel(nn.Module):
|
||||
"""Parameterized upstream WiFlowPoseModel.
|
||||
|
||||
Upstream config == tcn_channels=[540,440,340,240], conv_channels=[8,16,32,64],
|
||||
attn_groups=8, groups_mode='gcd20' (gcd(c,20)==20 for all upstream channels),
|
||||
input_pw_groups=1 -> identical architecture, 2,225,042 params.
|
||||
"""
|
||||
|
||||
def __init__(self, tcn_channels, conv_channels, attn_groups,
|
||||
groups_mode='gcd20', input_pw_groups=1, dropout=0.3,
|
||||
num_subcarriers=540, num_keypoints=15):
|
||||
super().__init__()
|
||||
self.tcn = CompactTemporalBlock(
|
||||
num_inputs=num_subcarriers, num_channels=tcn_channels, kernel_size=3,
|
||||
dropout=dropout, groups_mode=groups_mode, input_pw_groups=input_pw_groups)
|
||||
|
||||
self.up = ConvBlock1(1, conv_channels[0])
|
||||
|
||||
strides, self.final_width = compute_strides(
|
||||
tcn_channels[-1], len(conv_channels), target=num_keypoints)
|
||||
self.conv_strides = strides
|
||||
self.residual_blocks = nn.ModuleList()
|
||||
in_channels = conv_channels[0]
|
||||
for out_channels, s in zip(conv_channels, strides):
|
||||
self.residual_blocks.append(
|
||||
AsymmetricConvBlock(in_channels, out_channels, stride_w=s))
|
||||
in_channels = out_channels
|
||||
|
||||
c_last = conv_channels[-1]
|
||||
self.attention = DualAxialAttention(c_last, c_last, groups=attn_groups)
|
||||
|
||||
c_mid = max(c_last // 2, 4)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(c_last, c_mid, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(c_mid),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Conv2d(c_mid, 2, kernel_size=1),
|
||||
nn.BatchNorm2d(2),
|
||||
nn.SiLU(inplace=True)
|
||||
)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((num_keypoints, 1))
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# [B, 540, 20]
|
||||
x = self.tcn(x) # [B, C_tcn, 20]
|
||||
x = x.transpose(1, 2).unsqueeze(1) # [B, 1, 20, C_tcn]
|
||||
x = self.up(x)
|
||||
for block in self.residual_blocks:
|
||||
x = block(x) # [B, C_conv, 20, W']
|
||||
x = x.permute(0, 1, 3, 2) # [B, C_conv, W', 20]
|
||||
x = self.attention(x)
|
||||
x = self.decoder(x) # [B, 2, W', 20]
|
||||
x = self.avg_pool(x).squeeze(-1) # [B, 2, 15]
|
||||
return x.transpose(1, 2) # [B, 15, 2]
|
||||
|
||||
|
||||
def describe(model: 'CompactWiFlowPoseModel'):
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
tcn_g = [blk.groups for blk in model.tcn.network]
|
||||
return {'params': params, 'tcn_groups_per_block': tcn_g,
|
||||
'conv_strides': model.conv_strides, 'final_width': model.final_width}
|
||||
@@ -0,0 +1,278 @@
|
||||
"""WiFlow-STD compact-variant efficiency sweep (ADR-152) — sequential overnight runner.
|
||||
|
||||
Trains compact variants of the upstream WiFlow-STD architecture on the same
|
||||
data/split as the full-size reference retraining (seed 42, file-level 70/15/15,
|
||||
upstream dataset.py) and evaluates PCK@10..50 + MPJPE on the full test split and
|
||||
the corruption-free test subset (file indices < 487).
|
||||
|
||||
Training mirrors upstream run.py/train.py defaults except:
|
||||
- fp32 only (no fp16 autocast / GradScaler — avoids the BN-poisoning trap
|
||||
documented in RESULTS.md defect 5; data on disk is already cleaned).
|
||||
- batch 64 (kept modest: another GPU job may share the 16 GB card tonight).
|
||||
- scheduler + early stopping keyed on val MPJPE (upstream early-stops on val MPE
|
||||
with patience 5; same here).
|
||||
|
||||
Usage:
|
||||
venv/bin/python sweep/run_sweep.py --dry-run # param counts only
|
||||
nohup venv/bin/python sweep/run_sweep.py > sweep/sweep.log 2>&1 &
|
||||
|
||||
Idempotent: variants already present in sweep/results.jsonl are skipped.
|
||||
|
||||
NOTE: deployed to ruvultra (~/wiflow-std-bench/sweep) as a standalone file, so
|
||||
it deliberately inlines its helpers. The reference implementations (upstream
|
||||
import shim, >1GB np.load mmap patch, key-remap loader, canonical evaluate
|
||||
loop) live in benchmarks/wiflow-std/_bench_common.py — keep copies in sync.
|
||||
"""
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
# csi_windows.npy is ~13 GB; mmap large arrays instead of eagerly loading
|
||||
# ~15 GB into RAM (same patch as _bench_common._np_load_mmap).
|
||||
_np_load = np.load
|
||||
|
||||
|
||||
def _np_load_mmap(path, *a, **kw):
|
||||
if (isinstance(path, str) and path.endswith('.npy')
|
||||
and os.path.getsize(path) > 1 << 30 and 'mmap_mode' not in kw):
|
||||
kw['mmap_mode'] = 'r'
|
||||
return _np_load(path, *a, **kw)
|
||||
|
||||
|
||||
np.load = _np_load_mmap
|
||||
|
||||
BENCH = os.path.expanduser('~/wiflow-std-bench')
|
||||
SWEEP = os.path.join(BENCH, 'sweep')
|
||||
sys.path.insert(0, os.path.join(BENCH, 'upstream'))
|
||||
sys.path.insert(0, SWEEP)
|
||||
|
||||
from dataset import PreprocessedCSIKeypointsDataset, create_preprocessed_train_val_test_loaders # noqa: E402
|
||||
from losses.pose_loss import PoseLoss # noqa: E402
|
||||
from utils.metrics import calculate_pck, calculate_mpjpe # noqa: E402
|
||||
from model_compact import CompactWiFlowPoseModel, describe # noqa: E402
|
||||
|
||||
VARIANTS = [
|
||||
# name, tcn_channels, conv_channels, attn_groups, groups_mode, input_pw_groups
|
||||
dict(name='half', tcn=[270, 220, 170, 120], conv=[4, 8, 16, 32], attn_groups=4,
|
||||
groups_mode='gcd20', input_pw_groups=1),
|
||||
dict(name='quarter', tcn=[135, 110, 85, 60], conv=[2, 4, 8, 16], attn_groups=2,
|
||||
groups_mode='gcd20', input_pw_groups=1),
|
||||
dict(name='tiny', tcn=[68, 56, 44, 32], conv=[2, 4, 8, 16], attn_groups=2,
|
||||
groups_mode='depthwise', input_pw_groups=4),
|
||||
]
|
||||
|
||||
BATCH = 64
|
||||
EPOCHS = 50
|
||||
PATIENCE = 5
|
||||
LR = 1e-4
|
||||
WEIGHT_DECAY = 5e-5
|
||||
SEED = 42
|
||||
CORRUPT_FILE_START = 487 # files 487-499 were zero-filled by clean_nan.py
|
||||
|
||||
|
||||
def set_seed(seed=SEED):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def build_model(v, dropout=0.5):
|
||||
return CompactWiFlowPoseModel(
|
||||
tcn_channels=v['tcn'], conv_channels=v['conv'], attn_groups=v['attn_groups'],
|
||||
groups_mode=v['groups_mode'], input_pw_groups=v['input_pw_groups'],
|
||||
dropout=dropout)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, loader, device):
|
||||
model.eval()
|
||||
totals = {t: 0.0 for t in (0.1, 0.2, 0.3, 0.4, 0.5)}
|
||||
total_mpe, n = 0.0, 0
|
||||
for bx, by in loader:
|
||||
bx, by = bx.to(device), by.to(device)
|
||||
out = model(bx)
|
||||
bs = by.size(0)
|
||||
total_mpe += calculate_mpjpe(out, by) * bs
|
||||
pck = calculate_pck(out, by, thresholds=list(totals))
|
||||
for t in totals:
|
||||
totals[t] += pck[t] * bs
|
||||
n += bs
|
||||
return {'samples': n, 'mpjpe': total_mpe / n,
|
||||
**{f'pck@{int(t * 100)}': totals[t] / n for t in totals}}
|
||||
|
||||
|
||||
def train_variant(v, dataset, device):
|
||||
set_seed(SEED)
|
||||
train_loader, val_loader, test_loader = create_preprocessed_train_val_test_loaders(
|
||||
dataset=dataset, batch_size=BATCH, num_workers=2, random_seed=SEED)
|
||||
|
||||
set_seed(SEED) # re-seed after split so init is split-independent
|
||||
model = build_model(v).to(device)
|
||||
info = describe(model)
|
||||
print(f"[{v['name']}] params={info['params']:,} tcn_groups={info['tcn_groups_per_block']} "
|
||||
f"conv_strides={info['conv_strides']} final_width={info['final_width']}", flush=True)
|
||||
|
||||
criterion = PoseLoss(position_weight=1.0, bone_weight=0.2, loss_type='smooth_l1')
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY,
|
||||
betas=(0.9, 0.999))
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode='min', factor=0.5, patience=3, min_lr=LR / 1000,
|
||||
cooldown=1, threshold=1e-4)
|
||||
|
||||
best_val_mpe = float('inf')
|
||||
best_val_pck20 = 0.0
|
||||
best_epoch = 0
|
||||
best_state = None
|
||||
patience_counter = 0
|
||||
t0 = time.time()
|
||||
error = None
|
||||
epochs_run = 0
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
ep_loss, nb = 0.0, 0
|
||||
te = time.time()
|
||||
for i, (bx, by) in enumerate(train_loader):
|
||||
bx = bx.to(device, non_blocking=True)
|
||||
by = by.to(device, non_blocking=True)
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
out = model(bx)
|
||||
loss, _parts = criterion(out, by)
|
||||
if not torch.isfinite(loss):
|
||||
error = f'non-finite loss at epoch {epoch} step {i}'
|
||||
break
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
ep_loss += loss.item()
|
||||
nb += 1
|
||||
if epoch == 1 and i % 500 == 0:
|
||||
print(f"[{v['name']}] e1 step {i}/{len(train_loader)} loss={loss.item():.5f}",
|
||||
flush=True)
|
||||
if error:
|
||||
break
|
||||
epochs_run = epoch
|
||||
|
||||
val = evaluate(model, val_loader, device)
|
||||
scheduler.step(val['mpjpe'])
|
||||
lr_now = optimizer.param_groups[0]['lr']
|
||||
print(f"[{v['name']}] epoch {epoch}/{EPOCHS} train_loss={ep_loss / max(nb, 1):.5f} "
|
||||
f"val_mpjpe={val['mpjpe']:.5f} val_pck20={val['pck@20'] * 100:.2f}% "
|
||||
f"lr={lr_now:.2e} ({time.time() - te:.0f}s)", flush=True)
|
||||
|
||||
if val['mpjpe'] < best_val_mpe:
|
||||
best_val_mpe = val['mpjpe']
|
||||
best_val_pck20 = val['pck@20']
|
||||
best_epoch = epoch
|
||||
best_state = copy.deepcopy(model.state_dict())
|
||||
patience_counter = 0
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= PATIENCE:
|
||||
print(f"[{v['name']}] early stop at epoch {epoch} (best {best_epoch})", flush=True)
|
||||
break
|
||||
|
||||
train_seconds = time.time() - t0
|
||||
result = {
|
||||
'variant': v['name'], 'params': info['params'],
|
||||
'tcn_channels': v['tcn'], 'conv_channels': v['conv'],
|
||||
'attn_groups': v['attn_groups'], 'groups_mode': v['groups_mode'],
|
||||
'input_pw_groups': v['input_pw_groups'],
|
||||
'tcn_groups_per_block': info['tcn_groups_per_block'],
|
||||
'conv_strides': info['conv_strides'], 'final_width': info['final_width'],
|
||||
'batch_size': BATCH, 'max_epochs': EPOCHS, 'patience': PATIENCE,
|
||||
'lr': LR, 'weight_decay': WEIGHT_DECAY, 'seed': SEED, 'precision': 'fp32',
|
||||
'epochs_run': epochs_run, 'best_epoch': best_epoch,
|
||||
'best_val_mpjpe': best_val_mpe if best_state else None,
|
||||
'best_val_pck20': best_val_pck20 if best_state else None,
|
||||
'train_seconds': round(train_seconds, 1),
|
||||
'torch': torch.__version__, 'error': error,
|
||||
'finished_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()),
|
||||
}
|
||||
|
||||
if best_state is not None:
|
||||
ckpt = os.path.join(SWEEP, f"{v['name']}_best.pth")
|
||||
torch.save(best_state, ckpt)
|
||||
result['checkpoint'] = ckpt
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
eval_loader = DataLoader(test_loader.dataset, batch_size=256, shuffle=False,
|
||||
num_workers=2)
|
||||
result['test_full'] = evaluate(model, eval_loader, device)
|
||||
|
||||
w2f = dataset.window_to_file
|
||||
clean_idx = [i for i in test_loader.dataset.indices if w2f[i] < CORRUPT_FILE_START]
|
||||
clean_loader = DataLoader(Subset(dataset, clean_idx), batch_size=256,
|
||||
shuffle=False, num_workers=2)
|
||||
result['test_clean'] = evaluate(model, clean_loader, device)
|
||||
print(f"[{v['name']}] TEST clean: pck20={result['test_clean']['pck@20'] * 100:.2f}% "
|
||||
f"mpjpe={result['test_clean']['mpjpe']:.5f} | full: "
|
||||
f"pck20={result['test_full']['pck@20'] * 100:.2f}%", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument('--dry-run', action='store_true', help='print param counts and exit')
|
||||
args = ap.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
for v in VARIANTS:
|
||||
m = build_model(v)
|
||||
info = describe(m)
|
||||
x = torch.randn(2, 540, 20)
|
||||
m.eval()
|
||||
y = m(x)
|
||||
print(f"{v['name']:8s} params={info['params']:>9,} "
|
||||
f"tcn={v['tcn']} conv={v['conv']} attn_g={v['attn_groups']} "
|
||||
f"mode={v['groups_mode']} pw_g={v['input_pw_groups']} "
|
||||
f"tcn_groups={info['tcn_groups_per_block']} strides={info['conv_strides']} "
|
||||
f"W'={info['final_width']} out={tuple(y.shape)}")
|
||||
return
|
||||
|
||||
results_path = os.path.join(SWEEP, 'results.jsonl')
|
||||
done = set()
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path) as f:
|
||||
for line in f:
|
||||
try:
|
||||
done.add(json.loads(line)['variant'])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
device = torch.device('cuda')
|
||||
print(f"torch {torch.__version__} on {torch.cuda.get_device_name(0)}", flush=True)
|
||||
data_dir = os.path.join(BENCH, 'preprocessed_csi_data')
|
||||
dataset = PreprocessedCSIKeypointsDataset(data_dir=data_dir, keypoint_scale=1000.0,
|
||||
enable_temporal_clean=True)
|
||||
|
||||
for v in VARIANTS:
|
||||
if v['name'] in done:
|
||||
print(f"[{v['name']}] already in results.jsonl — skipping", flush=True)
|
||||
continue
|
||||
print(f"\n===== variant: {v['name']} =====", flush=True)
|
||||
try:
|
||||
result = train_variant(v, dataset, device)
|
||||
except Exception as e: # record and move on to next variant
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
result = {'variant': v['name'], 'error': repr(e),
|
||||
'finished_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())}
|
||||
with open(results_path, 'a') as f:
|
||||
f.write(json.dumps(result) + '\n')
|
||||
f.flush()
|
||||
print('\nSWEEP COMPLETE', flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
@@ -0,0 +1,772 @@
|
||||
{
|
||||
"torch": {
|
||||
"env": {
|
||||
"torch": "2.12.0+cpu",
|
||||
"platform": "Windows-11-10.0.26200-SP0",
|
||||
"processor": "Intel64 Family 6 Model 197 Stepping 2, GenuineIntel",
|
||||
"num_threads": 16,
|
||||
"checkpoint": "results\\retrained_best_pose_model.pth",
|
||||
"params": 2225042
|
||||
},
|
||||
"variants": {
|
||||
"fp32": {
|
||||
"file": "retrained_fp32_resaved.pth",
|
||||
"size_bytes": 9068948,
|
||||
"size_mb": 9.068948,
|
||||
"latency_batch1": {
|
||||
"batch_size": 1,
|
||||
"runs": 100,
|
||||
"median_ms_per_batch": 24.903650000851485,
|
||||
"median_ms_per_window": 24.903650000851485,
|
||||
"windows_per_second": 40.15475642991324
|
||||
},
|
||||
"latency_batch64": {
|
||||
"batch_size": 64,
|
||||
"runs": 30,
|
||||
"median_ms_per_batch": 184.02919999789447,
|
||||
"median_ms_per_window": 2.875456249967101,
|
||||
"windows_per_second": 347.77089723115813
|
||||
},
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9668200004577636,
|
||||
"pck@50": 0.9915333324432373,
|
||||
"mpjpe": 0.00936222033649683,
|
||||
"wall_seconds": 37.85407733917236
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"file": "retrained_fp16.pth",
|
||||
"size_bytes": 4580332,
|
||||
"size_mb": 4.580332,
|
||||
"latency_batch1": {
|
||||
"batch_size": 1,
|
||||
"runs": 100,
|
||||
"median_ms_per_batch": 23.936699999467237,
|
||||
"median_ms_per_window": 23.936699999467237,
|
||||
"windows_per_second": 41.776853117691964
|
||||
},
|
||||
"latency_batch64": {
|
||||
"batch_size": 64,
|
||||
"runs": 30,
|
||||
"median_ms_per_batch": 102.32584999903338,
|
||||
"median_ms_per_window": 1.5988414062348966,
|
||||
"windows_per_second": 625.4529036465817
|
||||
},
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.966773332977295,
|
||||
"pck@50": 0.9915066654205322,
|
||||
"mpjpe": 0.009460017587244511,
|
||||
"wall_seconds": 21.632277250289917
|
||||
}
|
||||
},
|
||||
"int8_dynamic": {
|
||||
"file": "retrained_int8_dynamic.pth",
|
||||
"size_bytes": 9068948,
|
||||
"size_mb": 9.068948,
|
||||
"latency_batch1": {
|
||||
"batch_size": 1,
|
||||
"runs": 100,
|
||||
"median_ms_per_batch": 18.105350000041653,
|
||||
"median_ms_per_window": 18.105350000041653,
|
||||
"windows_per_second": 55.23229321707117
|
||||
},
|
||||
"latency_batch64": {
|
||||
"batch_size": 64,
|
||||
"runs": 30,
|
||||
"median_ms_per_batch": 168.77549999844632,
|
||||
"median_ms_per_window": 2.6371171874757238,
|
||||
"windows_per_second": 379.20195763359703
|
||||
},
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9668200004577636,
|
||||
"pck@50": 0.9915333324432373,
|
||||
"mpjpe": 0.00936222033649683,
|
||||
"wall_seconds": 45.35376596450806
|
||||
}
|
||||
}
|
||||
},
|
||||
"int8_dynamic_quant_report": {
|
||||
"eligible_module_counts": {
|
||||
"nn.Linear": 0,
|
||||
"nn.Conv1d": 21,
|
||||
"nn.Conv2d": 22
|
||||
},
|
||||
"modules_actually_quantized": [],
|
||||
"n_modules_quantized": 0,
|
||||
"params_total": 2225042,
|
||||
"params_quantized": 0,
|
||||
"params_quantized_fraction": 0.0
|
||||
},
|
||||
"accuracy_subset": {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted windows (files 487-499) excluded, seed-42 random subset",
|
||||
"subset_size": 10000,
|
||||
"clean_test_total": 10000
|
||||
}
|
||||
},
|
||||
"onnx": {
|
||||
"env": {
|
||||
"torch": "2.12.0+cpu",
|
||||
"onnxruntime": "1.26.0",
|
||||
"platform": "Windows-11-10.0.26200-SP0"
|
||||
},
|
||||
"export": {
|
||||
"mode": "dynamic-batch",
|
||||
"exporter": "torchscript",
|
||||
"file": "retrained_fp32_dynamic.onnx",
|
||||
"size_mb": 8.971781
|
||||
},
|
||||
"parity": {
|
||||
"fixture": "results/parity_fixture.npz (batch 2, seed 42)",
|
||||
"max_abs_diff_vs_stored_fixture": 2.384185791015625e-07,
|
||||
"max_abs_diff_vs_torch_now": 2.384185791015625e-07,
|
||||
"pass_lt_1e-4": true
|
||||
},
|
||||
"latency": {
|
||||
"batch1": {
|
||||
"batch_size": 1,
|
||||
"runs": 100,
|
||||
"median_ms_per_batch": 2.5410999987798277,
|
||||
"median_ms_per_window": 2.5410999987798277,
|
||||
"windows_per_second": 393.5303610563043
|
||||
},
|
||||
"batch64": {
|
||||
"batch_size": 64,
|
||||
"runs": 30,
|
||||
"median_ms_per_batch": 181.95204999938142,
|
||||
"median_ms_per_window": 2.8430007812403346,
|
||||
"windows_per_second": 351.7410218803118
|
||||
}
|
||||
},
|
||||
"ort_int8_dynamic_supplementary": {
|
||||
"file": "retrained_int8_ort_dynamic.onnx",
|
||||
"size_mb": 2.438794,
|
||||
"runs": true,
|
||||
"max_abs_diff_vs_fp32_fixture": 0.00827130675315857
|
||||
}
|
||||
},
|
||||
"onnx_accuracy": {
|
||||
"onnx_fp32": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9668200004577636,
|
||||
"pck@50": 0.9915333324432373,
|
||||
"mpjpe": 0.00936222568154335,
|
||||
"wall_seconds": 22.34790802001953
|
||||
},
|
||||
"onnx_int8_ort_dynamic": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.965240001964569,
|
||||
"pck@50": 0.9915466655731201,
|
||||
"mpjpe": 0.01108054072111845,
|
||||
"wall_seconds": 55.742953062057495
|
||||
}
|
||||
},
|
||||
"latency_controlled_rerun": {
|
||||
"note": "3 interleaved repetitions per variant, median ms/window; quiet box",
|
||||
"fp32": {
|
||||
"batch1_ms_per_window_median": 10.969150001983508,
|
||||
"batch1_reps": [
|
||||
10.969150001983508,
|
||||
12.646450000829645,
|
||||
10.49820000116597
|
||||
],
|
||||
"batch64_ms_per_window_median": 2.2734187500077496,
|
||||
"batch64_reps": [
|
||||
2.377234374989712,
|
||||
2.124126562478068,
|
||||
2.2734187500077496
|
||||
]
|
||||
},
|
||||
"fp16": {
|
||||
"batch1_ms_per_window_median": 24.313550000442774,
|
||||
"batch1_reps": [
|
||||
25.1078499986761,
|
||||
21.856999999727122,
|
||||
24.313550000442774
|
||||
],
|
||||
"batch64_ms_per_window_median": 2.414695312495496,
|
||||
"batch64_reps": [
|
||||
2.5705156249955508,
|
||||
1.7137437499741281,
|
||||
2.414695312495496
|
||||
]
|
||||
},
|
||||
"int8_dynamic": {
|
||||
"batch1_ms_per_window_median": 15.627150000000256,
|
||||
"batch1_reps": [
|
||||
17.67525000104797,
|
||||
14.627999998992891,
|
||||
15.627150000000256
|
||||
],
|
||||
"batch64_ms_per_window_median": 2.0546906250160646,
|
||||
"batch64_reps": [
|
||||
2.0546906250160646,
|
||||
2.03407343752815,
|
||||
2.9325796875241394
|
||||
]
|
||||
},
|
||||
"onnx_fp32": {
|
||||
"batch1_ms_per_window_median": 3.186650001225644,
|
||||
"batch1_reps": [
|
||||
2.7332500012562377,
|
||||
3.1995500012271805,
|
||||
3.186650001225644
|
||||
],
|
||||
"batch64_ms_per_window_median": 1.9893374999924163,
|
||||
"batch64_reps": [
|
||||
1.5590843750032946,
|
||||
1.9893374999924163,
|
||||
2.2144343749914697
|
||||
]
|
||||
},
|
||||
"onnx_int8_ort_dynamic": {
|
||||
"batch1_ms_per_window_median": 6.50984999811044,
|
||||
"batch1_reps": [
|
||||
6.50984999811044,
|
||||
6.455249998907675,
|
||||
6.789299999581999
|
||||
],
|
||||
"batch64_ms_per_window_median": 5.770093750015803,
|
||||
"batch64_reps": [
|
||||
5.770093750015803,
|
||||
3.912374999970325,
|
||||
7.8067296875019565
|
||||
]
|
||||
}
|
||||
},
|
||||
"onnx_static_ptq": {
|
||||
"env": {
|
||||
"onnxruntime": "1.26.0",
|
||||
"torch": "2.12.0+cpu",
|
||||
"platform": "Windows-11-10.0.26200-SP0",
|
||||
"source_model": "retrained_fp32_dynamic.onnx",
|
||||
"preprocessed_model": {
|
||||
"file": "retrained_fp32_preproc.onnx",
|
||||
"size_mb": 8.981529
|
||||
}
|
||||
},
|
||||
"variants": {
|
||||
"minmax_all": {
|
||||
"file": "retrained_int8_static_minmax_all.onnx",
|
||||
"size_bytes": 2604286,
|
||||
"size_mb": 2.604286,
|
||||
"calibration": {
|
||||
"method": "minmax",
|
||||
"windows": 1000,
|
||||
"percentile": null,
|
||||
"seconds": 5.052440166473389
|
||||
},
|
||||
"scope": "all",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 283,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 181,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.015945255756378174,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9545266661643982,
|
||||
"pck@50": 0.9913666645050049,
|
||||
"mpjpe": 0.014860070134699345,
|
||||
"wall_seconds": 43.455235958099365
|
||||
}
|
||||
},
|
||||
"minmax_conv": {
|
||||
"file": "retrained_int8_static_minmax_conv.onnx",
|
||||
"size_bytes": 2527421,
|
||||
"size_mb": 2.527421,
|
||||
"calibration": {
|
||||
"method": "minmax",
|
||||
"windows": 1000,
|
||||
"percentile": null,
|
||||
"seconds": 4.380746126174927
|
||||
},
|
||||
"scope": "conv",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 156,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 78,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.010693132877349854,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9663399996757507,
|
||||
"pck@50": 0.9918666641235352,
|
||||
"mpjpe": 0.01084446222037077,
|
||||
"wall_seconds": 35.937947034835815
|
||||
}
|
||||
},
|
||||
"entropy_all": {
|
||||
"file": "retrained_int8_static_entropy_all.onnx",
|
||||
"size_bytes": 2604268,
|
||||
"size_mb": 2.604268,
|
||||
"calibration": {
|
||||
"method": "entropy",
|
||||
"windows": 512,
|
||||
"percentile": null,
|
||||
"seconds": 23.835066318511963
|
||||
},
|
||||
"scope": "all",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 283,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 181,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.015280365943908691,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9530466662406921,
|
||||
"pck@50": 0.9912600006103516,
|
||||
"mpjpe": 0.015098519864678382,
|
||||
"wall_seconds": 51.514281034469604
|
||||
}
|
||||
},
|
||||
"entropy_conv": {
|
||||
"file": "retrained_int8_static_entropy_conv.onnx",
|
||||
"size_bytes": 2527403,
|
||||
"size_mb": 2.527403,
|
||||
"calibration": {
|
||||
"method": "entropy",
|
||||
"windows": 512,
|
||||
"percentile": null,
|
||||
"seconds": 9.634419918060303
|
||||
},
|
||||
"scope": "conv",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 156,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 78,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.012535125017166138,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9659599989891052,
|
||||
"pck@50": 0.9918666648864746,
|
||||
"mpjpe": 0.010778637571632861,
|
||||
"wall_seconds": 41.01180171966553
|
||||
}
|
||||
},
|
||||
"percentile_all": {
|
||||
"file": "retrained_int8_static_percentile_all.onnx",
|
||||
"size_bytes": 2604052,
|
||||
"size_mb": 2.604052,
|
||||
"calibration": {
|
||||
"method": "percentile",
|
||||
"windows": 512,
|
||||
"percentile": 99.99,
|
||||
"seconds": 20.221954584121704
|
||||
},
|
||||
"scope": "all",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 283,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 181,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.017689883708953857,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9639333323478698,
|
||||
"pck@50": 0.9916799991607667,
|
||||
"mpjpe": 0.012176512064039708,
|
||||
"wall_seconds": 49.365190744400024
|
||||
}
|
||||
},
|
||||
"percentile_conv": {
|
||||
"file": "retrained_int8_static_percentile_conv.onnx",
|
||||
"size_bytes": 2527241,
|
||||
"size_mb": 2.527241,
|
||||
"calibration": {
|
||||
"method": "percentile",
|
||||
"windows": 512,
|
||||
"percentile": 99.99,
|
||||
"seconds": 8.223475694656372
|
||||
},
|
||||
"scope": "conv",
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {
|
||||
"Add": 9,
|
||||
"AveragePool": 1,
|
||||
"BatchNormalization": 12,
|
||||
"Concat": 10,
|
||||
"Conv": 43,
|
||||
"DequantizeLinear": 156,
|
||||
"Einsum": 4,
|
||||
"Gather": 16,
|
||||
"Mul": 39,
|
||||
"QuantizeLinear": 78,
|
||||
"Reshape": 14,
|
||||
"Shape": 2,
|
||||
"Sigmoid": 37,
|
||||
"Slice": 8,
|
||||
"Softmax": 2,
|
||||
"Squeeze": 1,
|
||||
"Transpose": 7,
|
||||
"Unsqueeze": 11
|
||||
},
|
||||
"max_abs_diff_vs_fp32_fixture": 0.014725983142852783,
|
||||
"accuracy": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9660599988937378,
|
||||
"pck@50": 0.9916066654205322,
|
||||
"mpjpe": 0.010310938355326652,
|
||||
"wall_seconds": 36.89548587799072
|
||||
}
|
||||
}
|
||||
},
|
||||
"latency": {
|
||||
"note": "3 interleaved repetitions per variant, median ms/window; onnx_fp32 / onnx_int8_ort_dynamic are same-session references",
|
||||
"onnx_fp32": {
|
||||
"batch1_reps": [
|
||||
4.5327999996516155,
|
||||
2.535649999117595,
|
||||
2.167549997466267
|
||||
],
|
||||
"batch64_reps": [
|
||||
1.9354515624740998,
|
||||
2.4948054687854437,
|
||||
1.9334703125082342
|
||||
],
|
||||
"batch1_ms_per_window_median": 2.535649999117595,
|
||||
"batch64_ms_per_window_median": 1.9354515624740998
|
||||
},
|
||||
"onnx_int8_ort_dynamic": {
|
||||
"batch1_reps": [
|
||||
5.698599999959697,
|
||||
5.721350000385428,
|
||||
4.805099997611251
|
||||
],
|
||||
"batch64_reps": [
|
||||
4.096601562508795,
|
||||
4.857628124995017,
|
||||
4.583800000006022
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.698599999959697,
|
||||
"batch64_ms_per_window_median": 4.583800000006022
|
||||
},
|
||||
"entropy_all": {
|
||||
"batch1_reps": [
|
||||
6.444149999879301,
|
||||
5.038299999796436,
|
||||
5.713200000172947
|
||||
],
|
||||
"batch64_reps": [
|
||||
4.149468750028973,
|
||||
3.437125000004926,
|
||||
4.410960937491382
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.713200000172947,
|
||||
"batch64_ms_per_window_median": 4.149468750028973
|
||||
},
|
||||
"entropy_conv": {
|
||||
"batch1_reps": [
|
||||
4.874750000453787,
|
||||
5.169099998965976,
|
||||
5.236699998931726
|
||||
],
|
||||
"batch64_reps": [
|
||||
3.010160156236452,
|
||||
3.1175546875203963,
|
||||
3.516850781238645
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.169099998965976,
|
||||
"batch64_ms_per_window_median": 3.1175546875203963
|
||||
},
|
||||
"percentile_all": {
|
||||
"batch1_reps": [
|
||||
5.184749999898486,
|
||||
5.2898499998264015,
|
||||
5.916899999647285
|
||||
],
|
||||
"batch64_reps": [
|
||||
4.305105468745296,
|
||||
4.460741406262514,
|
||||
4.184502343747454
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.2898499998264015,
|
||||
"batch64_ms_per_window_median": 4.305105468745296
|
||||
},
|
||||
"percentile_conv": {
|
||||
"batch1_reps": [
|
||||
4.916449999655015,
|
||||
7.150899999032845,
|
||||
5.284949998895172
|
||||
],
|
||||
"batch64_reps": [
|
||||
3.855813281262499,
|
||||
4.688969531230214,
|
||||
5.220103124997877
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.284949998895172,
|
||||
"batch64_ms_per_window_median": 4.688969531230214
|
||||
},
|
||||
"minmax_all": {
|
||||
"batch1_reps": [
|
||||
6.463300000177696,
|
||||
7.149449998905766,
|
||||
5.3209000016067876
|
||||
],
|
||||
"batch64_reps": [
|
||||
3.9251343750095202,
|
||||
4.033442187505898,
|
||||
3.428199218745931
|
||||
],
|
||||
"batch1_ms_per_window_median": 6.463300000177696,
|
||||
"batch64_ms_per_window_median": 3.9251343750095202
|
||||
},
|
||||
"minmax_conv": {
|
||||
"batch1_reps": [
|
||||
5.9961499991914025,
|
||||
5.236549999608542,
|
||||
4.854399998293957
|
||||
],
|
||||
"batch64_reps": [
|
||||
4.368359375007458,
|
||||
3.249617187492504,
|
||||
3.0238906249735464
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.236549999608542,
|
||||
"batch64_ms_per_window_median": 3.249617187492504
|
||||
}
|
||||
},
|
||||
"accuracy_subset": {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted windows excluded, seed-42 random subset (same as quantize_bench/eval_ort_accuracy)",
|
||||
"subset_size": 10000
|
||||
}
|
||||
},
|
||||
"tiny_variant": {
|
||||
"env": {
|
||||
"torch": "2.12.0+cpu",
|
||||
"onnxruntime": "1.26.0",
|
||||
"platform": "Windows-11-10.0.26200-SP0",
|
||||
"num_threads": 16,
|
||||
"checkpoint": "results\\tiny_best.pth",
|
||||
"checkpoint_size_bytes": 340555,
|
||||
"params": 56290,
|
||||
"variant_config": {
|
||||
"tcn": [
|
||||
68,
|
||||
56,
|
||||
44,
|
||||
32
|
||||
],
|
||||
"conv": [
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16
|
||||
],
|
||||
"attn_groups": 2,
|
||||
"groups_mode": "depthwise",
|
||||
"input_pw_groups": 4
|
||||
}
|
||||
},
|
||||
"export": {
|
||||
"mode": "dynamic-batch",
|
||||
"exporter": "torchscript",
|
||||
"opset": 17,
|
||||
"file": "tiny_fp32_dynamic.onnx",
|
||||
"size_bytes": 295279,
|
||||
"size_mb": 0.295279,
|
||||
"verified_batches": [
|
||||
1,
|
||||
2,
|
||||
64
|
||||
],
|
||||
"note": "AdaptiveAvgPool2d((15,1)) replaced at export by an exact mean(-1) + constant averaging matmul (final_width 16 is not a multiple of 15, which the TorchScript exporter rejects); exactness proven by the parity check vs the original torch model"
|
||||
},
|
||||
"parity": {
|
||||
"fixture": "results/parity_fixture.npz input (batch 2, seed 42); reference output recomputed with the tiny torch model",
|
||||
"max_abs_diff_vs_torch": 1.4901161193847656e-07,
|
||||
"pass_lt_1e-4": true
|
||||
},
|
||||
"int8_static_percentile_conv": {
|
||||
"file": "tiny_int8_static_percentile_conv.onnx",
|
||||
"size_bytes": 248278,
|
||||
"size_mb": 0.248278,
|
||||
"calibration": {
|
||||
"method": "percentile",
|
||||
"percentile": 99.99,
|
||||
"windows": 512,
|
||||
"scope": "conv-only TRAIN-split corruption-free",
|
||||
"seconds": 1.5347836017608643
|
||||
},
|
||||
"per_channel": true,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"max_abs_diff_vs_fp32_fixture": 0.018491357564926147
|
||||
},
|
||||
"latency": {
|
||||
"note": "3 interleaved repetitions per variant, median ms/window; full-model sessions are same-session references",
|
||||
"tiny_onnx_fp32": {
|
||||
"batch1_reps": [
|
||||
0.6312500008789357,
|
||||
0.6834500018157996,
|
||||
0.6595999984710943
|
||||
],
|
||||
"batch64_reps": [
|
||||
0.37747578119251557,
|
||||
0.24196640623586063,
|
||||
0.2314671875183194
|
||||
],
|
||||
"batch1_ms_per_window_median": 0.6595999984710943,
|
||||
"batch64_ms_per_window_median": 0.24196640623586063
|
||||
},
|
||||
"tiny_onnx_int8_static_percentile_conv": {
|
||||
"batch1_reps": [
|
||||
0.7988500001374632,
|
||||
0.9382499993080273,
|
||||
0.8451000030618161
|
||||
],
|
||||
"batch64_reps": [
|
||||
0.9211476562995813,
|
||||
1.3045390625165965,
|
||||
1.026230468767153
|
||||
],
|
||||
"batch1_ms_per_window_median": 0.8451000030618161,
|
||||
"batch64_ms_per_window_median": 1.026230468767153
|
||||
},
|
||||
"full_onnx_fp32_reference": {
|
||||
"batch1_reps": [
|
||||
2.267249998112675,
|
||||
2.80170000041835,
|
||||
2.132149998942623
|
||||
],
|
||||
"batch64_reps": [
|
||||
1.3050578124875756,
|
||||
1.4244992187855132,
|
||||
1.8014164062947202
|
||||
],
|
||||
"batch1_ms_per_window_median": 2.267249998112675,
|
||||
"batch64_ms_per_window_median": 1.4244992187855132
|
||||
},
|
||||
"full_onnx_int8_static_percentile_conv_reference": {
|
||||
"batch1_reps": [
|
||||
5.529599999135826,
|
||||
4.768399998283712,
|
||||
6.215800000063609
|
||||
],
|
||||
"batch64_reps": [
|
||||
3.815724218725336,
|
||||
3.1025562500417436,
|
||||
4.333318749957016
|
||||
],
|
||||
"batch1_ms_per_window_median": 5.529599999135826,
|
||||
"batch64_ms_per_window_median": 3.815724218725336
|
||||
}
|
||||
},
|
||||
"accuracy_subset": {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted windows excluded, seed-42 random subset (same as quantize_bench/eval_ort_accuracy/static_ptq_bench)",
|
||||
"subset_size": 10000
|
||||
},
|
||||
"accuracy": {
|
||||
"tiny_onnx_fp32": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.941106667804718,
|
||||
"pck@50": 0.99369333152771,
|
||||
"mpjpe": 0.012527281279861927,
|
||||
"wall_seconds": 10.927234888076782
|
||||
},
|
||||
"tiny_onnx_int8_static_percentile_conv": {
|
||||
"samples": 10000,
|
||||
"pck@20": 0.9268133331298828,
|
||||
"pck@50": 0.9932933319091797,
|
||||
"mpjpe": 0.014906252065300942,
|
||||
"wall_seconds": 12.320892333984375
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
{"variant": "half", "params": 843834, "tcn_channels": [270, 220, 170, 120], "conv_channels": [4, 8, 16, 32], "attn_groups": 4, "groups_mode": "gcd20", "input_pw_groups": 1, "tcn_groups_per_block": [[20, 10], [10, 20], [20, 10], [10, 20]], "conv_strides": [2, 2, 2, 1], "final_width": 15, "batch_size": 64, "max_epochs": 50, "patience": 5, "lr": 0.0001, "weight_decay": 5e-05, "seed": 42, "precision": "fp32", "epochs_run": 28, "best_epoch": 23, "best_val_mpjpe": 0.008576328293592842, "best_val_pck20": 0.9690593021534107, "train_seconds": 1346.4, "torch": "2.11.0+cu128", "error": null, "finished_utc": "2026-06-11T03:09:47Z", "checkpoint": "/home/ruvultra/wiflow-std-bench/sweep/half_best.pth", "test_full": {"samples": 54000, "mpjpe": 0.009419974447676428, "pck@10": 0.8740543655289544, "pck@20": 0.9610469643628156, "pck@30": 0.9813556064146537, "pck@40": 0.9896086878246731, "pck@50": 0.9934827546013726}, "test_clean": {"samples": 52560, "mpjpe": 0.008980081718602137, "pck@10": 0.8840944136840205, "pck@20": 0.9662253179869514, "pck@30": 0.9847971080282144, "pck@40": 0.9917795997050618, "pck@50": 0.9946956242600532}}
|
||||
{"variant": "quarter", "params": 338600, "tcn_channels": [135, 110, 85, 60], "conv_channels": [2, 4, 8, 16], "attn_groups": 2, "groups_mode": "gcd20", "input_pw_groups": 1, "tcn_groups_per_block": [[20, 5], [5, 10], [10, 5], [5, 20]], "conv_strides": [2, 2, 1, 1], "final_width": 15, "batch_size": 64, "max_epochs": 50, "patience": 5, "lr": 0.0001, "weight_decay": 5e-05, "seed": 42, "precision": "fp32", "epochs_run": 50, "best_epoch": 50, "best_val_mpjpe": 0.008780752391864856, "best_val_pck20": 0.9672531302240159, "train_seconds": 1754.4, "torch": "2.11.0+cu128", "error": null, "finished_utc": "2026-06-11T03:39:06Z", "checkpoint": "/home/ruvultra/wiflow-std-bench/sweep/quarter_best.pth", "test_full": {"samples": 54000, "mpjpe": 0.009705399298005634, "pck@10": 0.8646123917014511, "pck@20": 0.9553815319449813, "pck@30": 0.979827209190086, "pck@40": 0.9887037501511751, "pck@50": 0.9931309027671814}, "test_clean": {"samples": 52560, "mpjpe": 0.009279253277105465, "pck@10": 0.8742288637923323, "pck@20": 0.9605315079427745, "pck@30": 0.9833016723076865, "pck@40": 0.9908206971631566, "pck@50": 0.9942719799017071}}
|
||||
{"variant": "tiny", "params": 56290, "tcn_channels": [68, 56, 44, 32], "conv_channels": [2, 4, 8, 16], "attn_groups": 2, "groups_mode": "depthwise", "input_pw_groups": 4, "tcn_groups_per_block": [[540, 68], [68, 56], [56, 44], [44, 32]], "conv_strides": [2, 1, 1, 1], "final_width": 16, "batch_size": 64, "max_epochs": 50, "patience": 5, "lr": 0.0001, "weight_decay": 5e-05, "seed": 42, "precision": "fp32", "epochs_run": 50, "best_epoch": 47, "best_val_mpjpe": 0.012602971208592256, "best_val_pck20": 0.9397210340146666, "train_seconds": 1540.1, "torch": "2.11.0+cu128", "error": null, "finished_utc": "2026-06-11T04:04:50Z", "checkpoint": "/home/ruvultra/wiflow-std-bench/sweep/tiny_best.pth", "test_full": {"samples": 54000, "mpjpe": 0.012859782406853305, "pck@10": 0.7640358444319831, "pck@20": 0.9364815320968628, "pck@30": 0.9731568422317505, "pck@40": 0.9866444962642811, "pck@50": 0.992488939108672}, "test_clean": {"samples": 52560, "mpjpe": 0.012502924276904246, "pck@10": 0.770895526488985, "pck@20": 0.9411073559313967, "pck@30": 0.9764840687790962, "pck@40": 0.9886695077067278, "pck@50": 0.9936238432039409}}
|
||||
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"checkpoint": "/home/ruvultra/wiflow-std-bench/upstream/test/best_pose_model.pth",
|
||||
"test_full": {
|
||||
"samples": 54000,
|
||||
"mpjpe": 0.009834060806367133,
|
||||
"pck@10": 0.8686346120127925,
|
||||
"pck@20": 0.9608815324571398,
|
||||
"pck@30": 0.9789111610695168,
|
||||
"pck@40": 0.9857975759682832,
|
||||
"pck@50": 0.9898827553325229
|
||||
},
|
||||
"test_clean": {
|
||||
"samples": 52560,
|
||||
"mpjpe": 0.009432755044379373,
|
||||
"pck@10": 0.876996495807189,
|
||||
"pck@20": 0.9661454100405608,
|
||||
"pck@30": 0.9823453060205306,
|
||||
"pck@40": 0.987909734176537,
|
||||
"pck@50": 0.9911238361167036
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"published": {
|
||||
"pck@20": 0.9725,
|
||||
"pck@30": 0.9863,
|
||||
"pck@40": 0.9916,
|
||||
"pck@50": 0.9948,
|
||||
"mpjpe": 0.007
|
||||
},
|
||||
"params_millions": 2.225042,
|
||||
"data_dir": "C:\\Users\\ruv\\.cache\\kagglehub\\datasets\\kaka2434\\wiflow-dataset\\versions\\1\\preprocessed_csi_data",
|
||||
"device": "cpu",
|
||||
"test_full": {
|
||||
"samples": 54000,
|
||||
"mpjpe": NaN,
|
||||
"pck@10": 5.6790124349020145e-05,
|
||||
"pck@20": 0.0007876543271596785,
|
||||
"pck@30": 0.007780246982971827,
|
||||
"pck@40": 0.05529259262923841,
|
||||
"pck@50": 0.1542370371548114,
|
||||
"wall_seconds": 118.03756999969482
|
||||
},
|
||||
"test_drop_last": {
|
||||
"samples": 53952,
|
||||
"mpjpe": NaN,
|
||||
"pck@10": 5.6840649370682976e-05,
|
||||
"pck@20": 0.0007883550872372227,
|
||||
"pck@30": 0.007787168910892621,
|
||||
"pck@40": 0.055318307667895535,
|
||||
"pck@50": 0.15425316342412276,
|
||||
"wall_seconds": 120.87458372116089
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,333 @@
|
||||
"""ADR-152 edge optimization follow-up: ONNX Runtime STATIC post-training
|
||||
quantization (calibration-based QDQ) of the retrained WiFlow-STD model, to
|
||||
improve on the dynamic-int8 result (2.44 MB, PCK@20 96.52%, 6.5 ms/win b1).
|
||||
|
||||
Static PTQ pre-computes activation ranges from calibration data, so inference
|
||||
uses QLinearConv/QDQ kernels instead of dynamic ConvInteger -- typically both
|
||||
faster and (with good calibration) closer to fp32 accuracy.
|
||||
|
||||
Method:
|
||||
- Calibration set: corruption-free windows drawn ONLY from the seed-42
|
||||
file-level TRAINING split (same split as eval_repro.py; corrupted windows
|
||||
excluded via results/nan_windows_mask.npy | big_windows_mask.npy), chosen
|
||||
with np.random.default_rng(42). Never test windows.
|
||||
- quantize_static, QuantFormat.QDQ, per-channel int8 weights, int8
|
||||
activations; calibration methods MinMax / Entropy / Percentile(99.99);
|
||||
scopes "all" (ORT default op set) vs "conv" (op_types_to_quantize=
|
||||
["Conv"] -- leaves the attention path, which exports as Einsum/Softmax
|
||||
and elementwise ops, in fp32).
|
||||
- Model is pre-processed first (quant_pre_process: symbolic shape
|
||||
inference + ORT graph optimization, folds BatchNormalization into Conv).
|
||||
- Accuracy: identical protocol to eval_ort_accuracy.py -- the 10,000-window
|
||||
seed-42 subset of the corruption-free test split (PCK@20/50, MPJPE).
|
||||
- Latency: median ms/window at batch 1 (100 runs) and batch 64 (30 runs),
|
||||
3 interleaved repetitions across all variants (fp32 and dynamic-int8
|
||||
sessions included as same-session reference points).
|
||||
|
||||
Usage:
|
||||
PYTHONUTF8=1 .venv/Scripts/python.exe static_ptq_bench.py \
|
||||
[--data-dir <preprocessed_csi_data>] [--subset 10000]
|
||||
[--calib-minmax 1000] [--calib-hist 512] [--skip-accuracy]
|
||||
|
||||
Writes/merges into results/edge_optimization.json under key "onnx_static_ptq".
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, HERE)
|
||||
|
||||
from _bench_common import RESULTS # noqa: E402
|
||||
# quantize_bench sets up upstream imports + the np.load mmap patch
|
||||
# (both via _bench_common.import_upstream)
|
||||
from quantize_bench import build_test_subset # noqa: E402
|
||||
import quantize_bench as qb # noqa: E402
|
||||
from eval_ort_accuracy import evaluate_ort # noqa: E402
|
||||
|
||||
FP32_ONNX = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
|
||||
DYN_INT8_ONNX = os.path.join(RESULTS, "retrained_int8_ort_dynamic.onnx")
|
||||
PREPROC_ONNX = os.path.join(RESULTS, "retrained_fp32_preproc.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calibration data: corruption-free TRAINING-split windows only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_calibration_windows(data_dir, n_windows):
|
||||
"""Seed-42 file-level 70/15/15 TRAIN split (exactly as eval_repro.py),
|
||||
minus corrupted windows, then a seed-42 random draw of n_windows."""
|
||||
dataset = qb.PreprocessedCSIKeypointsDataset(
|
||||
data_dir=data_dir, keypoint_scale=1000.0, enable_temporal_clean=True)
|
||||
train_loader, _va, _te = qb.create_preprocessed_train_val_test_loaders(
|
||||
dataset=dataset, batch_size=64, num_workers=0, random_seed=42)
|
||||
train_indices = np.asarray(train_loader.dataset.indices)
|
||||
|
||||
corrupted = (np.load(os.path.join(RESULTS, "nan_windows_mask.npy"))
|
||||
| np.load(os.path.join(RESULTS, "big_windows_mask.npy")))
|
||||
clean = train_indices[~corrupted[train_indices]]
|
||||
print(f"train split: {len(train_indices)} windows, "
|
||||
f"{len(train_indices) - len(clean)} corrupted excluded, "
|
||||
f"{len(clean)} clean")
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
sel = np.sort(rng.choice(clean, size=n_windows, replace=False))
|
||||
xs = np.stack([dataset[int(i)][0].numpy() for i in sel]).astype(np.float32)
|
||||
print(f"calibration tensor: {xs.shape} from {n_windows} clean TRAIN windows")
|
||||
return xs
|
||||
|
||||
|
||||
def make_reader(windows, batch_size=64):
|
||||
from onnxruntime.quantization import CalibrationDataReader
|
||||
|
||||
class WindowReader(CalibrationDataReader):
|
||||
def __init__(self):
|
||||
self._batches = [windows[i:i + batch_size]
|
||||
for i in range(0, len(windows), batch_size)]
|
||||
self._it = iter(self._batches)
|
||||
|
||||
def get_next(self):
|
||||
b = next(self._it, None)
|
||||
return None if b is None else {"input": b}
|
||||
|
||||
def rewind(self):
|
||||
self._it = iter(self._batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._batches)
|
||||
|
||||
return WindowReader()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# quantization variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def preprocess_model():
|
||||
from onnxruntime.quantization.shape_inference import quant_pre_process
|
||||
quant_pre_process(FP32_ONNX, PREPROC_ONNX)
|
||||
return PREPROC_ONNX
|
||||
|
||||
|
||||
def quantize_variant(src, dst, method, scope, calib_windows):
|
||||
from onnxruntime.quantization import (CalibrationMethod, QuantFormat,
|
||||
QuantType, quantize_static)
|
||||
methods = {
|
||||
"minmax": CalibrationMethod.MinMax,
|
||||
"entropy": CalibrationMethod.Entropy,
|
||||
"percentile": CalibrationMethod.Percentile,
|
||||
}
|
||||
# NB: do NOT pass CalibMaxIntermediateOutputs -- in ORT 1.26 the MinMax
|
||||
# calibrater clears its buffer every N batches and then raises
|
||||
# "No data is collected" if the batch count is divisible by N.
|
||||
extra = {}
|
||||
if method == "percentile":
|
||||
extra["CalibPercentile"] = 99.99
|
||||
op_types = ["Conv"] if scope == "conv" else None
|
||||
|
||||
t0 = time.time()
|
||||
quantize_static(
|
||||
src, dst, make_reader(calib_windows),
|
||||
quant_format=QuantFormat.QDQ,
|
||||
op_types_to_quantize=op_types,
|
||||
per_channel=True,
|
||||
activation_type=QuantType.QInt8,
|
||||
weight_type=QuantType.QInt8,
|
||||
calibrate_method=methods[method],
|
||||
extra_options=extra,
|
||||
)
|
||||
secs = time.time() - t0
|
||||
|
||||
import onnx
|
||||
ops = collections.Counter(n.op_type for n in onnx.load(dst).graph.node)
|
||||
return {
|
||||
"file": os.path.basename(dst),
|
||||
"size_bytes": os.path.getsize(dst),
|
||||
"size_mb": os.path.getsize(dst) / 1e6,
|
||||
"calibration": {"method": method,
|
||||
"windows": int(len(calib_windows)),
|
||||
"percentile": extra.get("CalibPercentile"),
|
||||
"seconds": secs},
|
||||
"scope": scope,
|
||||
"per_channel": True,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
"node_counts": {k: v for k, v in sorted(ops.items())},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# latency (3 interleaved reps, like the latency_controlled_rerun)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ort_session(path):
|
||||
import onnxruntime as ort
|
||||
return ort.InferenceSession(path, providers=["CPUExecutionProvider"])
|
||||
|
||||
|
||||
def bench_ort(sess, batch, n_runs):
|
||||
rng = np.random.default_rng(123)
|
||||
x = rng.random((batch, 540, 20), dtype=np.float32)
|
||||
inp = sess.get_inputs()[0].name
|
||||
for _ in range(max(5, n_runs // 10)):
|
||||
sess.run(None, {inp: x})
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
t0 = time.perf_counter()
|
||||
sess.run(None, {inp: x})
|
||||
times.append(time.perf_counter() - t0)
|
||||
return statistics.median(times) * 1e3 / batch # ms/window
|
||||
|
||||
|
||||
def interleaved_latency(sessions, reps=3, runs_b1=100, runs_b64=30):
|
||||
lat = {name: {"batch1_reps": [], "batch64_reps": []} for name in sessions}
|
||||
for rep in range(reps):
|
||||
for name, sess in sessions.items():
|
||||
lat[name]["batch1_reps"].append(bench_ort(sess, 1, runs_b1))
|
||||
lat[name]["batch64_reps"].append(bench_ort(sess, 64, runs_b64))
|
||||
print(f" rep {rep + 1}/{reps} {name}: "
|
||||
f"b1={lat[name]['batch1_reps'][-1]:.2f} "
|
||||
f"b64={lat[name]['batch64_reps'][-1]:.3f} ms/win", flush=True)
|
||||
for name in lat:
|
||||
lat[name]["batch1_ms_per_window_median"] = statistics.median(
|
||||
lat[name]["batch1_reps"])
|
||||
lat[name]["batch64_ms_per_window_median"] = statistics.median(
|
||||
lat[name]["batch64_reps"])
|
||||
return lat
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
import onnxruntime
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default=os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
||||
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
||||
parser.add_argument("--subset", type=int, default=10000)
|
||||
parser.add_argument("--calib-minmax", type=int, default=1000)
|
||||
parser.add_argument("--calib-hist", type=int, default=512,
|
||||
help="calibration windows for Entropy/Percentile "
|
||||
"(histogram calibraters hold all intermediate "
|
||||
"activations in RAM)")
|
||||
parser.add_argument("--skip-accuracy", action="store_true")
|
||||
parser.add_argument("--methods", default="minmax,entropy,percentile",
|
||||
help="comma list of calibration methods to (re)run; "
|
||||
"results merge into existing onnx_static_ptq")
|
||||
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
results = {
|
||||
"env": {
|
||||
"onnxruntime": onnxruntime.__version__,
|
||||
"torch": torch.__version__,
|
||||
"platform": platform.platform(),
|
||||
"source_model": os.path.basename(FP32_ONNX),
|
||||
},
|
||||
"variants": {},
|
||||
}
|
||||
|
||||
# ---- calibration data (TRAIN split only) -------------------------------
|
||||
calib_mm = build_calibration_windows(args.data_dir, args.calib_minmax)
|
||||
calib_hist = calib_mm[:args.calib_hist]
|
||||
|
||||
# ---- preprocess + quantize ---------------------------------------------
|
||||
print("\n=== quant_pre_process (shape inference + graph optimization) ===")
|
||||
src = preprocess_model()
|
||||
results["env"]["preprocessed_model"] = {
|
||||
"file": os.path.basename(src),
|
||||
"size_mb": os.path.getsize(src) / 1e6,
|
||||
}
|
||||
|
||||
matrix = [(m, s) for m in args.methods.split(",")
|
||||
for s in ("all", "conv")]
|
||||
for method, scope in matrix:
|
||||
name = f"{method}_{scope}"
|
||||
dst = os.path.join(RESULTS, f"retrained_int8_static_{name}.onnx")
|
||||
calib = calib_mm if method == "minmax" else calib_hist
|
||||
print(f"\n=== quantize_static: {name} "
|
||||
f"({len(calib)} calib windows) ===", flush=True)
|
||||
try:
|
||||
results["variants"][name] = quantize_variant(
|
||||
src, dst, method, scope, calib)
|
||||
print(f" {results['variants'][name]['size_mb']:.3f} MB")
|
||||
except Exception as e: # noqa: BLE001
|
||||
results["variants"][name] = {"error": f"{type(e).__name__}: {e}"}
|
||||
print(f" FAILED: {e}")
|
||||
|
||||
# ---- fixture parity (sanity, batch 2) ----------------------------------
|
||||
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
|
||||
fx, fy = fixture["input"], fixture["output"]
|
||||
sessions = {}
|
||||
for name, info in results["variants"].items():
|
||||
if "error" in info:
|
||||
continue
|
||||
path = os.path.join(RESULTS, info["file"])
|
||||
try:
|
||||
sess = ort_session(path)
|
||||
yq = sess.run(None, {sess.get_inputs()[0].name: fx})[0]
|
||||
info["max_abs_diff_vs_fp32_fixture"] = float(np.abs(yq - fy).max())
|
||||
sessions[name] = sess
|
||||
except Exception as e: # noqa: BLE001
|
||||
info["run_error"] = f"{type(e).__name__}: {e}"
|
||||
print("\nfixture max-abs-diff vs fp32:",
|
||||
{n: round(results["variants"][n].get("max_abs_diff_vs_fp32_fixture",
|
||||
float("nan")), 5)
|
||||
for n in results["variants"]})
|
||||
|
||||
# ---- latency: 3 interleaved reps incl. fp32 + dynamic-int8 reference ----
|
||||
print("\n=== latency (3 interleaved reps) ===")
|
||||
lat_sessions = {"onnx_fp32": ort_session(FP32_ONNX),
|
||||
"onnx_int8_ort_dynamic": ort_session(DYN_INT8_ONNX)}
|
||||
lat_sessions.update(sessions)
|
||||
results["latency"] = {
|
||||
"note": "3 interleaved repetitions per variant, median ms/window; "
|
||||
"onnx_fp32 / onnx_int8_ort_dynamic are same-session references",
|
||||
**interleaved_latency(lat_sessions),
|
||||
}
|
||||
|
||||
# ---- accuracy on the standard 10k corruption-free test subset ----------
|
||||
if not args.skip_accuracy:
|
||||
loader, n_clean = build_test_subset(args.data_dir, args.subset)
|
||||
results["accuracy_subset"] = {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted "
|
||||
"windows excluded, seed-42 random subset (same as "
|
||||
"quantize_bench/eval_ort_accuracy)",
|
||||
"subset_size": min(args.subset, n_clean) if args.subset else n_clean,
|
||||
}
|
||||
for name, sess in sessions.items():
|
||||
print(f"\n=== accuracy: {name} ===")
|
||||
results["variants"][name]["accuracy"] = evaluate_ort(
|
||||
sess, loader, name)
|
||||
print(json.dumps(results["variants"][name]["accuracy"], indent=2))
|
||||
|
||||
# ---- merge into edge_optimization.json ----------------------------------
|
||||
merged = {}
|
||||
if os.path.exists(args.out):
|
||||
with open(args.out) as f:
|
||||
merged = json.load(f)
|
||||
prev = merged.get("onnx_static_ptq")
|
||||
if prev: # nested merge so partial --methods reruns don't clobber
|
||||
prev["env"] = results["env"]
|
||||
prev["variants"].update(results["variants"])
|
||||
prev.setdefault("latency", {}).update(results["latency"])
|
||||
if "accuracy_subset" in results:
|
||||
prev["accuracy_subset"] = results["accuracy_subset"]
|
||||
else:
|
||||
merged["onnx_static_ptq"] = results
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(merged, f, indent=2)
|
||||
print(f"\nwrote {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,313 @@
|
||||
"""ADR-152 efficiency-sweep follow-up: edge pipeline for the TINY compact
|
||||
WiFlow-STD variant (56,290 params, results/tiny_best.pth, trained overnight
|
||||
2026-06-10/11 -- see RESULTS.md "Efficiency sweep").
|
||||
|
||||
Headline question: what does the smallest deployable WiFlow-class model look
|
||||
like (KB + ms + PCK)? Reuses the onnx_bench.py / static_ptq_bench.py
|
||||
machinery on the tiny checkpoint:
|
||||
|
||||
1. Load tiny_best.pth with remote/sweep/model_compact.py
|
||||
(depthwise TCN groups, input_pw_groups=4, conv [2,4,8,16], attn groups 2).
|
||||
2. Export ONNX: dynamic batch, opset 17, TorchScript exporter (dynamo=False)
|
||||
-- same recipe that worked for the full model; verified at batch 1/2/64.
|
||||
One forced deviation: tiny's stride schedule [2,1,1,1] leaves final_width
|
||||
16, and the TorchScript exporter cannot export AdaptiveAvgPool2d((15,1))
|
||||
when 15 is not a factor of the input height (the full model never hit
|
||||
this -- its width was exactly 15). The adaptive pool over a fixed-size
|
||||
feature map is a fixed linear map, so the export wrapper replaces it with
|
||||
an exact matmul equivalent (PyTorch adaptive-pool bin semantics:
|
||||
bin i averages rows floor(i*H/K)..ceil((i+1)*H/K)); the W axis (20->1,
|
||||
a factor) becomes mean(-1). Exactness is proven by the parity check
|
||||
below, which compares against the ORIGINAL torch model with the real
|
||||
AdaptiveAvgPool2d.
|
||||
3. Torch-vs-ORT parity on the stored fixture input
|
||||
(results/parity_fixture.npz, batch 2, seed 42 -- same 540x20 input layout;
|
||||
reference output recomputed with the tiny torch model). PASS < 1e-4.
|
||||
4. Static QDQ conv-only int8 (quant_pre_process + quantize_static,
|
||||
per-channel QInt8 weights+activations, Percentile(99.99) calibration on
|
||||
512 corruption-free TRAIN-split windows -- the winning recipe and
|
||||
calibration count from static_ptq_bench.py. 512, not "about 500":
|
||||
ORT 1.26's histogram collector np.asarray()'s the per-batch maxima, so
|
||||
the calibration count must be a multiple of the batch size 64 or the
|
||||
ragged last batch crashes it).
|
||||
5. Disk size + CPU latency b1/b64 (3 interleaved reps, median ms/window)
|
||||
for tiny fp32 + tiny int8, with the full-model ONNX fp32 + static-int8
|
||||
sessions interleaved as same-session references.
|
||||
6. Accuracy (PCK@20/50 + MPJPE) on the identical 10k-window seed-42
|
||||
corruption-free test subset for tiny fp32 + tiny int8.
|
||||
|
||||
Usage:
|
||||
PYTHONUTF8=1 .venv/Scripts/python.exe tiny_edge_bench.py \
|
||||
[--data-dir <preprocessed_csi_data>] [--subset 10000] [--calib 512]
|
||||
(--calib must be a multiple of 64; see step 4 above)
|
||||
|
||||
Writes/merges into results/edge_optimization.json under key "tiny_variant".
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
RESULTS = os.path.join(HERE, "results")
|
||||
sys.path.insert(0, HERE)
|
||||
sys.path.insert(0, os.path.join(HERE, "remote", "sweep"))
|
||||
|
||||
# quantize_bench sets up upstream imports + the np.load mmap patch
|
||||
from quantize_bench import build_test_subset # noqa: E402
|
||||
from eval_ort_accuracy import evaluate_ort # noqa: E402
|
||||
from static_ptq_bench import ( # noqa: E402
|
||||
build_calibration_windows,
|
||||
interleaved_latency,
|
||||
make_reader,
|
||||
ort_session,
|
||||
)
|
||||
from model_compact import CompactWiFlowPoseModel, describe # noqa: E402
|
||||
|
||||
TINY_CKPT = os.path.join(RESULTS, "tiny_best.pth")
|
||||
TINY_FP32_ONNX = os.path.join(RESULTS, "tiny_fp32_dynamic.onnx")
|
||||
TINY_PREPROC_ONNX = os.path.join(RESULTS, "tiny_fp32_preproc.onnx")
|
||||
TINY_INT8_ONNX = os.path.join(RESULTS, "tiny_int8_static_percentile_conv.onnx")
|
||||
FULL_FP32_ONNX = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
|
||||
FULL_INT8_ONNX = os.path.join(RESULTS, "retrained_int8_static_percentile_conv.onnx")
|
||||
|
||||
# Exact tiny config from remote/sweep/run_sweep.py VARIANTS (measured 56,290
|
||||
# params, clean-test PCK@20 94.11% -- results/efficiency_sweep.jsonl).
|
||||
TINY = dict(tcn=[68, 56, 44, 32], conv=[2, 4, 8, 16], attn_groups=2,
|
||||
groups_mode="depthwise", input_pw_groups=4)
|
||||
|
||||
|
||||
def load_tiny_model():
|
||||
model = CompactWiFlowPoseModel(
|
||||
tcn_channels=TINY["tcn"], conv_channels=TINY["conv"],
|
||||
attn_groups=TINY["attn_groups"], groups_mode=TINY["groups_mode"],
|
||||
input_pw_groups=TINY["input_pw_groups"], dropout=0.5)
|
||||
state = torch.load(TINY_CKPT, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(state, strict=True)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def adaptive_pool_matrix(h_in, h_out):
|
||||
"""Exact AdaptiveAvgPool1d as a (h_out, h_in) averaging matrix, using
|
||||
PyTorch's bin rule: bin i covers rows floor(i*h_in/h_out) ..
|
||||
ceil((i+1)*h_in/h_out)."""
|
||||
w = torch.zeros(h_out, h_in)
|
||||
for i in range(h_out):
|
||||
s = (i * h_in) // h_out
|
||||
e = -((-(i + 1) * h_in) // h_out) # ceil division
|
||||
w[i, s:e] = 1.0 / (e - s)
|
||||
return w
|
||||
|
||||
|
||||
class ExportWrapper(torch.nn.Module):
|
||||
"""CompactWiFlowPoseModel forward with the AdaptiveAvgPool2d((K,1))
|
||||
replaced by an exact fixed linear map (mean over the factor W axis, then
|
||||
a constant averaging matmul over the non-factor H axis) so the
|
||||
TorchScript ONNX exporter accepts it. Bit-equivalent up to float
|
||||
round-off; proven by the parity check against the original model."""
|
||||
|
||||
def __init__(self, m, num_keypoints=15):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.register_buffer(
|
||||
"pool_w_t", adaptive_pool_matrix(m.final_width, num_keypoints).t())
|
||||
|
||||
def forward(self, x):
|
||||
m = self.m
|
||||
x = m.tcn(x)
|
||||
x = x.transpose(1, 2).unsqueeze(1)
|
||||
x = m.up(x)
|
||||
for block in m.residual_blocks:
|
||||
x = block(x)
|
||||
x = x.permute(0, 1, 3, 2)
|
||||
x = m.attention(x)
|
||||
x = m.decoder(x) # [B, 2, H=final_width, T=20]
|
||||
x = x.mean(-1) # W-axis pool (20 -> 1, a factor)
|
||||
x = x.matmul(self.pool_w_t) # exact adaptive H pool: [B, 2, K]
|
||||
return x.transpose(1, 2) # [B, K, 2]
|
||||
|
||||
|
||||
def export_onnx(model):
|
||||
"""Dynamic-batch TorchScript export (the recipe that worked for the full
|
||||
model in onnx_bench.py), verified at batch 1/2/64. Uses ExportWrapper
|
||||
(see docstring) because final_width 16 is not a multiple of 15."""
|
||||
wrapper = ExportWrapper(model).eval()
|
||||
x = torch.rand(2, 540, 20)
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapper, (x,), TINY_FP32_ONNX, opset_version=17,
|
||||
input_names=["input"], output_names=["output"], dynamo=False,
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
|
||||
sess = ort_session(TINY_FP32_ONNX)
|
||||
inp = sess.get_inputs()[0].name
|
||||
for b in (1, 2, 64):
|
||||
y = sess.run(None, {inp: np.zeros((b, 540, 20), dtype=np.float32)})[0]
|
||||
assert y.shape == (b, 15, 2), y.shape
|
||||
return {
|
||||
"mode": "dynamic-batch", "exporter": "torchscript", "opset": 17,
|
||||
"file": os.path.basename(TINY_FP32_ONNX),
|
||||
"size_bytes": os.path.getsize(TINY_FP32_ONNX),
|
||||
"size_mb": os.path.getsize(TINY_FP32_ONNX) / 1e6,
|
||||
"verified_batches": [1, 2, 64],
|
||||
"note": "AdaptiveAvgPool2d((15,1)) replaced at export by an exact "
|
||||
"mean(-1) + constant averaging matmul (final_width 16 is not "
|
||||
"a multiple of 15, which the TorchScript exporter rejects); "
|
||||
"exactness proven by the parity check vs the original torch "
|
||||
"model",
|
||||
}
|
||||
|
||||
|
||||
def quantize_tiny(calib_windows):
|
||||
"""quant_pre_process + static QDQ conv-only Percentile(99.99) int8 --
|
||||
the winning recipe from static_ptq_bench.py."""
|
||||
from onnxruntime.quantization import (CalibrationMethod, QuantFormat,
|
||||
QuantType, quantize_static)
|
||||
from onnxruntime.quantization.shape_inference import quant_pre_process
|
||||
|
||||
quant_pre_process(TINY_FP32_ONNX, TINY_PREPROC_ONNX)
|
||||
t0 = time.time()
|
||||
quantize_static(
|
||||
TINY_PREPROC_ONNX, TINY_INT8_ONNX, make_reader(calib_windows),
|
||||
quant_format=QuantFormat.QDQ,
|
||||
op_types_to_quantize=["Conv"],
|
||||
per_channel=True,
|
||||
activation_type=QuantType.QInt8,
|
||||
weight_type=QuantType.QInt8,
|
||||
calibrate_method=CalibrationMethod.Percentile,
|
||||
extra_options={"CalibPercentile": 99.99},
|
||||
)
|
||||
return {
|
||||
"file": os.path.basename(TINY_INT8_ONNX),
|
||||
"size_bytes": os.path.getsize(TINY_INT8_ONNX),
|
||||
"size_mb": os.path.getsize(TINY_INT8_ONNX) / 1e6,
|
||||
"calibration": {"method": "percentile", "percentile": 99.99,
|
||||
"windows": int(len(calib_windows)),
|
||||
"scope": "conv-only TRAIN-split corruption-free",
|
||||
"seconds": time.time() - t0},
|
||||
"per_channel": True,
|
||||
"activation_type": "QInt8",
|
||||
"weight_type": "QInt8",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import onnxruntime
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default=os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
||||
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
||||
parser.add_argument("--subset", type=int, default=10000)
|
||||
parser.add_argument("--calib", type=int, default=512,
|
||||
help="calibration windows; must be a multiple of the "
|
||||
"64-window calibration batch (ORT histogram "
|
||||
"collector rejects ragged batches)")
|
||||
parser.add_argument("--skip-accuracy", action="store_true")
|
||||
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.calib % 64 != 0:
|
||||
parser.error(
|
||||
f"--calib must be a multiple of 64 (got {args.calib}): ORT 1.26's "
|
||||
f"histogram calibration collector np.asarray()'s the per-batch "
|
||||
f"maxima and crashes on a ragged final batch (calibration batch "
|
||||
f"size is 64)")
|
||||
|
||||
model = load_tiny_model()
|
||||
info = describe(model)
|
||||
print(f"tiny model: {info['params']:,} params, tcn_groups={info['tcn_groups_per_block']}, "
|
||||
f"strides={info['conv_strides']}, final_width={info['final_width']}")
|
||||
assert info["params"] == 56290, info["params"]
|
||||
|
||||
results = {
|
||||
"env": {
|
||||
"torch": torch.__version__,
|
||||
"onnxruntime": onnxruntime.__version__,
|
||||
"platform": platform.platform(),
|
||||
"num_threads": torch.get_num_threads(),
|
||||
"checkpoint": os.path.relpath(TINY_CKPT, HERE),
|
||||
"checkpoint_size_bytes": os.path.getsize(TINY_CKPT),
|
||||
"params": info["params"],
|
||||
"variant_config": TINY,
|
||||
},
|
||||
}
|
||||
|
||||
# ---- export + parity ----------------------------------------------------
|
||||
print("\n=== ONNX export (dynamic batch, opset 17, torchscript) ===")
|
||||
results["export"] = export_onnx(model)
|
||||
print(f" {results['export']['size_mb']:.3f} MB, batches {results['export']['verified_batches']} OK")
|
||||
|
||||
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
|
||||
fx = fixture["input"] # (2, 540, 20), seed 42 -- same input layout as full model
|
||||
sess_fp32 = ort_session(TINY_FP32_ONNX)
|
||||
y_ort = sess_fp32.run(None, {sess_fp32.get_inputs()[0].name: fx})[0]
|
||||
with torch.no_grad():
|
||||
y_torch = model(torch.from_numpy(fx)).numpy()
|
||||
results["parity"] = {
|
||||
"fixture": "results/parity_fixture.npz input (batch 2, seed 42); "
|
||||
"reference output recomputed with the tiny torch model",
|
||||
"max_abs_diff_vs_torch": float(np.abs(y_ort - y_torch).max()),
|
||||
"pass_lt_1e-4": bool(np.abs(y_ort - y_torch).max() < 1e-4),
|
||||
}
|
||||
print("parity:", json.dumps(results["parity"], indent=2))
|
||||
assert results["parity"]["pass_lt_1e-4"], "torch-vs-ORT parity FAILED"
|
||||
|
||||
# ---- static PTQ int8 ------------------------------------------------------
|
||||
print(f"\n=== static QDQ int8 (Percentile conv-only, {args.calib} calib windows) ===")
|
||||
calib = build_calibration_windows(args.data_dir, args.calib)
|
||||
results["int8_static_percentile_conv"] = quantize_tiny(calib)
|
||||
print(f" {results['int8_static_percentile_conv']['size_mb']:.3f} MB")
|
||||
sess_int8 = ort_session(TINY_INT8_ONNX)
|
||||
yq = sess_int8.run(None, {sess_int8.get_inputs()[0].name: fx})[0]
|
||||
results["int8_static_percentile_conv"]["max_abs_diff_vs_fp32_fixture"] = float(
|
||||
np.abs(yq - y_torch).max())
|
||||
|
||||
# ---- latency (3 interleaved reps, full-model sessions as references) -----
|
||||
print("\n=== latency (3 interleaved reps) ===")
|
||||
lat_sessions = {
|
||||
"tiny_onnx_fp32": sess_fp32,
|
||||
"tiny_onnx_int8_static_percentile_conv": sess_int8,
|
||||
"full_onnx_fp32_reference": ort_session(FULL_FP32_ONNX),
|
||||
"full_onnx_int8_static_percentile_conv_reference": ort_session(FULL_INT8_ONNX),
|
||||
}
|
||||
results["latency"] = {
|
||||
"note": "3 interleaved repetitions per variant, median ms/window; "
|
||||
"full-model sessions are same-session references",
|
||||
**interleaved_latency(lat_sessions),
|
||||
}
|
||||
|
||||
# ---- accuracy on the standard 10k corruption-free test subset ------------
|
||||
if not args.skip_accuracy:
|
||||
loader, n_clean = build_test_subset(args.data_dir, args.subset)
|
||||
results["accuracy_subset"] = {
|
||||
"description": "seed-42 file-level 70/15/15 test split, corrupted "
|
||||
"windows excluded, seed-42 random subset (same as "
|
||||
"quantize_bench/eval_ort_accuracy/static_ptq_bench)",
|
||||
"subset_size": min(args.subset, n_clean) if args.subset else n_clean,
|
||||
}
|
||||
results["accuracy"] = {}
|
||||
for name, sess in (("tiny_onnx_fp32", sess_fp32),
|
||||
("tiny_onnx_int8_static_percentile_conv", sess_int8)):
|
||||
print(f"\n=== accuracy: {name} ===")
|
||||
results["accuracy"][name] = evaluate_ort(sess, loader, name)
|
||||
print(json.dumps(results["accuracy"][name], indent=2))
|
||||
|
||||
# ---- merge into edge_optimization.json -----------------------------------
|
||||
merged = {}
|
||||
if os.path.exists(args.out):
|
||||
with open(args.out) as f:
|
||||
merged = json.load(f)
|
||||
merged["tiny_variant"] = results
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(merged, f, indent=2)
|
||||
print(f"\nwrote {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -47,13 +47,16 @@ Adopt four changes, ordered by effort-vs-gain:
|
||||
|
||||
1. **Record transceiver geometry at enrollment.** `EnrollmentProtocol` gains an optional `NodeGeometry` record per node (position estimate, antenna orientation, inter-node distances where known). Stored alongside the room baseline in the bank; schema-versioned so existing banks remain readable.
|
||||
2. **Fuse geometry embeddings into specialist training.** Where a specialist head consumes the (future, ADR-150) backbone embedding, concatenate a small learned embedding of `NodeGeometry` — the PerceptAlign mechanism, transplanted to our per-room banks. Statistical specialists (current) ignore it; LoRA heads (ADR-151 P6) consume it.
|
||||
3. **Adopt the two-checkerboard alignment for the camera-supervised path (ADR-079).** When MediaPipe supervision is used, calibrate camera↔WiFi into one shared 3D frame before regression (<5 min, two checkerboards, a few photos). This is the direct defense against F1 for our 92.9%-PCK@20 pipeline.
|
||||
3. **Adopt the two-checkerboard alignment for the camera-supervised path (ADR-079).** When MediaPipe supervision is used, calibrate camera↔WiFi into one shared 3D frame before regression (<5 min, two checkerboards, a few photos). This is the direct defense against F1 for our camera-supervised pipeline. ~~92.9%-PCK@20~~ — *that figure was retracted during measurement (b) (2026-06-10): the surviving holdout shows a constant-output model under an absolute (non-torso) threshold on 69 near-static frames; mean predictor scores 100% under the same protocol. The §2.2 no-citation rule now applies to it.*
|
||||
4. **Evaluate on the PerceptAlign cross-domain dataset** (21 subjects / 7 layouts) as the MERIDIAN cross-layout benchmark — *gated on confirming its license and downloadability* (open question; repo per paper: github.com/Trymore-lab/PerceptAlign).
|
||||
> **Gate resolved (2026-06-10, MEASURED by repo inspection):** repo exists, **MIT license**, dataset downloadable from HuggingFace (5 per-scene repos, raw CSI + separate vision keypoints; Intel 5300, 1TX×3RX×3 ant, 57 subcarriers — same order as ESP32 subcarrier counts; Scene3 ships 3 distinct layouts). Code present, no pretrained weights. Benchmark adoption unblocked; dataset-side license terms inherit HF dataset terms (not separately stated — check at download time).
|
||||
|
||||
### 2.2 Benchmark against WiFlow-STD (DY2434) — ACCEPTED
|
||||
|
||||
Pull the Apache-2.0 weights + 360k-sample dataset; run three measurements: (a) their model on their data (reproduce 97.25% claim), (b) their model fine-tuned on our ESP32 17-keypoint eval set, (c) our internal WiFlow on their dataset (15-keypoint subset mapping). Until (a)–(c) are measured, **no RuView doc may cite 97.25% as a comparable number** — different dataset, subjects, keypoints.
|
||||
|
||||
> **Status (2026-06-10, measurement (a) complete — `benchmarks/wiflow-std/RESULTS.md`):** shipped checkpoint REFUTED (0.08% PCK@20 — wrong keypoint normalization, predates published code); released code does not run as published (6 defects, incl. broken package import and an unreachable test phase); released dataset's last 13 files are corrupted (9,072 windows: NaN + float32-max garbage, diverges fp16 training via BatchNorm poisoning). After repairing both, retraining with upstream defaults reproduced **96.09% PCK@20 full-test / 96.61% corruption-free / MPJPE 0.0094–0.0098** (published: 97.25% / 0.007) on an RTX 5080. Accuracy claims graded MEASURED-EQUIVALENT; params (2.23M) and FLOPs (~0.055G) verified. (b)/(c) remain open.
|
||||
|
||||
### 2.3 Apply the UNSW recipe to the ADR-150 encoder — ACCEPTED (amends ADR-150 §2.3)
|
||||
|
||||
- Pretraining corpus: start from the same 14 public datasets (1.3M samples) + our home/MM-Fi frames; data aggregation takes priority over architecture work.
|
||||
@@ -62,7 +65,7 @@ Pull the Apache-2.0 weights + 360k-sample dataset; run three measurements: (a) t
|
||||
|
||||
### 2.4 Hardware watch items — ACCEPTED (no code now)
|
||||
|
||||
- **802.11bf**: track silicon/certification; revisit when any commodity chipset exposes standardized sensing measurements. Our opportunistic CSI extraction remains the mechanism until then.
|
||||
- **802.11bf**: track silicon/certification; OTA binding remains deferred until commodity chipsets expose standardized sensing measurements. **Amended by ADR-153** (2026-06-10): implement a pure Rust forward-compatibility protocol layer now — typed procedure models, a deterministic session FSM, a transport abstraction, simulation tests, and an `OpportunisticCsiBridge` that maps today's ESP32 CSI batches into standardized sensing-report shape.
|
||||
- **esp_wifi_sensing**: benchmark our presence pipeline against the vendor FSM (one afternoon; useful external baseline). Do **not** treat as drop-in (refuted claim).
|
||||
- **ZTECSITool AP**: optional high-resolution anchor node for the ADR-029 multistatic mesh — procurement-gated; only pursue if a 160 MHz anchor materially helps tomography.
|
||||
|
||||
@@ -71,6 +74,29 @@ Pull the Apache-2.0 weights + 360k-sample dataset; run three measurements: (a) t
|
||||
- No pivot toward "wireless foundation model" papers that don't ship WiFi-CSI artifacts (HeterCSI, FMCW pilot, surveys).
|
||||
- No DensePose-UV work item: the field has not demonstrated UV regression from commodity WiFi; keypoints remain our supervised target (F5).
|
||||
|
||||
### 2.6 RuVector vendor sync + integration opportunities (added 2026-06-10)
|
||||
|
||||
**Vendor sync record.** `vendor/ruvector` moved from pin `e38347601` (2026-05-07) to `a083bd77f` (origin/main, 3 commits past tag `ruvector-v0.2.28`; vendored workspace version 2.2.3). 111 commits in the range, roughly half NAPI-binary/lint chores. Substantive: graph condensation + differentiable min-cut (#547), core HNSW correctness fixes v2.2.3 (#502), RUSTSEC/clippy hardening (#504), ONNX embedder API-contract fix (#523/#525 — npm/TypeScript package only), dead parallel-worker import removal (#532). *Evidence: MEASURED (git range + commit-stat inspection).*
|
||||
|
||||
**Opportunity table.** Workspace policy is crates.io versions only, so unpublished crates are WATCH by definition regardless of fit.
|
||||
|
||||
| Crate | What it offers | wifi-densepose target | crates.io | Verdict |
|
||||
|---|---|---|---|---|
|
||||
| `ruvector-graph-condense` (new, #547) | Training-free min-cut graph condensation + **differentiable normalized-cut loss** (`DiffCutCondenser`, analytic MinCutPool-style gradients, gradient-checked tests; provenance-retaining super-nodes) | `subcarrier_selection.rs` (condense 114 subcarriers into cut-preserving regions instead of raw min-cut); auxiliary clustering regularizer for `wifi-densepose-train`; `DynamicPersonMatcher` region structure | **Not published** | **WATCH** — strongest technical fit in the sync; adopt when published. README's "no published method uses graph-cut condensation" is CLAIMED; the diffcut implementation + tests are MEASURED |
|
||||
| `ruvector-attention` 2.1.0 | #304 SOTA modules: MLA, KV-cache, SSM, sparse/MoE, hybrid search, Graph RAG (publish date 2026-03-27 matches the #304 commit — MEASURED) | Supersedes pinned 2.0.4 used by `model.rs` spatial attention + `bvp.rs`; SSM/MLA are candidate pure-Rust edge-inference primitives for the ADR-150 encoder | 2.1.0 (pinned **2.0.4**) | **ADOPT** (minor bump; API-compat check first) |
|
||||
| `ruvector-gnn` 2.2.0 | panic→`Result` constructors, gradient clipping, MSE/CE/BCE losses, seeded-RNG layer init (#495 is post-2.2.0) | `wifi-densepose-train` GNN path (pinned 2.0.5, `default-features = false`) | 2.2.0 (pinned **2.0.5**) | **ADOPT** (bump) |
|
||||
| `ruvector-mincut` / `ruvector-solver` 2.0.6 | Patch-level fixes (workspace republish 2026-03-25) | `metrics.rs` DynamicPersonMatcher, subcarrier interpolation, triangulation | 2.0.6 (pinned **2.0.4** each) | **ADOPT** (routine patch bump) |
|
||||
| `ruvector-core` 2.2.3 (vendor) | HNSW correctness: k=0 guard, sorted results, flat-index fixes, cross-integration helpers (#502 — MEASURED, `index/hnsw.rs` + new integration tests) | `homecore-recorder` `RuvectorSemanticIndex` (real HNSW consumer); `sketch.rs` quantization unaffected | **2.2.0 = latest published**; 2.2.3 unpublished | **WATCH** — bump the moment 2.2.3 publishes |
|
||||
| `ruvector-cnn` 2.0.6 | Pure-Rust SIMD conv kernels (AVX2/NEON/WASM), MobileNetV3, INT8 quantization, contrastive losses (InfoNCE/triplet, #252) | **Not** the WiFlow-STD training port — `wiflow_std/model.rs` is tch/libtorch (MEASURED). Relevant to the *edge inference* path of the trained ~2.2 MB int8 model, and InfoNCE/triplet overlaps AETHER (ADR-024) | 2.0.6 | **EVALUATE** — only if/when we commit to a no-libtorch edge runtime for WiFlow-STD-class models |
|
||||
| `ruvector-acorn` (new-ish) | ACORN predicate-agnostic filtered HNSW (SIGMOD'24 algorithm; γ·M denser graphs for low-selectivity filters) | Metadata-filtered pattern search over ADR-151 calibration banks — speculative; bank sizes are far below where filtered-ANN recall collapse matters | **Not published** | **WATCH** |
|
||||
| `ruvector-cluster` 2.0.6 | Distributed sharding, gossip discovery, DAG consensus | No current need; ADR-029 mesh coordination is ESP32-side, not vector-DB-side | 2.0.6 | **WATCH** |
|
||||
| ONNX embedder fix (#523/#525) | API-contract + packaging fixes in `npm/packages/ruvector` (TypeScript) | None — `wifi-densepose-nn`'s ONNX backend is Rust (ort/tract), untouched by this change (MEASURED: commit touches npm/ only) | n/a | No action |
|
||||
| `ruvector-perception` (new, #547) | "Physical perception substrate" (hypothesis/topology/witness modules) — agent-perception oriented, not RF | None identified | Not published | WATCH (name-overlap only) |
|
||||
|
||||
**Security note (RUSTSEC #504).** The substantive fixes target `ruvllm`, `ruvector-dag`, `prime-radiant`, `rvagent-*`, and the `ruvector-server` HTTP endpoint (NaN-safe `partial_cmp`, input-validation guards, env-allowlisted exec) — **none of which we pin**. The commit states `cargo audit` returns clean across the workspace. *Evidence: MEASURED (commit message + file list). Conclusion: no pinned version has an outstanding advisory; no urgent bump required.* The NaN-sort hardening is panic-robustness hygiene our pinned 2.0.4-era crates predate, which is one more reason for the routine bumps below.
|
||||
|
||||
**Version-bump recommendations (follow-up PR — no Cargo.toml change in this ADR):** `ruvector-mincut` 2.0.4→2.0.6, `ruvector-solver` 2.0.4→2.0.6, `ruvector-attention` 2.0.4→2.1.0, `ruvector-gnn` 2.0.5→2.2.0. Current: `ruvector-core` 2.2.0, `ruvector-attn-mincut` 2.0.4, `ruvector-temporal-tensor` 2.0.6, `ruvector-crv` 0.1.1 — all at latest published. Nothing in the sync changes §2.1.2 geometry conditioning (our `viewpoint/attention.rs` `GeometricBias` already implements the fusion mechanism) or the ADR-150 MAE recipe (training stays in tch).
|
||||
|
||||
## 3. Consequences
|
||||
|
||||
**Positive:** the calibration system gains the one mechanism (geometry conditioning) the 2026 literature identifies as the difference between layout-brittle and layout-robust supervised WiFi pose; ADR-150 gets a measured training recipe instead of a guessed one; we acquire two external benchmarks (WiFlow-STD, PerceptAlign dataset) to keep our claims honest.
|
||||
@@ -82,6 +108,7 @@ Pull the Apache-2.0 weights + 360k-sample dataset; run three measurements: (a) t
|
||||
## 4. Open questions (carried from the research run)
|
||||
|
||||
1. Does WiFlow-STD retain accuracy when fine-tuned on ESP32-S3/C6 CSI (fewer subcarriers, lower SNR), scored on our 17-keypoint set? (§2.2 answers this.)
|
||||
> **Partial answer (MEASURED 2026-06-11, measurement (b) on 2,046 single-room windows — `benchmarks/wiflow-std/RESULTS.md`):** pretrained init shows strong *optimization* transfer (65% PCK@20 vs scratch's 0% collapse under the same budget) but **no feature transfer** (frozen-trunk + linear adapter ≈ 0%). And no run beat the mean-pose baseline (95.9% PCK@20 — single subject, near-static normalized coords), so no CSI→pose capability is citable from this data. A definitive answer needs multi-subject/multi-position data where the mean pose is weak.
|
||||
2. Is the PerceptAlign dataset downloadable under a usable license, and does the two-checkerboard procedure work with ESP32 transceiver geometry? (§2.1.4 gate.)
|
||||
3. Will esp_wifi_sensing evolve toward 802.11bf compliance, replacing opportunistic CSI extraction?
|
||||
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
# ADR-153: IEEE 802.11bf-2025 Forward-Compatibility Protocol Model for wifi-densepose-hardware
|
||||
|
||||
- **Status**: accepted
|
||||
- **Date**: 2026-06-10
|
||||
- **Deciders**: ruv
|
||||
- **Tags**: hardware, protocol, sensing, 802.11bf, forward-compatibility
|
||||
|
||||
## Context
|
||||
|
||||
IEEE 802.11bf-2025 (WLAN Sensing) is an **Active Standard**: board approval
|
||||
2025-05-28, published 2025-09-26 (verified against the IEEE SA record,
|
||||
<https://standards.ieee.org/ieee/802.11bf/11574/>). Its scope modifies the
|
||||
MAC, HE and EHT PHY service interfaces, plus DMG and EDMG PHYs, for WLAN
|
||||
sensing in **1–7.125 GHz** and **above 45 GHz** bands, with formal sensing
|
||||
measurement setup, measurement instance, feedback/reporting, and
|
||||
sensing-by-proxy (SBP) procedures (ADR-152 F4, evidence grade MEASURED).
|
||||
|
||||
No commodity silicon implements the standard yet — ESP32 parts included.
|
||||
ADR-152 §2.4 therefore decided "track silicon; no code now", with RuView's
|
||||
opportunistic CSI extraction remaining the mechanism. That left a gap: when
|
||||
silicon does land, RuView would have no typed model of the standard's
|
||||
procedures to bind to, and the integration would start from zero.
|
||||
|
||||
ADR-152 §2.4 originally classified 802.11bf as a hardware watch item with no
|
||||
implementation work until commodity silicon exposes standardized sensing
|
||||
measurements. This ADR amends that clause: OTA binding remains deferred, but
|
||||
a pure Rust protocol model, session FSM, transport seam, and opportunistic
|
||||
CSI bridge will be implemented now so RuView consumers can target a stable
|
||||
standardized sensing interface before silicon arrives.
|
||||
|
||||
The user directed (2026-06-10) that this **forward-compatibility protocol
|
||||
model** — a protocol surface, not a conformance implementation — be built
|
||||
now.
|
||||
|
||||
## Decision
|
||||
|
||||
Implement an `ieee80211bf` **forward-compatibility protocol model** in
|
||||
`wifi-densepose-hardware` (pure Rust, no internal deps, simulation-testable,
|
||||
no OTA path):
|
||||
|
||||
> This module is not a certified 802.11bf implementation. It models the
|
||||
> public procedure shape needed by RuView and RuvSense, while intentionally
|
||||
> avoiding OTA frame binding until chipset support and vendor APIs exist.
|
||||
|
||||
1. **`types.rs`** — typed structures for the standard's sensing procedures
|
||||
(sub-7 GHz focus; DMG stubbed): Sensing Measurement Setup (setup ID,
|
||||
initiator/responder and transmitter/receiver roles, bandwidth,
|
||||
periodicity, threshold-based reporting parameters), Sensing Measurement
|
||||
Instance, Sensing Measurement Report (CSI-variant payload), SBP
|
||||
request/response, termination. Two future-proofing requirements:
|
||||
|
||||
- **Version gates** — every negotiated surface is tagged with a spec
|
||||
profile, because vendors will expose partial or renamed capabilities
|
||||
first:
|
||||
|
||||
```rust
|
||||
pub enum SpecProfile {
|
||||
DraftCompatible,
|
||||
Ieee80211Bf2025,
|
||||
VendorExtension(String),
|
||||
}
|
||||
```
|
||||
|
||||
- **Capability negotiation** — no hardcoded ESP32 assumptions in the
|
||||
future-silicon path:
|
||||
|
||||
```rust
|
||||
pub struct SensingCapabilities {
|
||||
pub sub_7_ghz: bool,
|
||||
pub dmg: bool,
|
||||
pub edmg: bool,
|
||||
pub csi_report: bool,
|
||||
pub threshold_reporting: bool,
|
||||
pub sensing_by_proxy: bool,
|
||||
pub max_bandwidth_mhz: u16,
|
||||
pub max_period_ms: u32,
|
||||
pub max_active_setups: u16,
|
||||
}
|
||||
```
|
||||
|
||||
- **Privacy and governance fields** — sensing is presence inference, not
|
||||
just radio telemetry. Every `SensingMeasurementSetup` carries policy
|
||||
metadata (required, not optional), for enterprise, elderly-care,
|
||||
retail, workplace, and municipal deployments:
|
||||
|
||||
```rust
|
||||
pub enum ConsentMode {
|
||||
LabOnly,
|
||||
ExplicitConsent,
|
||||
ManagedEnterprisePolicy,
|
||||
Disabled,
|
||||
}
|
||||
```
|
||||
|
||||
2. **`session.rs`** — deterministic event-driven session state machine:
|
||||
`Idle → SetupNegotiating → Active → Terminating → Idle`, with explicit
|
||||
rejection paths (unsupported parameters, setup-ID collision) and timeout
|
||||
handling.
|
||||
3. **`transport.rs`** — a `SensingTransport` trait abstracting frame
|
||||
exchange; a `SimTransport` test double; and an `OpportunisticCsiBridge`
|
||||
adapter mapping today's ESP32 CSI extraction onto the report path
|
||||
(measurement instances ≈ CSI frame batches), so current hardware sits
|
||||
behind the standardized interface. **Replaceability benchmark
|
||||
(acceptance test):** RuvSense must consume either ESP32 opportunistic CSI
|
||||
or future 802.11bf chipset reports through the same `SensingTransport`
|
||||
and `SensingMeasurementReport` path, with no consumer-side rewrite — a
|
||||
future chipset adapter replaces `OpportunisticCsiBridge` without changing
|
||||
consumers.
|
||||
|
||||
Constraints: input validation at boundaries (typed errors, no panics on
|
||||
adversarial input), files under 500 lines, all protocol tests runnable
|
||||
without hardware.
|
||||
|
||||
### Acceptance checklist
|
||||
|
||||
| Area | Acceptance test |
|
||||
| --------------- | -------------------------------------------------------------------- |
|
||||
| Types | Serde round trip for setup, instance, report, SBP, termination |
|
||||
| FSM | Idle → setup → active → terminating → idle |
|
||||
| Rejection | Unsupported bandwidth, invalid period, duplicate setup ID |
|
||||
| Timeout | Negotiation timeout returns typed error and resets to Idle |
|
||||
| Threshold | Report emitted only when threshold condition is crossed |
|
||||
| SBP | Proxy request maps to responder path without direct sensor coupling |
|
||||
| Bridge | ESP32 CSI batch becomes standardized measurement report |
|
||||
| Safety | No panics on malformed inputs |
|
||||
| CI | All protocol tests run without hardware |
|
||||
| Maintainability | Each file under 500 lines |
|
||||
|
||||
### Non-Goals
|
||||
|
||||
This ADR does not claim IEEE 802.11bf conformance, certification, or OTA
|
||||
interoperability. It creates a typed protocol compatibility layer so RuView
|
||||
can consume standardized sensing reports when commodity silicon exposes
|
||||
them. Vendor-specific frame exchange, firmware hooks, trigger-frame
|
||||
sounding, and certification test vectors remain future ADRs.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
- RuView can adopt standardized WLAN sensing the day any chipset exposes
|
||||
802.11bf measurements — the data model, session FSM, and transport seam
|
||||
already exist and are tested.
|
||||
- The `OpportunisticCsiBridge` gives current ESP32 nodes a standardized-shape
|
||||
interface now, decoupling RuvSense consumers from the extraction mechanism.
|
||||
- Simulation transport enables protocol-level tests in CI without hardware.
|
||||
- `SpecProfile` + `SensingCapabilities` give a clean escape hatch for the
|
||||
partial/renamed vendor capabilities that will certainly arrive first.
|
||||
- Consent/policy metadata is structural from day one, not retrofitted.
|
||||
|
||||
### Negative
|
||||
- Code written against a standard with zero silicon risks drift: vendor
|
||||
implementations may interpret parameters differently; the layer may need
|
||||
rework at first real binding (drift risk scored 7/10 at acceptance).
|
||||
- Adds maintenance surface to wifi-densepose-hardware before any
|
||||
user-visible benefit (maintenance cost scored 3/10 — small without OTA).
|
||||
|
||||
### Neutral
|
||||
- ADR-152 §2.4's "watch item" remains: revisit when silicon/certification
|
||||
appears (re-check by 2026-12). This ADR changes only the "no code now"
|
||||
clause.
|
||||
|
||||
## Links
|
||||
|
||||
- ADR-152 — WiFi-Pose SOTA 2026 Intake (F4, §2.4 — amended by this ADR)
|
||||
- ADR-028 — ESP32 capability audit (opportunistic CSI extraction baseline)
|
||||
- ADR-029 — RuvSense multistatic sensing mode (consumer of sensing reports)
|
||||
- IEEE 802.11bf-2025 — Active Standard, board approval 2025-05-28, published
|
||||
2025-09-26: <https://standards.ieee.org/ieee/802.11bf/11574/>
|
||||
+15
-10
@@ -50,7 +50,7 @@ See [PR #405](https://github.com/ruvnet/RuView/pull/405) for full details.
|
||||
### What's New in v0.7.0
|
||||
|
||||
<details>
|
||||
<summary><strong>Camera Ground-Truth Training — 92.9% PCK@20</strong></summary>
|
||||
<summary><strong>Camera Ground-Truth Training</strong></summary>
|
||||
|
||||
**v0.7.0 adds camera-supervised pose training** using MediaPipe + real ESP32 CSI data:
|
||||
|
||||
@@ -76,15 +76,20 @@ node scripts/train-wiflow-supervised.js --data data/paired/*.jsonl --scale lite
|
||||
node scripts/eval-wiflow.js --model models/wiflow-real/wiflow-v1.json --data data/paired/*.jsonl
|
||||
```
|
||||
|
||||
**Result: 92.9% PCK@20** from a 5-minute data collection session with one ESP32-S3 and one webcam.
|
||||
> **Accuracy retraction (2026-06-10):** the "92.9% PCK@20" figure previously
|
||||
> shown here is retracted. A forensic recheck of the surviving eval holdout
|
||||
> (69 samples) found a constant-output model scored with an absolute
|
||||
> (non-torso-normalized) threshold on nearly-static frames — a protocol under
|
||||
> which a trivial mean-pose predictor scores 100%. Torso-normalized PCK@20 on
|
||||
> the same holdout is ~19% (from that degenerate predictor). No measured
|
||||
> camera-supervised PCK@20 is currently published (CHANGELOG, PR #535).
|
||||
|
||||
| Metric | Before (proxy) | After (camera-supervised) |
|
||||
|--------|----------------|--------------------------|
|
||||
| PCK@20 | 0% | **92.9%** |
|
||||
| Eval loss | 0.700 | **0.082** |
|
||||
| Bone constraint | N/A | **0.008** |
|
||||
| Training time | N/A | **19 minutes** |
|
||||
| Model size | N/A | **974 KB** |
|
||||
| Metric | Camera-supervised run (protocol retracted) |
|
||||
|--------|--------------------------------------------|
|
||||
| Eval loss | 0.082 |
|
||||
| Bone constraint | 0.008 |
|
||||
| Training time | 19 minutes |
|
||||
| Model size | 974 KB |
|
||||
|
||||
Pre-trained model: [HuggingFace ruv/ruview/wiflow-v1](https://huggingface.co/ruv/ruview)
|
||||
|
||||
@@ -868,7 +873,7 @@ Download a pre-built binary — no build toolchain needed:
|
||||
|
||||
| Release | What's included | Tag |
|
||||
|---------|-----------------|-----|
|
||||
| [v0.7.0](https://github.com/ruvnet/RuView/releases/tag/v0.7.0) | **Latest** — Camera-supervised WiFlow model (92.9% PCK@20), ground-truth training pipeline, ruvector optimizations | `v0.7.0` |
|
||||
| [v0.7.0](https://github.com/ruvnet/RuView/releases/tag/v0.7.0) | **Latest** — Camera-supervised WiFlow model (accuracy figure retracted 2026-06-10, see above), ground-truth training pipeline, ruvector optimizations | `v0.7.0` |
|
||||
| [v0.6.0](https://github.com/ruvnet/RuView/releases/tag/v0.6.0-esp32) | [Pre-trained models on HuggingFace](https://huggingface.co/ruv/ruview), 17 sensing apps, 51.6% contrastive improvement, 0.008ms inference | `v0.6.0-esp32` |
|
||||
| [v0.5.5](https://github.com/ruvnet/RuView/releases/tag/v0.5.5-esp32) | SNN + MinCut (#348 fix) + CNN spectrogram + WiFlow + multi-freq mesh + graph transformer | `v0.5.5-esp32` |
|
||||
| [v0.5.4](https://github.com/ruvnet/RuView/releases/tag/v0.5.4-esp32) | Cognitum Seed integration ([ADR-069](docs/adr/ADR-069-cognitum-seed-csi-pipeline.md)), 8-dim feature vectors, RVF store, witness chain, security hardening | `v0.5.4-esp32` |
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# We audited a state-of-the-art WiFi pose model. Here's what broke, what reproduced, and the 30× smaller model that nearly matches it.
|
||||
|
||||
*RuView team, June 2026. All numbers measured; full scripts and forensics in the
|
||||
[RuView repo](https://github.com/ruvnet/RuView/tree/main/benchmarks/wiflow-std).*
|
||||
|
||||
## The setup
|
||||
|
||||
WiFi sensing is having a moment: a 2026 preprint ("WiFlow", arXiv 2602.08661)
|
||||
claims **97.25% pose-estimation accuracy (PCK@20) from WiFi signals alone**,
|
||||
with a tiny 2.23M-parameter model — and unlike most papers, it ships
|
||||
everything: code, trained weights, and a 360,000-sample dataset.
|
||||
|
||||
We build WiFi sensing systems, so before adopting any external number we run
|
||||
it through a simple rule: **a claim is "CLAIMED" until we reproduce it, then
|
||||
it's "MEASURED."** Here's what happened when we tried.
|
||||
|
||||
## Day 1: nothing works
|
||||
|
||||
- **The code doesn't run.** The package imports a class that doesn't exist.
|
||||
(One-line fix.)
|
||||
- **The released model scores 0.08%, not 97.25%.** The shipped checkpoint was
|
||||
trained under a different data normalization than the shipped dataset —
|
||||
it's a real trained model, just not *this* pipeline's model. Even letting it
|
||||
cheat with a fitted per-keypoint correction only reaches 72%.
|
||||
- **The dataset is corrupted.** Its last 13 files contain garbage values up to
|
||||
3.4×10³⁸ (float32's maximum). Subtle consequence: the training loop uses
|
||||
fp16 mixed precision with no guards, so the first corrupted batch overflows
|
||||
and **permanently poisons the model's BatchNorm statistics**. Training from
|
||||
the public download produces NaN from epoch 1, every time.
|
||||
- The training script also crashes before its own test phase ever runs
|
||||
(calls an undefined function), and ignores its `--data_dir` flag.
|
||||
|
||||
At this point a less patient reader concludes "fraud." That would be wrong.
|
||||
|
||||
## Day 1, later: actually, the science is real
|
||||
|
||||
We repaired the artifacts — fixed the import, zeroed the 9,072 corrupted
|
||||
windows, retrained from scratch with the authors' own code and
|
||||
hyperparameters on one GPU (~50 minutes):
|
||||
|
||||
| Metric | Published | Our retrain |
|
||||
|---|---|---|
|
||||
| PCK@20 | 97.25% | **96.1–96.6%** |
|
||||
| PCK@50 | 99.48% | 99.0–99.1% |
|
||||
| Params | 2.23M | 2,225,042 (exact) |
|
||||
|
||||
**The claims reproduce.** What didn't survive contact was the *packaging*:
|
||||
wrong checkpoint, corrupted upload, broken glue code. This distinction —
|
||||
**artifact rot vs. bad science** — is the single most useful thing a
|
||||
reproduction can establish, and you can't establish it without actually
|
||||
running the thing.
|
||||
|
||||
(We filed all six defects upstream with fixes:
|
||||
[issue #3](https://github.com/DY2434/WiFlow-WiFi-Pose-Estimation-with-Spatio-Temporal-Decoupling/issues/3).
|
||||
And to be clear: the authors released more than 90% of papers do. That's the
|
||||
only reason this audit was possible.)
|
||||
|
||||
## Day 2: the model is also 2.6× too big
|
||||
|
||||
Once we could train, we asked: does the architecture need 2.23M parameters?
|
||||
|
||||
| Variant | Params | Accuracy (PCK@20) | Size on disk |
|
||||
|---|---|---|---|
|
||||
| Original | 2,225,042 | 96.61% | 8.97 MB |
|
||||
| **Half** | **843,834** | **96.62%** ✨ | — |
|
||||
| Quarter | 338,600 | 96.05% | — |
|
||||
| **Tiny** | **56,290** | **94.11%** | **295 KB** |
|
||||
|
||||
The half-width model **matches the original exactly** (and converges faster).
|
||||
The tiny one — 1/39th the parameters — gives up 2.5 points and runs at
|
||||
**0.66 ms per inference on a laptop CPU** (~1,500 poses/second) as a 295 KB
|
||||
ONNX file. For edge devices, that's the interesting end of the curve.
|
||||
|
||||
Quantization footnote: the paper's "~2.2 MB int8" estimate is reachable
|
||||
(we measured 2.44–2.53 MB) but only via conv-capable toolchains — PyTorch's
|
||||
one-line dynamic quantization converts *literally nothing* on this model
|
||||
(it has no Linear layers), a trap worth knowing about.
|
||||
|
||||
## What we took away
|
||||
|
||||
1. **Run the artifact, not the README.** Every number in a paper is one
|
||||
`git clone` away from being either confirmed or understood. Both outcomes
|
||||
are valuable; only one is publishable by the original authors.
|
||||
2. **fp16 + unvalidated data = silent model death.** Mixed-precision training
|
||||
with no NaN/inf guards doesn't fail loudly — it corrupts BatchNorm buffers
|
||||
and ships a broken model with a green progress bar. Validate inputs, or
|
||||
train in fp32, or guard the autocast.
|
||||
3. **Evidence-grade your own claims too.** Mid-audit, the same forensics
|
||||
tooling caught one of *our own* published accuracy numbers resting on a
|
||||
degenerate evaluation (a constant-output model scored with a flawed
|
||||
metric). We retracted it the same day. The rule has to cut both ways or
|
||||
it's marketing, not measurement.
|
||||
4. **Over-parameterization hides in SOTA tables.** Nobody publishes the
|
||||
half-size ablation that matches their headline model. Run it yourself;
|
||||
it's an hour of GPU time and sometimes it *is* the result.
|
||||
|
||||
*Reproduction scripts, corruption masks, the efficiency-sweep configs, and a
|
||||
numerically parity-proven Rust port (max divergence 1.2e-7) are all in
|
||||
[`benchmarks/wiflow-std/`](https://github.com/ruvnet/RuView/tree/main/benchmarks/wiflow-std).*
|
||||
+76
-16
@@ -1747,7 +1747,14 @@ See [ADR-071](adr/ADR-071-ruvllm-training-pipeline.md) and the [pretraining tuto
|
||||
|
||||
For significantly higher accuracy, use a webcam as a **temporary teacher** during training. The camera captures real 17-keypoint poses via MediaPipe, paired with simultaneous ESP32 CSI data. After training, the camera is no longer needed — the model runs on CSI only.
|
||||
|
||||
**Result: 92.9% PCK@20** from a 5-minute collection session.
|
||||
> **Accuracy note (2026-06-10):** the previously cited "92.9% PCK@20" figure is
|
||||
> retracted — a forensic recheck of the surviving eval holdout showed it came
|
||||
> from a constant-output model scored with an absolute (non-torso-normalized)
|
||||
> threshold on 69 nearly-static frames, a protocol under which a trivial
|
||||
> mean-pose predictor scores 100%. No measured camera-supervised PCK@20 is
|
||||
> currently published (see CHANGELOG, PR #535). Treat this workflow as a data
|
||||
> collection mechanism; accuracy claims will follow a ≥35-minute multi-pose
|
||||
> collection session evaluated with torso-normalized PCK.
|
||||
|
||||
### Requirements
|
||||
|
||||
@@ -1755,50 +1762,103 @@ For significantly higher accuracy, use a webcam as a **temporary teacher** durin
|
||||
- ESP32-S3 node streaming CSI over UDP (port 5005)
|
||||
- A webcam (laptop, USB, or Mac camera via Tailscale)
|
||||
|
||||
### Step 1: Capture Camera + CSI Simultaneously
|
||||
### Step 0: Check your CSI rate and plan the session length
|
||||
|
||||
Window yield is `csi_frames / 20` — **your CSI packet rate sets how long you
|
||||
must record.** Check it first (10-second probe):
|
||||
|
||||
```bash
|
||||
python - <<'EOF'
|
||||
import socket, time
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); s.bind(('0.0.0.0', 5005)); s.settimeout(2)
|
||||
n, t0 = 0, time.time()
|
||||
while time.time() - t0 < 10:
|
||||
try: s.recvfrom(4096); n += 1
|
||||
except socket.timeout: pass
|
||||
print(f"{n/10:.1f} Hz -> {n/10*60/20:.0f} windows/min")
|
||||
EOF
|
||||
```
|
||||
|
||||
| CSI rate | Windows/min | Minutes for 2,000 windows (minimum trainable) |
|
||||
|---|---|---|
|
||||
| ~13 Hz (idle network) | ~39 | ~52 min |
|
||||
| ~53 Hz (active self-ping, #985 firmware) | ~160 | ~13 min — record 35–40 min anyway for pose variety |
|
||||
|
||||
A 5-minute session is **not enough to train on** — it produces a few hundred
|
||||
windows of one pose context, and models trained on it memorize rather than
|
||||
generalize (this is what invalidated the earlier accuracy figure).
|
||||
|
||||
### Step 1: (Recommended) calibrate camera ↔ room
|
||||
|
||||
The two-checkerboard calibration (ADR-152 §2.1.3) puts labels in a shared 3D
|
||||
room frame instead of raw camera coordinates, which is the published defense
|
||||
against layout-brittle "coordinate overfitting" (PerceptAlign, MobiCom'26):
|
||||
|
||||
```bash
|
||||
python scripts/calibrate-camera-room.py # < 5 min, two checkerboards + a few photos
|
||||
```
|
||||
|
||||
Without it, collection still works but labels are camera-frame only and the
|
||||
trained model will not survive camera/node relocation.
|
||||
|
||||
### Step 2: Capture Camera + CSI Simultaneously
|
||||
|
||||
Run both scripts at the same time (in separate terminals):
|
||||
|
||||
```bash
|
||||
# Terminal 1: Record ESP32 CSI
|
||||
python scripts/record-csi-udp.py --duration 300
|
||||
# Terminal 1: Record ESP32 CSI (2400 s = 40 min)
|
||||
python scripts/record-csi-udp.py --duration 2400
|
||||
|
||||
# Terminal 2: Capture camera keypoints
|
||||
python scripts/collect-ground-truth.py --duration 300 --preview
|
||||
python scripts/collect-ground-truth.py --duration 2400 --preview \
|
||||
--calibration data/calibration/camera-room.json # omit if you skipped Step 1
|
||||
```
|
||||
|
||||
Move around naturally in front of the camera for 5 minutes. The `--preview` flag shows a live skeleton overlay.
|
||||
During capture: keep your **full body in frame** with good lighting (MediaPipe
|
||||
confidence must stay above 0.5 — low-confidence frames are dropped at
|
||||
alignment), and **change activity every 1–2 minutes**: walk, raise hands,
|
||||
squat, hands up, kick, wave, turn, jump, sit, stand still. Pose variety is
|
||||
what the model learns from; 40 minutes of sitting produces a constant-pose
|
||||
predictor.
|
||||
|
||||
### Step 2: Align and Train
|
||||
### Step 3: Align and Train
|
||||
|
||||
```bash
|
||||
# Align camera keypoints with CSI windows
|
||||
# Align camera keypoints with CSI windows (prints kept/dropped window counts —
|
||||
# expect roughly csi_frames/20 kept; investigate if far below)
|
||||
node scripts/align-ground-truth.js \
|
||||
--gt data/ground-truth/*.jsonl \
|
||||
--csi data/recordings/csi-*.csi.jsonl
|
||||
|
||||
# Train (start with lite, scale up as you collect more data)
|
||||
# Train (pick the preset matching your window count)
|
||||
node scripts/train-wiflow-supervised.js \
|
||||
--data data/paired/*.jsonl \
|
||||
--scale lite \
|
||||
--scale small \
|
||||
--epochs 50
|
||||
|
||||
# Evaluate
|
||||
# Evaluate — torso-normalized PCK on a TEMPORAL split
|
||||
node scripts/eval-wiflow.js \
|
||||
--model models/wiflow-supervised/wiflow-v1.json \
|
||||
--data data/paired/*.jsonl
|
||||
```
|
||||
|
||||
**Evaluation protocol matters.** Use `eval-wiflow.js` (torso-normalized
|
||||
PCK@20, the metric comparable to published WiFi-pose results) on a temporal
|
||||
hold-out, and sanity-check that predictions actually vary across frames
|
||||
(`pred std > 0`) — a constant-pose model can score deceptively well on
|
||||
near-static data under weaker protocols. See
|
||||
`benchmarks/wiflow-std/RESULTS.md` for the forensic case study.
|
||||
|
||||
### Scale Presets
|
||||
|
||||
| Preset | Params | Training Time | Best For |
|
||||
|--------|--------|---------------|----------|
|
||||
| `--scale lite` | 189K | ~19 min | < 1,000 samples (5 min capture) |
|
||||
| `--scale small` | 474K | ~1 hr | 1K-10K samples |
|
||||
| `--scale medium` | 800K | ~2 hrs | 10K-50K samples |
|
||||
| `--scale full` | 7.7M | ~8 hrs | 50K+ samples (GPU recommended) |
|
||||
| `--scale lite` | 189K | ~19 min | sanity runs only (< 2K windows trains poorly) |
|
||||
| `--scale small` | 474K | ~1 hr | 2K-10K windows (one 40-min session) |
|
||||
| `--scale medium` | 800K | ~2 hrs | 10K-50K windows (multiple sessions/rooms) |
|
||||
| `--scale full` | 7.7M | ~8 hrs | 50K+ windows (GPU recommended) |
|
||||
|
||||
See [ADR-079](adr/ADR-079-camera-ground-truth-training.md) for the full design and optimization details.
|
||||
See [ADR-079](adr/ADR-079-camera-ground-truth-training.md) for the full design and optimization details, and ADR-152 §2.2 for the external WiFlow-STD benchmark these numbers should be read against.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Two-checkerboard camera-room calibration for WiFi pose training (ADR-152 S2.1.3).
|
||||
|
||||
Aligns the ADR-079 ground-truth camera and the ESP32 WiFi transceivers in
|
||||
one shared 3D room frame -- the PerceptAlign (arXiv 2601.12252) defense
|
||||
against "coordinate overfitting", where CSI-to-camera-coordinate regression
|
||||
memorizes the deployment layout and collapses cross-layout.
|
||||
|
||||
Procedure (<5 minutes):
|
||||
1. Print a checkerboard (default 9x6 inner corners, 25 mm squares).
|
||||
2. Tape one board flat on the ORIGIN WALL, tape-measure its top-left inner
|
||||
corner position in room coordinates (+x along wall, +y into room, +z up).
|
||||
3. Lay the second board flat on the FLOOR, measure its near-left inner corner.
|
||||
4. With the collection camera in its final position, photograph each board.
|
||||
5. Run this script; tape-measure each ESP32 node position when prompted
|
||||
(or pass --geometry nodes.json).
|
||||
|
||||
Output: a calibration bundle JSON consumed by
|
||||
scripts/collect-ground-truth.py --calibration <bundle.json>
|
||||
|
||||
Usage:
|
||||
python scripts/calibrate-camera-room.py \\
|
||||
--wall-image photos/wall.jpg --wall-origin 0.50,0.0,1.60 \\
|
||||
--floor-image photos/floor.jpg --floor-origin 1.00,1.00,0.0 \\
|
||||
--calib-images "photos/intrinsics/*.jpg" \\
|
||||
--geometry config/transceivers.json \\
|
||||
--output data/calibration/camera-room.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
import calibration_lib as cal # noqa: E402
|
||||
|
||||
INTRINSICS_CACHE = Path("data") / ".cache" / "camera_intrinsics.json"
|
||||
|
||||
|
||||
def parse_vec3(text: str) -> np.ndarray:
|
||||
parts = [float(p) for p in text.replace(",", " ").split()]
|
||||
if len(parts) != 3:
|
||||
raise argparse.ArgumentTypeError(f"Expected 3 comma-separated numbers, got {text!r}")
|
||||
return np.array(parts, dtype=np.float64)
|
||||
|
||||
|
||||
def detect_corners(image_path: Path, cols: int, rows: int) -> tuple[np.ndarray, tuple[int, int]]:
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
print(f"ERROR: Cannot read image {image_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
corners = cal.find_board_corners(image, cols, rows)
|
||||
if corners is None:
|
||||
print(
|
||||
f"ERROR: No {cols}x{rows} checkerboard found in {image_path}. "
|
||||
"Check lighting, focus, and the --board-cols/--board-rows flags.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
h, w = image.shape[:2]
|
||||
return corners, (w, h)
|
||||
|
||||
|
||||
def resolve_intrinsics(args, repo_root: Path, board_args: tuple[int, int, float]) -> dict:
|
||||
"""Pre-computed file > cached > computed from --calib-images >
|
||||
last-resort 2-view estimate from the wall+floor photos themselves."""
|
||||
cols, rows, square_m = board_args
|
||||
|
||||
if args.intrinsics:
|
||||
print(f"Intrinsics: loading {args.intrinsics}")
|
||||
return cal.load_intrinsics(Path(args.intrinsics))
|
||||
|
||||
cache_path = repo_root / INTRINSICS_CACHE
|
||||
if cache_path.exists() and not args.recalibrate_intrinsics:
|
||||
print(f"Intrinsics: using cached {cache_path} (pass --recalibrate-intrinsics to redo)")
|
||||
intr = cal.load_intrinsics(cache_path)
|
||||
intr["source"] = "cached"
|
||||
return intr
|
||||
|
||||
if args.calib_images:
|
||||
paths = sorted(glob.glob(args.calib_images))
|
||||
if len(paths) < 3:
|
||||
print(
|
||||
f"ERROR: --calib-images matched only {len(paths)} file(s); "
|
||||
"need >= 3 checkerboard views for stable intrinsics.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
corner_sets, image_size = [], None
|
||||
for p in paths:
|
||||
corners, size = detect_corners(Path(p), cols, rows)
|
||||
if image_size is None:
|
||||
image_size = size
|
||||
elif size != image_size:
|
||||
print(f"ERROR: {p} has size {size}, expected {image_size}.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
corner_sets.append(corners)
|
||||
print(f" corners found: {p}")
|
||||
intr = cal.compute_intrinsics(corner_sets, image_size, cols, rows, square_m)
|
||||
print(f"Intrinsics: computed from {len(paths)} views, "
|
||||
f"reprojection RMS {intr['reprojection_error_px']:.3f} px")
|
||||
cal.save_bundle(intr, cache_path) # plain JSON write; reused on next run
|
||||
print(f" cached to {cache_path}")
|
||||
return intr
|
||||
|
||||
# Last resort: 2-view calibration from the extrinsic photos. Workable but
|
||||
# weak -- warn loudly and recommend a proper multi-view pass.
|
||||
print(
|
||||
"WARNING: no --intrinsics / cache / --calib-images; estimating intrinsics "
|
||||
"from the wall+floor photos alone (2 views, low quality). Prefer "
|
||||
"--calib-images with 5-10 varied board views.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
corner_sets, image_size = [], None
|
||||
for p in (args.wall_image, args.floor_image):
|
||||
corners, size = detect_corners(Path(p), cols, rows)
|
||||
image_size = image_size or size
|
||||
corner_sets.append(corners)
|
||||
intr = cal.compute_intrinsics(corner_sets, image_size, cols, rows, square_m)
|
||||
intr["source"] = "two-view-fallback"
|
||||
return intr
|
||||
|
||||
|
||||
def prompt_transceiver_geometry() -> dict:
|
||||
"""Tape-measure entry of ESP32 node positions in room coordinates."""
|
||||
print()
|
||||
print("Transceiver geometry -- enter one node per line:")
|
||||
print(" <node-id> <x> <y> <z> [yaw_deg] (meters, room frame; blank line to finish)")
|
||||
print(" example: esp32-s3-a 0.10 2.40 1.10 180")
|
||||
nodes = []
|
||||
while True:
|
||||
try:
|
||||
line = input("node> ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not line:
|
||||
break
|
||||
parts = line.split()
|
||||
if len(parts) not in (4, 5):
|
||||
print(" expected: <node-id> <x> <y> <z> [yaw_deg]", file=sys.stderr)
|
||||
continue
|
||||
try:
|
||||
node = {"id": parts[0], "position_m": [float(parts[1]), float(parts[2]), float(parts[3])]}
|
||||
if len(parts) == 5:
|
||||
node["antenna_yaw_deg"] = float(parts[4])
|
||||
except ValueError:
|
||||
print(" positions must be numeric", file=sys.stderr)
|
||||
continue
|
||||
nodes.append(node)
|
||||
if not nodes:
|
||||
print("WARNING: no transceiver nodes entered; bundle will carry empty geometry.",
|
||||
file=sys.stderr)
|
||||
return {"nodes": nodes, "units": "meters", "source": "tape-measure-prompt"}
|
||||
|
||||
|
||||
def load_geometry_file(path: Path) -> dict:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
nodes = data.get("nodes", data if isinstance(data, list) else None)
|
||||
if nodes is None:
|
||||
raise ValueError(f"{path}: expected {{'nodes': [...]}} or a top-level list")
|
||||
for node in nodes:
|
||||
if "id" not in node or "position_m" not in node:
|
||||
raise ValueError(f"{path}: each node needs 'id' and 'position_m' [x,y,z]")
|
||||
return {"nodes": nodes, "units": "meters", "source": "file"}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Two-checkerboard camera-room calibration (ADR-152 S2.1.3 / ADR-079)."
|
||||
)
|
||||
parser.add_argument("--wall-image", required=True,
|
||||
help="Photo of the checkerboard on the origin wall")
|
||||
parser.add_argument("--floor-image", required=True,
|
||||
help="Photo of the checkerboard on the floor (camera NOT moved)")
|
||||
parser.add_argument("--wall-origin", type=parse_vec3, default="0.5,0.0,1.6",
|
||||
help="Room xyz (m) of the wall board's first inner corner "
|
||||
"(default: 0.5,0.0,1.6)")
|
||||
parser.add_argument("--floor-origin", type=parse_vec3, default="1.0,1.0,0.0",
|
||||
help="Room xyz (m) of the floor board's first inner corner "
|
||||
"(default: 1.0,1.0,0.0)")
|
||||
parser.add_argument("--wall-axes", default="+x,-z",
|
||||
help="Wall board column,row directions in room frame (default: +x,-z)")
|
||||
parser.add_argument("--floor-axes", default="+x,+y",
|
||||
help="Floor board column,row directions in room frame (default: +x,+y)")
|
||||
parser.add_argument("--board-cols", type=int, default=cal.DEFAULT_BOARD_COLS,
|
||||
help=f"Inner corners per row (default: {cal.DEFAULT_BOARD_COLS})")
|
||||
parser.add_argument("--board-rows", type=int, default=cal.DEFAULT_BOARD_ROWS,
|
||||
help=f"Inner corners per column (default: {cal.DEFAULT_BOARD_ROWS})")
|
||||
parser.add_argument("--square-size-mm", type=float, default=cal.DEFAULT_SQUARE_SIZE_MM,
|
||||
help=f"Checkerboard square size in mm (default: {cal.DEFAULT_SQUARE_SIZE_MM})")
|
||||
parser.add_argument("--intrinsics", help="Pre-computed intrinsics JSON (skips computation)")
|
||||
parser.add_argument("--calib-images",
|
||||
help="Glob of >=3 checkerboard photos for intrinsics computation")
|
||||
parser.add_argument("--recalibrate-intrinsics", action="store_true",
|
||||
help="Ignore the cached intrinsics and recompute")
|
||||
parser.add_argument("--geometry",
|
||||
help="Transceiver geometry JSON ({nodes:[{id,position_m,[antenna_yaw_deg]}]}); "
|
||||
"omit to be prompted for tape-measure entry")
|
||||
parser.add_argument("--output", default=None,
|
||||
help="Bundle output path (default: data/calibration/camera-room-<ts>.json)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if isinstance(args.wall_origin, str):
|
||||
args.wall_origin = parse_vec3(args.wall_origin)
|
||||
if isinstance(args.floor_origin, str):
|
||||
args.floor_origin = parse_vec3(args.floor_origin)
|
||||
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
cols, rows = args.board_cols, args.board_rows
|
||||
square_m = args.square_size_mm / 1000.0
|
||||
|
||||
# --- Intrinsics ---
|
||||
intrinsics = resolve_intrinsics(args, repo_root, (cols, rows, square_m))
|
||||
camera_matrix = np.asarray(intrinsics["camera_matrix"], dtype=np.float64)
|
||||
dist_coeffs = np.asarray(intrinsics["dist_coeffs"], dtype=np.float64)
|
||||
|
||||
# --- Corner detection on the two placed boards ---
|
||||
wall_corners, wall_size = detect_corners(Path(args.wall_image), cols, rows)
|
||||
floor_corners, floor_size = detect_corners(Path(args.floor_image), cols, rows)
|
||||
if wall_size != floor_size:
|
||||
print(f"ERROR: wall image {wall_size} and floor image {floor_size} differ in size; "
|
||||
"both must come from the fixed collection camera.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print(f"Corners detected: wall + floor boards ({cols}x{rows}, {args.square_size_mm} mm)")
|
||||
|
||||
# Re-scale intrinsics if they were computed at a different resolution
|
||||
# than the extrinsic photos (the bundle always stores K at wall_size).
|
||||
intr_size = tuple(intrinsics["image_size"])
|
||||
if intr_size != wall_size:
|
||||
sx, sy = wall_size[0] / intr_size[0], wall_size[1] / intr_size[1]
|
||||
camera_matrix[0, 0] *= sx
|
||||
camera_matrix[0, 2] *= sx
|
||||
camera_matrix[1, 1] *= sy
|
||||
camera_matrix[1, 2] *= sy
|
||||
print(f" intrinsics scaled {intr_size} -> {wall_size}")
|
||||
intrinsics = {**intrinsics, "camera_matrix": camera_matrix.tolist(),
|
||||
"image_size": list(wall_size)}
|
||||
|
||||
# --- Room-frame corner positions from the measured placements ---
|
||||
wall_u, wall_v = (cal.parse_axis(t) for t in args.wall_axes.split(","))
|
||||
floor_u, floor_v = (cal.parse_axis(t) for t in args.floor_axes.split(","))
|
||||
wall_room = cal.board_room_points(cols, rows, square_m, args.wall_origin, wall_u, wall_v)
|
||||
floor_room = cal.board_room_points(cols, rows, square_m, args.floor_origin, floor_u, floor_v)
|
||||
|
||||
# --- Extrinsics: joint two-board solve (resolves per-board corner-order
|
||||
# ambiguity -- a single planar board is centrosymmetric; the pair is not) ---
|
||||
extrinsics = cal.solve_two_board_extrinsics(
|
||||
wall_room, wall_corners, floor_room, floor_corners, camera_matrix, dist_coeffs
|
||||
)
|
||||
wall_rmse = extrinsics["per_board"]["wall"]["rmse_px"]
|
||||
floor_rmse = extrinsics["per_board"]["floor"]["rmse_px"]
|
||||
print(f" joint solve: RMSE {extrinsics['rmse_px']:.3f} px "
|
||||
f"(wall {wall_rmse:.3f} / floor {floor_rmse:.3f})")
|
||||
print(f" camera at room {np.round(extrinsics['translation_m'], 3).tolist()} m")
|
||||
if max(wall_rmse, floor_rmse) > 3.0:
|
||||
print(
|
||||
"WARNING: high per-board reprojection error -- re-check the measured "
|
||||
"board origins/axes and that the camera did not move between photos.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# --- Transceiver geometry ---
|
||||
if args.geometry:
|
||||
geometry = load_geometry_file(Path(args.geometry))
|
||||
print(f"Transceiver geometry: {len(geometry['nodes'])} node(s) from {args.geometry}")
|
||||
else:
|
||||
geometry = prompt_transceiver_geometry()
|
||||
|
||||
# --- Bundle ---
|
||||
bundle = cal.make_bundle(
|
||||
camera_intrinsics=intrinsics,
|
||||
camera_to_room_extrinsics=extrinsics,
|
||||
checkerboard_spec={"cols": cols, "rows": rows, "square_size_mm": args.square_size_mm},
|
||||
transceiver_geometry=geometry,
|
||||
)
|
||||
if args.output:
|
||||
out_path = Path(args.output)
|
||||
else:
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = repo_root / "data" / "calibration" / f"camera-room-{ts}.json"
|
||||
cal.save_bundle(bundle, out_path)
|
||||
|
||||
print()
|
||||
print("=== Calibration bundle written ===")
|
||||
print(f" path: {out_path}")
|
||||
print(f" calibration_id: {cal.calibration_id(bundle)}")
|
||||
print(f" next: python scripts/collect-ground-truth.py --calibration {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Camera-room calibration library for WiFi pose ground truth (ADR-152 S2.1.3).
|
||||
|
||||
Implements the PerceptAlign-style two-checkerboard alignment adopted in
|
||||
ADR-152 S2.1.3 to defend the ADR-079 camera-supervised pipeline against
|
||||
"coordinate overfitting" (arXiv 2601.12252, MobiCom'26): models regressing
|
||||
CSI to raw camera-frame coordinates memorize the deployment layout and
|
||||
collapse cross-layout. The fix is to express camera AND WiFi transceivers
|
||||
in one shared 3D room frame, and stamp every training label with the
|
||||
calibration + transceiver geometry that produced it.
|
||||
|
||||
Used by:
|
||||
scripts/calibrate-camera-room.py (produces the calibration bundle)
|
||||
scripts/collect-ground-truth.py (consumes it via --calibration)
|
||||
|
||||
Room frame convention (right-handed, meters):
|
||||
origin = a designated wall/floor corner of the room
|
||||
+x = along the origin wall
|
||||
+y = into the room (away from the origin wall)
|
||||
+z = up
|
||||
|
||||
No-depth limitation (IMPORTANT): a single 2D camera keypoint constrains
|
||||
only a *ray* in the room frame, not a 3D point. The transform helpers here
|
||||
therefore return unit bearing rays from the camera center -- a projective
|
||||
alignment. Consumers that need metric 3D points must supply a depth
|
||||
assumption downstream (floor-plane intersection, known subject height,
|
||||
multi-view triangulation, ...). Raw image coordinates are always preserved
|
||||
alongside the room-frame rays so training can choose either representation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
BUNDLE_SCHEMA_VERSION = 1
|
||||
BUNDLE_METHOD = "two-checkerboard"
|
||||
|
||||
# Default checkerboard: 9x6 inner corners, 25 mm squares (a common print).
|
||||
DEFAULT_BOARD_COLS = 9
|
||||
DEFAULT_BOARD_ROWS = 6
|
||||
DEFAULT_SQUARE_SIZE_MM = 25.0
|
||||
|
||||
_AXIS_TOKENS = {
|
||||
"+x": (1.0, 0.0, 0.0), "-x": (-1.0, 0.0, 0.0),
|
||||
"+y": (0.0, 1.0, 0.0), "-y": (0.0, -1.0, 0.0),
|
||||
"+z": (0.0, 0.0, 1.0), "-z": (0.0, 0.0, -1.0),
|
||||
}
|
||||
|
||||
|
||||
def parse_axis(token: str) -> np.ndarray:
|
||||
"""Parse an axis token like '+x' or '-z' into a room-frame unit vector."""
|
||||
key = token.strip().lower()
|
||||
if key in _AXIS_TOKENS:
|
||||
return np.array(_AXIS_TOKENS[key], dtype=np.float64)
|
||||
raise ValueError(f"Invalid axis token {token!r}; expected one of {sorted(_AXIS_TOKENS)}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkerboard geometry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def board_object_points(cols: int, rows: int, square_size_m: float) -> np.ndarray:
|
||||
"""Inner-corner positions in the board's own frame (z=0 plane), row-major.
|
||||
|
||||
Matches the corner ordering of cv2.findChessboardCorners for a
|
||||
(cols, rows) pattern: cols varies fastest.
|
||||
"""
|
||||
pts = np.zeros((rows * cols, 3), dtype=np.float64)
|
||||
grid = np.mgrid[0:cols, 0:rows].T.reshape(-1, 2) # (rows*cols, 2), cols fastest
|
||||
pts[:, :2] = grid * square_size_m
|
||||
return pts
|
||||
|
||||
|
||||
def board_room_points(
|
||||
cols: int,
|
||||
rows: int,
|
||||
square_size_m: float,
|
||||
origin: np.ndarray,
|
||||
u_axis: np.ndarray,
|
||||
v_axis: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Inner-corner positions in ROOM coordinates for a board placed at a
|
||||
known position: first corner at `origin`, columns stepping along
|
||||
`u_axis`, rows stepping along `v_axis` (both room-frame unit vectors).
|
||||
"""
|
||||
local = board_object_points(cols, rows, square_size_m)
|
||||
origin = np.asarray(origin, dtype=np.float64)
|
||||
u = np.asarray(u_axis, dtype=np.float64)
|
||||
v = np.asarray(v_axis, dtype=np.float64)
|
||||
return origin[None, :] + local[:, 0:1] * u[None, :] + local[:, 1:2] * v[None, :]
|
||||
|
||||
|
||||
def find_board_corners(image: np.ndarray, cols: int, rows: int) -> np.ndarray | None:
|
||||
"""Detect and sub-pixel-refine checkerboard inner corners.
|
||||
|
||||
Returns (cols*rows, 2) float64 pixel coordinates, or None if not found.
|
||||
"""
|
||||
gray = image if image.ndim == 2 else cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
flags = cv2.CALIB_CB_ADAPTIVE_THRESH | cv2.CALIB_CB_NORMALIZE_IMAGE
|
||||
found, corners = cv2.findChessboardCorners(gray, (cols, rows), flags=flags)
|
||||
if not found:
|
||||
return None
|
||||
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 1e-3)
|
||||
corners = cv2.cornerSubPix(gray, corners, (11, 11), (-1, -1), criteria)
|
||||
return corners.reshape(-1, 2).astype(np.float64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intrinsics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_intrinsics(
|
||||
corner_sets: list[np.ndarray],
|
||||
image_size: tuple[int, int],
|
||||
cols: int,
|
||||
rows: int,
|
||||
square_size_m: float,
|
||||
) -> dict:
|
||||
"""Camera intrinsics from N checkerboard views via cv2.calibrateCamera.
|
||||
|
||||
corner_sets: list of (cols*rows, 2) pixel corner arrays.
|
||||
image_size: (width, height) of the calibration images.
|
||||
"""
|
||||
obj = board_object_points(cols, rows, square_size_m).astype(np.float32)
|
||||
obj_pts = [obj for _ in corner_sets]
|
||||
img_pts = [c.reshape(-1, 1, 2).astype(np.float32) for c in corner_sets]
|
||||
rms, camera_matrix, dist_coeffs, _, _ = cv2.calibrateCamera(
|
||||
obj_pts, img_pts, tuple(image_size), None, None
|
||||
)
|
||||
return {
|
||||
"image_size": [int(image_size[0]), int(image_size[1])],
|
||||
"camera_matrix": camera_matrix.tolist(),
|
||||
"dist_coeffs": dist_coeffs.ravel().tolist(),
|
||||
"reprojection_error_px": float(rms),
|
||||
"source": "computed",
|
||||
}
|
||||
|
||||
|
||||
def load_intrinsics(path: Path) -> dict:
|
||||
"""Load a pre-computed intrinsics JSON ({camera_matrix, dist_coeffs, image_size})."""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# Accept either a bare intrinsics dict or a full calibration bundle.
|
||||
intr = data.get("camera_intrinsics", data)
|
||||
for key in ("camera_matrix", "dist_coeffs", "image_size"):
|
||||
if key not in intr:
|
||||
raise ValueError(f"Intrinsics file {path} missing key {key!r}")
|
||||
intr = dict(intr)
|
||||
intr["source"] = "file"
|
||||
return intr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extrinsics (camera -> room rigid transform)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def reprojection_rmse(
|
||||
room_points: np.ndarray,
|
||||
image_points: np.ndarray,
|
||||
rvec: np.ndarray,
|
||||
tvec: np.ndarray,
|
||||
camera_matrix: np.ndarray,
|
||||
dist_coeffs: np.ndarray,
|
||||
) -> float:
|
||||
proj, _ = cv2.projectPoints(room_points, rvec, tvec, camera_matrix, dist_coeffs)
|
||||
err = proj.reshape(-1, 2) - image_points.reshape(-1, 2)
|
||||
return float(np.sqrt(np.mean(np.sum(err**2, axis=1))))
|
||||
|
||||
|
||||
def _solve_pnp(
|
||||
room_points: np.ndarray,
|
||||
image_points: np.ndarray,
|
||||
camera_matrix: np.ndarray,
|
||||
dist_coeffs: np.ndarray,
|
||||
) -> dict | None:
|
||||
"""One solvePnP run (room->camera), inverted to camera->room. Returns
|
||||
{rotation (3x3 camera->room), translation_m (camera center in room
|
||||
frame), rmse_px} or None on failure.
|
||||
"""
|
||||
ok, rvec, tvec = cv2.solvePnP(
|
||||
room_points.reshape(-1, 1, 3),
|
||||
image_points.reshape(-1, 1, 2),
|
||||
camera_matrix,
|
||||
dist_coeffs,
|
||||
flags=cv2.SOLVEPNP_ITERATIVE,
|
||||
)
|
||||
if not ok:
|
||||
return None
|
||||
rmse = reprojection_rmse(room_points, image_points, rvec, tvec, camera_matrix, dist_coeffs)
|
||||
r_room_to_cam, _ = cv2.Rodrigues(rvec)
|
||||
r_cam_to_room = r_room_to_cam.T
|
||||
camera_center_room = (-r_cam_to_room @ tvec).ravel()
|
||||
return {
|
||||
"rotation": r_cam_to_room.tolist(),
|
||||
"translation_m": camera_center_room.tolist(),
|
||||
"rmse_px": rmse,
|
||||
}
|
||||
|
||||
|
||||
def solve_extrinsics(
|
||||
room_points: np.ndarray,
|
||||
image_points: np.ndarray,
|
||||
camera_matrix: np.ndarray,
|
||||
dist_coeffs: np.ndarray,
|
||||
) -> dict:
|
||||
"""Solve the camera->room rigid transform from 3D room-frame points and
|
||||
their 2D pixel observations.
|
||||
|
||||
NOTE: the corner grid of a single planar checkerboard is centrosymmetric,
|
||||
so the corner ordering returned by findChessboardCorners (which may
|
||||
enumerate from either board end) cannot be disambiguated from one board
|
||||
alone -- the reversed ordering fits a ghost pose with identical
|
||||
reprojection error. Use solve_two_board_extrinsics for the full
|
||||
two-checkerboard procedure, where the joint point set breaks the symmetry.
|
||||
"""
|
||||
ext = _solve_pnp(room_points, image_points, camera_matrix, dist_coeffs)
|
||||
if ext is None:
|
||||
raise RuntimeError("solvePnP failed")
|
||||
return ext
|
||||
|
||||
|
||||
def solve_two_board_extrinsics(
|
||||
wall_room: np.ndarray,
|
||||
wall_image: np.ndarray,
|
||||
floor_room: np.ndarray,
|
||||
floor_image: np.ndarray,
|
||||
camera_matrix: np.ndarray,
|
||||
dist_coeffs: np.ndarray,
|
||||
) -> dict:
|
||||
"""Joint camera->room solve over both checkerboards (the ADR-152 S2.1.3
|
||||
two-checkerboard method).
|
||||
|
||||
Tries all 4 per-board corner-ordering combinations: each board's ordering
|
||||
is individually ambiguous (centrosymmetric grid), but the combined
|
||||
wall+floor point set is not, so exactly one combination reaches minimal
|
||||
reprojection error. Returns the solve_extrinsics dict plus
|
||||
{wall_flipped, floor_flipped, per_board: {wall|floor: {rmse_px}}}.
|
||||
"""
|
||||
best = None
|
||||
for wall_flipped in (False, True):
|
||||
for floor_flipped in (False, True):
|
||||
wi = wall_image[::-1].copy() if wall_flipped else wall_image
|
||||
fi = floor_image[::-1].copy() if floor_flipped else floor_image
|
||||
room = np.concatenate([wall_room, floor_room], axis=0)
|
||||
img = np.concatenate([wi, fi], axis=0)
|
||||
ext = _solve_pnp(room, img, camera_matrix, dist_coeffs)
|
||||
if ext is None:
|
||||
continue
|
||||
if best is None or ext["rmse_px"] < best[0]["rmse_px"]:
|
||||
ext["wall_flipped"] = wall_flipped
|
||||
ext["floor_flipped"] = floor_flipped
|
||||
rvec, _ = cv2.Rodrigues(np.asarray(ext["rotation"]).T)
|
||||
tvec = -np.asarray(ext["rotation"]).T @ np.asarray(ext["translation_m"])
|
||||
ext["per_board"] = {
|
||||
"wall": {"rmse_px": reprojection_rmse(
|
||||
wall_room, wi, rvec, tvec, camera_matrix, dist_coeffs)},
|
||||
"floor": {"rmse_px": reprojection_rmse(
|
||||
floor_room, fi, rvec, tvec, camera_matrix, dist_coeffs)},
|
||||
}
|
||||
best = (ext,)
|
||||
if best is None:
|
||||
raise RuntimeError("solvePnP failed for all corner-ordering combinations")
|
||||
return best[0]
|
||||
|
||||
|
||||
def extrinsics_consistency(ext_a: dict, ext_b: dict) -> dict:
|
||||
"""Angular + translational disagreement between two extrinsic solutions
|
||||
(the two single-board solves). Large values mean a mis-entered board
|
||||
placement or a bad corner detection.
|
||||
"""
|
||||
ra = np.asarray(ext_a["rotation"])
|
||||
rb = np.asarray(ext_b["rotation"])
|
||||
r_delta = ra.T @ rb
|
||||
angle = float(np.degrees(np.arccos(np.clip((np.trace(r_delta) - 1.0) / 2.0, -1.0, 1.0))))
|
||||
t_delta = float(
|
||||
np.linalg.norm(np.asarray(ext_a["translation_m"]) - np.asarray(ext_b["translation_m"]))
|
||||
)
|
||||
return {"rotation_deg": angle, "translation_m": t_delta}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Calibration bundle (the artifact written to disk)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_bundle(
|
||||
camera_intrinsics: dict,
|
||||
camera_to_room_extrinsics: dict,
|
||||
checkerboard_spec: dict,
|
||||
transceiver_geometry: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"schema_version": BUNDLE_SCHEMA_VERSION,
|
||||
"method": BUNDLE_METHOD,
|
||||
"calibrated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"room_frame": {
|
||||
"description": "right-handed; origin at wall/floor corner; "
|
||||
"+x along origin wall, +y into room, +z up",
|
||||
"units": "meters",
|
||||
},
|
||||
"checkerboard_spec": checkerboard_spec,
|
||||
"camera_intrinsics": camera_intrinsics,
|
||||
"camera_to_room_extrinsics": camera_to_room_extrinsics,
|
||||
"transceiver_geometry": transceiver_geometry,
|
||||
}
|
||||
|
||||
|
||||
def calibration_id(bundle: dict) -> str:
|
||||
"""Stable content hash of a bundle -- stamped onto every emitted sample
|
||||
so a label can always be traced to the exact calibration that framed it.
|
||||
"""
|
||||
canonical = json.dumps(bundle, sort_keys=True, separators=(",", ":"))
|
||||
return "sha256:" + hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def save_bundle(bundle: dict, path: Path) -> None:
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(bundle, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def load_bundle(path: Path) -> dict:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
bundle = json.load(f)
|
||||
for key in ("camera_intrinsics", "camera_to_room_extrinsics", "transceiver_geometry"):
|
||||
if key not in bundle:
|
||||
raise ValueError(f"Calibration bundle {path} missing key {key!r}")
|
||||
return bundle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keypoint transform (image -> room-frame bearing rays)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CalibrationContext:
|
||||
"""Pre-computed transform state for a collection session.
|
||||
|
||||
Scales the bundle's intrinsics to the live capture resolution (MediaPipe
|
||||
keypoints are normalized [0,1], so we need the actual frame size to get
|
||||
back to pixels before undistorting).
|
||||
"""
|
||||
|
||||
def __init__(self, bundle: dict, frame_w: int, frame_h: int):
|
||||
self.bundle = bundle
|
||||
self.calibration_id = calibration_id(bundle)
|
||||
self.transceiver_geometry = bundle["transceiver_geometry"]
|
||||
self.frame_w = int(frame_w)
|
||||
self.frame_h = int(frame_h)
|
||||
|
||||
intr = bundle["camera_intrinsics"]
|
||||
k = np.asarray(intr["camera_matrix"], dtype=np.float64)
|
||||
cal_w, cal_h = intr["image_size"]
|
||||
sx = self.frame_w / float(cal_w)
|
||||
sy = self.frame_h / float(cal_h)
|
||||
k = k.copy()
|
||||
k[0, 0] *= sx
|
||||
k[0, 2] *= sx
|
||||
k[1, 1] *= sy
|
||||
k[1, 2] *= sy
|
||||
self.camera_matrix = k
|
||||
self.dist_coeffs = np.asarray(intr["dist_coeffs"], dtype=np.float64)
|
||||
|
||||
ext = bundle["camera_to_room_extrinsics"]
|
||||
self.r_cam_to_room = np.asarray(ext["rotation"], dtype=np.float64)
|
||||
self.origin_room = np.asarray(ext["translation_m"], dtype=np.float64)
|
||||
|
||||
def transform_keypoints(self, keypoints_norm: list[list[float]]) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Normalized [0,1] image keypoints -> unit bearing rays in the room
|
||||
frame, anchored at the camera center.
|
||||
|
||||
Projective alignment ONLY (no depth): each returned ray is the locus
|
||||
of room positions consistent with the 2D observation. Returns
|
||||
(camera_origin_room (3,), ray_dirs (N, 3) unit vectors).
|
||||
"""
|
||||
pts = np.asarray(keypoints_norm, dtype=np.float64)
|
||||
pts_px = pts * np.array([self.frame_w, self.frame_h], dtype=np.float64)
|
||||
undist = cv2.undistortPoints(
|
||||
pts_px.reshape(-1, 1, 2), self.camera_matrix, self.dist_coeffs
|
||||
).reshape(-1, 2)
|
||||
rays_cam = np.concatenate([undist, np.ones((len(undist), 1))], axis=1)
|
||||
rays_cam /= np.linalg.norm(rays_cam, axis=1, keepdims=True)
|
||||
rays_room = (self.r_cam_to_room @ rays_cam.T).T
|
||||
return self.origin_room, rays_room
|
||||
|
||||
|
||||
def load_calibration_context(path: Path, frame_w: int, frame_h: int) -> CalibrationContext:
|
||||
return CalibrationContext(load_bundle(path), frame_w, frame_h)
|
||||
|
||||
|
||||
def augment_record(record: dict, ctx: CalibrationContext | None) -> dict:
|
||||
"""Stamp a ground-truth record with room-frame rays + calibration metadata.
|
||||
|
||||
With ctx=None this is the identity -- the record (and hence the emitted
|
||||
JSONL line) is byte-identical to the pre-calibration ADR-079 format.
|
||||
Raw image-coordinate keypoints are kept untouched in both cases; the
|
||||
room-frame representation is ADDED, never substituted, so training can
|
||||
choose either (ADR-152 S2.1.3).
|
||||
"""
|
||||
if ctx is None:
|
||||
return record
|
||||
if record.get("keypoints"):
|
||||
_, rays = ctx.transform_keypoints(record["keypoints"])
|
||||
record["keypoints_room"] = [[round(float(v), 5) for v in ray] for ray in rays]
|
||||
else:
|
||||
record["keypoints_room"] = []
|
||||
record["camera_origin_room"] = [round(float(v), 5) for v in ctx.origin_room]
|
||||
record["calibration_id"] = ctx.calibration_id
|
||||
record["transceiver_geometry"] = ctx.transceiver_geometry
|
||||
return record
|
||||
@@ -6,9 +6,19 @@ synchronizes with ESP32 CSI recording from the sensing server.
|
||||
|
||||
Output: JSONL file in data/ground-truth/ with per-frame 17-keypoint COCO poses.
|
||||
|
||||
With --calibration <bundle.json> (produced by scripts/calibrate-camera-room.py,
|
||||
ADR-152 S2.1.3), every record is additionally stamped with room-frame bearing
|
||||
rays for each keypoint, the calibration_id, and the transceiver geometry --
|
||||
the PerceptAlign-style defense against coordinate overfitting. Raw image
|
||||
coordinates are always kept; without depth the room-frame representation is
|
||||
a projective alignment (rays, not 3D points) -- see scripts/calibration_lib.py.
|
||||
Without --calibration the output is byte-identical to the original ADR-079
|
||||
format.
|
||||
|
||||
Usage:
|
||||
python scripts/collect-ground-truth.py --preview --duration 60
|
||||
python scripts/collect-ground-truth.py --server http://192.168.1.10:3000
|
||||
python scripts/collect-ground-truth.py --calibration data/calibration/camera-room.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -168,8 +178,23 @@ def main():
|
||||
default="data/ground-truth",
|
||||
help="Output directory (default: data/ground-truth)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calibration",
|
||||
default=None,
|
||||
help="Camera-room calibration bundle JSON from scripts/calibrate-camera-room.py "
|
||||
"(ADR-152 S2.1.3); adds room-frame keypoint rays + transceiver geometry "
|
||||
"to every record",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.calibration:
|
||||
print(
|
||||
"WARNING: no --calibration bundle; labels stay in raw camera coordinates "
|
||||
"and are layout-brittle (coordinate overfitting, ADR-152 S2.1.3) -- run "
|
||||
"scripts/calibrate-camera-room.py first.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# --- Resolve paths relative to repo root ---
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
output_dir = repo_root / args.output
|
||||
@@ -193,6 +218,25 @@ def main():
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
print(f"Camera opened: {frame_w}x{frame_h}")
|
||||
|
||||
# --- Load calibration bundle (ADR-152 S2.1.3) ---
|
||||
calib_ctx = None
|
||||
if args.calibration:
|
||||
# Lazy import keeps the no-calibration path identical to the original.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
import calibration_lib
|
||||
|
||||
try:
|
||||
calib_ctx = calibration_lib.load_calibration_context(
|
||||
Path(args.calibration), frame_w, frame_h
|
||||
)
|
||||
except (OSError, ValueError, json.JSONDecodeError) as exc:
|
||||
print(f"ERROR: Cannot load calibration bundle {args.calibration}: {exc}",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
n_nodes = len(calib_ctx.transceiver_geometry.get("nodes", []))
|
||||
print(f"Calibration: {calib_ctx.calibration_id[:23]}... "
|
||||
f"({n_nodes} transceiver node(s)); emitting room-frame keypoint rays")
|
||||
|
||||
# --- Create PoseLandmarker ---
|
||||
options = PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=str(model_path)),
|
||||
@@ -287,6 +331,10 @@ def main():
|
||||
"n_visible": n_visible,
|
||||
"n_persons": n_persons,
|
||||
}
|
||||
if calib_ctx is not None:
|
||||
# Adds keypoints_room (bearing rays), camera_origin_room,
|
||||
# calibration_id, transceiver_geometry (ADR-152 S2.1.3).
|
||||
record = calibration_lib.augment_record(record, calib_ctx)
|
||||
out_file.write(json.dumps(record) + "\n")
|
||||
frame_count += 1
|
||||
total_confidence += confidence
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Segmented overnight empty-room CSI capture (ADR-135 baseline / MAE corpus).
|
||||
|
||||
Binds UDP once and writes fixed-duration JSONL segments with explicit names —
|
||||
no post-hoc renaming, no glob collisions with other recordings.
|
||||
|
||||
Usage:
|
||||
python scripts/overnight-empty-capture.py --segments 8 --segment-seconds 3300
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
||||
|
||||
def parse_csi_packet(data):
|
||||
"""ADR-018 binary CSI packet → dict (same layout as record-csi-udp.py)."""
|
||||
if len(data) < 8:
|
||||
return None
|
||||
node_id = data[4]
|
||||
rssi = struct.unpack("b", bytes([data[6]]))[0]
|
||||
channel = data[7]
|
||||
iq = data[8:]
|
||||
amplitudes = []
|
||||
for i in range(0, len(iq) - 1, 2):
|
||||
I = struct.unpack("b", bytes([iq[i]]))[0]
|
||||
Q = struct.unpack("b", bytes([iq[i + 1]]))[0]
|
||||
amplitudes.append(round((I * I + Q * Q) ** 0.5, 2))
|
||||
return {
|
||||
"type": "raw_csi",
|
||||
"ts_ns": time.time_ns(),
|
||||
"node_id": node_id,
|
||||
"rssi": rssi,
|
||||
"channel": channel,
|
||||
"subcarriers": len(iq) // 2,
|
||||
"amplitudes": amplitudes,
|
||||
"iq_hex": iq.hex(),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--port", type=int, default=5005)
|
||||
ap.add_argument("--segments", type=int, default=8)
|
||||
ap.add_argument("--segment-seconds", type=int, default=3300)
|
||||
ap.add_argument("--output", default="data/recordings")
|
||||
ap.add_argument("--prefix", default="overnight-empty")
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.bind(("0.0.0.0", args.port))
|
||||
sock.settimeout(2.0)
|
||||
|
||||
for seg in range(1, args.segments + 1):
|
||||
path = os.path.join(
|
||||
args.output, f"{args.prefix}-seg{seg}-{int(time.time())}.csi.jsonl"
|
||||
)
|
||||
n = 0
|
||||
t_end = time.time() + args.segment_seconds
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
while time.time() < t_end:
|
||||
try:
|
||||
data, _ = sock.recvfrom(4096)
|
||||
except socket.timeout:
|
||||
continue
|
||||
rec = parse_csi_packet(data)
|
||||
if rec is not None:
|
||||
f.write(json.dumps(rec) + "\n")
|
||||
n += 1
|
||||
print(f"segment {seg}: {n} frames -> {path}", flush=True)
|
||||
|
||||
print("capture complete", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Make scripts/ importable for the calibration tests (ADR-152 S2.1.3)."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SCRIPTS_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(SCRIPTS_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Headless tests for the camera-room calibration pipeline (ADR-152 S2.1.3).
|
||||
|
||||
Covers calibration_lib.py end to end on synthetic data -- no camera, no
|
||||
display, no MediaPipe:
|
||||
* known extrinsics recovered from synthetic two-checkerboard corners
|
||||
* calibration bundle JSON round-trip + stable content hash
|
||||
* image->room keypoint transform correctness (rays pass through the
|
||||
original 3D points -- the projective, no-depth alignment of ADR-079
|
||||
labels into the shared room frame)
|
||||
* collect-ground-truth's no-calibration record path is byte-identical
|
||||
(augment_record with ctx=None is the identity)
|
||||
|
||||
Run: python -m pytest scripts/tests/ -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import calibration_lib as cal
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic scene fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
IMG_W, IMG_H = 1280, 720
|
||||
K_GT = np.array(
|
||||
[[800.0, 0.0, 640.0],
|
||||
[0.0, 800.0, 360.0],
|
||||
[0.0, 0.0, 1.0]]
|
||||
)
|
||||
DIST_ZERO = np.zeros(5)
|
||||
DIST_MILD = np.array([-0.10, 0.02, 0.001, -0.001, 0.0])
|
||||
|
||||
BOARD_COLS, BOARD_ROWS = 9, 6
|
||||
SQUARE_M = 0.025
|
||||
|
||||
|
||||
def look_at_pose(camera_pos, target):
|
||||
"""Ground-truth camera pose: returns (R_cam_to_room, camera_center_room).
|
||||
|
||||
Camera convention: +z forward (optical axis), +x right, +y down.
|
||||
"""
|
||||
c = np.asarray(camera_pos, dtype=np.float64)
|
||||
fwd = np.asarray(target, dtype=np.float64) - c
|
||||
fwd /= np.linalg.norm(fwd)
|
||||
up_room = np.array([0.0, 0.0, 1.0])
|
||||
x_cam = np.cross(fwd, -up_room)
|
||||
x_cam /= np.linalg.norm(x_cam)
|
||||
y_cam = np.cross(fwd, x_cam)
|
||||
r_cam_to_room = np.stack([x_cam, y_cam, fwd], axis=1) # columns = camera axes in room
|
||||
return r_cam_to_room, c
|
||||
|
||||
|
||||
def room_to_cam(r_cam_to_room, center):
|
||||
"""Invert to the solvePnP (room->camera) convention: rvec, tvec."""
|
||||
r_room_to_cam = r_cam_to_room.T
|
||||
tvec = -r_room_to_cam @ center
|
||||
rvec, _ = cv2.Rodrigues(r_room_to_cam)
|
||||
return rvec, tvec.reshape(3, 1)
|
||||
|
||||
|
||||
def project_room_points(points_room, r_cam_to_room, center, k=K_GT, dist=DIST_ZERO):
|
||||
rvec, tvec = room_to_cam(r_cam_to_room, center)
|
||||
proj, _ = cv2.projectPoints(np.asarray(points_room, dtype=np.float64), rvec, tvec, k, dist)
|
||||
return proj.reshape(-1, 2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scene():
|
||||
"""A camera in the room looking at the wall + floor checkerboards."""
|
||||
r_gt, c_gt = look_at_pose(camera_pos=[1.5, 3.0, 1.3], target=[1.0, 0.5, 0.8])
|
||||
wall_room = cal.board_room_points(
|
||||
BOARD_COLS, BOARD_ROWS, SQUARE_M,
|
||||
origin=[0.5, 0.0, 1.6], u_axis=cal.parse_axis("+x"), v_axis=cal.parse_axis("-z"),
|
||||
)
|
||||
floor_room = cal.board_room_points(
|
||||
BOARD_COLS, BOARD_ROWS, SQUARE_M,
|
||||
origin=[1.0, 1.0, 0.0], u_axis=cal.parse_axis("+x"), v_axis=cal.parse_axis("+y"),
|
||||
)
|
||||
return r_gt, c_gt, wall_room, floor_room
|
||||
|
||||
|
||||
def make_bundle(r_gt, c_gt, dist=DIST_ZERO):
|
||||
return cal.make_bundle(
|
||||
camera_intrinsics={
|
||||
"image_size": [IMG_W, IMG_H],
|
||||
"camera_matrix": K_GT.tolist(),
|
||||
"dist_coeffs": dist.tolist(),
|
||||
"reprojection_error_px": 0.0,
|
||||
"source": "synthetic",
|
||||
},
|
||||
camera_to_room_extrinsics={
|
||||
"rotation": r_gt.tolist(),
|
||||
"translation_m": c_gt.tolist(),
|
||||
"rmse_px": 0.0,
|
||||
},
|
||||
checkerboard_spec={"cols": BOARD_COLS, "rows": BOARD_ROWS, "square_size_mm": 25.0},
|
||||
transceiver_geometry={
|
||||
"nodes": [
|
||||
{"id": "esp32-s3-a", "position_m": [0.1, 2.4, 1.1], "antenna_yaw_deg": 180.0},
|
||||
{"id": "esp32-c6-b", "position_m": [3.2, 0.3, 0.9]},
|
||||
],
|
||||
"units": "meters",
|
||||
"source": "file",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extrinsics recovery from synthetic checkerboard corners
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtrinsicsRecovery:
|
||||
def test_two_board_combined_recovers_known_pose(self, scene):
|
||||
r_gt, c_gt, wall_room, floor_room = scene
|
||||
room_pts = np.concatenate([wall_room, floor_room], axis=0)
|
||||
img_pts = project_room_points(room_pts, r_gt, c_gt)
|
||||
|
||||
ext = cal.solve_extrinsics(room_pts, img_pts, K_GT, DIST_ZERO)
|
||||
|
||||
assert ext["rmse_px"] < 1e-3
|
||||
np.testing.assert_allclose(np.asarray(ext["translation_m"]), c_gt, atol=1e-4)
|
||||
r_delta = np.asarray(ext["rotation"]).T @ r_gt
|
||||
angle_deg = np.degrees(np.arccos(np.clip((np.trace(r_delta) - 1) / 2, -1, 1)))
|
||||
assert angle_deg < 0.01
|
||||
|
||||
def test_single_board_solves_agree(self, scene):
|
||||
# With correct corner ordering, each board alone recovers the same pose.
|
||||
r_gt, c_gt, wall_room, floor_room = scene
|
||||
ext_wall = cal.solve_extrinsics(
|
||||
wall_room, project_room_points(wall_room, r_gt, c_gt), K_GT, DIST_ZERO)
|
||||
ext_floor = cal.solve_extrinsics(
|
||||
floor_room, project_room_points(floor_room, r_gt, c_gt), K_GT, DIST_ZERO)
|
||||
consistency = cal.extrinsics_consistency(ext_wall, ext_floor)
|
||||
assert consistency["rotation_deg"] < 0.1
|
||||
assert consistency["translation_m"] < 1e-3
|
||||
|
||||
def test_reversed_corner_order_auto_recovered(self, scene):
|
||||
# findChessboardCorners may enumerate from either board end. A single
|
||||
# board cannot disambiguate that flip (centrosymmetric grid), but the
|
||||
# joint two-board solve can -- feed it a reversed wall ordering and
|
||||
# require the true pose back.
|
||||
r_gt, c_gt, wall_room, floor_room = scene
|
||||
wall_img = project_room_points(wall_room, r_gt, c_gt)
|
||||
floor_img = project_room_points(floor_room, r_gt, c_gt)
|
||||
ext = cal.solve_two_board_extrinsics(
|
||||
wall_room, wall_img[::-1].copy(), floor_room, floor_img,
|
||||
K_GT, DIST_ZERO)
|
||||
assert ext["wall_flipped"] is True
|
||||
assert ext["floor_flipped"] is False
|
||||
assert ext["rmse_px"] < 1e-3
|
||||
np.testing.assert_allclose(np.asarray(ext["translation_m"]), c_gt, atol=1e-3)
|
||||
|
||||
def test_joint_solver_matches_unflipped(self, scene):
|
||||
r_gt, c_gt, wall_room, floor_room = scene
|
||||
ext = cal.solve_two_board_extrinsics(
|
||||
wall_room, project_room_points(wall_room, r_gt, c_gt),
|
||||
floor_room, project_room_points(floor_room, r_gt, c_gt),
|
||||
K_GT, DIST_ZERO)
|
||||
assert ext["wall_flipped"] is False and ext["floor_flipped"] is False
|
||||
assert ext["per_board"]["wall"]["rmse_px"] < 1e-3
|
||||
assert ext["per_board"]["floor"]["rmse_px"] < 1e-3
|
||||
|
||||
def test_intrinsics_recovered_from_synthetic_views(self):
|
||||
# Several board views from different poses -> calibrateCamera should
|
||||
# get focal length / principal point close to ground truth.
|
||||
obj = cal.board_object_points(BOARD_COLS, BOARD_ROWS, SQUARE_M)
|
||||
poses = [
|
||||
([0.05, 1.2, 0.05], [0.10, 0.0, 0.06]),
|
||||
([-0.25, 1.0, 0.20], [0.10, 0.0, 0.06]),
|
||||
([0.45, 0.9, -0.15], [0.10, 0.0, 0.06]),
|
||||
([0.10, 1.4, 0.30], [0.10, 0.0, 0.06]),
|
||||
([-0.15, 0.8, -0.20], [0.10, 0.0, 0.06]),
|
||||
]
|
||||
corner_sets = []
|
||||
for cam_pos, target in poses:
|
||||
r, c = look_at_pose(cam_pos, target)
|
||||
# Embed the board rigidly in the y=0 plane (u=+x, v=+z) and view it.
|
||||
board_in_room = np.column_stack([obj[:, 0], obj[:, 2], obj[:, 1]])
|
||||
corner_sets.append(project_room_points(board_in_room, r, c))
|
||||
intr = cal.compute_intrinsics(corner_sets, (IMG_W, IMG_H),
|
||||
BOARD_COLS, BOARD_ROWS, SQUARE_M)
|
||||
k = np.asarray(intr["camera_matrix"])
|
||||
assert abs(k[0, 0] - K_GT[0, 0]) / K_GT[0, 0] < 0.05
|
||||
assert abs(k[1, 1] - K_GT[1, 1]) / K_GT[1, 1] < 0.05
|
||||
assert intr["reprojection_error_px"] < 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundle round-trip + content hash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBundle:
|
||||
def test_save_load_roundtrip(self, scene, tmp_path):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
bundle = make_bundle(r_gt, c_gt)
|
||||
path = tmp_path / "camera-room.json"
|
||||
cal.save_bundle(bundle, path)
|
||||
loaded = cal.load_bundle(path)
|
||||
assert loaded == bundle
|
||||
assert cal.calibration_id(loaded) == cal.calibration_id(bundle)
|
||||
|
||||
def test_bundle_schema_fields(self, scene):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
bundle = make_bundle(r_gt, c_gt)
|
||||
for key in ("schema_version", "method", "calibrated_at", "room_frame",
|
||||
"checkerboard_spec", "camera_intrinsics",
|
||||
"camera_to_room_extrinsics", "transceiver_geometry"):
|
||||
assert key in bundle
|
||||
assert bundle["method"] == "two-checkerboard"
|
||||
|
||||
def test_calibration_id_changes_with_content(self, scene):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
bundle_a = make_bundle(r_gt, c_gt)
|
||||
bundle_b = json.loads(json.dumps(bundle_a))
|
||||
bundle_b["transceiver_geometry"]["nodes"][0]["position_m"] = [0.2, 2.4, 1.1]
|
||||
assert cal.calibration_id(bundle_a) != cal.calibration_id(bundle_b)
|
||||
assert cal.calibration_id(bundle_a).startswith("sha256:")
|
||||
|
||||
def test_load_bundle_rejects_missing_keys(self, tmp_path):
|
||||
path = tmp_path / "bad.json"
|
||||
path.write_text('{"camera_intrinsics": {}}', encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="missing key"):
|
||||
cal.load_bundle(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keypoint transform: image -> room-frame bearing rays (projective alignment)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKeypointTransform:
|
||||
PERSON_POINTS = np.array([
|
||||
[1.2, 1.5, 1.7], # head height
|
||||
[1.1, 1.5, 1.4], # shoulder
|
||||
[1.3, 1.6, 0.9], # hip
|
||||
[1.2, 1.5, 0.1], # ankle
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("dist", [DIST_ZERO, DIST_MILD], ids=["no-distortion", "mild-distortion"])
|
||||
def test_rays_pass_through_original_points(self, scene, dist):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
img = project_room_points(self.PERSON_POINTS, r_gt, c_gt, dist=dist)
|
||||
kps_norm = (img / np.array([IMG_W, IMG_H])).tolist()
|
||||
|
||||
ctx = cal.CalibrationContext(make_bundle(r_gt, c_gt, dist=dist), IMG_W, IMG_H)
|
||||
origin, rays = ctx.transform_keypoints(kps_norm)
|
||||
|
||||
np.testing.assert_allclose(origin, c_gt, atol=1e-9)
|
||||
np.testing.assert_allclose(np.linalg.norm(rays, axis=1), 1.0, atol=1e-9)
|
||||
for point, ray in zip(self.PERSON_POINTS, rays):
|
||||
v = point - origin
|
||||
# Distance from the true 3D point to the recovered ray ~ 0, and
|
||||
# the point sits in FRONT of the camera along the ray.
|
||||
dist_to_ray = np.linalg.norm(v - np.dot(v, ray) * ray)
|
||||
assert dist_to_ray < 1e-4
|
||||
assert np.dot(v, ray) > 0
|
||||
|
||||
def test_resolution_scaling(self, scene):
|
||||
# Collection camera runs 640x360 while the bundle was made at
|
||||
# 1280x720 -- normalized keypoints must land on the same rays.
|
||||
r_gt, c_gt, _, _ = scene
|
||||
img = project_room_points(self.PERSON_POINTS, r_gt, c_gt)
|
||||
kps_norm = (img / np.array([IMG_W, IMG_H])).tolist()
|
||||
|
||||
ctx = cal.CalibrationContext(make_bundle(r_gt, c_gt), 640, 360)
|
||||
origin, rays = ctx.transform_keypoints(kps_norm)
|
||||
for point, ray in zip(self.PERSON_POINTS, rays):
|
||||
v = point - origin
|
||||
assert np.linalg.norm(v - np.dot(v, ray) * ray) < 1e-4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# collect-ground-truth record path (import-level; no camera loop)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRecordAugmentation:
|
||||
LEGACY_RECORD = {
|
||||
"ts_ns": 1775300000000000000,
|
||||
"keypoints": [[0.45, 0.12]] * 17,
|
||||
"confidence": 0.92,
|
||||
"n_visible": 14,
|
||||
"n_persons": 1,
|
||||
}
|
||||
|
||||
def test_no_calibration_is_byte_identical(self):
|
||||
# The collector's no---calibration path must emit exactly the
|
||||
# original ADR-079 JSONL line (back-compat guarantee).
|
||||
record = json.loads(json.dumps(self.LEGACY_RECORD))
|
||||
before = json.dumps(record)
|
||||
out = cal.augment_record(record, None)
|
||||
assert out is record
|
||||
assert json.dumps(out) == before
|
||||
assert set(out.keys()) == {"ts_ns", "keypoints", "confidence",
|
||||
"n_visible", "n_persons"}
|
||||
|
||||
def test_calibrated_record_gains_room_fields(self, scene):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
bundle = make_bundle(r_gt, c_gt)
|
||||
ctx = cal.CalibrationContext(bundle, IMG_W, IMG_H)
|
||||
|
||||
record = json.loads(json.dumps(self.LEGACY_RECORD))
|
||||
out = cal.augment_record(record, ctx)
|
||||
|
||||
# Raw image coords preserved untouched; room representation added.
|
||||
assert out["keypoints"] == self.LEGACY_RECORD["keypoints"]
|
||||
assert len(out["keypoints_room"]) == 17
|
||||
assert all(len(ray) == 3 for ray in out["keypoints_room"])
|
||||
assert out["calibration_id"] == cal.calibration_id(bundle)
|
||||
assert out["transceiver_geometry"] == bundle["transceiver_geometry"]
|
||||
assert len(out["camera_origin_room"]) == 3
|
||||
json.dumps(out) # remains JSONL-serializable
|
||||
|
||||
def test_empty_keypoints_record(self, scene):
|
||||
r_gt, c_gt, _, _ = scene
|
||||
ctx = cal.CalibrationContext(make_bundle(r_gt, c_gt), IMG_W, IMG_H)
|
||||
record = {"ts_ns": 1, "keypoints": [], "confidence": 0.0,
|
||||
"n_visible": 0, "n_persons": 0}
|
||||
out = cal.augment_record(record, ctx)
|
||||
assert out["keypoints_room"] == []
|
||||
assert "calibration_id" in out
|
||||
Generated
+12
-12
@@ -7328,9 +7328,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-attention"
|
||||
version = "2.0.4"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb4233c1cecd0ea826d95b787065b398489328885042247ff5ffcbb774e864ff"
|
||||
checksum = "a92e8e456458188d04aee946579aa7cf96d7b8f276cbf6094532b2c3f6d8cc0b"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
@@ -7395,14 +7395,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-gnn"
|
||||
version = "2.0.5"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e17c1cf1ff3380026b299ff3c1ba3a5685c3d8d54700e6ab0b585b6cec21d7b"
|
||||
checksum = "a251f9ced8d3231395d922369edc803ef0fc513c7776128f7b4ef21f20dd1f4b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"dashmap",
|
||||
"libc",
|
||||
"ndarray 0.16.1",
|
||||
"ndarray 0.17.2",
|
||||
"parking_lot",
|
||||
"rand 0.8.5",
|
||||
"rand_distr 0.4.3",
|
||||
@@ -7415,9 +7415,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-mincut"
|
||||
version = "2.0.4"
|
||||
version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d62e10cbb7d80b1e2b72d55c1e3eb7f0c4c5e3f31984bc3baa9b7a02700741e"
|
||||
checksum = "d60947433f740d0f589a2911d7b72a02e07a916e7257e478b14386f0ff068fb7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"crossbeam",
|
||||
@@ -7437,9 +7437,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-solver"
|
||||
version = "2.0.4"
|
||||
version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce69cbde4ee5747281edb1d987a8292940397723924262b6218fc19022cbf687"
|
||||
checksum = "9be7c4f61940ae8b451f88b9a629a08ee8ee5c8e6b00ab96ca10ecf59e70f558"
|
||||
dependencies = [
|
||||
"dashmap",
|
||||
"getrandom 0.2.17",
|
||||
@@ -11041,7 +11041,7 @@ version = "0.3.1"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"criterion",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attention 2.1.0",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-core",
|
||||
"ruvector-crv",
|
||||
@@ -11103,7 +11103,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"proptest",
|
||||
"rustfft",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attention 2.1.0",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-mincut",
|
||||
"ruvector-solver",
|
||||
@@ -11134,7 +11134,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"petgraph",
|
||||
"proptest",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attention 2.1.0",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-mincut",
|
||||
"ruvector-solver",
|
||||
|
||||
+6
-5
@@ -187,15 +187,16 @@ midstreamer-temporal-compare = "0.2"
|
||||
midstreamer-attractor = "0.2"
|
||||
|
||||
# ruvector integration (published on crates.io)
|
||||
# Vendored at v2.1.0 in vendor/ruvector; using crates.io versions until published.
|
||||
# Vendored at origin/main (a083bd77f) in vendor/ruvector; using crates.io versions
|
||||
# until published. Bumps per ADR-152 §2.6 (2026-06-10 vendor sync survey).
|
||||
ruvector-core = "2.2.0"
|
||||
ruvector-mincut = "2.0.4"
|
||||
ruvector-mincut = "2.0.6"
|
||||
ruvector-attn-mincut = "2.0.4"
|
||||
ruvector-temporal-tensor = "2.0.6"
|
||||
ruvector-solver = "2.0.4"
|
||||
ruvector-attention = "2.0.4"
|
||||
ruvector-solver = "2.0.6"
|
||||
ruvector-attention = "2.1.0"
|
||||
ruvector-crv = "0.1.1"
|
||||
ruvector-gnn = { version = "2.0.5", default-features = false }
|
||||
ruvector-gnn = { version = "2.2.0", default-features = false }
|
||||
|
||||
|
||||
# Internal crates
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::geometry::NodeGeometry;
|
||||
|
||||
/// Coarse posture an anchor establishes.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Posture {
|
||||
@@ -96,9 +98,7 @@ impl AnchorLabel {
|
||||
/// Suggested capture duration (seconds).
|
||||
pub fn duration_s(&self) -> u32 {
|
||||
match self {
|
||||
AnchorLabel::BreatheSlow
|
||||
| AnchorLabel::BreatheNormal
|
||||
| AnchorLabel::SleepPosture => 30,
|
||||
AnchorLabel::BreatheSlow | AnchorLabel::BreatheNormal | AnchorLabel::SleepPosture => 30,
|
||||
_ => 20,
|
||||
}
|
||||
}
|
||||
@@ -165,6 +165,17 @@ pub enum EnrollmentEvent {
|
||||
/// The accepted anchor.
|
||||
anchor: Anchor,
|
||||
},
|
||||
/// Transceiver geometry recorded for the session's nodes (ADR-152 §2.1.1).
|
||||
/// Typically appended right after `Started`; a later event supersedes an
|
||||
/// earlier one (latest wins), so a geometry correction is an append, not a
|
||||
/// rewrite. Sessions persisted before this variant existed replay cleanly —
|
||||
/// the variant is additive to the externally-tagged event encoding.
|
||||
GeometryRecorded {
|
||||
/// Per-node geometry records.
|
||||
geometry: Vec<NodeGeometry>,
|
||||
/// Unix seconds.
|
||||
at: i64,
|
||||
},
|
||||
/// An anchor failed the gate (re-prompt).
|
||||
AnchorRejected {
|
||||
/// Which anchor.
|
||||
@@ -230,6 +241,21 @@ impl EnrollmentSession {
|
||||
out
|
||||
}
|
||||
|
||||
/// Record the session's transceiver geometry (ADR-152 §2.1.1) — appends a
|
||||
/// [`EnrollmentEvent::GeometryRecorded`] event; the latest recording wins.
|
||||
pub fn record_geometry(&mut self, geometry: Vec<NodeGeometry>, at: i64) {
|
||||
self.apply(EnrollmentEvent::GeometryRecorded { geometry, at });
|
||||
}
|
||||
|
||||
/// The geometry snapshot in effect (latest `GeometryRecorded` event), if
|
||||
/// any was recorded. Derived from the event log, never stored separately.
|
||||
pub fn geometry(&self) -> Option<&[NodeGeometry]> {
|
||||
self.events.iter().rev().find_map(|ev| match ev {
|
||||
EnrollmentEvent::GeometryRecorded { geometry, .. } => Some(geometry.as_slice()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// The next anchor in the canonical sequence not yet accepted, if any.
|
||||
pub fn next_anchor(&self) -> Option<AnchorLabel> {
|
||||
let accepted = self.accepted_anchors();
|
||||
@@ -241,10 +267,7 @@ impl EnrollmentSession {
|
||||
|
||||
/// `(accepted, total)` progress.
|
||||
pub fn progress(&self) -> (usize, usize) {
|
||||
(
|
||||
self.accepted_anchors().len(),
|
||||
AnchorLabel::SEQUENCE.len(),
|
||||
)
|
||||
(self.accepted_anchors().len(), AnchorLabel::SEQUENCE.len())
|
||||
}
|
||||
|
||||
/// Whether every anchor in the sequence has been accepted.
|
||||
@@ -340,6 +363,47 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometry_recorded_latest_wins_and_roundtrips() {
|
||||
let mut s = EnrollmentSession::new("r", "b", 0);
|
||||
assert!(s.geometry().is_none(), "no geometry before recording");
|
||||
|
||||
s.record_geometry(vec![NodeGeometry::unknown(1)], 5);
|
||||
let corrected = vec![
|
||||
NodeGeometry::new(1, "tape-measure").with_position(0.0, 0.0, 1.0),
|
||||
NodeGeometry::new(2, "tape-measure")
|
||||
.with_position(3.0, 0.0, 1.0)
|
||||
.with_distance(1, 3.0),
|
||||
];
|
||||
s.record_geometry(corrected.clone(), 10);
|
||||
|
||||
// Latest recording wins, derived from the event log.
|
||||
assert_eq!(s.geometry(), Some(corrected.as_slice()));
|
||||
|
||||
// The whole session (incl. geometry events) survives persistence.
|
||||
let json = serde_json::to_string(&s).unwrap();
|
||||
let back: EnrollmentSession = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.geometry(), Some(corrected.as_slice()));
|
||||
assert_eq!(back.events.len(), s.events.len());
|
||||
}
|
||||
|
||||
/// Sessions persisted BEFORE the GeometryRecorded variant existed must
|
||||
/// deserialize cleanly and report no geometry (ADR-152 schema-compat rule).
|
||||
#[test]
|
||||
fn pre_geometry_session_json_loads() {
|
||||
let old_json = r#"{
|
||||
"room_id": "r",
|
||||
"baseline_id": "b",
|
||||
"events": [
|
||||
{"Started": {"room_id": "r", "baseline_id": "b", "at": 0}},
|
||||
{"AnchorRejected": {"label": "sit", "reason": "no person", "at": 1}}
|
||||
]
|
||||
}"#;
|
||||
let s: EnrollmentSession = serde_json::from_str(old_json).unwrap();
|
||||
assert!(s.geometry().is_none());
|
||||
assert_eq!(s.next_anchor(), Some(AnchorLabel::Empty));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn posture_mapping() {
|
||||
assert_eq!(AnchorLabel::StandStill.posture(), Some(Posture::Standing));
|
||||
|
||||
@@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{CalibrationError, Result};
|
||||
use crate::extract::AnchorFeature;
|
||||
use crate::geometry::NodeGeometry;
|
||||
use crate::specialist::{
|
||||
AnomalySpecialist, BreathingSpecialist, HeartbeatSpecialist, PostureSpecialist,
|
||||
PresenceSpecialist, RestlessnessSpecialist, SpecialistKind,
|
||||
@@ -26,6 +27,13 @@ pub struct SpecialistBank {
|
||||
pub trained_at_unix_s: i64,
|
||||
/// Number of anchors used.
|
||||
pub anchor_count: usize,
|
||||
/// Transceiver geometry snapshot the bank was trained under (ADR-152
|
||||
/// §2.1.1). Empty both for banks persisted before geometry existed (serde
|
||||
/// default — same pattern as `PresenceSpecialist::mean_dist_threshold`) and
|
||||
/// for enrollments where no geometry was recorded. Statistical specialists
|
||||
/// ignore it; the ADR-151 P6 LoRA heads will consume it (ADR-152 §2.1.2).
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub geometry: Vec<NodeGeometry>,
|
||||
|
||||
/// Presence gate (requires the `empty` + an occupied anchor).
|
||||
pub presence: Option<PresenceSpecialist>,
|
||||
@@ -65,6 +73,7 @@ impl SpecialistBank {
|
||||
baseline_id: baseline_id.into(),
|
||||
trained_at_unix_s: at_unix_s,
|
||||
anchor_count: anchors.len(),
|
||||
geometry: Vec::new(),
|
||||
presence: PresenceSpecialist::train(anchors),
|
||||
posture: PostureSpecialist::train(anchors),
|
||||
breathing: BreathingSpecialist::default(),
|
||||
@@ -74,6 +83,22 @@ impl SpecialistBank {
|
||||
})
|
||||
}
|
||||
|
||||
/// Attach the enrollment's transceiver-geometry snapshot (ADR-152 §2.1.1),
|
||||
/// builder style — typically `EnrollmentSession::geometry()` at train time.
|
||||
pub fn with_geometry(mut self, geometry: Vec<NodeGeometry>) -> Self {
|
||||
self.geometry = geometry;
|
||||
self
|
||||
}
|
||||
|
||||
/// The fixed-length geometry embedding of the bank's snapshot (ADR-152
|
||||
/// §2.1.2) — the conditioning vector the ADR-151 P6 LoRA heads concatenate
|
||||
/// with the backbone embedding. Derived on demand from [`Self::geometry`]
|
||||
/// (it is a pure function of the snapshot), so it adds no schema surface;
|
||||
/// a geometry-free bank yields the well-defined all-zero embedding.
|
||||
pub fn geometry_embedding(&self) -> crate::geometry_embedding::GeometryEmbedding {
|
||||
crate::geometry_embedding::GeometryEmbedding::from_nodes(&self.geometry)
|
||||
}
|
||||
|
||||
/// `true` if the bank was trained against a different baseline (it is STALE).
|
||||
pub fn is_stale(&self, current_baseline_id: &str) -> bool {
|
||||
self.baseline_id != current_baseline_id
|
||||
@@ -178,6 +203,70 @@ mod tests {
|
||||
assert_eq!(back.anchor_count, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometry_snapshot_roundtrips() {
|
||||
let geometry = vec![
|
||||
NodeGeometry::new(1, "tape-measure").with_position(0.0, 0.0, 1.0),
|
||||
NodeGeometry::unknown(2),
|
||||
];
|
||||
let bank = SpecialistBank::train("r", "base-1", &full_anchors(), 1000)
|
||||
.unwrap()
|
||||
.with_geometry(geometry.clone());
|
||||
let json = bank.to_json().unwrap();
|
||||
let back = SpecialistBank::from_json(&json).unwrap();
|
||||
assert_eq!(back.geometry, geometry);
|
||||
}
|
||||
|
||||
/// ADR-152 §2.1.2: the embedding is derived from the snapshot — present
|
||||
/// geometry conditions it, absent geometry yields the all-zero vector.
|
||||
#[test]
|
||||
fn geometry_embedding_derives_from_snapshot() {
|
||||
let bare = SpecialistBank::train("r", "base-1", &full_anchors(), 1000).unwrap();
|
||||
assert_eq!(
|
||||
bare.geometry_embedding(),
|
||||
crate::geometry_embedding::GeometryEmbedding::default(),
|
||||
"no geometry → all-zero embedding"
|
||||
);
|
||||
|
||||
let geometry = vec![
|
||||
NodeGeometry::new(1, "tape-measure").with_position(0.0, 0.0, 1.0),
|
||||
NodeGeometry::new(2, "tape-measure").with_position(3.0, 0.0, 1.0),
|
||||
];
|
||||
let bank = bare.with_geometry(geometry.clone());
|
||||
let emb = bank.geometry_embedding();
|
||||
assert_eq!(
|
||||
emb,
|
||||
crate::geometry_embedding::GeometryEmbedding::from_nodes(&geometry),
|
||||
"embedding is a pure function of the snapshot"
|
||||
);
|
||||
assert!(emb.as_slice().iter().any(|&x| x != 0.0));
|
||||
}
|
||||
|
||||
/// ADR-152 schema-compat fixture: bank JSON persisted BEFORE the geometry
|
||||
/// field existed (captured from the pre-ADR-152 serializer shape) must
|
||||
/// deserialize cleanly with an empty geometry snapshot.
|
||||
#[test]
|
||||
fn pre_geometry_bank_json_loads() {
|
||||
let old_json = r#"{
|
||||
"room_id": "living-room",
|
||||
"baseline_id": "base-1",
|
||||
"trained_at_unix_s": 1000,
|
||||
"anchor_count": 2,
|
||||
"presence": {"threshold": 5.5, "occupied_var": 10.0},
|
||||
"posture": null,
|
||||
"breathing": {"min_score": 0.0},
|
||||
"heartbeat": {"min_score": 0.0},
|
||||
"restlessness": null,
|
||||
"anomaly": null
|
||||
}"#;
|
||||
let bank = SpecialistBank::from_json(old_json).unwrap();
|
||||
assert!(bank.geometry.is_empty(), "old banks carry no geometry");
|
||||
assert_eq!(bank.room_id, "living-room");
|
||||
assert!(bank.presence.is_some());
|
||||
// And a geometry-free bank serializes without the field (old shape).
|
||||
assert!(!bank.to_json().unwrap().contains("geometry"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn staleness() {
|
||||
let bank = SpecialistBank::train("r", "base-1", &full_anchors(), 1000).unwrap();
|
||||
|
||||
@@ -203,13 +203,13 @@ impl AnchorRecorder {
|
||||
|
||||
/// Evaluate the capture against the gate and produce an `Anchor` (accepted
|
||||
/// or not) plus a rejection reason.
|
||||
pub fn finalize(
|
||||
&self,
|
||||
gate: &AnchorQualityGate,
|
||||
at_unix_s: i64,
|
||||
) -> (Anchor, Option<String>) {
|
||||
let (quality, reason) =
|
||||
gate.evaluate(self.label, self.presence_z(), self.motion_rate(), self.frames);
|
||||
pub fn finalize(&self, gate: &AnchorQualityGate, at_unix_s: i64) -> (Anchor, Option<String>) {
|
||||
let (quality, reason) = gate.evaluate(
|
||||
self.label,
|
||||
self.presence_z(),
|
||||
self.motion_rate(),
|
||||
self.frames,
|
||||
);
|
||||
(
|
||||
Anchor {
|
||||
label: self.label,
|
||||
@@ -255,7 +255,13 @@ mod tests {
|
||||
/// Alternating z (every frame's |Δz| exceeds Z_DELTA_MOTION ⇒ all motion).
|
||||
fn run_jittery(label: AnchorLabel, z: f32, n: usize) -> (Anchor, Option<String>) {
|
||||
let zs: Vec<f32> = (0..n)
|
||||
.map(|i| if i % 2 == 0 { z } else { z + 2.0 * Z_DELTA_MOTION })
|
||||
.map(|i| {
|
||||
if i % 2 == 0 {
|
||||
z
|
||||
} else {
|
||||
z + 2.0 * Z_DELTA_MOTION
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
run_series(label, &zs)
|
||||
}
|
||||
@@ -268,7 +274,10 @@ mod tests {
|
||||
let (a, reason) = run_still(AnchorLabel::StandStill, 3.0, 400);
|
||||
assert!(a.quality.accepted, "z-band squeeze is back: {reason:?}");
|
||||
assert!(reason.is_none());
|
||||
assert!(a.quality.motion_rate < 0.05, "flat z-series must read still");
|
||||
assert!(
|
||||
a.quality.motion_rate < 0.05,
|
||||
"flat z-series must read still"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -301,7 +310,11 @@ mod tests {
|
||||
let mut r = AnchorRecorder::new(AnchorLabel::LieDown);
|
||||
for i in 0..400 {
|
||||
let mut s = score(1.8);
|
||||
s.phase_drift_median = if i % 2 == 0 { 0.0 } else { PHASE_DELTA_MOTION * 1.5 };
|
||||
s.phase_drift_median = if i % 2 == 0 {
|
||||
0.0
|
||||
} else {
|
||||
PHASE_DELTA_MOTION * 1.5
|
||||
};
|
||||
r.record_score(&s);
|
||||
}
|
||||
let (a, reason) = r.finalize(&AnchorQualityGate::default(), 100);
|
||||
|
||||
@@ -58,7 +58,13 @@ impl Features {
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
[self.mean, self.variance, self.motion, breathing_hz, heart_hz]
|
||||
[
|
||||
self.mean,
|
||||
self.variance,
|
||||
self.motion,
|
||||
breathing_hz,
|
||||
heart_hz,
|
||||
]
|
||||
}
|
||||
|
||||
/// Squared Euclidean distance between two embeddings.
|
||||
@@ -85,8 +91,7 @@ impl Features {
|
||||
};
|
||||
}
|
||||
let mean = series.iter().copied().sum::<f32>() / n as f32;
|
||||
let variance =
|
||||
series.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / n as f32;
|
||||
let variance = series.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / n as f32;
|
||||
let motion = if n > 1 {
|
||||
series.windows(2).map(|w| (w[1] - w[0]).abs()).sum::<f32>() / (n - 1) as f32
|
||||
} else {
|
||||
@@ -234,8 +239,12 @@ mod tests {
|
||||
#[test]
|
||||
fn motion_distinguishes_still_from_noisy() {
|
||||
let still = vec![1.0f32; 200];
|
||||
let noisy: Vec<f32> = (0..200).map(|i| if i % 2 == 0 { 0.0 } else { 5.0 }).collect();
|
||||
assert!(Features::from_series(&still, 15.0).motion < Features::from_series(&noisy, 15.0).motion);
|
||||
let noisy: Vec<f32> = (0..200)
|
||||
.map(|i| if i % 2 == 0 { 0.0 } else { 5.0 })
|
||||
.collect();
|
||||
assert!(
|
||||
Features::from_series(&still, 15.0).motion < Features::from_series(&noisy, 15.0).motion
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
//! Transceiver-geometry records (ADR-152 §2.1.1, extends ADR-151 Stage 2).
|
||||
//!
|
||||
//! PerceptAlign (ADR-152 F1) diagnosed "coordinate overfitting": pose heads
|
||||
//! trained without an explicit layout model memorise the deployment-specific
|
||||
//! transceiver geometry and break in unseen rooms. The first, cheap half of
|
||||
//! the fix is to *record* the geometry at enrollment so every specialist bank
|
||||
//! knows the layout it was trained under.
|
||||
//!
|
||||
//! This module is the record only. The learned geometry *embeddings* that
|
||||
//! condition specialist heads (ADR-152 §2.1.2) are out of scope until the
|
||||
//! ADR-151 P6 LoRA heads exist — statistical specialists ignore geometry.
|
||||
//!
|
||||
//! Every field is optional **by design**: geometry is captured when the
|
||||
//! operator knows it (tape measure, checkerboard calibration, installer
|
||||
//! floor plan) and omitted when they don't. An all-unknown record is still
|
||||
//! useful — it pins down *which* nodes existed and that geometry was not
|
||||
//! measured, rather than leaving the question open.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Estimated node position in the room frame (meters).
|
||||
///
|
||||
/// The room frame is whatever frame the recording `method` defines (e.g. a
|
||||
/// tape-measure origin at a room corner, or the shared 3D frame of the
|
||||
/// two-checkerboard alignment, ADR-152 §2.1.3). Consistency *within* one
|
||||
/// enrollment is what matters; there is no global frame.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PositionEstimate {
|
||||
/// X coordinate (meters).
|
||||
pub x_m: f32,
|
||||
/// Y coordinate (meters).
|
||||
pub y_m: f32,
|
||||
/// Z coordinate / height (meters).
|
||||
pub z_m: f32,
|
||||
}
|
||||
|
||||
/// Antenna boresight orientation (radians, room frame).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AntennaOrientation {
|
||||
/// Azimuth from the room frame's +X axis, counter-clockwise (radians).
|
||||
pub azimuth_rad: f32,
|
||||
/// Elevation above the horizontal plane (radians).
|
||||
pub elevation_rad: f32,
|
||||
}
|
||||
|
||||
fn unknown_method() -> String {
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
/// Per-node transceiver geometry recorded at enrollment (ADR-152 §2.1.1).
|
||||
///
|
||||
/// Stored in the [`EnrollmentSession`](crate::EnrollmentSession) event log and
|
||||
/// snapshotted into the [`SpecialistBank`](crate::SpecialistBank), so a bank
|
||||
/// always carries the layout it was trained under. Schema-versioned: banks and
|
||||
/// sessions persisted before this record existed deserialize with no geometry
|
||||
/// (serde defaults), same pattern as `PresenceSpecialist::mean_dist_threshold`.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeGeometry {
|
||||
/// Node this record describes (same id space as the multistatic fusion).
|
||||
pub node_id: u8,
|
||||
/// Estimated position, if measured.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub position: Option<PositionEstimate>,
|
||||
/// Antenna orientation, if measured.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub orientation: Option<AntennaOrientation>,
|
||||
/// Known distances to other nodes (node_id → meters). Empty = not measured.
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub distances_m: BTreeMap<u8, f32>,
|
||||
/// How the geometry was obtained — free-form provenance, e.g.
|
||||
/// `"tape-measure"`, `"checkerboard"`, `"floor-plan"`, `"unknown"`.
|
||||
#[serde(default = "unknown_method")]
|
||||
pub method: String,
|
||||
}
|
||||
|
||||
impl NodeGeometry {
|
||||
/// A record with everything unknown except the node id.
|
||||
pub fn unknown(node_id: u8) -> Self {
|
||||
Self::new(node_id, "unknown")
|
||||
}
|
||||
|
||||
/// A record with no measurements yet, tagged with its provenance method.
|
||||
pub fn new(node_id: u8, method: impl Into<String>) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
position: None,
|
||||
orientation: None,
|
||||
distances_m: BTreeMap::new(),
|
||||
method: method.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the position estimate (builder style).
|
||||
pub fn with_position(mut self, x_m: f32, y_m: f32, z_m: f32) -> Self {
|
||||
self.position = Some(PositionEstimate { x_m, y_m, z_m });
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the antenna orientation (builder style).
|
||||
pub fn with_orientation(mut self, azimuth_rad: f32, elevation_rad: f32) -> Self {
|
||||
self.orientation = Some(AntennaOrientation {
|
||||
azimuth_rad,
|
||||
elevation_rad,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Record a known distance to another node (builder style).
|
||||
pub fn with_distance(mut self, other_node_id: u8, meters: f32) -> Self {
|
||||
self.distances_m.insert(other_node_id, meters);
|
||||
self
|
||||
}
|
||||
|
||||
/// `true` when nothing beyond the node id was measured.
|
||||
pub fn is_unmeasured(&self) -> bool {
|
||||
self.position.is_none() && self.orientation.is_none() && self.distances_m.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn full_record_roundtrips() {
|
||||
let g = NodeGeometry::new(1, "tape-measure")
|
||||
.with_position(0.5, 2.0, 1.2)
|
||||
.with_orientation(std::f32::consts::FRAC_PI_2, 0.0)
|
||||
.with_distance(2, 3.4);
|
||||
let json = serde_json::to_string(&g).unwrap();
|
||||
let back: NodeGeometry = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, g);
|
||||
assert_eq!(back.distances_m.get(&2), Some(&3.4));
|
||||
assert!(!back.is_unmeasured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_optional_empty_roundtrips() {
|
||||
let g = NodeGeometry::unknown(7);
|
||||
assert!(g.is_unmeasured());
|
||||
let json = serde_json::to_string(&g).unwrap();
|
||||
// Optional fields must be omitted, not serialized as null/empty.
|
||||
assert!(!json.contains("position"));
|
||||
assert!(!json.contains("orientation"));
|
||||
assert!(!json.contains("distances_m"));
|
||||
let back: NodeGeometry = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, g);
|
||||
assert_eq!(back.method, "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimal_json_defaults_cleanly() {
|
||||
// A record written by a producer that only knew the node id.
|
||||
let g: NodeGeometry = serde_json::from_str(r#"{"node_id":3}"#).unwrap();
|
||||
assert_eq!(g.node_id, 3);
|
||||
assert!(g.is_unmeasured());
|
||||
assert_eq!(g.method, "unknown");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,499 @@
|
||||
//! Geometry embedding — deterministic featurization of transceiver layout
|
||||
//! (ADR-152 §2.1.2, the second half of the PerceptAlign fix).
|
||||
//!
|
||||
//! §2.1.1 ([`geometry`](crate::geometry)) *records* the layout; this module
|
||||
//! turns that record into a fixed-length conditioning vector. PerceptAlign
|
||||
//! fuses transceiver-position embeddings with CSI features so pose heads stop
|
||||
//! memorising the deployment layout; transplanted to our per-room banks, the
|
||||
//! ADR-151 P6 LoRA heads will concatenate this vector with the backbone
|
||||
//! embedding. Statistical specialists (current) ignore it. The crate is pure
|
||||
//! Rust and edge-deployable (no torch/candle), so the "embedding" is **not a
|
||||
//! trained network** — it is a deterministic, well-conditioned featurization;
|
||||
//! the learned part (if any) lives in the head that consumes it.
|
||||
//!
|
||||
//! Properties, by construction: **fixed dimension** ([`GeometryEmbedding::DIM`]
|
||||
//! = 32) for any node count (designed for 1..=8; more nodes still aggregate,
|
||||
//! only the per-node flag slots truncate); **permutation-invariant** (nodes
|
||||
//! sorted by `node_id`; aggregates are order-free); and **total** — missing
|
||||
//! data degrades gracefully: an all-unknown layout (or empty slice) yields a
|
||||
//! well-defined vector, never `NaN`/`inf`; adversarial inputs (non-finite
|
||||
//! coordinates, absurd magnitudes) are treated as unmeasured.
|
||||
//!
|
||||
//! ## Slot layout (v1)
|
||||
//!
|
||||
//! Positions/distances are raw meters (room-scale values are already
|
||||
//! O(1)–O(10)); angles in radians; fractions in `[0, 1]`. Unmeasurable
|
||||
//! slots are `0.0`.
|
||||
//!
|
||||
//! | Slot | Content | Units / range |
|
||||
//! |-------|---------|----------------|
|
||||
//! | 0 | node count / 8 | `[0, 2]` (clamped; 8 nodes → 1.0) |
|
||||
//! | 1 | fraction of nodes with a position | `[0, 1]` |
|
||||
//! | 2 | fraction of nodes with an orientation | `[0, 1]` |
|
||||
//! | 3 | fraction of nodes with ≥1 measured inter-node distance | `[0, 1]` |
|
||||
//! | 4–6 | position centroid (x, y, z) | m, clamped ±[`MAX_COORD_M`] |
|
||||
//! | 7–9 | position std-dev per axis (x, y, z) | m, `[0,` [`MAX_COORD_M`]`]` |
|
||||
//! | 10–12 | pairwise position distance min / mean / max | m |
|
||||
//! | 13–15 | inter-node distance min / mean / max — measured `distances_m`, falling back to position-derived distance per pair | m |
|
||||
//! | 16 | measured-distance pair coverage (measured pairs / possible pairs) | `[0, 1]` |
|
||||
//! | 17–18 | azimuth circular mean resultant vector (cos, sin components) | `[-1, 1]` |
|
||||
//! | 19 | azimuth concentration (mean resultant length `R`; 1 = all boresights parallel) | `[0, 1]` |
|
||||
//! | 20 | mean elevation | rad, `[-π/2, π/2]` |
|
||||
//! | 21–22 | geometric diversity: eigenvalue ratios `λ2/λ1`, `λ3/λ1` of the position covariance — 0 = collinear/degenerate, →1 = isotropic spread (chosen over polygon area: defined for any node count, no 2-D planarity assumption) | `[0, 1]` |
|
||||
//! | 23 | dominant spread scale `sqrt(λ1)` | m |
|
||||
//! | 24–31 | per-node measurement flags, nodes sorted by `node_id`, rank `i` → slot `24+i` (first 8 nodes): `0` = no node at this rank, else `0.25` (node exists) `+0.25` (position) `+0.25` (orientation) `+0.25` (≥1 measured distance) | `{0}` ∪ `[0.25, 1]` |
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::geometry::NodeGeometry;
|
||||
|
||||
/// Coordinates / distances beyond this magnitude (meters) are treated as
|
||||
/// unmeasured — rooms are not kilometer-scale, and the guard keeps
|
||||
/// adversarial values from overflowing the covariance into `inf`.
|
||||
pub const MAX_COORD_M: f32 = 1_000.0;
|
||||
|
||||
/// Number of per-node flag slots (slots 24..32); designed node count 1..=8.
|
||||
const NODE_SLOTS: usize = 8;
|
||||
|
||||
fn schema_v1() -> u32 {
|
||||
GeometryEmbedding::SCHEMA_VERSION
|
||||
}
|
||||
|
||||
/// Fixed-length featurization of a room's transceiver layout (ADR-152 §2.1.2).
|
||||
///
|
||||
/// Computed deterministically from the [`NodeGeometry`] snapshot via
|
||||
/// [`GeometryEmbedding::from_nodes`]; the conditioning input the ADR-151 P6
|
||||
/// LoRA heads concatenate with the backbone embedding. Not stored in the bank
|
||||
/// — derive it via [`SpecialistBank::geometry_embedding`](crate::SpecialistBank::geometry_embedding)
|
||||
/// — but schema-versioned and serde-serializable (the `NodeGeometry` compat
|
||||
/// pattern) for callers that snapshot it alongside trained head weights.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GeometryEmbedding {
|
||||
/// Slot-layout version; bump when the slot table changes meaning.
|
||||
#[serde(default = "schema_v1")]
|
||||
pub schema_version: u32,
|
||||
/// The embedding vector — see the module docs for the slot table.
|
||||
/// Invariant: every value is finite (never `NaN`/`inf`).
|
||||
pub values: [f32; GeometryEmbedding::DIM],
|
||||
}
|
||||
|
||||
impl Default for GeometryEmbedding {
|
||||
/// All slots zero — the embedding of an empty layout.
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
schema_version: Self::SCHEMA_VERSION,
|
||||
values: [0.0; Self::DIM],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GeometryEmbedding {
|
||||
/// Output dimension. Fixed regardless of node count.
|
||||
pub const DIM: usize = 32;
|
||||
|
||||
/// Current slot-layout version.
|
||||
pub const SCHEMA_VERSION: u32 = 1;
|
||||
|
||||
/// The embedding as a slice (always [`Self::DIM`] long).
|
||||
pub fn as_slice(&self) -> &[f32] {
|
||||
&self.values
|
||||
}
|
||||
|
||||
/// Compute the embedding from a geometry snapshot. Permutation-invariant
|
||||
/// (nodes are sorted by `node_id` internally) and total: any input —
|
||||
/// empty, all-unknown, non-finite — produces a fully finite vector.
|
||||
pub fn from_nodes(nodes: &[NodeGeometry]) -> Self {
|
||||
let mut v = [0.0f32; Self::DIM];
|
||||
|
||||
// Permutation invariance: order by node_id before per-node slots.
|
||||
let mut sorted: Vec<&NodeGeometry> = nodes.iter().collect();
|
||||
sorted.sort_by_key(|g| g.node_id);
|
||||
let n = sorted.len();
|
||||
if n == 0 {
|
||||
return Self::default();
|
||||
}
|
||||
|
||||
// Sanitized views: a measurement with non-finite or absurd components
|
||||
// counts as not taken at all.
|
||||
let positions: Vec<Option<[f32; 3]>> = sorted.iter().map(|g| valid_position(g)).collect();
|
||||
let orientations: Vec<Option<(f32, f32)>> =
|
||||
sorted.iter().map(|g| valid_orientation(g)).collect();
|
||||
let measured = measured_pairs(&sorted);
|
||||
let node_has_dist = |id: u8| measured.keys().any(|&(a, b)| a == id || b == id);
|
||||
let has_dist: Vec<bool> = sorted.iter().map(|g| node_has_dist(g.node_id)).collect();
|
||||
|
||||
// Slots 0–3: node count + measurement-presence fractions.
|
||||
let nf = n as f32;
|
||||
v[0] = (nf / NODE_SLOTS as f32).min(2.0);
|
||||
v[1] = positions.iter().flatten().count() as f32 / nf;
|
||||
v[2] = orientations.iter().flatten().count() as f32 / nf;
|
||||
v[3] = has_dist.iter().filter(|&&d| d).count() as f32 / nf;
|
||||
|
||||
// Slots 4–9: centroid + per-axis std of the known positions.
|
||||
let known: Vec<[f32; 3]> = positions.iter().flatten().copied().collect();
|
||||
if !known.is_empty() {
|
||||
let kf = known.len() as f32;
|
||||
let mut centroid = [0.0f32; 3];
|
||||
for p in &known {
|
||||
for (c, x) in centroid.iter_mut().zip(p) {
|
||||
*c += x / kf;
|
||||
}
|
||||
}
|
||||
for axis in 0..3 {
|
||||
v[4 + axis] = clamp_m(centroid[axis]);
|
||||
let mut var = 0.0;
|
||||
for p in &known {
|
||||
var += (p[axis] - centroid[axis]).powi(2) / kf;
|
||||
}
|
||||
v[7 + axis] = clamp_m(var.max(0.0).sqrt());
|
||||
}
|
||||
|
||||
// Slots 10–12: pairwise position distance stats.
|
||||
let mut dists = Vec::new();
|
||||
for i in 0..known.len() {
|
||||
for j in (i + 1)..known.len() {
|
||||
dists.push(euclidean(&known[i], &known[j]));
|
||||
}
|
||||
}
|
||||
write_min_mean_max(&mut v, 10, &dists);
|
||||
|
||||
// Slots 21–23: geometric diversity from the position covariance
|
||||
// eigenstructure (see module docs for why over polygon area).
|
||||
let (l1, l2, l3) = covariance_eigenvalues(&known, ¢roid);
|
||||
if l1 > f32::EPSILON {
|
||||
v[21] = (l2 / l1).clamp(0.0, 1.0);
|
||||
v[22] = (l3 / l1).clamp(0.0, 1.0);
|
||||
}
|
||||
v[23] = clamp_m(l1.max(0.0).sqrt());
|
||||
}
|
||||
|
||||
// Slots 13–16: inter-node distances — measured first, position fallback.
|
||||
let mut inter = Vec::new();
|
||||
for i in 0..n {
|
||||
for j in (i + 1)..n {
|
||||
let key = pair_key(sorted[i].node_id, sorted[j].node_id);
|
||||
if let Some(&d) = measured.get(&key) {
|
||||
inter.push(d);
|
||||
} else if let (Some(a), Some(b)) = (&positions[i], &positions[j]) {
|
||||
inter.push(euclidean(a, b));
|
||||
}
|
||||
}
|
||||
}
|
||||
write_min_mean_max(&mut v, 13, &inter);
|
||||
let possible_pairs = n * n.saturating_sub(1) / 2;
|
||||
if possible_pairs > 0 {
|
||||
v[16] = (measured.len() as f32 / possible_pairs as f32).clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
// Slots 17–20: orientation statistics (circular mean of azimuth).
|
||||
let known_orient: Vec<(f32, f32)> = orientations.iter().flatten().copied().collect();
|
||||
if !known_orient.is_empty() {
|
||||
let of = known_orient.len() as f32;
|
||||
let c = known_orient.iter().map(|(az, _)| az.cos()).sum::<f32>() / of;
|
||||
let s = known_orient.iter().map(|(az, _)| az.sin()).sum::<f32>() / of;
|
||||
v[17] = c.clamp(-1.0, 1.0);
|
||||
v[18] = s.clamp(-1.0, 1.0);
|
||||
v[19] = (c * c + s * s).sqrt().clamp(0.0, 1.0);
|
||||
let el = known_orient.iter().map(|(_, e)| e).sum::<f32>() / of;
|
||||
v[20] = el.clamp(-std::f32::consts::FRAC_PI_2, std::f32::consts::FRAC_PI_2);
|
||||
}
|
||||
|
||||
// Slots 24–31: per-node measurement flags (first NODE_SLOTS by id).
|
||||
for i in 0..n.min(NODE_SLOTS) {
|
||||
v[24 + i] = 0.25
|
||||
+ 0.25 * f32::from(positions[i].is_some() as u8)
|
||||
+ 0.25 * f32::from(orientations[i].is_some() as u8)
|
||||
+ 0.25 * f32::from(has_dist[i] as u8);
|
||||
}
|
||||
|
||||
// The finite invariant must hold whatever happened above.
|
||||
for x in &mut v {
|
||||
if !x.is_finite() {
|
||||
*x = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
schema_version: Self::SCHEMA_VERSION,
|
||||
values: v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A position whose components are all finite and room-scale, else `None`.
|
||||
fn valid_position(g: &NodeGeometry) -> Option<[f32; 3]> {
|
||||
let p = g.position?;
|
||||
let ok = |c: f32| c.is_finite() && c.abs() <= MAX_COORD_M;
|
||||
(ok(p.x_m) && ok(p.y_m) && ok(p.z_m)).then_some([p.x_m, p.y_m, p.z_m])
|
||||
}
|
||||
|
||||
/// An orientation whose angles are both finite, else `None`.
|
||||
fn valid_orientation(g: &NodeGeometry) -> Option<(f32, f32)> {
|
||||
let o = g.orientation?;
|
||||
let ok = o.azimuth_rad.is_finite() && o.elevation_rad.is_finite();
|
||||
ok.then_some((o.azimuth_rad, o.elevation_rad))
|
||||
}
|
||||
|
||||
/// Canonical unordered pair key.
|
||||
fn pair_key(a: u8, b: u8) -> (u8, u8) {
|
||||
(a.min(b), a.max(b))
|
||||
}
|
||||
|
||||
/// Valid measured distances between *enrolled* nodes, deduplicated to
|
||||
/// unordered pairs (both directions recorded → averaged); distances to
|
||||
/// non-enrolled node ids are ignored.
|
||||
fn measured_pairs(sorted: &[&NodeGeometry]) -> BTreeMap<(u8, u8), f32> {
|
||||
let ids: Vec<u8> = sorted.iter().map(|g| g.node_id).collect();
|
||||
let mut sums: BTreeMap<(u8, u8), (f32, u32)> = BTreeMap::new();
|
||||
for g in sorted {
|
||||
for (&other, &d) in &g.distances_m {
|
||||
let pair_ok = other != g.node_id && ids.contains(&other);
|
||||
if pair_ok && d.is_finite() && d > 0.0 && d <= MAX_COORD_M {
|
||||
let e = sums.entry(pair_key(g.node_id, other)).or_insert((0.0, 0));
|
||||
e.0 += d;
|
||||
e.1 += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
sums.into_iter()
|
||||
.map(|(k, (sum, n))| (k, sum / n as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn euclidean(a: &[f32; 3], b: &[f32; 3]) -> f32 {
|
||||
let mut d2 = 0.0;
|
||||
for k in 0..3 {
|
||||
d2 += (a[k] - b[k]).powi(2);
|
||||
}
|
||||
d2.sqrt()
|
||||
}
|
||||
|
||||
/// Write min/mean/max of a sample into slots `base..base+3` (left at zero
|
||||
/// when the sample is empty), clamped to the meters range.
|
||||
fn write_min_mean_max(v: &mut [f32; GeometryEmbedding::DIM], base: usize, xs: &[f32]) {
|
||||
if xs.is_empty() {
|
||||
return;
|
||||
}
|
||||
let (mut min, mut max, mut sum) = (f32::INFINITY, f32::NEG_INFINITY, 0.0);
|
||||
for &x in xs {
|
||||
min = min.min(x);
|
||||
max = max.max(x);
|
||||
sum += x;
|
||||
}
|
||||
v[base] = clamp_m(min);
|
||||
v[base + 1] = clamp_m(sum / xs.len() as f32);
|
||||
v[base + 2] = clamp_m(max);
|
||||
}
|
||||
|
||||
/// Clamp a meters-valued slot into ±[`MAX_COORD_M`], mapping non-finite to 0.
|
||||
fn clamp_m(x: f32) -> f32 {
|
||||
if x.is_finite() {
|
||||
x.clamp(-MAX_COORD_M, MAX_COORD_M)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Eigenvalues `λ1 ≥ λ2 ≥ λ3 ≥ 0` of the 3×3 position covariance, via the
|
||||
/// closed-form trigonometric solution for symmetric matrices (no linear-
|
||||
/// algebra dependency; f64 internally for conditioning).
|
||||
fn covariance_eigenvalues(points: &[[f32; 3]], centroid: &[f32; 3]) -> (f32, f32, f32) {
|
||||
let nf = points.len() as f64;
|
||||
// Upper triangle of the symmetric covariance: (xx, yy, zz, xy, xz, yz).
|
||||
const IJ: [(usize, usize); 6] = [(0, 0), (1, 1), (2, 2), (0, 1), (0, 2), (1, 2)];
|
||||
let mut m = [0.0f64; 6];
|
||||
for p in points {
|
||||
let d: [f64; 3] = std::array::from_fn(|i| (p[i] - centroid[i]) as f64);
|
||||
for (k, &(i, j)) in IJ.iter().enumerate() {
|
||||
m[k] += d[i] * d[j] / nf;
|
||||
}
|
||||
}
|
||||
let (a, b, c, d, e, f) = (m[0], m[1], m[2], m[3], m[4], m[5]);
|
||||
let p1 = d * d + e * e + f * f;
|
||||
let q = (a + b + c) / 3.0;
|
||||
let p2 = (a - q).powi(2) + (b - q).powi(2) + (c - q).powi(2) + 2.0 * p1;
|
||||
let p = (p2 / 6.0).sqrt();
|
||||
let (l1, l2, l3) = if p < 1e-12 {
|
||||
(q, q, q) // (Near-)isotropic: all eigenvalues equal — diagonal incl.
|
||||
} else {
|
||||
// r = det((M - qI)/p) / 2, clamped into acos' domain.
|
||||
let (ba, bb, bc) = ((a - q) / p, (b - q) / p, (c - q) / p);
|
||||
let (bd, be, bf) = (d / p, e / p, f / p);
|
||||
let det = ba * (bb * bc - bf * bf) - bd * (bd * bc - bf * be) + be * (bd * bf - bb * be);
|
||||
let phi = (det / 2.0).clamp(-1.0, 1.0).acos() / 3.0;
|
||||
let e1 = q + 2.0 * p * phi.cos();
|
||||
let e3 = q + 2.0 * p * (phi + 2.0 * std::f64::consts::PI / 3.0).cos();
|
||||
(e1, 3.0 * q - e1 - e3, e3)
|
||||
};
|
||||
// PSD matrix: tiny negatives are numerical noise — clamp.
|
||||
(l1.max(0.0) as f32, l2.max(0.0) as f32, l3.max(0.0) as f32)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// A fully-measured node at `(x, y, 1)` with boresight toward +Y.
|
||||
fn node(id: u8, x: f32, y: f32) -> NodeGeometry {
|
||||
NodeGeometry::new(id, "tape-measure")
|
||||
.with_position(x, y, 1.0)
|
||||
.with_orientation(std::f32::consts::FRAC_PI_2, 0.1)
|
||||
}
|
||||
|
||||
/// 3 nodes on a 3-4-5 triangle; the (1,2) edge also measured by tape.
|
||||
fn full_layout() -> Vec<NodeGeometry> {
|
||||
vec![
|
||||
node(1, 0.0, 0.0).with_distance(2, 3.0),
|
||||
node(2, 3.0, 0.0).with_distance(1, 3.0),
|
||||
node(3, 0.0, 4.0),
|
||||
]
|
||||
}
|
||||
|
||||
fn assert_all_finite(e: &GeometryEmbedding) {
|
||||
for (i, x) in e.values.iter().enumerate() {
|
||||
assert!(x.is_finite(), "slot {i} is not finite: {x}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dimension_stable_and_empty_input_is_all_zero() {
|
||||
assert_eq!(GeometryEmbedding::DIM, 32);
|
||||
let full = GeometryEmbedding::from_nodes(&full_layout());
|
||||
assert_eq!(full.as_slice().len(), GeometryEmbedding::DIM);
|
||||
let empty = GeometryEmbedding::from_nodes(&[]);
|
||||
assert_eq!(empty, GeometryEmbedding::default(), "all-zero");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_unknown_layout_degrades_gracefully() {
|
||||
let nodes = vec![NodeGeometry::unknown(1), NodeGeometry::unknown(2)];
|
||||
let e = GeometryEmbedding::from_nodes(&nodes);
|
||||
assert_all_finite(&e);
|
||||
assert!((e.values[0] - 2.0 / 8.0).abs() < 1e-6, "node count slot");
|
||||
// No measurements: presence fractions and all stats at zero …
|
||||
for slot in 1..24 {
|
||||
assert_eq!(e.values[slot], 0.0, "slot {slot} should be 0");
|
||||
}
|
||||
// … but the per-node existence flags still say two nodes were there.
|
||||
assert_eq!(&e.values[24..27], &[0.25, 0.25, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_node_has_no_pairwise_stats() {
|
||||
let n = NodeGeometry::new(5, "t")
|
||||
.with_position(1.0, 2.0, 1.5)
|
||||
.with_orientation(0.0, 0.0);
|
||||
let e = GeometryEmbedding::from_nodes(&[n]);
|
||||
assert_all_finite(&e);
|
||||
assert_eq!(&e.values[4..7], &[1.0, 2.0, 1.5], "centroid = the node");
|
||||
assert_eq!(&e.values[7..10], &[0.0, 0.0, 0.0], "no spread");
|
||||
assert_eq!(&e.values[10..17], &[0.0; 7], "no pairs");
|
||||
assert_eq!(e.values[17], 1.0, "cos(0)");
|
||||
assert_eq!(e.values[19], 1.0, "single boresight is fully concentrated");
|
||||
assert_eq!(e.values[24], 0.75, "position + orientation, no distances");
|
||||
}
|
||||
|
||||
/// Full-measurement layout: every slot family lands where the geometry
|
||||
/// says it should, and shuffling node order changes nothing.
|
||||
#[test]
|
||||
fn full_layout_statistics_and_permutation_invariance() {
|
||||
let nodes = full_layout();
|
||||
let e = GeometryEmbedding::from_nodes(&nodes);
|
||||
assert!((e.values[1] - 1.0).abs() < 1e-6, "all positioned");
|
||||
assert!((e.values[2] - 1.0).abs() < 1e-6, "all oriented");
|
||||
// 3-4-5 triangle: position-pair distances {3, 4, 5}.
|
||||
assert!((e.values[10] - 3.0).abs() < 1e-5, "min dist");
|
||||
assert!((e.values[11] - 4.0).abs() < 1e-5, "mean dist");
|
||||
assert!((e.values[12] - 5.0).abs() < 1e-5, "max dist");
|
||||
// Inter-node stats: pair (1,2) measured, (1,3)/(2,3) from positions.
|
||||
assert!((e.values[14] - 4.0).abs() < 1e-5, "mean inter-node dist");
|
||||
assert!((e.values[16] - 1.0 / 3.0).abs() < 1e-6, "1 of 3 measured");
|
||||
// Parallel boresights: fully concentrated, pointing +Y.
|
||||
assert!(e.values[17].abs() < 1e-6, "cos(π/2)");
|
||||
assert!((e.values[18] - 1.0).abs() < 1e-5, "sin(π/2)");
|
||||
assert!((e.values[19] - 1.0).abs() < 1e-5, "concentration");
|
||||
assert!((e.values[20] - 0.1).abs() < 1e-5, "mean elevation");
|
||||
// Coplanar triangle: λ1 ≈ 4.32, λ2 ≈ 1.23 (3-4-5 covariance), λ3 = 0.
|
||||
assert!((e.values[21] - 0.286).abs() < 0.01, "λ2/λ1 planar");
|
||||
assert!(e.values[22] < 1e-5, "λ3/λ1 ≈ 0 — coplanar nodes");
|
||||
assert!(e.values[23] > 0.5, "dominant spread is meter-scale");
|
||||
// Node 3 (rank 2) recorded no distances; nodes 1, 2 did.
|
||||
assert_eq!(&e.values[24..27], &[1.0, 1.0, 0.75]);
|
||||
|
||||
let mut shuffled = nodes;
|
||||
shuffled.rotate_left(1);
|
||||
shuffled.swap(0, 1);
|
||||
assert_eq!(e, GeometryEmbedding::from_nodes(&shuffled));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn measured_distance_overrides_position_distance() {
|
||||
// Positions say 3 m apart, the tape measure said 2.5 m: measured wins.
|
||||
let nodes = vec![
|
||||
NodeGeometry::new(1, "t")
|
||||
.with_position(0.0, 0.0, 1.0)
|
||||
.with_distance(2, 2.5),
|
||||
NodeGeometry::new(2, "t").with_position(3.0, 0.0, 1.0),
|
||||
];
|
||||
let e = GeometryEmbedding::from_nodes(&nodes);
|
||||
assert!((e.values[10] - 3.0).abs() < 1e-5, "position pair stat raw");
|
||||
assert!((e.values[14] - 2.5).abs() < 1e-5, "measured wins");
|
||||
assert!((e.values[16] - 1.0).abs() < 1e-6, "full pair coverage");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_inputs_never_produce_nan() {
|
||||
let nodes = vec![
|
||||
NodeGeometry::new(1, "garbage")
|
||||
.with_position(f32::NAN, f32::INFINITY, -0.0)
|
||||
.with_orientation(f32::NAN, f32::NEG_INFINITY)
|
||||
.with_distance(2, f32::NAN)
|
||||
.with_distance(3, -5.0)
|
||||
.with_distance(1, 1.0), // self-distance: ignored
|
||||
NodeGeometry::new(2, "garbage")
|
||||
.with_position(1e30, 1e30, 1e30)
|
||||
.with_distance(99, 4.0), // unknown node: ignored
|
||||
NodeGeometry::new(3, "garbage").with_position(2.0, 0.0, 1.0),
|
||||
];
|
||||
let e = GeometryEmbedding::from_nodes(&nodes);
|
||||
assert_all_finite(&e);
|
||||
// Only node 3's position survived sanitization.
|
||||
assert!((e.values[1] - 1.0 / 3.0).abs() < 1e-6);
|
||||
assert_eq!(e.values[2], 0.0, "no valid orientations");
|
||||
assert_eq!(e.values[16], 0.0, "no valid measured pairs");
|
||||
assert!(e.values.iter().all(|x| x.abs() <= MAX_COORD_M), "bounded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn more_than_eight_nodes_still_aggregates() {
|
||||
let nodes: Vec<NodeGeometry> = (0..12)
|
||||
.map(|i| NodeGeometry::new(i, "plan").with_position(i as f32, 0.0, 1.0))
|
||||
.collect();
|
||||
let e = GeometryEmbedding::from_nodes(&nodes);
|
||||
assert!((e.values[0] - 12.0 / 8.0).abs() < 1e-6);
|
||||
// All 8 flag slots filled (positions known, ranks 0..8 by node_id).
|
||||
assert!(e.values[24..32].iter().all(|&f| f == 0.5));
|
||||
// Collinear nodes: zero planar/volume diversity, meter-scale spread.
|
||||
assert!(e.values[21] < 1e-5);
|
||||
assert!(e.values[22] < 1e-5);
|
||||
assert!(e.values[23] > 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_roundtrip_and_schema_default() {
|
||||
let e = GeometryEmbedding::from_nodes(&full_layout());
|
||||
let json = serde_json::to_string(&e).unwrap();
|
||||
let back: GeometryEmbedding = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, e);
|
||||
assert_eq!(back.schema_version, GeometryEmbedding::SCHEMA_VERSION);
|
||||
// JSON written by a pre-versioning producer (no version field)
|
||||
// defaults to the current schema — the NodeGeometry pattern.
|
||||
let vals = serde_json::to_string(&e.values).unwrap();
|
||||
let bare = format!("{{\"values\":{vals}}}");
|
||||
let from_bare: GeometryEmbedding = serde_json::from_str(&bare).unwrap();
|
||||
assert_eq!(from_bare.schema_version, 1);
|
||||
assert_eq!(from_bare.values, e.values);
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,10 @@
|
||||
//!
|
||||
//! Stages (ADR-151 §1.3):
|
||||
//! 1. **baseline** — empty-room environmental fingerprint (ADR-135; consumed here).
|
||||
//! 2. **enroll** — guided anchors with an adaptive quality gate ([`anchor`], [`enrollment`]).
|
||||
//! 2. **enroll** — guided anchors with an adaptive quality gate ([`anchor`],
|
||||
//! [`enrollment`]) plus an optional transceiver-geometry record ([`geometry`],
|
||||
//! ADR-152 §2.1.1) and its fixed-length conditioning featurization
|
||||
//! ([`geometry_embedding`], ADR-152 §2.1.2).
|
||||
//! 3. **extract** — labelled feature records from anchor captures ([`extract`]).
|
||||
//! 4. **train** — a bank of small specialist models ([`specialist`], [`bank`]) and a
|
||||
//! confidence-gated mixture runtime ([`runtime`]).
|
||||
@@ -19,19 +22,23 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod anchor;
|
||||
pub mod bank;
|
||||
pub mod enrollment;
|
||||
pub mod error;
|
||||
pub mod extract;
|
||||
pub mod specialist;
|
||||
pub mod bank;
|
||||
pub mod runtime;
|
||||
pub mod geometry;
|
||||
pub mod geometry_embedding;
|
||||
pub mod multistatic;
|
||||
pub mod runtime;
|
||||
pub mod specialist;
|
||||
|
||||
pub use anchor::{Anchor, AnchorLabel, AnchorQuality, EnrollmentEvent, EnrollmentSession, Posture};
|
||||
pub use bank::SpecialistBank;
|
||||
pub use enrollment::{AnchorQualityGate, AnchorRecorder};
|
||||
pub use error::{CalibrationError, Result};
|
||||
pub use extract::AnchorFeature;
|
||||
pub use geometry::{AntennaOrientation, NodeGeometry, PositionEstimate};
|
||||
pub use geometry_embedding::GeometryEmbedding;
|
||||
pub use multistatic::MultiNodeMixture;
|
||||
pub use runtime::{MixtureOfSpecialists, RoomState};
|
||||
pub use specialist::{Specialist, SpecialistKind, SpecialistReading};
|
||||
|
||||
@@ -20,6 +20,7 @@ use std::collections::BTreeMap;
|
||||
|
||||
use crate::bank::SpecialistBank;
|
||||
use crate::extract::Features;
|
||||
use crate::geometry::NodeGeometry;
|
||||
use crate::runtime::{MixtureOfSpecialists, RoomState};
|
||||
use crate::specialist::SpecialistReading;
|
||||
|
||||
@@ -45,7 +46,12 @@ impl MultiNodeMixture {
|
||||
|
||||
/// Register a node's bank. `current_baseline_id` is the baseline the node is
|
||||
/// observing now (drift vs the bank's training baseline → STALE).
|
||||
pub fn add_node(&mut self, node_id: u8, bank: SpecialistBank, current_baseline_id: impl Into<String>) {
|
||||
pub fn add_node(
|
||||
&mut self,
|
||||
node_id: u8,
|
||||
bank: SpecialistBank,
|
||||
current_baseline_id: impl Into<String>,
|
||||
) {
|
||||
self.nodes.insert(
|
||||
node_id,
|
||||
NodeEntry {
|
||||
@@ -60,6 +66,26 @@ impl MultiNodeMixture {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// The transceiver-geometry snapshot a node's bank was trained under
|
||||
/// (ADR-152 §2.1.1), if its enrollment recorded one. Threaded through for
|
||||
/// the fusion logic; **not used algorithmically yet** — geometry-aware
|
||||
/// fusion is the §2.1.2 learned-embedding work (ADR-151 P6).
|
||||
pub fn node_geometry(&self, node_id: u8) -> Option<&[NodeGeometry]> {
|
||||
self.nodes
|
||||
.get(&node_id)
|
||||
.map(|e| e.mixture.bank().geometry.as_slice())
|
||||
.filter(|g| !g.is_empty())
|
||||
}
|
||||
|
||||
/// All registered nodes' geometry snapshots, keyed by node id. Nodes whose
|
||||
/// banks carry no geometry are omitted.
|
||||
pub fn geometries(&self) -> BTreeMap<u8, &[NodeGeometry]> {
|
||||
self.nodes
|
||||
.keys()
|
||||
.filter_map(|&id| self.node_geometry(id).map(|g| (id, g)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Fuse per-node feature windows into one room state. Nodes without a feature
|
||||
/// entry this window are skipped.
|
||||
pub fn infer(&self, per_node: &BTreeMap<u8, Features>) -> RoomState {
|
||||
@@ -109,15 +135,13 @@ impl MultiNodeMixture {
|
||||
|
||||
/// Presence: a person is present if ANY node sees one; confidence = max.
|
||||
fn fuse_presence(states: &[RoomState]) -> Option<SpecialistReading> {
|
||||
let readings: Vec<&SpecialistReading> = states.iter().filter_map(|s| s.presence.as_ref()).collect();
|
||||
let readings: Vec<&SpecialistReading> =
|
||||
states.iter().filter_map(|s| s.presence.as_ref()).collect();
|
||||
if readings.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let any_present = readings.iter().any(|r| r.value > 0.5);
|
||||
let confidence = readings
|
||||
.iter()
|
||||
.map(|r| r.confidence)
|
||||
.fold(0.0f32, f32::max);
|
||||
let confidence = readings.iter().map(|r| r.confidence).fold(0.0f32, f32::max);
|
||||
Some(SpecialistReading {
|
||||
kind: readings[0].kind,
|
||||
value: if any_present { 1.0 } else { 0.0 },
|
||||
@@ -205,6 +229,22 @@ mod tests {
|
||||
assert_eq!(m.node_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometry_threads_through_to_fusion() {
|
||||
let geo1 = vec![NodeGeometry::new(1, "tape-measure")
|
||||
.with_position(0.0, 0.0, 1.0)
|
||||
.with_distance(2, 3.0)];
|
||||
let mut m = MultiNodeMixture::new();
|
||||
m.add_node(1, bank("b1").with_geometry(geo1.clone()), "b1");
|
||||
m.add_node(2, bank("b1"), "b1"); // no geometry recorded for node 2
|
||||
assert_eq!(m.node_geometry(1), Some(geo1.as_slice()));
|
||||
assert_eq!(m.node_geometry(2), None, "geometry-free bank reads None");
|
||||
assert_eq!(m.node_geometry(9), None, "unknown node reads None");
|
||||
let all = m.geometries();
|
||||
assert_eq!(all.len(), 1);
|
||||
assert_eq!(all.get(&1), Some(&geo1.as_slice()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn presence_or_across_nodes() {
|
||||
let mut m = MultiNodeMixture::new();
|
||||
|
||||
@@ -123,9 +123,7 @@ impl Specialist for PresenceSpecialist {
|
||||
fn infer(&self, f: &Features) -> Option<SpecialistReading> {
|
||||
let by_variance = f.variance > self.threshold;
|
||||
let mean_dist = (f.mean - self.empty_mean).abs();
|
||||
let by_mean = self
|
||||
.mean_dist_threshold
|
||||
.is_some_and(|thr| mean_dist > thr);
|
||||
let by_mean = self.mean_dist_threshold.is_some_and(|thr| mean_dist > thr);
|
||||
let present = by_variance || by_mean;
|
||||
|
||||
// Confidence: strongest margin among the channels that are enabled.
|
||||
@@ -228,7 +226,11 @@ impl Specialist for BreathingSpecialist {
|
||||
SpecialistKind::Breathing
|
||||
}
|
||||
fn infer(&self, f: &Features) -> Option<SpecialistReading> {
|
||||
let min = if self.min_score > 0.0 { self.min_score } else { 0.25 };
|
||||
let min = if self.min_score > 0.0 {
|
||||
self.min_score
|
||||
} else {
|
||||
0.25
|
||||
};
|
||||
if f.breathing_score < min || f.breathing_hz <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
@@ -253,7 +255,11 @@ impl Specialist for HeartbeatSpecialist {
|
||||
SpecialistKind::Heartbeat
|
||||
}
|
||||
fn infer(&self, f: &Features) -> Option<SpecialistReading> {
|
||||
let min = if self.min_score > 0.0 { self.min_score } else { 0.3 };
|
||||
let min = if self.min_score > 0.0 {
|
||||
self.min_score
|
||||
} else {
|
||||
0.3
|
||||
};
|
||||
if f.heart_score < min || f.heart_hz <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ use num_complex::Complex64;
|
||||
use wifi_densepose_calibration::extract::Features;
|
||||
use wifi_densepose_calibration::{
|
||||
AnchorFeature, AnchorLabel, AnchorQualityGate, AnchorRecorder, EnrollmentEvent,
|
||||
EnrollmentSession, MixtureOfSpecialists, SpecialistBank, SpecialistKind,
|
||||
EnrollmentSession, MixtureOfSpecialists, NodeGeometry, SpecialistBank, SpecialistKind,
|
||||
};
|
||||
use wifi_densepose_core::types::{AntennaConfig, CsiFrame, CsiMetadata, DeviceId, FrequencyBand};
|
||||
use wifi_densepose_signal::{BaselineCalibration, CalibrationConfig, CalibrationRecorder};
|
||||
@@ -271,6 +271,19 @@ fn full_loop_baseline_enroll_extract_train_infer() {
|
||||
// -- Stage 2: guided-anchor enrollment with the quality gate -------------
|
||||
let gate = AnchorQualityGate::default();
|
||||
let mut session = EnrollmentSession::new(room_id, &baseline_id, 1_700_000_000);
|
||||
|
||||
// Transceiver geometry recorded at session start (ADR-152 §2.1.1): a
|
||||
// two-node layout, one tape-measured, one unknown — all fields optional.
|
||||
let geometry = vec![
|
||||
NodeGeometry::new(1, "tape-measure")
|
||||
.with_position(0.0, 0.0, 1.2)
|
||||
.with_orientation(0.0, 0.0)
|
||||
.with_distance(2, 3.5),
|
||||
NodeGeometry::unknown(2),
|
||||
];
|
||||
session.record_geometry(geometry.clone(), 1_700_000_000);
|
||||
assert_eq!(session.geometry(), Some(geometry.as_slice()));
|
||||
|
||||
let mut features: Vec<AnchorFeature> = Vec::new();
|
||||
|
||||
for (i, label) in AnchorLabel::SEQUENCE.into_iter().enumerate() {
|
||||
@@ -345,8 +358,10 @@ fn full_loop_baseline_enroll_extract_train_infer() {
|
||||
);
|
||||
|
||||
// -- Stage 4: train the specialist bank + JSON persistence round-trip ----
|
||||
// The bank snapshots the geometry the enrollment recorded (ADR-152 §2.1.1).
|
||||
let bank = SpecialistBank::train(room_id, &baseline_id, &features, 1_700_000_400)
|
||||
.expect("bank training");
|
||||
.expect("bank training")
|
||||
.with_geometry(session.geometry().map(<[_]>::to_vec).unwrap_or_default());
|
||||
assert_eq!(bank.room_id, room_id);
|
||||
assert_eq!(bank.anchor_count, 8);
|
||||
let kinds = bank.trained_kinds();
|
||||
@@ -373,6 +388,10 @@ fn full_loop_baseline_enroll_extract_train_infer() {
|
||||
bank.presence.as_ref().map(|p| p.threshold),
|
||||
"presence threshold must survive persistence"
|
||||
);
|
||||
assert_eq!(
|
||||
reloaded.geometry, geometry,
|
||||
"the enrollment geometry snapshot must survive bank persistence"
|
||||
);
|
||||
|
||||
// -- Stage 5: runtime inference through the mixture ----------------------
|
||||
let mix = MixtureOfSpecialists::new(reloaded);
|
||||
|
||||
@@ -39,7 +39,8 @@ use tokio::sync::{mpsc, oneshot, RwLock};
|
||||
use tower_http::cors::CorsLayer;
|
||||
use wifi_densepose_calibration::extract::{AnchorFeature, Features};
|
||||
use wifi_densepose_calibration::{
|
||||
AnchorLabel, AnchorQualityGate, AnchorRecorder, MixtureOfSpecialists, SpecialistBank,
|
||||
AnchorLabel, AnchorQualityGate, AnchorRecorder, MixtureOfSpecialists, NodeGeometry,
|
||||
SpecialistBank,
|
||||
};
|
||||
use wifi_densepose_core::types::CsiFrame;
|
||||
use wifi_densepose_signal::{BaselineCalibration, CalibrationRecorder};
|
||||
@@ -207,6 +208,9 @@ struct RoomEnroll {
|
||||
baseline_id: String,
|
||||
fs_hz: f32,
|
||||
anchors: Vec<AnchorFeature>,
|
||||
/// Transceiver geometry recorded via `POST /enroll/geometry` (ADR-152
|
||||
/// §2.1.1); latest recording wins. Snapshotted into the bank at train time.
|
||||
geometry: Vec<NodeGeometry>,
|
||||
}
|
||||
|
||||
/// Result of capturing one anchor (`POST /enroll/anchor`).
|
||||
@@ -299,6 +303,7 @@ fn build_router(state: ApiState) -> Router {
|
||||
.route("/api/v1/room/state", get(room_state))
|
||||
.route("/api/v1/room/train", post(train_room))
|
||||
.route("/api/v1/enroll/anchor", post(enroll_anchor))
|
||||
.route("/api/v1/enroll/geometry", post(enroll_geometry))
|
||||
.route("/api/v1/enroll/status", get(enroll_status))
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state)
|
||||
@@ -670,8 +675,9 @@ async fn descriptor() -> impl IntoResponse {
|
||||
"GET /api/v1/calibration/result": "last finalized baseline summary",
|
||||
"GET /api/v1/calibration/baselines": "list persisted baseline files",
|
||||
"GET /api/v1/room/state?bank=<name>": "live mixture-of-specialists RoomState over the CSI window",
|
||||
"POST /api/v1/room/train": "{ room_id, baseline_id, anchors[]? } → train + persist a specialist bank (anchors[] optional if enrolled in-server)",
|
||||
"POST /api/v1/room/train": "{ room_id, baseline_id, anchors[]?, geometry[]? } → train + persist a specialist bank (anchors[]/geometry[] optional if enrolled in-server)",
|
||||
"POST /api/v1/enroll/anchor": "{ room_id, baseline, label, duration_s? } → capture one guided anchor (blocks for the capture)",
|
||||
"POST /api/v1/enroll/geometry": "{ room_id, geometry: [NodeGeometry…] } → record transceiver geometry for the room (ADR-152 §2.1.1; latest wins)",
|
||||
"GET /api/v1/enroll/status?room=<id>": "enrollment progress (accepted anchors, next, complete)"
|
||||
}
|
||||
}))
|
||||
@@ -740,11 +746,18 @@ struct TrainRequest {
|
||||
baseline_id: String,
|
||||
#[serde(default)]
|
||||
anchors: Vec<AnchorFeature>,
|
||||
/// Optional transceiver geometry (ADR-152 §2.1.1). Falls back to the
|
||||
/// geometry recorded in-server via `POST /enroll/geometry`; absent both,
|
||||
/// the bank trains geometry-free (valid, but no geometry conditioning).
|
||||
#[serde(default)]
|
||||
geometry: Vec<NodeGeometry>,
|
||||
}
|
||||
|
||||
/// Train a per-room specialist bank and persist it as `<output_dir>/<room_id>.json`
|
||||
/// (the name `room-state` reads back). Uses the posted `anchors` if present, else
|
||||
/// falls back to the in-server enrollment accumulated via `POST /enroll/anchor`.
|
||||
/// The enrollment's transceiver-geometry snapshot (posted `geometry` or the
|
||||
/// `POST /enroll/geometry` record) is threaded into the bank (ADR-152 §2.1.1).
|
||||
async fn train_room(State(st): State<ApiState>, Json(req): Json<TrainRequest>) -> impl IntoResponse {
|
||||
let (anchors, baseline_id) = if !req.anchors.is_empty() {
|
||||
(req.anchors.clone(), req.baseline_id.clone())
|
||||
@@ -756,11 +769,25 @@ async fn train_room(State(st): State<ApiState>, Json(req): Json<TrainRequest>) -
|
||||
}
|
||||
}
|
||||
};
|
||||
let geometry = if !req.geometry.is_empty() {
|
||||
req.geometry.clone()
|
||||
} else {
|
||||
st.enroll.read().await.get(&req.room_id).map(|re| re.geometry.clone()).unwrap_or_default()
|
||||
};
|
||||
let at = (unix_ms() / 1000) as i64;
|
||||
let bank = match SpecialistBank::train(&req.room_id, &baseline_id, &anchors, at) {
|
||||
Ok(b) => b,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("training failed: {e}")}))).into_response(),
|
||||
};
|
||||
let bank = if geometry.is_empty() {
|
||||
eprintln!(
|
||||
"[calibrate-serve] no transceiver geometry recorded for room '{}' — bank will not support geometry conditioning (ADR-152 §2.1.2)",
|
||||
req.room_id
|
||||
);
|
||||
bank
|
||||
} else {
|
||||
bank.with_geometry(geometry)
|
||||
};
|
||||
let name = sanitize_room_id(&req.room_id);
|
||||
let dir = { st.status.read().await.output_dir.clone() };
|
||||
let path = format!("{dir}/{name}.json");
|
||||
@@ -777,10 +804,37 @@ async fn train_room(State(st): State<ApiState>, Json(req): Json<TrainRequest>) -
|
||||
"bank": name, // pass as ?bank=<name> to /room/state
|
||||
"anchor_count": bank.anchor_count,
|
||||
"specialists": kinds,
|
||||
"geometry_nodes": bank.geometry.len(),
|
||||
"path": path,
|
||||
}))).into_response()
|
||||
}
|
||||
|
||||
/// Body for `POST /api/v1/enroll/geometry`.
|
||||
#[derive(Deserialize)]
|
||||
struct EnrollGeometryBody {
|
||||
room_id: String,
|
||||
/// Per-node transceiver geometry records (ADR-152 §2.1.1).
|
||||
geometry: Vec<NodeGeometry>,
|
||||
}
|
||||
|
||||
/// Record the room's transceiver geometry (ADR-152 §2.1.1) into the in-server
|
||||
/// enrollment; the next `POST /room/train` snapshots it into the bank. A later
|
||||
/// POST supersedes an earlier one (latest wins), mirroring
|
||||
/// `EnrollmentSession::record_geometry`.
|
||||
async fn enroll_geometry(State(st): State<ApiState>, Json(b): Json<EnrollGeometryBody>) -> impl IntoResponse {
|
||||
if b.geometry.is_empty() {
|
||||
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":"geometry must be a non-empty array of NodeGeometry records"}))).into_response();
|
||||
}
|
||||
let nodes = b.geometry.len();
|
||||
{
|
||||
let mut map = st.enroll.write().await;
|
||||
let re = map.entry(b.room_id.clone()).or_insert_with(RoomEnroll::default);
|
||||
re.geometry = b.geometry;
|
||||
}
|
||||
eprintln!("[calibrate-serve] enroll geometry room={} nodes={nodes}", b.room_id);
|
||||
(StatusCode::OK, Json(serde_json::json!({"room_id": b.room_id, "geometry_nodes": nodes}))).into_response()
|
||||
}
|
||||
|
||||
/// Body for `POST /api/v1/enroll/anchor`.
|
||||
#[derive(Deserialize)]
|
||||
struct EnrollAnchorBody {
|
||||
@@ -1086,6 +1140,59 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
/// ADR-152 §2.1.1: geometry threads into the trained bank through both API
|
||||
/// paths — inline in the train request, or recorded via /enroll/geometry —
|
||||
/// and a geometry-free train still produces a valid (unconditioned) bank.
|
||||
#[tokio::test]
|
||||
async fn train_threads_geometry_into_bank() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let app = build_router(test_state(dir.path().to_str().unwrap()));
|
||||
let anchors = r#"[
|
||||
{"room_id":"g","label":"empty","features":{"mean":1.0,"variance":1.0,"motion":0.1,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}},
|
||||
{"room_id":"g","label":"stand_still","features":{"mean":1.0,"variance":10.0,"motion":0.2,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}}
|
||||
]"#;
|
||||
let load_bank = |name: &str| {
|
||||
let raw = std::fs::read_to_string(dir.path().join(format!("{name}.json"))).unwrap();
|
||||
SpecialistBank::from_json(&raw).unwrap()
|
||||
};
|
||||
|
||||
// (1) geometry inline in the train request.
|
||||
let body = format!(
|
||||
r#"{{"room_id":"g1","baseline_id":"b","anchors":{anchors},
|
||||
"geometry":[{{"node_id":1,"position":{{"x_m":0.0,"y_m":0.0,"z_m":1.0}},"method":"tape-measure"}},{{"node_id":2}}]}}"#
|
||||
);
|
||||
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body)).await, StatusCode::OK);
|
||||
let bank = load_bank("g1");
|
||||
assert_eq!(bank.geometry.len(), 2);
|
||||
assert_eq!(bank.geometry[0].method, "tape-measure");
|
||||
assert_eq!(bank.geometry[1].node_id, 2);
|
||||
|
||||
// (2) geometry recorded via /enroll/geometry; train body omits it.
|
||||
assert_eq!(
|
||||
req(app.clone(), "POST", "/api/v1/enroll/geometry",
|
||||
Some(r#"{"room_id":"g2","geometry":[{"node_id":7,"method":"floor-plan"}]}"#)).await,
|
||||
StatusCode::OK
|
||||
);
|
||||
let body2 = format!(r#"{{"room_id":"g2","baseline_id":"b","anchors":{anchors}}}"#);
|
||||
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body2)).await, StatusCode::OK);
|
||||
let bank2 = load_bank("g2");
|
||||
assert_eq!(bank2.geometry.len(), 1);
|
||||
assert_eq!(bank2.geometry[0].node_id, 7);
|
||||
|
||||
// (3) no geometry anywhere → valid geometry-free bank (note logged).
|
||||
let body3 = format!(r#"{{"room_id":"g3","baseline_id":"b","anchors":{anchors}}}"#);
|
||||
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body3)).await, StatusCode::OK);
|
||||
let bank3 = load_bank("g3");
|
||||
assert!(bank3.geometry.is_empty());
|
||||
assert!(bank3.presence.is_some(), "bank still trains without geometry");
|
||||
|
||||
// (4) empty geometry array is rejected.
|
||||
assert_eq!(
|
||||
req(app, "POST", "/api/v1/enroll/geometry", Some(r#"{"room_id":"g4","geometry":[]}"#)).await,
|
||||
StatusCode::BAD_REQUEST
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enroll_status_empty_and_bad_label() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
|
||||
@@ -11,7 +11,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
use tokio::net::UdpSocket;
|
||||
use wifi_densepose_calibration::{
|
||||
Anchor, AnchorLabel, AnchorQualityGate, AnchorRecorder, EnrollmentEvent, EnrollmentSession,
|
||||
MixtureOfSpecialists, MultiNodeMixture, SpecialistBank,
|
||||
MixtureOfSpecialists, MultiNodeMixture, NodeGeometry, SpecialistBank,
|
||||
};
|
||||
use wifi_densepose_calibration::extract::{AnchorFeature, Features};
|
||||
use wifi_densepose_core::types::CsiFrame;
|
||||
@@ -226,20 +226,50 @@ pub struct TrainRoomArgs {
|
||||
/// Output specialist-bank file.
|
||||
#[arg(long, default_value = "./room-bank.json")]
|
||||
pub output: String,
|
||||
/// Optional transceiver-geometry file: a JSON array of `NodeGeometry`
|
||||
/// records (ADR-152 §2.1.1). Recorded into the enrollment session before
|
||||
/// training so the bank carries the layout it was trained under.
|
||||
#[arg(long)]
|
||||
pub geometry: Option<String>,
|
||||
}
|
||||
|
||||
/// Execute `train-room`.
|
||||
///
|
||||
/// If the enrollment session carries a transceiver-geometry snapshot (recorded
|
||||
/// at enroll time or supplied here via `--geometry`), it is threaded into the
|
||||
/// bank (ADR-152 §2.1.1); a geometry-free enrollment still trains a valid bank.
|
||||
pub async fn train_room(args: TrainRoomArgs) -> Result<()> {
|
||||
let raw = std::fs::read_to_string(&args.enrollment)
|
||||
.map_err(|e| anyhow::anyhow!("cannot read {}: {e} — run `enroll` first", args.enrollment))?;
|
||||
let data: EnrollmentData =
|
||||
let mut data: EnrollmentData =
|
||||
serde_json::from_str(&raw).map_err(|e| anyhow::anyhow!("invalid enrollment: {e}"))?;
|
||||
if data.anchors.is_empty() {
|
||||
bail!("no accepted anchors in {} — re-run enroll", args.enrollment);
|
||||
}
|
||||
|
||||
let bank = SpecialistBank::train(&data.room_id, &data.baseline_id, &data.anchors, now_unix())
|
||||
if let Some(path) = &args.geometry {
|
||||
let graw = std::fs::read_to_string(path)
|
||||
.map_err(|e| anyhow::anyhow!("cannot read geometry {path}: {e}"))?;
|
||||
let geometry: Vec<NodeGeometry> = serde_json::from_str(&graw).map_err(|e| {
|
||||
anyhow::anyhow!("invalid geometry {path}: {e} (expected a JSON array of NodeGeometry records)")
|
||||
})?;
|
||||
data.session.record_geometry(geometry, now_unix());
|
||||
}
|
||||
|
||||
let mut bank = SpecialistBank::train(&data.room_id, &data.baseline_id, &data.anchors, now_unix())
|
||||
.map_err(|e| anyhow::anyhow!("training failed: {e}"))?;
|
||||
match data.session.geometry() {
|
||||
Some(g) if !g.is_empty() => {
|
||||
bank = bank.with_geometry(g.to_vec());
|
||||
eprintln!(
|
||||
"[train-room] geometry: {} node(s) snapshotted into the bank (ADR-152 §2.1.1)",
|
||||
bank.geometry.len()
|
||||
);
|
||||
}
|
||||
_ => eprintln!(
|
||||
"[train-room] no transceiver geometry recorded — bank will not support geometry conditioning (ADR-152 §2.1.2)"
|
||||
),
|
||||
}
|
||||
std::fs::write(&args.output, bank.to_json().map_err(|e| anyhow::anyhow!("{e}"))?)
|
||||
.map_err(|e| anyhow::anyhow!("cannot write {}: {e}", args.output))?;
|
||||
|
||||
@@ -456,3 +486,141 @@ async fn room_watch_multi(args: RoomWatchArgs) -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn feature(label: AnchorLabel, variance: f32, motion: f32) -> AnchorFeature {
|
||||
AnchorFeature {
|
||||
room_id: "t".into(),
|
||||
label,
|
||||
features: Features {
|
||||
mean: 1.0,
|
||||
variance,
|
||||
motion,
|
||||
breathing_score: 0.0,
|
||||
breathing_hz: 0.0,
|
||||
heart_score: 0.0,
|
||||
heart_hz: 0.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a minimal valid enrollment file (two anchors, no geometry event).
|
||||
fn write_enrollment(dir: &std::path::Path) -> String {
|
||||
let data = EnrollmentData {
|
||||
room_id: "t".into(),
|
||||
baseline_id: "base-1".into(),
|
||||
fs_hz: 15.0,
|
||||
anchors: vec![
|
||||
feature(AnchorLabel::Empty, 1.0, 0.1),
|
||||
feature(AnchorLabel::StandStill, 10.0, 0.2),
|
||||
],
|
||||
session: EnrollmentSession::new("t", "base-1", 1000),
|
||||
};
|
||||
let path = dir.join("enrollment.json");
|
||||
std::fs::write(&path, serde_json::to_string(&data).unwrap()).unwrap();
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
fn trained_bank(out: &std::path::Path) -> SpecialistBank {
|
||||
SpecialistBank::from_json(&std::fs::read_to_string(out).unwrap()).unwrap()
|
||||
}
|
||||
|
||||
/// ADR-152 §2.1.1: `--geometry` records into the session and the bank
|
||||
/// snapshots it — enrollment geometry reaches the trained bank.
|
||||
#[tokio::test]
|
||||
async fn train_room_threads_geometry_when_provided() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let enrollment = write_enrollment(dir.path());
|
||||
let geometry = vec![
|
||||
NodeGeometry::new(1, "tape-measure").with_position(0.0, 0.0, 1.0),
|
||||
NodeGeometry::unknown(2),
|
||||
];
|
||||
let gpath = dir.path().join("geometry.json");
|
||||
std::fs::write(&gpath, serde_json::to_string(&geometry).unwrap()).unwrap();
|
||||
let out = dir.path().join("bank.json");
|
||||
|
||||
train_room(TrainRoomArgs {
|
||||
enrollment,
|
||||
output: out.to_string_lossy().into_owned(),
|
||||
geometry: Some(gpath.to_string_lossy().into_owned()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(trained_bank(&out).geometry, geometry);
|
||||
}
|
||||
|
||||
/// A geometry-free enrollment still trains a valid bank (optional by
|
||||
/// design) — it just carries no snapshot.
|
||||
#[tokio::test]
|
||||
async fn train_room_without_geometry_yields_geometry_free_bank() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let enrollment = write_enrollment(dir.path());
|
||||
let out = dir.path().join("bank.json");
|
||||
|
||||
train_room(TrainRoomArgs {
|
||||
enrollment,
|
||||
output: out.to_string_lossy().into_owned(),
|
||||
geometry: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let bank = trained_bank(&out);
|
||||
assert!(bank.geometry.is_empty());
|
||||
assert!(bank.presence.is_some(), "bank still trains without geometry");
|
||||
}
|
||||
|
||||
/// Geometry recorded at enroll time (in the session event log) is picked up
|
||||
/// without the `--geometry` flag.
|
||||
#[tokio::test]
|
||||
async fn train_room_uses_session_geometry() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let geometry = vec![NodeGeometry::new(3, "floor-plan").with_position(1.0, 2.0, 1.5)];
|
||||
let mut session = EnrollmentSession::new("t", "base-1", 1000);
|
||||
session.record_geometry(geometry.clone(), 1000);
|
||||
let data = EnrollmentData {
|
||||
room_id: "t".into(),
|
||||
baseline_id: "base-1".into(),
|
||||
fs_hz: 15.0,
|
||||
anchors: vec![
|
||||
feature(AnchorLabel::Empty, 1.0, 0.1),
|
||||
feature(AnchorLabel::StandStill, 10.0, 0.2),
|
||||
],
|
||||
session,
|
||||
};
|
||||
let epath = dir.path().join("enrollment.json");
|
||||
std::fs::write(&epath, serde_json::to_string(&data).unwrap()).unwrap();
|
||||
let out = dir.path().join("bank.json");
|
||||
|
||||
train_room(TrainRoomArgs {
|
||||
enrollment: epath.to_string_lossy().into_owned(),
|
||||
output: out.to_string_lossy().into_owned(),
|
||||
geometry: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(trained_bank(&out).geometry, geometry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn train_room_rejects_invalid_geometry_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let enrollment = write_enrollment(dir.path());
|
||||
let gpath = dir.path().join("geometry.json");
|
||||
std::fs::write(&gpath, r#"{"not":"an array"}"#).unwrap();
|
||||
|
||||
let err = train_room(TrainRoomArgs {
|
||||
enrollment,
|
||||
output: dir.path().join("bank.json").to_string_lossy().into_owned(),
|
||||
geometry: Some(gpath.to_string_lossy().into_owned()),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("invalid geometry"), "{err}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
//! Session FSM I/O types for the 802.11bf sensing model: events in
|
||||
//! ([`SessionEvent`]), actions out ([`Action`]), close reasons, static
|
||||
//! configuration, and the state enum.
|
||||
//!
|
||||
//! Split from [`super::session`] to keep each file under the ADR-153
|
||||
//! 500-line maintainability cap; the canonical public path re-exports
|
||||
//! these from [`super::session`].
|
||||
|
||||
use super::messages::{
|
||||
CsiReportPayload, SbpRequest, SbpResponse, SbpStatus, SensingMeasurementInstance,
|
||||
SensingMeasurementReport, SensingMeasurementSetupRequest, SensingMeasurementSetupResponse,
|
||||
SensingSessionTermination, TerminationReason,
|
||||
};
|
||||
use super::types::{MeasurementInstanceId, SensingCapabilities, SetupStatus, SpecProfile};
|
||||
|
||||
/// Session FSM states.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SessionState {
|
||||
Idle,
|
||||
SetupNegotiating,
|
||||
Active,
|
||||
Terminating,
|
||||
}
|
||||
|
||||
/// Inputs to the session FSM. `Start*` are local commands; `*Received` are
|
||||
/// frames from the peer; `Timeout`/`InstanceElapsed` are scheduler ticks.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SessionEvent {
|
||||
/// Local command (initiator): begin setup negotiation.
|
||||
StartSetup(SensingMeasurementSetupRequest),
|
||||
/// Local command (initiator): request sensing-by-proxy from an AP.
|
||||
StartSbp(SbpRequest),
|
||||
SetupRequestReceived(SensingMeasurementSetupRequest),
|
||||
SetupResponseReceived(SensingMeasurementSetupResponse),
|
||||
SbpRequestReceived(SbpRequest),
|
||||
SbpResponseReceived(SbpResponse),
|
||||
/// Scheduler tick: the negotiated periodicity elapsed (the
|
||||
/// measurement-driving endpoint — initiator or SBP proxy — emits the
|
||||
/// next measurement-instance trigger).
|
||||
InstanceElapsed,
|
||||
/// A sensing receiver captured a measurement for an instance (payload is
|
||||
/// fed by the transport/bridge — see `OpportunisticCsiBridge`).
|
||||
MeasurementCaptured {
|
||||
instance_id: MeasurementInstanceId,
|
||||
payload: CsiReportPayload,
|
||||
},
|
||||
ReportReceived(SensingMeasurementReport),
|
||||
/// Generic timeout tick for the current state.
|
||||
Timeout,
|
||||
/// Local command: terminate the session.
|
||||
Terminate(TerminationReason),
|
||||
TerminationReceived(SensingSessionTermination),
|
||||
}
|
||||
|
||||
/// Outputs of the session FSM. `Send*`/`TriggerInstance`/`RelaySbpReport`
|
||||
/// go to the transport; `DeliverReport`/`SessionClosed` go to the local
|
||||
/// consumer.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Action {
|
||||
SendSetupRequest(SensingMeasurementSetupRequest),
|
||||
SendSetupResponse(SensingMeasurementSetupResponse),
|
||||
SendSbpRequest(SbpRequest),
|
||||
SendSbpResponse(SbpResponse),
|
||||
TriggerInstance(SensingMeasurementInstance),
|
||||
SendReport(SensingMeasurementReport),
|
||||
DeliverReport(SensingMeasurementReport),
|
||||
/// SBP proxy mode: forward a report received from the sensing responder
|
||||
/// to the SBP client. The transport maps this to a frame toward the
|
||||
/// client (`SensingFrame::SbpReport`), distinct from `SendReport`,
|
||||
/// which travels toward the sensing initiator.
|
||||
RelaySbpReport(SensingMeasurementReport),
|
||||
SendTermination(SensingSessionTermination),
|
||||
SessionClosed(CloseReason),
|
||||
}
|
||||
|
||||
/// Why a session returned to Idle.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum CloseReason {
|
||||
SetupRejected(SetupStatus),
|
||||
SbpRejected(SbpStatus),
|
||||
Terminated(TerminationReason),
|
||||
/// Terminating-state quiescence completed (no peer echo required).
|
||||
Completed,
|
||||
}
|
||||
|
||||
/// Static configuration for a sensing session.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SessionConfig {
|
||||
/// Spec profile this endpoint advertises/accepts.
|
||||
pub profile: SpecProfile,
|
||||
/// Capability set used to evaluate inbound setups.
|
||||
pub capabilities: SensingCapabilities,
|
||||
/// Consecutive negotiation timeouts before aborting to Idle.
|
||||
pub max_setup_timeouts: u8,
|
||||
/// Consecutive missed instances (Active timeouts) before terminating.
|
||||
pub max_missed_instances: u8,
|
||||
}
|
||||
|
||||
impl Default for SessionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
profile: SpecProfile::Ieee80211Bf2025,
|
||||
capabilities: SensingCapabilities::sim_full(),
|
||||
max_setup_timeouts: 3,
|
||||
max_missed_instances: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
//! Procedure message types for the 802.11bf sensing model: measurement
|
||||
//! setup request/response, measurement instance, CSI-variant measurement
|
||||
//! report, sensing-by-proxy (SBP) exchange, session termination, and the
|
||||
//! minimal DMG (>45 GHz) stubs. Negotiation-core types (identifiers,
|
||||
//! parameters, capabilities, statuses) live in [`super::types`].
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::types::{
|
||||
BfError, MeasurementInstanceId, MeasurementSetupId, MeasurementSetupParams, SetupStatus,
|
||||
SpecProfile, MAX_REPORT_SUBCARRIERS,
|
||||
};
|
||||
|
||||
/// Sensing measurement setup request (initiator → responder).
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SensingMeasurementSetupRequest {
|
||||
/// Version gate for the negotiated surface.
|
||||
pub profile: SpecProfile,
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub params: MeasurementSetupParams,
|
||||
}
|
||||
|
||||
impl SensingMeasurementSetupRequest {
|
||||
pub fn validate(&self) -> Result<(), BfError> {
|
||||
self.params.validate()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sensing measurement setup response (responder → initiator).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SensingMeasurementSetupResponse {
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub status: SetupStatus,
|
||||
}
|
||||
|
||||
/// One scheduled sensing measurement instance within an active setup.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SensingMeasurementInstance {
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub instance_id: MeasurementInstanceId,
|
||||
/// Deterministic schedule offset of this instance (µs since setup
|
||||
/// activation; synthesized from the negotiated periodicity).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
/// CSI-variant sensing measurement report payload (amplitude/phase per
|
||||
/// usable subcarrier, averaged over the measurement instance).
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CsiReportPayload {
|
||||
pub n_subcarriers: u16,
|
||||
pub amplitudes: Vec<f32>,
|
||||
pub phases: Vec<f32>,
|
||||
}
|
||||
|
||||
impl CsiReportPayload {
|
||||
/// Boundary validation: shape coherence and value sanity. Rejects NaN,
|
||||
/// infinities, and negative amplitudes from adversarial peers.
|
||||
pub fn validate(&self) -> Result<(), BfError> {
|
||||
if self.n_subcarriers == 0 {
|
||||
return Err(BfError::EmptyPayload);
|
||||
}
|
||||
if self.n_subcarriers > MAX_REPORT_SUBCARRIERS {
|
||||
return Err(BfError::PayloadTooLarge {
|
||||
count: self.n_subcarriers,
|
||||
});
|
||||
}
|
||||
let declared = self.n_subcarriers as usize;
|
||||
if self.amplitudes.len() != declared || self.phases.len() != declared {
|
||||
return Err(BfError::PayloadLengthMismatch {
|
||||
declared,
|
||||
amplitudes: self.amplitudes.len(),
|
||||
phases: self.phases.len(),
|
||||
});
|
||||
}
|
||||
for (index, a) in self.amplitudes.iter().enumerate() {
|
||||
if !a.is_finite() || *a < 0.0 {
|
||||
return Err(BfError::PayloadValueInvalid { index });
|
||||
}
|
||||
}
|
||||
for (index, p) in self.phases.iter().enumerate() {
|
||||
if !p.is_finite() {
|
||||
return Err(BfError::PayloadValueInvalid { index });
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mean amplitude across subcarriers (threshold-trigger metric).
|
||||
pub fn mean_amplitude(&self) -> f64 {
|
||||
if self.amplitudes.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
self.amplitudes.iter().map(|a| *a as f64).sum::<f64>() / self.amplitudes.len() as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Sensing measurement report (sensing receiver → initiator).
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SensingMeasurementReport {
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub instance_id: MeasurementInstanceId,
|
||||
pub payload: CsiReportPayload,
|
||||
}
|
||||
|
||||
impl SensingMeasurementReport {
|
||||
pub fn validate(&self) -> Result<(), BfError> {
|
||||
self.payload.validate()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sensing-by-Proxy (SBP) request: a non-AP STA asks an AP to act as sensing
|
||||
/// initiator on its behalf and forward the resulting reports.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SbpRequest {
|
||||
pub profile: SpecProfile,
|
||||
/// Setup ID the proxy uses for the sensing it conducts on our behalf.
|
||||
pub proxy_setup_id: MeasurementSetupId,
|
||||
pub params: MeasurementSetupParams,
|
||||
}
|
||||
|
||||
impl SbpRequest {
|
||||
pub fn validate(&self) -> Result<(), BfError> {
|
||||
self.params.validate()
|
||||
}
|
||||
}
|
||||
|
||||
/// Status carried by an SBP response.
|
||||
///
|
||||
/// Mirrors [`SetupStatus`] 1:1 (see the `From<SetupStatus>` impl): an SBP
|
||||
/// request is validated through the same chain as a direct setup, so every
|
||||
/// rejection class must survive the proxy translation.
|
||||
/// `RejectedNotSupported` additionally covers a proxy that lacks the SBP
|
||||
/// capability itself.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SbpStatus {
|
||||
Accepted,
|
||||
RejectedNotSupported,
|
||||
RejectedUnsupportedParams,
|
||||
RejectedSetupIdCollision,
|
||||
RejectedIncompatibleProfile,
|
||||
RejectedByPolicy,
|
||||
RejectedCapacity,
|
||||
}
|
||||
|
||||
impl From<SetupStatus> for SbpStatus {
|
||||
/// 1:1 mapping from the direct-setup status space, keeping the SBP path
|
||||
/// on the single `evaluate_setup` validation chain (no SBP-only policy
|
||||
/// drift or bypass).
|
||||
fn from(status: SetupStatus) -> Self {
|
||||
match status {
|
||||
SetupStatus::Accepted => SbpStatus::Accepted,
|
||||
SetupStatus::RejectedNotSupported => SbpStatus::RejectedNotSupported,
|
||||
SetupStatus::RejectedUnsupportedParams => SbpStatus::RejectedUnsupportedParams,
|
||||
SetupStatus::RejectedSetupIdCollision => SbpStatus::RejectedSetupIdCollision,
|
||||
SetupStatus::RejectedIncompatibleProfile => SbpStatus::RejectedIncompatibleProfile,
|
||||
SetupStatus::RejectedByPolicy => SbpStatus::RejectedByPolicy,
|
||||
SetupStatus::RejectedCapacity => SbpStatus::RejectedCapacity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sensing-by-Proxy (SBP) response (proxy AP → requesting STA).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SbpResponse {
|
||||
pub proxy_setup_id: MeasurementSetupId,
|
||||
pub status: SbpStatus,
|
||||
}
|
||||
|
||||
/// Reason carried by a sensing session termination.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TerminationReason {
|
||||
InitiatorRequested,
|
||||
ResponderRequested,
|
||||
Timeout,
|
||||
PolicyChange,
|
||||
}
|
||||
|
||||
/// Sensing measurement setup termination (either side may send).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SensingSessionTermination {
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub reason: TerminationReason,
|
||||
}
|
||||
|
||||
/// Minimal stub for DMG/EDMG (>45 GHz) sensing types. The standard also
|
||||
/// covers directional multi-gigabit sensing; this model does not elaborate
|
||||
/// it beyond a typed placeholder (ADR-153 scope: sub-7 GHz focus).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DmgSensingType {
|
||||
Monostatic,
|
||||
Bistatic,
|
||||
Multistatic,
|
||||
}
|
||||
|
||||
/// Placeholder for a future DMG sensing setup surface.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct DmgSensingSetupStub {
|
||||
pub setup_id: MeasurementSetupId,
|
||||
pub sensing_type: DmgSensingType,
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
//! IEEE 802.11bf-2025 WLAN sensing — forward-compatibility protocol model
|
||||
//! (ADR-153, amending ADR-152 §2.4).
|
||||
//!
|
||||
//! # Why this exists
|
||||
//!
|
||||
//! IEEE 802.11bf-2025 ("WLAN Sensing") was **published 2025-09-26** (verified
|
||||
//! against the IEEE SA record — ADR-152 §1.1 F4, evidence grade MEASURED).
|
||||
//! Sensing standardization is complete for sub-7 GHz and >45 GHz (DMG) bands,
|
||||
//! with formal sensing measurement setup, measurement instance,
|
||||
//! feedback/reporting, and sensing-by-proxy (SBP) procedures.
|
||||
//!
|
||||
//! **No commodity silicon — ESP32 parts included — implements the standard
|
||||
//! yet.** ADR-152 §2.4 originally decided "track silicon; no code now";
|
||||
//! ADR-153 amends that clause: build the typed protocol surface now, so
|
||||
//! RuView can adopt standardized sensing the day any chipset exposes it.
|
||||
//! This layer is simulation-tested forward compatibility — the OTA binding
|
||||
//! lands when silicon does. Today's opportunistic CSI extraction (ADR-018 /
|
||||
//! ADR-028) remains the backend, mapped onto the standardized report path by
|
||||
//! [`transport::OpportunisticCsiBridge`].
|
||||
//!
|
||||
//! > This module is not a certified 802.11bf implementation. It models the
|
||||
//! > public procedure shape needed by RuView and RuvSense, while intentionally
|
||||
//! > avoiding OTA frame binding until chipset support and vendor APIs exist.
|
||||
//!
|
||||
//! # Layout
|
||||
//!
|
||||
//! - [`types`] — typed structures for the sensing procedures (setup, roles,
|
||||
//! measurement instances, CSI-variant reports, SBP, termination), plus the
|
||||
//! ADR-153 future-proofing surfaces: [`types::SpecProfile`] version gates,
|
||||
//! [`types::SensingCapabilities`] negotiation, and required
|
||||
//! [`types::ConsentMode`] governance metadata on every setup.
|
||||
//! - [`messages`] — the procedure message types (setup request/response,
|
||||
//! measurement instance, CSI-variant report, SBP exchange, termination).
|
||||
//! - [`session`] — deterministic event-driven session FSM:
|
||||
//! `Idle → SetupNegotiating → Active → Terminating → Idle`, with explicit
|
||||
//! rejection paths, timeout handling, single-role enforcement, and the
|
||||
//! first-class SBP proxy mode. No async, no clocks.
|
||||
//! - [`events`] — the FSM I/O types ([`events::SessionEvent`],
|
||||
//! [`events::Action`], close reasons, configuration), re-exported via
|
||||
//! [`session`].
|
||||
//! - [`table`] — responder-side setup registry (setup-ID collision and
|
||||
//! capacity rejection paths, for direct setups and SBP alike).
|
||||
//! - [`transport`] — the [`transport::SensingTransport`] seam, the
|
||||
//! [`transport::SimTransport`] test double, and the ESP32 bridge.
|
||||
|
||||
pub mod events;
|
||||
pub mod messages;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
pub use messages::{
|
||||
CsiReportPayload, DmgSensingSetupStub, DmgSensingType, SbpRequest, SbpResponse, SbpStatus,
|
||||
SensingMeasurementInstance, SensingMeasurementReport, SensingMeasurementSetupRequest,
|
||||
SensingMeasurementSetupResponse, SensingSessionTermination, TerminationReason,
|
||||
};
|
||||
pub use session::{Action, CloseReason, SensingSession, SessionConfig, SessionEvent, SessionState};
|
||||
pub use table::SessionTable;
|
||||
pub use transport::{
|
||||
action_to_frame, frame_to_event, OpportunisticCsiBridge, SensingFrame, SensingTransport,
|
||||
SimTransport, TransportError,
|
||||
};
|
||||
pub use types::{
|
||||
bandwidth_mhz, BfError, ConsentMode, MeasurementInstanceId, MeasurementSetupId,
|
||||
MeasurementSetupParams, ReportingConfig, SensingCapabilities, SensingRole, SetupStatus,
|
||||
SpecProfile, ThresholdParams, TransceiverRole, MAX_BURST_INSTANCES, MAX_PERIOD_MS,
|
||||
MAX_REPORT_SUBCARRIERS, MAX_SETUP_ID, MIN_PERIOD_MS,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
#[cfg(test)]
|
||||
mod tests_fsm;
|
||||
#[cfg(test)]
|
||||
mod tests_sbp;
|
||||
#[cfg(test)]
|
||||
mod testutil;
|
||||
@@ -0,0 +1,499 @@
|
||||
//! Sensing session state machine for the 802.11bf forward-compatibility model.
|
||||
//!
|
||||
//! Deterministic, event-driven, no async, no clocks: callers inject
|
||||
//! [`SessionEvent`]s (including `Timeout` ticks) and act on the returned
|
||||
//! [`Action`]s. State flow (ADR-153):
|
||||
//!
|
||||
//! ```text
|
||||
//! Idle → SetupNegotiating → Active → Terminating → Idle
|
||||
//! ```
|
||||
//!
|
||||
//! Rejection paths: unsupported parameters / incompatible profile / policy
|
||||
//! (responder responds with a rejected setup status), setup-ID collision
|
||||
//! ([`super::table::SessionTable`]), and negotiation timeout (typed
|
||||
//! [`BfError::NegotiationTimeout`] + reset to Idle).
|
||||
//!
|
||||
//! **Single-role design:** a session is constructed as initiator or responder
|
||||
//! and keeps that role for its whole lifetime. An initiator-role session
|
||||
//! receiving a peer's setup or SBP request answers `RejectedNotSupported`
|
||||
//! instead of accepting — a peer must never be able to hijack a session out
|
||||
//! of its configured role. Endpoints that play both roles run one session per
|
||||
//! role (or a [`super::table::SessionTable`] for the responder side).
|
||||
//!
|
||||
//! **SBP proxy mode:** a responder session that accepts an SBP request
|
||||
//! becomes a first-class proxy ([`SensingSession::is_sbp_proxy`]): it drives
|
||||
//! the standard initiator path toward the actual sensing responder —
|
||||
//! including re-triggering measurement instances on
|
||||
//! [`SessionEvent::InstanceElapsed`] — and relays every received report to
|
||||
//! the SBP client via [`Action::RelaySbpReport`], in addition to local
|
||||
//! [`Action::DeliverReport`] delivery.
|
||||
//!
|
||||
//! Local `Start*` commands issued outside Idle are caller bugs and surface
|
||||
//! as typed [`BfError::InvalidStateForCommand`]; genuinely ignorable stray
|
||||
//! frames/ticks remain silent no-ops. The FSM I/O types live in
|
||||
//! [`super::events`] and are re-exported here.
|
||||
|
||||
use super::messages::{
|
||||
SbpRequest, SbpResponse, SbpStatus, SensingMeasurementInstance, SensingMeasurementReport,
|
||||
SensingMeasurementSetupRequest, SensingMeasurementSetupResponse, SensingSessionTermination,
|
||||
TerminationReason,
|
||||
};
|
||||
use super::types::{
|
||||
BfError, MeasurementInstanceId, MeasurementSetupId, MeasurementSetupParams, ReportingConfig,
|
||||
SensingRole, SetupStatus,
|
||||
};
|
||||
|
||||
pub use super::events::{Action, CloseReason, SessionConfig, SessionEvent, SessionState};
|
||||
|
||||
/// One sensing session (one measurement setup) on one endpoint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SensingSession {
|
||||
role: SensingRole,
|
||||
state: SessionState,
|
||||
config: SessionConfig,
|
||||
/// Last setup request we sent (for negotiation re-sends).
|
||||
pending_request: Option<SensingMeasurementSetupRequest>,
|
||||
/// Negotiated (or in-negotiation) setup.
|
||||
setup: Option<(MeasurementSetupId, MeasurementSetupParams)>,
|
||||
/// True when this session awaits proxied sensing (SBP client).
|
||||
sbp_client: bool,
|
||||
/// True when this responder-role session proxies sensing for an SBP
|
||||
/// client: it drives the initiator path toward the sensing responder
|
||||
/// and relays received reports back to the client.
|
||||
sbp_proxy: bool,
|
||||
setup_timeouts: u8,
|
||||
missed_instances: u8,
|
||||
instance_counter: u32,
|
||||
/// Mean amplitude of the last *reported* measurement (threshold trigger).
|
||||
last_reported_mean: Option<f64>,
|
||||
}
|
||||
|
||||
impl SensingSession {
|
||||
pub fn new_initiator(config: SessionConfig) -> Self {
|
||||
Self::new(SensingRole::Initiator, config)
|
||||
}
|
||||
|
||||
pub fn new_responder(config: SessionConfig) -> Self {
|
||||
Self::new(SensingRole::Responder, config)
|
||||
}
|
||||
|
||||
fn new(role: SensingRole, config: SessionConfig) -> Self {
|
||||
Self {
|
||||
role,
|
||||
state: SessionState::Idle,
|
||||
config,
|
||||
pending_request: None,
|
||||
setup: None,
|
||||
sbp_client: false,
|
||||
sbp_proxy: false,
|
||||
setup_timeouts: 0,
|
||||
missed_instances: 0,
|
||||
instance_counter: 0,
|
||||
last_reported_mean: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> SessionState {
|
||||
self.state
|
||||
}
|
||||
|
||||
pub fn role(&self) -> SensingRole {
|
||||
self.role
|
||||
}
|
||||
|
||||
/// True when this session is acting as an SBP proxy (accepted via
|
||||
/// [`SessionEvent::SbpRequestReceived`]); cleared on reset to Idle.
|
||||
pub fn is_sbp_proxy(&self) -> bool {
|
||||
self.sbp_proxy
|
||||
}
|
||||
|
||||
pub fn setup_id(&self) -> Option<MeasurementSetupId> {
|
||||
self.setup.as_ref().map(|(id, _)| *id)
|
||||
}
|
||||
|
||||
/// Drive the FSM with one event. Protocol-level rejections surface as
|
||||
/// `Ok` actions (responses to the peer); malformed/adversarial input,
|
||||
/// out-of-state local commands, and negotiation timeout surface as typed
|
||||
/// `Err` (never a panic).
|
||||
pub fn handle(&mut self, event: SessionEvent) -> Result<Vec<Action>, BfError> {
|
||||
match self.state {
|
||||
SessionState::Idle => self.handle_idle(event),
|
||||
SessionState::SetupNegotiating => self.handle_negotiating(event),
|
||||
SessionState::Active => self.handle_active(event),
|
||||
SessionState::Terminating => self.handle_terminating(event),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_idle(&mut self, event: SessionEvent) -> Result<Vec<Action>, BfError> {
|
||||
match event {
|
||||
SessionEvent::StartSetup(req) => {
|
||||
if self.role != SensingRole::Initiator {
|
||||
return Err(BfError::InvalidStateForCommand {
|
||||
state: "Idle (responder cannot StartSetup)",
|
||||
});
|
||||
}
|
||||
req.validate()?;
|
||||
self.setup = Some((req.setup_id, req.params.clone()));
|
||||
self.pending_request = Some(req.clone());
|
||||
self.setup_timeouts = 0;
|
||||
self.state = SessionState::SetupNegotiating;
|
||||
Ok(vec![Action::SendSetupRequest(req)])
|
||||
}
|
||||
SessionEvent::StartSbp(sbp) => {
|
||||
if self.role != SensingRole::Initiator {
|
||||
return Err(BfError::InvalidStateForCommand {
|
||||
state: "Idle (responder cannot StartSbp)",
|
||||
});
|
||||
}
|
||||
sbp.validate()?;
|
||||
self.setup = Some((sbp.proxy_setup_id, sbp.params.clone()));
|
||||
self.sbp_client = true;
|
||||
self.setup_timeouts = 0;
|
||||
self.state = SessionState::SetupNegotiating;
|
||||
Ok(vec![Action::SendSbpRequest(sbp)])
|
||||
}
|
||||
SessionEvent::SetupRequestReceived(req) => {
|
||||
let response = |status| {
|
||||
Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
setup_id: req.setup_id,
|
||||
status,
|
||||
})
|
||||
};
|
||||
// Single-role design (module docs): an initiator-role
|
||||
// session never accepts a peer's setup request — accepting
|
||||
// here would let a peer hijack the session into the
|
||||
// responder path.
|
||||
if self.role != SensingRole::Responder {
|
||||
return Ok(vec![response(SetupStatus::RejectedNotSupported)]);
|
||||
}
|
||||
match self.evaluate_setup(&req) {
|
||||
SetupStatus::Accepted => {
|
||||
self.setup = Some((req.setup_id, req.params.clone()));
|
||||
self.missed_instances = 0;
|
||||
self.last_reported_mean = None;
|
||||
self.state = SessionState::Active;
|
||||
Ok(vec![response(SetupStatus::Accepted)])
|
||||
}
|
||||
status => Ok(vec![response(status)]),
|
||||
}
|
||||
}
|
||||
SessionEvent::SbpRequestReceived(sbp) => {
|
||||
// Single-role design: only responder-role sessions proxy.
|
||||
if self.role != SensingRole::Responder {
|
||||
return Ok(vec![Action::SendSbpResponse(SbpResponse {
|
||||
proxy_setup_id: sbp.proxy_setup_id,
|
||||
status: SbpStatus::RejectedNotSupported,
|
||||
})]);
|
||||
}
|
||||
Ok(self.handle_sbp_request(sbp))
|
||||
}
|
||||
// Stray frames/ticks in Idle are ignored, not errors.
|
||||
_ => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
/// SBP proxy path: accept the request, then run the *standard initiator
|
||||
/// path* toward the actual sensing responder. No direct sensor coupling —
|
||||
/// the proxied setup is an ordinary `SendSetupRequest` on the transport.
|
||||
///
|
||||
/// Validation is the single [`Self::evaluate_setup`] chain: the proxied
|
||||
/// setup request is built first and evaluated exactly as a direct setup
|
||||
/// would be, with the resulting [`SetupStatus`] mapped 1:1 onto
|
||||
/// [`SbpStatus`] — no SBP-only re-implementation that could drift from
|
||||
/// (or bypass) the setup policy.
|
||||
fn handle_sbp_request(&mut self, sbp: SbpRequest) -> Vec<Action> {
|
||||
let respond = |status| {
|
||||
Action::SendSbpResponse(SbpResponse {
|
||||
proxy_setup_id: sbp.proxy_setup_id,
|
||||
status,
|
||||
})
|
||||
};
|
||||
// SBP-specific capability gate; everything else is the setup chain.
|
||||
if !self.config.capabilities.sensing_by_proxy {
|
||||
return vec![respond(SbpStatus::RejectedNotSupported)];
|
||||
}
|
||||
let req = SensingMeasurementSetupRequest {
|
||||
profile: sbp.profile.clone(),
|
||||
setup_id: sbp.proxy_setup_id,
|
||||
params: sbp.params.clone(),
|
||||
};
|
||||
match self.evaluate_setup(&req) {
|
||||
SetupStatus::Accepted => {}
|
||||
status => return vec![respond(SbpStatus::from(status))],
|
||||
}
|
||||
self.setup = Some((req.setup_id, req.params.clone()));
|
||||
self.pending_request = Some(req.clone());
|
||||
self.sbp_proxy = true;
|
||||
self.setup_timeouts = 0;
|
||||
self.state = SessionState::SetupNegotiating;
|
||||
vec![respond(SbpStatus::Accepted), Action::SendSetupRequest(req)]
|
||||
}
|
||||
|
||||
fn evaluate_setup(&self, req: &SensingMeasurementSetupRequest) -> SetupStatus {
|
||||
if !self.config.profile.accepts(&req.profile) {
|
||||
return SetupStatus::RejectedIncompatibleProfile;
|
||||
}
|
||||
match req.validate() {
|
||||
Err(BfError::SensingDisabledByPolicy) => return SetupStatus::RejectedByPolicy,
|
||||
Err(_) => return SetupStatus::RejectedUnsupportedParams,
|
||||
Ok(()) => {}
|
||||
}
|
||||
match self.config.capabilities.evaluate(&req.params) {
|
||||
Err(status) => status,
|
||||
Ok(()) => SetupStatus::Accepted,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_negotiating(&mut self, event: SessionEvent) -> Result<Vec<Action>, BfError> {
|
||||
match event {
|
||||
SessionEvent::SetupResponseReceived(resp) => {
|
||||
let expected = match self.setup_id() {
|
||||
Some(id) => id,
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
if resp.setup_id != expected {
|
||||
return Err(BfError::SetupIdMismatch {
|
||||
expected: expected.value(),
|
||||
got: resp.setup_id.value(),
|
||||
});
|
||||
}
|
||||
match resp.status {
|
||||
SetupStatus::Accepted => {
|
||||
self.setup_timeouts = 0;
|
||||
self.missed_instances = 0;
|
||||
self.state = SessionState::Active;
|
||||
match self.next_instance_record() {
|
||||
Some(instance) => Ok(vec![Action::TriggerInstance(instance)]),
|
||||
None => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
status => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::SetupRejected(
|
||||
status,
|
||||
))])
|
||||
}
|
||||
}
|
||||
}
|
||||
SessionEvent::SbpResponseReceived(resp) if self.sbp_client => {
|
||||
let expected = match self.setup_id() {
|
||||
Some(id) => id,
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
if resp.proxy_setup_id != expected {
|
||||
return Err(BfError::SetupIdMismatch {
|
||||
expected: expected.value(),
|
||||
got: resp.proxy_setup_id.value(),
|
||||
});
|
||||
}
|
||||
match resp.status {
|
||||
SbpStatus::Accepted => {
|
||||
// Proxied reports will arrive via ReportReceived.
|
||||
self.setup_timeouts = 0;
|
||||
self.state = SessionState::Active;
|
||||
Ok(vec![])
|
||||
}
|
||||
status => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::SbpRejected(
|
||||
status,
|
||||
))])
|
||||
}
|
||||
}
|
||||
}
|
||||
SessionEvent::Timeout => {
|
||||
self.setup_timeouts = self.setup_timeouts.saturating_add(1);
|
||||
if self.setup_timeouts >= self.config.max_setup_timeouts {
|
||||
let setup_id = self.setup_id().map(|id| id.value()).unwrap_or(0);
|
||||
let attempts = self.setup_timeouts;
|
||||
self.reset();
|
||||
Err(BfError::NegotiationTimeout { setup_id, attempts })
|
||||
} else if let Some(req) = &self.pending_request {
|
||||
Ok(vec![Action::SendSetupRequest(req.clone())])
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
SessionEvent::Terminate(reason) => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::Terminated(reason))])
|
||||
}
|
||||
SessionEvent::TerminationReceived(term) => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::Terminated(
|
||||
term.reason,
|
||||
))])
|
||||
}
|
||||
// Local Start* outside Idle is a caller bug — typed error.
|
||||
SessionEvent::StartSetup(_) | SessionEvent::StartSbp(_) => {
|
||||
Err(BfError::InvalidStateForCommand {
|
||||
state: "SetupNegotiating",
|
||||
})
|
||||
}
|
||||
// Genuinely ignorable stray frames/ticks are no-ops.
|
||||
_ => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_active(&mut self, event: SessionEvent) -> Result<Vec<Action>, BfError> {
|
||||
match event {
|
||||
SessionEvent::InstanceElapsed => {
|
||||
// The measurement-driving endpoint re-triggers here: the
|
||||
// initiator, or an SBP proxy running the initiator path
|
||||
// toward the sensing responder. SBP *clients* only consume
|
||||
// proxied reports and never trigger instances.
|
||||
let drives_instances =
|
||||
(self.role == SensingRole::Initiator || self.sbp_proxy) && !self.sbp_client;
|
||||
if drives_instances {
|
||||
match self.next_instance_record() {
|
||||
Some(instance) => Ok(vec![Action::TriggerInstance(instance)]),
|
||||
None => Ok(vec![]),
|
||||
}
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
SessionEvent::MeasurementCaptured {
|
||||
instance_id,
|
||||
payload,
|
||||
} => {
|
||||
payload.validate()?;
|
||||
let (setup_id, params) = match &self.setup {
|
||||
Some((id, p)) => (*id, p.clone()),
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
// A successful capture means this instance was not missed —
|
||||
// the missed-instance budget counts *consecutive* misses,
|
||||
// so it resets here even when threshold-based reporting
|
||||
// suppresses the report below.
|
||||
self.missed_instances = 0;
|
||||
let mean = payload.mean_amplitude();
|
||||
let should_report = match params.reporting {
|
||||
ReportingConfig::EveryInstance => true,
|
||||
ReportingConfig::ThresholdBased(threshold) => match self.last_reported_mean {
|
||||
None => true,
|
||||
Some(previous) => threshold.exceeds(previous, mean),
|
||||
},
|
||||
};
|
||||
if !should_report {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
self.last_reported_mean = Some(mean);
|
||||
Ok(vec![Action::SendReport(SensingMeasurementReport {
|
||||
setup_id,
|
||||
instance_id,
|
||||
payload,
|
||||
})])
|
||||
}
|
||||
SessionEvent::ReportReceived(report) => {
|
||||
report.validate()?;
|
||||
let expected = match self.setup_id() {
|
||||
Some(id) => id,
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
if report.setup_id != expected {
|
||||
return Err(BfError::SetupIdMismatch {
|
||||
expected: expected.value(),
|
||||
got: report.setup_id.value(),
|
||||
});
|
||||
}
|
||||
self.missed_instances = 0;
|
||||
if self.sbp_proxy {
|
||||
// Proxy mode: deliver to the local consumer *and* relay
|
||||
// toward the SBP client on the transport.
|
||||
Ok(vec![
|
||||
Action::DeliverReport(report.clone()),
|
||||
Action::RelaySbpReport(report),
|
||||
])
|
||||
} else {
|
||||
Ok(vec![Action::DeliverReport(report)])
|
||||
}
|
||||
}
|
||||
SessionEvent::Timeout => {
|
||||
self.missed_instances = self.missed_instances.saturating_add(1);
|
||||
if self.missed_instances >= self.config.max_missed_instances {
|
||||
self.state = SessionState::Terminating;
|
||||
Ok(self.termination_actions(TerminationReason::Timeout))
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
SessionEvent::Terminate(reason) => {
|
||||
self.state = SessionState::Terminating;
|
||||
Ok(self.termination_actions(reason))
|
||||
}
|
||||
SessionEvent::TerminationReceived(term) => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::Terminated(
|
||||
term.reason,
|
||||
))])
|
||||
}
|
||||
// Local Start* outside Idle is a caller bug — typed error.
|
||||
SessionEvent::StartSetup(_) | SessionEvent::StartSbp(_) => {
|
||||
Err(BfError::InvalidStateForCommand { state: "Active" })
|
||||
}
|
||||
// Genuinely ignorable stray frames (duplicate setup/SBP traffic)
|
||||
// are no-ops.
|
||||
_ => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_terminating(&mut self, event: SessionEvent) -> Result<Vec<Action>, BfError> {
|
||||
match event {
|
||||
SessionEvent::TerminationReceived(term) => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::Terminated(
|
||||
term.reason,
|
||||
))])
|
||||
}
|
||||
// No peer echo is required: a quiescence tick completes teardown.
|
||||
SessionEvent::Timeout => {
|
||||
self.reset();
|
||||
Ok(vec![Action::SessionClosed(CloseReason::Completed)])
|
||||
}
|
||||
// Local Start* outside Idle is a caller bug — typed error.
|
||||
SessionEvent::StartSetup(_) | SessionEvent::StartSbp(_) => {
|
||||
Err(BfError::InvalidStateForCommand {
|
||||
state: "Terminating",
|
||||
})
|
||||
}
|
||||
_ => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
fn termination_actions(&self, reason: TerminationReason) -> Vec<Action> {
|
||||
match self.setup_id() {
|
||||
Some(setup_id) => vec![Action::SendTermination(SensingSessionTermination {
|
||||
setup_id,
|
||||
reason,
|
||||
})],
|
||||
None => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn next_instance_record(&mut self) -> Option<SensingMeasurementInstance> {
|
||||
let (setup_id, params) = match &self.setup {
|
||||
Some((id, p)) => (*id, p.clone()),
|
||||
None => return None,
|
||||
};
|
||||
let n = self.instance_counter;
|
||||
self.instance_counter = self.instance_counter.wrapping_add(1);
|
||||
Some(SensingMeasurementInstance {
|
||||
setup_id,
|
||||
instance_id: MeasurementInstanceId::new((n % 256) as u8),
|
||||
timestamp_us: u64::from(n) * u64::from(params.period_ms) * 1_000,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.state = SessionState::Idle;
|
||||
self.pending_request = None;
|
||||
self.setup = None;
|
||||
self.sbp_client = false;
|
||||
self.sbp_proxy = false;
|
||||
self.setup_timeouts = 0;
|
||||
self.missed_instances = 0;
|
||||
self.instance_counter = 0;
|
||||
self.last_reported_mean = None;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
//! Responder-side setup registry for the 802.11bf sensing model — enforces
|
||||
//! the setup-ID-collision and capacity rejection paths a single session
|
||||
//! cannot see on its own (ADR-153 acceptance: duplicate setup ID rejected).
|
||||
//! Both entry points — direct setups ([`SessionTable::handle_setup_request`])
|
||||
//! and sensing-by-proxy ([`SessionTable::handle_sbp_request`]) — share the
|
||||
//! same guards and the same per-setup session storage.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use super::messages::{
|
||||
SbpRequest, SbpResponse, SbpStatus, SensingMeasurementSetupRequest,
|
||||
SensingMeasurementSetupResponse,
|
||||
};
|
||||
use super::session::{Action, SensingSession, SessionConfig, SessionEvent, SessionState};
|
||||
use super::types::{BfError, MeasurementSetupId, SetupStatus};
|
||||
|
||||
/// Responder-side registry of sensing sessions keyed by setup ID.
|
||||
///
|
||||
/// Enforces the setup-ID-collision and capacity rejection paths the single
|
||||
/// session cannot see on its own.
|
||||
#[derive(Debug)]
|
||||
pub struct SessionTable {
|
||||
config: SessionConfig,
|
||||
sessions: BTreeMap<u8, SensingSession>,
|
||||
/// Events dropped because no session owned the setup ID (see
|
||||
/// [`Self::handle_for`]).
|
||||
unknown_setup_drops: u64,
|
||||
}
|
||||
|
||||
impl SessionTable {
|
||||
pub fn new(config: SessionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
sessions: BTreeMap::new(),
|
||||
unknown_setup_drops: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of setups not in Idle.
|
||||
pub fn active_setups(&self) -> usize {
|
||||
self.sessions
|
||||
.values()
|
||||
.filter(|s| s.state() != SessionState::Idle)
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn session(&self, setup_id: MeasurementSetupId) -> Option<&SensingSession> {
|
||||
self.sessions.get(&setup_id.value())
|
||||
}
|
||||
|
||||
/// Count of events dropped by [`Self::handle_for`] because the setup ID
|
||||
/// was unknown — lets an AP spot peers addressing setups it never
|
||||
/// accepted without turning stray frames into errors.
|
||||
pub fn unknown_setup_drops(&self) -> u64 {
|
||||
self.unknown_setup_drops
|
||||
}
|
||||
|
||||
/// Route an inbound setup request, rejecting setup-ID collisions and
|
||||
/// capacity overruns before delegating to a responder session.
|
||||
pub fn handle_setup_request(
|
||||
&mut self,
|
||||
req: SensingMeasurementSetupRequest,
|
||||
) -> Result<Vec<Action>, BfError> {
|
||||
let reject = |setup_id, status| {
|
||||
Ok(vec![Action::SendSetupResponse(
|
||||
SensingMeasurementSetupResponse { setup_id, status },
|
||||
)])
|
||||
};
|
||||
if self.is_collision(req.setup_id) {
|
||||
return reject(req.setup_id, SetupStatus::RejectedSetupIdCollision);
|
||||
}
|
||||
if self.at_capacity() {
|
||||
return reject(req.setup_id, SetupStatus::RejectedCapacity);
|
||||
}
|
||||
let key = req.setup_id.value();
|
||||
let mut session = SensingSession::new_responder(self.config.clone());
|
||||
let actions = session.handle(SessionEvent::SetupRequestReceived(req))?;
|
||||
self.sessions.insert(key, session);
|
||||
Ok(actions)
|
||||
}
|
||||
|
||||
/// Route an inbound SBP request, rejecting proxy-setup-ID collisions and
|
||||
/// capacity overruns before delegating to a (new) proxy session — the
|
||||
/// SBP mirror of [`Self::handle_setup_request`], so a table-driven AP
|
||||
/// accepts SBP end-to-end instead of silently dropping it.
|
||||
pub fn handle_sbp_request(&mut self, sbp: SbpRequest) -> Result<Vec<Action>, BfError> {
|
||||
let reject = |proxy_setup_id, status| {
|
||||
Ok(vec![Action::SendSbpResponse(SbpResponse {
|
||||
proxy_setup_id,
|
||||
status,
|
||||
})])
|
||||
};
|
||||
if self.is_collision(sbp.proxy_setup_id) {
|
||||
return reject(sbp.proxy_setup_id, SbpStatus::RejectedSetupIdCollision);
|
||||
}
|
||||
if self.at_capacity() {
|
||||
return reject(sbp.proxy_setup_id, SbpStatus::RejectedCapacity);
|
||||
}
|
||||
let key = sbp.proxy_setup_id.value();
|
||||
let mut session = SensingSession::new_responder(self.config.clone());
|
||||
let actions = session.handle(SessionEvent::SbpRequestReceived(sbp))?;
|
||||
self.sessions.insert(key, session);
|
||||
Ok(actions)
|
||||
}
|
||||
|
||||
/// Route any other event to the session owning `setup_id`.
|
||||
///
|
||||
/// Frames addressing an unknown setup are dropped *by design* (stray
|
||||
/// frames are ignored, not errors), but the drop is observable through
|
||||
/// [`Self::unknown_setup_drops`].
|
||||
pub fn handle_for(
|
||||
&mut self,
|
||||
setup_id: MeasurementSetupId,
|
||||
event: SessionEvent,
|
||||
) -> Result<Vec<Action>, BfError> {
|
||||
match self.sessions.get_mut(&setup_id.value()) {
|
||||
Some(session) => session.handle(event),
|
||||
None => {
|
||||
self.unknown_setup_drops = self.unknown_setup_drops.saturating_add(1);
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A non-Idle session already owns this setup ID.
|
||||
fn is_collision(&self, setup_id: MeasurementSetupId) -> bool {
|
||||
self.sessions
|
||||
.get(&setup_id.value())
|
||||
.is_some_and(|existing| existing.state() != SessionState::Idle)
|
||||
}
|
||||
|
||||
/// The active-setup budget is exhausted.
|
||||
fn at_capacity(&self) -> bool {
|
||||
self.active_setups() >= self.config.capabilities.max_active_setups as usize
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
//! ADR-153 acceptance tests — types (serde round trips, boundary
|
||||
//! validation), the SimTransport double, and the ESP32 CSI bridge.
|
||||
//! FSM/timeout/threshold/SBP coverage lives in [`super::tests_fsm`].
|
||||
//! All tests are hardware-free (simulation only).
|
||||
|
||||
use super::messages::*;
|
||||
use super::testutil::{csi_frame, params, payload, setup_request};
|
||||
use super::transport::{
|
||||
OpportunisticCsiBridge, SensingFrame, SensingTransport, SimTransport, TransportError,
|
||||
};
|
||||
use super::types::*;
|
||||
|
||||
// ---------- serde round trips ----------
|
||||
|
||||
#[test]
|
||||
fn serde_round_trips_setup_instance_report_sbp_termination() {
|
||||
let req = setup_request(7);
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SensingMeasurementSetupRequest>(&json).unwrap(),
|
||||
req
|
||||
);
|
||||
|
||||
let resp = SensingMeasurementSetupResponse {
|
||||
setup_id: req.setup_id,
|
||||
status: SetupStatus::Accepted,
|
||||
};
|
||||
let json = serde_json::to_string(&resp).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SensingMeasurementSetupResponse>(&json).unwrap(),
|
||||
resp
|
||||
);
|
||||
|
||||
let instance = SensingMeasurementInstance {
|
||||
setup_id: req.setup_id,
|
||||
instance_id: MeasurementInstanceId::new(3),
|
||||
timestamp_us: 300_000,
|
||||
};
|
||||
let json = serde_json::to_string(&instance).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SensingMeasurementInstance>(&json).unwrap(),
|
||||
instance
|
||||
);
|
||||
|
||||
let report = SensingMeasurementReport {
|
||||
setup_id: req.setup_id,
|
||||
instance_id: MeasurementInstanceId::new(3),
|
||||
payload: payload(42.0),
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SensingMeasurementReport>(&json).unwrap(),
|
||||
report
|
||||
);
|
||||
|
||||
let sbp = SbpRequest {
|
||||
profile: SpecProfile::VendorExtension("acme-presensing".into()),
|
||||
proxy_setup_id: req.setup_id,
|
||||
params: params(),
|
||||
};
|
||||
let json = serde_json::to_string(&sbp).unwrap();
|
||||
assert_eq!(serde_json::from_str::<SbpRequest>(&json).unwrap(), sbp);
|
||||
|
||||
let sbp_resp = SbpResponse {
|
||||
proxy_setup_id: req.setup_id,
|
||||
status: SbpStatus::Accepted,
|
||||
};
|
||||
let json = serde_json::to_string(&sbp_resp).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SbpResponse>(&json).unwrap(),
|
||||
sbp_resp
|
||||
);
|
||||
|
||||
let term = SensingSessionTermination {
|
||||
setup_id: req.setup_id,
|
||||
reason: TerminationReason::InitiatorRequested,
|
||||
};
|
||||
let json = serde_json::to_string(&term).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::from_str::<SensingSessionTermination>(&json).unwrap(),
|
||||
term
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_rejects_out_of_range_setup_id() {
|
||||
assert!(serde_json::from_str::<MeasurementSetupId>("200").is_err());
|
||||
assert!(serde_json::from_str::<MeasurementSetupId>("127").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_rejects_out_of_range_threshold_params() {
|
||||
assert!(serde_json::from_str::<ThresholdParams>(r#"{"delta_percent":255}"#).is_err());
|
||||
let ok = serde_json::from_str::<ThresholdParams>(r#"{"delta_percent":100}"#).unwrap();
|
||||
assert_eq!(ok.delta_percent(), 100);
|
||||
}
|
||||
|
||||
// ---------- validation, no panics ----------
|
||||
|
||||
#[test]
|
||||
fn setup_id_construction_never_panics_and_bounds_hold() {
|
||||
for v in 0u8..=255 {
|
||||
let result = MeasurementSetupId::new(v);
|
||||
assert_eq!(result.is_ok(), v <= MAX_SETUP_ID);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn params_validation_rejects_malformed() {
|
||||
let mut p = params();
|
||||
p.period_ms = MIN_PERIOD_MS - 1;
|
||||
assert!(matches!(p.validate(), Err(BfError::InvalidPeriod { .. })));
|
||||
p = params();
|
||||
p.period_ms = MAX_PERIOD_MS + 1;
|
||||
assert!(matches!(p.validate(), Err(BfError::InvalidPeriod { .. })));
|
||||
p = params();
|
||||
p.burst_instances = 0;
|
||||
assert!(matches!(
|
||||
p.validate(),
|
||||
Err(BfError::InvalidBurstInstances { .. })
|
||||
));
|
||||
p = params();
|
||||
p.burst_instances = MAX_BURST_INSTANCES + 1;
|
||||
assert!(matches!(
|
||||
p.validate(),
|
||||
Err(BfError::InvalidBurstInstances { .. })
|
||||
));
|
||||
p = params();
|
||||
p.initiator_role = TransceiverRole::Receiver; // no transmitter anywhere
|
||||
assert!(matches!(
|
||||
p.validate(),
|
||||
Err(BfError::InvalidTransceiverRoles)
|
||||
));
|
||||
p = params();
|
||||
p.consent = ConsentMode::Disabled;
|
||||
assert!(matches!(
|
||||
p.validate(),
|
||||
Err(BfError::SensingDisabledByPolicy)
|
||||
));
|
||||
assert!(ThresholdParams::new(101).is_err());
|
||||
assert!(ThresholdParams::new(100).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn payload_validation_rejects_adversarial_values_without_panic() {
|
||||
let adversarial = [
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 0,
|
||||
amplitudes: vec![],
|
||||
phases: vec![],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: u16::MAX,
|
||||
amplitudes: vec![1.0; 4],
|
||||
phases: vec![0.0; 4],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 4,
|
||||
amplitudes: vec![1.0; 3],
|
||||
phases: vec![0.0; 4],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 2,
|
||||
amplitudes: vec![f32::NAN, 1.0],
|
||||
phases: vec![0.0; 2],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 2,
|
||||
amplitudes: vec![1.0, f32::INFINITY],
|
||||
phases: vec![0.0; 2],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 2,
|
||||
amplitudes: vec![-1.0, 1.0],
|
||||
phases: vec![0.0; 2],
|
||||
},
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 2,
|
||||
amplitudes: vec![1.0; 2],
|
||||
phases: vec![f32::NEG_INFINITY, 0.0],
|
||||
},
|
||||
];
|
||||
for p in adversarial {
|
||||
assert!(p.validate().is_err());
|
||||
}
|
||||
assert!(payload(5.0).validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_profile_compatibility() {
|
||||
let published = SpecProfile::Ieee80211Bf2025;
|
||||
assert!(published.accepts(&SpecProfile::DraftCompatible));
|
||||
assert!(published.accepts(&SpecProfile::Ieee80211Bf2025));
|
||||
assert!(!published.accepts(&SpecProfile::VendorExtension("x".into())));
|
||||
let vendor = SpecProfile::VendorExtension("x".into());
|
||||
assert!(vendor.accepts(&SpecProfile::VendorExtension("x".into())));
|
||||
assert!(!vendor.accepts(&SpecProfile::VendorExtension("y".into())));
|
||||
}
|
||||
|
||||
// ---------- bridge: ESP32 CSI → standardized report ----------
|
||||
|
||||
#[test]
|
||||
fn bridge_maps_csi_batches_to_measurement_reports() {
|
||||
let setup_id = MeasurementSetupId::new(1).unwrap();
|
||||
let mut bridge = OpportunisticCsiBridge::new(setup_id, 4).unwrap();
|
||||
assert!(OpportunisticCsiBridge::new(setup_id, 0).is_err());
|
||||
|
||||
// 3 frames: no report yet. 4th completes the instance batch.
|
||||
for _ in 0..3 {
|
||||
assert!(bridge.ingest(&csi_frame(8, 30, 40)).is_none());
|
||||
}
|
||||
let report = bridge
|
||||
.ingest(&csi_frame(8, 30, 40))
|
||||
.expect("batch complete");
|
||||
assert_eq!(report.setup_id, setup_id);
|
||||
assert_eq!(report.instance_id.value(), 0);
|
||||
assert_eq!(report.payload.n_subcarriers, 8);
|
||||
assert!(report.payload.validate().is_ok());
|
||||
// |30 + 40i| = 50 on every subcarrier of every frame.
|
||||
assert!(report
|
||||
.payload
|
||||
.amplitudes
|
||||
.iter()
|
||||
.all(|a| (a - 50.0).abs() < 1e-3));
|
||||
|
||||
// Invalid (all-zero) frames are skipped and do not advance the batch.
|
||||
for _ in 0..10 {
|
||||
assert!(bridge.ingest(&csi_frame(8, 0, 0)).is_none());
|
||||
}
|
||||
// A mid-batch subcarrier-shape change restarts the batch on the new shape.
|
||||
assert!(bridge.ingest(&csi_frame(8, 10, 0)).is_none());
|
||||
assert!(bridge.ingest(&csi_frame(4, 10, 0)).is_none()); // restart at n=4
|
||||
for _ in 0..2 {
|
||||
assert!(bridge.ingest(&csi_frame(4, 10, 0)).is_none());
|
||||
}
|
||||
let report = bridge.ingest(&csi_frame(4, 10, 0)).expect("second batch");
|
||||
assert_eq!(report.instance_id.value(), 1); // instance counter advanced
|
||||
assert_eq!(report.payload.n_subcarriers, 4);
|
||||
}
|
||||
|
||||
// ---------- transport ----------
|
||||
|
||||
#[test]
|
||||
fn sim_transport_scripted_responses_and_failures() {
|
||||
let mut t = SimTransport::new();
|
||||
let resp = SensingMeasurementSetupResponse {
|
||||
setup_id: MeasurementSetupId::new(7).unwrap(),
|
||||
status: SetupStatus::Accepted,
|
||||
};
|
||||
t.script_response(SensingFrame::SetupResponse(resp));
|
||||
assert!(t.poll_frame().is_none());
|
||||
t.send_setup_request(setup_request(7)).unwrap();
|
||||
assert_eq!(t.poll_frame(), Some(SensingFrame::SetupResponse(resp)));
|
||||
assert_eq!(t.sent().len(), 1);
|
||||
|
||||
let mut tiny = SimTransport::with_capacity(1);
|
||||
tiny.send_setup_request(setup_request(1)).unwrap();
|
||||
assert_eq!(
|
||||
tiny.send_setup_request(setup_request(2)),
|
||||
Err(TransportError::QueueFull { capacity: 1 })
|
||||
);
|
||||
|
||||
let mut down = SimTransport::new();
|
||||
down.set_link_down(true);
|
||||
assert_eq!(
|
||||
down.send_setup_request(setup_request(1)),
|
||||
Err(TransportError::LinkDown)
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
//! ADR-153 acceptance tests — session FSM full cycle, rejection paths,
|
||||
//! timeout handling, threshold-based reporting, single-role enforcement,
|
||||
//! and adversarial no-panic coverage. SBP flows live in [`super::tests_sbp`];
|
||||
//! type/serde/transport/bridge tests in [`super::tests`]. All tests are
|
||||
//! hardware-free (simulation only).
|
||||
|
||||
use super::messages::*;
|
||||
use super::session::{
|
||||
Action, CloseReason, SensingSession, SessionConfig, SessionEvent, SessionState,
|
||||
};
|
||||
use super::table::SessionTable;
|
||||
use super::testutil::{dispatch, ferry, params, payload, pump, setup_request};
|
||||
use super::transport::{SensingFrame, SimTransport};
|
||||
use super::types::*;
|
||||
use crate::csi_frame::Bandwidth;
|
||||
|
||||
// ---------- FSM: full cycle ----------
|
||||
|
||||
#[test]
|
||||
fn fsm_full_cycle_setup_measure_report_terminate() {
|
||||
let cfg = SessionConfig::default();
|
||||
let mut initiator = SensingSession::new_initiator(cfg.clone());
|
||||
let mut responder = SensingSession::new_responder(cfg);
|
||||
let mut wire_i = SimTransport::new();
|
||||
let mut wire_r = SimTransport::new();
|
||||
|
||||
// Idle → SetupNegotiating
|
||||
dispatch(
|
||||
&mut initiator,
|
||||
SessionEvent::StartSetup(setup_request(7)),
|
||||
&mut wire_i,
|
||||
);
|
||||
assert_eq!(initiator.state(), SessionState::SetupNegotiating);
|
||||
|
||||
// Responder accepts → Active
|
||||
ferry(&mut wire_i, &mut wire_r);
|
||||
pump(&mut responder, &mut wire_r);
|
||||
assert_eq!(responder.state(), SessionState::Active);
|
||||
|
||||
// Initiator sees Accepted → Active + first instance trigger on the wire
|
||||
ferry(&mut wire_r, &mut wire_i);
|
||||
pump(&mut initiator, &mut wire_i);
|
||||
assert_eq!(initiator.state(), SessionState::Active);
|
||||
assert!(wire_i
|
||||
.sent()
|
||||
.iter()
|
||||
.any(|f| matches!(f, SensingFrame::InstanceTrigger(i) if i.setup_id.value() == 7)));
|
||||
|
||||
// Responder captures a measurement → report on the wire
|
||||
wire_i.drain_sent();
|
||||
let actions = dispatch(
|
||||
&mut responder,
|
||||
SessionEvent::MeasurementCaptured {
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: payload(10.0),
|
||||
},
|
||||
&mut wire_r,
|
||||
);
|
||||
assert!(actions.iter().any(|a| matches!(a, Action::SendReport(_))));
|
||||
|
||||
// Initiator delivers the report to its consumer
|
||||
ferry(&mut wire_r, &mut wire_i);
|
||||
let actions = pump(&mut initiator, &mut wire_i);
|
||||
assert!(actions
|
||||
.iter()
|
||||
.any(|a| matches!(a, Action::DeliverReport(_))));
|
||||
|
||||
// Active → Terminating → Idle (peer notified, quiescence completes)
|
||||
wire_i.drain_sent();
|
||||
dispatch(
|
||||
&mut initiator,
|
||||
SessionEvent::Terminate(TerminationReason::InitiatorRequested),
|
||||
&mut wire_i,
|
||||
);
|
||||
assert_eq!(initiator.state(), SessionState::Terminating);
|
||||
ferry(&mut wire_i, &mut wire_r);
|
||||
let actions = pump(&mut responder, &mut wire_r);
|
||||
assert!(actions.iter().any(|a| matches!(
|
||||
a,
|
||||
Action::SessionClosed(CloseReason::Terminated(
|
||||
TerminationReason::InitiatorRequested
|
||||
))
|
||||
)));
|
||||
assert_eq!(responder.state(), SessionState::Idle);
|
||||
let actions = initiator.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(actions
|
||||
.iter()
|
||||
.any(|a| matches!(a, Action::SessionClosed(CloseReason::Completed))));
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
}
|
||||
|
||||
// ---------- FSM: rejection paths ----------
|
||||
|
||||
#[test]
|
||||
fn responder_rejects_unsupported_bandwidth_and_initiator_resets() {
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.capabilities = SensingCapabilities::esp32_opportunistic(); // max 40 MHz
|
||||
let mut responder = SensingSession::new_responder(cfg);
|
||||
let mut initiator = SensingSession::new_initiator(SessionConfig::default());
|
||||
|
||||
let mut req = setup_request(3);
|
||||
req.params.bandwidth = Bandwidth::Bw80;
|
||||
initiator
|
||||
.handle(SessionEvent::StartSetup(req.clone()))
|
||||
.unwrap();
|
||||
|
||||
let actions = responder
|
||||
.handle(SessionEvent::SetupRequestReceived(req))
|
||||
.unwrap();
|
||||
let resp = match &actions[..] {
|
||||
[Action::SendSetupResponse(r)] => *r,
|
||||
other => panic!("expected single rejection response, got {other:?}"),
|
||||
};
|
||||
assert_eq!(resp.status, SetupStatus::RejectedUnsupportedParams);
|
||||
assert_eq!(responder.state(), SessionState::Idle);
|
||||
|
||||
let actions = initiator
|
||||
.handle(SessionEvent::SetupResponseReceived(resp))
|
||||
.unwrap();
|
||||
assert!(actions.iter().any(|a| matches!(
|
||||
a,
|
||||
Action::SessionClosed(CloseReason::SetupRejected(
|
||||
SetupStatus::RejectedUnsupportedParams
|
||||
))
|
||||
)));
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_period_rejected_on_both_sides() {
|
||||
let mut req = setup_request(4);
|
||||
req.params.period_ms = 1; // below MIN_PERIOD_MS
|
||||
let mut initiator = SensingSession::new_initiator(SessionConfig::default());
|
||||
assert!(matches!(
|
||||
initiator.handle(SessionEvent::StartSetup(req.clone())),
|
||||
Err(BfError::InvalidPeriod { period_ms: 1 })
|
||||
));
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default());
|
||||
let actions = responder
|
||||
.handle(SessionEvent::SetupRequestReceived(req))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedUnsupportedParams,
|
||||
..
|
||||
})]
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_setup_id_rejected_by_session_table() {
|
||||
let mut table = SessionTable::new(SessionConfig::default());
|
||||
let actions = table.handle_setup_request(setup_request(9)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::Accepted,
|
||||
..
|
||||
})]
|
||||
));
|
||||
let actions = table.handle_setup_request(setup_request(9)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedSetupIdCollision,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(table.active_setups(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capacity_and_policy_and_profile_rejections() {
|
||||
// Capacity
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.capabilities.max_active_setups = 1;
|
||||
let mut table = SessionTable::new(cfg);
|
||||
table.handle_setup_request(setup_request(1)).unwrap();
|
||||
let actions = table.handle_setup_request(setup_request(2)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedCapacity,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Consent policy
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default());
|
||||
let mut req = setup_request(5);
|
||||
req.params.consent = ConsentMode::Disabled;
|
||||
let actions = responder
|
||||
.handle(SessionEvent::SetupRequestReceived(req))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedByPolicy,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Incompatible profile
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.profile = SpecProfile::VendorExtension("acme".into());
|
||||
let mut responder = SensingSession::new_responder(cfg);
|
||||
let actions = responder
|
||||
.handle(SessionEvent::SetupRequestReceived(setup_request(6)))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedIncompatibleProfile,
|
||||
..
|
||||
})]
|
||||
));
|
||||
}
|
||||
|
||||
// ---------- FSM: timeouts ----------
|
||||
|
||||
#[test]
|
||||
fn negotiation_timeout_returns_typed_error_and_resets_to_idle() {
|
||||
let mut initiator = SensingSession::new_initiator(SessionConfig::default()); // 3 timeouts
|
||||
initiator
|
||||
.handle(SessionEvent::StartSetup(setup_request(7)))
|
||||
.unwrap();
|
||||
|
||||
// First two timeouts re-send the pending request.
|
||||
for _ in 0..2 {
|
||||
let actions = initiator.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendSetupRequest(_)]));
|
||||
assert_eq!(initiator.state(), SessionState::SetupNegotiating);
|
||||
}
|
||||
// Third gives up: typed error + Idle.
|
||||
assert_eq!(
|
||||
initiator.handle(SessionEvent::Timeout),
|
||||
Err(BfError::NegotiationTimeout {
|
||||
setup_id: 7,
|
||||
attempts: 3
|
||||
})
|
||||
);
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_missed_instance_timeouts_terminate_session() {
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default()); // 5 missed max
|
||||
responder
|
||||
.handle(SessionEvent::SetupRequestReceived(setup_request(2)))
|
||||
.unwrap();
|
||||
assert_eq!(responder.state(), SessionState::Active);
|
||||
for _ in 0..4 {
|
||||
assert!(responder.handle(SessionEvent::Timeout).unwrap().is_empty());
|
||||
}
|
||||
let actions = responder.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendTermination(SensingSessionTermination {
|
||||
reason: TerminationReason::Timeout,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(responder.state(), SessionState::Terminating);
|
||||
let actions = responder.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SessionClosed(CloseReason::Completed)]
|
||||
));
|
||||
assert_eq!(responder.state(), SessionState::Idle);
|
||||
}
|
||||
|
||||
// ---------- threshold-based reporting ----------
|
||||
|
||||
#[test]
|
||||
fn threshold_report_emitted_only_when_threshold_crossed() {
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default());
|
||||
let mut req = setup_request(8);
|
||||
req.params.reporting = ReportingConfig::ThresholdBased(ThresholdParams::new(20).unwrap());
|
||||
responder
|
||||
.handle(SessionEvent::SetupRequestReceived(req))
|
||||
.unwrap();
|
||||
|
||||
let capture = |mean: f32| SessionEvent::MeasurementCaptured {
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: payload(mean),
|
||||
};
|
||||
// First measurement always reported (establishes the baseline).
|
||||
let actions = responder.handle(capture(100.0)).unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendReport(_)]));
|
||||
// +10% — below threshold, suppressed; baseline stays at 100.
|
||||
assert!(responder.handle(capture(110.0)).unwrap().is_empty());
|
||||
// +19% vs the *reported* baseline — still suppressed.
|
||||
assert!(responder.handle(capture(119.0)).unwrap().is_empty());
|
||||
// +50% — crossed, reported, baseline moves to 150.
|
||||
let actions = responder.handle(capture(150.0)).unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendReport(_)]));
|
||||
// 150 → 125 is ~16.7% — suppressed against the new baseline.
|
||||
assert!(responder.handle(capture(125.0)).unwrap().is_empty());
|
||||
}
|
||||
|
||||
// ---------- consecutive missed-instance semantics ----------
|
||||
|
||||
#[test]
|
||||
fn missed_instance_budget_is_consecutive_not_cumulative() {
|
||||
// Review finding 2: a successful measurement must reset the
|
||||
// missed-instance counter — `max_missed_instances` bounds *consecutive*
|
||||
// misses (as documented on SessionConfig), not cumulative ones.
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default()); // 5 missed max
|
||||
responder
|
||||
.handle(SessionEvent::SetupRequestReceived(setup_request(2)))
|
||||
.unwrap();
|
||||
assert_eq!(responder.state(), SessionState::Active);
|
||||
let capture = || SessionEvent::MeasurementCaptured {
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: payload(10.0),
|
||||
};
|
||||
|
||||
// Miss 4, then succeed once...
|
||||
for _ in 0..4 {
|
||||
assert!(responder.handle(SessionEvent::Timeout).unwrap().is_empty());
|
||||
}
|
||||
let actions = responder.handle(capture()).unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendReport(_)]));
|
||||
|
||||
// ...so 4 more misses still leave the session alive.
|
||||
for _ in 0..4 {
|
||||
assert!(responder.handle(SessionEvent::Timeout).unwrap().is_empty());
|
||||
assert_eq!(responder.state(), SessionState::Active);
|
||||
}
|
||||
// The 5th consecutive miss terminates.
|
||||
let actions = responder.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendTermination(SensingSessionTermination {
|
||||
reason: TerminationReason::Timeout,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(responder.state(), SessionState::Terminating);
|
||||
}
|
||||
|
||||
// ---------- single-role enforcement & out-of-state commands ----------
|
||||
|
||||
#[test]
|
||||
fn initiator_role_session_rejects_inbound_setup_and_sbp_requests() {
|
||||
// Review finding 4a: single-role design — a peer must not be able to
|
||||
// hijack an initiator-role session into the responder path.
|
||||
let mut initiator = SensingSession::new_initiator(SessionConfig::default());
|
||||
let actions = initiator
|
||||
.handle(SessionEvent::SetupRequestReceived(setup_request(3)))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSetupResponse(SensingMeasurementSetupResponse {
|
||||
status: SetupStatus::RejectedNotSupported,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
|
||||
let sbp = SbpRequest {
|
||||
profile: SpecProfile::Ieee80211Bf2025,
|
||||
proxy_setup_id: MeasurementSetupId::new(4).unwrap(),
|
||||
params: params(),
|
||||
};
|
||||
let actions = initiator
|
||||
.handle(SessionEvent::SbpRequestReceived(sbp))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedNotSupported,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(initiator.state(), SessionState::Idle);
|
||||
assert!(!initiator.is_sbp_proxy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_start_commands_error_outside_idle() {
|
||||
// Review finding 4b: StartSetup/StartSbp outside Idle are caller bugs
|
||||
// and must surface as typed errors, not silent no-ops.
|
||||
let sbp = SbpRequest {
|
||||
profile: SpecProfile::Ieee80211Bf2025,
|
||||
proxy_setup_id: MeasurementSetupId::new(13).unwrap(),
|
||||
params: params(),
|
||||
};
|
||||
let start_err = |s: &mut SensingSession, expected: SessionState| {
|
||||
assert!(matches!(
|
||||
s.handle(SessionEvent::StartSetup(setup_request(8))),
|
||||
Err(BfError::InvalidStateForCommand { .. })
|
||||
));
|
||||
assert!(matches!(
|
||||
s.handle(SessionEvent::StartSbp(sbp.clone())),
|
||||
Err(BfError::InvalidStateForCommand { .. })
|
||||
));
|
||||
// The rejected commands must not disturb the session.
|
||||
assert_eq!(s.state(), expected);
|
||||
};
|
||||
|
||||
let mut s = SensingSession::new_initiator(SessionConfig::default());
|
||||
s.handle(SessionEvent::StartSetup(setup_request(7)))
|
||||
.unwrap();
|
||||
start_err(&mut s, SessionState::SetupNegotiating);
|
||||
|
||||
s.handle(SessionEvent::SetupResponseReceived(
|
||||
SensingMeasurementSetupResponse {
|
||||
setup_id: MeasurementSetupId::new(7).unwrap(),
|
||||
status: SetupStatus::Accepted,
|
||||
},
|
||||
))
|
||||
.unwrap();
|
||||
start_err(&mut s, SessionState::Active);
|
||||
// Genuinely ignorable stray frames remain no-ops in Active.
|
||||
assert!(s
|
||||
.handle(SessionEvent::SbpResponseReceived(SbpResponse {
|
||||
proxy_setup_id: MeasurementSetupId::new(7).unwrap(),
|
||||
status: SbpStatus::Accepted,
|
||||
}))
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
|
||||
s.handle(SessionEvent::Terminate(
|
||||
TerminationReason::InitiatorRequested,
|
||||
))
|
||||
.unwrap();
|
||||
start_err(&mut s, SessionState::Terminating);
|
||||
}
|
||||
|
||||
// ---------- adversarial: no panics anywhere ----------
|
||||
|
||||
#[test]
|
||||
fn malformed_and_out_of_state_events_never_panic() {
|
||||
let junk_payload = CsiReportPayload {
|
||||
n_subcarriers: 3,
|
||||
amplitudes: vec![f32::NAN, -5.0, f32::INFINITY],
|
||||
phases: vec![f32::NAN],
|
||||
};
|
||||
let bad_report = SensingMeasurementReport {
|
||||
setup_id: MeasurementSetupId::new(99).unwrap(),
|
||||
instance_id: MeasurementInstanceId::new(255),
|
||||
payload: junk_payload.clone(),
|
||||
};
|
||||
let events: Vec<SessionEvent> = vec![
|
||||
SessionEvent::StartSetup(setup_request(0)),
|
||||
SessionEvent::StartSbp(SbpRequest {
|
||||
profile: SpecProfile::DraftCompatible,
|
||||
proxy_setup_id: MeasurementSetupId::new(0).unwrap(),
|
||||
params: params(),
|
||||
}),
|
||||
SessionEvent::SetupRequestReceived(setup_request(127)),
|
||||
SessionEvent::SetupResponseReceived(SensingMeasurementSetupResponse {
|
||||
setup_id: MeasurementSetupId::new(50).unwrap(),
|
||||
status: SetupStatus::RejectedCapacity,
|
||||
}),
|
||||
SessionEvent::SbpResponseReceived(SbpResponse {
|
||||
proxy_setup_id: MeasurementSetupId::new(50).unwrap(),
|
||||
status: SbpStatus::RejectedByPolicy,
|
||||
}),
|
||||
SessionEvent::InstanceElapsed,
|
||||
SessionEvent::MeasurementCaptured {
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: junk_payload,
|
||||
},
|
||||
SessionEvent::ReportReceived(bad_report),
|
||||
SessionEvent::Timeout,
|
||||
SessionEvent::Terminate(TerminationReason::PolicyChange),
|
||||
SessionEvent::TerminationReceived(SensingSessionTermination {
|
||||
setup_id: MeasurementSetupId::new(1).unwrap(),
|
||||
reason: TerminationReason::Timeout,
|
||||
}),
|
||||
];
|
||||
// Drive both roles through every event repeatedly from whatever state
|
||||
// each lands in; typed errors are fine, panics are not.
|
||||
for session in [
|
||||
&mut SensingSession::new_initiator(SessionConfig::default()),
|
||||
&mut SensingSession::new_responder(SessionConfig::default()),
|
||||
] {
|
||||
for _ in 0..4 {
|
||||
for event in &events {
|
||||
let _ = session.handle(event.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
//! ADR-153 sensing-by-proxy (SBP) acceptance tests — proxy lifecycle
|
||||
//! (re-triggering + report relay), client flow, table-driven AP entry
|
||||
//! point, and the single-validation-path status mapping. Other FSM tests
|
||||
//! live in [`super::tests_fsm`]; type/serde/transport/bridge tests in
|
||||
//! [`super::tests`]. All tests are hardware-free (simulation only).
|
||||
|
||||
use super::messages::*;
|
||||
use super::session::{
|
||||
Action, CloseReason, SensingSession, SessionConfig, SessionEvent, SessionState,
|
||||
};
|
||||
use super::table::SessionTable;
|
||||
use super::testutil::{params, payload};
|
||||
use super::transport::{action_to_frame, frame_to_event, SensingFrame};
|
||||
use super::types::*;
|
||||
use crate::csi_frame::Bandwidth;
|
||||
|
||||
fn sbp_request(id: u8) -> SbpRequest {
|
||||
SbpRequest {
|
||||
profile: SpecProfile::Ieee80211Bf2025,
|
||||
proxy_setup_id: MeasurementSetupId::new(id).unwrap(),
|
||||
params: params(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbp_proxy_request_maps_to_standard_responder_path() {
|
||||
// Proxy AP: accepts the SBP request and initiates an ordinary setup
|
||||
// toward the sensing responder — no direct sensor coupling.
|
||||
let mut proxy = SensingSession::new_responder(SessionConfig::default());
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::SbpRequestReceived(sbp_request(11)))
|
||||
.unwrap();
|
||||
let forwarded = match &actions[..] {
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::Accepted,
|
||||
..
|
||||
}), Action::SendSetupRequest(req)] => req.clone(),
|
||||
other => panic!("expected SBP accept + setup request, got {other:?}"),
|
||||
};
|
||||
assert_eq!(proxy.state(), SessionState::SetupNegotiating);
|
||||
assert_eq!(forwarded.setup_id.value(), 11);
|
||||
|
||||
// The forwarded request drives a *normal* responder session.
|
||||
let mut responder = SensingSession::new_responder(SessionConfig::default());
|
||||
let actions = responder
|
||||
.handle(SessionEvent::SetupRequestReceived(forwarded))
|
||||
.unwrap();
|
||||
let resp = match &actions[..] {
|
||||
[Action::SendSetupResponse(r)] => *r,
|
||||
other => panic!("expected accept, got {other:?}"),
|
||||
};
|
||||
assert_eq!(resp.status, SetupStatus::Accepted);
|
||||
proxy
|
||||
.handle(SessionEvent::SetupResponseReceived(resp))
|
||||
.unwrap();
|
||||
assert_eq!(proxy.state(), SessionState::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbp_client_flow_and_rejections() {
|
||||
let mut client = SensingSession::new_initiator(SessionConfig::default());
|
||||
let sbp = sbp_request(12);
|
||||
let actions = client.handle(SessionEvent::StartSbp(sbp.clone())).unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendSbpRequest(_)]));
|
||||
let accept = SbpResponse {
|
||||
proxy_setup_id: sbp.proxy_setup_id,
|
||||
status: SbpStatus::Accepted,
|
||||
};
|
||||
client
|
||||
.handle(SessionEvent::SbpResponseReceived(accept))
|
||||
.unwrap();
|
||||
assert_eq!(client.state(), SessionState::Active);
|
||||
// Proxied report is delivered to the local consumer.
|
||||
let report = SensingMeasurementReport {
|
||||
setup_id: sbp.proxy_setup_id,
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: payload(1.0),
|
||||
};
|
||||
let actions = client.handle(SessionEvent::ReportReceived(report)).unwrap();
|
||||
assert!(matches!(actions[..], [Action::DeliverReport(_)]));
|
||||
|
||||
// A proxy without SBP capability rejects.
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.capabilities.sensing_by_proxy = false;
|
||||
let mut no_sbp = SensingSession::new_responder(cfg);
|
||||
let actions = no_sbp
|
||||
.handle(SessionEvent::SbpRequestReceived(sbp))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedNotSupported,
|
||||
..
|
||||
})]
|
||||
));
|
||||
assert_eq!(no_sbp.state(), SessionState::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbp_proxy_full_lifecycle_retriggers_and_relays() {
|
||||
// Review finding 1: the SBP proxy is a first-class mode — after the
|
||||
// proxied setup is accepted it keeps driving measurement instances on
|
||||
// InstanceElapsed (like an initiator) and relays every received report
|
||||
// to the SBP client in addition to local delivery.
|
||||
let mut proxy = SensingSession::new_responder(SessionConfig::default());
|
||||
|
||||
// Accept: SBP response to the client + proxied setup to the responder.
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::SbpRequestReceived(sbp_request(21)))
|
||||
.unwrap();
|
||||
let forwarded = match &actions[..] {
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::Accepted,
|
||||
..
|
||||
}), Action::SendSetupRequest(req)] => req.clone(),
|
||||
other => panic!("expected SBP accept + setup request, got {other:?}"),
|
||||
};
|
||||
assert!(proxy.is_sbp_proxy());
|
||||
|
||||
// Responder accepts → proxy Active, instance 0 triggered.
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::SetupResponseReceived(
|
||||
SensingMeasurementSetupResponse {
|
||||
setup_id: forwarded.setup_id,
|
||||
status: SetupStatus::Accepted,
|
||||
},
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(proxy.state(), SessionState::Active);
|
||||
match &actions[..] {
|
||||
[Action::TriggerInstance(i)] => assert_eq!(i.instance_id.value(), 0),
|
||||
other => panic!("expected instance 0 trigger, got {other:?}"),
|
||||
}
|
||||
|
||||
// InstanceElapsed re-triggers instance 1+ (proxy drives the schedule).
|
||||
let actions = proxy.handle(SessionEvent::InstanceElapsed).unwrap();
|
||||
match &actions[..] {
|
||||
[Action::TriggerInstance(i)] => assert_eq!(i.instance_id.value(), 1),
|
||||
other => panic!("expected instance 1 trigger, got {other:?}"),
|
||||
}
|
||||
|
||||
// A report from the sensing responder is delivered locally AND relayed.
|
||||
let report = SensingMeasurementReport {
|
||||
setup_id: forwarded.setup_id,
|
||||
instance_id: MeasurementInstanceId::new(1),
|
||||
payload: payload(5.0),
|
||||
};
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::ReportReceived(report.clone()))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
Action::DeliverReport(report.clone()),
|
||||
Action::RelaySbpReport(report.clone()),
|
||||
]
|
||||
);
|
||||
// The relay action maps to a frame toward the SBP client, which
|
||||
// consumes it through the standard report path.
|
||||
let frame = action_to_frame(&Action::RelaySbpReport(report.clone())).unwrap();
|
||||
assert_eq!(frame, SensingFrame::SbpReport(report.clone()));
|
||||
assert_eq!(
|
||||
frame_to_event(frame),
|
||||
Some(SessionEvent::ReportReceived(report))
|
||||
);
|
||||
|
||||
// Terminate cleanly: notify the responder, quiesce back to Idle.
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::Terminate(
|
||||
TerminationReason::InitiatorRequested,
|
||||
))
|
||||
.unwrap();
|
||||
assert!(matches!(actions[..], [Action::SendTermination(_)]));
|
||||
assert_eq!(proxy.state(), SessionState::Terminating);
|
||||
let actions = proxy.handle(SessionEvent::Timeout).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SessionClosed(CloseReason::Completed)]
|
||||
));
|
||||
assert_eq!(proxy.state(), SessionState::Idle);
|
||||
assert!(!proxy.is_sbp_proxy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_table_routes_sbp_end_to_end() {
|
||||
// Review finding 3: the table has a first-class SBP entry point with
|
||||
// the same collision/capacity guards as direct setups — a table-driven
|
||||
// AP accepts SBP instead of silently dropping it.
|
||||
let mut table = SessionTable::new(SessionConfig::default());
|
||||
let actions = table.handle_sbp_request(sbp_request(31)).unwrap();
|
||||
let forwarded = match &actions[..] {
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::Accepted,
|
||||
..
|
||||
}), Action::SendSetupRequest(req)] => req.clone(),
|
||||
other => panic!("expected SBP accept + setup request, got {other:?}"),
|
||||
};
|
||||
let setup_id = forwarded.setup_id;
|
||||
assert_eq!(table.active_setups(), 1);
|
||||
assert!(table.session(setup_id).unwrap().is_sbp_proxy());
|
||||
|
||||
// Proxy-setup-ID collision while the first proxy is live.
|
||||
let actions = table.handle_sbp_request(sbp_request(31)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedSetupIdCollision,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Drive the proxied negotiation to Active through the table.
|
||||
let actions = table
|
||||
.handle_for(
|
||||
setup_id,
|
||||
SessionEvent::SetupResponseReceived(SensingMeasurementSetupResponse {
|
||||
setup_id,
|
||||
status: SetupStatus::Accepted,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(matches!(actions[..], [Action::TriggerInstance(_)]));
|
||||
assert_eq!(
|
||||
table.session(setup_id).unwrap().state(),
|
||||
SessionState::Active
|
||||
);
|
||||
|
||||
// Reports relay to the SBP client through the table-owned proxy.
|
||||
let report = SensingMeasurementReport {
|
||||
setup_id,
|
||||
instance_id: MeasurementInstanceId::new(0),
|
||||
payload: payload(2.0),
|
||||
};
|
||||
let actions = table
|
||||
.handle_for(setup_id, SessionEvent::ReportReceived(report.clone()))
|
||||
.unwrap();
|
||||
assert!(actions.contains(&Action::RelaySbpReport(report)));
|
||||
|
||||
// Capacity guard mirrors the direct-setup path.
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.capabilities.max_active_setups = 1;
|
||||
let mut small = SessionTable::new(cfg);
|
||||
small.handle_sbp_request(sbp_request(1)).unwrap();
|
||||
let actions = small.handle_sbp_request(sbp_request(2)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedCapacity,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Unknown-setup drops are observable, not silent (finding 3).
|
||||
assert_eq!(table.unknown_setup_drops(), 0);
|
||||
let actions = table
|
||||
.handle_for(MeasurementSetupId::new(99).unwrap(), SessionEvent::Timeout)
|
||||
.unwrap();
|
||||
assert!(actions.is_empty());
|
||||
assert_eq!(table.unknown_setup_drops(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbp_validation_shares_setup_chain_with_one_to_one_status_mapping() {
|
||||
// Review finding 5: SBP requests are validated by building the proxied
|
||||
// setup request first and running it through the single evaluate_setup
|
||||
// chain — statuses map 1:1, so no rejection class is folded away and no
|
||||
// setup policy can be bypassed via SBP.
|
||||
|
||||
// Incompatible profile now surfaces as its own status (the old
|
||||
// duplicated SBP chain folded it into RejectedUnsupportedParams).
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.profile = SpecProfile::VendorExtension("acme".into());
|
||||
let mut proxy = SensingSession::new_responder(cfg);
|
||||
let actions = proxy
|
||||
.handle(SessionEvent::SbpRequestReceived(sbp_request(41)))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedIncompatibleProfile,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Consent policy rejection passes through unchanged.
|
||||
let mut proxy = SensingSession::new_responder(SessionConfig::default());
|
||||
let mut sbp = sbp_request(42);
|
||||
sbp.params.consent = ConsentMode::Disabled;
|
||||
let actions = proxy.handle(SessionEvent::SbpRequestReceived(sbp)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedByPolicy,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// Capability rejection (bandwidth beyond the advertised maximum).
|
||||
let mut cfg = SessionConfig::default();
|
||||
cfg.capabilities.max_bandwidth_mhz = 40;
|
||||
let mut proxy = SensingSession::new_responder(cfg);
|
||||
let mut sbp = sbp_request(43);
|
||||
sbp.params.bandwidth = Bandwidth::Bw80;
|
||||
let actions = proxy.handle(SessionEvent::SbpRequestReceived(sbp)).unwrap();
|
||||
assert!(matches!(
|
||||
actions[..],
|
||||
[Action::SendSbpResponse(SbpResponse {
|
||||
status: SbpStatus::RejectedUnsupportedParams,
|
||||
..
|
||||
})]
|
||||
));
|
||||
|
||||
// The status translation itself is exhaustive and 1:1.
|
||||
let pairs = [
|
||||
(SetupStatus::Accepted, SbpStatus::Accepted),
|
||||
(
|
||||
SetupStatus::RejectedNotSupported,
|
||||
SbpStatus::RejectedNotSupported,
|
||||
),
|
||||
(
|
||||
SetupStatus::RejectedUnsupportedParams,
|
||||
SbpStatus::RejectedUnsupportedParams,
|
||||
),
|
||||
(
|
||||
SetupStatus::RejectedSetupIdCollision,
|
||||
SbpStatus::RejectedSetupIdCollision,
|
||||
),
|
||||
(
|
||||
SetupStatus::RejectedIncompatibleProfile,
|
||||
SbpStatus::RejectedIncompatibleProfile,
|
||||
),
|
||||
(SetupStatus::RejectedByPolicy, SbpStatus::RejectedByPolicy),
|
||||
(SetupStatus::RejectedCapacity, SbpStatus::RejectedCapacity),
|
||||
];
|
||||
for (setup, sbp) in pairs {
|
||||
assert_eq!(SbpStatus::from(setup), sbp);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
//! Shared helpers for the ADR-153 acceptance tests (hardware-free).
|
||||
|
||||
use chrono::Utc;
|
||||
|
||||
use super::messages::{CsiReportPayload, SensingMeasurementSetupRequest};
|
||||
use super::session::{Action, SensingSession, SessionEvent};
|
||||
use super::transport::{action_to_frame, frame_to_event, SensingTransport, SimTransport};
|
||||
use super::types::{
|
||||
ConsentMode, MeasurementSetupId, MeasurementSetupParams, ReportingConfig, SpecProfile,
|
||||
TransceiverRole,
|
||||
};
|
||||
use crate::csi_frame::{
|
||||
Adr018Flags, AntennaConfig, Bandwidth, CsiFrame, CsiMetadata, PpduType, SubcarrierData,
|
||||
};
|
||||
|
||||
pub(super) fn params() -> MeasurementSetupParams {
|
||||
MeasurementSetupParams {
|
||||
bandwidth: Bandwidth::Bw20,
|
||||
period_ms: 100,
|
||||
burst_instances: 4,
|
||||
reporting: ReportingConfig::EveryInstance,
|
||||
initiator_role: TransceiverRole::Transmitter,
|
||||
responder_role: TransceiverRole::Receiver,
|
||||
consent: ConsentMode::ExplicitConsent,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn setup_request(id: u8) -> SensingMeasurementSetupRequest {
|
||||
SensingMeasurementSetupRequest {
|
||||
profile: SpecProfile::Ieee80211Bf2025,
|
||||
setup_id: MeasurementSetupId::new(id).unwrap(),
|
||||
params: params(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn payload(mean: f32) -> CsiReportPayload {
|
||||
CsiReportPayload {
|
||||
n_subcarriers: 4,
|
||||
amplitudes: vec![mean; 4],
|
||||
phases: vec![0.25; 4],
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn csi_frame(n: usize, i: i16, q: i16) -> CsiFrame {
|
||||
CsiFrame {
|
||||
metadata: CsiMetadata {
|
||||
timestamp: Utc::now(),
|
||||
node_id: 1,
|
||||
n_antennas: 1,
|
||||
n_subcarriers: n as u16,
|
||||
channel_freq_mhz: 2437,
|
||||
rssi_dbm: -50,
|
||||
noise_floor_dbm: -95,
|
||||
bandwidth: Bandwidth::Bw20,
|
||||
antenna_config: AntennaConfig::default(),
|
||||
sequence: 0,
|
||||
ppdu_type: PpduType::HtLegacy,
|
||||
adr018_flags: Adr018Flags::default(),
|
||||
},
|
||||
subcarriers: (0..n)
|
||||
.map(|k| SubcarrierData {
|
||||
i,
|
||||
q,
|
||||
index: k as i16,
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive a session, forwarding wire-bound actions onto a transport.
|
||||
pub(super) fn dispatch(
|
||||
s: &mut SensingSession,
|
||||
event: SessionEvent,
|
||||
out: &mut SimTransport,
|
||||
) -> Vec<Action> {
|
||||
let actions = s.handle(event).expect("handle must not error");
|
||||
for a in &actions {
|
||||
if let Some(f) = action_to_frame(a) {
|
||||
out.send_frame(f).expect("send must not error");
|
||||
}
|
||||
}
|
||||
actions
|
||||
}
|
||||
|
||||
pub(super) fn ferry(from: &mut SimTransport, to: &mut SimTransport) {
|
||||
for f in from.drain_sent() {
|
||||
to.push_inbound(f);
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume inbound frames on `wire`, sending any resulting outbound frames
|
||||
/// back onto the same transport's sent log.
|
||||
pub(super) fn pump(s: &mut SensingSession, wire: &mut SimTransport) -> Vec<Action> {
|
||||
let mut all = Vec::new();
|
||||
while let Some(frame) = wire.poll_frame() {
|
||||
if let Some(event) = frame_to_event(frame) {
|
||||
all.extend(dispatch(s, event, wire));
|
||||
}
|
||||
}
|
||||
all
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
//! Transport abstraction for the 802.11bf forward-compatibility model.
|
||||
//!
|
||||
//! [`SensingTransport`] is the seam where a real chipset binding will land
|
||||
//! when commodity silicon implements IEEE 802.11bf-2025 (none does today —
|
||||
//! ADR-152 F4, ADR-153). Until then:
|
||||
//!
|
||||
//! - [`SimTransport`] is a scriptable in-memory test double for protocol
|
||||
//! tests in CI (no hardware).
|
||||
//! - [`OpportunisticCsiBridge`] maps today's opportunistic ESP32 CSI
|
||||
//! extraction (ADR-018 frames parsed by [`crate::Esp32CsiParser`] and
|
||||
//! delivered by [`crate::aggregator::Esp32Aggregator`]) onto the
|
||||
//! standardized report path: one measurement instance ≈ one batch of
|
||||
//! [`CsiFrame`]s.
|
||||
//!
|
||||
//! **Replaceability benchmark (ADR-153):** consumers must depend only on
|
||||
//! `SensingTransport` plus the report types in [`super::types`] — a future
|
||||
//! chipset adapter replaces `OpportunisticCsiBridge` without touching them.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::messages::{
|
||||
CsiReportPayload, SbpRequest, SbpResponse, SensingMeasurementInstance,
|
||||
SensingMeasurementReport, SensingMeasurementSetupRequest, SensingMeasurementSetupResponse,
|
||||
SensingSessionTermination,
|
||||
};
|
||||
use super::session::Action;
|
||||
use super::types::{BfError, MeasurementInstanceId, MeasurementSetupId, MAX_REPORT_SUBCARRIERS};
|
||||
use crate::csi_frame::CsiFrame;
|
||||
|
||||
/// Frames exchanged between sensing endpoints. This is a *logical* frame
|
||||
/// set — no OTA encoding is defined until silicon exists to bind to.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SensingFrame {
|
||||
SetupRequest(SensingMeasurementSetupRequest),
|
||||
SetupResponse(SensingMeasurementSetupResponse),
|
||||
InstanceTrigger(SensingMeasurementInstance),
|
||||
Report(SensingMeasurementReport),
|
||||
SbpRequest(SbpRequest),
|
||||
SbpResponse(SbpResponse),
|
||||
/// Proxied measurement report forwarded by an SBP proxy toward its SBP
|
||||
/// client ([`Action::RelaySbpReport`]) — distinct from [`Self::Report`],
|
||||
/// which travels toward the sensing initiator.
|
||||
SbpReport(SensingMeasurementReport),
|
||||
Termination(SensingSessionTermination),
|
||||
}
|
||||
|
||||
/// Errors surfaced by a sensing transport.
|
||||
#[derive(Debug, Clone, PartialEq, Error)]
|
||||
pub enum TransportError {
|
||||
#[error("transport link down")]
|
||||
LinkDown,
|
||||
#[error("transport queue full (capacity {capacity})")]
|
||||
QueueFull { capacity: usize },
|
||||
}
|
||||
|
||||
/// Frame-exchange abstraction for sensing endpoints.
|
||||
///
|
||||
/// The required surface is deliberately tiny (`send_frame`/`poll_frame`);
|
||||
/// the named helpers are convenience wrappers so call sites read like the
|
||||
/// standard's procedures.
|
||||
pub trait SensingTransport {
|
||||
/// Queue one logical frame toward the peer.
|
||||
fn send_frame(&mut self, frame: SensingFrame) -> Result<(), TransportError>;
|
||||
|
||||
/// Pop the next inbound frame, if any.
|
||||
fn poll_frame(&mut self) -> Option<SensingFrame>;
|
||||
|
||||
fn send_setup_request(
|
||||
&mut self,
|
||||
req: SensingMeasurementSetupRequest,
|
||||
) -> Result<(), TransportError> {
|
||||
self.send_frame(SensingFrame::SetupRequest(req))
|
||||
}
|
||||
|
||||
fn send_setup_response(
|
||||
&mut self,
|
||||
resp: SensingMeasurementSetupResponse,
|
||||
) -> Result<(), TransportError> {
|
||||
self.send_frame(SensingFrame::SetupResponse(resp))
|
||||
}
|
||||
|
||||
fn trigger_measurement_instance(
|
||||
&mut self,
|
||||
instance: SensingMeasurementInstance,
|
||||
) -> Result<(), TransportError> {
|
||||
self.send_frame(SensingFrame::InstanceTrigger(instance))
|
||||
}
|
||||
|
||||
fn send_report(&mut self, report: SensingMeasurementReport) -> Result<(), TransportError> {
|
||||
self.send_frame(SensingFrame::Report(report))
|
||||
}
|
||||
|
||||
fn send_termination(
|
||||
&mut self,
|
||||
termination: SensingSessionTermination,
|
||||
) -> Result<(), TransportError> {
|
||||
self.send_frame(SensingFrame::Termination(termination))
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a session [`Action`] to the frame it puts on the wire, if any.
|
||||
/// `DeliverReport`/`SessionClosed` are local-consumer actions and map to `None`.
|
||||
pub fn action_to_frame(action: &Action) -> Option<SensingFrame> {
|
||||
match action {
|
||||
Action::SendSetupRequest(req) => Some(SensingFrame::SetupRequest(req.clone())),
|
||||
Action::SendSetupResponse(resp) => Some(SensingFrame::SetupResponse(*resp)),
|
||||
Action::SendSbpRequest(req) => Some(SensingFrame::SbpRequest(req.clone())),
|
||||
Action::SendSbpResponse(resp) => Some(SensingFrame::SbpResponse(*resp)),
|
||||
Action::TriggerInstance(instance) => Some(SensingFrame::InstanceTrigger(*instance)),
|
||||
Action::SendReport(report) => Some(SensingFrame::Report(report.clone())),
|
||||
Action::RelaySbpReport(report) => Some(SensingFrame::SbpReport(report.clone())),
|
||||
Action::SendTermination(term) => Some(SensingFrame::Termination(*term)),
|
||||
Action::DeliverReport(_) | Action::SessionClosed(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map an inbound frame to the session event it raises on the receiver.
|
||||
///
|
||||
/// `InstanceTrigger` maps to `None`: a sensing receiver pairs the trigger
|
||||
/// with locally captured CSI and raises `MeasurementCaptured` itself (see
|
||||
/// [`OpportunisticCsiBridge`]).
|
||||
pub fn frame_to_event(frame: SensingFrame) -> Option<super::session::SessionEvent> {
|
||||
use super::session::SessionEvent as E;
|
||||
match frame {
|
||||
SensingFrame::SetupRequest(req) => Some(E::SetupRequestReceived(req)),
|
||||
SensingFrame::SetupResponse(resp) => Some(E::SetupResponseReceived(resp)),
|
||||
SensingFrame::Report(report) => Some(E::ReportReceived(report)),
|
||||
// The SBP client consumes proxied reports through the standard
|
||||
// report path (its session is in sbp_client mode).
|
||||
SensingFrame::SbpReport(report) => Some(E::ReportReceived(report)),
|
||||
SensingFrame::SbpRequest(req) => Some(E::SbpRequestReceived(req)),
|
||||
SensingFrame::SbpResponse(resp) => Some(E::SbpResponseReceived(resp)),
|
||||
SensingFrame::Termination(term) => Some(E::TerminationReceived(term)),
|
||||
SensingFrame::InstanceTrigger(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory scriptable transport test double.
|
||||
///
|
||||
/// Every successful `send_frame` is recorded in [`SimTransport::sent`]; if a
|
||||
/// scripted response is queued, it is moved to the inbound queue so the next
|
||||
/// `poll_frame` returns it — letting tests script a peer without one.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SimTransport {
|
||||
sent: Vec<SensingFrame>,
|
||||
inbound: VecDeque<SensingFrame>,
|
||||
scripted: VecDeque<SensingFrame>,
|
||||
link_down: bool,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl SimTransport {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
capacity: 1024,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
capacity,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Frames sent so far, in order.
|
||||
pub fn sent(&self) -> &[SensingFrame] {
|
||||
&self.sent
|
||||
}
|
||||
|
||||
/// Drain the sent log (useful when ferrying frames between two doubles).
|
||||
pub fn drain_sent(&mut self) -> Vec<SensingFrame> {
|
||||
std::mem::take(&mut self.sent)
|
||||
}
|
||||
|
||||
/// Queue a frame as if the peer transmitted it.
|
||||
pub fn push_inbound(&mut self, frame: SensingFrame) {
|
||||
self.inbound.push_back(frame);
|
||||
}
|
||||
|
||||
/// Script a response: the next successful send moves it to the inbound
|
||||
/// queue (one scripted frame consumed per send).
|
||||
pub fn script_response(&mut self, frame: SensingFrame) {
|
||||
self.scripted.push_back(frame);
|
||||
}
|
||||
|
||||
pub fn set_link_down(&mut self, down: bool) {
|
||||
self.link_down = down;
|
||||
}
|
||||
}
|
||||
|
||||
impl SensingTransport for SimTransport {
|
||||
fn send_frame(&mut self, frame: SensingFrame) -> Result<(), TransportError> {
|
||||
if self.link_down {
|
||||
return Err(TransportError::LinkDown);
|
||||
}
|
||||
if self.sent.len() >= self.capacity {
|
||||
return Err(TransportError::QueueFull {
|
||||
capacity: self.capacity,
|
||||
});
|
||||
}
|
||||
self.sent.push(frame);
|
||||
if let Some(response) = self.scripted.pop_front() {
|
||||
self.inbound.push_back(response);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_frame(&mut self) -> Option<SensingFrame> {
|
||||
self.inbound.pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapter mapping today's opportunistic ESP32 CSI extraction onto the
|
||||
/// standardized sensing report path.
|
||||
///
|
||||
/// A "measurement instance" is approximated by one batch of `batch_size`
|
||||
/// ADR-018 [`CsiFrame`]s from a node (as produced by
|
||||
/// [`crate::aggregator::Esp32Aggregator`]'s mpsc channel). Amplitudes are
|
||||
/// averaged arithmetically; phases via the circular mean (consistent with
|
||||
/// the RuvSense `phase_align` treatment of LO phase). Invalid frames
|
||||
/// ([`CsiFrame::is_valid`] false) are skipped; a mid-batch subcarrier-shape
|
||||
/// change (node reconfiguration) restarts the batch on the new shape.
|
||||
///
|
||||
/// This is the *interim backend*: when 802.11bf silicon exists, a chipset
|
||||
/// adapter producing the same [`SensingMeasurementReport`]s replaces this
|
||||
/// bridge with no change to consumers (ADR-153 replaceability benchmark).
|
||||
#[derive(Debug)]
|
||||
pub struct OpportunisticCsiBridge {
|
||||
setup_id: MeasurementSetupId,
|
||||
batch_size: usize,
|
||||
instance_counter: u32,
|
||||
amp_accum: Vec<f64>,
|
||||
phase_cos_accum: Vec<f64>,
|
||||
phase_sin_accum: Vec<f64>,
|
||||
frames_in_batch: usize,
|
||||
}
|
||||
|
||||
impl OpportunisticCsiBridge {
|
||||
pub fn new(setup_id: MeasurementSetupId, batch_size: usize) -> Result<Self, BfError> {
|
||||
if batch_size == 0 {
|
||||
return Err(BfError::InvalidBatchSize { got: 0 });
|
||||
}
|
||||
Ok(Self {
|
||||
setup_id,
|
||||
batch_size,
|
||||
instance_counter: 0,
|
||||
amp_accum: Vec::new(),
|
||||
phase_cos_accum: Vec::new(),
|
||||
phase_sin_accum: Vec::new(),
|
||||
frames_in_batch: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn setup_id(&self) -> MeasurementSetupId {
|
||||
self.setup_id
|
||||
}
|
||||
|
||||
pub fn batch_size(&self) -> usize {
|
||||
self.batch_size
|
||||
}
|
||||
|
||||
/// Feed one parsed CSI frame; returns a standardized measurement report
|
||||
/// when a batch completes. Never panics on malformed frames.
|
||||
pub fn ingest(&mut self, frame: &CsiFrame) -> Option<SensingMeasurementReport> {
|
||||
if !frame.is_valid() || frame.subcarrier_count() > MAX_REPORT_SUBCARRIERS as usize {
|
||||
return None;
|
||||
}
|
||||
let (amplitudes, phases) = frame.to_amplitude_phase();
|
||||
if self.frames_in_batch == 0 || amplitudes.len() != self.amp_accum.len() {
|
||||
// Fresh batch (or node reconfigured mid-batch — restart on the
|
||||
// new subcarrier shape, dropping the partial batch).
|
||||
self.amp_accum = vec![0.0; amplitudes.len()];
|
||||
self.phase_cos_accum = vec![0.0; amplitudes.len()];
|
||||
self.phase_sin_accum = vec![0.0; amplitudes.len()];
|
||||
self.frames_in_batch = 0;
|
||||
}
|
||||
for (i, (a, p)) in amplitudes.iter().zip(phases.iter()).enumerate() {
|
||||
self.amp_accum[i] += a;
|
||||
self.phase_cos_accum[i] += p.cos();
|
||||
self.phase_sin_accum[i] += p.sin();
|
||||
}
|
||||
self.frames_in_batch += 1;
|
||||
if self.frames_in_batch < self.batch_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
let scale = self.frames_in_batch as f64;
|
||||
let payload = CsiReportPayload {
|
||||
n_subcarriers: self.amp_accum.len() as u16,
|
||||
amplitudes: self.amp_accum.iter().map(|a| (a / scale) as f32).collect(),
|
||||
phases: self
|
||||
.phase_sin_accum
|
||||
.iter()
|
||||
.zip(self.phase_cos_accum.iter())
|
||||
.map(|(s, c)| s.atan2(*c) as f32)
|
||||
.collect(),
|
||||
};
|
||||
self.amp_accum.clear();
|
||||
self.phase_cos_accum.clear();
|
||||
self.phase_sin_accum.clear();
|
||||
self.frames_in_batch = 0;
|
||||
|
||||
let n = self.instance_counter;
|
||||
self.instance_counter = self.instance_counter.wrapping_add(1);
|
||||
let report = SensingMeasurementReport {
|
||||
setup_id: self.setup_id,
|
||||
instance_id: MeasurementInstanceId::new((n % 256) as u8),
|
||||
payload,
|
||||
};
|
||||
// Boundary check before handing to consumers; drop instead of panic.
|
||||
report.validate().ok()?;
|
||||
Some(report)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,398 @@
|
||||
//! Typed structures for IEEE 802.11bf-2025 WLAN sensing procedures.
|
||||
//!
|
||||
//! Sub-7 GHz focus; DMG (>45 GHz) types are stubbed minimally. Concept names
|
||||
//! follow the standard's procedure vocabulary descriptively — "Sensing
|
||||
//! Measurement Setup", "Sensing Measurement Instance", "Sensing Measurement
|
||||
//! Report", "Sensing by Proxy (SBP)", session termination — without claiming
|
||||
//! clause-level conformance. See [`crate::ieee80211bf`] module docs and
|
||||
//! ADR-153 for framing; ADR-152 §1.1 F4 for the standards-body evidence.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::csi_frame::Bandwidth;
|
||||
|
||||
/// Largest measurement setup identifier accepted by this model (7-bit space;
|
||||
/// chosen conservatively — the standard encodes the Measurement Setup ID in a
|
||||
/// compact identifier field).
|
||||
pub const MAX_SETUP_ID: u8 = 127;
|
||||
/// Minimum measurement-instance periodicity accepted by this model.
|
||||
pub const MIN_PERIOD_MS: u32 = 10;
|
||||
/// Maximum measurement-instance periodicity accepted by this model (1 hour).
|
||||
pub const MAX_PERIOD_MS: u32 = 3_600_000;
|
||||
/// Maximum measurement instances per burst accepted by this model.
|
||||
pub const MAX_BURST_INSTANCES: u8 = 64;
|
||||
/// Maximum subcarriers in a CSI-variant report payload (matches the 160 MHz
|
||||
/// usable-subcarrier count, [`Bandwidth::Bw160`]).
|
||||
pub const MAX_REPORT_SUBCARRIERS: u16 = 484;
|
||||
|
||||
/// Errors produced by validation at the protocol-model boundary.
|
||||
///
|
||||
/// Adversarial or malformed input must surface as one of these — never a
|
||||
/// panic (crate rule: input validation at system boundaries).
|
||||
#[derive(Debug, Clone, PartialEq, Error)]
|
||||
pub enum BfError {
|
||||
/// Measurement setup ID outside the accepted identifier space.
|
||||
#[error("invalid measurement setup ID {value} (valid 0..={MAX_SETUP_ID})")]
|
||||
InvalidSetupId { value: u8 },
|
||||
/// Measurement periodicity outside the accepted range.
|
||||
#[error("measurement period {period_ms} ms out of range ({MIN_PERIOD_MS}..={MAX_PERIOD_MS})")]
|
||||
InvalidPeriod { period_ms: u32 },
|
||||
/// Instances-per-burst outside the accepted range.
|
||||
#[error("burst instance count {count} out of range (1..={MAX_BURST_INSTANCES})")]
|
||||
InvalidBurstInstances { count: u8 },
|
||||
/// Threshold-based reporting parameter outside 0..=100 percent.
|
||||
#[error("reporting threshold {value}% out of range (0..=100)")]
|
||||
InvalidThreshold { value: u8 },
|
||||
/// The initiator/responder transceiver roles leave the measurement with
|
||||
/// no sensing transmitter or no sensing receiver.
|
||||
#[error("transceiver roles leave no sensing transmitter/receiver pair")]
|
||||
InvalidTransceiverRoles,
|
||||
/// Setup carries [`ConsentMode::Disabled`] — sensing must not start.
|
||||
#[error("sensing disabled by consent policy")]
|
||||
SensingDisabledByPolicy,
|
||||
/// Report payload declares zero subcarriers.
|
||||
#[error("report payload empty")]
|
||||
EmptyPayload,
|
||||
/// Report payload claims more subcarriers than this model supports.
|
||||
#[error("report payload claims {count} subcarriers (max {MAX_REPORT_SUBCARRIERS})")]
|
||||
PayloadTooLarge { count: u16 },
|
||||
/// Declared subcarrier count and vector lengths disagree.
|
||||
#[error(
|
||||
"report payload length mismatch: declared {declared}, amplitudes {amplitudes}, phases {phases}"
|
||||
)]
|
||||
PayloadLengthMismatch {
|
||||
declared: usize,
|
||||
amplitudes: usize,
|
||||
phases: usize,
|
||||
},
|
||||
/// A payload value is NaN/infinite, or an amplitude is negative.
|
||||
#[error("report payload value at index {index} is not finite (or negative amplitude)")]
|
||||
PayloadValueInvalid { index: usize },
|
||||
/// A frame referenced a setup ID that does not match the session.
|
||||
#[error("setup ID mismatch: session {expected}, frame {got}")]
|
||||
SetupIdMismatch { expected: u8, got: u8 },
|
||||
/// Sensing measurement setup negotiation timed out (session resets to Idle).
|
||||
#[error("negotiation timed out for setup {setup_id} after {attempts} attempts")]
|
||||
NegotiationTimeout { setup_id: u8, attempts: u8 },
|
||||
/// A local command (`StartSetup`/`StartSbp`) was issued in a state or
|
||||
/// role that cannot accept it.
|
||||
#[error("command not valid in state {state}")]
|
||||
InvalidStateForCommand { state: &'static str },
|
||||
/// CSI bridge batch size must be at least one frame.
|
||||
#[error("invalid CSI batch size {got} (must be >= 1)")]
|
||||
InvalidBatchSize { got: usize },
|
||||
}
|
||||
|
||||
/// Version gate for every negotiated surface (ADR-153).
|
||||
///
|
||||
/// Vendors will expose partial or renamed capabilities before full
|
||||
/// IEEE 802.11bf-2025 conformance; tagging setups and capability
|
||||
/// advertisements with a profile keeps that drift explicit.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum SpecProfile {
|
||||
/// Pre-publication draft semantics (D-series compatible behavior).
|
||||
DraftCompatible,
|
||||
/// Published standard semantics (IEEE 802.11bf-2025, published 2025-09-26).
|
||||
Ieee80211Bf2025,
|
||||
/// Vendor-specific extension or renamed capability set.
|
||||
VendorExtension(String),
|
||||
}
|
||||
|
||||
impl SpecProfile {
|
||||
/// Whether a peer advertising `self` accepts a setup tagged `requested`.
|
||||
///
|
||||
/// Published-standard peers accept draft-compatible requests; vendor
|
||||
/// extensions must match exactly.
|
||||
pub fn accepts(&self, requested: &SpecProfile) -> bool {
|
||||
self == requested
|
||||
|| matches!(
|
||||
(self, requested),
|
||||
(SpecProfile::Ieee80211Bf2025, SpecProfile::DraftCompatible)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Consent/governance mode carried by every sensing measurement setup
|
||||
/// (ADR-153: sensing is presence inference, not just radio telemetry).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum ConsentMode {
|
||||
/// Lab/bench use only; not a deployment consent basis.
|
||||
LabOnly,
|
||||
/// Sensed persons gave explicit consent.
|
||||
ExplicitConsent,
|
||||
/// Enterprise-managed policy authorizes sensing.
|
||||
ManagedEnterprisePolicy,
|
||||
/// Sensing administratively disabled — setups must be rejected.
|
||||
Disabled,
|
||||
}
|
||||
|
||||
/// WLAN sensing procedure role: sensing initiator or sensing responder.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SensingRole {
|
||||
Initiator,
|
||||
Responder,
|
||||
}
|
||||
|
||||
/// Per-measurement-instance role: sensing transmitter, sensing receiver,
|
||||
/// or both (a STA may act as either within a measurement instance).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TransceiverRole {
|
||||
Transmitter,
|
||||
Receiver,
|
||||
TransmitterReceiver,
|
||||
}
|
||||
|
||||
impl TransceiverRole {
|
||||
pub fn is_transmitter(self) -> bool {
|
||||
matches!(self, Self::Transmitter | Self::TransmitterReceiver)
|
||||
}
|
||||
pub fn is_receiver(self) -> bool {
|
||||
matches!(self, Self::Receiver | Self::TransmitterReceiver)
|
||||
}
|
||||
}
|
||||
|
||||
/// Identifier of a sensing measurement setup ("Measurement Setup ID").
|
||||
///
|
||||
/// Validated newtype: construction and deserialization both reject values
|
||||
/// above [`MAX_SETUP_ID`].
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[serde(try_from = "u8", into = "u8")]
|
||||
pub struct MeasurementSetupId(u8);
|
||||
|
||||
impl MeasurementSetupId {
|
||||
pub fn new(value: u8) -> Result<Self, BfError> {
|
||||
if value > MAX_SETUP_ID {
|
||||
Err(BfError::InvalidSetupId { value })
|
||||
} else {
|
||||
Ok(Self(value))
|
||||
}
|
||||
}
|
||||
pub fn value(self) -> u8 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MeasurementSetupId {
|
||||
type Error = BfError;
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MeasurementSetupId> for u8 {
|
||||
fn from(id: MeasurementSetupId) -> u8 {
|
||||
id.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Identifier of a sensing measurement instance within a setup
|
||||
/// ("Measurement Instance ID"). Wraps modulo 256.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct MeasurementInstanceId(u8);
|
||||
|
||||
impl MeasurementInstanceId {
|
||||
pub fn new(value: u8) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
pub fn value(self) -> u8 {
|
||||
self.0
|
||||
}
|
||||
pub fn wrapping_next(self) -> Self {
|
||||
Self(self.0.wrapping_add(1))
|
||||
}
|
||||
}
|
||||
|
||||
/// Channel width of a bandwidth variant in MHz (capability comparisons).
|
||||
pub fn bandwidth_mhz(bw: Bandwidth) -> u16 {
|
||||
match bw {
|
||||
Bandwidth::Bw20 => 20,
|
||||
Bandwidth::Bw40 => 40,
|
||||
Bandwidth::Bw80 => 80,
|
||||
Bandwidth::Bw160 => 160,
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold-based reporting parameters: a report is generated only when the
|
||||
/// measurement changes by at least `delta_percent` relative to the last
|
||||
/// reported measurement (normalized-change trigger).
|
||||
///
|
||||
/// Deserialization validates through [`ThresholdParams::new`] so the
|
||||
/// `delta_percent <= 100` invariant holds on every construction path,
|
||||
/// including untrusted wire/persisted payloads (same convention as
|
||||
/// [`MeasurementSetupId`]).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(try_from = "RawThresholdParams")]
|
||||
pub struct ThresholdParams {
|
||||
delta_percent: u8,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawThresholdParams {
|
||||
delta_percent: u8,
|
||||
}
|
||||
|
||||
impl TryFrom<RawThresholdParams> for ThresholdParams {
|
||||
type Error = BfError;
|
||||
|
||||
fn try_from(raw: RawThresholdParams) -> Result<Self, Self::Error> {
|
||||
Self::new(raw.delta_percent)
|
||||
}
|
||||
}
|
||||
|
||||
impl ThresholdParams {
|
||||
pub fn new(delta_percent: u8) -> Result<Self, BfError> {
|
||||
if delta_percent > 100 {
|
||||
Err(BfError::InvalidThreshold {
|
||||
value: delta_percent,
|
||||
})
|
||||
} else {
|
||||
Ok(Self { delta_percent })
|
||||
}
|
||||
}
|
||||
pub fn delta_percent(self) -> u8 {
|
||||
self.delta_percent
|
||||
}
|
||||
/// Whether the change from `previous` to `current` crosses the threshold.
|
||||
pub fn exceeds(self, previous: f64, current: f64) -> bool {
|
||||
let denom = previous.abs().max(f64::EPSILON);
|
||||
((current - previous).abs() / denom) * 100.0 >= self.delta_percent as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Reporting discipline negotiated in the sensing measurement setup.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ReportingConfig {
|
||||
/// Report every measurement instance.
|
||||
EveryInstance,
|
||||
/// Threshold-based reporting (report only on significant change).
|
||||
ThresholdBased(ThresholdParams),
|
||||
}
|
||||
|
||||
/// Parameters of a sensing measurement setup ("Sensing Measurement Setup
|
||||
/// element" parameters, sub-7 GHz). Consent metadata is **required**.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MeasurementSetupParams {
|
||||
/// Sounding bandwidth.
|
||||
pub bandwidth: Bandwidth,
|
||||
/// Periodicity of measurement instances, in milliseconds.
|
||||
pub period_ms: u32,
|
||||
/// Measurement instances per burst.
|
||||
pub burst_instances: u8,
|
||||
/// Reporting discipline (per-instance or threshold-based).
|
||||
pub reporting: ReportingConfig,
|
||||
/// Transceiver role the initiator takes during measurement instances.
|
||||
pub initiator_role: TransceiverRole,
|
||||
/// Transceiver role the responder takes during measurement instances.
|
||||
pub responder_role: TransceiverRole,
|
||||
/// Required governance metadata (ADR-153 privacy requirement).
|
||||
pub consent: ConsentMode,
|
||||
}
|
||||
|
||||
impl MeasurementSetupParams {
|
||||
/// Boundary validation: range checks plus role/consent coherence.
|
||||
pub fn validate(&self) -> Result<(), BfError> {
|
||||
if self.period_ms < MIN_PERIOD_MS || self.period_ms > MAX_PERIOD_MS {
|
||||
return Err(BfError::InvalidPeriod {
|
||||
period_ms: self.period_ms,
|
||||
});
|
||||
}
|
||||
if self.burst_instances == 0 || self.burst_instances > MAX_BURST_INSTANCES {
|
||||
return Err(BfError::InvalidBurstInstances {
|
||||
count: self.burst_instances,
|
||||
});
|
||||
}
|
||||
let has_tx = self.initiator_role.is_transmitter() || self.responder_role.is_transmitter();
|
||||
let has_rx = self.initiator_role.is_receiver() || self.responder_role.is_receiver();
|
||||
if !has_tx || !has_rx {
|
||||
return Err(BfError::InvalidTransceiverRoles);
|
||||
}
|
||||
if self.consent == ConsentMode::Disabled {
|
||||
return Err(BfError::SensingDisabledByPolicy);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Capability advertisement for capability negotiation (ADR-153): no
|
||||
/// hardcoded ESP32 assumptions in the future-silicon path.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct SensingCapabilities {
|
||||
pub sub_7_ghz: bool,
|
||||
pub dmg: bool,
|
||||
pub edmg: bool,
|
||||
pub csi_report: bool,
|
||||
pub threshold_reporting: bool,
|
||||
pub sensing_by_proxy: bool,
|
||||
pub max_bandwidth_mhz: u16,
|
||||
pub max_period_ms: u32,
|
||||
pub max_active_setups: u16,
|
||||
}
|
||||
|
||||
impl SensingCapabilities {
|
||||
/// Permissive capability set for simulation and tests.
|
||||
pub fn sim_full() -> Self {
|
||||
Self {
|
||||
sub_7_ghz: true,
|
||||
dmg: false,
|
||||
edmg: false,
|
||||
csi_report: true,
|
||||
threshold_reporting: true,
|
||||
sensing_by_proxy: true,
|
||||
max_bandwidth_mhz: 160,
|
||||
max_period_ms: MAX_PERIOD_MS,
|
||||
max_active_setups: 8,
|
||||
}
|
||||
}
|
||||
|
||||
/// What today's opportunistic ESP32 CSI extraction (ADR-018/ADR-028) can
|
||||
/// honor when mapped through [`crate::ieee80211bf::transport::OpportunisticCsiBridge`].
|
||||
pub fn esp32_opportunistic() -> Self {
|
||||
Self {
|
||||
sub_7_ghz: true,
|
||||
dmg: false,
|
||||
edmg: false,
|
||||
csi_report: true,
|
||||
threshold_reporting: true,
|
||||
sensing_by_proxy: false,
|
||||
max_bandwidth_mhz: 40,
|
||||
max_period_ms: 60_000,
|
||||
max_active_setups: 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate setup parameters against this capability set; `Err` carries
|
||||
/// the protocol-level rejection status to return to the peer.
|
||||
pub fn evaluate(&self, params: &MeasurementSetupParams) -> Result<(), SetupStatus> {
|
||||
if !self.sub_7_ghz || !self.csi_report {
|
||||
return Err(SetupStatus::RejectedUnsupportedParams);
|
||||
}
|
||||
if bandwidth_mhz(params.bandwidth) > self.max_bandwidth_mhz {
|
||||
return Err(SetupStatus::RejectedUnsupportedParams);
|
||||
}
|
||||
if params.period_ms > self.max_period_ms {
|
||||
return Err(SetupStatus::RejectedUnsupportedParams);
|
||||
}
|
||||
if matches!(params.reporting, ReportingConfig::ThresholdBased(_))
|
||||
&& !self.threshold_reporting
|
||||
{
|
||||
return Err(SetupStatus::RejectedUnsupportedParams);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Status carried by a sensing measurement setup response.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SetupStatus {
|
||||
Accepted,
|
||||
/// The receiving endpoint does not act as a sensing responder for this
|
||||
/// request — e.g. an initiator-role session received a setup request
|
||||
/// (single-role design, see [`crate::ieee80211bf::session`]).
|
||||
RejectedNotSupported,
|
||||
RejectedUnsupportedParams,
|
||||
RejectedSetupIdCollision,
|
||||
RejectedIncompatibleProfile,
|
||||
RejectedByPolicy,
|
||||
RejectedCapacity,
|
||||
}
|
||||
@@ -40,6 +40,12 @@ mod csi_frame;
|
||||
mod error;
|
||||
pub mod esp32;
|
||||
mod esp32_parser;
|
||||
// ADR-153: IEEE 802.11bf-2025 forward-compatibility protocol model
|
||||
// (sensing setup / measurement instance / report / SBP / termination).
|
||||
// Simulation-tested; no commodity silicon implements the standard yet —
|
||||
// the OpportunisticCsiBridge maps today's ESP32 CSI extraction onto the
|
||||
// standardized report path until an OTA binding exists.
|
||||
pub mod ieee80211bf;
|
||||
pub mod sync_packet;
|
||||
|
||||
// ADR-081: Rust mirror of the firmware radio abstraction layer (L1) and
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
//! TrainError (top-level)
|
||||
//! ├── ConfigError (config validation / file loading)
|
||||
//! ├── DatasetError (data loading, I/O, format)
|
||||
//! └── SubcarrierError (frequency-axis resampling)
|
||||
//! ├── SubcarrierError (frequency-axis resampling)
|
||||
//! └── MaeError (MAE patchify / masking — ADR-152 §2.3)
|
||||
//! ```
|
||||
|
||||
use std::path::PathBuf;
|
||||
@@ -44,6 +45,10 @@ pub enum TrainError {
|
||||
#[error("Dataset error: {0}")]
|
||||
Dataset(#[from] DatasetError),
|
||||
|
||||
/// A MAE pretraining patchify / masking error (ADR-152 §2.3).
|
||||
#[error("MAE pretraining error: {0}")]
|
||||
Mae(#[from] MaeError),
|
||||
|
||||
/// JSON (de)serialization error.
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
@@ -373,3 +378,85 @@ impl SubcarrierError {
|
||||
SubcarrierError::NumericalError(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MaeError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the MAE pretraining patchify / masking functions
|
||||
/// ([`crate::mae`], ADR-152 §2.3).
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MaeError {
|
||||
/// The flat window buffer does not match the declared `time × subc` shape.
|
||||
#[error(
|
||||
"Window length {actual} does not match time × subcarriers = \
|
||||
{time} × {subc} = {expected}"
|
||||
)]
|
||||
WindowShapeMismatch {
|
||||
/// Declared time dimension.
|
||||
time: usize,
|
||||
/// Declared subcarrier dimension.
|
||||
subc: usize,
|
||||
/// Expected buffer length (`time * subc`).
|
||||
expected: usize,
|
||||
/// Actual buffer length.
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// A patch dimension is larger than the window along that axis.
|
||||
#[error("Patch {axis} extent {patch} exceeds window {axis} extent {window}")]
|
||||
PatchExceedsWindow {
|
||||
/// Axis name (`"time"` or `"subcarrier"`).
|
||||
axis: &'static str,
|
||||
/// Patch extent along the axis.
|
||||
patch: usize,
|
||||
/// Window extent along the axis.
|
||||
window: usize,
|
||||
},
|
||||
|
||||
/// The window is not an exact multiple of the patch extent along an axis.
|
||||
///
|
||||
/// Patchification never silently truncates; crop the window to `crop`
|
||||
/// (the largest divisible extent) or change the patch size.
|
||||
#[error(
|
||||
"Window {axis} extent {window} is not divisible by patch {axis} extent \
|
||||
{patch} (remainder {remainder}); crop the window to {crop} or change \
|
||||
the patch size"
|
||||
)]
|
||||
NotDivisible {
|
||||
/// Axis name (`"time"` or `"subcarrier"`).
|
||||
axis: &'static str,
|
||||
/// Window extent along the axis.
|
||||
window: usize,
|
||||
/// Patch extent along the axis.
|
||||
patch: usize,
|
||||
/// `window % patch`.
|
||||
remainder: usize,
|
||||
/// Largest divisible extent (`window - remainder`).
|
||||
crop: usize,
|
||||
},
|
||||
|
||||
/// The mask ratio is not a finite value strictly inside `(0, 1)` — the
|
||||
/// same rule as [`MaePretrainConfig::validate`]. A NaN ratio must never
|
||||
/// silently mask zero patches, and ratios ≤ 0 / ≥ 1 degenerate to
|
||||
/// all-visible / all-masked grids.
|
||||
///
|
||||
/// [`MaePretrainConfig::validate`]: crate::mae::MaePretrainConfig::validate
|
||||
#[error("Invalid mask ratio {ratio}: must be finite and strictly inside (0, 1)")]
|
||||
InvalidMaskRatio {
|
||||
/// The offending ratio.
|
||||
ratio: f64,
|
||||
},
|
||||
|
||||
/// A NaN or ±inf CSI value was found; corrupted input must be cleaned
|
||||
/// upstream, never masked over.
|
||||
#[error("Non-finite CSI value {value} at (t={row}, sc={col})")]
|
||||
NonFiniteValue {
|
||||
/// Time index of the offending value.
|
||||
row: usize,
|
||||
/// Subcarrier index of the offending value.
|
||||
col: usize,
|
||||
/// The non-finite value itself.
|
||||
value: f32,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -49,11 +49,13 @@ pub mod domain;
|
||||
pub mod error;
|
||||
pub mod eval;
|
||||
pub mod geometry;
|
||||
pub mod mae;
|
||||
pub mod rapid_adapt;
|
||||
pub mod ruview_metrics;
|
||||
pub mod signal_features;
|
||||
pub mod subcarrier;
|
||||
pub mod virtual_aug;
|
||||
pub mod wiflow_std;
|
||||
|
||||
// The following modules use `tch` (PyTorch Rust bindings) for GPU-accelerated
|
||||
// training and are only compiled when the `tch-backend` feature is enabled.
|
||||
@@ -81,7 +83,7 @@ pub use config::TrainingConfig;
|
||||
pub use dataset::{
|
||||
CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticConfig, SyntheticCsiDataset,
|
||||
};
|
||||
pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError};
|
||||
pub use error::{ConfigError, DatasetError, MaeError, SubcarrierError, TrainError};
|
||||
// TrainResult<T> is the generic Result alias from error.rs; the concrete
|
||||
// TrainResult struct from trainer.rs is accessed via trainer::TrainResult.
|
||||
pub use error::TrainResult as TrainResultAlias;
|
||||
@@ -89,6 +91,14 @@ pub use subcarrier::{
|
||||
compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance,
|
||||
};
|
||||
|
||||
// ADR-152 §2.3 — UNSW MAE pretraining recipe re-exports.
|
||||
pub use mae::{patchify, random_mask, unpatchify, MaePretrainConfig, MaskIndices, PatchGrid};
|
||||
|
||||
// ADR-152 §2.2 — WiFlow-STD (DY2434) spatio-temporal-decoupled pose model.
|
||||
pub use wiflow_std::WiFlowStdConfig;
|
||||
#[cfg(feature = "tch-backend")]
|
||||
pub use wiflow_std::WiFlowStdModel;
|
||||
|
||||
// MERIDIAN (ADR-027) re-exports.
|
||||
pub use domain::{AdversarialSchedule, DomainClassifier, DomainFactorizer, GradientReversalLayer};
|
||||
pub use eval::CrossDomainEvaluator;
|
||||
|
||||
@@ -118,7 +118,7 @@ impl WiFiDensePoseLoss {
|
||||
// Normalise by number of visible joints in the batch.
|
||||
let n_visible = visibility.sum(Kind::Float);
|
||||
// Guard against division by zero (entire batch may have no labels).
|
||||
let safe_n = n_visible.clamp(1.0, f64::MAX);
|
||||
let safe_n = n_visible.clamp_min(1.0);
|
||||
|
||||
masked.sum(Kind::Float) / safe_n
|
||||
}
|
||||
@@ -165,7 +165,7 @@ impl WiFiDensePoseLoss {
|
||||
let masked_target_uv = target_uv * &fg_mask_f;
|
||||
|
||||
// Count foreground pixels × 48 channels to normalise.
|
||||
let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX);
|
||||
let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0);
|
||||
|
||||
// Smooth-L1 with beta=1.0, reduction=Sum then divide by fg count.
|
||||
let uv_loss_sum = masked_pred_uv.smooth_l1_loss(&masked_target_uv, Reduction::Sum, 1.0);
|
||||
@@ -234,7 +234,7 @@ impl WiFiDensePoseLoss {
|
||||
// UV loss (foreground masked)
|
||||
let fg_mask = target_int.not_equal(0_i64);
|
||||
let fg_mask_f = fg_mask.unsqueeze(1).expand_as(pu).to_kind(Kind::Float);
|
||||
let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX);
|
||||
let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0);
|
||||
let uv_loss =
|
||||
(pu * &fg_mask_f).smooth_l1_loss(&(tu * &fg_mask_f), Reduction::Sum, 1.0)
|
||||
/ n_fg;
|
||||
@@ -743,10 +743,11 @@ mod tests {
|
||||
}
|
||||
|
||||
// Visible batch (index 1) should have non-zero heatmaps.
|
||||
let heatmaps_ref = &heatmaps;
|
||||
let batch1_sum: f32 = (0..num_joints)
|
||||
.map(|j| {
|
||||
(0..size)
|
||||
.flat_map(|r| (0..size).map(move |c| heatmaps[[1, j, r, c]]))
|
||||
.flat_map(|r| (0..size).map(move |c| heatmaps_ref[[1, j, r, c]]))
|
||||
.sum::<f32>()
|
||||
})
|
||||
.sum();
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
//! Masked-autoencoder (MAE) pretraining recipe for the ADR-150 RF foundation
|
||||
//! encoder — ADR-152 §2.3 (amends ADR-150 §2.3).
|
||||
//!
|
||||
//! Implements the *measured* tokenization recipe from the UNSW MAE pretraining
|
||||
//! study (arXiv [2511.18792](https://arxiv.org/abs/2511.18792), Nov 2025), the
|
||||
//! largest heterogeneous CSI pretraining run to date (1,320,892 samples, 14
|
||||
//! public datasets, 4 devices, 2.4/5/6 GHz, 20–160 MHz):
|
||||
//!
|
||||
//! - **80% masking ratio** over the patch grid.
|
||||
//! - **Small (30, 3) patches** — 30 time steps × 3 subcarriers — measured
|
||||
//! **+4.7%** over (40, 5) patches by preserving fine temporal dynamics.
|
||||
//! - Encoder capacity stays **ViT-Small-class (~15M params)**: ViT-Base adds
|
||||
//! only +0.4–0.9% over ViT-Small in-study, corroborating ADR-150's own
|
||||
//! finding that capacity hurts cross-subject transfer.
|
||||
//! - Unseen-domain performance scales **log-linearly with pretraining data,
|
||||
//! unsaturated at 1.3M samples** — data aggregation outranks architecture
|
||||
//! work (ADR-152 §2.3).
|
||||
//!
|
||||
//! This module provides the GPU-free half of the recipe: configuration,
|
||||
//! patchification, and deterministic random masking. The (future, ADR-150)
|
||||
//! encoder consumes [`PatchGrid`] + [`MaskIndices`] to compute the masked
|
||||
//! reconstruction loss (`L_masked_csi` in ADR-150 §2.3's loss stack).
|
||||
//!
|
||||
//! ## Axis convention
|
||||
//!
|
||||
//! A CSI window is `time × subcarriers`, row-major (`index = t * subc + sc`),
|
||||
//! matching the crate's `[T, …, n_sc]` dataset layout (time first, subcarriers
|
||||
//! last) and the UNSW "(30 time steps, 3 subcarriers)" patch framing. Patches
|
||||
//! are indexed row-major over the patch grid (`p = pt * n_patches_subc + ps`),
|
||||
//! and values within a patch are row-major time-major
|
||||
//! (`local = lt * patch_subc + lsc`).
|
||||
//!
|
||||
//! ## Divisibility policy: error, never truncate
|
||||
//!
|
||||
//! Window dimensions **must** be exact multiples of the patch dimensions.
|
||||
//! Non-divisible shapes return [`MaeError::NotDivisible`] instead of silently
|
||||
//! truncating trailing samples (this crate never silently drops data). The
|
||||
//! error names the largest divisible crop; use
|
||||
//! [`MaePretrainConfig::cropped_window_shape`] to compute it and crop
|
||||
//! explicitly before calling [`patchify`].
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use wifi_densepose_train::mae::MaePretrainConfig;
|
||||
//!
|
||||
//! let cfg = MaePretrainConfig::default(); // 0.80 masking, (30, 3) patches
|
||||
//! cfg.validate().expect("default recipe is valid");
|
||||
//!
|
||||
//! // 90 frames × 54 subcarriers → a 3 × 18 grid of (30, 3) patches.
|
||||
//! let window = vec![0.25_f32; 90 * 54];
|
||||
//! let (grid, mask) = cfg.mask_window(&window, 90, 54).unwrap();
|
||||
//! assert_eq!(grid.n_patches(), 54);
|
||||
//! assert_eq!(mask.masked.len(), 43); // round(0.80 * 54)
|
||||
//! assert_eq!(mask.visible.len(), 11);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{ConfigError, MaeError};
|
||||
use crate::virtual_aug::Xorshift64;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MaePretrainConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Hyper-parameters for masked-CSI pretraining (ADR-152 §2.3).
|
||||
///
|
||||
/// Defaults are the measured-optimal UNSW recipe (arXiv 2511.18792); change
|
||||
/// them only with benchmark evidence. Serializable so the recipe is recorded
|
||||
/// in checkpoint metadata alongside [`crate::config::TrainingConfig`].
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MaePretrainConfig {
|
||||
/// Fraction of patches hidden from the encoder, in `(0, 1)`.
|
||||
///
|
||||
/// Default: **0.80** (UNSW measured optimum).
|
||||
pub mask_ratio: f64,
|
||||
|
||||
/// Patch extent along the time axis, in frames. Default: **30**.
|
||||
pub patch_time: usize,
|
||||
|
||||
/// Patch extent along the subcarrier axis. Default: **3**.
|
||||
pub patch_subc: usize,
|
||||
|
||||
/// Base seed for the deterministic mask sampler. Default: **42**.
|
||||
///
|
||||
/// For per-sample masks derive a child seed (e.g.
|
||||
/// `seed ^ sample_idx as u64`) and pass it to [`random_mask`]; reusing one
|
||||
/// seed yields the identical mask for every sample.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
impl Default for MaePretrainConfig {
|
||||
fn default() -> Self {
|
||||
MaePretrainConfig {
|
||||
mask_ratio: 0.80,
|
||||
patch_time: 30,
|
||||
patch_subc: 3,
|
||||
seed: 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MaePretrainConfig {
|
||||
/// Validate the shape-independent fields.
|
||||
///
|
||||
/// # Validated invariants
|
||||
///
|
||||
/// - `mask_ratio` must be strictly inside `(0, 1)` and finite.
|
||||
/// - `patch_time` and `patch_subc` must be at least 1.
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
if !self.mask_ratio.is_finite() || self.mask_ratio <= 0.0 || self.mask_ratio >= 1.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"mask_ratio",
|
||||
format!("must be in (0.0, 1.0), got {}", self.mask_ratio),
|
||||
));
|
||||
}
|
||||
if self.patch_time == 0 {
|
||||
return Err(ConfigError::invalid_value("patch_time", "must be >= 1"));
|
||||
}
|
||||
if self.patch_subc == 0 {
|
||||
return Err(ConfigError::invalid_value("patch_subc", "must be >= 1"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check this recipe against a concrete `time × subc` window shape.
|
||||
///
|
||||
/// Errors if a patch dimension exceeds the window or if either axis is
|
||||
/// not an exact multiple of the patch extent (divisibility policy above).
|
||||
pub fn validate_for_window(&self, time: usize, subc: usize) -> Result<(), MaeError> {
|
||||
check_axis("time", time, self.patch_time)?;
|
||||
check_axis("subcarrier", subc, self.patch_subc)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Largest `(time, subc)` crop of the given window that is exactly
|
||||
/// divisible by the patch dimensions. Either component may be 0 when the
|
||||
/// window is smaller than one patch.
|
||||
#[must_use]
|
||||
pub fn cropped_window_shape(&self, time: usize, subc: usize) -> (usize, usize) {
|
||||
(
|
||||
(time / self.patch_time) * self.patch_time,
|
||||
(subc / self.patch_subc) * self.patch_subc,
|
||||
)
|
||||
}
|
||||
|
||||
/// Number of patches a `time × subc` window yields under this recipe.
|
||||
pub fn num_patches(&self, time: usize, subc: usize) -> Result<usize, MaeError> {
|
||||
self.validate_for_window(time, subc)?;
|
||||
Ok((time / self.patch_time) * (subc / self.patch_subc))
|
||||
}
|
||||
|
||||
/// Exact number of masked patches for a grid of `n_patches`:
|
||||
/// `round(mask_ratio * n_patches)`, clamped to `[0, n_patches]`.
|
||||
#[must_use]
|
||||
pub fn num_masked(&self, n_patches: usize) -> usize {
|
||||
((self.mask_ratio * n_patches as f64).round() as usize).min(n_patches)
|
||||
}
|
||||
|
||||
/// Patchify `window` and draw the deterministic random mask in one step,
|
||||
/// using `self.seed`. See [`patchify`] and [`random_mask`].
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Everything [`patchify`] rejects, plus [`MaeError::InvalidMaskRatio`]
|
||||
/// if `self.mask_ratio` is not finite or outside `(0, 1)` (the
|
||||
/// [`Self::validate`] rule) — a NaN ratio must never silently mask zero
|
||||
/// patches.
|
||||
pub fn mask_window(
|
||||
&self,
|
||||
window: &[f32],
|
||||
time: usize,
|
||||
subc: usize,
|
||||
) -> Result<(PatchGrid, MaskIndices), MaeError> {
|
||||
let grid = patchify(window, time, subc, self)?;
|
||||
let mask = random_mask(grid.n_patches(), self.mask_ratio, self.seed)?;
|
||||
Ok((grid, mask))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PatchGrid / MaskIndices
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A CSI window decomposed into non-overlapping `patch_time × patch_subc`
|
||||
/// patches (see the module-level axis convention).
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PatchGrid {
|
||||
/// Patch extent along the time axis.
|
||||
pub patch_time: usize,
|
||||
/// Patch extent along the subcarrier axis.
|
||||
pub patch_subc: usize,
|
||||
/// Number of patch rows (`time / patch_time`).
|
||||
pub n_patches_time: usize,
|
||||
/// Number of patch columns (`subc / patch_subc`).
|
||||
pub n_patches_subc: usize,
|
||||
/// Flattened patches, row-major over the grid; each inner `Vec` is one
|
||||
/// patch of length `patch_time * patch_subc`, row-major time-major.
|
||||
pub patches: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl PatchGrid {
|
||||
/// Total number of patches in the grid.
|
||||
#[must_use]
|
||||
pub fn n_patches(&self) -> usize {
|
||||
self.n_patches_time * self.n_patches_subc
|
||||
}
|
||||
|
||||
/// Number of scalar values per patch.
|
||||
#[must_use]
|
||||
pub fn patch_len(&self) -> usize {
|
||||
self.patch_time * self.patch_subc
|
||||
}
|
||||
|
||||
/// Window shape `(time, subc)` this grid reconstructs to.
|
||||
#[must_use]
|
||||
pub fn window_shape(&self) -> (usize, usize) {
|
||||
(
|
||||
self.n_patches_time * self.patch_time,
|
||||
self.n_patches_subc * self.patch_subc,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorted, disjoint patch-index sets produced by [`random_mask`]. Together
|
||||
/// they cover `0..n_patches` exactly.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MaskIndices {
|
||||
/// Indices of patches hidden from the encoder (`round(ratio * n)` of them).
|
||||
pub masked: Vec<usize>,
|
||||
/// Indices of patches the encoder sees.
|
||||
pub visible: Vec<usize>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// patchify / unpatchify
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decompose a row-major `time × subc` CSI window into the patch grid defined
|
||||
/// by `cfg`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - [`MaeError::WindowShapeMismatch`] if `window.len() != time * subc`.
|
||||
/// - [`MaeError::PatchExceedsWindow`] / [`MaeError::NotDivisible`] per the
|
||||
/// module-level divisibility policy.
|
||||
/// - [`MaeError::NonFiniteValue`] on the first NaN/±inf encountered —
|
||||
/// corrupted CSI must be cleaned upstream, never masked over (cf. the
|
||||
/// WiFlow-STD NaN-poisoning incident, ADR-152 §2.2).
|
||||
pub fn patchify(
|
||||
window: &[f32],
|
||||
time: usize,
|
||||
subc: usize,
|
||||
cfg: &MaePretrainConfig,
|
||||
) -> Result<PatchGrid, MaeError> {
|
||||
let expected = time * subc;
|
||||
if window.len() != expected {
|
||||
return Err(MaeError::WindowShapeMismatch {
|
||||
time,
|
||||
subc,
|
||||
expected,
|
||||
actual: window.len(),
|
||||
});
|
||||
}
|
||||
cfg.validate_for_window(time, subc)?;
|
||||
if let Some(idx) = window.iter().position(|v| !v.is_finite()) {
|
||||
return Err(MaeError::NonFiniteValue {
|
||||
row: idx / subc,
|
||||
col: idx % subc,
|
||||
value: window[idx],
|
||||
});
|
||||
}
|
||||
|
||||
let n_patches_time = time / cfg.patch_time;
|
||||
let n_patches_subc = subc / cfg.patch_subc;
|
||||
let mut patches = Vec::with_capacity(n_patches_time * n_patches_subc);
|
||||
for pt in 0..n_patches_time {
|
||||
for ps in 0..n_patches_subc {
|
||||
let mut patch = Vec::with_capacity(cfg.patch_time * cfg.patch_subc);
|
||||
for lt in 0..cfg.patch_time {
|
||||
let t = pt * cfg.patch_time + lt;
|
||||
let row_start = t * subc + ps * cfg.patch_subc;
|
||||
patch.extend_from_slice(&window[row_start..row_start + cfg.patch_subc]);
|
||||
}
|
||||
patches.push(patch);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PatchGrid {
|
||||
patch_time: cfg.patch_time,
|
||||
patch_subc: cfg.patch_subc,
|
||||
n_patches_time,
|
||||
n_patches_subc,
|
||||
patches,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reassemble the full row-major `time × subc` window from a [`PatchGrid`].
|
||||
/// Exact inverse of [`patchify`].
|
||||
#[must_use]
|
||||
pub fn unpatchify(grid: &PatchGrid) -> Vec<f32> {
|
||||
unpatchify_select(grid, None, 0.0)
|
||||
}
|
||||
|
||||
/// Reassemble the window keeping only the patches listed in `visible`;
|
||||
/// every other patch's region is filled with `fill` (the standard MAE
|
||||
/// "visible tokens + mask token" view of the input).
|
||||
#[must_use]
|
||||
pub fn unpatchify_visible(grid: &PatchGrid, visible: &[usize], fill: f32) -> Vec<f32> {
|
||||
unpatchify_select(grid, Some(visible), fill)
|
||||
}
|
||||
|
||||
fn unpatchify_select(grid: &PatchGrid, keep: Option<&[usize]>, fill: f32) -> Vec<f32> {
|
||||
let (time, subc) = grid.window_shape();
|
||||
let mut window = vec![fill; time * subc];
|
||||
for (p, patch) in grid.patches.iter().enumerate() {
|
||||
if let Some(keep) = keep {
|
||||
if !keep.contains(&p) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let pt = p / grid.n_patches_subc;
|
||||
let ps = p % grid.n_patches_subc;
|
||||
for lt in 0..grid.patch_time {
|
||||
let t = pt * grid.patch_time + lt;
|
||||
let row_start = t * subc + ps * grid.patch_subc;
|
||||
let local_start = lt * grid.patch_subc;
|
||||
window[row_start..row_start + grid.patch_subc]
|
||||
.copy_from_slice(&patch[local_start..local_start + grid.patch_subc]);
|
||||
}
|
||||
}
|
||||
window
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// random_mask
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Draw a deterministic random mask over `n_patches` patches.
|
||||
///
|
||||
/// Exactly `round(mask_ratio * n_patches)` patches (clamped to
|
||||
/// `[0, n_patches]`) are masked, chosen by a seeded Fisher–Yates shuffle
|
||||
/// ([`Xorshift64`]), so the same `(n_patches, mask_ratio, seed)` triple always
|
||||
/// yields the same mask. Both index lists are sorted ascending, disjoint, and
|
||||
/// together cover `0..n_patches`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// [`MaeError::InvalidMaskRatio`] if `mask_ratio` is not finite or outside
|
||||
/// the open interval `(0, 1)` — the same rule as
|
||||
/// [`MaePretrainConfig::validate`]. Erroring (never clamping) keeps the
|
||||
/// module's error-not-silent policy: a NaN ratio would otherwise silently
|
||||
/// mask zero patches and a ratio ≥ 1 would mask everything.
|
||||
pub fn random_mask(n_patches: usize, mask_ratio: f64, seed: u64) -> Result<MaskIndices, MaeError> {
|
||||
if !mask_ratio.is_finite() || mask_ratio <= 0.0 || mask_ratio >= 1.0 {
|
||||
return Err(MaeError::InvalidMaskRatio { ratio: mask_ratio });
|
||||
}
|
||||
let n_masked = ((mask_ratio * n_patches as f64).round() as usize).min(n_patches);
|
||||
let mut order: Vec<usize> = (0..n_patches).collect();
|
||||
let mut rng = Xorshift64::new(seed);
|
||||
for i in (1..n_patches).rev() {
|
||||
let j = (rng.next_u64() % (i as u64 + 1)) as usize;
|
||||
order.swap(i, j);
|
||||
}
|
||||
let mut masked: Vec<usize> = order[..n_masked].to_vec();
|
||||
let mut visible: Vec<usize> = order[n_masked..].to_vec();
|
||||
masked.sort_unstable();
|
||||
visible.sort_unstable();
|
||||
Ok(MaskIndices { masked, visible })
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn check_axis(axis: &'static str, window: usize, patch: usize) -> Result<(), MaeError> {
|
||||
if patch > window {
|
||||
return Err(MaeError::PatchExceedsWindow {
|
||||
axis,
|
||||
patch,
|
||||
window,
|
||||
});
|
||||
}
|
||||
let remainder = window % patch;
|
||||
if remainder != 0 {
|
||||
return Err(MaeError::NotDivisible {
|
||||
axis,
|
||||
window,
|
||||
patch,
|
||||
remainder,
|
||||
crop: window - remainder,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
|
||||
use petgraph::graph::{DiGraph, NodeIndex};
|
||||
use petgraph::visit::EdgeRef;
|
||||
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
@@ -106,6 +107,24 @@ impl Default for MetricsResult {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EvalMetrics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Per-evaluation pose metrics.
|
||||
///
|
||||
/// Plain value container produced by evaluation runs: lower `mpjpe`/`gps`
|
||||
/// and higher `pck_at_05` indicate better predictions.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq)]
|
||||
pub struct EvalMetrics {
|
||||
/// Mean Per-Joint Position Error (normalised units).
|
||||
pub mpjpe: f64,
|
||||
/// Percentage of Correct Keypoints at threshold 0.05 (0-1 scale).
|
||||
pub pck_at_05: f64,
|
||||
/// Geodesic Point Similarity error for DensePose surface predictions.
|
||||
pub gps: f64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MetricsAccumulator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -126,7 +126,15 @@ impl WiFiDensePoseModel {
|
||||
tch::no_grad(|| self.forward_impl(amplitude, phase, false))
|
||||
}
|
||||
|
||||
/// Save model weights to a file (tch safetensors / .pt format).
|
||||
/// Save model weights to a file. The tch `VarStore` dispatches the format
|
||||
/// on the file extension: `.safetensors` → safetensors, anything else →
|
||||
/// torch `.pt`.
|
||||
///
|
||||
/// **Platform constraint:** prefer `.safetensors`. The `.pt` path
|
||||
/// (`_save_parameters`/`_load_parameters`) is broken on Windows with
|
||||
/// torch 2.11 (GenericDict internal assert on the load roundtrip — see
|
||||
/// `wiflow_std/model.rs::save_and_load_roundtrip`), which is why
|
||||
/// [`crate::trainer::Trainer`] writes `.safetensors` checkpoints.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
@@ -137,7 +145,8 @@ impl WiFiDensePoseModel {
|
||||
.map_err(|e| TrainError::training_step(format!("save failed: {e}")))
|
||||
}
|
||||
|
||||
/// Load model weights from a file.
|
||||
/// Load model weights from a file (format dispatched on extension; see
|
||||
/// the `.pt`-on-Windows caveat on [`Self::save`]).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
@@ -182,7 +191,7 @@ impl WiFiDensePoseModel {
|
||||
self.vs
|
||||
.trainable_variables()
|
||||
.iter()
|
||||
.map(|t| t.numel())
|
||||
.map(|t| t.numel() as i64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
@@ -297,7 +306,12 @@ fn apply_antenna_attention(x: &Tensor, lambda: f32) -> Tensor {
|
||||
let xi = x.select(0, bi as i64); // [n_ant, n_sc]
|
||||
|
||||
// Move to CPU and convert to f32 for the pure-Rust attention kernel.
|
||||
let flat: Vec<f32> = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
|
||||
let flat: Vec<f32> = Vec::<f32>::try_from(
|
||||
xi.to_kind(Kind::Float)
|
||||
.to_device(Device::Cpu)
|
||||
.flatten(0, -1),
|
||||
)
|
||||
.expect("antenna tensor to vec");
|
||||
|
||||
// Q = K = V = the antenna features (self-attention over antenna paths).
|
||||
let out = attn_mincut(
|
||||
@@ -350,7 +364,12 @@ fn apply_spatial_attention(x: &Tensor) -> Tensor {
|
||||
for bi in 0..b {
|
||||
// Extract [C, H*W] and transpose to [H*W, C].
|
||||
let xi = x.select(0, bi).reshape([c, h * w]).transpose(0, 1); // [H*W, C]
|
||||
let flat: Vec<f32> = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
|
||||
let flat: Vec<f32> = Vec::<f32>::try_from(
|
||||
xi.to_kind(Kind::Float)
|
||||
.to_device(Device::Cpu)
|
||||
.flatten(0, -1),
|
||||
)
|
||||
.expect("spatial tensor to vec");
|
||||
|
||||
// Build token slices — one per spatial position.
|
||||
let tokens: Vec<&[f32]> = (0..n_spatial).map(|i| &flat[i * d..(i + 1) * d]).collect();
|
||||
@@ -973,7 +992,9 @@ mod tests {
|
||||
let mut model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
|
||||
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let path = tmp.path().join("weights.pt");
|
||||
// safetensors, not .pt: this torch build's .pt roundtrip is broken on
|
||||
// Windows (torch 2.11 GenericDict internal assert).
|
||||
let path = tmp.path().join("weights.safetensors");
|
||||
|
||||
model.save(&path).expect("save should succeed");
|
||||
model.load(&path).expect("load should succeed");
|
||||
|
||||
@@ -153,11 +153,11 @@ pub fn run_proof(proof_dir: &Path) -> Result<ProofResult, Box<dyn std::error::Er
|
||||
let num_kp = kp.size()[1] as usize;
|
||||
let hm_size = cfg.heatmap_size;
|
||||
|
||||
let kp_vec: Vec<f32> = Vec::<f64>::from(kp.to_kind(Kind::Double).flatten(0, -1))
|
||||
let kp_vec: Vec<f32> = Vec::<f64>::try_from(kp.to_kind(Kind::Double).flatten(0, -1))?
|
||||
.iter()
|
||||
.map(|&x| x as f32)
|
||||
.collect();
|
||||
let vis_vec: Vec<f32> = Vec::<f64>::from(vis.to_kind(Kind::Double).flatten(0, -1))
|
||||
let vis_vec: Vec<f32> = Vec::<f64>::try_from(vis.to_kind(Kind::Double).flatten(0, -1))?
|
||||
.iter()
|
||||
.map(|&x| x as f32)
|
||||
.collect();
|
||||
@@ -261,7 +261,7 @@ pub fn hash_model_weights(model: &WiFiDensePoseModel) -> String {
|
||||
.flatten(0, -1)
|
||||
.to_kind(Kind::Float)
|
||||
.to_device(Device::Cpu);
|
||||
let values: Vec<f32> = Vec::<f32>::from(&flat);
|
||||
let values: Vec<f32> = Vec::<f32>::try_from(&flat).expect("param tensor to vec");
|
||||
let mut buf = vec![0u8; values.len() * 4];
|
||||
for (i, v) in values.iter().enumerate() {
|
||||
let bytes = v.to_le_bytes();
|
||||
@@ -292,6 +292,15 @@ pub fn load_expected_hash(proof_dir: &Path) -> Result<Option<String>, std::io::E
|
||||
Ok(if hash.is_empty() { None } else { Some(hash) })
|
||||
}
|
||||
|
||||
/// Verify that `path` is a valid checkpoint directory.
|
||||
///
|
||||
/// Returns `true` only when the path exists and is a directory. Deterministic
|
||||
/// and side-effect free — repeated calls always return the same result for an
|
||||
/// unchanged filesystem.
|
||||
pub fn verify_checkpoint_dir(path: &Path) -> bool {
|
||||
path.is_dir()
|
||||
}
|
||||
|
||||
/// Save the expected model hash to `<proof_dir>/expected_proof.sha256`.
|
||||
///
|
||||
/// Creates `proof_dir` if it does not already exist.
|
||||
|
||||
@@ -286,7 +286,12 @@ impl Trainer {
|
||||
best_epoch = epoch;
|
||||
patience_counter = 0;
|
||||
|
||||
let ckpt_name = format!("best_epoch{epoch:04}_pck{val_pck:.4}.pt");
|
||||
// .safetensors, not .pt: VarStore dispatches the format on
|
||||
// the extension, and this torch build's .pt
|
||||
// _save_parameters/_load_parameters roundtrip is broken on
|
||||
// Windows (torch 2.11 GenericDict internal assert — see
|
||||
// wiflow_std/model.rs save_and_load_roundtrip).
|
||||
let ckpt_name = format!("best_epoch{epoch:04}_pck{val_pck:.4}.safetensors");
|
||||
let ckpt_path = self.config.checkpoint_dir.join(&ckpt_name);
|
||||
|
||||
match self.model.save(&ckpt_path) {
|
||||
@@ -339,8 +344,8 @@ impl Trainer {
|
||||
}
|
||||
}
|
||||
|
||||
// Save final model regardless.
|
||||
let final_ckpt = self.config.checkpoint_dir.join("final.pt");
|
||||
// Save final model regardless (.safetensors — see checkpoint note above).
|
||||
let final_ckpt = self.config.checkpoint_dir.join("final.safetensors");
|
||||
if let Err(e) = self.model.save(&final_ckpt) {
|
||||
warn!("Failed to save final model: {e}");
|
||||
}
|
||||
@@ -413,7 +418,8 @@ impl Trainer {
|
||||
.load(path)
|
||||
.map_err(|e| TrainError::checkpoint(e.to_string(), path))?;
|
||||
|
||||
// Try to parse the epoch from the filename (e.g. "best_epoch0042_pck0.7842.pt").
|
||||
// Try to parse the epoch from the filename, extension-agnostic
|
||||
// (e.g. "best_epoch0042_pck0.7842.safetensors").
|
||||
let epoch = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
@@ -582,11 +588,13 @@ fn kp_to_heatmap_tensor(
|
||||
let num_kp = kp_tensor.size()[1] as usize;
|
||||
|
||||
// Convert to ndarray for generate_target_heatmaps.
|
||||
let kp_vec: Vec<f32> = Vec::<f64>::from(kp_tensor.to_kind(Kind::Double).flatten(0, -1))
|
||||
let kp_vec: Vec<f32> = Vec::<f64>::try_from(kp_tensor.to_kind(Kind::Double).flatten(0, -1))
|
||||
.expect("kp tensor to vec")
|
||||
.iter()
|
||||
.map(|&x| x as f32)
|
||||
.collect();
|
||||
let vis_vec: Vec<f32> = Vec::<f64>::from(vis_tensor.to_kind(Kind::Double).flatten(0, -1))
|
||||
let vis_vec: Vec<f32> = Vec::<f64>::try_from(vis_tensor.to_kind(Kind::Double).flatten(0, -1))
|
||||
.expect("vis tensor to vec")
|
||||
.iter()
|
||||
.map(|&x| x as f32)
|
||||
.collect();
|
||||
@@ -622,8 +630,8 @@ fn heatmap_to_keypoints(heatmaps: &Tensor) -> Tensor {
|
||||
let arg = flat.argmax(-1, false);
|
||||
|
||||
// Decompose linear index into (row, col).
|
||||
let row = (&arg / w).to_kind(Kind::Float); // [B, 17]
|
||||
let col = (&arg % w).to_kind(Kind::Float); // [B, 17]
|
||||
let row = arg.divide_scalar_mode(w, "floor").to_kind(Kind::Float); // [B, 17]
|
||||
let col = arg.remainder(w).to_kind(Kind::Float); // [B, 17]
|
||||
|
||||
// Normalize to [0, 1]
|
||||
let x = col / (w - 1) as f64;
|
||||
@@ -639,7 +647,8 @@ fn heatmap_to_keypoints(heatmaps: &Tensor) -> Tensor {
|
||||
fn extract_kp_ndarray(kp_tensor: &Tensor, batch_idx: usize) -> Array2<f32> {
|
||||
let num_kp = kp_tensor.size()[1] as usize;
|
||||
let row = kp_tensor.select(0, batch_idx as i64);
|
||||
let data: Vec<f32> = Vec::<f64>::from(row.to_kind(Kind::Double).flatten(0, -1))
|
||||
let data: Vec<f32> = Vec::<f64>::try_from(row.to_kind(Kind::Double).flatten(0, -1))
|
||||
.expect("kp tensor to vec")
|
||||
.iter()
|
||||
.map(|&v| v as f32)
|
||||
.collect();
|
||||
@@ -652,7 +661,8 @@ fn extract_kp_ndarray(kp_tensor: &Tensor, batch_idx: usize) -> Array2<f32> {
|
||||
fn extract_vis_ndarray(vis_tensor: &Tensor, batch_idx: usize) -> Array1<f32> {
|
||||
let num_kp = vis_tensor.size()[1] as usize;
|
||||
let row = vis_tensor.select(0, batch_idx as i64);
|
||||
let data: Vec<f32> = Vec::<f64>::from(row.to_kind(Kind::Double))
|
||||
let data: Vec<f32> = Vec::<f64>::try_from(row.to_kind(Kind::Double))
|
||||
.expect("vis tensor to vec")
|
||||
.iter()
|
||||
.map(|&v| v as f32)
|
||||
.collect();
|
||||
|
||||
@@ -0,0 +1,899 @@
|
||||
//! Configuration and pure-Rust shape/parameter math for WiFlow-STD
|
||||
//! (ADR-152 §2.2). See the [module docs](crate::wiflow_std) for provenance.
|
||||
//!
|
||||
//! Everything here compiles without the `tch-backend` feature so the
|
||||
//! architecture's invariants (parameter count, output shapes, divisibility
|
||||
//! constraints) are unit-testable under `--no-default-features`. The
|
||||
//! 15-keypoint default must yield exactly **2,225,042** parameters — the
|
||||
//! count verified against the upstream reference (`RESULTS.md`).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// TCN kernel size — fixed at 3 in the reference architecture.
|
||||
pub const TCN_KERNEL: usize = 3;
|
||||
|
||||
/// Dropout used inside the 2-D conv blocks (`Dropout2d`). The reference
|
||||
/// hardcodes 0.3 in `convnet.py` (the model-level `dropout` argument is only
|
||||
/// forwarded to the TCN), so it is a constant here rather than a config field.
|
||||
pub const CONV_BLOCK_DROPOUT: f64 = 0.3;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TcnGroupsMode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// How the group count of each depthwise-grouped TCN convolution is chosen
|
||||
/// (ADR-152 efficiency sweep, `benchmarks/wiflow-std/remote/sweep/model_compact.py`).
|
||||
///
|
||||
/// The upstream reference hardcodes `groups = 20`, which does not divide the
|
||||
/// compact variants' channel counts (e.g. 270, 135, 85). The sweep's rules:
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TcnGroupsMode {
|
||||
/// Every grouped conv uses [`WiFlowStdConfig::tcn_groups`] verbatim
|
||||
/// (upstream behavior; requires divisibility). Default.
|
||||
#[default]
|
||||
Fixed,
|
||||
/// Per-conv groups = `gcd(channels, tcn_groups)` — equals `tcn_groups`
|
||||
/// wherever the upstream choice is valid (incl. the 540-channel input
|
||||
/// conv) and falls back to the largest common divisor otherwise.
|
||||
/// The sweep's `gcd20` mode (`half` / `quarter` presets).
|
||||
Gcd,
|
||||
/// Per-conv groups = channels (fully depthwise; `tiny` preset).
|
||||
Depthwise,
|
||||
}
|
||||
|
||||
fn gcd(a: usize, b: usize) -> usize {
|
||||
let (mut a, mut b) = (a, b);
|
||||
while b != 0 {
|
||||
(a, b) = (b, a % b);
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
fn default_input_pw_groups() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_min_feature_width() -> usize {
|
||||
15
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WiFlowStdConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Hyper-parameters for the WiFlow-STD pose model (ADR-152 §2.2).
|
||||
///
|
||||
/// Defaults reproduce the verified upstream architecture exactly (2,225,042
|
||||
/// parameters, 15 keypoints). For RuView's ESP32 17-keypoint eval set
|
||||
/// (ADR-152 §2.2(b)) use [`WiFlowStdConfig::for_keypoints`]`(17)` — the
|
||||
/// keypoint count only changes the final adaptive pooling, not the parameter
|
||||
/// count, so retrained 15-keypoint weights remain shape-compatible.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WiFlowStdConfig {
|
||||
/// CSI input feature dimension (subcarriers × antenna paths flattened).
|
||||
/// Must be divisible by [`Self::tcn_groups`]. Default: **540**.
|
||||
pub subcarriers: usize,
|
||||
|
||||
/// Temporal window length in CSI frames. Default: **20**.
|
||||
pub window: usize,
|
||||
|
||||
/// Output channels of each TCN level (dilation doubles per level:
|
||||
/// 1, 2, 4, 8, …). Every entry must be divisible by [`Self::tcn_groups`].
|
||||
/// Default: **[540, 440, 340, 240]** — the `models/` code values, *not*
|
||||
/// upstream `config.py`'s stale `[480, 360, 240]`.
|
||||
pub tcn_channels: Vec<usize>,
|
||||
|
||||
/// Group count for the depthwise-grouped TCN convolutions. The reference
|
||||
/// hardcodes **20**; exposed so non-540 subcarrier layouts can keep the
|
||||
/// divisibility invariant. Default: **20**. Interpreted per
|
||||
/// [`Self::tcn_groups_mode`]: the verbatim group count in `Fixed` mode,
|
||||
/// the gcd base in `Gcd` mode, ignored in `Depthwise` mode.
|
||||
pub tcn_groups: usize,
|
||||
|
||||
/// Group-selection rule for the TCN's grouped convolutions
|
||||
/// (ADR-152 efficiency sweep). Default: [`TcnGroupsMode::Fixed`]
|
||||
/// (upstream behavior — every grouped conv uses [`Self::tcn_groups`]).
|
||||
#[serde(default)]
|
||||
pub tcn_groups_mode: TcnGroupsMode,
|
||||
|
||||
/// Group count for the **first** TCN block's pointwise (1×1) and residual
|
||||
/// downsample convs (`subcarriers → tcn_channels[0]`). The sweep's `tiny`
|
||||
/// variant uses **4** to break the dense-540-input parameter floor
|
||||
/// (~117k params, which alone exceeds tiny's budget); every other config
|
||||
/// uses **1** (upstream behavior). Must divide both `subcarriers` and
|
||||
/// `tcn_channels[0]`. Default: **1**.
|
||||
#[serde(default = "default_input_pw_groups")]
|
||||
pub input_pw_groups: usize,
|
||||
|
||||
/// Output channels of the 2-D conv encoder blocks. The first entry is
|
||||
/// also `ConvBlock1`'s output; each subsequent block downsamples the
|
||||
/// subcarrier axis by 2. Default: **[8, 16, 32, 64]**.
|
||||
pub conv_channels: Vec<usize>,
|
||||
|
||||
/// Attention head groups for the dual axial attention. Must divide the
|
||||
/// last entry of [`Self::conv_channels`]. Default: **8**.
|
||||
pub attention_groups: usize,
|
||||
|
||||
/// Number of 2-D keypoints produced. Default: **15** (upstream skeleton);
|
||||
/// use **17** for RuView's COCO-skeleton ESP32 eval set. Only changes the
|
||||
/// parameter-free final adaptive pool — never the trunk: the stride
|
||||
/// schedule is governed by [`Self::min_feature_width`], so 15- and
|
||||
/// 17-keypoint variants share the identical conv graph and weights
|
||||
/// (matching the validated Python protocol,
|
||||
/// `benchmarks/wiflow-std/remote/measb/train_measb.py`, which swaps only
|
||||
/// `avg_pool` and loads the pretrained state_dict `strict=True`).
|
||||
pub keypoints: usize,
|
||||
|
||||
/// Floor for the conv encoder's width downsampling: each
|
||||
/// `AsymmetricConvBlock` halves the width only while the result stays
|
||||
/// ≥ this value (see [`Self::conv_strides`]).
|
||||
///
|
||||
/// Default: **15** — the upstream constant. Provenance: the reference's
|
||||
/// four hardcoded stride-2 blocks exist because its 240-channel TCN
|
||||
/// output halves cleanly four times, 240 / 2⁴ = 15. The compact presets'
|
||||
/// schedules were derived with this same floor. Override only when
|
||||
/// designing a new trunk; do **not** couple it to [`Self::keypoints`] —
|
||||
/// the adaptive pool maps the decoder height to any keypoint count.
|
||||
#[serde(default = "default_min_feature_width")]
|
||||
pub min_feature_width: usize,
|
||||
|
||||
/// Elementwise dropout probability inside the TCN blocks, in `[0, 1)`.
|
||||
/// Default: **0.5** (the value used by our verified retraining run).
|
||||
pub dropout: f64,
|
||||
}
|
||||
|
||||
impl Default for WiFlowStdConfig {
|
||||
fn default() -> Self {
|
||||
WiFlowStdConfig {
|
||||
subcarriers: 540,
|
||||
window: 20,
|
||||
tcn_channels: vec![540, 440, 340, 240],
|
||||
tcn_groups: 20,
|
||||
tcn_groups_mode: TcnGroupsMode::Fixed,
|
||||
input_pw_groups: 1,
|
||||
conv_channels: vec![8, 16, 32, 64],
|
||||
attention_groups: 8,
|
||||
keypoints: 15,
|
||||
min_feature_width: 15,
|
||||
dropout: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WiFlowStdConfig {
|
||||
/// Default architecture with a different keypoint count (e.g. 17 for the
|
||||
/// ESP32 COCO-skeleton eval set, ADR-152 §2.2(b)).
|
||||
///
|
||||
/// The trunk is untouched: [`Self::min_feature_width`] stays at the
|
||||
/// upstream floor of 15, so e.g. `for_keypoints(17)` keeps the trained
|
||||
/// `[2, 2, 2, 2]` stride schedule (feature width 15) and the adaptive
|
||||
/// pool maps 15 → 17 — exactly the validated Python protocol
|
||||
/// (`benchmarks/wiflow-std/remote/measb/train_measb.py`).
|
||||
pub fn for_keypoints(keypoints: usize) -> Self {
|
||||
WiFlowStdConfig {
|
||||
keypoints,
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// **half** compact preset (ADR-152 efficiency sweep, trained
|
||||
/// 2026-06-10/11): **843,834** parameters (0.38×), clean-test PCK@20
|
||||
/// **96.62%** — strictly dominates the full reference on its own
|
||||
/// benchmark. Per-conv groups = `gcd(channels, 20)`; stride schedule
|
||||
/// derives to `[2, 2, 2, 1]`. See
|
||||
/// `benchmarks/wiflow-std/results/efficiency_sweep.jsonl`.
|
||||
pub fn half() -> Self {
|
||||
WiFlowStdConfig {
|
||||
tcn_channels: vec![270, 220, 170, 120],
|
||||
tcn_groups_mode: TcnGroupsMode::Gcd,
|
||||
conv_channels: vec![4, 8, 16, 32],
|
||||
attention_groups: 4,
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// **quarter** compact preset (ADR-152 efficiency sweep): **338,600**
|
||||
/// parameters (0.15×), clean-test PCK@20 **96.05%**. Per-conv groups =
|
||||
/// `gcd(channels, 20)`; stride schedule derives to `[2, 2, 1, 1]`.
|
||||
pub fn quarter() -> Self {
|
||||
WiFlowStdConfig {
|
||||
tcn_channels: vec![135, 110, 85, 60],
|
||||
tcn_groups_mode: TcnGroupsMode::Gcd,
|
||||
conv_channels: vec![2, 4, 8, 16],
|
||||
attention_groups: 2,
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// **tiny** compact preset (ADR-152 efficiency sweep): **56,290**
|
||||
/// parameters (0.025×), clean-test PCK@20 **94.11%** — the smallest
|
||||
/// deployable WiFlow-class model (~220 KB fp32). Fully depthwise TCN
|
||||
/// groups plus `input_pw_groups = 4` on the first block's pointwise /
|
||||
/// downsample convs; stride schedule derives to `[2, 1, 1, 1]`
|
||||
/// (feature width 16).
|
||||
pub fn tiny() -> Self {
|
||||
WiFlowStdConfig {
|
||||
tcn_channels: vec![68, 56, 44, 32],
|
||||
tcn_groups_mode: TcnGroupsMode::Depthwise,
|
||||
input_pw_groups: 4,
|
||||
conv_channels: vec![2, 4, 8, 16],
|
||||
attention_groups: 2,
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate all architectural invariants.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ConfigError::InvalidValue`] naming the offending field.
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
if self.subcarriers == 0 {
|
||||
return Err(ConfigError::invalid_value("subcarriers", "must be >= 1"));
|
||||
}
|
||||
if self.window == 0 {
|
||||
return Err(ConfigError::invalid_value("window", "must be >= 1"));
|
||||
}
|
||||
if self.tcn_groups == 0 {
|
||||
return Err(ConfigError::invalid_value("tcn_groups", "must be >= 1"));
|
||||
}
|
||||
// In Gcd mode the per-conv group count is gcd(channels, tcn_groups)
|
||||
// and in Depthwise mode it is the channel count itself, so the
|
||||
// divisibility invariant holds by construction; only Fixed mode
|
||||
// (upstream behavior) needs the explicit checks.
|
||||
let fixed = self.tcn_groups_mode == TcnGroupsMode::Fixed;
|
||||
if fixed && self.subcarriers % self.tcn_groups != 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"subcarriers",
|
||||
format!(
|
||||
"{} is not divisible by tcn_groups={} (grouped conv requirement)",
|
||||
self.subcarriers, self.tcn_groups
|
||||
),
|
||||
));
|
||||
}
|
||||
if self.tcn_channels.is_empty() {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"tcn_channels",
|
||||
"must contain at least one level",
|
||||
));
|
||||
}
|
||||
for (i, &c) in self.tcn_channels.iter().enumerate() {
|
||||
if c == 0 || (fixed && c % self.tcn_groups != 0) {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"tcn_channels",
|
||||
format!(
|
||||
"level {i} has {c} channels; must be > 0 and divisible by tcn_groups={}",
|
||||
self.tcn_groups
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
if self.input_pw_groups == 0
|
||||
|| self.subcarriers % self.input_pw_groups != 0
|
||||
|| self.tcn_channels[0] % self.input_pw_groups != 0
|
||||
{
|
||||
return Err(ConfigError::invalid_value(
|
||||
"input_pw_groups",
|
||||
format!(
|
||||
"{} must be >= 1 and divide both subcarriers={} and tcn_channels[0]={}",
|
||||
self.input_pw_groups, self.subcarriers, self.tcn_channels[0]
|
||||
),
|
||||
));
|
||||
}
|
||||
if self.conv_channels.is_empty() {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"conv_channels",
|
||||
"must contain at least one block",
|
||||
));
|
||||
}
|
||||
if self.conv_channels.iter().any(|&c| c == 0) {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"conv_channels",
|
||||
"all blocks must have > 0 channels",
|
||||
));
|
||||
}
|
||||
let c_last = *self.conv_channels.last().expect("non-empty checked above");
|
||||
if self.attention_groups == 0 || c_last % self.attention_groups != 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"attention_groups",
|
||||
format!(
|
||||
"{} must be >= 1 and divide the last conv channel count {c_last}",
|
||||
self.attention_groups
|
||||
),
|
||||
));
|
||||
}
|
||||
if c_last < 2 || c_last % 2 != 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"conv_channels",
|
||||
format!("last block has {c_last} channels; decoder needs an even count >= 2"),
|
||||
));
|
||||
}
|
||||
if self.keypoints == 0 {
|
||||
return Err(ConfigError::invalid_value("keypoints", "must be >= 1"));
|
||||
}
|
||||
if self.min_feature_width == 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"min_feature_width",
|
||||
"must be >= 1",
|
||||
));
|
||||
}
|
||||
if !self.dropout.is_finite() || !(0.0..1.0).contains(&self.dropout) {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"dropout",
|
||||
format!("{} is outside [0, 1)", self.dropout),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shape inference
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Channel count produced by the TCN stack (last TCN level). This is the
|
||||
/// *width* of the image-like tensor fed to the 2-D encoder.
|
||||
pub fn tcn_output_channels(&self) -> usize {
|
||||
*self.tcn_channels.last().unwrap_or(&0)
|
||||
}
|
||||
|
||||
/// Group count of a grouped TCN conv over `channels` channels, per
|
||||
/// [`Self::tcn_groups_mode`].
|
||||
pub fn tcn_conv_groups(&self, channels: usize) -> usize {
|
||||
match self.tcn_groups_mode {
|
||||
TcnGroupsMode::Fixed => self.tcn_groups,
|
||||
TcnGroupsMode::Gcd => gcd(channels, self.tcn_groups),
|
||||
TcnGroupsMode::Depthwise => channels,
|
||||
}
|
||||
}
|
||||
|
||||
/// Width stride of each `AsymmetricConvBlock`, derived with the sweep's
|
||||
/// rule (`model_compact.py::compute_strides`): halve the width
|
||||
/// (`w → ceil(w / 2)`, the `(1,3)`-kernel stride-2 output size) only
|
||||
/// while the result stays ≥ [`Self::min_feature_width`]. At the upstream
|
||||
/// default (240 TCN channels, floor 15) this derives `[2, 2, 2, 2]` —
|
||||
/// the hardcoded upstream schedule, exactly.
|
||||
///
|
||||
/// Deliberately independent of [`Self::keypoints`]: the keypoint count
|
||||
/// only changes the parameter-free adaptive pool, so retargeting the
|
||||
/// skeleton (e.g. [`Self::for_keypoints`]`(17)`) keeps the trained graph
|
||||
/// and the pool maps `feature_width() → keypoints`.
|
||||
pub fn conv_strides(&self) -> Vec<usize> {
|
||||
let mut w = self.tcn_output_channels();
|
||||
let mut strides = Vec::with_capacity(self.conv_channels.len());
|
||||
for _ in &self.conv_channels {
|
||||
let next = w.div_ceil(2);
|
||||
if next >= self.min_feature_width {
|
||||
strides.push(2);
|
||||
w = next;
|
||||
} else {
|
||||
strides.push(1);
|
||||
}
|
||||
}
|
||||
strides
|
||||
}
|
||||
|
||||
/// Width of the encoder feature map after the conv blocks.
|
||||
///
|
||||
/// `ConvBlock1` preserves width; each `AsymmetricConvBlock` applies a
|
||||
/// `(1, 3)` kernel with padding `(0, 1)` and the per-block stride from
|
||||
/// [`Self::conv_strides`]. Default: 240 → 120 → 60 → 30 → **15**.
|
||||
pub fn feature_width(&self) -> usize {
|
||||
let mut w = self.tcn_output_channels();
|
||||
for s in self.conv_strides() {
|
||||
if s == 2 {
|
||||
w = w.div_ceil(2);
|
||||
}
|
||||
}
|
||||
w
|
||||
}
|
||||
|
||||
/// Mid-channel count of the decoder's 3×3 conv:
|
||||
/// `max(conv_channels.last() / 2, 4)` (the sweep's floor of 4 keeps the
|
||||
/// decoder viable at very small widths; identical to the upstream `c / 2`
|
||||
/// for every channel count ≥ 8, including the default 64 → 32).
|
||||
pub fn decoder_mid(&self) -> usize {
|
||||
(self.conv_channels.last().unwrap_or(&0) / 2).max(4)
|
||||
}
|
||||
|
||||
/// Output tensor shape `(batch, keypoints, 2)`. The adaptive average pool
|
||||
/// maps the feature height to `keypoints` regardless of its size, so the
|
||||
/// keypoint count is free (15 and 17 share identical weights).
|
||||
pub fn output_shape(&self, batch: usize) -> (usize, usize, usize) {
|
||||
(batch, self.keypoints, 2)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Parameter-count formula
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Total trainable parameter count, derived layer-by-layer from the
|
||||
/// architecture (BatchNorm weight+bias counted; running stats are buffers
|
||||
/// and excluded, matching PyTorch's `numel` convention).
|
||||
///
|
||||
/// Pins the port against the verified reference: the 15-keypoint default
|
||||
/// must equal **2,225,042** (`RESULTS.md` artifact verification).
|
||||
///
|
||||
/// Returns **0** for any config that fails [`Self::validate`]: the
|
||||
/// formula is only meaningful for buildable architectures (an invalid
|
||||
/// config would otherwise index an empty `conv_channels` or divide by a
|
||||
/// zero group count). Call `validate()` first when you need the reason.
|
||||
pub fn param_count(&self) -> usize {
|
||||
if self.validate().is_err() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut total = 0;
|
||||
|
||||
// TCN stack: per-conv groups follow tcn_groups_mode; only the first
|
||||
// block's pointwise/downsample convs use input_pw_groups.
|
||||
let mut c_in = self.subcarriers;
|
||||
for (i, &c_out) in self.tcn_channels.iter().enumerate() {
|
||||
let pw_groups = if i == 0 { self.input_pw_groups } else { 1 };
|
||||
total += tcn_block_params(
|
||||
c_in,
|
||||
c_out,
|
||||
TCN_KERNEL,
|
||||
self.tcn_conv_groups(c_in),
|
||||
self.tcn_conv_groups(c_out),
|
||||
pw_groups,
|
||||
);
|
||||
c_in = c_out;
|
||||
}
|
||||
|
||||
// ConvBlock1 (1 → conv_channels[0]) + asymmetric blocks. Both block
|
||||
// kinds have identical parameter shapes (stride changes nothing).
|
||||
let mut c_in = 1;
|
||||
total += conv_block_params(c_in, self.conv_channels[0]);
|
||||
c_in = self.conv_channels[0];
|
||||
for &c_out in &self.conv_channels {
|
||||
total += conv_block_params(c_in, c_out);
|
||||
c_in = c_out;
|
||||
}
|
||||
|
||||
// Dual axial attention: width axis + height axis, both c_in → c_in.
|
||||
total += 2 * axial_attention_params(c_in, self.attention_groups);
|
||||
|
||||
// Decoder: 3×3 conv (c → decoder_mid) + BN + 1×1 conv (mid → 2) + BN.
|
||||
total += decoder_params(c_in, self.decoder_mid());
|
||||
|
||||
total
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-component parameter formulas
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// One `InnerGroupedTemporalBlock`: two (depthwise-grouped conv → BN →
|
||||
/// pointwise conv → BN) stages plus a 1×1 + BN residual projection when the
|
||||
/// channel count changes. All convs are bias-free. `g_in`/`g_out` are the
|
||||
/// group counts of the two grouped convs (each conv groups over its own
|
||||
/// channel count — they differ in `Gcd`/`Depthwise` mode); `pw_groups`
|
||||
/// groups the first pointwise conv and the residual projection (the sweep's
|
||||
/// `input_pw_groups`, block 0 only — 1 everywhere else).
|
||||
fn tcn_block_params(
|
||||
c_in: usize,
|
||||
c_out: usize,
|
||||
k: usize,
|
||||
g_in: usize,
|
||||
g_out: usize,
|
||||
pw_groups: usize,
|
||||
) -> usize {
|
||||
let grouped1 = c_in * (c_in / g_in) * k; // depthwise-grouped, c_in → c_in
|
||||
let bn1g = 2 * c_in;
|
||||
let pw1 = c_out * (c_in / pw_groups); // pointwise 1×1
|
||||
let bn1p = 2 * c_out;
|
||||
let grouped2 = c_out * (c_out / g_out) * k;
|
||||
let bn2g = 2 * c_out;
|
||||
let pw2 = c_out * c_out;
|
||||
let bn2p = 2 * c_out;
|
||||
let downsample = if c_in != c_out {
|
||||
(c_in / pw_groups) * c_out + 2 * c_out
|
||||
} else {
|
||||
0
|
||||
};
|
||||
grouped1 + bn1g + pw1 + bn1p + grouped2 + bn2g + pw2 + bn2p + downsample
|
||||
}
|
||||
|
||||
/// One `ConvBlock1` / `AsymmetricConvBlock`: three (1, 3) convs **with bias**
|
||||
/// + BN each, plus a bias-free 1×1 + BN residual projection.
|
||||
fn conv_block_params(c_in: usize, c_out: usize) -> usize {
|
||||
let conv1 = c_out * c_in * 3 + c_out;
|
||||
let conv_rest = 2 * (c_out * c_out * 3 + c_out);
|
||||
let bns = 3 * 2 * c_out;
|
||||
let downsample = c_in * c_out + 2 * c_out;
|
||||
conv1 + conv_rest + bns + downsample
|
||||
}
|
||||
|
||||
/// One `AxialAttention` axis: bias-free 1×1 qkv conv (c → 3c), BN over the
|
||||
/// 3c qkv channels, BN over the `groups` similarity maps, BN over the output.
|
||||
fn axial_attention_params(c: usize, groups: usize) -> usize {
|
||||
let qkv = c * 3 * c;
|
||||
let bn_qkv = 2 * (3 * c);
|
||||
let bn_similarity = 2 * groups;
|
||||
let bn_output = 2 * c;
|
||||
qkv + bn_qkv + bn_similarity + bn_output
|
||||
}
|
||||
|
||||
/// Decoder: `Conv2d(c → mid, 3×3, bias)` + BN + `Conv2d(mid → 2, 1×1, bias)`
|
||||
/// + BN, where `mid` = [`WiFlowStdConfig::decoder_mid`].
|
||||
fn decoder_params(c: usize, mid: usize) -> usize {
|
||||
let conv1 = mid * c * 9 + mid;
|
||||
let bn1 = 2 * mid;
|
||||
let conv2 = 2 * mid + 2;
|
||||
let bn2 = 2 * 2;
|
||||
conv1 + bn1 + conv2 + bn2
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests (pure Rust — run under --no-default-features)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Reference parameter count verified against the upstream checkpoint
|
||||
/// and `torchinfo` (benchmarks/wiflow-std/RESULTS.md, 2026-06-10).
|
||||
const REFERENCE_PARAMS: usize = 2_225_042;
|
||||
|
||||
#[test]
|
||||
fn default_config_is_valid() {
|
||||
WiFlowStdConfig::default()
|
||||
.validate()
|
||||
.expect("default config must validate");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_param_count_matches_verified_reference() {
|
||||
assert_eq!(WiFlowStdConfig::default().param_count(), REFERENCE_PARAMS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn param_count_is_independent_of_keypoints() {
|
||||
// The keypoint count only changes the parameter-free adaptive pool,
|
||||
// so 15- and 17-keypoint variants share identical weights.
|
||||
let kp17 = WiFlowStdConfig::for_keypoints(17);
|
||||
kp17.validate().expect("17-keypoint config must validate");
|
||||
assert_eq!(kp17.param_count(), REFERENCE_PARAMS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_component_breakdown_matches_hand_calculation() {
|
||||
// TCN levels (hand-verified against the reference layer shapes).
|
||||
assert_eq!(tcn_block_params(540, 540, 3, 20, 20, 1), 675_000);
|
||||
assert_eq!(tcn_block_params(540, 440, 3, 20, 20, 1), 746_180);
|
||||
assert_eq!(tcn_block_params(440, 340, 3, 20, 20, 1), 464_780);
|
||||
assert_eq!(tcn_block_params(340, 240, 3, 20, 20, 1), 249_380);
|
||||
// Conv encoder.
|
||||
assert_eq!(conv_block_params(1, 8), 504);
|
||||
assert_eq!(conv_block_params(8, 8), 728);
|
||||
assert_eq!(conv_block_params(8, 16), 2_224);
|
||||
assert_eq!(conv_block_params(16, 32), 8_544);
|
||||
assert_eq!(conv_block_params(32, 64), 33_472);
|
||||
// Attention + decoder.
|
||||
assert_eq!(axial_attention_params(64, 8), 12_816);
|
||||
assert_eq!(decoder_params(64, 32), 18_598);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ADR-152 efficiency-sweep compact presets. The parameter pins are
|
||||
// GROUND TRUTH measured from the trained PyTorch checkpoints
|
||||
// (benchmarks/wiflow-std/results/efficiency_sweep.jsonl, 2026-06-11):
|
||||
// any mismatch means the Rust formula or config mapping is wrong.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn half_preset_param_count_matches_trained_checkpoint() {
|
||||
let cfg = WiFlowStdConfig::half();
|
||||
cfg.validate().expect("half preset must validate");
|
||||
assert_eq!(cfg.param_count(), 843_834);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quarter_preset_param_count_matches_trained_checkpoint() {
|
||||
let cfg = WiFlowStdConfig::quarter();
|
||||
cfg.validate().expect("quarter preset must validate");
|
||||
assert_eq!(cfg.param_count(), 338_600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tiny_preset_param_count_matches_trained_checkpoint() {
|
||||
let cfg = WiFlowStdConfig::tiny();
|
||||
cfg.validate().expect("tiny preset must validate");
|
||||
assert_eq!(cfg.param_count(), 56_290);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preset_tcn_groups_match_sweep_per_block_record() {
|
||||
// efficiency_sweep.jsonl "tcn_groups_per_block": (conv1, conv2) of
|
||||
// each block — conv1 groups over c_in, conv2 over c_out.
|
||||
let half = WiFlowStdConfig::half();
|
||||
let groups: Vec<(usize, usize)> = {
|
||||
let mut c_in = half.subcarriers;
|
||||
half.tcn_channels
|
||||
.iter()
|
||||
.map(|&c_out| {
|
||||
let g = (half.tcn_conv_groups(c_in), half.tcn_conv_groups(c_out));
|
||||
c_in = c_out;
|
||||
g
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
assert_eq!(groups, [(20, 10), (10, 20), (20, 10), (10, 20)]);
|
||||
|
||||
let tiny = WiFlowStdConfig::tiny();
|
||||
assert_eq!(tiny.tcn_conv_groups(540), 540); // depthwise input conv
|
||||
assert_eq!(tiny.tcn_conv_groups(68), 68);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preset_stride_schedules_match_sweep_record() {
|
||||
// efficiency_sweep.jsonl "conv_strides" / "final_width".
|
||||
assert_eq!(WiFlowStdConfig::default().conv_strides(), [2, 2, 2, 2]);
|
||||
assert_eq!(WiFlowStdConfig::half().conv_strides(), [2, 2, 2, 1]);
|
||||
assert_eq!(WiFlowStdConfig::quarter().conv_strides(), [2, 2, 1, 1]);
|
||||
assert_eq!(WiFlowStdConfig::tiny().conv_strides(), [2, 1, 1, 1]);
|
||||
assert_eq!(WiFlowStdConfig::half().feature_width(), 15);
|
||||
assert_eq!(WiFlowStdConfig::quarter().feature_width(), 15);
|
||||
assert_eq!(WiFlowStdConfig::tiny().feature_width(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_keypoints_17_keeps_trained_trunk_and_pools_15_to_17() {
|
||||
// Pin against the validated Python protocol (train_measb.py): K=17
|
||||
// swaps only the adaptive pool, never the stride schedule. A derived
|
||||
// [2, 2, 2, 1]/width-30 graph here would silently diverge from the
|
||||
// trained [2, 2, 2, 2]/width-15 checkpoint.
|
||||
let cfg = WiFlowStdConfig::for_keypoints(17);
|
||||
assert_eq!(cfg.min_feature_width, 15);
|
||||
assert_eq!(cfg.conv_strides(), [2, 2, 2, 2]);
|
||||
assert_eq!(cfg.feature_width(), 15);
|
||||
assert_eq!(cfg.output_shape(1), (1, 17, 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_feature_width_override_changes_schedule_as_designed() {
|
||||
// Raising the floor stops the downsampling earlier (240 → 30).
|
||||
let cfg = WiFlowStdConfig {
|
||||
min_feature_width: 30,
|
||||
..Default::default()
|
||||
};
|
||||
cfg.validate().expect("floor 30 validates");
|
||||
assert_eq!(cfg.conv_strides(), [2, 2, 2, 1]);
|
||||
assert_eq!(cfg.feature_width(), 30);
|
||||
|
||||
// Lowering it lets a small trunk halve further (tiny: 32 → 8).
|
||||
let cfg = WiFlowStdConfig {
|
||||
min_feature_width: 8,
|
||||
..WiFlowStdConfig::tiny()
|
||||
};
|
||||
cfg.validate().expect("floor 8 validates");
|
||||
assert_eq!(cfg.conv_strides(), [2, 2, 1, 1]);
|
||||
assert_eq!(cfg.feature_width(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_zero_min_feature_width() {
|
||||
let cfg = WiFlowStdConfig {
|
||||
min_feature_width: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn param_count_returns_zero_for_invalid_configs() {
|
||||
// Documented total behavior: configs that fail validate() yield 0
|
||||
// instead of panicking (OOB index / division by zero).
|
||||
for cfg in [
|
||||
WiFlowStdConfig {
|
||||
conv_channels: vec![],
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
tcn_groups: 0,
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
input_pw_groups: 0,
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
tcn_channels: vec![],
|
||||
..Default::default()
|
||||
},
|
||||
] {
|
||||
assert!(cfg.validate().is_err(), "precondition: {cfg:?} is invalid");
|
||||
assert_eq!(cfg.param_count(), 0, "no panic, returns 0: {cfg:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_mode_with_defaults_is_unchanged_by_new_knobs() {
|
||||
// The new fields default to upstream behavior: gcd(c, 20) == 20 for
|
||||
// every default channel count, so Gcd mode is also a no-op there.
|
||||
let mut cfg = WiFlowStdConfig::default();
|
||||
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
|
||||
cfg.tcn_groups_mode = TcnGroupsMode::Gcd;
|
||||
cfg.validate().expect("gcd mode validates at defaults");
|
||||
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
|
||||
assert_eq!(WiFlowStdConfig::default().decoder_mid(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_bad_input_pw_groups() {
|
||||
// 7 divides neither 540 nor 540's first TCN level.
|
||||
let cfg = WiFlowStdConfig {
|
||||
input_pw_groups: 7,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
// 27 divides subcarriers=540 but not tiny's tcn_channels[0]=68.
|
||||
let cfg = WiFlowStdConfig {
|
||||
input_pw_groups: 27,
|
||||
..WiFlowStdConfig::tiny()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
let zero = WiFlowStdConfig {
|
||||
input_pw_groups: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(zero.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_defaults_for_new_fields_are_backward_compatible() {
|
||||
// A config serialized before the compact-variant knobs existed must
|
||||
// deserialize to upstream behavior (Fixed mode, input_pw_groups 1).
|
||||
let legacy = r#"{
|
||||
"subcarriers": 540, "window": 20,
|
||||
"tcn_channels": [540, 440, 340, 240], "tcn_groups": 20,
|
||||
"conv_channels": [8, 16, 32, 64], "attention_groups": 8,
|
||||
"keypoints": 15, "dropout": 0.5
|
||||
}"#;
|
||||
let cfg: WiFlowStdConfig = serde_json::from_str(legacy).expect("deserialize");
|
||||
assert_eq!(cfg, WiFlowStdConfig::default());
|
||||
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_roundtrip_preserves_presets() {
|
||||
for cfg in [
|
||||
WiFlowStdConfig::half(),
|
||||
WiFlowStdConfig::quarter(),
|
||||
WiFlowStdConfig::tiny(),
|
||||
] {
|
||||
let json = serde_json::to_string(&cfg).expect("serialize");
|
||||
let back: WiFlowStdConfig = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(back, cfg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_shape_default_and_esp32() {
|
||||
assert_eq!(WiFlowStdConfig::default().output_shape(4), (4, 15, 2));
|
||||
assert_eq!(
|
||||
WiFlowStdConfig::for_keypoints(17).output_shape(1),
|
||||
(1, 17, 2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feature_width_default_is_15() {
|
||||
// 240 → 120 → 60 → 30 → 15 (four stride-(1,2) blocks).
|
||||
assert_eq!(WiFlowStdConfig::default().feature_width(), 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcn_output_channels_default_is_240() {
|
||||
assert_eq!(WiFlowStdConfig::default().tcn_output_channels(), 240);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_subcarriers_not_divisible_by_groups() {
|
||||
let cfg = WiFlowStdConfig {
|
||||
subcarriers: 541,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_zero_dimensions() {
|
||||
for cfg in [
|
||||
WiFlowStdConfig {
|
||||
subcarriers: 0,
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
window: 0,
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
keypoints: 0,
|
||||
..Default::default()
|
||||
},
|
||||
WiFlowStdConfig {
|
||||
tcn_groups: 0,
|
||||
..Default::default()
|
||||
},
|
||||
] {
|
||||
assert!(cfg.validate().is_err(), "expected rejection: {cfg:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_or_indivisible_tcn_channels() {
|
||||
let empty = WiFlowStdConfig {
|
||||
tcn_channels: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(empty.validate().is_err());
|
||||
|
||||
let indivisible = WiFlowStdConfig {
|
||||
tcn_channels: vec![540, 441],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(indivisible.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_bad_conv_channels() {
|
||||
let empty = WiFlowStdConfig {
|
||||
conv_channels: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(empty.validate().is_err());
|
||||
|
||||
let zero = WiFlowStdConfig {
|
||||
conv_channels: vec![8, 0, 64],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(zero.validate().is_err());
|
||||
|
||||
// Odd last channel breaks the c → c/2 decoder split.
|
||||
let odd_last = WiFlowStdConfig {
|
||||
conv_channels: vec![8, 16, 33],
|
||||
attention_groups: 1,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(odd_last.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_attention_group_mismatch() {
|
||||
let cfg = WiFlowStdConfig {
|
||||
attention_groups: 7, // 64 % 7 != 0
|
||||
..Default::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
let zero = WiFlowStdConfig {
|
||||
attention_groups: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(zero.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_out_of_range_dropout() {
|
||||
for d in [1.0, 1.5, -0.1, f64::NAN] {
|
||||
let cfg = WiFlowStdConfig {
|
||||
dropout: d,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err(), "dropout {d} must be rejected");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_roundtrip_preserves_config() {
|
||||
let cfg = WiFlowStdConfig::for_keypoints(17);
|
||||
let json = serde_json::to_string(&cfg).expect("serialize");
|
||||
let back: WiFlowStdConfig = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(back, cfg);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,334 @@
|
||||
//! Building-block layers for the WiFlow-STD model (tch backend, ADR-152 §2.2):
|
||||
//! grouped causal TCN blocks, asymmetric residual conv blocks, and dual axial
|
||||
//! attention. Internal to [`super::model`]; see the module docs for provenance.
|
||||
|
||||
use tch::{nn, nn::Module, Tensor};
|
||||
|
||||
use super::config::{CONV_BLOCK_DROPOUT, TCN_KERNEL};
|
||||
|
||||
/// BatchNorm config matching the reference: gamma = 1 (PyTorch default; the
|
||||
/// reference additionally pins BatchNorm1d weight=1/bias=0). tch-0.24's
|
||||
/// `BatchNormConfig::default()` would draw gamma from Uniform(0,1), silently
|
||||
/// halving activations on average in from-scratch training.
|
||||
pub(super) fn bn_cfg() -> nn::BatchNormConfig {
|
||||
nn::BatchNormConfig {
|
||||
ws_init: nn::Init::Const(1.0),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupedTemporalBlock (TCN level)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// One TCN level: two (depthwise-grouped causal conv → BN → SiLU → pointwise
|
||||
/// conv → BN → SiLU → dropout) stages with a residual connection (1×1 + BN
|
||||
/// projection when channels change) and a final SiLU.
|
||||
///
|
||||
/// Causality: each grouped conv pads by `(k-1)·dilation` and the trailing
|
||||
/// padding is chomped off afterwards, exactly like the reference `Chomp1d`.
|
||||
pub(super) struct GroupedTemporalBlock {
|
||||
conv1_group: nn::Conv1D,
|
||||
bn1_group: nn::BatchNorm,
|
||||
conv1_pw: nn::Conv1D,
|
||||
bn1_pw: nn::BatchNorm,
|
||||
conv2_group: nn::Conv1D,
|
||||
bn2_group: nn::BatchNorm,
|
||||
conv2_pw: nn::Conv1D,
|
||||
bn2_pw: nn::BatchNorm,
|
||||
downsample: Option<(nn::Conv1D, nn::BatchNorm)>,
|
||||
dropout: f64,
|
||||
}
|
||||
|
||||
impl GroupedTemporalBlock {
|
||||
/// `g_in`/`g_out`: group counts of the two grouped convs (each conv
|
||||
/// groups over its own channel count — they differ under the ADR-152
|
||||
/// compact variants' `Gcd`/`Depthwise` modes). `pw_groups` groups the
|
||||
/// first pointwise conv and the residual projection (`input_pw_groups`
|
||||
/// on block 0; 1 everywhere else).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn new(
|
||||
vs: nn::Path,
|
||||
c_in: i64,
|
||||
c_out: i64,
|
||||
dilation: i64,
|
||||
g_in: i64,
|
||||
g_out: i64,
|
||||
pw_groups: i64,
|
||||
dropout: f64,
|
||||
) -> Self {
|
||||
let k = TCN_KERNEL as i64;
|
||||
let padding = (k - 1) * dilation;
|
||||
let grouped_cfg = |groups| nn::ConvConfig {
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias: false,
|
||||
..Default::default()
|
||||
};
|
||||
let pointwise_cfg = |groups| nn::ConvConfig {
|
||||
groups,
|
||||
bias: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv1_group = nn::conv1d(&vs / "conv1_group", c_in, c_in, k, grouped_cfg(g_in));
|
||||
let bn1_group = nn::batch_norm1d(&vs / "bn1_group", c_in, bn_cfg());
|
||||
let conv1_pw = nn::conv1d(&vs / "conv1_pw", c_in, c_out, 1, pointwise_cfg(pw_groups));
|
||||
let bn1_pw = nn::batch_norm1d(&vs / "bn1_pw", c_out, bn_cfg());
|
||||
|
||||
let conv2_group = nn::conv1d(&vs / "conv2_group", c_out, c_out, k, grouped_cfg(g_out));
|
||||
let bn2_group = nn::batch_norm1d(&vs / "bn2_group", c_out, bn_cfg());
|
||||
let conv2_pw = nn::conv1d(&vs / "conv2_pw", c_out, c_out, 1, pointwise_cfg(1));
|
||||
let bn2_pw = nn::batch_norm1d(&vs / "bn2_pw", c_out, bn_cfg());
|
||||
|
||||
let downsample = (c_in != c_out).then(|| {
|
||||
(
|
||||
nn::conv1d(&vs / "ds_conv", c_in, c_out, 1, pointwise_cfg(pw_groups)),
|
||||
nn::batch_norm1d(&vs / "ds_bn", c_out, bn_cfg()),
|
||||
)
|
||||
});
|
||||
|
||||
GroupedTemporalBlock {
|
||||
conv1_group,
|
||||
bn1_group,
|
||||
conv1_pw,
|
||||
bn1_pw,
|
||||
conv2_group,
|
||||
bn2_group,
|
||||
conv2_pw,
|
||||
bn2_pw,
|
||||
downsample,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
let res = match &self.downsample {
|
||||
Some((conv, bn)) => conv.forward(x).apply_t(bn, train),
|
||||
None => x.shallow_clone(),
|
||||
};
|
||||
let t = x.size()[2];
|
||||
|
||||
// Stage 1: grouped causal conv (chomp trailing padding) + pointwise.
|
||||
let out = self
|
||||
.conv1_group
|
||||
.forward(x)
|
||||
.narrow(2, 0, t) // Chomp1d
|
||||
.apply_t(&self.bn1_group, train)
|
||||
.silu()
|
||||
.apply(&self.conv1_pw)
|
||||
.apply_t(&self.bn1_pw, train)
|
||||
.silu()
|
||||
.dropout(self.dropout, train);
|
||||
|
||||
// Stage 2.
|
||||
let out = self
|
||||
.conv2_group
|
||||
.forward(&out)
|
||||
.narrow(2, 0, t) // Chomp1d
|
||||
.apply_t(&self.bn2_group, train)
|
||||
.silu()
|
||||
.apply(&self.conv2_pw)
|
||||
.apply_t(&self.bn2_pw, train)
|
||||
.silu()
|
||||
.dropout(self.dropout, train);
|
||||
|
||||
(out + res).silu()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ConvBlock (ConvBlock1 / AsymmetricConvBlock)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Asymmetric residual conv block: three `(1, 3)` convs (only the subcarrier
|
||||
/// axis is convolved) with BN, SiLU and channel dropout, plus a 1×1 + BN
|
||||
/// residual projection. `stride_w == 1` reproduces the reference `ConvBlock1`,
|
||||
/// `stride_w == 2` the downsampling `AsymmetricConvBlock`.
|
||||
pub(super) struct ConvBlock {
|
||||
conv1: nn::Conv2D,
|
||||
bn1: nn::BatchNorm,
|
||||
conv2: nn::Conv2D,
|
||||
bn2: nn::BatchNorm,
|
||||
conv3: nn::Conv2D,
|
||||
bn3: nn::BatchNorm,
|
||||
ds_conv: nn::Conv2D,
|
||||
ds_bn: nn::BatchNorm,
|
||||
}
|
||||
|
||||
impl ConvBlock {
|
||||
pub(super) fn new(vs: nn::Path, c_in: i64, c_out: i64, stride_w: i64) -> Self {
|
||||
let asym = |stride_w| nn::ConvConfigND::<[i64; 2]> {
|
||||
stride: [1, stride_w],
|
||||
padding: [0, 1],
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = nn::conv(&vs / "conv1", c_in, c_out, [1, 3], asym(stride_w));
|
||||
let bn1 = nn::batch_norm2d(&vs / "bn1", c_out, bn_cfg());
|
||||
let conv2 = nn::conv(&vs / "conv2", c_out, c_out, [1, 3], asym(1));
|
||||
let bn2 = nn::batch_norm2d(&vs / "bn2", c_out, bn_cfg());
|
||||
let conv3 = nn::conv(&vs / "conv3", c_out, c_out, [1, 3], asym(1));
|
||||
let bn3 = nn::batch_norm2d(&vs / "bn3", c_out, bn_cfg());
|
||||
|
||||
let ds_conv = nn::conv(
|
||||
&vs / "ds_conv",
|
||||
c_in,
|
||||
c_out,
|
||||
[1, 1],
|
||||
nn::ConvConfigND::<[i64; 2]> {
|
||||
stride: [1, stride_w],
|
||||
bias: false,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let ds_bn = nn::batch_norm2d(&vs / "ds_bn", c_out, bn_cfg());
|
||||
|
||||
ConvBlock {
|
||||
conv1,
|
||||
bn1,
|
||||
conv2,
|
||||
bn2,
|
||||
conv3,
|
||||
bn3,
|
||||
ds_conv,
|
||||
ds_bn,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
let identity = self.ds_conv.forward(x).apply_t(&self.ds_bn, train);
|
||||
let out = x
|
||||
.apply(&self.conv1)
|
||||
.apply_t(&self.bn1, train)
|
||||
.silu()
|
||||
.feature_dropout(CONV_BLOCK_DROPOUT, train) // Dropout2d
|
||||
.apply(&self.conv2)
|
||||
.apply_t(&self.bn2, train)
|
||||
.silu()
|
||||
.feature_dropout(CONV_BLOCK_DROPOUT, train)
|
||||
.apply(&self.conv3)
|
||||
.apply_t(&self.bn3, train);
|
||||
(out + identity).silu()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Axial attention
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Single-axis self-attention with BN-normalised qkv, BN-normalised
|
||||
/// similarity logits and BN-normalised output. `width == true` attends along
|
||||
/// the last (W) axis, otherwise along the H axis; the other spatial axis is
|
||||
/// folded into the batch.
|
||||
pub(super) struct AxialAttention {
|
||||
qkv: nn::Conv1D,
|
||||
bn_qkv: nn::BatchNorm,
|
||||
bn_similarity: nn::BatchNorm,
|
||||
bn_output: nn::BatchNorm,
|
||||
out_planes: i64,
|
||||
groups: i64,
|
||||
width: bool,
|
||||
}
|
||||
|
||||
impl AxialAttention {
|
||||
pub(super) fn new(vs: nn::Path, planes: i64, groups: i64, width: bool) -> Self {
|
||||
// Reference init: N(0, sqrt(1 / in_planes)).
|
||||
let qkv = nn::conv1d(
|
||||
&vs / "qkv",
|
||||
planes,
|
||||
planes * 3,
|
||||
1,
|
||||
nn::ConvConfig {
|
||||
bias: false,
|
||||
ws_init: nn::Init::Randn {
|
||||
mean: 0.0,
|
||||
stdev: (1.0 / planes as f64).sqrt(),
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let bn_qkv = nn::batch_norm1d(&vs / "bn_qkv", planes * 3, bn_cfg());
|
||||
let bn_similarity = nn::batch_norm2d(&vs / "bn_similarity", groups, bn_cfg());
|
||||
let bn_output = nn::batch_norm1d(&vs / "bn_output", planes, bn_cfg());
|
||||
|
||||
AxialAttention {
|
||||
qkv,
|
||||
bn_qkv,
|
||||
bn_similarity,
|
||||
bn_output,
|
||||
out_planes: planes,
|
||||
groups,
|
||||
width,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
// Fold the non-attended spatial axis into the batch:
|
||||
// width: [B,C,H,W] → [B,H,C,W]; height: [B,C,H,W] → [B,W,C,H].
|
||||
let x = if self.width {
|
||||
x.permute([0, 2, 1, 3])
|
||||
} else {
|
||||
x.permute([0, 3, 1, 2])
|
||||
};
|
||||
let (n, outer, c, axis) = {
|
||||
let s = x.size();
|
||||
(s[0], s[1], s[2], s[3])
|
||||
};
|
||||
let flat = x.contiguous().view([n * outer, c, axis]);
|
||||
|
||||
// BN-normalised qkv: [N', 3·C, axis] → grouped q, k, v.
|
||||
let gp = self.out_planes / self.groups; // group planes
|
||||
let qkv = flat.apply(&self.qkv).apply_t(&self.bn_qkv, train).reshape([
|
||||
n * outer,
|
||||
3,
|
||||
self.groups,
|
||||
gp,
|
||||
axis,
|
||||
]);
|
||||
let q = qkv.select(1, 0); // [N', g, gp, axis]
|
||||
let k = qkv.select(1, 1);
|
||||
let v = qkv.select(1, 2);
|
||||
|
||||
// similarity[b,g,i,j] = Σ_c q[b,g,c,i]·k[b,g,c,j], BN over the g maps.
|
||||
let logits = q.transpose(2, 3).matmul(&k); // [N', g, axis, axis]
|
||||
let similarity = logits
|
||||
.apply_t(&self.bn_similarity, train)
|
||||
.softmax(-1, logits.kind());
|
||||
|
||||
// out[b,g,c,i] = Σ_j similarity[b,g,i,j]·v[b,g,c,j].
|
||||
let sv = v.matmul(&similarity.transpose(2, 3)); // [N', g, gp, axis]
|
||||
let out = sv
|
||||
.reshape([n * outer, self.out_planes, axis])
|
||||
.apply_t(&self.bn_output, train)
|
||||
.view([n, outer, self.out_planes, axis]);
|
||||
|
||||
// Restore [B, C, H, W].
|
||||
if self.width {
|
||||
out.permute([0, 2, 1, 3])
|
||||
} else {
|
||||
out.permute([0, 2, 3, 1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Width-axis then height-axis axial attention (the reference
|
||||
/// `DualAxialAttention`, stride 1).
|
||||
pub(super) struct DualAxialAttention {
|
||||
width_axis: AxialAttention,
|
||||
height_axis: AxialAttention,
|
||||
}
|
||||
|
||||
impl DualAxialAttention {
|
||||
pub(super) fn new(vs: nn::Path, planes: i64, groups: i64) -> Self {
|
||||
DualAxialAttention {
|
||||
width_axis: AxialAttention::new(&vs / "width", planes, groups, true),
|
||||
height_axis: AxialAttention::new(&vs / "height", planes, groups, false),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
let x = self.width_axis.forward_t(x, train);
|
||||
self.height_axis.forward_t(&x, train)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
//! WiFlow-STD — spatio-temporal-decoupled CSI pose estimation (ADR-152 §2.2).
|
||||
//!
|
||||
//! Native Rust port of the **WiFlow-STD** architecture by DY2434
|
||||
//! (<https://github.com/DY2434/WiFlow-WiFi-Pose-Estimation-with-Spatio-Temporal-Decoupling>,
|
||||
//! Apache-2.0), reimplemented idiomatically from the vendored read-only
|
||||
//! reference in `benchmarks/wiflow-std/upstream/models/`.
|
||||
//!
|
||||
//! ## Evidence grade (ADR-152 §2.2 citation rule)
|
||||
//!
|
||||
//! Per `benchmarks/wiflow-std/RESULTS.md`, the upstream accuracy claims are
|
||||
//! **MEASURED-EQUIVALENT**: our retraining of the reference implementation on
|
||||
//! the released dataset reproduced **~96% PCK@20** (96.09% full test / 96.61%
|
||||
//! corruption-free; published claim 97.25%). The *shipped* upstream checkpoint
|
||||
//! was REFUTED (0.08% PCK@20 — keypoint-convention mismatch), and the released
|
||||
//! dataset/code required repairs before training converged. Cite this port as
|
||||
//! "~96% PCK@20 (our reproduction)" — **not comparable** to RuView's
|
||||
//! 17-keypoint ESP32 numbers (different hardware, subjects, split, skeleton).
|
||||
//!
|
||||
//! ## Name collision
|
||||
//!
|
||||
//! WiFlow-STD (this module) is the *external* DY2434 architecture. It is
|
||||
//! **distinct from RuView's internal WiFlow** camera-free pose pipeline; the
|
||||
//! `_std` suffix (Spatio-Temporal Decoupling) disambiguates the two.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! CSI window [B, 540 sub, 20 t]
|
||||
//! │ TCN stack: 4 × grouped TemporalBlock (groups=20, k=3, dilation 1/2/4/8,
|
||||
//! │ depthwise-grouped + pointwise convs, causal Chomp1d padding)
|
||||
//! ▼ channels 540 → 540 → 440 → 340 → 240
|
||||
//! [B, 240, 20] ── transpose+unsqueeze ──► [B, 1, 20, 240] (image-like)
|
||||
//! │ ConvBlock1 (1→8, asymmetric 1×3 kernels, no downsampling)
|
||||
//! │ 4 × AsymmetricConvBlock (8→8→16→32→64, stride (1,2) on subcarrier axis)
|
||||
//! ▼
|
||||
//! [B, 64, 20, 15] ── permute ──► [B, 64, 15, 20]
|
||||
//! │ DualAxialAttention (64 ch, 8 groups, width- then height-axial
|
||||
//! │ self-attention with BN-normalised qkv and BN-normalised similarity)
|
||||
//! │ Decoder convs 64 → 32 → 2 (3×3 then 1×1, BN + SiLU)
|
||||
//! ▼
|
||||
//! [B, 2, 15, 20] ── adaptive avg-pool (K, 1) ──► [B, K, 2] keypoints
|
||||
//! ```
|
||||
//!
|
||||
//! 2,225,042 parameters / ~0.055 GFLOPs at the 15-keypoint default
|
||||
//! (both verified against the reference — see `RESULTS.md`).
|
||||
//!
|
||||
//! Note: upstream `config.py` lists `TCN_CHANNELS = [480, 360, 240]`, but the
|
||||
//! released checkpoint and `models/` code use `[540, 440, 340, 240]`. This
|
||||
//! port follows the `models/` code, which we verified loads the released
|
||||
//! weights after key remapping.
|
||||
//!
|
||||
//! ## Feature gating
|
||||
//!
|
||||
//! [`WiFlowStdConfig`] (validation, parameter-count formula, output-shape
|
||||
//! inference) is pure Rust and always available. [`model::WiFlowStdModel`]
|
||||
//! (the tch / LibTorch forward pass) requires the `tch-backend` feature,
|
||||
//! matching [`crate::model`]'s gating.
|
||||
|
||||
pub mod config;
|
||||
|
||||
#[cfg(feature = "tch-backend")]
|
||||
mod layers;
|
||||
#[cfg(feature = "tch-backend")]
|
||||
pub mod model;
|
||||
|
||||
pub use config::{TcnGroupsMode, WiFlowStdConfig};
|
||||
|
||||
#[cfg(feature = "tch-backend")]
|
||||
pub use model::WiFlowStdModel;
|
||||
@@ -0,0 +1,360 @@
|
||||
//! WiFlow-STD forward pass (tch-rs / LibTorch backend, ADR-152 §2.2).
|
||||
//!
|
||||
//! Idiomatic reimplementation of the DY2434 reference (Apache-2.0); see the
|
||||
//! [module docs](crate::wiflow_std) for provenance and the evidence grade.
|
||||
//! From-scratch init: BN gamma is pinned to 1 (see `layers::bn_cfg`); the
|
||||
//! axial-attention qkv conv uses `N(0, sqrt(1/in_planes))` per the
|
||||
//! reference's `attention.py` intent (note the reference's *effective* init
|
||||
//! differs — its `_initialize_weights` re-inits every `nn.Conv1d`, qkv
|
||||
//! included, with `kaiming_normal(fan_out)`); conv weights keep tch defaults
|
||||
//! (kaiming-uniform fan_in), which differ in scale from PyTorch's defaults.
|
||||
//! These divergences affect from-scratch training dynamics only — BN absorbs
|
||||
//! them at init, and loaded checkpoints overwrite everything. The
|
||||
//! retrained PyTorch checkpoint loads via [`WiFlowStdModel::load`] after
|
||||
//! key-remapped safetensors export
|
||||
//! (`benchmarks/wiflow-std/export_to_safetensors.py`); numerical parity with
|
||||
//! the PyTorch forward pass is proven by
|
||||
//! `tests/test_wiflow_std_parity.rs` (max abs diff ~1.2e-7).
|
||||
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
use super::config::WiFlowStdConfig;
|
||||
use super::layers::{ConvBlock, DualAxialAttention, GroupedTemporalBlock};
|
||||
use crate::error::TrainError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WiFlowStdModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// WiFlow-STD pose model: TCN temporal encoder → asymmetric 2-D conv encoder
|
||||
/// → dual axial attention → conv decoder → adaptive pool to `(K, 2)` keypoints.
|
||||
///
|
||||
/// Input: `[B, subcarriers, window]` CSI amplitudes.
|
||||
/// Output: `[B, keypoints, 2]` normalised 2-D keypoint coordinates.
|
||||
pub struct WiFlowStdModel {
|
||||
vs: nn::VarStore,
|
||||
tcn: Vec<GroupedTemporalBlock>,
|
||||
conv_in: ConvBlock,
|
||||
conv_blocks: Vec<ConvBlock>,
|
||||
attention: DualAxialAttention,
|
||||
dec_conv1: nn::Conv2D,
|
||||
dec_bn1: nn::BatchNorm,
|
||||
dec_conv2: nn::Conv2D,
|
||||
dec_bn2: nn::BatchNorm,
|
||||
/// Active model configuration.
|
||||
pub config: WiFlowStdConfig,
|
||||
}
|
||||
|
||||
impl WiFlowStdModel {
|
||||
/// Build a new model with randomly-initialised weights on `device`.
|
||||
///
|
||||
/// Call `tch::manual_seed(seed)` before this for reproducibility.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`TrainError::Config`] if `config.validate()` fails.
|
||||
pub fn new(config: &WiFlowStdConfig, device: Device) -> Result<Self, TrainError> {
|
||||
config.validate()?;
|
||||
|
||||
let vs = nn::VarStore::new(device);
|
||||
let root = vs.root();
|
||||
|
||||
// TCN stack: dilation doubles per level, causal padding. Per-conv
|
||||
// groups follow `config.tcn_groups_mode`; only block 0's pointwise/
|
||||
// downsample convs use `config.input_pw_groups` (ADR-152 sweep).
|
||||
let mut tcn = Vec::with_capacity(config.tcn_channels.len());
|
||||
let mut c_in = config.subcarriers;
|
||||
for (i, &c_out) in config.tcn_channels.iter().enumerate() {
|
||||
let dilation = 1_i64 << i;
|
||||
let pw_groups = if i == 0 { config.input_pw_groups } else { 1 };
|
||||
tcn.push(GroupedTemporalBlock::new(
|
||||
&root / format!("tcn{i}"),
|
||||
c_in as i64,
|
||||
c_out as i64,
|
||||
dilation,
|
||||
config.tcn_conv_groups(c_in) as i64,
|
||||
config.tcn_conv_groups(c_out) as i64,
|
||||
pw_groups as i64,
|
||||
config.dropout,
|
||||
));
|
||||
c_in = c_out;
|
||||
}
|
||||
|
||||
// 2-D conv encoder: ConvBlock1 (stride 1) + asymmetric blocks with
|
||||
// the derived stride schedule ([2, 2, 2, 2] at the upstream default).
|
||||
let c0 = config.conv_channels[0] as i64;
|
||||
let conv_in = ConvBlock::new(&root / "conv_in", 1, c0, 1);
|
||||
let mut conv_blocks = Vec::with_capacity(config.conv_channels.len());
|
||||
let strides = config.conv_strides();
|
||||
let mut c_in = c0;
|
||||
for (i, &c_out) in config.conv_channels.iter().enumerate() {
|
||||
conv_blocks.push(ConvBlock::new(
|
||||
&root / format!("conv{i}"),
|
||||
c_in,
|
||||
c_out as i64,
|
||||
strides[i] as i64,
|
||||
));
|
||||
c_in = c_out as i64;
|
||||
}
|
||||
|
||||
let attention =
|
||||
DualAxialAttention::new(&root / "attention", c_in, config.attention_groups as i64);
|
||||
|
||||
// Decoder: c → decoder_mid (3×3) → 2 (1×1), BN + SiLU after each conv.
|
||||
let mid = config.decoder_mid() as i64;
|
||||
let dec_conv1 = nn::conv2d(
|
||||
&root / "dec_conv1",
|
||||
c_in,
|
||||
mid,
|
||||
3,
|
||||
nn::ConvConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let dec_bn1 = nn::batch_norm2d(&root / "dec_bn1", mid, super::layers::bn_cfg());
|
||||
let dec_conv2 = nn::conv2d(&root / "dec_conv2", mid, 2, 1, Default::default());
|
||||
let dec_bn2 = nn::batch_norm2d(&root / "dec_bn2", 2, super::layers::bn_cfg());
|
||||
|
||||
Ok(WiFlowStdModel {
|
||||
vs,
|
||||
tcn,
|
||||
conv_in,
|
||||
conv_blocks,
|
||||
attention,
|
||||
dec_conv1,
|
||||
dec_bn1,
|
||||
dec_conv2,
|
||||
dec_bn2,
|
||||
config: config.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Forward pass in training mode (dropout active, BN in train mode).
|
||||
///
|
||||
/// `csi`: `[B, subcarriers, window]` → `[B, keypoints, 2]`.
|
||||
pub fn forward_t(&self, csi: &Tensor) -> Tensor {
|
||||
self.forward_impl(csi, true)
|
||||
}
|
||||
|
||||
/// Forward pass without gradient tracking (inference mode).
|
||||
pub fn forward_inference(&self, csi: &Tensor) -> Tensor {
|
||||
tch::no_grad(|| self.forward_impl(csi, false))
|
||||
}
|
||||
|
||||
/// Save model weights. The tch `VarStore` dispatches the format on the
|
||||
/// file extension: `.safetensors` → safetensors, anything else → torch
|
||||
/// `.pt`.
|
||||
///
|
||||
/// **Platform constraint:** prefer `.safetensors`. The `.pt` path
|
||||
/// (`_save_parameters`/`_load_parameters`) is broken on Windows with
|
||||
/// torch 2.11 (GenericDict internal assert on the load roundtrip — see
|
||||
/// the `save_and_load_roundtrip` test below), and the verified retrained
|
||||
/// checkpoint is shipped as key-remapped safetensors anyway
|
||||
/// (`benchmarks/wiflow-std/export_to_safetensors.py`).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`TrainError::TrainingStep`] if the file cannot be written.
|
||||
pub fn save(&self, path: &std::path::Path) -> Result<(), TrainError> {
|
||||
self.vs
|
||||
.save(path)
|
||||
.map_err(|e| TrainError::training_step(format!("save failed: {e}")))
|
||||
}
|
||||
|
||||
/// Load model weights from a file (format dispatched on extension; see
|
||||
/// the `.pt`-on-Windows caveat on [`Self::save`]).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`TrainError::TrainingStep`] if the file cannot be read or the
|
||||
/// weights are incompatible with this architecture.
|
||||
pub fn load(&mut self, path: &std::path::Path) -> Result<(), TrainError> {
|
||||
self.vs
|
||||
.load(path)
|
||||
.map_err(|e| TrainError::training_step(format!("load failed: {e}")))
|
||||
}
|
||||
|
||||
/// Reference to the internal `VarStore` (e.g. to build an optimiser).
|
||||
pub fn var_store(&self) -> &nn::VarStore {
|
||||
&self.vs
|
||||
}
|
||||
|
||||
/// Mutable access to the internal `VarStore`.
|
||||
pub fn var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.vs
|
||||
}
|
||||
|
||||
/// Total number of trainable scalar parameters. Must equal
|
||||
/// [`WiFlowStdConfig::param_count`] (2,225,042 at the default config).
|
||||
pub fn num_parameters(&self) -> i64 {
|
||||
self.vs
|
||||
.trainable_variables()
|
||||
.iter()
|
||||
.map(|t| t.numel() as i64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn forward_impl(&self, csi: &Tensor, train: bool) -> Tensor {
|
||||
// TCN: [B, subcarriers, T] → [B, c_tcn, T].
|
||||
let mut h = csi.shallow_clone();
|
||||
for block in &self.tcn {
|
||||
h = block.forward_t(&h, train);
|
||||
}
|
||||
|
||||
// Image-like reshape: [B, c_tcn, T] → [B, 1, T, c_tcn].
|
||||
let h = h.transpose(1, 2).unsqueeze(1);
|
||||
|
||||
// 2-D conv encoder: [B, 1, T, S] → [B, C, T, S'].
|
||||
let mut h = self.conv_in.forward_t(&h, train);
|
||||
for block in &self.conv_blocks {
|
||||
h = block.forward_t(&h, train);
|
||||
}
|
||||
|
||||
// Swap to [B, C, S', T] for the axial attention + decoder.
|
||||
let h = h.permute([0, 1, 3, 2]);
|
||||
let h = self.attention.forward_t(&h, train);
|
||||
|
||||
// Decoder: [B, C, S', T] → [B, 2, S', T].
|
||||
let h = h
|
||||
.apply(&self.dec_conv1)
|
||||
.apply_t(&self.dec_bn1, train)
|
||||
.silu()
|
||||
.apply(&self.dec_conv2)
|
||||
.apply_t(&self.dec_bn2, train)
|
||||
.silu();
|
||||
|
||||
// [B, 2, S', T] → pool (K, 1) → [B, 2, K] → [B, K, 2].
|
||||
let k = self.config.keypoints as i64;
|
||||
h.adaptive_avg_pool2d([k, 1])
|
||||
.squeeze_dim(-1)
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests (require the tch-backend feature + LibTorch)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tch::Kind;
|
||||
|
||||
fn random_csi(cfg: &WiFlowStdConfig, batch: i64) -> Tensor {
|
||||
Tensor::rand(
|
||||
[batch, cfg.subcarriers as i64, cfg.window as i64],
|
||||
(Kind::Float, Device::Cpu),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn param_count_matches_pure_rust_formula() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("default config builds");
|
||||
// Pins the tch graph against the verified reference (2,225,042).
|
||||
assert_eq!(model.num_parameters(), cfg.param_count() as i64);
|
||||
assert_eq!(model.num_parameters(), 2_225_042);
|
||||
}
|
||||
|
||||
/// ADR-152 efficiency-sweep compact presets: the tch graph must realise
|
||||
/// exactly the trained checkpoints' measured parameter counts
|
||||
/// (benchmarks/wiflow-std/results/efficiency_sweep.jsonl) and produce
|
||||
/// the standard [B, 15, 2] output.
|
||||
#[test]
|
||||
fn compact_preset_param_counts_and_shapes() {
|
||||
for (cfg, expected) in [
|
||||
(WiFlowStdConfig::half(), 843_834_i64),
|
||||
(WiFlowStdConfig::quarter(), 338_600),
|
||||
(WiFlowStdConfig::tiny(), 56_290),
|
||||
] {
|
||||
tch::manual_seed(0);
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("preset builds");
|
||||
assert_eq!(model.num_parameters(), expected);
|
||||
assert_eq!(model.num_parameters(), cfg.param_count() as i64);
|
||||
let out = model.forward_inference(&random_csi(&cfg, 2));
|
||||
assert_eq!(out.size(), &[2, 15, 2]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_output_shape_15_keypoints() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
|
||||
let out = model.forward_t(&random_csi(&cfg, 2));
|
||||
assert_eq!(out.size(), &[2, 15, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_output_shape_17_keypoints_esp32() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = WiFlowStdConfig::for_keypoints(17);
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
|
||||
let out = model.forward_inference(&random_csi(&cfg, 1));
|
||||
assert_eq!(out.size(), &[1, 17, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inference_outputs_are_finite_and_deterministic() {
|
||||
tch::manual_seed(7);
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
|
||||
let csi = random_csi(&cfg, 1);
|
||||
let a = model.forward_inference(&csi);
|
||||
let b = model.forward_inference(&csi);
|
||||
assert!(
|
||||
bool::try_from(a.isfinite().all()).unwrap(),
|
||||
"non-finite output"
|
||||
);
|
||||
assert!(
|
||||
bool::try_from(a.eq_tensor(&b).all()).unwrap(),
|
||||
"inference must be deterministic (dropout disabled)"
|
||||
);
|
||||
}
|
||||
|
||||
/// Dumps the authoritative tch `VarStore` variable names + shapes. This is
|
||||
/// the source of truth for the PyTorch→tch key mapping implemented in
|
||||
/// `benchmarks/wiflow-std/export_to_safetensors.py` — rerun it (with
|
||||
/// `--nocapture`) whenever the architecture changes.
|
||||
#[test]
|
||||
fn dump_variable_names() {
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
|
||||
let vars = model.var_store().variables();
|
||||
let mut names: Vec<(String, Vec<i64>)> =
|
||||
vars.iter().map(|(n, t)| (n.clone(), t.size())).collect();
|
||||
names.sort();
|
||||
for (name, shape) in &names {
|
||||
println!("{name} {shape:?}");
|
||||
}
|
||||
println!("total: {} variables", names.len());
|
||||
assert!(!names.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_config_is_rejected() {
|
||||
let cfg = WiFlowStdConfig {
|
||||
subcarriers: 541, // not divisible by tcn_groups
|
||||
..Default::default()
|
||||
};
|
||||
assert!(WiFlowStdModel::new(&cfg, Device::Cpu).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_and_load_roundtrip() {
|
||||
use tempfile::tempdir;
|
||||
tch::manual_seed(42);
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let mut model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
// safetensors, not .pt: this torch build's _save_parameters/_load_parameters
|
||||
// .pt roundtrip is broken on Windows (GenericDict internal assert)
|
||||
let path = tmp.path().join("wiflow_std.safetensors");
|
||||
model.save(&path).expect("save");
|
||||
model.load(&path).expect("load");
|
||||
let out = model.forward_inference(&random_csi(&cfg, 1));
|
||||
assert_eq!(out.size(), &[1, 15, 2]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
//! Integration + property tests for [`wifi_densepose_train::mae`]
|
||||
//! (ADR-152 §2.3 — UNSW MAE pretraining recipe).
|
||||
//!
|
||||
//! All deterministic tests use fixed seeds; property tests use `proptest`
|
||||
//! with its default deterministic-replay machinery.
|
||||
|
||||
use proptest::prelude::*;
|
||||
use wifi_densepose_train::mae::{
|
||||
patchify, random_mask, unpatchify, unpatchify_visible, MaePretrainConfig,
|
||||
};
|
||||
use wifi_densepose_train::MaeError;
|
||||
|
||||
/// Deterministic test window: value = t * 1000 + sc (every cell unique).
|
||||
fn window(time: usize, subc: usize) -> Vec<f32> {
|
||||
(0..time * subc)
|
||||
.map(|i| ((i / subc) * 1000 + i % subc) as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config defaults + validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn default_config_matches_unsw_recipe() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
assert!((cfg.mask_ratio - 0.80).abs() < 1e-12);
|
||||
assert_eq!(cfg.patch_time, 30);
|
||||
assert_eq!(cfg.patch_subc, 3);
|
||||
assert_eq!(cfg.seed, 42);
|
||||
cfg.validate().expect("default recipe is valid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_json_round_trip() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let json = serde_json::to_string(&cfg).unwrap();
|
||||
let back: MaePretrainConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, cfg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_mask_ratio_rejected() {
|
||||
for ratio in [0.0, 1.0, -0.1, 1.5, f64::NAN] {
|
||||
let cfg = MaePretrainConfig {
|
||||
mask_ratio: ratio,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err(), "ratio {ratio} should be invalid");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_patch_dims_rejected() {
|
||||
let cfg = MaePretrainConfig {
|
||||
patch_time: 0,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
let cfg = MaePretrainConfig {
|
||||
patch_subc: 0,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Divisibility policy: error, never truncate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn non_divisible_window_errors_with_crop_hint() {
|
||||
let cfg = MaePretrainConfig::default(); // (30, 3)
|
||||
// Default TrainingConfig window 100 × 56 is NOT divisible by (30, 3).
|
||||
let err = cfg.validate_for_window(100, 56).unwrap_err();
|
||||
match err {
|
||||
MaeError::NotDivisible {
|
||||
axis,
|
||||
window,
|
||||
patch,
|
||||
remainder,
|
||||
crop,
|
||||
} => {
|
||||
assert_eq!(axis, "time");
|
||||
assert_eq!(window, 100);
|
||||
assert_eq!(patch, 30);
|
||||
assert_eq!(remainder, 10);
|
||||
assert_eq!(crop, 90);
|
||||
}
|
||||
other => panic!("expected NotDivisible, got {other:?}"),
|
||||
}
|
||||
assert_eq!(cfg.cropped_window_shape(100, 56), (90, 54));
|
||||
// The hinted crop validates cleanly.
|
||||
cfg.validate_for_window(90, 54).expect("crop is divisible");
|
||||
assert_eq!(cfg.num_patches(90, 54).unwrap(), 3 * 18);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_larger_than_window_errors() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let err = cfg.validate_for_window(20, 3).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
MaeError::PatchExceedsWindow { axis: "time", .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_length_mismatch_errors() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let buf = vec![0.0_f32; 89 * 54]; // declared 90 × 54
|
||||
let err = patchify(&buf, 90, 54, &cfg).unwrap_err();
|
||||
assert!(matches!(err, MaeError::WindowShapeMismatch { .. }));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NaN handling
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn nan_and_inf_input_rejected_with_location() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let mut buf = window(90, 54);
|
||||
buf[2 * 54 + 7] = f32::NAN;
|
||||
match patchify(&buf, 90, 54, &cfg).unwrap_err() {
|
||||
MaeError::NonFiniteValue { row, col, .. } => {
|
||||
assert_eq!((row, col), (2, 7));
|
||||
}
|
||||
other => panic!("expected NonFiniteValue, got {other:?}"),
|
||||
}
|
||||
buf[2 * 54 + 7] = f32::INFINITY;
|
||||
assert!(matches!(
|
||||
patchify(&buf, 90, 54, &cfg),
|
||||
Err(MaeError::NonFiniteValue { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finite_input_is_nan_free_after_round_trip() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let buf = window(90, 54);
|
||||
let grid = patchify(&buf, 90, 54, &cfg).unwrap();
|
||||
assert!(grid.patches.iter().flatten().all(|v| v.is_finite()));
|
||||
assert!(unpatchify(&grid).iter().all(|v| v.is_finite()));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Patchify / unpatchify round trip
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn patchify_unpatchify_identity_default_recipe() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let buf = window(90, 54);
|
||||
let grid = patchify(&buf, 90, 54, &cfg).unwrap();
|
||||
assert_eq!(grid.n_patches(), 54);
|
||||
assert_eq!(grid.patch_len(), 90);
|
||||
assert_eq!(grid.window_shape(), (90, 54));
|
||||
assert_eq!(unpatchify(&grid), buf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_layout_is_time_major() {
|
||||
// 4 × 4 window, (2, 2) patches → patch 0 is rows 0–1 × cols 0–1.
|
||||
let cfg = MaePretrainConfig {
|
||||
patch_time: 2,
|
||||
patch_subc: 2,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
let buf = window(4, 4);
|
||||
let grid = patchify(&buf, 4, 4, &cfg).unwrap();
|
||||
assert_eq!(grid.patches[0], vec![0.0, 1.0, 1000.0, 1001.0]);
|
||||
// Patch index 1 is the next subcarrier block on the same time rows.
|
||||
assert_eq!(grid.patches[1], vec![2.0, 3.0, 1002.0, 1003.0]);
|
||||
// Patch index n_patches_subc starts the second time row of patches.
|
||||
assert_eq!(grid.patches[2], vec![2000.0, 2001.0, 3000.0, 3001.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpatchify_visible_restores_visible_and_fills_masked() {
|
||||
let cfg = MaePretrainConfig::default();
|
||||
let buf = window(90, 54);
|
||||
let (grid, mask) = cfg.mask_window(&buf, 90, 54).unwrap();
|
||||
let fill = -1.0_f32;
|
||||
let recon = unpatchify_visible(&grid, &mask.visible, fill);
|
||||
|
||||
// Visible patch regions are identical to the input; masked regions = fill.
|
||||
let full = unpatchify(&grid);
|
||||
assert_eq!(full, buf);
|
||||
let mut n_fill = 0usize;
|
||||
for (i, (&r, &orig)) in recon.iter().zip(buf.iter()).enumerate() {
|
||||
if r == fill && orig != fill {
|
||||
n_fill += 1;
|
||||
} else {
|
||||
assert_eq!(r, orig, "visible value at flat index {i} must round-trip");
|
||||
}
|
||||
}
|
||||
assert_eq!(n_fill, mask.masked.len() * grid.patch_len());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Random mask: exact count, determinism, disjointness
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn mask_count_is_exact_for_default_recipe() {
|
||||
// 54 patches @ 0.80 → round(43.2) = 43 masked, 11 visible.
|
||||
let cfg = MaePretrainConfig::default();
|
||||
assert_eq!(cfg.num_masked(54), 43);
|
||||
let mask = random_mask(54, cfg.mask_ratio, cfg.seed).unwrap();
|
||||
assert_eq!(mask.masked.len(), 43);
|
||||
assert_eq!(mask.visible.len(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_seed_same_mask_different_seed_differs() {
|
||||
let a = random_mask(100, 0.80, 7).unwrap();
|
||||
let b = random_mask(100, 0.80, 7).unwrap();
|
||||
assert_eq!(a, b, "same (n, ratio, seed) must reproduce the mask");
|
||||
|
||||
let c = random_mask(100, 0.80, 8).unwrap();
|
||||
assert_ne!(a.masked, c.masked, "different seeds must differ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_mask_rejects_invalid_ratios() {
|
||||
// Error-not-silent: NaN must not silently mask 0 patches; ratios outside
|
||||
// (0, 1) must not degenerate to all-visible / all-masked grids.
|
||||
for ratio in [
|
||||
f64::NAN,
|
||||
f64::INFINITY,
|
||||
f64::NEG_INFINITY,
|
||||
1.0,
|
||||
1.5,
|
||||
0.0,
|
||||
-0.1,
|
||||
] {
|
||||
let err = random_mask(54, ratio, 42).unwrap_err();
|
||||
assert!(
|
||||
matches!(err, MaeError::InvalidMaskRatio { .. }),
|
||||
"ratio {ratio} must be rejected, got {err:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_window_rejects_invalid_ratio_before_masking() {
|
||||
let cfg = MaePretrainConfig {
|
||||
mask_ratio: f64::NAN,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
let buf = window(90, 54);
|
||||
assert!(matches!(
|
||||
cfg.mask_window(&buf, 90, 54),
|
||||
Err(MaeError::InvalidMaskRatio { .. })
|
||||
));
|
||||
}
|
||||
|
||||
proptest! {
|
||||
/// Exact count, sortedness, range, disjointness, and full coverage hold
|
||||
/// for arbitrary grid sizes, ratios, and seeds.
|
||||
#[test]
|
||||
fn prop_mask_invariants(
|
||||
n in 1usize..600,
|
||||
ratio in 0.01f64..0.99,
|
||||
seed in any::<u64>(),
|
||||
) {
|
||||
let mask = random_mask(n, ratio, seed).unwrap();
|
||||
let expected_masked = ((ratio * n as f64).round() as usize).min(n);
|
||||
prop_assert_eq!(mask.masked.len(), expected_masked);
|
||||
prop_assert_eq!(mask.masked.len() + mask.visible.len(), n);
|
||||
|
||||
// In range, sorted, strictly increasing (no duplicates).
|
||||
for set in [&mask.masked, &mask.visible] {
|
||||
for w in set.windows(2) {
|
||||
prop_assert!(w[0] < w[1]);
|
||||
}
|
||||
if let Some(&last) = set.last() {
|
||||
prop_assert!(last < n);
|
||||
}
|
||||
}
|
||||
// Disjoint + complete: merged sets are exactly 0..n.
|
||||
let mut all: Vec<usize> = mask.masked.iter().chain(&mask.visible).copied().collect();
|
||||
all.sort_unstable();
|
||||
prop_assert_eq!(all, (0..n).collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
/// Determinism by seed for arbitrary inputs.
|
||||
#[test]
|
||||
fn prop_mask_deterministic(n in 1usize..400, seed in any::<u64>()) {
|
||||
prop_assert_eq!(
|
||||
random_mask(n, 0.80, seed).unwrap(),
|
||||
random_mask(n, 0.80, seed).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
/// Round-trip identity for arbitrary divisible window/patch geometries.
|
||||
#[test]
|
||||
fn prop_patchify_round_trip(
|
||||
pt in 1usize..8,
|
||||
ps in 1usize..8,
|
||||
nt in 1usize..6,
|
||||
ns in 1usize..6,
|
||||
seed in any::<u64>(),
|
||||
) {
|
||||
let (time, subc) = (pt * nt, ps * ns);
|
||||
let cfg = MaePretrainConfig {
|
||||
patch_time: pt,
|
||||
patch_subc: ps,
|
||||
seed,
|
||||
..MaePretrainConfig::default()
|
||||
};
|
||||
let buf = window(time, subc);
|
||||
let grid = patchify(&buf, time, subc, &cfg).unwrap();
|
||||
prop_assert_eq!(grid.n_patches(), nt * ns);
|
||||
prop_assert_eq!(unpatchify(&grid), buf);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
//! Numerical parity between the Rust WiFlow-STD port and the retrained
|
||||
//! PyTorch checkpoint (ADR-152 §2.2).
|
||||
//!
|
||||
//! The fixtures are produced by `benchmarks/wiflow-std/export_to_safetensors.py`
|
||||
//! (gitignored — they derive from the retrained checkpoint, which is itself
|
||||
//! gitignored):
|
||||
//!
|
||||
//! - `results/retrained_wiflow_std.safetensors` — the epoch-36 checkpoint
|
||||
//! (val PCK@20 96.99%) remapped to tch `VarStore` variable names
|
||||
//! - `results/parity_fixture.json` — a deterministic input (seed 42, shape
|
||||
//! `(2, 540, 20)`, uniform `[0, 1]`) and the upstream `WiFlowPoseModel`'s
|
||||
//! eval-mode output on it
|
||||
//!
|
||||
//! Run explicitly (needs LibTorch, e.g. `LIBTORCH_USE_PYTORCH=1` with the
|
||||
//! torch DLL directory on `PATH`):
|
||||
//!
|
||||
//! ```text
|
||||
//! cargo test -p wifi-densepose-train --features tch-backend \
|
||||
//! --test test_wiflow_std_parity -- --ignored --nocapture
|
||||
//! ```
|
||||
|
||||
#![cfg(feature = "tch-backend")]
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tch::{Device, Tensor};
|
||||
use wifi_densepose_train::{WiFlowStdConfig, WiFlowStdModel};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ParityFixture {
|
||||
input_shape: Vec<i64>,
|
||||
input: Vec<f32>,
|
||||
output_shape: Vec<i64>,
|
||||
output: Vec<f32>,
|
||||
}
|
||||
|
||||
fn results_dir() -> PathBuf {
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("benchmarks")
|
||||
.join("wiflow-std")
|
||||
.join("results")
|
||||
}
|
||||
|
||||
/// Loads the retrained checkpoint into the Rust model and asserts the forward
|
||||
/// pass matches PyTorch to within 1e-4 max absolute difference.
|
||||
///
|
||||
/// `#[ignore]`d by default: it needs the gitignored fixtures above plus a
|
||||
/// working LibTorch environment, neither of which exist in CI.
|
||||
#[test]
|
||||
#[ignore = "needs gitignored fixtures (run export_to_safetensors.py) + LibTorch env; run with --ignored"]
|
||||
fn retrained_checkpoint_matches_pytorch_forward() {
|
||||
let dir = results_dir();
|
||||
let weights = dir.join("retrained_wiflow_std.safetensors");
|
||||
let fixture_path = dir.join("parity_fixture.json");
|
||||
for p in [&weights, &fixture_path] {
|
||||
assert!(
|
||||
p.exists(),
|
||||
"missing fixture {} — run benchmarks/wiflow-std/export_to_safetensors.py first",
|
||||
p.display()
|
||||
);
|
||||
}
|
||||
|
||||
let fixture: ParityFixture = serde_json::from_reader(BufReader::new(
|
||||
File::open(&fixture_path).expect("open parity_fixture.json"),
|
||||
))
|
||||
.expect("parse parity_fixture.json");
|
||||
assert_eq!(fixture.input_shape, vec![2, 540, 20]);
|
||||
assert_eq!(fixture.output_shape, vec![2, 15, 2]);
|
||||
|
||||
let cfg = WiFlowStdConfig::default();
|
||||
let mut model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build default model");
|
||||
model
|
||||
.load(&weights)
|
||||
.expect("safetensors load: every VarStore variable must match by name and shape");
|
||||
|
||||
let input = Tensor::from_slice(&fixture.input).reshape(&fixture.input_shape[..]);
|
||||
let expected = Tensor::from_slice(&fixture.output).reshape(&fixture.output_shape[..]);
|
||||
|
||||
let output = model.forward_inference(&input);
|
||||
assert_eq!(output.size(), fixture.output_shape);
|
||||
|
||||
let max_diff = (&output - &expected).abs().max().double_value(&[]);
|
||||
println!("max |rust - python| = {max_diff:.3e}");
|
||||
assert!(
|
||||
max_diff < 1e-4,
|
||||
"Rust forward pass diverges from PyTorch: max abs diff {max_diff:.3e} >= 1e-4"
|
||||
);
|
||||
}
|
||||
Vendored
+1
-1
Submodule vendor/ruvector updated: e383476014...a083bd77fa
Reference in New Issue
Block a user