mirror of
https://github.com/ruvnet/RuView
synced 2026-06-18 11:43:19 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4bf88e1283 | |||
| a4c2935a2f | |||
| 315d7df09e | |||
| bdd1eaf927 | |||
| 4001e9e178 | |||
| 65e29ef47a | |||
| cb30988cf9 | |||
| 128b129474 | |||
| 15a983b555 | |||
| c6e7667676 | |||
| d639c747df | |||
| 42c764652d | |||
| db02956c22 | |||
| c84ea39e62 | |||
| 760d05026c | |||
| a784546918 | |||
| 9c751d0d92 | |||
| a13e9b66cb | |||
| 6db183bf3e | |||
| f65d0f79e7 | |||
| 7fb3b88061 | |||
| aeac5f5543 | |||
| c257e67c3d | |||
| a4d5ea88f3 | |||
| ebe217569b | |||
| c27d6cc98e |
@@ -36,7 +36,7 @@ jobs:
|
||||
features:
|
||||
- { label: 'default', flags: '--no-default-features' }
|
||||
- { label: 'train', flags: '--features train' }
|
||||
- { label: 'ruflo+itar', flags: '--features ruflo,itar-unrestricted' }
|
||||
- { label: 'ruflo', flags: '--features ruflo' }
|
||||
- { label: 'full+train', flags: '--features full,train' }
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
+14
@@ -277,3 +277,17 @@ aether-arena/staging/
|
||||
# MM-Fi benchmark dataset archives — large data, fetch separately, never commit
|
||||
assets/MM-Fi/E0*.zip
|
||||
assets/MM-Fi/*.zip
|
||||
|
||||
# through-wall demo: regenerable trained model artifact
|
||||
examples/through-wall/model/
|
||||
|
||||
# RuView harness (npx ruview) build artifacts — ADR-182
|
||||
harness/**/node_modules/
|
||||
harness/**/*.tgz
|
||||
harness/**/package-lock.json
|
||||
harness/**/.claude-flow/
|
||||
harness/**/ruvector.db
|
||||
|
||||
# ruvector runtime/hook DB — never tracked (any depth)
|
||||
ruvector.db
|
||||
**/ruvector.db
|
||||
|
||||
@@ -21,3 +21,11 @@
|
||||
[submodule "vendor/rufield"]
|
||||
path = vendor/rufield
|
||||
url = https://github.com/ruvnet/rufield
|
||||
[submodule "v2/crates/ruview-swarm"]
|
||||
path = v2/crates/ruview-swarm
|
||||
url = https://github.com/ruvnet/ruv-drone.git
|
||||
branch = main
|
||||
[submodule "v2/crates/worldgraph"]
|
||||
path = v2/crates/worldgraph
|
||||
url = https://github.com/ruvnet/worldgraph.git
|
||||
branch = main
|
||||
|
||||
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Fixed
|
||||
- **Multistatic fusion guard interval is now operator-configurable — fixes permanent trust demotion with WiFi-synced ESP32 nodes (#1049).** Two independently-clocked ESP32-S3 boards on ESP-NOW sync drift 10–150 ms (typ. ~70 ms) — the 100 ms beacon + WiFi-MAC jitter cannot hold them within the published 60 ms default guard, so the governed-trust cycle permanently demoted to `Restricted`, suppressed all pose output, and spun the error counter to 200k+ with **no escape hatch but a container restart**. Added a **direct `WDP_GUARD_INTERVAL_US` override** (+ optional `WDP_SOFT_GUARD_US`) to `multistatic_guard_config_from_env`, so a deployment can lift the hard guard past its measured spread (e.g. `WDP_GUARD_INTERVAL_US=200000`) without having to know its exact TDM schedule. Precedence is most-specific-wins: a direct override beats the existing `WDP_TDM_SLOTS`+`WDP_TDM_SLOT_US` schedule-derived guard, which beats the 60 ms/20 ms default; the override is applied on top of whichever base is selected, the soft band is always clamped strictly below the hard guard, and a malformed/zero value is ignored (falls back to the base rather than breaking fusion). The effective guard is now logged at startup. Pinned by 6 new tests (`multistatic_guard_config_tests`): direct-override-wins / beats-TDM-derived / soft-clamped-below-hard / lowering-hard-pulls-soft-down / malformed-or-zero-falls-back / default-when-unset. `wifi-densepose-sensing-server` bin tests **449 → 455**, 0 failed; Python proof VERDICT PASS, hash unchanged (off the signal proof path).
|
||||
|
||||
### Security
|
||||
- **`wifi-densepose-occworld-candle` — beyond-SOTA security + correctness review (Milestone #9, crate 4/4).** (1) **HIGH (MEASURED) — checkpoint-load crash on any int32 tensor** (`model.rs::safetensor_dtype_to_candle`). `safetensors::Dtype::I32` was mapped to `candle_core::DType::I64` and the raw int32 byte buffer (4 bytes/elem) was then handed to `Tensor::from_raw_buffer(.., I64, shape, ..)`. Candle derives `elem_count = data.len() / dtype.size_in_bytes()`, so the I64 path halved the element count while keeping the *original* shape — yielding a tensor whose declared shape claims twice as many elements as its backing storage holds. Reading it **panics** (`range end index 6 out of range for slice of length 3` — slice OOB inside candle-core) on any attacker-supplied or PyTorch-exported checkpoint containing an int32 tensor (common: index/buffer tensors). Fixed by mapping `I32 → DType::I32` (and `I16 → DType::I16`), both first-class candle dtypes. Reproduction recorded on old code; pinned by `tests/checkpoint_loading.rs::int32_tensor_loads_with_consistent_shape_and_values` (panics on old, passes on new) plus F32/I64/corrupt-file control cases. (2) **LOW (MEASURED) — `predict()` lacked frame/batch validation at the input boundary** (`inference.rs`). It validated H/W/D but not the externally-supplied frame count; an `f_in > num_frames*2` over-indexed the temporal positional embedding deep in the transformer and surfaced as a cryptic candle "gather" `InvalidIndex` (returned error, not a panic — candle bounds-checks), and a zero frame/batch dim fed a zero-element tensor into the pipeline. Now rejected at the boundary with a clear `ShapeMismatch`. Pinned by `predict_rejects_zero_frames` / `predict_rejects_too_many_frames` / `predict_accepts_frame_count_at_capacity`. (3) **LOW (MEASURED) — divide-by-zero panic on a degenerate input to the public `VQCodebook::encode`** (`vqvae.rs`): a rank-0 / empty-last-dim tensor made `last == 0` and panicked on `elem_count() / last`. Now fails closed with a clear error. Pinned by `encode_rejects_scalar_without_panicking`. **Dimensions confirmed CLEAN with evidence:** panic surface — zero `unwrap()`/`expect()`/`panic!`/`unreachable!` in production code paths (grep evidence; all error handling via `?`/`map_err`); NaN-state-poisoning — N/A (engine is stateless between `predict` calls, input is `u8` class indices so non-finite input is structurally impossible, no persistent world-model buffer to latch into); unbounded-alloc / shape-data mismatch from malformed weights — defended upstream by `safetensors::validate()` (overflow-checked `nelements*dtype.size()` vs declared byte range, rejected before reaching candle); secrets — none (grep clean, only `token_h`/`token_w` config fields match). `unsafe_code = forbid` in the crate manifest. **Build/validation status (MEASURED on Windows):** crate builds and tests under `cargo test -p wifi-densepose-occworld-candle --no-default-features` — **29/29 pass** (20 unit + 4 checkpoint_loading + 3 predict_honesty + 2 doc) after fixes; `cargo test --workspace --no-default-features` = 0 failed across all crates (lone `wifi-densepose-desktop` `api_integration` failure was a Windows "Access is denied (os error 5)" file-lock flake — re-ran in isolation **21/21 pass**); Python proof VERDICT PASS, hash `f8e76f21…446f7a` unchanged. *Warrants ADR slot 179 (parent to author).*
|
||||
- **`wifi-densepose-wasm-edge` beyond-SOTA closing review — boundary NaN-state-poisoning guard + clean-with-evidence attestation (ADR-040 edge crate, ~70 modules).** Closing pass of the security campaign over the last untouched sizeable crate. **One real finding fixed (LOW / source-analysis + reproduced):** the two WASM↔host frame boundaries (`lib.rs::on_frame`/`on_timer` and `bin/ghost_hunter.rs::on_frame`) read raw IEEE-754 `f32` from the `csi_get_phase`/`csi_get_amplitude`/`csi_get_variance`/`csi_get_motion_energy` host imports **without any finiteness check** — the entire crate had **zero** `is_finite`/`is_nan` guards, and the in-crate `clamp` helpers propagate NaN (`NaN < lo` and `NaN > hi` are both false). A single non-finite value (firmware DSP bug, uninitialised buffer, or hostile host) latches NaN into the long-lived per-module accumulators (EMA, Welford, phasor sums, anomaly baselines); once latched, every downstream comparison evaluates `false`, so detectors fail **degraded** (stuck gate state, silently-disabled anomaly checks) — silent corruption, not a crash (WASM `panic=abort` is *not* tripped: no indexing/`unwrap` on the poisoned value). Threat model is a **semi-trusted** boundary (the Tier-2 DSP firmware supplies the imports, not direct network/JS), hence LOW severity / defense-in-depth. **Fix:** added `sanitize_host_f32()` (maps non-finite→`0.0`, `core`-only so it holds in `no_std`) applied at every `host_get_*` float read — a single chokepoint covering all ~70 downstream modules, mirroring the existing M-01 negative-`n_subcarriers` boundary clamp. **Pinned by** `boundary_tests::{sanitize_passes_finite_values_through, sanitize_maps_non_finite_to_zero, coherence_monitor_nan_latches_without_sanitize_but_not_with}` — the last asserts on the *current* `CoherenceMonitor` that a raw NaN frame latches the smoothed score (documents the hazard) while the boundary-sanitized path stays finite. **Dimensions attested CLEAN with evidence (source-analysis):** (a) **panic-on-input** — every non-test `unwrap()`/`expect()` is either `#[cfg(test)]` or in the `std`-gated RVF *builder* host tool writing to an in-memory `Vec` (infallible); no `panic!`/`unreachable!`/`todo!`/`get_unchecked` in any hot path. (b) **shape/bounds** — all frame-buffer access is `min()`-clamped (`MAX_SC=32`, `DTW_MAX_LEN`, `LCS_WINDOW`, `PATTERN_LEN`), all index-by-cast sites (`feature_id as usize`, `conclusion_id`, `minute_counter`, `plan_step`) are either compile-time-const-bounded or `if idx <`/`%`-guarded; negative `n_subcarriers` already mapped to 0 (M-01). (c) **memory/leak** — no `move ||` closures, no `mem::forget`/`Box::leak`/`.leak()`; the only `Box::new` is in the `std`-gated `skill_registry` (one-time init, bounded). (d) **secrets** — none (grep clean). **MEASURED build/test evidence:** host `cargo test --features std,medical-experimental` = **672 passed / 0 failed** (was 669 pre-fix; +3 new tests); the real deployment artifacts all build clean on the actual target — `cargo build --target wasm32-unknown-unknown --release` (no_std/panic=abort default lib), `--bin ghost_hunter --no-default-features --features standalone-bin`, and `--features medical-experimental` (toolchain 1.89 per `rust-toolchain.toml`). No ADR slot needed — a single LOW defense-in-depth boundary fix; CHANGELOG attestation suffices.
|
||||
|
||||
@@ -601,6 +601,8 @@ claude --plugin-dir ./plugins/ruview
|
||||
|
||||
Verify the plugin structure: `bash plugins/ruview/scripts/smoke.sh`. Full details: [`plugins/ruview/README.md`](plugins/ruview/README.md).
|
||||
|
||||
**Portable harness — `npx @ruvnet/ruview`:** a lighter, host-portable companion to the in-repo plugin, minted via [MetaHarness](https://www.npmjs.com/package/metaharness) and hardened per [ADR-182](docs/adr/ADR-182-npx-ruview-harness-via-metaharness.md). It runs **without cloning this repo** and on more hosts (Claude Code, Codex, Copilot, opencode, …), exposing the RuView operator tools (`onboard`, `verify`, `node_monitor`, `calibrate`, `node_flash`) over an MCP server — plus the project's **MEASURED-vs-CLAIMED honesty guardrail enforced in code** (`ruview.claim_check` flags untagged or retracted-"100%" accuracy claims). v0.1: the onboarding/verify/claim-check paths are tested (17/17, `verify.py` → PASS); the hardware tools are fail-closed wrappers. Try `npx @ruvnet/ruview` to onboard, or `npx @ruvnet/ruview claim-check --text "…"`. Source: [`harness/ruview/`](harness/ruview/README.md).
|
||||
|
||||
---
|
||||
|
||||
## 📖 Documentation
|
||||
@@ -614,6 +616,7 @@ Verify the plugin structure: `bash plugins/ruview/scripts/smoke.sh`. Full detail
|
||||
| [**SENSE-BRIDGE — rvagent MCP server**](tools/ruview-mcp/README.md) | Dual-transport MCP server (`@ruvnet/rvagent`) bridging the RuView sensing stack to AI agents (Claude Code, Cursor, ruflo swarms). 6 tools wired: `ruview.presence.now`, `ruview.vitals.get_{breathing,heart_rate,all}`, `ruview.bfld.last_scan`, `ruview.bfld.subscribe`. stdio + Streamable HTTP (`POST /mcp`, Origin-validated, bearer-token auth, `127.0.0.1` bind). Full 20-tool Zod schema barrel + 5 RUVIEW-POLICY governance tools. 93 tests. [ADR-124](docs/adr/ADR-124-rvagent-mcp-ruvector-npm-integration.md). Try: `npx @ruvnet/rvagent stdio`. |
|
||||
| [Semantic Primitives — Precision/Recall](docs/integrations/semantic-primitives-metrics.md) | Per-primitive F1 on the held-out paired-capture set: someone-sleeping, possible-distress, room-active, elderly-inactivity-anomaly, meeting, bathroom, fall-risk, bed-exit, no-movement, multi-room. |
|
||||
| [Claude Code / Codex Plugin](plugins/ruview/README.md) | The `ruview` plugin + marketplace — skills, `/ruview-*` commands, agents, and the Codex prompt mirror |
|
||||
| [Portable harness — `npx @ruvnet/ruview`](harness/ruview/README.md) | MetaHarness-minted, host-portable RuView operator harness — `ruview.*` MCP tools + the MEASURED-vs-CLAIMED honesty guardrail enforced in code ([ADR-182](docs/adr/ADR-182-npx-ruview-harness-via-metaharness.md)). A lighter, multi-host companion to the in-repo plugin. |
|
||||
| [Architecture Decisions](docs/adr/README.md) | 96 ADRs — why each technical choice was made, organized by domain (hardware, signal processing, ML, platform, infrastructure) |
|
||||
| [Domain Models](docs/ddd/README.md) | 8 DDD models (RuvSense, Signal Processing, Training Pipeline, Hardware Platform, Sensing Server, WiFi-Mat, CHCI, rvCSI) — bounded contexts, aggregates, domain events, and ubiquitous language |
|
||||
| [rvCSI — edge RF sensing runtime](https://github.com/ruvnet/rvcsi) | Rust-first / TypeScript-accessible / hardware-abstracted CSI runtime: multi-source ingestion (incl. real nexmon_csi `.pcap` from a **Raspberry Pi 5** / Pi 4 / Pi 3B+ — CYW43455 / BCM43455c0) → validation → DSP → typed events → RuVector RF memory ([ADR-095](docs/adr/ADR-095-rvcsi-edge-rf-sensing-platform.md), [ADR-096](docs/adr/ADR-096-rvcsi-ffi-crate-layout.md), [domain model](docs/ddd/rvcsi-domain-model.md)). Now its own repo — [`ruvnet/rvcsi`](https://github.com/ruvnet/rvcsi) — vendored here under `vendor/rvcsi`; 9 `rvcsi-*` crates on crates.io, `@ruv/rvcsi` on npm, plus a Claude Code plugin. |
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,279 @@
|
||||
# ADR-182: `npx ruview` — A RuView Agent Harness Minted via MetaHarness
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Accepted — **P1+P2 implemented & validated** (`harness/ruview/`, 17/17 tests, MCP handshake + `ruview.verify` PASS against the real repo, packs to 16.7 kB / 21 files) · P3 publish-ready (name decision pending) · P4 (router + provenance) designed |
|
||||
| **Date** | 2026-06-17 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **RUVIEW-HARNESS** |
|
||||
| **Builds on** | MetaHarness (`metaharness@0.1.15`, `@metaharness/kernel`, `@metaharness/host-*`, `@metaharness/router`), the `ruview-*` Claude Code subagents (`ruview-onboarding-guide`, `ruview-config-engineer`, `ruview-training-engineer`), the `wifi-densepose` CLI (`calibrate`/`enroll`/`train-room`/`room-watch`), the sensing-server, ADR-028 (witness verification), ADR-095/096 (rvCSI runtime), ADR-260/262 (RuField bridge) |
|
||||
| **Supersedes** | none |
|
||||
|
||||
## Context
|
||||
|
||||
RuView (WiFi-DensePose) is a deep stack — 15 Rust crates, an ESP32 firmware line,
|
||||
a sensing-server, a CLI, ~180 ADRs, a calibration pipeline, training recipes, and a
|
||||
hard cultural rule that **every claim must be independently reproducible** (the
|
||||
"prove everything" ethos, after the project was accused of AI-slop). The barrier to
|
||||
entry is correspondingly steep: a newcomer who wants to "set up WiFi sensing" must
|
||||
discover the right firmware variant, provision an ESP32 over a Windows-only Python
|
||||
subprocess, point it at the sensing-server, run `calibrate` → `enroll` →
|
||||
`train-room`, and know which numbers are MEASURED vs CLAIMED. We already encode this
|
||||
knowledge as **Claude Code subagents** (`ruview-onboarding-guide`,
|
||||
`ruview-config-engineer`, `ruview-training-engineer`) — but those only exist inside
|
||||
*this* repo's `.claude/agents/`, only on Claude Code, and only for someone who has
|
||||
already cloned the monorepo.
|
||||
|
||||
Separately, this session shipped **MetaHarness** (`metaharness@0.1.15`): a tool that
|
||||
*"mints a custom AI agent harness from any repo"*, runnable on **9 hosts**
|
||||
(claude-code, codex, pi-dev, hermes, openclaw, rvm, copilot, opencode,
|
||||
github-actions) over a wasm-primary / NAPI-RS-fallback **kernel**, with a
|
||||
**cost-optimal model router** (`@metaharness/router`, the productized DRACO Phase-2
|
||||
k-NN finding) and ed25519/SLSA/SBOM provenance baked in. Crucially, MetaHarness
|
||||
**already ships a `vertical:ruview` template** in its template list. That template
|
||||
is generic scaffolding; it is not wired to RuView's actual tools, agents, or the
|
||||
"prove everything" guardrails.
|
||||
|
||||
The gap: **there is no single, host-portable, provenance-signed entry point that
|
||||
gives any user an AI agent that actually knows how to operate RuView.** A user
|
||||
should be able to run one command —
|
||||
|
||||
```bash
|
||||
npx ruview
|
||||
```
|
||||
|
||||
— in an empty directory (or alongside an ESP32) and get an agent harness that can
|
||||
onboard them, configure firmware, drive a live capture, train a room model, and
|
||||
**refuse to overstate accuracy** — on whichever coding host they already use.
|
||||
|
||||
## Decision
|
||||
|
||||
**Mint a first-class RuView agent harness from this repo using MetaHarness, harden
|
||||
its `vertical:ruview` template into a RuView-specific harness with a real MCP tool
|
||||
surface and the project's honesty guardrails, and publish it as `npx ruview`.**
|
||||
|
||||
`npx ruview` is *not* a new runtime. It is a **thin, versioned distribution** of a
|
||||
MetaHarness harness: the kernel + host adapters + a RuView "genome" (skills, agents,
|
||||
MCP tools, guardrails) generated from and pinned against this monorepo. The harness
|
||||
is the product; `npx ruview` is the front door.
|
||||
|
||||
### Why mint-from-repo instead of hand-writing a harness
|
||||
|
||||
MetaHarness's value here is exactly the work we would otherwise hand-roll across 9
|
||||
hosts: host-specific config (`.claude/settings.json` MCP + hooks for claude-code,
|
||||
the codex/copilot/opencode equivalents), the kernel that abstracts wasm-vs-native,
|
||||
the cost router, and the provenance chain. We write the **RuView knowledge once** as
|
||||
host-neutral genome assets; MetaHarness projects them onto each host adapter. This
|
||||
also keeps the harness regenerable: when the CLI or an ADR changes, re-mint and
|
||||
re-pin rather than maintaining 9 divergent copies.
|
||||
|
||||
### What the harness contains (the RuView genome)
|
||||
|
||||
1. **Skills / playbooks** (host-neutral markdown, projected to each host's skill
|
||||
format):
|
||||
- `onboard` — zero-to-sensing path picker (Docker demo / repo build / live
|
||||
ESP32), the physics caveats, the hardware table. Port of
|
||||
`ruview-onboarding-guide`.
|
||||
- `provision-node` — ESP-IDF v5.4 Windows-subprocess build/flash/provision flow
|
||||
(the exact MSYSTEM-stripped invocation from `CLAUDE.local.md`), firmware
|
||||
variant selection (8MB display / 4MB no-display / C6), NVS + WiFi + channel /
|
||||
MAC-filter overrides (ADR-060).
|
||||
- `calibrate-room` — `baseline → enroll → extract → train` via the
|
||||
`wifi-densepose` CLI (`calibrate`/`calibrate-serve`/`enroll`/`train-room`/
|
||||
`room-watch`, ADR-151).
|
||||
- `train-pose` — camera-supervised + camera-free training, the MEASURED-vs-CLAIMED
|
||||
discipline, the mean-pose baseline check (ADR-079, ADR-152, ADR-181).
|
||||
- `verify` — run the witness bundle + Python proof (`verify.py` → VERDICT: PASS),
|
||||
ADR-028.
|
||||
- Ports of `ruview-config-engineer` and `ruview-training-engineer`.
|
||||
|
||||
2. **MCP tool surface** (`@metaharness/kernel`-hosted MCP server, one schema per
|
||||
capability — see "MCP tools" below). This is what makes the harness *operate*
|
||||
RuView, not just talk about it.
|
||||
|
||||
3. **Guardrails** (the differentiator): the harness's system prompt and a
|
||||
pre-output hook enforce the "prove everything" rule — accuracy numbers must be
|
||||
tagged MEASURED (with a reproducer) or CLAIMED; the agent must run the mean-pose
|
||||
baseline before quoting PCK; firmware fixes are never presented as
|
||||
hardware-validated without a real boot log (the exact discipline this session
|
||||
followed for `v0.8.1-esp32`).
|
||||
|
||||
4. **Host adapters** — claude-code first (P1), then codex / opencode / copilot /
|
||||
pi-dev / hermes / rvm / github-actions (P3+), each via the published
|
||||
`@metaharness/host-*` package.
|
||||
|
||||
5. **Router** — `@metaharness/router` routes each step to the cheapest adequate
|
||||
model (e.g. a var-rename or a log-grep → Haiku; calibration-math reasoning or a
|
||||
security review → Sonnet/Opus), mirroring the repo's 3-tier routing (ADR-026).
|
||||
|
||||
### MCP tools (the operational surface)
|
||||
|
||||
| Tool | Wraps | Purpose |
|
||||
|------|-------|---------|
|
||||
| `ruview.onboard` | docs + agent | Pick a setup path, print the next concrete command |
|
||||
| `ruview.node.flash` | ESP-IDF subprocess (ADR `CLAUDE.local.md`) | Build + flash a firmware variant to a COM port |
|
||||
| `ruview.node.provision` | `provision.py` | Set SSID/password/target-ip/channel/MAC-filter over serial |
|
||||
| `ruview.node.monitor` | pyserial | Stream boot log; assert CSI is flowing (MGMT+DATA) |
|
||||
| `ruview.server.up` | sensing-server | Start the Axum sensing-server (`:3000`/`:5005`/`:8765`) |
|
||||
| `ruview.calibrate` | `wifi-densepose calibrate`/`enroll`/`train-room` | Run the ADR-151 room pipeline |
|
||||
| `ruview.room.watch` | `wifi-densepose room-watch` | Live presence/vitals from a trained room |
|
||||
| `ruview.verify` | `scripts/generate-witness-bundle.sh` + `verify.py` | Produce/verify the witness bundle (must be N/N PASS) |
|
||||
| `ruview.claim.check` | static lint | Scan output for untagged accuracy claims; flag MEASURED-vs-CLAIMED |
|
||||
|
||||
Each tool returns structured JSON and is fail-closed: a tool that cannot prove its
|
||||
result (e.g. `ruview.node.monitor` sees no CSI callbacks) returns an honest negative,
|
||||
never a fabricated success — consistent with the RuField `map_privacy` fail-closed
|
||||
posture (ADR-262 §3.3).
|
||||
|
||||
### The mint + pin flow (how the harness is produced)
|
||||
|
||||
```bash
|
||||
# P1 — mint from this repo, claude-code host, RuView vertical
|
||||
npx metaharness ruview --template vertical:ruview --host claude-code \
|
||||
--from-existing . --description "RuView WiFi-sensing operator agent" \
|
||||
--target ./harness/ruview
|
||||
|
||||
# readiness + fit/cost/safety scorecards (ADR-041) — gate before publish
|
||||
npx metaharness genome . # 7-section repo readiness
|
||||
npx metaharness score . --json # fit / cost / safety
|
||||
npx metaharness analyze . # recommended harness plan (no-exec)
|
||||
```
|
||||
|
||||
The minted harness is committed under `harness/ruview/` and **pinned** (kernel +
|
||||
host-adapter + router versions locked) so `npx ruview` is reproducible. Re-minting on
|
||||
a CLI/ADR change is a reviewed PR, not an implicit regeneration.
|
||||
|
||||
### Distribution: `npx ruview`
|
||||
|
||||
A small published package whose `bin` boots the pinned harness via the kernel:
|
||||
|
||||
- **Preferred name:** `ruview` (currently **free** on npm — verified 2026-06-17).
|
||||
- **Risk:** npm's typosquat filter may reject `ruview` as too close to `review` /
|
||||
`preview` (this session hit exactly that on `ruvn`→`levn`/`raven` and
|
||||
`worldgraph`→`world-graph`). **Fallback:** publish scoped `@ruvnet/ruview` (also
|
||||
free) and/or `npx ruvnet/ruview` straight from GitHub. Decide at publish time;
|
||||
do not unpublish to rename (the 24-h name-lock lesson from `worldgraphs`).
|
||||
- `bin: { "ruview": "bin/cli.js" }` — note **`bin/cli.js`, not `./bin/cli.js`** (npm
|
||||
strips the `./` form; this broke `ruvn@0.1.0` this session).
|
||||
- `npx ruview` with no args → `onboard` skill (interactive path picker).
|
||||
`npx ruview <skill> [...]` → run a specific skill. `npx ruview --host codex` →
|
||||
install the harness into an existing repo for that host.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
npx ruview (thin bin — boots the pinned harness)
|
||||
│
|
||||
@metaharness/kernel (wasm primary · NAPI-RS native fallback)
|
||||
├── host adapter ── claude-code | codex | opencode | copilot | pi-dev | hermes | rvm | github-actions
|
||||
├── @metaharness/router (k-NN cost-optimal model routing — DRACO P2 / ADR-026)
|
||||
└── RuView genome (pinned)
|
||||
├── skills onboard · provision-node · calibrate-room · train-pose · verify
|
||||
├── mcp tools ruview.node.* · ruview.calibrate · ruview.room.watch · ruview.verify · ruview.claim.check
|
||||
└── guardrails MEASURED-vs-CLAIMED · mean-pose baseline · no-unvalidated-firmware-claims
|
||||
│
|
||||
RuView assets (the real system the agent drives)
|
||||
├── wifi-densepose CLI calibrate / enroll / train-room / room-watch
|
||||
├── sensing-server :3000 / :5005 / :8765
|
||||
├── ESP-IDF subprocess build / flash / provision / monitor (COM8/COM9/COM12)
|
||||
└── witness bundle + verify.py
|
||||
```
|
||||
|
||||
Provenance: the harness ships an **ed25519 witness + SBOM (SPDX) + SLSA** chain
|
||||
(MetaHarness already does this for minted harnesses), so a recipient can verify the
|
||||
RuView harness was built from a specific monorepo commit — the agentic analogue of
|
||||
the firmware witness bundle (ADR-028).
|
||||
|
||||
## Phases
|
||||
|
||||
- **P1 — Mint & pin (claude-code).** `npx metaharness ruview --template
|
||||
vertical:ruview --from-existing . --host claude-code`. Port the three `ruview-*`
|
||||
subagents into host-neutral genome skills. Commit under `harness/ruview/`, pin
|
||||
versions. Acceptance: `npx metaharness score .` ≥ threshold; the harness can run
|
||||
`onboard` and `verify` end-to-end locally.
|
||||
- **P2 — MCP tool surface.** Implement the `ruview.*` MCP tools over the kernel
|
||||
(start with `onboard`, `verify`, `claim.check`, `node.monitor` — the read-only /
|
||||
proving tools), then the mutating ones (`node.flash`, `provision`, `calibrate`).
|
||||
Acceptance: `ruview.verify` returns the witness bundle PASS as structured JSON;
|
||||
`ruview.claim.check` flags a seeded untagged "100% accuracy" string.
|
||||
- **P3 — Publish `npx ruview` + multi-host.** Publish the bin package (name decision
|
||||
per Distribution). Add codex / opencode / copilot / pi-dev / hermes / rvm /
|
||||
github-actions adapters. Acceptance: `npx ruview` cold-starts on ≥3 hosts and runs
|
||||
`onboard`; provenance verifies.
|
||||
- **P4 — Router + guardrail hardening.** Wire `@metaharness/router`; calibrate the
|
||||
3-tier routing on a RuView task set. Make the MEASURED-vs-CLAIMED guardrail a hard
|
||||
pre-output gate. Acceptance: a benchmark of RuView tasks shows cost reduction vs
|
||||
all-Opus with no quality regression; the guardrail blocks an untagged accuracy
|
||||
claim in a red-team prompt.
|
||||
|
||||
## Consequences
|
||||
|
||||
**Positive**
|
||||
- One reproducible, signed entry point (`npx ruview`) that operates RuView on the
|
||||
host the user already has — onboarding goes from "clone a 15-crate monorepo" to a
|
||||
single `npx`.
|
||||
- The "prove everything" ethos becomes **executable**, not just documentation: the
|
||||
harness *enforces* MEASURED-vs-CLAIMED and the mean-pose baseline.
|
||||
- Knowledge written once (host-neutral genome) instead of 9× per host; regenerable
|
||||
from the repo as the system evolves.
|
||||
- Dogfoods MetaHarness on a hard real vertical, surfacing bugs back to
|
||||
`agent-harness-generator` (this session already filed #9–#13 there).
|
||||
|
||||
**Negative / risks**
|
||||
- **Drift:** a pinned harness goes stale as the CLI/ADRs move; mitigated by a
|
||||
re-mint-on-change PR ritual and a CI check that the genome's referenced
|
||||
CLI flags still exist.
|
||||
- **Surface area:** mutating MCP tools (`node.flash`, `provision`) touch hardware and
|
||||
the network — must be permission-gated and fail-closed; the firmware-flash tool
|
||||
must never claim hardware validation without a captured boot log.
|
||||
- **Name/typosquat:** `ruview` may be rejected at publish; scoped fallback decided in
|
||||
P3. Do not unpublish-to-rename.
|
||||
- **Host parity:** not all 9 hosts support MCP + hooks equally; the guardrail gate
|
||||
may degrade to advisory on weaker hosts — must be disclosed in the badge, not
|
||||
hidden (same honesty principle as ADR-181's backend badge).
|
||||
- **Windows-coupled tooling:** the ESP-IDF flow is Windows-subprocess-specific
|
||||
today; the `node.*` tools are gated to that environment until a cross-platform
|
||||
path exists.
|
||||
|
||||
## Alternatives considered
|
||||
|
||||
1. **Keep the `ruview-*` subagents repo-local (status quo).** Zero new surface, but
|
||||
stays Claude-Code-only and clone-gated; no portable front door. Rejected — it's
|
||||
the gap this ADR exists to close.
|
||||
2. **Hand-write a bespoke `npx ruview` harness (no MetaHarness).** Full control, but
|
||||
re-implements the kernel, 9 host adapters, the router, and the provenance chain
|
||||
we already ship — months of duplicated work and 9 divergent configs to maintain.
|
||||
Rejected.
|
||||
3. **Use the generic `vertical:ruview` template as-is.** It's scaffolding with no
|
||||
real tools or guardrails — it would *talk about* RuView without being able to
|
||||
*operate* it or enforce honesty. Rejected as insufficient; P2 is precisely the
|
||||
hardening that makes it real.
|
||||
4. **Ship only an MCP server (no harness/host adapters).** Covers tools but not the
|
||||
skills, routing, guardrails, or multi-host projection — a strictly smaller subset
|
||||
of this design. Folded in as the P2 layer rather than the whole.
|
||||
|
||||
## Open questions
|
||||
|
||||
- Final published name: bare `ruview` vs scoped `@ruvnet/ruview` vs GitHub-only
|
||||
`npx ruvnet/ruview` — resolve against the typosquat filter at P3.
|
||||
- Does the harness bundle the `wifi-densepose` binary, shell out to a user-installed
|
||||
one, or offer both? (Leaning: shell out; print install guidance if absent.)
|
||||
- Where do the `node.*` hardware tools live for non-Windows users — defer, or wrap
|
||||
the rvCSI runtime (ADR-095/096) which is cross-platform Rust?
|
||||
- Should `ruview.verify` gate `npx ruview` self-tests in CI (harness can't publish if
|
||||
the witness bundle regresses)?
|
||||
- Relationship to the RuField MFS harness surface (ADR-260/262) — one harness with a
|
||||
RuField skill, or a sibling `npx rufield`?
|
||||
|
||||
## References
|
||||
|
||||
- MetaHarness: `metaharness@0.1.15` (`npx metaharness`, templates incl.
|
||||
`vertical:ruview`; hosts: claude-code/codex/pi-dev/hermes/openclaw/rvm/copilot/
|
||||
opencode/github-actions), `@metaharness/kernel`, `@metaharness/router`,
|
||||
`@metaharness/host-*`, repo `github.com/ruvnet/agent-harness-generator`.
|
||||
- RuView subagents: `ruview-onboarding-guide`, `ruview-config-engineer`,
|
||||
`ruview-training-engineer` (`.claude/agents/`).
|
||||
- ADR-026 (3-tier model routing), ADR-028 (witness verification), ADR-041
|
||||
(MetaHarness scorecards), ADR-060 (channel / MAC-filter overrides), ADR-079
|
||||
(camera ground-truth training), ADR-095/096 (rvCSI runtime), ADR-151 (per-room
|
||||
calibration), ADR-152/181 (WiFlow / browser pose), ADR-260/262 (RuField bridge).
|
||||
@@ -0,0 +1,98 @@
|
||||
# ADR-183: Onboard LED as a 40 Hz Gamma Stimulus, Colour-Mapped from Live CSI via `ruv-neural-viz`
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Accepted — implemented & hardware-confirmed on ESP32-S3 N16R8 (COM8) |
|
||||
| **Date** | 2026-06-17 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **GAMMA-VIZ** |
|
||||
| **Builds on** | `ruv-neural-viz::ColorMap` (now `no_std` — ruvnet/ruv-neural#3 / RuView#1126), the ESP32 edge `motion_energy` metric (`edge_processing.c`), PR #962 (WS2812 on GPIO 48) |
|
||||
|
||||
## Context
|
||||
|
||||
Two threads converged. (1) `ruv-neural-viz::ColorMap` — the viridis/cool-warm
|
||||
palette the rUv-Neural stack uses to render brain-topology graphs — was `std`-only,
|
||||
so it couldn't run on the ESP32. (2) The onboard WS2812 on the S3 CSI node was dead
|
||||
weight: the firmware only cleared it on boot (and on the wrong pin for N16R8 — GPIO
|
||||
38 vs the actual 48, see #962).
|
||||
|
||||
The ask: make the LED do something real and honest, using the project's own visual
|
||||
capability — not a decorative blink. The natural fit is a **40 Hz gamma stimulus**
|
||||
(the GENUS gamma-entrainment frequency from Alzheimer's light-therapy research)
|
||||
whose **colour is driven by live sensed motion**, so the node's front panel is both
|
||||
a known bio-stimulus waveform and a truthful readout of what the CSI is detecting.
|
||||
|
||||
## Decision
|
||||
|
||||
### Part A — make `ColorMap` `no_std`
|
||||
|
||||
`colormap.rs` is self-contained (no cross-crate deps), so expose it on `no_std`
|
||||
targets. The only blockers were two `std`-only `f64` ops:
|
||||
|
||||
- `f64::round` / `f64::abs` → replaced with `core`+`alloc`-safe helpers `fround`
|
||||
(round via `f64 as i64` truncation — a `core` cast, no `libm`) and `fabs`.
|
||||
- `Vec`/`String`/`format!` → from `alloc`.
|
||||
|
||||
The graph-bound modules (`animation`/`ascii`/`export`/`layout`) and their heavy deps
|
||||
move behind a default `std` feature; `--no-default-features` builds the crate `no_std`
|
||||
and exposes only `colormap`. Output is **byte-identical** (8/8 colormap tests pass with
|
||||
the same RGB values), so this is a pure portability change.
|
||||
|
||||
### Part B — the LED stimulus (firmware)
|
||||
|
||||
`firmware/esp32-csi-node/main/main.c`, on boot:
|
||||
|
||||
- WS2812 on **GPIO 48** (N16R8 / DevKitC-1 v1.1; GPIO 8 on C6).
|
||||
- An `esp_timer` periodic at **12 500 µs toggles a square wave → 40 Hz, 50 % duty**
|
||||
(full-on / full-off — a *perceptible* gamma flicker, not a colour drift).
|
||||
- **ON-phase colour = live CSI motion.** Each ON phase reads `edge_get_vitals().motion_energy`,
|
||||
normalises it (`/ LED_MOTION_FULLSCALE`, clamped `[0,1]`), and indexes a **60-step
|
||||
viridis LUT generated from `ColorMap::viridis().map()`** — still = dark purple,
|
||||
strong motion = yellow.
|
||||
|
||||
The LUT is baked from the real crate (Part A makes the same `ColorMap` embeddable
|
||||
for a future direct FFI path once the ESP Rust toolchain is in CI). The colours are
|
||||
therefore provably `ruv-neural-viz`'s, and the motion is provably real.
|
||||
|
||||
## Honesty (what it is and is not)
|
||||
|
||||
- **40 Hz is a real square-wave stimulus** (12.5 ms on / 12.5 ms off), not a label on
|
||||
a colour sweep. It is *not* tied to any measured 40 Hz brain rhythm — it is an
|
||||
*output* stimulus at the gamma frequency, not a readout of neural gamma.
|
||||
- **Colour is a real CSI readout** — `motion_energy` is the on-device phase-variance
|
||||
motion metric the node already computes; no fabrication. At rest the LED sits at the
|
||||
purple (low) end and flickers there.
|
||||
- No therapeutic claim is made. 40 Hz GENUS entrainment is cited as the *origin of the
|
||||
frequency choice*, not as a validated medical effect of this device.
|
||||
|
||||
## Consequences
|
||||
|
||||
**Positive**
|
||||
- The LED is now an honest front-panel: gamma-frequency flicker + a live motion readout.
|
||||
- `ColorMap` is embeddable (`no_std`), unblocking on-device use of the rUv-Neural
|
||||
palette beyond this LED.
|
||||
- Confirms #962's GPIO-48 fix visually (the LED lights on N16R8).
|
||||
|
||||
**Negative / risks**
|
||||
- Changes the *default* firmware behaviour: the onboard LED animates instead of staying
|
||||
off. Now **gated by `CONFIG_LED_GAMMA_VIZ`** (default `y`); set it `n` for a dark,
|
||||
lower-power boot (the LED is just cleared) — no source change needed.
|
||||
- A 40 Hz flicker can be an issue for photosensitive users; document on the enclosure
|
||||
and disable `CONFIG_LED_GAMMA_VIZ` in those deployments.
|
||||
- The saturation point is now `CONFIG_LED_MOTION_FULLSCALE_MILLI` (default 250 = 0.25),
|
||||
operator-tunable; still not auto-calibrated per-environment.
|
||||
- The colour uses a baked LUT, not the live Rust `ColorMap` (FFI path deferred — needs
|
||||
the ESP Rust/xtensa toolchain, not yet in CI).
|
||||
|
||||
## Validation
|
||||
|
||||
- `ruv-neural-viz`: `cargo build` (std) ✓, `cargo test colormap` 8/8 ✓ (identical RGB),
|
||||
`cargo build --no-default-features` compiles `no_std` ✓.
|
||||
- Firmware: built (1.13 MB), flashed to ESP32-S3 N16R8 (COM8). Boot log:
|
||||
`Onboard WS2812: 40 Hz gamma flicker (GENUS), colour=CSI motion via ruv-neural-viz, GPIO 48`;
|
||||
CSI continues (27–38 pps), `motion=0.00` at rest → purple flicker as designed.
|
||||
- Full on-device (xtensa) Rust build of `ColorMap` not run — ESP Rust toolchain absent.
|
||||
|
||||
## References
|
||||
- ruvnet/ruv-neural#3 (ColorMap no_std), RuView#1126 (submodule bump), #962 (GPIO 48).
|
||||
- Singer/Tsai GENUS 40 Hz gamma entrainment (origin of the frequency, not a device claim).
|
||||
@@ -0,0 +1,135 @@
|
||||
# WiFlow Browser Trainer (`wiflow_browser.html`)
|
||||
|
||||
A **single self-contained HTML page** that does the entire camera-supervised
|
||||
WiFi-pose loop **in your browser, in your laptop camera's coordinate frame**, as
|
||||
a **4-stage gated flow** with a progress stepper (each stage unlocks the next):
|
||||
|
||||
0. **CALIBRATE** *(ADR-151 empty-room baseline)* — you step OUT of the space; the
|
||||
page captures ~10 s of the quiescent CSI and computes a per-feature running
|
||||
**mean + std (Welford)** over the 410-d vector. Every CSI vector afterwards is
|
||||
expressed as **deviation from baseline**
|
||||
(`x_norm = (x − base_mean) / (base_std + ε)`), so a body's perturbation stands
|
||||
out from the static channel. Persisted to IndexedDB. *Can't capture without it.*
|
||||
1. **CAPTURE** — MediaPipe Pose runs on your laptop camera → 17 COCO keypoints
|
||||
(the *label*), paired with the **baseline-normalized** 410-d ESP32 CSI vector
|
||||
(the *input*). A **guided, balanced routine** cycles big on-screen prompts
|
||||
(stand / turn / walk / arms / crouch / sit / reach) with a countdown, and a
|
||||
**per-pose coverage meter** so you build a balanced dataset, not 2 000 frames
|
||||
of standing.
|
||||
2. **TRAIN** — a TensorFlow.js MLP learns `CSI → pose` in-browser. Honest
|
||||
held-out PCK@0.10 / PCK@0.05 / MPJPE, plus a **mean-pose baseline** the model
|
||||
must beat (the project's whole ethos — no baseline-beating signal, it says so).
|
||||
*Can't train with <200 samples.*
|
||||
3. **INFER** — the trained model drives a skeleton **from WiFi CSI only**
|
||||
(baseline-normalized → standardized → model), drawn over the **same** camera
|
||||
frame it trained in — so the inferred skeleton **aligns** with the camera
|
||||
image. That alignment is the entire point of doing this in-browser instead of
|
||||
with a separate Python camera. *Can't infer without a model.*
|
||||
|
||||
## Why in-browser
|
||||
|
||||
The Python pipeline (`wiflow_capture.py` → `wiflow_train.py` → `wiflow_infer.py`)
|
||||
proved the signal is real (held-out PCK@0.10 ≈ 59.5% vs a 50% mean-pose baseline
|
||||
= +9.4 pp). But it trained in a *different* camera's frame, so the inferred
|
||||
skeleton never lined up with the laptop camera. Doing capture + train + infer all
|
||||
in the browser with the **same** camera makes the training frame and the
|
||||
inference frame identical → the skeleton aligns.
|
||||
|
||||
## Compute backends (WebGPU / WASM / WebGL)
|
||||
|
||||
Training and inference run on TensorFlow.js. The page selects the backend at
|
||||
startup, preferring the fastest available:
|
||||
|
||||
- **WebGPU** (Chrome / Edge, secure context — `localhost` qualifies) — GPU compute.
|
||||
- **WASM-SIMD** fallback (`tfjs-backend-wasm`, SIMD enabled, `.wasm` from the CDN).
|
||||
- **WebGL** last-resort fallback (ships inside tfjs core).
|
||||
|
||||
The **active backend is shown as a badge in the header** (`compute: WebGPU` /
|
||||
`WASM-SIMD` / `WebGL`) so it's honest about what's actually running. The model
|
||||
code is backend-agnostic — tf.js abstracts the device.
|
||||
|
||||
## Honesty (baked in)
|
||||
|
||||
- The **CAPTURE** skeleton (blue) is the camera = ground truth, labeled as such.
|
||||
- The **INFER** skeleton (green) is **CSI-only**, labeled, and **coarse** — the
|
||||
real measured held-out PCK is shown, not a marketing number.
|
||||
- The **mean-pose baseline** is always computed and shown in TRAIN; the verdict
|
||||
states plainly whether the model **beats** it (real signal) or **does not**
|
||||
(no usable signal). This guards against the project's retracted 92.9% that
|
||||
failed exactly this check.
|
||||
- Status banner is strict and mutually exclusive:
|
||||
**LIVE** (real `source: "esp32"`) / **SIMULATED — not real** (any other source)
|
||||
/ **NO-CSI-SERVER**. The page never invents frames.
|
||||
|
||||
## How to run
|
||||
|
||||
### 1. Start the real sensing-server (provides the CSI WebSocket on :8765)
|
||||
|
||||
```bash
|
||||
cd v2
|
||||
cargo build -p wifi-densepose-sensing-server
|
||||
./target/debug/sensing-server.exe --ws-port 8765 --udp-port 5005
|
||||
```
|
||||
|
||||
A real ESP32-S3 must be provisioned and streaming for `source` to read `esp32`
|
||||
(see `CLAUDE.local.md` for the firmware build/provision steps). The page expects
|
||||
the verified live endpoint **`ws://localhost:8765/ws/sensing`** with
|
||||
`source:"esp32"`, nodes `[9, 13]`, `features.*`, `node_features[].features.*`,
|
||||
and `signal_field.values` (400 floats).
|
||||
|
||||
### 2. Serve this page over localhost (camera + WebGPU need a localhost/secure origin)
|
||||
|
||||
Any static localhost server works. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server 8099
|
||||
# then open: http://localhost:8099/examples/through-wall/wiflow_browser.html
|
||||
```
|
||||
|
||||
(8099 is just the static file server — 8765 is a separate process, the CSI
|
||||
WebSocket.) Allow camera access when the browser prompts.
|
||||
|
||||
Point at a CSI server on another host with `?ws=`:
|
||||
|
||||
```
|
||||
http://localhost:8099/examples/through-wall/wiflow_browser.html?ws=ws://192.168.1.20:8765/ws/sensing
|
||||
```
|
||||
|
||||
### 3. Use it
|
||||
|
||||
1. **CAPTURE** tab → *enable laptop camera* → *start recording*. Follow the guided
|
||||
routine (stand / turn / walk / arms / crouch / sit). A pair is stored only when
|
||||
a confident pose AND a fresh live `esp32` CSI frame coexist. Aim for a few
|
||||
thousand samples. Samples persist in IndexedDB across refreshes.
|
||||
2. **TRAIN** tab → *train model*. Watch the live loss curve, held-out PCK, and the
|
||||
baseline verdict. The model saves to IndexedDB.
|
||||
3. **INFER** tab → the green skeleton is now driven by WiFi CSI only, aligned over
|
||||
your camera. Toggle *hide camera* to see the CSI-only skeleton on black.
|
||||
|
||||
## The 410-d CSI vector (matches the Python pipeline exactly)
|
||||
|
||||
```
|
||||
[ mean_rssi, variance, motion_band_power, breathing_band_power ] # 4 (features.*)
|
||||
+ for node 9 then node 13: [ mean_rssi, variance, motion_band_power ] # 6 (node_features[].features.*)
|
||||
+ signal_field.values, padded / truncated to 400 # 400
|
||||
= 410-d
|
||||
```
|
||||
|
||||
Verified against a real live frame: the in-browser `csiVector()` produces the
|
||||
identical 410 vector as `wiflow_capture.py`'s `csi_vector()` (node 9 first, then
|
||||
node 13; field zero-padded).
|
||||
|
||||
## Libraries (CDN only, no bundler)
|
||||
|
||||
| Library | CDN |
|
||||
|---|---|
|
||||
| TensorFlow.js core | `@tensorflow/tfjs@4.22.0/dist/tf.min.js` |
|
||||
| TF.js WebGPU backend | `@tensorflow/tfjs-backend-webgpu@4.22.0/dist/tf-backend-webgpu.min.js` |
|
||||
| TF.js WASM backend | `@tensorflow/tfjs-backend-wasm@4.22.0/dist/tf-backend-wasm.min.js` |
|
||||
| MediaPipe Pose 0.5 (legacy solutions) | `@mediapipe/pose@0.5/pose.js` |
|
||||
|
||||
## Scope / honesty caveats
|
||||
|
||||
Same person, same room, same session. **Not** validated cross-day, cross-room, or
|
||||
through-wall. The inferred pose is coarse (PCK@0.05 is typically weak). If the
|
||||
model does not beat the mean-pose baseline, the page says so — that is a feature.
|
||||
@@ -0,0 +1,644 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>RuView · Through-Wall WiFi Sensing · LIVE CSI (no skeleton, no simulation)</title>
|
||||
<!--
|
||||
THROUGH-WALL WiFi-CSI SENSING DEMO — honest, real-data-only.
|
||||
|
||||
Renders ONLY what the running sensing-server actually streams over
|
||||
ws://localhost:8765/ws/sensing :
|
||||
- the 20x20 `signal_field` floor heatmap (real values)
|
||||
- a coarse RF-localization puck from persons[0].position (NOT pose)
|
||||
- live motion / presence / rssi / confidence meters
|
||||
- the real `source` ("esp32" = LIVE) verbatim in the banner
|
||||
|
||||
It deliberately does NOT draw a skeleton. The server's
|
||||
persons[].keypoints carry confidence:0.0 (image-pixel garbage, not
|
||||
real 3D joints) so we never render them. WiFi CSI gives
|
||||
motion/presence/coarse-position — that is the honest wow, and it
|
||||
penetrates drywall. See README.md.
|
||||
-->
|
||||
<style>
|
||||
:root {
|
||||
--bg: #050507; --bg-panel: rgba(8,10,14,0.80);
|
||||
--amber: #ffb840; --amber-hot: #ffe09f;
|
||||
--cyan: #4cf; --magenta: #ff4cc8;
|
||||
--text: #d8c69a; --text-mute: #6b6155;
|
||||
--green: #4f4; --red: #f64;
|
||||
--border: rgba(255,184,64,0.18);
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
margin: 0; background: var(--bg); color: var(--text); overflow: hidden;
|
||||
font-family: 'SF Mono', 'Cascadia Code', Consolas, monospace;
|
||||
-webkit-font-smoothing: antialiased; font-size: 12px;
|
||||
}
|
||||
canvas { display: block; }
|
||||
.overlay-frame {
|
||||
position: fixed; inset: 0; pointer-events: none; z-index: 5;
|
||||
background:
|
||||
radial-gradient(ellipse at center, transparent 55%, rgba(0,0,0,0.55) 100%),
|
||||
linear-gradient(180deg, rgba(0,0,0,0.32) 0%, transparent 18%, transparent 82%, rgba(0,0,0,0.38) 100%);
|
||||
}
|
||||
.scanlines {
|
||||
position: fixed; inset: 0; pointer-events: none; z-index: 6;
|
||||
background: repeating-linear-gradient(0deg, rgba(0,0,0,0.04) 0px, rgba(0,0,0,0.04) 1px, transparent 1px, transparent 3px);
|
||||
mix-blend-mode: overlay; opacity: 0.5;
|
||||
}
|
||||
.panel {
|
||||
position: absolute; background: var(--bg-panel); border: 1px solid var(--border);
|
||||
border-radius: 4px; padding: 12px 14px; backdrop-filter: blur(8px);
|
||||
box-shadow: 0 1px 0 rgba(255,184,64,0.04), 0 8px 32px rgba(0,0,0,0.55); z-index: 10;
|
||||
}
|
||||
.panel h2 {
|
||||
margin: 0 0 8px 0; font-size: 10px; text-transform: uppercase; letter-spacing: 2px;
|
||||
color: var(--amber); font-weight: 600; border-bottom: 1px solid var(--border); padding-bottom: 6px;
|
||||
}
|
||||
|
||||
/* ---- Honest status banner (top-center, mutually exclusive states) ---- */
|
||||
#banner {
|
||||
position: fixed; top: 0; left: 0; right: 0; z-index: 30;
|
||||
text-align: center; padding: 7px 12px; font-size: 12px; letter-spacing: 1px;
|
||||
font-weight: 600; border-bottom: 1px solid rgba(0,0,0,0.4);
|
||||
transition: background 0.3s, color 0.3s;
|
||||
}
|
||||
#banner.live { background: rgba(40,255,80,0.12); color: var(--green); border-bottom-color: rgba(80,255,120,0.4); }
|
||||
#banner.sim { background: rgba(255,120,40,0.16); color: #ffae5a; border-bottom-color: rgba(255,140,60,0.5); }
|
||||
#banner.noserver { background: rgba(255,80,80,0.16); color: var(--red); border-bottom-color: rgba(255,90,90,0.5); }
|
||||
#banner .src { opacity: 0.8; font-weight: 400; }
|
||||
#banner-caption {
|
||||
position: fixed; top: 30px; left: 0; right: 0; z-index: 29;
|
||||
text-align: center; font-size: 10px; color: var(--text-mute); letter-spacing: 0.5px;
|
||||
pointer-events: none; padding-top: 2px;
|
||||
}
|
||||
|
||||
#info { top: 64px; left: 20px; min-width: 270px; }
|
||||
#info h1 { margin: 0 0 1px 0; font-size: 13px; letter-spacing: 1px; color: var(--amber-hot); font-weight: 600; }
|
||||
#info .sub { font-size: 10px; color: var(--text-mute); letter-spacing: 0.5px; margin-bottom: 10px; padding-bottom: 8px; border-bottom: 1px solid var(--border); }
|
||||
#info .row { display: flex; justify-content: space-between; gap: 12px; padding: 2px 0; }
|
||||
#info .row .k { color: var(--text-mute); font-size: 11px; }
|
||||
#info .row .v { color: var(--text); font-variant-numeric: tabular-nums; font-size: 11px; }
|
||||
#info .row .v.amber { color: var(--amber); }
|
||||
#info .row .v.cyan { color: var(--cyan); }
|
||||
#info .row .v.green { color: var(--green); }
|
||||
#info .row .v.red { color: var(--red); }
|
||||
#info .row .v.mag { color: var(--magenta); }
|
||||
#info .row .v.mute { color: var(--text-mute); }
|
||||
|
||||
#csi { top: 64px; right: 20px; min-width: 270px; }
|
||||
#csi .bar-row { display: flex; align-items: center; gap: 8px; padding: 3px 0; font-size: 10px; }
|
||||
#csi .bar-row .label { width: 86px; color: var(--text-mute); }
|
||||
#csi .bar-row .bar-track { flex: 1; height: 6px; background: rgba(255,184,64,0.08); border-radius: 2px; overflow: hidden; }
|
||||
#csi .bar-row .bar-fill {
|
||||
height: 100%; background: linear-gradient(90deg, var(--amber-hot), var(--amber));
|
||||
box-shadow: 0 0 6px var(--amber); transition: width 0.1s linear;
|
||||
}
|
||||
#csi .bar-row .val { width: 44px; text-align: right; color: var(--amber); font-variant-numeric: tabular-nums; }
|
||||
#csi .spark { margin-top: 8px; }
|
||||
#csi canvas { width: 100%; height: 38px; display: block; border: 1px solid var(--border); border-radius: 3px; background: rgba(0,0,0,0.3); }
|
||||
#csi .legend { margin-top: 8px; padding-top: 8px; border-top: 1px solid var(--border); font-size: 10px; color: var(--text-mute); line-height: 1.5; }
|
||||
|
||||
/* ---- waiting / no-server overlay ---- */
|
||||
#waiting {
|
||||
position: fixed; inset: 0; z-index: 25; display: none;
|
||||
flex-direction: column; align-items: center; justify-content: center;
|
||||
background: rgba(5,5,7,0.94); color: var(--amber); text-align: center; padding: 24px;
|
||||
}
|
||||
#waiting.show { display: flex; }
|
||||
#waiting .big { font-size: 22px; letter-spacing: 2px; color: var(--red); margin-bottom: 16px; text-transform: uppercase; }
|
||||
#waiting code {
|
||||
display: block; text-align: left; max-width: 640px; margin: 8px auto;
|
||||
background: rgba(255,184,64,0.06); border: 1px solid var(--border); border-radius: 4px;
|
||||
padding: 10px 14px; color: var(--amber-hot); font-size: 12px; white-space: pre-wrap;
|
||||
}
|
||||
#waiting .pulse { animation: pulse 1.4s ease-in-out infinite; }
|
||||
@keyframes pulse { 0%,100% { opacity: 0.55; } 50% { opacity: 1; } }
|
||||
|
||||
/* ---- optional webcam ground-truth tile ---- */
|
||||
#cam-tile {
|
||||
position: absolute; bottom: 20px; right: 20px; width: 240px; z-index: 12;
|
||||
background: var(--bg-panel); border: 1px solid var(--border); border-radius: 4px;
|
||||
padding: 8px; backdrop-filter: blur(8px);
|
||||
}
|
||||
#cam-tile h2 { margin: 0 0 6px 0; font-size: 9px; text-transform: uppercase; letter-spacing: 1.5px;
|
||||
color: var(--cyan); font-weight: 600; }
|
||||
#cam-tile .gt-note { font-size: 9px; color: var(--text-mute); margin-top: 4px; line-height: 1.4; }
|
||||
#cam-video { width: 100%; border-radius: 3px; display: none; background: #000; }
|
||||
#cam-tile button {
|
||||
width: 100%; margin-top: 6px; padding: 5px 8px; font-family: inherit; font-size: 11px;
|
||||
background: transparent; color: var(--cyan); border: 1px solid var(--cyan); border-radius: 3px; cursor: pointer;
|
||||
}
|
||||
#cam-tile button:hover { background: rgba(68,204,255,0.12); }
|
||||
#cam-tile button:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
|
||||
#legend-nodes {
|
||||
position: absolute; bottom: 20px; left: 20px; min-width: 220px;
|
||||
background: var(--bg-panel); border: 1px solid var(--border); border-radius: 4px;
|
||||
padding: 12px 14px; backdrop-filter: blur(8px); z-index: 10;
|
||||
}
|
||||
#legend-nodes h2 { margin: 0 0 8px 0; font-size: 10px; text-transform: uppercase; letter-spacing: 2px;
|
||||
color: var(--amber); font-weight: 600; border-bottom: 1px solid var(--border); padding-bottom: 6px; }
|
||||
#legend-nodes .lr { display: flex; align-items: center; gap: 8px; padding: 2px 0; font-size: 11px; }
|
||||
#legend-nodes .dot { width: 9px; height: 9px; border-radius: 50%; box-shadow: 0 0 6px currentColor; flex: 0 0 auto; }
|
||||
#legend-nodes .muted { color: var(--text-mute); }
|
||||
</style>
|
||||
|
||||
<!-- three.js r128 + addons (same CDN set as examples/three.js/demos/05) -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/postprocessing/EffectComposer.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/postprocessing/RenderPass.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/postprocessing/ShaderPass.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/postprocessing/UnrealBloomPass.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/shaders/CopyShader.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/shaders/LuminosityHighPassShader.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="banner" class="noserver">NO SERVER — start the sensing-server <span class="src"></span></div>
|
||||
<div id="banner-caption">Real WiFi CSI motion / presence / coarse-localization — penetrates drywall. Not skeletal pose.</div>
|
||||
|
||||
<div class="overlay-frame"></div>
|
||||
<div class="scanlines"></div>
|
||||
|
||||
<div class="panel" id="info">
|
||||
<h1>THROUGH-WALL WiFi SENSING</h1>
|
||||
<div class="sub">Live CSI · ws://localhost:8765/ws/sensing</div>
|
||||
<div class="row"><span class="k">source</span><span class="v amber" id="m-source">—</span></div>
|
||||
<div class="row"><span class="k">presence</span><span class="v" id="m-presence">—</span></div>
|
||||
<div class="row"><span class="k">motion level</span><span class="v" id="m-motion">—</span></div>
|
||||
<div class="row"><span class="k">confidence</span><span class="v cyan" id="m-conf">—</span></div>
|
||||
<div class="row"><span class="k">est. persons</span><span class="v amber" id="m-persons">—</span></div>
|
||||
<div class="row"><span class="k">active nodes</span><span class="v" id="m-nodes">—</span></div>
|
||||
<div class="row"><span class="k">tick</span><span class="v" id="m-tick">—</span></div>
|
||||
<div class="row"><span class="k">update rate</span><span class="v cyan" id="m-fps">—</span></div>
|
||||
</div>
|
||||
|
||||
<div class="panel" id="csi">
|
||||
<h2>Live RF features</h2>
|
||||
<div class="bar-row"><span class="label">motion</span><div class="bar-track"><div class="bar-fill" id="bar-motion"></div></div><span class="val" id="v-motion">—</span></div>
|
||||
<div class="bar-row"><span class="label">breathing</span><div class="bar-track"><div class="bar-fill" id="bar-breath"></div></div><span class="val" id="v-breath">—</span></div>
|
||||
<div class="bar-row"><span class="label">variance</span><div class="bar-track"><div class="bar-fill" id="bar-var"></div></div><span class="val" id="v-var">—</span></div>
|
||||
<div class="bar-row"><span class="label">mean rssi</span><div class="bar-track"><div class="bar-fill" id="bar-rssi"></div></div><span class="val" id="v-rssi">—</span></div>
|
||||
<div class="spark"><canvas id="spark" width="252" height="38"></canvas></div>
|
||||
<div class="legend">motion sparkline (last ~6s of real motion_band_power)</div>
|
||||
</div>
|
||||
|
||||
<div id="legend-nodes">
|
||||
<h2>Sensor nodes</h2>
|
||||
<div class="lr"><span class="dot" style="color:#4cf"></span><span>ESP32-S3 office <span class="muted">(node 9)</span></span></div>
|
||||
<div class="lr"><span class="dot" style="color:#ff4cc8"></span><span>ESP32-S3 hallway <span class="muted">(node 13)</span></span></div>
|
||||
<div class="lr" style="margin-top:6px"><span class="dot" style="color:#4f4"></span><span>RF localization <span class="muted">(coarse)</span></span></div>
|
||||
<div class="lr"><span class="muted" style="font-size:10px;line-height:1.4">Office & hallway split by a wall + doorway. WiFi motion still shows through drywall.</span></div>
|
||||
</div>
|
||||
|
||||
<div id="cam-tile">
|
||||
<h2>camera — ground truth when visible</h2>
|
||||
<video id="cam-video" autoplay muted playsinline></video>
|
||||
<button id="cam-btn">▶ enable webcam (optional)</button>
|
||||
<div class="gt-note">Independent of the CSI sensing. The WiFi works in the dark and through walls; the camera does not.</div>
|
||||
</div>
|
||||
|
||||
<div id="waiting" class="show">
|
||||
<div class="big pulse">Waiting for live sensing-server</div>
|
||||
<div>No connection to <b>ws://localhost:8765/ws/sensing</b>. Start the real server, then this page connects automatically.</div>
|
||||
<code>cd v2
|
||||
cargo build -p wifi-densepose-sensing-server
|
||||
./target/debug/sensing-server.exe --ws-port 8765 --udp-port 5005</code>
|
||||
<div style="margin-top:10px; color:var(--text-mute); font-size:11px;">This demo renders ONLY real data. It never invents frames.</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
"use strict";
|
||||
// =====================================================================
|
||||
// Config + WS endpoint (allow ?ws= override)
|
||||
// =====================================================================
|
||||
const params = new URLSearchParams(location.search);
|
||||
const WS_URL = params.get('ws') || 'ws://localhost:8765/ws/sensing';
|
||||
const ROOM_HALF = 5; // half-extent of the floor plane in metres
|
||||
const GRID_N = 20; // signal_field is 20 x 20
|
||||
|
||||
// Known node anchor positions (server sends node 9 @ [2,0,1.5]; node 13
|
||||
// joins later from the hallway side once its firmware is flashed). These
|
||||
// are anchors for the room model + labels, NOT fabricated sensing data.
|
||||
const NODE_ANCHORS = {
|
||||
9: { pos: [ 2.0, 0.0, 1.5], color: 0x44ccff, label: 'office (node 9)' },
|
||||
13: { pos: [-2.0, 0.0, -3.0], color: 0xff4cc8, label: 'hallway (node 13)' },
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Three.js scene (reused pattern from demos/05-skinned-realtime.html)
|
||||
// =====================================================================
|
||||
const scene = new THREE.Scene();
|
||||
scene.background = new THREE.Color(0x050507);
|
||||
scene.fog = new THREE.FogExp2(0x050507, 0.045);
|
||||
|
||||
const camera = new THREE.PerspectiveCamera(50, window.innerWidth/window.innerHeight, 0.05, 100);
|
||||
camera.position.set(4.5, 4.2, 6.0);
|
||||
|
||||
const renderer = new THREE.WebGLRenderer({ antialias: true, powerPreference: 'high-performance' });
|
||||
renderer.setPixelRatio(Math.min(2, window.devicePixelRatio));
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
renderer.toneMapping = THREE.ACESFilmicToneMapping;
|
||||
renderer.toneMappingExposure = 0.85;
|
||||
renderer.outputEncoding = THREE.sRGBEncoding;
|
||||
document.body.appendChild(renderer.domElement);
|
||||
|
||||
const controls = new THREE.OrbitControls(camera, renderer.domElement);
|
||||
controls.target.set(0, 0.4, -0.5);
|
||||
controls.enableDamping = true; controls.dampingFactor = 0.06;
|
||||
controls.minDistance = 3; controls.maxDistance = 18;
|
||||
controls.maxPolarAngle = Math.PI * 0.49;
|
||||
|
||||
scene.add(new THREE.HemisphereLight(0x553a18, 0x080606, 0.7));
|
||||
const keyLight = new THREE.DirectionalLight(0xffc070, 0.9);
|
||||
keyLight.position.set(3, 6, 4);
|
||||
scene.add(keyLight);
|
||||
|
||||
// Post-processing — gentle bloom so the heatmap + puck glow.
|
||||
const composer = new THREE.EffectComposer(renderer);
|
||||
composer.addPass(new THREE.RenderPass(scene, camera));
|
||||
const bloom = new THREE.UnrealBloomPass(
|
||||
new THREE.Vector2(window.innerWidth, window.innerHeight), 0.55, 0.45, 0.82);
|
||||
composer.addPass(bloom);
|
||||
|
||||
// =====================================================================
|
||||
// Room: floor grid + wall + doorway dividing office / hallway
|
||||
// =====================================================================
|
||||
const gridHelper = new THREE.GridHelper(2*ROOM_HALF, GRID_N, 0x554a32, 0x2a2418);
|
||||
gridHelper.position.y = 0.002;
|
||||
scene.add(gridHelper);
|
||||
|
||||
// Dividing wall runs along world X near z = -1 (office z>-1, hallway z<-1),
|
||||
// with a doorway gap. Two wall segments leave a gap in the middle.
|
||||
const wallMat = new THREE.MeshStandardMaterial({
|
||||
color: 0x1b2330, transparent: true, opacity: 0.55, roughness: 0.9,
|
||||
side: THREE.DoubleSide,
|
||||
});
|
||||
const wallH = 1.4, wallZ = -1.0;
|
||||
function addWallSeg(cx, w) {
|
||||
const m = new THREE.Mesh(new THREE.BoxGeometry(w, wallH, 0.08), wallMat);
|
||||
m.position.set(cx, wallH/2, wallZ);
|
||||
scene.add(m);
|
||||
// top edge highlight
|
||||
const edge = new THREE.Mesh(new THREE.BoxGeometry(w, 0.02, 0.10),
|
||||
new THREE.MeshBasicMaterial({ color: 0x4cf, transparent: true, opacity: 0.5 }));
|
||||
edge.position.set(cx, wallH, wallZ);
|
||||
scene.add(edge);
|
||||
}
|
||||
// left segment, doorway gap (-0.7..0.7), right segment
|
||||
addWallSeg(-3.15, 3.7);
|
||||
addWallSeg( 3.15, 3.7);
|
||||
|
||||
// Room labels (sprite text) for OFFICE / HALLWAY
|
||||
function makeLabel(text, color) {
|
||||
const c = document.createElement('canvas'); c.width = 256; c.height = 64;
|
||||
const ctx = c.getContext('2d');
|
||||
ctx.fillStyle = color; ctx.font = 'bold 30px Consolas, monospace';
|
||||
ctx.textAlign = 'center'; ctx.textBaseline = 'middle';
|
||||
ctx.fillText(text, 128, 34);
|
||||
const tex = new THREE.CanvasTexture(c);
|
||||
const spr = new THREE.Sprite(new THREE.SpriteMaterial({ map: tex, transparent: true, depthTest: false }));
|
||||
spr.scale.set(2.0, 0.5, 1);
|
||||
return spr;
|
||||
}
|
||||
const officeLbl = makeLabel('OFFICE', '#ffb840'); officeLbl.position.set(2.6, 0.06, 2.6); scene.add(officeLbl);
|
||||
const hallLbl = makeLabel('HALLWAY', '#ff4cc8'); hallLbl.position.set(-2.6, 0.06, -3.2); scene.add(hallLbl);
|
||||
|
||||
// =====================================================================
|
||||
// Node markers (office / hallway). The hallway node is dimmed until it
|
||||
// actually appears in the live `nodes` list.
|
||||
// =====================================================================
|
||||
const nodeMeshes = {};
|
||||
function buildNode(id) {
|
||||
const a = NODE_ANCHORS[id];
|
||||
const g = new THREE.Group();
|
||||
const post = new THREE.Mesh(
|
||||
new THREE.CylinderGeometry(0.05, 0.07, 0.9, 12),
|
||||
new THREE.MeshStandardMaterial({ color: a.color, emissive: a.color, emissiveIntensity: 0.4, roughness: 0.4 }));
|
||||
post.position.y = 0.45; g.add(post);
|
||||
const orb = new THREE.Mesh(
|
||||
new THREE.SphereGeometry(0.12, 20, 16),
|
||||
new THREE.MeshBasicMaterial({ color: a.color }));
|
||||
orb.position.y = 0.95; g.add(orb);
|
||||
const ring = new THREE.Mesh(
|
||||
new THREE.RingGeometry(0.18, 0.24, 32),
|
||||
new THREE.MeshBasicMaterial({ color: a.color, transparent: true, opacity: 0.6, side: THREE.DoubleSide }));
|
||||
ring.rotation.x = -Math.PI/2; ring.position.y = 0.01; g.add(ring);
|
||||
const lbl = makeLabel('ESP32-S3 ' + a.label, '#' + a.color.toString(16).padStart(6,'0'));
|
||||
lbl.scale.set(2.6, 0.65, 1); lbl.position.set(0, 1.25, 0); g.add(lbl);
|
||||
g.position.set(a.pos[0], 0, a.pos[2]);
|
||||
g.userData.parts = { post, orb, ring };
|
||||
scene.add(g);
|
||||
return g;
|
||||
}
|
||||
Object.keys(NODE_ANCHORS).forEach(id => { nodeMeshes[id] = buildNode(+id); });
|
||||
function setNodeActive(id, active) {
|
||||
const g = nodeMeshes[id]; if (!g) return;
|
||||
const o = active ? 1.0 : 0.22;
|
||||
const parts = g.userData.parts;
|
||||
parts.orb.material.opacity = o; parts.orb.material.transparent = true;
|
||||
parts.ring.material.opacity = 0.6 * o;
|
||||
parts.post.material.emissiveIntensity = active ? 0.5 : 0.12;
|
||||
}
|
||||
setNodeActive(9, false); setNodeActive(13, false);
|
||||
|
||||
// =====================================================================
|
||||
// signal_field 20x20 floor heatmap — instanced colored tiles.
|
||||
// Driven ONLY by real `signal_field.values` (400 floats ~0..1).
|
||||
// =====================================================================
|
||||
const TILE = (2*ROOM_HALF) / GRID_N;
|
||||
const heatGeo = new THREE.PlaneGeometry(TILE * 0.96, TILE * 0.96);
|
||||
const heatMat = new THREE.MeshBasicMaterial({ vertexColors: true, transparent: true, opacity: 0.85, side: THREE.DoubleSide });
|
||||
const heatMesh = new THREE.InstancedMesh(heatGeo, heatMat, GRID_N * GRID_N);
|
||||
heatMesh.instanceMatrix.setUsage(THREE.DynamicDrawUsage);
|
||||
const heatColor = new THREE.InstancedBufferAttribute(new Float32Array(GRID_N * GRID_N * 3), 3);
|
||||
heatMesh.instanceColor = heatColor;
|
||||
const _m = new THREE.Matrix4();
|
||||
const _q = new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(1,0,0), -Math.PI/2);
|
||||
const _s = new THREE.Vector3(1,1,1);
|
||||
const _p = new THREE.Vector3();
|
||||
// gridCell (gx,gz) -> world (x,z). gx,gz in [0,GRID_N).
|
||||
function cellToWorld(gx, gz) {
|
||||
return [ (gx + 0.5) * TILE - ROOM_HALF, (gz + 0.5) * TILE - ROOM_HALF ];
|
||||
}
|
||||
for (let gz = 0; gz < GRID_N; gz++) {
|
||||
for (let gx = 0; gx < GRID_N; gx++) {
|
||||
const i = gz * GRID_N + gx;
|
||||
const [wx, wz] = cellToWorld(gx, gz);
|
||||
_p.set(wx, 0.012, wz);
|
||||
_m.compose(_p, _q, _s);
|
||||
heatMesh.setMatrixAt(i, _m);
|
||||
heatColor.setXYZ(i, 0.02, 0.02, 0.03);
|
||||
}
|
||||
}
|
||||
heatMesh.instanceMatrix.needsUpdate = true;
|
||||
scene.add(heatMesh);
|
||||
|
||||
// amber→white heat ramp for a value in [0,1]
|
||||
function heatRamp(v, out) {
|
||||
v = Math.max(0, Math.min(1, v));
|
||||
// dark -> amber -> hot white
|
||||
const r = Math.min(1, 0.05 + 1.6 * v);
|
||||
const g = Math.min(1, 0.02 + 1.1 * v * v);
|
||||
const b = Math.min(1, 0.04 + 0.9 * Math.pow(v, 3));
|
||||
out.set(r, g, b);
|
||||
return out;
|
||||
}
|
||||
const _c = new THREE.Color();
|
||||
let lastFieldPeak = { gx: GRID_N/2|0, gz: GRID_N/2|0, v: 0 };
|
||||
function updateHeatmap(field) {
|
||||
if (!field || !Array.isArray(field.values)) return;
|
||||
const vals = field.values;
|
||||
// grid_size is [20,1,20]; values are row-major 400 floats.
|
||||
let peakV = -1, peakGx = lastFieldPeak.gx, peakGz = lastFieldPeak.gz;
|
||||
const n = Math.min(vals.length, GRID_N * GRID_N);
|
||||
for (let i = 0; i < n; i++) {
|
||||
const v = vals[i];
|
||||
heatRamp(v, _c);
|
||||
heatColor.setXYZ(i, _c.r, _c.g, _c.b);
|
||||
if (v > peakV) { peakV = v; peakGx = i % GRID_N; peakGz = (i / GRID_N) | 0; }
|
||||
}
|
||||
heatColor.needsUpdate = true;
|
||||
lastFieldPeak = { gx: peakGx, gz: peakGz, v: peakV };
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// RF-localization puck — from persons[0].position (coarse, NOT pose).
|
||||
// Falls back to the signal_field peak cell when no person is present.
|
||||
// =====================================================================
|
||||
const puck = new THREE.Group();
|
||||
const puckCore = new THREE.Mesh(
|
||||
new THREE.SphereGeometry(0.16, 24, 18),
|
||||
new THREE.MeshBasicMaterial({ color: 0x66ff88 }));
|
||||
puckCore.position.y = 0.16; puck.add(puckCore);
|
||||
const puckRing = new THREE.Mesh(
|
||||
new THREE.RingGeometry(0.28, 0.36, 40),
|
||||
new THREE.MeshBasicMaterial({ color: 0x66ff88, transparent: true, opacity: 0.7, side: THREE.DoubleSide }));
|
||||
puckRing.rotation.x = -Math.PI/2; puckRing.position.y = 0.02; puck.add(puckRing);
|
||||
const puckBeam = new THREE.Mesh(
|
||||
new THREE.CylinderGeometry(0.03, 0.03, 1.2, 8),
|
||||
new THREE.MeshBasicMaterial({ color: 0x66ff88, transparent: true, opacity: 0.35 }));
|
||||
puckBeam.position.y = 0.6; puck.add(puckBeam);
|
||||
puck.visible = false;
|
||||
scene.add(puck);
|
||||
const puckTarget = new THREE.Vector3(0, 0, 0);
|
||||
|
||||
function updatePuck(frame) {
|
||||
let wx = null, wz = null, present = false;
|
||||
const persons = frame.persons || [];
|
||||
if (persons.length && Array.isArray(persons[0].position)) {
|
||||
// server position is [x, 0, z] in metres, origin at room centre
|
||||
wx = persons[0].position[0];
|
||||
wz = persons[0].position[2];
|
||||
present = true;
|
||||
}
|
||||
// If no person but the field has clear energy, show the peak cell
|
||||
// (coarse) so the puck honestly tracks "where the RF energy is".
|
||||
if (!present && lastFieldPeak.v > 0.55) {
|
||||
const peak = cellToWorld(lastFieldPeak.gx, lastFieldPeak.gz);
|
||||
wx = peak[0]; wz = peak[1]; present = true;
|
||||
}
|
||||
if (present && wx !== null) {
|
||||
// clamp into the room so it never flies off the floor
|
||||
wx = Math.max(-ROOM_HALF+0.3, Math.min(ROOM_HALF-0.3, wx));
|
||||
wz = Math.max(-ROOM_HALF+0.3, Math.min(ROOM_HALF-0.3, wz));
|
||||
puckTarget.set(wx, 0, wz);
|
||||
puck.visible = true;
|
||||
} else {
|
||||
puck.visible = false;
|
||||
}
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// HUD updates
|
||||
// =====================================================================
|
||||
const $ = id => document.getElementById(id);
|
||||
function clamp01(x){ return Math.max(0, Math.min(1, x)); }
|
||||
function setBar(barId, valId, frac, text) {
|
||||
$(barId).style.width = (clamp01(frac) * 100).toFixed(0) + '%';
|
||||
$(valId).textContent = text;
|
||||
}
|
||||
|
||||
// motion sparkline ring buffer
|
||||
const sparkCtx = $('spark').getContext('2d');
|
||||
const SPARK_N = 120;
|
||||
const sparkBuf = new Array(SPARK_N).fill(0);
|
||||
function pushSpark(v) {
|
||||
sparkBuf.push(v); if (sparkBuf.length > SPARK_N) sparkBuf.shift();
|
||||
const w = sparkCtx.canvas.width, h = sparkCtx.canvas.height;
|
||||
sparkCtx.clearRect(0,0,w,h);
|
||||
let maxV = 40; for (const x of sparkBuf) if (x > maxV) maxV = x;
|
||||
sparkCtx.strokeStyle = '#ffb840'; sparkCtx.lineWidth = 1.5; sparkCtx.beginPath();
|
||||
for (let i = 0; i < sparkBuf.length; i++) {
|
||||
const x = (i / (SPARK_N-1)) * w;
|
||||
const y = h - (sparkBuf[i] / maxV) * (h - 3) - 1.5;
|
||||
i === 0 ? sparkCtx.moveTo(x, y) : sparkCtx.lineTo(x, y);
|
||||
}
|
||||
sparkCtx.stroke();
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// Honest status banner (strict, mutually exclusive)
|
||||
// =====================================================================
|
||||
const banner = $('banner');
|
||||
function setBannerLive(source, nodeCount) {
|
||||
if (source === 'esp32') {
|
||||
banner.className = 'live';
|
||||
banner.innerHTML = 'LIVE — real ESP32 CSI <span class="src">(source=' + source + ', ' + nodeCount + ' node' + (nodeCount === 1 ? '' : 's') + ')</span>';
|
||||
} else {
|
||||
// anything not esp32 = explicitly NOT real, badged
|
||||
banner.className = 'sim';
|
||||
banner.innerHTML = 'SIMULATED — not real <span class="src">(source=' + source + ' — start an ESP32 for live CSI)</span>';
|
||||
}
|
||||
}
|
||||
function setBannerNoServer() {
|
||||
banner.className = 'noserver';
|
||||
banner.innerHTML = 'NO SERVER — start the sensing-server <span class="src">(ws://localhost:8765/ws/sensing unreachable)</span>';
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// WebSocket — render ONLY real frames. Reconnect; never fabricate.
|
||||
// =====================================================================
|
||||
let ws = null, gotFrame = false;
|
||||
let frameTimes = []; // for measured update rate (fps)
|
||||
let lastFrame = null; // most recent real frame (render loop interpolates puck)
|
||||
|
||||
function connect() {
|
||||
setBannerNoServer();
|
||||
try { ws = new WebSocket(WS_URL); }
|
||||
catch (e) { scheduleReconnect(); return; }
|
||||
|
||||
ws.onopen = () => { /* wait for first frame before claiming LIVE */ };
|
||||
ws.onmessage = (ev) => {
|
||||
let d; try { d = JSON.parse(ev.data); } catch (e) { return; }
|
||||
if (!d || d.type !== 'sensing_update') return;
|
||||
onFrame(d);
|
||||
};
|
||||
ws.onclose = () => { gotFrame = false; $('waiting').classList.add('show'); setBannerNoServer(); scheduleReconnect(); };
|
||||
ws.onerror = () => { try { ws.close(); } catch (e) {} };
|
||||
}
|
||||
let reconnectT = null;
|
||||
function scheduleReconnect() {
|
||||
if (reconnectT) return;
|
||||
reconnectT = setTimeout(() => { reconnectT = null; connect(); }, 1500);
|
||||
}
|
||||
|
||||
function onFrame(d) {
|
||||
gotFrame = true;
|
||||
lastFrame = d;
|
||||
$('waiting').classList.remove('show');
|
||||
|
||||
const source = d.source || 'unknown';
|
||||
const nodes = Array.isArray(d.nodes) ? d.nodes : [];
|
||||
setBannerLive(source, nodes.length);
|
||||
|
||||
// measured update rate
|
||||
const now = performance.now();
|
||||
frameTimes.push(now);
|
||||
while (frameTimes.length && now - frameTimes[0] > 2000) frameTimes.shift();
|
||||
const fps = frameTimes.length > 1 ? (frameTimes.length - 1) / ((frameTimes[frameTimes.length-1] - frameTimes[0]) / 1000) : 0;
|
||||
|
||||
const cls = d.classification || {};
|
||||
const feat = d.features || {};
|
||||
|
||||
// info panel
|
||||
$('m-source').textContent = source.toUpperCase();
|
||||
$('m-source').className = 'v ' + (source === 'esp32' ? 'green' : 'red');
|
||||
const presence = !!cls.presence;
|
||||
$('m-presence').textContent = presence ? (cls.motion_level === 'present_moving' ? 'PRESENT · MOVING' : 'PRESENT') : 'CLEAR';
|
||||
$('m-presence').className = 'v ' + (presence ? 'green' : 'mute');
|
||||
$('m-motion').textContent = cls.motion_level || '—';
|
||||
$('m-conf').textContent = (cls.confidence != null) ? cls.confidence.toFixed(2) : '—';
|
||||
$('m-persons').textContent = (d.estimated_persons != null) ? d.estimated_persons : '—';
|
||||
$('m-nodes').textContent = nodes.length + ' (' + nodes.map(n => n.node_id).join(', ') + ')';
|
||||
$('m-tick').textContent = (d.tick != null) ? d.tick : '—';
|
||||
$('m-fps').textContent = fps ? fps.toFixed(1) + ' Hz' : '—';
|
||||
|
||||
// feature bars (real values, scaled into 0..1 for the bar width only)
|
||||
const motion = feat.motion_band_power || 0;
|
||||
const breath = feat.breathing_band_power || 0;
|
||||
const variance = feat.variance || 0;
|
||||
const rssi = feat.mean_rssi != null ? feat.mean_rssi : -100;
|
||||
setBar('bar-motion', 'v-motion', motion / 100, motion.toFixed(1));
|
||||
setBar('bar-breath', 'v-breath', breath / 100, breath.toFixed(1));
|
||||
setBar('bar-var', 'v-var', variance / 80, variance.toFixed(1));
|
||||
// rssi: map -90..-30 dBm -> 0..1
|
||||
setBar('bar-rssi', 'v-rssi', (rssi + 90) / 60, rssi.toFixed(0));
|
||||
pushSpark(motion);
|
||||
|
||||
// node activity
|
||||
const activeIds = new Set(nodes.map(n => n.node_id));
|
||||
[9, 13].forEach(id => setNodeActive(id, activeIds.has(id)));
|
||||
|
||||
// heatmap + puck
|
||||
updateHeatmap(d.signal_field);
|
||||
updatePuck(d);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// Optional webcam ground-truth tile (reused from demos/05). Camera is
|
||||
// separate from CSI sensing — labeled "ground truth when visible".
|
||||
// =====================================================================
|
||||
let camStream = null;
|
||||
$('cam-btn').addEventListener('click', async () => {
|
||||
const btn = $('cam-btn');
|
||||
if (camStream) { // toggle off
|
||||
camStream.getTracks().forEach(t => t.stop());
|
||||
$('cam-video').style.display = 'none'; camStream = null;
|
||||
btn.textContent = '▶ enable webcam (optional)';
|
||||
return;
|
||||
}
|
||||
btn.disabled = true; btn.textContent = 'requesting camera…';
|
||||
try {
|
||||
camStream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { width: { ideal: 640 }, height: { ideal: 480 }, facingMode: 'user' }, audio: false,
|
||||
});
|
||||
const v = $('cam-video'); v.srcObject = camStream; v.style.display = 'block';
|
||||
btn.textContent = '■ stop webcam'; btn.disabled = false;
|
||||
} catch (e) {
|
||||
btn.textContent = '✗ camera unavailable'; btn.disabled = false; console.error(e);
|
||||
setTimeout(() => { if (!camStream) btn.textContent = '▶ enable webcam (optional)'; }, 2000);
|
||||
}
|
||||
});
|
||||
|
||||
// =====================================================================
|
||||
// Render loop — smooth the puck toward its real target; pulse rings.
|
||||
// =====================================================================
|
||||
const clock = new THREE.Clock();
|
||||
function animate() {
|
||||
requestAnimationFrame(animate);
|
||||
const t = clock.getElapsedTime();
|
||||
controls.update();
|
||||
|
||||
if (puck.visible) {
|
||||
puck.position.lerp(puckTarget, 0.18);
|
||||
const pulse = 0.8 + 0.25 * Math.sin(t * 3.0);
|
||||
puckRing.scale.set(pulse, pulse, pulse);
|
||||
puckRing.material.opacity = 0.5 + 0.25 * Math.sin(t * 3.0);
|
||||
}
|
||||
// node rings breathe when active
|
||||
[9,13].forEach(id => {
|
||||
const g = nodeMeshes[id]; if (!g) return;
|
||||
const r = g.userData.parts.ring;
|
||||
const s = 1 + 0.08 * Math.sin(t * 2 + id);
|
||||
r.scale.set(s, s, s);
|
||||
});
|
||||
|
||||
composer.render();
|
||||
}
|
||||
animate();
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
camera.aspect = window.innerWidth / window.innerHeight;
|
||||
camera.updateProjectionMatrix();
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
composer.setSize(window.innerWidth, window.innerHeight);
|
||||
});
|
||||
|
||||
// kick off
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,159 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<title>WiFlow · live WiFi-inferred pose</title>
|
||||
<style>
|
||||
:root{--bg:#0a0c10;--panel:#11151c;--amber:#ffb840;--green:#46e08a;--red:#ff5a5a;--mute:#7d8796;--line:#1d2430}
|
||||
*{box-sizing:border-box}
|
||||
body{margin:0;background:var(--bg);color:#dfe6ee;font:14px/1.5 'JetBrains Mono',ui-monospace,Menlo,monospace}
|
||||
header{padding:14px 18px;border-bottom:1px solid var(--line);display:flex;align-items:center;gap:14px;flex-wrap:wrap}
|
||||
h1{font-size:15px;margin:0;letter-spacing:1px;text-transform:uppercase;font-weight:600}
|
||||
h1 span{color:var(--amber)}
|
||||
#banner{margin-left:auto;padding:5px 12px;border-radius:5px;font-weight:600;font-size:12px;letter-spacing:.5px}
|
||||
.live{background:rgba(70,224,138,.15);color:var(--green);border:1px solid var(--green)}
|
||||
.sim{background:rgba(255,184,64,.15);color:var(--amber);border:1px solid var(--amber)}
|
||||
.down{background:rgba(255,90,90,.15);color:var(--red);border:1px solid var(--red)}
|
||||
main{display:flex;gap:18px;padding:18px;flex-wrap:wrap}
|
||||
.card{background:var(--panel);border:1px solid var(--line);border-radius:10px;padding:14px}
|
||||
canvas{background:#070a0e;border-radius:8px;display:block}
|
||||
.label{font-size:11px;text-transform:uppercase;letter-spacing:1.5px;color:var(--mute);margin-bottom:8px}
|
||||
.stats{min-width:240px}
|
||||
.row{display:flex;justify-content:space-between;padding:3px 0;border-bottom:1px dashed var(--line)}
|
||||
.row .k{color:var(--mute)} .row .v{color:var(--amber);font-variant-numeric:tabular-nums}
|
||||
.v.green{color:var(--green)}
|
||||
.note{margin-top:12px;font-size:11px;color:var(--mute);line-height:1.6;max-width:300px}
|
||||
.note b{color:#dfe6ee}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>WiFlow · <span>live WiFi-inferred pose</span></h1>
|
||||
<div id="banner" class="down">CONNECTING…</div>
|
||||
</header>
|
||||
<main>
|
||||
<div class="card">
|
||||
<div class="label">CSI → pose (skeleton) overlaid on your laptop camera</div>
|
||||
<div id="stage" style="width:420px;height:560px;border-radius:8px;overflow:hidden;background:#070a0e">
|
||||
<video id="cam" autoplay muted playsinline style="position:absolute;width:2px;height:2px;opacity:0;pointer-events:none"></video>
|
||||
<canvas id="cv" width="420" height="560"></canvas>
|
||||
</div>
|
||||
<div style="margin-top:10px;display:flex;gap:8px;align-items:center;flex-wrap:wrap">
|
||||
<button id="camBtn" style="background:var(--amber);color:#0a0c10;border:0;border-radius:6px;padding:7px 14px;font:inherit;font-weight:600;cursor:pointer">enable laptop camera</button>
|
||||
<select id="camSel" style="display:none;background:var(--panel);color:#dfe6ee;border:1px solid var(--line);border-radius:6px;padding:6px;font:inherit;max-width:220px"></select>
|
||||
</div>
|
||||
<div id="camStatus" style="margin-top:6px;font-size:11px;color:var(--mute)">camera: off</div>
|
||||
<div class="note" style="margin-top:8px">Camera is a <b>visual reference only</b> — it is NOT fed to the model. Overlay alignment is approximate (model trained in a different camera's frame).</div>
|
||||
</div>
|
||||
<div class="card stats">
|
||||
<div class="label">live</div>
|
||||
<div class="row"><span class="k">CSI source</span><span class="v" id="src">—</span></div>
|
||||
<div class="row"><span class="k">nodes</span><span class="v" id="nodes">—</span></div>
|
||||
<div class="row"><span class="k">presence</span><span class="v" id="pres">—</span></div>
|
||||
<div class="row"><span class="k">motion</span><span class="v" id="motion">—</span></div>
|
||||
<div class="row"><span class="k">pose fps</span><span class="v" id="fps">—</span></div>
|
||||
<div class="note">
|
||||
This skeleton is inferred <b>from WiFi CSI only</b> — no camera in the loop here. A model was
|
||||
trained on paired (camera-pose, CSI) data in this room (ADR-079/180).
|
||||
<br/><br/>
|
||||
<b>Honest accuracy:</b> ~<b>59.5% PCK@0.10</b> on held-out data (vs a 50% mean-pose baseline →
|
||||
<b>+9.4 pp real signal</b>). It captures <b>coarse</b> pose; fine detail is weak (PCK@0.05 ≈ 24%).
|
||||
Same person / room / session — not validated cross-day or through-wall.
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
<script>
|
||||
const POSE_WS = (new URLSearchParams(location.search)).get('ws') || `ws://${location.hostname||'localhost'}:8770/pose`;
|
||||
const cv = document.getElementById('cv'), ctx = cv.getContext('2d');
|
||||
const $ = id => document.getElementById(id);
|
||||
let edges = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12],[11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]];
|
||||
let last = null, frames = 0, t0 = performance.now();
|
||||
|
||||
function banner(state, txt){ const b=$('banner'); b.className=state; b.textContent=txt; }
|
||||
|
||||
// per-joint smoothing (EMA) so dropped/jittery CSI frames render fluidly (ADR-180 dead-reckoning, lite)
|
||||
let sm = null;
|
||||
function smooth(kps){
|
||||
if(!sm){ sm = kps.map(p=>[p[0],p[1]]); return sm; }
|
||||
const a=0.35; for(let i=0;i<kps.length;i++){ sm[i][0]+=a*(kps[i][0]-sm[i][0]); sm[i][1]+=a*(kps[i][1]-sm[i][1]); }
|
||||
return sm;
|
||||
}
|
||||
const camEl=document.getElementById('cam');
|
||||
function draw(p){
|
||||
const W=cv.width, H=cv.height;
|
||||
// paint the live camera frame onto the canvas (robust — no z-index/overlay tricks)
|
||||
if(camEl && camEl.videoWidth>0){
|
||||
ctx.save(); ctx.globalAlpha=0.9;
|
||||
// cover-fit the camera frame into the canvas
|
||||
const vr=camEl.videoWidth/camEl.videoHeight, cr=W/H;
|
||||
let dw=W, dh=H, dx=0, dy=0;
|
||||
if(vr>cr){ dh=H; dw=H*vr; dx=(W-dw)/2; } else { dw=W; dh=W/vr; dy=(H-dh)/2; }
|
||||
ctx.drawImage(camEl, dx, dy, dw, dh); ctx.restore();
|
||||
} else {
|
||||
ctx.fillStyle='#070a0e'; ctx.fillRect(0,0,W,H);
|
||||
}
|
||||
if(!p || !p.kps){ return; }
|
||||
const s = smooth(p.kps);
|
||||
const k = s.map(([x,y])=>[x*W, y*H]);
|
||||
ctx.lineWidth=5; ctx.strokeStyle=p.presence?'rgba(70,224,138,.95)':'rgba(125,135,150,.8)'; ctx.lineCap='round';
|
||||
ctx.shadowColor='rgba(70,224,138,.6)'; ctx.shadowBlur=8;
|
||||
for(const [a,b] of edges){ ctx.beginPath(); ctx.moveTo(k[a][0],k[a][1]); ctx.lineTo(k[b][0],k[b][1]); ctx.stroke(); }
|
||||
ctx.shadowBlur=0;
|
||||
for(const [x,y] of k){ ctx.beginPath(); ctx.arc(x,y,5,0,7); ctx.fillStyle=p.presence?'#ffb840':'#667'; ctx.fill(); }
|
||||
}
|
||||
|
||||
// ---- laptop webcam (visual reference only; NOT fed to the model) ----
|
||||
let camStream=null;
|
||||
async function startCam(deviceId){
|
||||
if(camStream){ camStream.getTracks().forEach(t=>t.stop()); }
|
||||
const constraints = deviceId ? {video:{deviceId:{exact:deviceId}}} : {video:true};
|
||||
const st=document.getElementById('camStatus');
|
||||
try{
|
||||
st.textContent='camera: requesting…';
|
||||
camStream = await navigator.mediaDevices.getUserMedia(constraints);
|
||||
const v=document.getElementById('cam'); v.muted=true; v.srcObject=camStream;
|
||||
v.onloadedmetadata=()=>{ v.play().catch(err=>st.textContent='camera: play() blocked '+err.name); };
|
||||
await v.play().catch(()=>{});
|
||||
const tr=camStream.getVideoTracks()[0]; const ss=tr.getSettings();
|
||||
// live readout: shows if real frames are flowing (videoWidth>0) and which device
|
||||
const tick=()=>{ st.textContent = `camera: "${tr.label}" ${v.videoWidth}x${v.videoHeight} ${tr.readyState} ${v.paused?'PAUSED':'playing'}`; };
|
||||
tick(); setInterval(tick, 1000);
|
||||
document.getElementById('camBtn').textContent='switch camera ↻';
|
||||
// populate the picker now that we have permission (labels need permission)
|
||||
const devs = (await navigator.mediaDevices.enumerateDevices()).filter(d=>d.kind==='videoinput');
|
||||
const sel=document.getElementById('camSel'); sel.style.display = devs.length>1?'inline-block':'none';
|
||||
sel.innerHTML = devs.map((d,i)=>`<option value="${d.deviceId}">${d.label||('camera '+(i+1))}</option>`).join('');
|
||||
const cur = camStream.getVideoTracks()[0].getSettings().deviceId; if(cur) sel.value=cur;
|
||||
}catch(e){
|
||||
document.getElementById('camBtn').textContent = 'camera error: '+e.name+(e.name==='NotReadableError'?' (in use by Zoom/Teams?)':'');
|
||||
console.error('getUserMedia', e);
|
||||
}
|
||||
}
|
||||
document.getElementById('camBtn').addEventListener('click', ()=>startCam());
|
||||
document.getElementById('camSel').addEventListener('change', e=>startCam(e.target.value));
|
||||
|
||||
function connect(){
|
||||
banner('down','CONNECTING…');
|
||||
const ws = new WebSocket(POSE_WS);
|
||||
ws.onopen = ()=> banner('sim','WAITING FOR POSE…');
|
||||
ws.onmessage = ev => {
|
||||
const d = JSON.parse(ev.data);
|
||||
if(d.type==='meta'){ edges = d.edges; return; }
|
||||
if(d.type!=='pose') return;
|
||||
last=d; frames++;
|
||||
if(d.src==='esp32') banner('live','LIVE — WiFi-inferred pose (real ESP32 CSI)');
|
||||
else banner('sim','SIMULATED CSI — not real ('+d.src+')');
|
||||
$('src').textContent=d.src; $('src').className = d.src==='esp32'?'v green':'v';
|
||||
$('nodes').textContent=(d.nodes||[]).join(', ')||'—';
|
||||
$('pres').textContent=d.presence?'PRESENT':'—';
|
||||
$('motion').textContent=(d.motion!=null?Math.round(d.motion):'—');
|
||||
};
|
||||
ws.onclose = ()=>{ banner('down','NO BRIDGE — start wiflow_infer.py'); setTimeout(connect,1500); };
|
||||
ws.onerror = ()=> ws.close();
|
||||
}
|
||||
function loop(){ draw(last); const now=performance.now(); if(now-t0>1000){ $('fps').textContent=frames; frames=0; t0=now; } requestAnimationFrame(loop); }
|
||||
connect(); loop();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Tiny threaded static server for the through-wall WiFi-CSI sensing demo.
|
||||
|
||||
Adapted from examples/three.js/server/serve-demo.py. Serves the
|
||||
`examples/through-wall/` page so a browser can fetch index.html, then the
|
||||
page connects directly to the LIVE sensing-server WebSocket at
|
||||
ws://localhost:8765/ws/sensing (NOT proxied through here).
|
||||
|
||||
Why a threaded server (not `python -m http.server`)?
|
||||
The stdlib SimpleHTTPServer is single-threaded; a browser opens several
|
||||
parallel connections (HTML + the three.js CDN tags fetch in parallel),
|
||||
the first eats the worker, the rest can stall. ThreadingHTTPServer fixes it.
|
||||
|
||||
IMPORTANT: this serves on port 8080 — port 8765 is taken by the
|
||||
sensing-server's WebSocket. They are two different processes.
|
||||
|
||||
Usage:
|
||||
# 1) start the REAL sensing-server (separate terminal):
|
||||
# cd v2
|
||||
# cargo build -p wifi-densepose-sensing-server
|
||||
# ./target/debug/sensing-server.exe --ws-port 8765 --udp-port 5005
|
||||
# 2) start this static server:
|
||||
python examples/through-wall/serve.py
|
||||
# 3) open:
|
||||
# http://localhost:8080/examples/through-wall/index.html
|
||||
|
||||
Override the WS endpoint with a query param, e.g.:
|
||||
http://localhost:8080/examples/through-wall/index.html?ws=ws://192.168.1.20:8765/ws/sensing
|
||||
"""
|
||||
from http.server import ThreadingHTTPServer, SimpleHTTPRequestHandler
|
||||
import os
|
||||
import sys
|
||||
|
||||
PORT = int(os.environ.get("PORT", 8080))
|
||||
|
||||
# Serve from the repo root regardless of where this script is launched.
|
||||
# This file lives at examples/through-wall/serve.py — two levels deep.
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
|
||||
|
||||
class NoCacheHandler(SimpleHTTPRequestHandler):
|
||||
def end_headers(self):
|
||||
# Aggressive no-cache so the browser ALWAYS fetches the latest
|
||||
# index.html after edits, even on a soft refresh.
|
||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
|
||||
self.send_header("Pragma", "no-cache")
|
||||
self.send_header("Expires", "0")
|
||||
super().end_headers()
|
||||
|
||||
def log_message(self, fmt, *args): # quieter logs
|
||||
sys.stderr.write("[serve] " + (fmt % args) + "\n")
|
||||
|
||||
|
||||
PAGE = "examples/through-wall/index.html"
|
||||
|
||||
with ThreadingHTTPServer(("127.0.0.1", PORT), NoCacheHandler) as srv:
|
||||
print(f"serving {os.getcwd()} on http://127.0.0.1:{PORT}/")
|
||||
print(f" open http://localhost:{PORT}/{PAGE}")
|
||||
print("")
|
||||
print(" The page connects to the LIVE sensing-server at")
|
||||
print(" ws://localhost:8765/ws/sensing (start it first — see README.md).")
|
||||
print(" Override with ?ws=ws://HOST:PORT/ws/sensing")
|
||||
try:
|
||||
srv.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Rigorous A/B for WiFlow CSI->pose: is the held-out PCK real signal or split leakage?
|
||||
|
||||
For a dataset of {csi:[D], kps:17x[x,y,vis]} pairs, train the SAME small MLP under
|
||||
several train/val SPLITS and report held-out PCK@0.10 vs the mean-pose baseline:
|
||||
|
||||
- chronological_80_20 : last 20% in time (val temporally ADJACENT to train -> leaks
|
||||
via CSI/pose autocorrelation; this is what gave us +9.4)
|
||||
- random_80_20 : shuffled (val frames interleaved with train -> MAX leak)
|
||||
- blocked_gap : hold out a contiguous MIDDLE block with a time GAP buffer on
|
||||
each side so val is NOT adjacent to any train frame -> the
|
||||
honest, leakage-controlled test
|
||||
|
||||
If the model beats baseline on chronological/random but COLLAPSES to ~baseline on
|
||||
blocked_gap, the apparent signal was temporal leakage, not generalizable CSI->pose.
|
||||
|
||||
Usage (ruvultra venv): python wiflow_ab.py --data ~/wiflow-room/dataset.jsonl
|
||||
"""
|
||||
import argparse, json, sys
|
||||
import numpy as np, torch, torch.nn as nn
|
||||
|
||||
def _rec(r, X, Y, V, B):
|
||||
X.append(r["csi"]); kp=r["kps"]
|
||||
if kp and isinstance(kp[0], (list,tuple)): # 17 x [x,y(,vis)]
|
||||
Y.append([c for k in kp for c in (k[0],k[1])]); V.append([(k[2] if len(k)>2 else 1.0) for k in kp])
|
||||
else: # flat 34 (browser export, no vis)
|
||||
Y.append(list(kp)); V.append([1.0]*17)
|
||||
B.append(r.get("bucket"))
|
||||
|
||||
def load(path):
|
||||
X,Y,V,B=[],[],[],[]
|
||||
txt=open(path).read().strip()
|
||||
if txt[:1] in "[{": # JSON (browser export: dict{samples:[]} or bare array)
|
||||
d=json.loads(txt)
|
||||
rows = d if isinstance(d,list) else d.get("samples", d.get("data", []))
|
||||
for r in rows: _rec(r,X,Y,V,B)
|
||||
else: # JSONL (python capture)
|
||||
for line in txt.splitlines():
|
||||
if line.strip(): _rec(json.loads(line),X,Y,V,B)
|
||||
return np.array(X,np.float32), np.array(Y,np.float32), np.array(V,np.float32), B
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(s,din,dout):
|
||||
super().__init__()
|
||||
s.n=nn.Sequential(nn.Linear(din,384),nn.ReLU(),nn.Dropout(.35),
|
||||
nn.Linear(384,192),nn.ReLU(),nn.Dropout(.35),
|
||||
nn.Linear(192,96),nn.ReLU(),nn.Linear(96,dout),nn.Sigmoid())
|
||||
def forward(s,x): return s.n(x)
|
||||
|
||||
def pck(pred,gt,vis,thr=0.10):
|
||||
p=pred.reshape(-1,17,2); g=gt.reshape(-1,17,2)
|
||||
d=np.linalg.norm(p-g,axis=2); m=vis>0.5
|
||||
return float((d[m]<thr).mean()) if m.any() else 0.0
|
||||
|
||||
def split_idx(n, kind, B=None):
|
||||
idx=np.arange(n)
|
||||
if kind=="chronological_80_20":
|
||||
c=int(n*.8); return idx[:c], idx[c:]
|
||||
if kind=="random_80_20":
|
||||
rng=np.random.default_rng(0); p=rng.permutation(n); c=int(n*.8); return p[:c], p[c:]
|
||||
if kind=="blocked_gap":
|
||||
# val = contiguous middle 20%; a WIDE 10% time gap each side guarantees no train
|
||||
# frame is temporally adjacent to a val frame (kills frame-autocorrelation leakage).
|
||||
v0=int(n*.4); v1=int(n*.6); gap=int(n*.10)
|
||||
val=idx[v0:v1]; train=np.concatenate([idx[:max(0,v0-gap)], idx[min(n,v1+gap):]])
|
||||
return train, val
|
||||
if kind=="grouped_bucket":
|
||||
# hold out ENTIRE activity buckets -> val poses/activities never seen in train.
|
||||
# the strictest leakage-free test (only when bucket labels exist).
|
||||
b=np.array([x if x is not None else -1 for x in B])
|
||||
uniq=[u for u in sorted(set(b.tolist())) if u!=-1]
|
||||
if len(uniq)<3: raise ValueError("too few buckets")
|
||||
hold=set(uniq[::max(1,len(uniq)//3)][:max(1,len(uniq)//3)]) # ~1/3 of activities held out
|
||||
val=idx[np.isin(b,list(hold))]; train=idx[~np.isin(b,list(hold))]
|
||||
return train, val
|
||||
raise ValueError(kind)
|
||||
|
||||
def run(X,Y,V,tr,va,epochs=250,seed=0):
|
||||
torch.manual_seed(seed); np.random.seed(seed) # seed weight init + batch shuffle
|
||||
dev="cuda" if torch.cuda.is_available() else "cpu"
|
||||
mu,sd=X[tr].mean(0),X[tr].std(0)+1e-6
|
||||
Xtr=torch.tensor((X[tr]-mu)/sd).to(dev); Ytr=torch.tensor(Y[tr]).to(dev)
|
||||
Xva=torch.tensor((X[va]-mu)/sd).to(dev)
|
||||
net=Net(X.shape[1],Y.shape[1]).to(dev)
|
||||
opt=torch.optim.Adam(net.parameters(),lr=1e-3,weight_decay=1e-4); lf=nn.MSELoss()
|
||||
best=(1e9,None)
|
||||
for ep in range(epochs):
|
||||
net.train(); perm=torch.randperm(len(Xtr),device=dev)
|
||||
for i in range(0,len(Xtr),64):
|
||||
j=perm[i:i+64]; opt.zero_grad(); loss=lf(net(Xtr[j]),Ytr[j]); loss.backward(); opt.step()
|
||||
net.eval()
|
||||
with torch.no_grad(): pv=net(Xva).cpu().numpy()
|
||||
vl=float(((pv-Y[va])**2).mean())
|
||||
if vl<best[0]: best=(vl,pv)
|
||||
base=np.tile(Y[tr].mean(0),(len(va),1))
|
||||
return pck(best[1],Y[va],V[va]), pck(base,Y[va],V[va])
|
||||
|
||||
def main():
|
||||
ap=argparse.ArgumentParser(); ap.add_argument("--data",required=True)
|
||||
ap.add_argument("--epochs",type=int,default=250); ap.add_argument("--seeds",type=int,default=3)
|
||||
a=ap.parse_args()
|
||||
X,Y,V,B=load(a.data); n=len(X)
|
||||
has_buckets=any(x is not None for x in B)
|
||||
print(f"[ab] {n} samples, X={X.shape}, buckets={'yes' if has_buckets else 'no'}, "
|
||||
f"seeds={a.seeds}, epochs={a.epochs}\n")
|
||||
print(f"{'split':<22}{'model PCK@0.10':>16}{'baseline':>11}{'delta (mean±sd)':>20} verdict")
|
||||
print("-"*86)
|
||||
splits=["chronological_80_20","random_80_20","blocked_gap"]+(["grouped_bucket"] if has_buckets else [])
|
||||
for kind in splits:
|
||||
try:
|
||||
tr,va=split_idx(n,kind,B)
|
||||
ms=[]; bs=[]
|
||||
for s in range(a.seeds):
|
||||
m,b=run(X,Y,V,tr,va,a.epochs,seed=s); ms.append(m); bs.append(b)
|
||||
ms=np.array(ms)*100; bs=np.array(bs)*100; ds=ms-bs
|
||||
dm,dsd=ds.mean(),ds.std()
|
||||
# REAL only if the mean delta minus 1 sd still clears the 1.5pp threshold (robust to seed variance)
|
||||
verdict = "REAL signal" if dm-dsd>1.5 else ("weak/uncertain" if dm>1.5 else "no signal (==baseline)")
|
||||
print(f"{kind:<22}{ms.mean():>13.1f}±{ms.std():>3.1f}{bs.mean():>10.1f}%{dm:>+12.1f}±{dsd:>4.1f}pp {verdict}")
|
||||
except Exception as e:
|
||||
print(f"{kind:<22} skipped: {e}")
|
||||
print(f"\nmean±sd over {a.seeds} seeds (weight init + batch order). blocked_gap = 10% time gap each")
|
||||
print("side; grouped_bucket holds out ENTIRE activities (strictest). If only the LEAKY splits")
|
||||
print("(chronological/random) beat baseline, the apparent signal is leakage, not generalizable pose.")
|
||||
|
||||
if __name__=="__main__": main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python3
|
||||
"""WiFlow-style camera-supervised capture (ADR-079 / ADR-180).
|
||||
|
||||
Runs on a box with BOTH a camera (ground truth) and reachable live CSI:
|
||||
- opens a camera, runs MediaPipe Pose -> 17 COCO keypoints (the LABEL),
|
||||
- subscribes to the sensing-server /ws/sensing (the INPUT: CSI features +
|
||||
20x20 signal-field),
|
||||
- writes timestamp-aligned (csi -> pose) pairs to a JSONL dataset.
|
||||
|
||||
This is the *collect* phase of camera-supervised CSI->pose training. The camera
|
||||
and the CSI nodes MUST see the same person in the same space at the same time,
|
||||
or the pairs are meaningless. Honest by construction: we only emit a pair when
|
||||
BOTH a confident camera pose AND a live (source=esp32) CSI frame are present in
|
||||
the same ~100 ms window.
|
||||
|
||||
Usage (on ruvultra, with the CSI tunneled to localhost:8765):
|
||||
python3 wiflow_capture.py --ws ws://localhost:8765/ws/sensing \
|
||||
--cam 0 --out ~/wiflow-room/dataset.jsonl --seconds 180
|
||||
"""
|
||||
import argparse, asyncio, json, time, threading, sys, os
|
||||
from collections import deque
|
||||
|
||||
import urllib.request
|
||||
import cv2
|
||||
import numpy as np
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks.python import BaseOptions
|
||||
from mediapipe.tasks.python.vision import PoseLandmarker, PoseLandmarkerOptions, RunningMode
|
||||
import websockets
|
||||
|
||||
_MODEL_URL = ("https://storage.googleapis.com/mediapipe-models/pose_landmarker/"
|
||||
"pose_landmarker_lite/float16/latest/pose_landmarker_lite.task")
|
||||
|
||||
def ensure_model(path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
print(f"[capture] downloading pose model -> {path}", flush=True)
|
||||
urllib.request.urlretrieve(_MODEL_URL, path)
|
||||
return path
|
||||
|
||||
# MediaPipe Pose (33 landmarks) -> 17 COCO keypoints (same mapping as
|
||||
# scripts/collect-ground-truth.py, ADR-079).
|
||||
COCO_FROM_MP = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
|
||||
COCO_NAMES = ["nose","l_eye","r_eye","l_ear","r_ear","l_sho","r_sho","l_elb",
|
||||
"r_elb","l_wri","r_wri","l_hip","r_hip","l_knee","r_knee","l_ank","r_ank"]
|
||||
|
||||
# ---- shared state between the CSI (async) thread and the camera (sync) loop ----
|
||||
_latest_csi = {"t": 0.0, "frame": None}
|
||||
_csi_lock = threading.Lock()
|
||||
_stop = threading.Event()
|
||||
|
||||
|
||||
def csi_thread(ws_url: str):
|
||||
"""Background thread: keep the most recent LIVE csi frame in _latest_csi."""
|
||||
async def run():
|
||||
while not _stop.is_set():
|
||||
try:
|
||||
async with websockets.connect(ws_url, open_timeout=8, ping_interval=20) as ws:
|
||||
while not _stop.is_set():
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=8)
|
||||
d = json.loads(msg)
|
||||
with _csi_lock:
|
||||
_latest_csi["t"] = time.time()
|
||||
_latest_csi["frame"] = d
|
||||
except Exception as e:
|
||||
print(f"[csi] reconnect ({e})", flush=True)
|
||||
await asyncio.sleep(1.0)
|
||||
asyncio.new_event_loop().run_until_complete(run())
|
||||
|
||||
|
||||
def csi_vector(frame: dict):
|
||||
"""Flatten a csi frame to a fixed-length input vector: features + field."""
|
||||
f = frame.get("features", {}) or {}
|
||||
feats = [f.get("mean_rssi", 0.0), f.get("variance", 0.0),
|
||||
f.get("motion_band_power", 0.0), f.get("breathing_band_power", 0.0)]
|
||||
# per-node mean_rssi/variance/motion for up to the 2 nodes (9, 13)
|
||||
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
|
||||
for nid in (9, 13):
|
||||
nf = pernode.get(nid, {})
|
||||
feats += [nf.get("mean_rssi", 0.0), nf.get("variance", 0.0), nf.get("motion_band_power", 0.0)]
|
||||
field = (frame.get("signal_field", {}) or {}).get("values") or []
|
||||
field = (field + [0.0] * 400)[:400]
|
||||
return feats + field # 4 + 6 + 400 = 410-d
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="WiFlow camera-supervised CSI<->pose capture (ADR-180).")
|
||||
ap.add_argument("--ws", default="ws://localhost:8765/ws/sensing")
|
||||
ap.add_argument("--cam", type=int, default=0)
|
||||
ap.add_argument("--out", default=os.path.expanduser("~/wiflow-room/dataset.jsonl"))
|
||||
ap.add_argument("--seconds", type=int, default=180)
|
||||
ap.add_argument("--min-vis", type=float, default=0.5, help="min mean landmark visibility to accept a pose label")
|
||||
ap.add_argument("--max-skew-ms", type=float, default=150, help="max csi/pose time skew to pair")
|
||||
ap.add_argument("--require-esp32", action="store_true", default=True,
|
||||
help="only pair when csi source==esp32 (real). Default on.")
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
th = threading.Thread(target=csi_thread, args=(args.ws,), daemon=True)
|
||||
th.start()
|
||||
|
||||
cap = cv2.VideoCapture(args.cam)
|
||||
if not cap.isOpened():
|
||||
print(f"ERROR: cannot open camera {args.cam}", file=sys.stderr); sys.exit(2)
|
||||
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
|
||||
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 480
|
||||
model_path = ensure_model(os.path.expanduser("~/wiflow-room/pose_landmarker_lite.task"))
|
||||
landmarker = PoseLandmarker.create_from_options(PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=model_path),
|
||||
running_mode=RunningMode.IMAGE, min_pose_detection_confidence=0.5))
|
||||
|
||||
n_pairs = 0; n_nopose = 0; n_nocsi = 0; n_skew = 0; n_sim = 0
|
||||
t0 = time.time()
|
||||
print(f"[capture] camera {args.cam} {W}x{H} -> {args.out} for {args.seconds}s")
|
||||
print("[capture] stand in view AND in the CSI field; move/walk so poses vary. Ctrl-C to stop.")
|
||||
with open(args.out, "a") as out:
|
||||
try:
|
||||
while time.time() - t0 < args.seconds:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
continue
|
||||
now = time.time()
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
res = landmarker.detect(mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb))
|
||||
if not res.pose_landmarks:
|
||||
n_nopose += 1; continue
|
||||
lm = res.pose_landmarks[0]
|
||||
kps = [[lm[i].x, lm[i].y, lm[i].visibility] for i in COCO_FROM_MP]
|
||||
vis = float(np.mean([k[2] for k in kps]))
|
||||
if vis < args.min_vis:
|
||||
n_nopose += 1; continue
|
||||
with _csi_lock:
|
||||
ct = _latest_csi["t"]; cf = _latest_csi["frame"]
|
||||
if cf is None:
|
||||
n_nocsi += 1; continue
|
||||
if (now - ct) * 1000.0 > args.max_skew_ms:
|
||||
n_skew += 1; continue
|
||||
if args.require_esp32 and cf.get("source") != "esp32":
|
||||
n_sim += 1; continue
|
||||
rec = {"t": now, "vis": round(vis, 3),
|
||||
"kps": [[round(x, 4), round(y, 4), round(v, 3)] for x, y, v in kps],
|
||||
"csi": csi_vector(cf),
|
||||
"src": cf.get("source"),
|
||||
"nodes": sorted(n.get("node_id") for n in cf.get("nodes", []) if n.get("node_id") is not None)}
|
||||
out.write(json.dumps(rec) + "\n")
|
||||
n_pairs += 1
|
||||
if n_pairs % 30 == 0:
|
||||
out.flush()
|
||||
el = int(now - t0)
|
||||
print(f"[capture] t+{el:3d}s pairs={n_pairs} (skip: nopose={n_nopose} nocsi={n_nocsi} skew={n_skew} sim={n_sim})", flush=True)
|
||||
except KeyboardInterrupt:
|
||||
print("\n[capture] stopped by user")
|
||||
_stop.set(); cap.release()
|
||||
print(f"[capture] DONE. wrote {n_pairs} paired samples to {args.out}")
|
||||
print(f"[capture] skipped: no-pose={n_nopose} no-csi={n_nocsi} skew={n_skew} simulated={n_sim}")
|
||||
if n_pairs == 0:
|
||||
print("[capture] WARNING: 0 pairs — check camera sees you AND csi source==esp32 (live).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Live CSI->pose inference bridge (ADR-180).
|
||||
|
||||
Runs on the box with the live CSI. Loads the camera-supervised model (numpy,
|
||||
no torch needed), subscribes to /ws/sensing, runs a forward pass per frame, and
|
||||
broadcasts the predicted 17-keypoint pose to HTML clients on ws://:8770/pose.
|
||||
|
||||
python wiflow_infer.py --model model/model.npz \
|
||||
--in ws://localhost:8765/ws/sensing --port 8770
|
||||
"""
|
||||
import argparse, asyncio, json, os
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
# COCO skeleton edges (for the client; sent once in 'meta')
|
||||
EDGES = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12],
|
||||
[11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]]
|
||||
|
||||
def csi_vector(frame):
|
||||
f = frame.get("features", {}) or {}
|
||||
feats = [f.get("mean_rssi",0.0), f.get("variance",0.0),
|
||||
f.get("motion_band_power",0.0), f.get("breathing_band_power",0.0)]
|
||||
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
|
||||
for nid in (9,13):
|
||||
nf = pernode.get(nid,{}); feats += [nf.get("mean_rssi",0.0), nf.get("variance",0.0), nf.get("motion_band_power",0.0)]
|
||||
field = (frame.get("signal_field",{}) or {}).get("values") or []
|
||||
field = (field + [0.0]*400)[:400]
|
||||
return np.array(feats + field, np.float32)
|
||||
|
||||
class Model:
|
||||
def __init__(self, path):
|
||||
z = np.load(path)
|
||||
self.mu, self.sd = z["mu"], z["sd"]
|
||||
self.W = [z["net_0_weight"], z["net_3_weight"], z["net_6_weight"], z["net_8_weight"]]
|
||||
self.b = [z["net_0_bias"], z["net_3_bias"], z["net_6_bias"], z["net_8_bias"]]
|
||||
def __call__(self, x):
|
||||
h = (x - self.mu) / self.sd
|
||||
for i in range(3):
|
||||
h = np.maximum(0.0, h @ self.W[i].T + self.b[i]) # Linear+ReLU
|
||||
out = 1.0/(1.0+np.exp(-(h @ self.W[3].T + self.b[3]))) # Linear+Sigmoid -> 34
|
||||
return out.reshape(17,2)
|
||||
|
||||
CLIENTS = set()
|
||||
LATEST = {"pose": None}
|
||||
|
||||
async def serve_client(ws):
|
||||
CLIENTS.add(ws)
|
||||
try:
|
||||
await ws.send(json.dumps({"type":"meta","edges":EDGES}))
|
||||
async for _ in ws: # client is read-only; just keep alive
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
CLIENTS.discard(ws)
|
||||
|
||||
async def infer_loop(model, in_url):
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(in_url, open_timeout=8, ping_interval=20) as ws:
|
||||
async for msg in ws:
|
||||
d = json.loads(msg)
|
||||
kp = model(csi_vector(d))
|
||||
cls = d.get("classification",{})
|
||||
payload = {"type":"pose","src":d.get("source"),
|
||||
"presence":bool(cls.get("presence")),
|
||||
"motion":(d.get("features",{}) or {}).get("motion_band_power"),
|
||||
"kps":[[round(float(x),4),round(float(y),4)] for x,y in kp],
|
||||
"nodes":sorted(n.get("node_id") for n in d.get("nodes",[]) if n.get("node_id") is not None)}
|
||||
LATEST["pose"]=payload
|
||||
if CLIENTS:
|
||||
dead=[]
|
||||
for c in list(CLIENTS):
|
||||
try: await c.send(json.dumps(payload))
|
||||
except Exception: dead.append(c)
|
||||
for c in dead: CLIENTS.discard(c)
|
||||
except Exception as e:
|
||||
print(f"[infer] reconnect ({e})", flush=True); await asyncio.sleep(1.0)
|
||||
|
||||
async def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__),"model","model.npz"))
|
||||
ap.add_argument("--in", dest="in_url", default="ws://localhost:8765/ws/sensing")
|
||||
ap.add_argument("--port", type=int, default=8770)
|
||||
args = ap.parse_args()
|
||||
model = Model(args.model)
|
||||
print(f"[infer] model {args.model} loaded; serving predicted poses on ws://0.0.0.0:{args.port}/pose")
|
||||
async with websockets.serve(serve_client, "0.0.0.0", args.port):
|
||||
await infer_loop(model, args.in_url)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train a CSI->pose model on the camera-supervised dataset (ADR-079/180).
|
||||
|
||||
Input : 410-d CSI vector (4 global feats + 6 per-node + 400 signal-field).
|
||||
Target : 17 COCO keypoints (x,y), normalized 0..1 from the camera (ground truth).
|
||||
Reports HONEST held-out PCK@k + MPJPE on a chronological val split (the last
|
||||
20% of the session — never trained on), so the number is not leaked.
|
||||
|
||||
Usage (ruvultra venv):
|
||||
python wiflow_train.py --data ~/wiflow-room/dataset.jsonl --out ~/wiflow-room/model.pt
|
||||
"""
|
||||
import argparse, json, math, os, sys
|
||||
import numpy as np
|
||||
import torch, torch.nn as nn
|
||||
|
||||
|
||||
def load(path):
|
||||
X, Y, V = [], [], []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
r = json.loads(line)
|
||||
X.append(r["csi"]) # 410
|
||||
kp = r["kps"] # 17 x [x,y,vis]
|
||||
Y.append([c for k in kp for c in (k[0], k[1])]) # 34
|
||||
V.append([k[2] for k in kp]) # 17 visibilities
|
||||
return np.array(X, np.float32), np.array(Y, np.float32), np.array(V, np.float32)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, din, dout):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(din, 512), nn.ReLU(), nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3),
|
||||
nn.Linear(256, 128), nn.ReLU(),
|
||||
nn.Linear(128, dout), nn.Sigmoid()) # coords in 0..1
|
||||
def forward(self, x): return self.net(x)
|
||||
|
||||
|
||||
def pck(pred, gt, vis, thr):
|
||||
# pred/gt: [N,34] -> [N,17,2]; PCK@thr in normalized image units, visible kps only
|
||||
p = pred.reshape(-1, 17, 2); g = gt.reshape(-1, 17, 2)
|
||||
d = np.linalg.norm(p - g, axis=2) # [N,17]
|
||||
m = vis > 0.5
|
||||
return float((d[m] < thr).mean()) if m.any() else 0.0, float(d[m].mean()) if m.any() else float("nan")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--data", required=True)
|
||||
ap.add_argument("--out", default=os.path.expanduser("~/wiflow-room/model.pt"))
|
||||
ap.add_argument("--epochs", type=int, default=300)
|
||||
ap.add_argument("--bs", type=int, default=64)
|
||||
args = ap.parse_args()
|
||||
|
||||
X, Y, V = load(args.data)
|
||||
n = len(X)
|
||||
print(f"[train] {n} samples, X={X.shape} Y={Y.shape}")
|
||||
if n < 200:
|
||||
print("[train] too few samples"); sys.exit(2)
|
||||
|
||||
# chronological split (NOT shuffled) so val is a held-out time segment -> honest
|
||||
cut = int(n * 0.8)
|
||||
mu, sd = X[:cut].mean(0), X[:cut].std(0) + 1e-6 # standardize on train only
|
||||
Xn = (X - mu) / sd
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
Xtr = torch.tensor(Xn[:cut]).to(dev); Ytr = torch.tensor(Y[:cut]).to(dev)
|
||||
Xva = torch.tensor(Xn[cut:]).to(dev); Yva = Y[cut:]; Vva = V[cut:]
|
||||
|
||||
# mean-pose baseline (predict the train-mean pose for everything) — the bar to beat
|
||||
mean_pose = Y[:cut].mean(0)
|
||||
base_pck, base_mpjpe = pck(np.tile(mean_pose, (len(Yva), 1)), Yva, Vva, 0.10)
|
||||
|
||||
net = Net(X.shape[1], Y.shape[1]).to(dev)
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
lossf = nn.MSELoss()
|
||||
best = (1e9, None)
|
||||
for ep in range(args.epochs):
|
||||
net.train(); perm = torch.randperm(len(Xtr), device=dev)
|
||||
for i in range(0, len(Xtr), args.bs):
|
||||
idx = perm[i:i+args.bs]
|
||||
opt.zero_grad(); out = net(Xtr[idx]); loss = lossf(out, Ytr[idx]); loss.backward(); opt.step()
|
||||
if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
|
||||
net.eval()
|
||||
with torch.no_grad(): pv = net(Xva).cpu().numpy()
|
||||
p10, mpj = pck(pv, Yva, Vva, 0.10); p05, _ = pck(pv, Yva, Vva, 0.05)
|
||||
vloss = float(((pv - Yva) ** 2).mean())
|
||||
print(f"[train] ep{ep+1:3d} val_mse={vloss:.4f} PCK@0.10={p10*100:.1f}% PCK@0.05={p05*100:.1f}% MPJPE={mpj:.4f}")
|
||||
if vloss < best[0]: best = (vloss, {"sd": net.state_dict(), "p10": p10, "p05": p05, "mpj": mpj})
|
||||
|
||||
torch.save({"model": best[1]["sd"], "mu": mu, "sd": sd, "din": X.shape[1]}, args.out)
|
||||
print("\n==================== HONEST RESULT (held-out 20%, never trained) ====================")
|
||||
print(f" MEAN-POSE BASELINE : PCK@0.10 = {base_pck*100:.1f}% MPJPE = {base_mpjpe:.4f} (the bar to beat)")
|
||||
print(f" CSI->POSE MODEL : PCK@0.10 = {best[1]['p10']*100:.1f}% PCK@0.05 = {best[1]['p05']*100:.1f}% MPJPE = {best[1]['mpj']:.4f}")
|
||||
delta = (best[1]['p10'] - base_pck) * 100
|
||||
print(f" VERDICT: model {'BEATS' if delta>1 else 'does NOT beat'} mean-pose baseline by {delta:+.1f} pp "
|
||||
f"-> {'real CSI->pose signal' if delta>1 else 'NO usable CSI->pose signal (honest negative)'}")
|
||||
print(f" saved -> {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -468,3 +468,29 @@ menu "Mock CSI (QEMU Testing)"
|
||||
depends on CSI_MOCK_ENABLED
|
||||
default n
|
||||
endmenu
|
||||
|
||||
menu "Onboard LED (ADR-183)"
|
||||
|
||||
config LED_GAMMA_VIZ
|
||||
bool "Onboard WS2812: 40 Hz gamma flicker + CSI-motion colour"
|
||||
default y
|
||||
help
|
||||
Drive the onboard WS2812 as a GENUS-style 40 Hz gamma square wave
|
||||
(12.5 ms on / 12.5 ms off, 50% duty). The ON-phase colour is live
|
||||
CSI motion (edge motion_energy) mapped through the ruv-neural-viz
|
||||
viridis colormap (still=purple, moving=yellow).
|
||||
|
||||
Disable to leave the LED off at boot — lower power, no flicker.
|
||||
NOTE: a 40 Hz flicker can affect photosensitive users; disable or
|
||||
shield the LED in those environments. Not a medical device.
|
||||
|
||||
config LED_MOTION_FULLSCALE_MILLI
|
||||
int "Motion value (x1000) that saturates the colormap to yellow"
|
||||
depends on LED_GAMMA_VIZ
|
||||
default 250
|
||||
range 1 100000
|
||||
help
|
||||
edge motion_energy that maps to the top (yellow) of the viridis
|
||||
colormap, in milli-units (250 = 0.25). Lower = more sensitive
|
||||
(reaches yellow with less motion).
|
||||
endmenu
|
||||
|
||||
@@ -114,6 +114,19 @@ esp_err_t display_task_start(void)
|
||||
/* Init touch (optional) */
|
||||
esp_err_t touch_ret = display_hal_init_touch();
|
||||
|
||||
/* The SH8601 QSPI panel is write-only — display_hal_init_panel() above "succeeds"
|
||||
* even on a bare board with no panel attached, so it cannot detect absence. The
|
||||
* FT3168 touch controller is an I2C device with readback and is always present on
|
||||
* the Touch-AMOLED board. If touch is absent, the panel "success" was a false-
|
||||
* positive on a display-less DevKit: bail to headless so display_is_active() stays
|
||||
* false and CSI upgrades to MGMT+DATA capture instead of starving at MGMT-only
|
||||
* (RuView#1000). */
|
||||
if (touch_ret != ESP_OK) {
|
||||
ESP_LOGW(TAG, "No FT3168 touch readback — SH8601 probe was a false-positive on a "
|
||||
"display-less board; running headless so CSI captures (#1000)");
|
||||
return ESP_OK;
|
||||
}
|
||||
|
||||
/* Initialize LVGL */
|
||||
lv_init();
|
||||
|
||||
|
||||
@@ -144,6 +144,54 @@ static void wifi_init_sta(void)
|
||||
}
|
||||
}
|
||||
|
||||
#if CONFIG_LED_GAMMA_VIZ
|
||||
/* Viridis colormap (60 steps), generated from ruv-neural-viz::ColorMap::viridis()
|
||||
* — the rUv-Neural brain-topology colormap, now no_std (ruvnet/ruv-neural#3 /
|
||||
* RuView#1126). Used as the ON-phase colour of the 40 Hz gamma flicker below:
|
||||
* dark-purple (still) -> teal -> green -> yellow (strong motion). */
|
||||
static const uint8_t VIRIDIS_LUT[60][3] = {
|
||||
{ 68, 1, 84},{ 67, 6, 88},{ 67, 12, 91},{ 66, 17, 95},{ 66, 23, 99},
|
||||
{ 65, 28,103},{ 64, 34,106},{ 64, 39,110},{ 63, 45,114},{ 63, 50,118},
|
||||
{ 62, 56,121},{ 61, 61,125},{ 61, 67,129},{ 60, 72,132},{ 59, 78,136},
|
||||
{ 59, 83,139},{ 57, 87,139},{ 55, 92,139},{ 53, 96,139},{ 52,100,139},
|
||||
{ 50,104,139},{ 48,109,139},{ 46,113,139},{ 44,117,140},{ 43,122,140},
|
||||
{ 41,126,140},{ 39,130,140},{ 37,134,140},{ 36,139,140},{ 34,143,140},
|
||||
{ 35,147,139},{ 39,151,136},{ 43,154,133},{ 47,158,130},{ 52,162,127},
|
||||
{ 56,166,124},{ 60,170,121},{ 64,173,119},{ 68,177,116},{ 72,181,113},
|
||||
{ 76,185,110},{ 81,189,107},{ 85,192,104},{ 89,196,102},{ 93,200, 99},
|
||||
{102,203, 95},{113,205, 91},{124,207, 87},{134,209, 82},{145,211, 78},
|
||||
{156,213, 74},{167,215, 70},{178,217, 66},{188,219, 62},{199,221, 58},
|
||||
{210,223, 54},{221,225, 49},{231,227, 45},{242,229, 41},{253,231, 37},
|
||||
};
|
||||
static led_strip_handle_t s_viz_led;
|
||||
|
||||
/* motion_energy that saturates the colormap to yellow (CONFIG, milli-units). */
|
||||
#define LED_MOTION_FULLSCALE ((float)CONFIG_LED_MOTION_FULLSCALE_MILLI / 1000.0f)
|
||||
|
||||
/* GENUS-style 40 Hz gamma flicker: full on/off square wave, 50% duty (toggled
|
||||
* every 12.5 ms → 40 Hz). The ON colour is live CSI motion (edge motion_energy)
|
||||
* mapped through the ruv-neural-viz viridis LUT — still=purple, moving=yellow.
|
||||
* So the LED is a real 40 Hz gamma stimulus whose hue tracks sensed motion. */
|
||||
static void led_gamma_40hz_cb(void *arg)
|
||||
{
|
||||
static bool on = false;
|
||||
on = !on;
|
||||
if (on) {
|
||||
edge_vitals_pkt_t v;
|
||||
float m = edge_get_vitals(&v) ? v.motion_energy : 0.0f;
|
||||
float norm = m / LED_MOTION_FULLSCALE;
|
||||
if (norm < 0.0f) norm = 0.0f;
|
||||
if (norm > 1.0f) norm = 1.0f;
|
||||
int idx = (int)(norm * 59.0f + 0.5f);
|
||||
const uint8_t *c = VIRIDIS_LUT[idx];
|
||||
led_strip_set_pixel(s_viz_led, 0, c[0], c[1], c[2]); /* R,G,B (driver maps to GRB) */
|
||||
} else {
|
||||
led_strip_set_pixel(s_viz_led, 0, 0, 0, 0); /* off phase */
|
||||
}
|
||||
led_strip_refresh(s_viz_led);
|
||||
}
|
||||
#endif /* CONFIG_LED_GAMMA_VIZ */
|
||||
|
||||
void app_main(void)
|
||||
{
|
||||
/* Initialize NVS */
|
||||
@@ -173,15 +221,16 @@ void app_main(void)
|
||||
ESP_LOGI(TAG, "%s CSI Node (ADR-018 / ADR-110) — v%s — Node ID: %d",
|
||||
target_name, app_desc->version, g_nvs_config.node_id);
|
||||
|
||||
/* Turn off onboard WS2812 LED.
|
||||
* S3 dev boards put the LED on GPIO 38; C6 dev boards on GPIO 8.
|
||||
* On C6, GPIO 38 doesn't exist (only 0-30) — gate the init by target. */
|
||||
/* Onboard WS2812. C6 wires the LED to GPIO 8; S3 to GPIO 38 (DevKitC-1 v1.0)
|
||||
* or GPIO 48 (DevKitC-1 v1.1 / N16R8 — see #962). On S3 we drive 48 (the
|
||||
* common module). On C6, GPIO 38/48 don't exist (only 0-30) — gate by target.
|
||||
* Behaviour is set by CONFIG_LED_GAMMA_VIZ (ADR-183): on = 40 Hz gamma flicker
|
||||
* coloured by CSI motion; off = clear the LED at boot. */
|
||||
#if defined(CONFIG_IDF_TARGET_ESP32C6)
|
||||
const int led_gpio = 8;
|
||||
#else
|
||||
const int led_gpio = 38;
|
||||
const int led_gpio = 48;
|
||||
#endif
|
||||
led_strip_handle_t led_strip;
|
||||
led_strip_config_t strip_config = {
|
||||
.strip_gpio_num = led_gpio,
|
||||
.max_leds = 1,
|
||||
@@ -193,9 +242,26 @@ void app_main(void)
|
||||
.resolution_hz = 10 * 1000 * 1000, // 10MHz
|
||||
.flags.with_dma = false,
|
||||
};
|
||||
#if CONFIG_LED_GAMMA_VIZ
|
||||
if (led_strip_new_rmt_device(&strip_config, &rmt_config, &s_viz_led) == ESP_OK) {
|
||||
const esp_timer_create_args_t viz_args = {
|
||||
.callback = &led_gamma_40hz_cb,
|
||||
.name = "led_gamma_40hz",
|
||||
};
|
||||
esp_timer_handle_t viz_timer;
|
||||
if (esp_timer_create(&viz_args, &viz_timer) == ESP_OK) {
|
||||
esp_timer_start_periodic(viz_timer, 12500); // 12.5 ms toggle → 40 Hz square wave
|
||||
ESP_LOGI(TAG, "Onboard WS2812: 40 Hz gamma flicker (GENUS), colour=CSI motion via ruv-neural-viz, GPIO %d", led_gpio);
|
||||
}
|
||||
}
|
||||
#else
|
||||
/* Viz disabled — clear the onboard LED at boot and release the RMT channel. */
|
||||
led_strip_handle_t led_strip;
|
||||
if (led_strip_new_rmt_device(&strip_config, &rmt_config, &led_strip) == ESP_OK) {
|
||||
led_strip_clear(led_strip);
|
||||
led_strip_del(led_strip);
|
||||
}
|
||||
#endif /* CONFIG_LED_GAMMA_VIZ */
|
||||
|
||||
/* ADR-110 P4: 802.15.4 mesh time-sync (C6 only).
|
||||
* Initialized BEFORE WiFi so it's available even when WiFi STA can't
|
||||
|
||||
@@ -387,11 +387,21 @@ static mmwave_type_t probe_at_baud(uint32_t baud)
|
||||
if (len <= 0) continue;
|
||||
|
||||
for (int i = 0; i < len; i++) {
|
||||
/* MR60BHA2: SOF = 0x01, followed by valid-looking frame_id bytes */
|
||||
if (buf[i] == MR60_SOF && baud == MMWAVE_MR60_BAUD) {
|
||||
mr60_sof_seen++;
|
||||
/* MR60BHA2: require a *validated* 8-byte header — SOF (0x01) + a valid
|
||||
* header checksum (over bytes 0..6) + a known frame type (0x0A__ or
|
||||
* 0x0F09) — NOT a bare 0x01 byte. A floating UART1 with no sensor reads
|
||||
* noise full of 0x01s, which the old `buf[i] == MR60_SOF` check mistook
|
||||
* for a real sensor (false "Detected MR60BHA2", #1107). */
|
||||
if (buf[i] == MR60_SOF && baud == MMWAVE_MR60_BAUD && i + 7 < len) {
|
||||
const uint8_t *h = &buf[i];
|
||||
if (mr60_calc_checksum(h, 7) == h[7]) {
|
||||
uint16_t type = ((uint16_t)h[5] << 8) | h[6];
|
||||
if ((type >> 8) == 0x0A || type == 0x0F09) {
|
||||
mr60_sof_seen++;
|
||||
}
|
||||
}
|
||||
}
|
||||
/* LD2410: 4-byte header 0xF4F3F2F1 */
|
||||
/* LD2410: 4-byte header 0xF4F3F2F1 (already specific enough). */
|
||||
if (i + 3 < len && buf[i] == 0xF4 && buf[i+1] == 0xF3
|
||||
&& buf[i+2] == 0xF2 && buf[i+3] == 0xF1
|
||||
&& baud == MMWAVE_LD2410_BAUD) {
|
||||
@@ -403,9 +413,8 @@ static mmwave_type_t probe_at_baud(uint32_t baud)
|
||||
if (ld2410_header_seen >= 2) return MMWAVE_TYPE_LD2410;
|
||||
}
|
||||
|
||||
if (mr60_sof_seen > 0) return MMWAVE_TYPE_MR60BHA2;
|
||||
if (ld2410_header_seen > 0) return MMWAVE_TYPE_LD2410;
|
||||
|
||||
/* No weak single-hit fallback: line noise can produce a stray match, so a real
|
||||
* sensor must clear the ≥3 (MR60) / ≥2 (LD2410) validated-frame thresholds. */
|
||||
return MMWAVE_TYPE_NONE;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(npx ruview*)",
|
||||
"mcp__ruview__*"
|
||||
],
|
||||
"deny": [
|
||||
"Read(./.env)",
|
||||
"Read(./.env.*)"
|
||||
]
|
||||
},
|
||||
"mcpServers": {
|
||||
"ruview": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@ruvnet/ruview", "mcp", "start"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
---
|
||||
name: calibrate-room
|
||||
description: Run the ADR-151 per-room calibration pipeline — baseline → enroll → extract → train → a bank of small specialists (presence/posture/breathing/heartbeat/restlessness/anomaly).
|
||||
---
|
||||
|
||||
# calibrate-room
|
||||
|
||||
Turn a provisioned node + sensing-server into a working room model. Pure-Rust,
|
||||
edge-deployable (ADR-151). Use the `ruview.calibrate` tool (installed
|
||||
`wifi-densepose` binary, else `cargo run -p wifi-densepose-cli`).
|
||||
|
||||
## Sequence
|
||||
|
||||
1. **baseline** — capture the empty room (Welford amplitude + von Mises phase). Leave
|
||||
the room empty.
|
||||
`ruview.calibrate {step: "baseline"}`
|
||||
2. **enroll** — record the occupant(s) doing the target activities.
|
||||
`ruview.calibrate {step: "enroll"}`
|
||||
3. **train-room** — train the bank of small specialists from baseline + enrollment.
|
||||
`ruview.calibrate {step: "train-room"}`
|
||||
4. **room-watch** — live presence/posture/breathing from the trained room.
|
||||
`ruview.calibrate {step: "room-watch"}` (or the `room-watch` skill)
|
||||
|
||||
## Honesty
|
||||
|
||||
The specialists are calibrated to *this* room; cross-room transfer is a separate
|
||||
problem (LoRA recalibration, ADR-079 P9). Report which room a number came from, and
|
||||
tag presence/vitals accuracy MEASURED only with a held-out check — run
|
||||
`ruview.claim_check` on the writeup.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
name: onboard
|
||||
description: Zero-to-sensing path picker for RuView (WiFi-DensePose) — pick docker-demo, repo-build, or live-esp32 and run the next concrete step.
|
||||
---
|
||||
|
||||
# onboard
|
||||
|
||||
Get a newcomer from nothing to a working RuView setup. **First fact to set:** WiFi
|
||||
sensing infers *coarse* pose/presence/breathing from Channel State Information — it
|
||||
is **not a camera**, and any accuracy number must be MEASURED against a baseline
|
||||
(use the `verify` skill / `ruview.claim_check` tool). Never present WiFi output as
|
||||
camera-grade.
|
||||
|
||||
## Pick a path
|
||||
|
||||
Run `ruview.onboard {path}` or decide from:
|
||||
|
||||
1. **docker-demo** — fastest, no hardware. Replays sample CSI into the dashboard.
|
||||
`docker run -p 8000:8000 ruvnet/wifi-densepose` → open `http://localhost:8000`.
|
||||
Use to see what it looks like.
|
||||
2. **repo-build** — for developers. `cd v2 && cargo test --workspace --no-default-features`
|
||||
(1,031+ tests pass), then `cargo run -p wifi-densepose-cli -- --help`.
|
||||
3. **live-esp32** — a real install. Flash a node (`provision-node` skill), point it at
|
||||
the sensing-server, then `calibrate-room`. This is the only path that senses a real room.
|
||||
|
||||
## Then
|
||||
|
||||
- Live sensing → go to **provision-node**, then **calibrate-room**.
|
||||
- Evaluating a model/claim → go to **verify** and run `ruview.claim_check` on any
|
||||
report before you quote a number.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
name: provision-node
|
||||
description: Build, flash, and provision an ESP32-S3/C6 CSI node for RuView — firmware variant choice, ESP-IDF Windows-subprocess flow, NVS/WiFi/channel/MAC-filter overrides.
|
||||
---
|
||||
|
||||
# provision-node
|
||||
|
||||
Bring an ESP32 sensing node online.
|
||||
|
||||
## 1. Pick a firmware variant
|
||||
|
||||
- **s3-8mb** (display build) — ESP32-S3 N16R8 / 16MB; AMOLED optional. The display-detect
|
||||
fix (#1000) means a *bare* board still captures CSI (MGMT+DATA).
|
||||
- **s3-4mb** (no-display) — ESP32-S3 4MB; dual-OTA, display disabled.
|
||||
- **c6** — ESP32-C6 + Seeed MR60BHA2 (60 GHz mmWave + WiFi CSI). The mmwave probe
|
||||
requires a validated MR60 header (#1107) so an empty UART never false-detects.
|
||||
|
||||
Prebuilt binaries: GitHub release `v0.8.1-esp32` (hardware-validated on S3 QFN56 rev v0.2).
|
||||
|
||||
## 2. Flash
|
||||
|
||||
ESP-IDF v5.4 on Windows is **subprocess-only** (Git Bash/MSYS is unsupported — strip
|
||||
`MSYSTEM*` env vars). Offsets for the S3 image:
|
||||
|
||||
```
|
||||
esptool --chip esp32s3 -p <PORT> -b 460800 write_flash \
|
||||
0x0 bootloader.bin 0x8000 partition-table.bin \
|
||||
0xf000 ota_data_initial.bin 0x20000 esp32-csi-node-s3-8mb.bin
|
||||
```
|
||||
|
||||
(`ruview.node_flash` returns the exact pinned command rather than running an
|
||||
unattended flash.)
|
||||
|
||||
## 3. Provision
|
||||
|
||||
```
|
||||
python firmware/esp32-csi-node/provision.py --port <PORT> \
|
||||
--ssid "<SSID>" --password "<secret>" --target-ip <server-ip> --target-port 5005
|
||||
# optional ADR-060 overrides:
|
||||
python firmware/esp32-csi-node/provision.py --port <PORT> --channel 6 --filter-mac AA:BB:CC:DD:EE:FF
|
||||
```
|
||||
|
||||
Never echo or commit the WiFi password.
|
||||
|
||||
## 4. Confirm CSI is flowing
|
||||
|
||||
`ruview.node_monitor {port}` — PASS criteria: serial shows `CSI cb #...` callbacks and
|
||||
(on a bare board) `CSI filter upgraded to MGMT+DATA`. No callbacks → the node isn't
|
||||
capturing; do not proceed to calibration.
|
||||
@@ -0,0 +1,33 @@
|
||||
---
|
||||
name: train-pose
|
||||
description: Train/evaluate WiFi pose models honestly — camera-supervised (MediaPipe + CSI) and camera-free (WiFlow), always checked against the mean-pose baseline before any PCK is quoted.
|
||||
---
|
||||
|
||||
# train-pose
|
||||
|
||||
Build a CSI→pose model without overstating it. The project has a **retracted 92.9%/100%**
|
||||
history — the discipline below exists so it never recurs.
|
||||
|
||||
## The non-negotiable: mean-pose baseline first
|
||||
|
||||
A pose model that always predicts the dataset's *mean pose* already scores ~50% PCK.
|
||||
**Quote PCK only as a delta over that baseline**, on a held-out split with no subject
|
||||
or temporal leakage. Example honest result (ADR-181):
|
||||
|
||||
> Held-out PCK@20 **59.5%** vs a 50% mean-pose baseline = **+9.4 pp real signal** — MEASURED.
|
||||
|
||||
## Paths
|
||||
|
||||
- **camera-supervised** (ADR-079) — MediaPipe Pose labels the camera frame; paired CSI
|
||||
trains the net. Train/infer in one camera frame so the skeleton aligns.
|
||||
- **camera-free** (WiFlow, ADR-152) — no camera at inference; geometry-conditioned.
|
||||
- **in-browser** (ADR-181) — WebGPU/WASM trainer; the active backend is shown as a badge
|
||||
(honest about what's executing).
|
||||
|
||||
## Before you publish a number
|
||||
|
||||
1. Run the mean-pose baseline on the same split.
|
||||
2. Report `(model − baseline)` in pp, with the split definition (chronological /
|
||||
blocked-gap / grouped-bucket; no leakage).
|
||||
3. `ruview.claim_check` the writeup — it flags any untagged or 100%/perfect claim.
|
||||
4. If it's a benchmark vs SOTA, tag MEASURED-EQUIVALENT only with the reproducer.
|
||||
@@ -0,0 +1,42 @@
|
||||
---
|
||||
name: verify
|
||||
description: Prove a RuView result is real — run the deterministic SHA-256 proof and the witness bundle (ADR-028), and lint any claim for MEASURED-vs-CLAIMED honesty.
|
||||
---
|
||||
|
||||
# verify
|
||||
|
||||
The "prove everything" skill. Nothing ships as validated without this.
|
||||
|
||||
## Deterministic proof (Trust Kill Switch)
|
||||
|
||||
`ruview.verify` runs `archive/v1/data/proof/verify.py`: it feeds a reference signal
|
||||
through the production pipeline and hashes the output against
|
||||
`expected_features.sha256`. Must print **VERDICT: PASS**. If numpy/scipy changed the
|
||||
hash, regenerate with `verify.py --generate-hash` then re-verify.
|
||||
|
||||
## Witness bundle (ADR-028)
|
||||
|
||||
For a release-grade attestation:
|
||||
|
||||
```
|
||||
bash scripts/generate-witness-bundle.sh
|
||||
cd dist/witness-bundle-ADR028-*/ && bash VERIFY.sh # must be 7/7 PASS
|
||||
```
|
||||
|
||||
Contains the Rust test log, the proof + expected hash, firmware SHA-256 manifest, and
|
||||
crate versions — a recipient can re-verify with one command.
|
||||
|
||||
## Claim honesty
|
||||
|
||||
Run `ruview.claim_check {text}` on any report, README section, PR body, or model card
|
||||
before quoting accuracy. It flags:
|
||||
- untagged accuracy numbers (must be MEASURED / CLAIMED / SYNTHETIC),
|
||||
- MEASURED claims with no reproducer cited,
|
||||
- the retracted "100%/perfect accuracy" framing.
|
||||
|
||||
## Firmware-specific
|
||||
|
||||
A firmware fix is **not** "hardware-validated" without a captured boot log on real
|
||||
silicon (e.g. the `v0.8.1-esp32` rev-v0.2 validation: `running headless so CSI
|
||||
captures (#1000)` + `CSI filter upgraded to MGMT+DATA` + a no-false-detect mmwave
|
||||
probe). Do not merge or release on a build-passes signal alone.
|
||||
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"schema": 1,
|
||||
"generator": "metaharness 0.1.15 + ADR-182 hardening",
|
||||
"template": "vertical:ruview",
|
||||
"name": "@ruvnet/ruview",
|
||||
"vars": {
|
||||
"name": "@ruvnet/ruview",
|
||||
"description": "RuView WiFi-sensing operator agent harness",
|
||||
"host": "claude-code"
|
||||
},
|
||||
"hosts": [
|
||||
"claude-code"
|
||||
],
|
||||
"files": {
|
||||
".claude/settings.json": "b0ea971383716f18b89db73010b8f0ea0f1b16bdec4cd1068245772ba1c27bdd",
|
||||
".claude/skills/calibrate-room/SKILL.md": "6a6c8211a7109feb76620c618963c10ad9a9f633ffce7676e631a80a1181986d",
|
||||
".claude/skills/onboard/SKILL.md": "22323732fe746b38b77a7c8c052e952dff2fe87ae939ba125379125827385f21",
|
||||
".claude/skills/provision-node/SKILL.md": "5ffe5a75873e873b80758d9c81005774d4191317227f2e9aa4345cbce3f29751",
|
||||
".claude/skills/train-pose/SKILL.md": "b3ee95bfb0b678eb3d101138b9ea0e7cab3db3a9906d19c4059f9cca0598e87b",
|
||||
".claude/skills/verify/SKILL.md": "c0314d5ead465d9089b6a4917fd125051a5be20dc07ba92d5b601fcaada32e19",
|
||||
"CLAUDE.md": "7ecdb2b9d9abcf4aa22dd3ce553b60216a135e147893a59fa944fc1a8c81f5ef",
|
||||
"LICENSE": "631f94984f626818d42ecf717aa6e8e0afd4f9f355ca706bd2effafbd1416d06",
|
||||
"README.md": "b77d30428de8efb6758f2ca3eb22e84849013b2c0e6c601d488d2ea5a6f0da44",
|
||||
"bin/cli.js": "b0d74690cff4329dfe342271fc475eaa140b767bdb66b37cf4992ad209012fe8",
|
||||
"package.json": "2af49561ef0d59cafc4b99885816e580635b2d2ad329dfe17c69b9df6f8afceb",
|
||||
"skills/calibrate-room.md": "6a6c8211a7109feb76620c618963c10ad9a9f633ffce7676e631a80a1181986d",
|
||||
"skills/onboard.md": "22323732fe746b38b77a7c8c052e952dff2fe87ae939ba125379125827385f21",
|
||||
"skills/provision-node.md": "5ffe5a75873e873b80758d9c81005774d4191317227f2e9aa4345cbce3f29751",
|
||||
"skills/train-pose.md": "b3ee95bfb0b678eb3d101138b9ea0e7cab3db3a9906d19c4059f9cca0598e87b",
|
||||
"skills/verify.md": "c0314d5ead465d9089b6a4917fd125051a5be20dc07ba92d5b601fcaada32e19",
|
||||
"src/guardrails.js": "1631cea02c4354fe6126c576300faf5f8b68ae2f5e2e3a658c99eb25a7403e55",
|
||||
"src/mcp-server.js": "e51379f5ebb0b7b4670c7412714e559931ef1be8df20551f8f7309b53f0fb7af",
|
||||
"src/tools.js": "b558f61bb202abf5a967ce3a6ccaea351f2d186238cf49c7fc151d1de028eee8"
|
||||
},
|
||||
"meta": {
|
||||
"surface": "cli+mcp",
|
||||
"adr": "ADR-182"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
6c6c1431c37472494c9b309c8b5d761dd4fc41e30313baead6320831fb982e57 manifest.json
|
||||
@@ -0,0 +1,34 @@
|
||||
# RuView harness — agent operating notes
|
||||
|
||||
You are operating **RuView** (WiFi-DensePose), a camera-free WiFi-CSI sensing system.
|
||||
|
||||
## The one rule: prove everything
|
||||
|
||||
This project was accused of AI-slop; the fix is hard discipline. Before you quote ANY
|
||||
accuracy number:
|
||||
|
||||
1. It must be tagged **MEASURED** (with a reproducer named), **CLAIMED**, or **SYNTHETIC**.
|
||||
2. Pose PCK is quoted only as a **delta over the mean-pose baseline** on a leakage-free
|
||||
held-out split. (A mean-pose predictor already scores ~50% PCK.)
|
||||
3. Run `ruview.claim_check` on any report/PR/model-card. It flags untagged numbers and
|
||||
the retracted "100%/perfect accuracy" framing.
|
||||
4. Firmware is "hardware-validated" only with a captured **boot log on real silicon** —
|
||||
never on a build-passes signal.
|
||||
|
||||
## Tools
|
||||
|
||||
`ruview.onboard`, `ruview.claim_check`, `ruview.verify`, `ruview.node_monitor`,
|
||||
`ruview.calibrate`, `ruview.node_flash`. All fail-closed. Mutating/hardware tools
|
||||
(`node_flash`) require explicit confirmation and are Windows/ESP-IDF gated.
|
||||
|
||||
## Skills
|
||||
|
||||
`onboard` · `provision-node` · `calibrate-room` · `train-pose` · `verify`
|
||||
(`npx @ruvnet/ruview skill <name>`).
|
||||
|
||||
## Don'ts
|
||||
|
||||
- Don't present WiFi sensing as camera-grade.
|
||||
- Don't echo or commit WiFi passwords / secrets.
|
||||
- Don't merge or release firmware without a real boot log.
|
||||
- Don't report a PCK without its mean-pose baseline.
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 ruvnet
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,60 @@
|
||||
# `npx @ruvnet/ruview` — RuView WiFi-sensing operator harness
|
||||
|
||||
An AI agent harness that knows how to operate **RuView** (WiFi-DensePose): onboard a
|
||||
newcomer, provision an ESP32 CSI node, calibrate a room, train pose models, and —
|
||||
crucially — **refuse to overstate accuracy**. Minted from the RuView monorepo via
|
||||
[`metaharness`](https://www.npmjs.com/package/metaharness) and hardened per **ADR-182**.
|
||||
|
||||
WiFi sensing infers *coarse* pose/presence/breathing from Channel State Information.
|
||||
It is **not a camera**. Every accuracy number this harness emits must be MEASURED
|
||||
against a baseline — that rule is enforced in code (`ruview.claim_check`).
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
npx @ruvnet/ruview # onboard — pick a setup path
|
||||
npx @ruvnet/ruview claim-check --text "we hit 100% accuracy" # the honesty guardrail
|
||||
npx @ruvnet/ruview verify # run the deterministic proof (VERDICT: PASS)
|
||||
npx @ruvnet/ruview doctor # self-check (tools + optional kernel/host)
|
||||
npx @ruvnet/ruview --help
|
||||
```
|
||||
|
||||
The operator tools are pure Node and run with **zero install weight**. The
|
||||
`@metaharness/kernel` + host adapter are `optionalDependencies` — only `doctor` /
|
||||
`install` use them, only if present.
|
||||
|
||||
## Tools (`ruview.*`)
|
||||
|
||||
Exposed both as CLI verbs and as an MCP server (`npx @ruvnet/ruview mcp start`):
|
||||
|
||||
| Tool | What it does |
|
||||
|------|--------------|
|
||||
| `ruview.onboard` | Pick docker-demo / repo-build / live-esp32; print the next command |
|
||||
| `ruview.claim_check` | Lint text for untagged / overstated accuracy claims (guardrail) |
|
||||
| `ruview.verify` | Run `verify.py` deterministic proof → VERDICT |
|
||||
| `ruview.node_monitor` | Assert CSI is flowing on an ESP32 (read-only) |
|
||||
| `ruview.calibrate` | ADR-151 room pipeline (baseline→enroll→train-room→room-watch) |
|
||||
| `ruview.node_flash` | Build+flash firmware (Windows/ESP-IDF; mutating, guarded) |
|
||||
|
||||
Every tool is **fail-closed**: missing repo / python / binary / port → an honest
|
||||
negative, never a fabricated success.
|
||||
|
||||
## Skills
|
||||
|
||||
Host-neutral playbooks in `skills/` (`onboard`, `provision-node`, `calibrate-room`,
|
||||
`train-pose`, `verify`). `npx @ruvnet/ruview skill <name>` prints one.
|
||||
|
||||
## Use as a Claude Code MCP server
|
||||
|
||||
The bundled `.claude/settings.json` registers the `ruview` MCP server
|
||||
(`npx -y @ruvnet/ruview mcp start`). Drop this package's `.claude/` into a repo, or run
|
||||
`npx @ruvnet/ruview install --host claude-code`.
|
||||
|
||||
## Hosts
|
||||
|
||||
claude-code (bundled), and via metaharness host adapters: codex, opencode, copilot,
|
||||
pi-dev, hermes, rvm, github-actions.
|
||||
|
||||
## License
|
||||
|
||||
MIT © ruvnet
|
||||
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env node
|
||||
// SPDX-License-Identifier: MIT
|
||||
// `npx ruview` — the RuView WiFi-sensing operator harness (minted via metaharness,
|
||||
// hardened per ADR-182). Plain ESM, no build step: ships and runs as-is.
|
||||
//
|
||||
// The `ruview.*` tools (onboard/verify/claim-check/…) are PURE Node and run with
|
||||
// zero deps. The kernel + host adapter are only touched by `doctor`/`install`
|
||||
// (the harness-into-a-repo story), so the operator tools never block on a wasm load.
|
||||
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import { realpathSync, existsSync, readdirSync, readFileSync } from 'node:fs';
|
||||
import { join, dirname } from 'node:path';
|
||||
import { argv } from 'node:process';
|
||||
import { TOOLS, runTool, listTools } from '../src/tools.js';
|
||||
import { claimCheck, summarize } from '../src/guardrails.js';
|
||||
|
||||
const NAME = 'ruview';
|
||||
const ROOT = dirname(dirname(fileURLToPath(import.meta.url)));
|
||||
const SKILLS_DIR = join(ROOT, 'skills');
|
||||
|
||||
// Map friendly CLI verbs → registry tool names.
|
||||
const VERB_TO_TOOL = {
|
||||
onboard: 'ruview.onboard',
|
||||
verify: 'ruview.verify',
|
||||
'claim-check': 'ruview.claim_check',
|
||||
calibrate: 'ruview.calibrate',
|
||||
monitor: 'ruview.node_monitor',
|
||||
flash: 'ruview.node_flash',
|
||||
};
|
||||
|
||||
function pjson(o) { console.log(JSON.stringify(o, null, 2)); }
|
||||
|
||||
function listSkills() {
|
||||
if (!existsSync(SKILLS_DIR)) return [];
|
||||
return readdirSync(SKILLS_DIR).filter((f) => f.endsWith('.md')).map((f) => f.replace(/\.md$/, ''));
|
||||
}
|
||||
|
||||
async function doctor() {
|
||||
const checks = [];
|
||||
// Tools layer (always available, no deps).
|
||||
checks.push(['tool registry loads', Object.keys(TOOLS).length > 0]);
|
||||
checks.push(['claim_check flags a 100% claim',
|
||||
!claimCheck('We hit 100% accuracy on poses.').ok]);
|
||||
checks.push(['claim_check passes a tagged MEASURED claim',
|
||||
claimCheck('Held-out PCK@20 59.5% (MEASURED vs mean-pose baseline, verify.py).').ok]);
|
||||
checks.push(['skills present', listSkills().length > 0]);
|
||||
// Kernel + host adapter (optional — only needed to install into a repo).
|
||||
let kernelLine = 'kernel/host: not installed (ok — operator tools run without them)';
|
||||
try {
|
||||
const { loadKernel } = await import('@metaharness/kernel');
|
||||
const adapter = (await import('@metaharness/host-claude-code')).default;
|
||||
const k = await loadKernel();
|
||||
const info = k.kernelInfo();
|
||||
checks.push(['kernel loads + reports version', typeof info.version === 'string' && info.version.length > 0]);
|
||||
checks.push(['kernel backend is native|wasm|js', ['native', 'wasm', 'js'].includes(k.backend)]);
|
||||
checks.push(['host adapter resolves', typeof adapter?.name === 'string']);
|
||||
kernelLine = `kernel ${info.version} (${k.backend}) · host ${adapter.name}`;
|
||||
} catch {
|
||||
/* kernel not installed — fine for the tools-only path */
|
||||
}
|
||||
let ok = true;
|
||||
for (const [label, pass] of checks) { console.log(`${pass ? 'PASS' : 'FAIL'} ${label}`); if (!pass) ok = false; }
|
||||
console.log(`\n${NAME}: ${ok ? 'all checks passed' : 'doctor found problems'} — ${kernelLine}`);
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
||||
function help() {
|
||||
console.log(`Usage: ${NAME} <command> [options]
|
||||
|
||||
Operator tools:
|
||||
onboard [--path docker-demo|repo-build|live-esp32] pick a setup path
|
||||
verify [--repo <dir>] run the deterministic proof (VERDICT: PASS)
|
||||
claim-check --text "..." | --file <path> lint accuracy claims (the honesty guardrail)
|
||||
calibrate --step baseline|enroll|train-room|room-watch
|
||||
monitor --port COM8 [--seconds 12] assert CSI is flowing on a node
|
||||
flash --port COM8 --variant s3-8mb [--confirm] build+flash firmware (Windows/ESP-IDF)
|
||||
|
||||
Harness:
|
||||
doctor verify the install (tools + optional kernel/host)
|
||||
skills list bundled skills
|
||||
skill <name> print a skill playbook
|
||||
mcp start run the ruview.* MCP server (stdio)
|
||||
install --host <h> project the harness config into the current repo
|
||||
--version | --help
|
||||
|
||||
Hosts: claude-code, codex, opencode, copilot, pi-dev, hermes, rvm, github-actions`);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/** tiny flag parser: --k v / --k=v / --flag (boolean) */
|
||||
function parseFlags(rest) {
|
||||
const f = {};
|
||||
for (let i = 0; i < rest.length; i++) {
|
||||
const a = rest[i];
|
||||
if (a.startsWith('--')) {
|
||||
const eq = a.indexOf('=');
|
||||
if (eq !== -1) { f[a.slice(2, eq)] = a.slice(eq + 1); }
|
||||
else if (i + 1 < rest.length && !rest[i + 1].startsWith('--')) { f[a.slice(2)] = rest[++i]; }
|
||||
else { f[a.slice(2)] = true; }
|
||||
}
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
export async function run(args) {
|
||||
const cmd = args[0] ?? 'onboard';
|
||||
const rest = args.slice(1);
|
||||
const flags = parseFlags(rest);
|
||||
|
||||
// Direct tool verbs.
|
||||
if (VERB_TO_TOOL[cmd]) {
|
||||
const toolArgs = { ...flags };
|
||||
if (cmd === 'claim-check') {
|
||||
if (flags.file) toolArgs.text = readFileSync(flags.file, 'utf8');
|
||||
const res = runTool('ruview.claim_check', toolArgs);
|
||||
pjson(res);
|
||||
return res.ok ? 0 : 1;
|
||||
}
|
||||
if (cmd === 'monitor' && flags.seconds) toolArgs.seconds = Number(flags.seconds);
|
||||
if (cmd === 'calibrate' && typeof flags.args === 'string') toolArgs.args = flags.args.split(',');
|
||||
const res = runTool(VERB_TO_TOOL[cmd], toolArgs);
|
||||
pjson(res);
|
||||
return res.ok ? 0 : 1;
|
||||
}
|
||||
|
||||
switch (cmd) {
|
||||
case 'doctor': return doctor();
|
||||
case 'skills': console.log(listSkills().join('\n') || '(none)'); return 0;
|
||||
case 'skill': {
|
||||
const n = rest[0];
|
||||
const p = n && join(SKILLS_DIR, `${n}.md`);
|
||||
if (!p || !existsSync(p)) { console.error(`No skill "${n}". Try: ${listSkills().join(', ')}`); return 2; }
|
||||
console.log(readFileSync(p, 'utf8'));
|
||||
return 0;
|
||||
}
|
||||
case 'mcp': {
|
||||
if (rest[0] === 'start' || rest[0] === undefined) {
|
||||
const { startMcpServer } = await import('../src/mcp-server.js');
|
||||
startMcpServer();
|
||||
return new Promise(() => {}); // run until stdin closes
|
||||
}
|
||||
console.error('Usage: ruview mcp start'); return 2;
|
||||
}
|
||||
case 'install': {
|
||||
const host = flags.host || 'claude-code';
|
||||
try {
|
||||
const adapter = (await import('@metaharness/host-claude-code')).default;
|
||||
console.log(`Projecting RuView harness for host "${host}" via ${adapter.name}.`);
|
||||
console.log('Add to your host config — MCP server command: npx -y ruview mcp start');
|
||||
console.log('Skills:', listSkills().join(', '));
|
||||
return 0;
|
||||
} catch {
|
||||
console.error('Host adapter not installed. `npm i @metaharness/host-claude-code` or use the bundled .claude/ config.');
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
case 'tools': pjson(listTools()); return 0;
|
||||
case '--version': case '-v': {
|
||||
const pkg = JSON.parse(readFileSync(join(ROOT, 'package.json'), 'utf8'));
|
||||
console.log(pkg.version); return 0;
|
||||
}
|
||||
case '--help': case '-h': return help();
|
||||
default:
|
||||
console.error(`Unknown command: ${cmd}. Try \`${NAME} --help\`.`);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
// CLI guard: run only when invoked directly (realpath both sides — npm/npx shims
|
||||
// pass a non-normalized, possibly case-skewed argv[1] on Windows).
|
||||
const invokedDirectly = (() => {
|
||||
if (!argv[1]) return false;
|
||||
try {
|
||||
const a = realpathSync(argv[1]);
|
||||
const b = realpathSync(fileURLToPath(import.meta.url));
|
||||
return process.platform === 'win32' ? a.toLowerCase() === b.toLowerCase() : a === b;
|
||||
} catch { return false; }
|
||||
})();
|
||||
if (invokedDirectly) {
|
||||
run(argv.slice(2)).then((code) => process.exit(code)).catch((err) => { console.error(err); process.exit(1); });
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
{
|
||||
"name": "@ruvnet/ruview",
|
||||
"version": "0.1.0",
|
||||
"description": "RuView WiFi-sensing operator agent harness — onboard, calibrate, train, and verify camera-free WiFi-CSI sensing, with the project's MEASURED-vs-CLAIMED honesty guardrail enforced. Minted via metaharness (ADR-182).",
|
||||
"type": "module",
|
||||
"bin": {
|
||||
"ruview": "bin/cli.js"
|
||||
},
|
||||
"exports": {
|
||||
".": "./src/tools.js",
|
||||
"./guardrails": "./src/guardrails.js"
|
||||
},
|
||||
"files": [
|
||||
"bin/",
|
||||
"src/",
|
||||
"skills/",
|
||||
".claude/",
|
||||
".harness/",
|
||||
"CLAUDE.md",
|
||||
"README.md",
|
||||
"LICENSE"
|
||||
],
|
||||
"scripts": {
|
||||
"test": "node --test test/*.test.mjs",
|
||||
"doctor": "node ./bin/cli.js doctor",
|
||||
"mcp": "node ./bin/cli.js mcp start"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@metaharness/kernel": "^0.1.0",
|
||||
"@metaharness/host-claude-code": "^0.1.0"
|
||||
},
|
||||
"keywords": [
|
||||
"wifi-sensing",
|
||||
"wifi-densepose",
|
||||
"ruview",
|
||||
"csi",
|
||||
"channel-state-information",
|
||||
"pose-estimation",
|
||||
"presence-detection",
|
||||
"esp32",
|
||||
"agent-harness",
|
||||
"metaharness",
|
||||
"mcp",
|
||||
"mcp-server",
|
||||
"claude-code",
|
||||
"ambient-intelligence"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"license": "MIT",
|
||||
"author": "ruvnet",
|
||||
"homepage": "https://github.com/ruvnet/RuView#readme",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ruvnet/RuView.git",
|
||||
"directory": "harness/ruview"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/ruvnet/RuView/issues"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
---
|
||||
name: calibrate-room
|
||||
description: Run the ADR-151 per-room calibration pipeline — baseline → enroll → extract → train → a bank of small specialists (presence/posture/breathing/heartbeat/restlessness/anomaly).
|
||||
---
|
||||
|
||||
# calibrate-room
|
||||
|
||||
Turn a provisioned node + sensing-server into a working room model. Pure-Rust,
|
||||
edge-deployable (ADR-151). Use the `ruview.calibrate` tool (installed
|
||||
`wifi-densepose` binary, else `cargo run -p wifi-densepose-cli`).
|
||||
|
||||
## Sequence
|
||||
|
||||
1. **baseline** — capture the empty room (Welford amplitude + von Mises phase). Leave
|
||||
the room empty.
|
||||
`ruview.calibrate {step: "baseline"}`
|
||||
2. **enroll** — record the occupant(s) doing the target activities.
|
||||
`ruview.calibrate {step: "enroll"}`
|
||||
3. **train-room** — train the bank of small specialists from baseline + enrollment.
|
||||
`ruview.calibrate {step: "train-room"}`
|
||||
4. **room-watch** — live presence/posture/breathing from the trained room.
|
||||
`ruview.calibrate {step: "room-watch"}` (or the `room-watch` skill)
|
||||
|
||||
## Honesty
|
||||
|
||||
The specialists are calibrated to *this* room; cross-room transfer is a separate
|
||||
problem (LoRA recalibration, ADR-079 P9). Report which room a number came from, and
|
||||
tag presence/vitals accuracy MEASURED only with a held-out check — run
|
||||
`ruview.claim_check` on the writeup.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
name: onboard
|
||||
description: Zero-to-sensing path picker for RuView (WiFi-DensePose) — pick docker-demo, repo-build, or live-esp32 and run the next concrete step.
|
||||
---
|
||||
|
||||
# onboard
|
||||
|
||||
Get a newcomer from nothing to a working RuView setup. **First fact to set:** WiFi
|
||||
sensing infers *coarse* pose/presence/breathing from Channel State Information — it
|
||||
is **not a camera**, and any accuracy number must be MEASURED against a baseline
|
||||
(use the `verify` skill / `ruview.claim_check` tool). Never present WiFi output as
|
||||
camera-grade.
|
||||
|
||||
## Pick a path
|
||||
|
||||
Run `ruview.onboard {path}` or decide from:
|
||||
|
||||
1. **docker-demo** — fastest, no hardware. Replays sample CSI into the dashboard.
|
||||
`docker run -p 8000:8000 ruvnet/wifi-densepose` → open `http://localhost:8000`.
|
||||
Use to see what it looks like.
|
||||
2. **repo-build** — for developers. `cd v2 && cargo test --workspace --no-default-features`
|
||||
(1,031+ tests pass), then `cargo run -p wifi-densepose-cli -- --help`.
|
||||
3. **live-esp32** — a real install. Flash a node (`provision-node` skill), point it at
|
||||
the sensing-server, then `calibrate-room`. This is the only path that senses a real room.
|
||||
|
||||
## Then
|
||||
|
||||
- Live sensing → go to **provision-node**, then **calibrate-room**.
|
||||
- Evaluating a model/claim → go to **verify** and run `ruview.claim_check` on any
|
||||
report before you quote a number.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
name: provision-node
|
||||
description: Build, flash, and provision an ESP32-S3/C6 CSI node for RuView — firmware variant choice, ESP-IDF Windows-subprocess flow, NVS/WiFi/channel/MAC-filter overrides.
|
||||
---
|
||||
|
||||
# provision-node
|
||||
|
||||
Bring an ESP32 sensing node online.
|
||||
|
||||
## 1. Pick a firmware variant
|
||||
|
||||
- **s3-8mb** (display build) — ESP32-S3 N16R8 / 16MB; AMOLED optional. The display-detect
|
||||
fix (#1000) means a *bare* board still captures CSI (MGMT+DATA).
|
||||
- **s3-4mb** (no-display) — ESP32-S3 4MB; dual-OTA, display disabled.
|
||||
- **c6** — ESP32-C6 + Seeed MR60BHA2 (60 GHz mmWave + WiFi CSI). The mmwave probe
|
||||
requires a validated MR60 header (#1107) so an empty UART never false-detects.
|
||||
|
||||
Prebuilt binaries: GitHub release `v0.8.1-esp32` (hardware-validated on S3 QFN56 rev v0.2).
|
||||
|
||||
## 2. Flash
|
||||
|
||||
ESP-IDF v5.4 on Windows is **subprocess-only** (Git Bash/MSYS is unsupported — strip
|
||||
`MSYSTEM*` env vars). Offsets for the S3 image:
|
||||
|
||||
```
|
||||
esptool --chip esp32s3 -p <PORT> -b 460800 write_flash \
|
||||
0x0 bootloader.bin 0x8000 partition-table.bin \
|
||||
0xf000 ota_data_initial.bin 0x20000 esp32-csi-node-s3-8mb.bin
|
||||
```
|
||||
|
||||
(`ruview.node_flash` returns the exact pinned command rather than running an
|
||||
unattended flash.)
|
||||
|
||||
## 3. Provision
|
||||
|
||||
```
|
||||
python firmware/esp32-csi-node/provision.py --port <PORT> \
|
||||
--ssid "<SSID>" --password "<secret>" --target-ip <server-ip> --target-port 5005
|
||||
# optional ADR-060 overrides:
|
||||
python firmware/esp32-csi-node/provision.py --port <PORT> --channel 6 --filter-mac AA:BB:CC:DD:EE:FF
|
||||
```
|
||||
|
||||
Never echo or commit the WiFi password.
|
||||
|
||||
## 4. Confirm CSI is flowing
|
||||
|
||||
`ruview.node_monitor {port}` — PASS criteria: serial shows `CSI cb #...` callbacks and
|
||||
(on a bare board) `CSI filter upgraded to MGMT+DATA`. No callbacks → the node isn't
|
||||
capturing; do not proceed to calibration.
|
||||
@@ -0,0 +1,33 @@
|
||||
---
|
||||
name: train-pose
|
||||
description: Train/evaluate WiFi pose models honestly — camera-supervised (MediaPipe + CSI) and camera-free (WiFlow), always checked against the mean-pose baseline before any PCK is quoted.
|
||||
---
|
||||
|
||||
# train-pose
|
||||
|
||||
Build a CSI→pose model without overstating it. The project has a **retracted 92.9%/100%**
|
||||
history — the discipline below exists so it never recurs.
|
||||
|
||||
## The non-negotiable: mean-pose baseline first
|
||||
|
||||
A pose model that always predicts the dataset's *mean pose* already scores ~50% PCK.
|
||||
**Quote PCK only as a delta over that baseline**, on a held-out split with no subject
|
||||
or temporal leakage. Example honest result (ADR-181):
|
||||
|
||||
> Held-out PCK@20 **59.5%** vs a 50% mean-pose baseline = **+9.4 pp real signal** — MEASURED.
|
||||
|
||||
## Paths
|
||||
|
||||
- **camera-supervised** (ADR-079) — MediaPipe Pose labels the camera frame; paired CSI
|
||||
trains the net. Train/infer in one camera frame so the skeleton aligns.
|
||||
- **camera-free** (WiFlow, ADR-152) — no camera at inference; geometry-conditioned.
|
||||
- **in-browser** (ADR-181) — WebGPU/WASM trainer; the active backend is shown as a badge
|
||||
(honest about what's executing).
|
||||
|
||||
## Before you publish a number
|
||||
|
||||
1. Run the mean-pose baseline on the same split.
|
||||
2. Report `(model − baseline)` in pp, with the split definition (chronological /
|
||||
blocked-gap / grouped-bucket; no leakage).
|
||||
3. `ruview.claim_check` the writeup — it flags any untagged or 100%/perfect claim.
|
||||
4. If it's a benchmark vs SOTA, tag MEASURED-EQUIVALENT only with the reproducer.
|
||||
@@ -0,0 +1,42 @@
|
||||
---
|
||||
name: verify
|
||||
description: Prove a RuView result is real — run the deterministic SHA-256 proof and the witness bundle (ADR-028), and lint any claim for MEASURED-vs-CLAIMED honesty.
|
||||
---
|
||||
|
||||
# verify
|
||||
|
||||
The "prove everything" skill. Nothing ships as validated without this.
|
||||
|
||||
## Deterministic proof (Trust Kill Switch)
|
||||
|
||||
`ruview.verify` runs `archive/v1/data/proof/verify.py`: it feeds a reference signal
|
||||
through the production pipeline and hashes the output against
|
||||
`expected_features.sha256`. Must print **VERDICT: PASS**. If numpy/scipy changed the
|
||||
hash, regenerate with `verify.py --generate-hash` then re-verify.
|
||||
|
||||
## Witness bundle (ADR-028)
|
||||
|
||||
For a release-grade attestation:
|
||||
|
||||
```
|
||||
bash scripts/generate-witness-bundle.sh
|
||||
cd dist/witness-bundle-ADR028-*/ && bash VERIFY.sh # must be 7/7 PASS
|
||||
```
|
||||
|
||||
Contains the Rust test log, the proof + expected hash, firmware SHA-256 manifest, and
|
||||
crate versions — a recipient can re-verify with one command.
|
||||
|
||||
## Claim honesty
|
||||
|
||||
Run `ruview.claim_check {text}` on any report, README section, PR body, or model card
|
||||
before quoting accuracy. It flags:
|
||||
- untagged accuracy numbers (must be MEASURED / CLAIMED / SYNTHETIC),
|
||||
- MEASURED claims with no reproducer cited,
|
||||
- the retracted "100%/perfect accuracy" framing.
|
||||
|
||||
## Firmware-specific
|
||||
|
||||
A firmware fix is **not** "hardware-validated" without a captured boot log on real
|
||||
silicon (e.g. the `v0.8.1-esp32` rev-v0.2 validation: `running headless so CSI
|
||||
captures (#1000)` + `CSI filter upgraded to MGMT+DATA` + a no-false-detect mmwave
|
||||
probe). Do not merge or release on a build-passes signal alone.
|
||||
@@ -0,0 +1,106 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// RuView harness guardrails — the "prove everything" rule made executable.
|
||||
//
|
||||
// The project was accused of AI-slop; the cultural fix is that every accuracy
|
||||
// number must be tagged MEASURED (with a reproducer) or CLAIMED/SYNTHETIC, and
|
||||
// the retracted "100% accuracy" framing must never reappear untagged. This module
|
||||
// is the static enforcement of that, shared by the `ruview.claim_check` MCP tool,
|
||||
// the `npx ruview claim-check` CLI, and the claude-code pre-output hook.
|
||||
|
||||
/** Phrases that signal a quantitative accuracy claim. */
|
||||
const METRIC_TERMS = [
|
||||
'accuracy', 'pck', 'pck@', 'f1', 'precision', 'recall', 'map', 'auc',
|
||||
'iou', 'mpjpe', 'error rate', 'detection rate', 'true positive',
|
||||
];
|
||||
|
||||
/** Tags that make a claim honest (case-insensitive). */
|
||||
const HONEST_TAGS = ['measured', 'claimed', 'synthetic', 'unvalidated', 'baseline'];
|
||||
|
||||
/** Reproducer references that count as evidence backing a MEASURED claim. */
|
||||
const REPRODUCER_HINTS = [
|
||||
'verify.py', 'witness', 'mean-pose', 'mean pose', 'held-out', 'held out',
|
||||
'baseline', 'reproduce', 'sha256', 'boot log', 'pck@20 vs', 'expected_features',
|
||||
];
|
||||
|
||||
const PERCENT_RE = /\b(\d{1,3}(?:\.\d+)?)\s?%/g;
|
||||
// "perfect" / "100%" framing is the specific retracted claim — always high severity.
|
||||
// NOTE: no trailing \b after "%": "%"→" " is non-word→non-word, so a trailing \b
|
||||
// never matches and would silently miss "100%". Bare 100% is only damning next to a
|
||||
// metric term (see claimCheck); the word phrases are inherently accuracy claims.
|
||||
const PERFECT_PCT_RE = /\b100(?:\.0+)?\s?%/;
|
||||
const PERFECT_WORD_RE = /perfect accuracy|flawless|never (?:wrong|fails)/i;
|
||||
|
||||
/**
|
||||
* Lint a block of text for untagged or overstated accuracy claims.
|
||||
* @param {string} text
|
||||
* @returns {{ok: boolean, findings: Array<{severity:'high'|'medium', line:number, excerpt:string, reason:string, suggestion:string}>}}
|
||||
*/
|
||||
export function claimCheck(text) {
|
||||
const findings = [];
|
||||
if (typeof text !== 'string' || text.length === 0) {
|
||||
return { ok: true, findings };
|
||||
}
|
||||
const lines = text.split(/\r?\n/);
|
||||
|
||||
lines.forEach((raw, i) => {
|
||||
const line = raw.trim();
|
||||
if (!line) return;
|
||||
const lower = line.toLowerCase();
|
||||
|
||||
const hasPercent = PERCENT_RE.test(line);
|
||||
PERCENT_RE.lastIndex = 0; // reset stateful global regex
|
||||
const mentionsMetric = METRIC_TERMS.some((t) => lower.includes(t));
|
||||
if (!hasPercent && !mentionsMetric) return;
|
||||
|
||||
const tagged = HONEST_TAGS.some((t) => lower.includes(t));
|
||||
const hasReproducer = REPRODUCER_HINTS.some((h) => lower.includes(h));
|
||||
const perfect = PERFECT_WORD_RE.test(line) || (mentionsMetric && PERFECT_PCT_RE.test(line));
|
||||
|
||||
if (perfect && !lower.includes('retract')) {
|
||||
findings.push({
|
||||
severity: 'high',
|
||||
line: i + 1,
|
||||
excerpt: clip(line),
|
||||
reason: 'States perfect/100% accuracy — this is the exact framing the project retracted.',
|
||||
suggestion: 'Replace with a held-out number vs the mean-pose baseline, tagged MEASURED, or mark the old claim "retracted".',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// A metric/percent with no honesty tag at all.
|
||||
if (!tagged) {
|
||||
findings.push({
|
||||
severity: 'medium',
|
||||
line: i + 1,
|
||||
excerpt: clip(line),
|
||||
reason: 'Accuracy claim is not tagged MEASURED / CLAIMED / SYNTHETIC.',
|
||||
suggestion: 'Tag it. If MEASURED, name the reproducer (verify.py, witness bundle, held-out vs mean-pose).',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Tagged MEASURED but cites no reproducer — still a gap.
|
||||
if (lower.includes('measured') && !hasReproducer) {
|
||||
findings.push({
|
||||
severity: 'medium',
|
||||
line: i + 1,
|
||||
excerpt: clip(line),
|
||||
reason: 'Tagged MEASURED but cites no reproducer/evidence.',
|
||||
suggestion: 'Add the evidence path: verify.py VERDICT, witness bundle, or held-out PCK vs the mean-pose baseline.',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return { ok: findings.length === 0, findings };
|
||||
}
|
||||
|
||||
function clip(s, n = 120) {
|
||||
return s.length > n ? `${s.slice(0, n - 1)}…` : s;
|
||||
}
|
||||
|
||||
/** Convenience: a one-line human summary for CLI output. */
|
||||
export function summarize(result) {
|
||||
if (result.ok) return 'claim-check: PASS — no untagged or overstated accuracy claims.';
|
||||
const high = result.findings.filter((f) => f.severity === 'high').length;
|
||||
return `claim-check: ${result.findings.length} finding(s) (${high} high) — accuracy claims need MEASURED/CLAIMED tags + a reproducer.`;
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// RuView harness — minimal MCP stdio server (JSON-RPC 2.0 over stdin/stdout).
|
||||
//
|
||||
// Dependency-free on purpose: a published `npx ruview` must `mcp start` without
|
||||
// pulling the full MCP SDK. Implements the subset hosts use: `initialize`,
|
||||
// `tools/list`, `tools/call`, and the `notifications/initialized` ack. Logs go to
|
||||
// stderr ONLY — stdout is the JSON-RPC channel and must stay clean.
|
||||
|
||||
import { createInterface } from 'node:readline';
|
||||
import { listTools, runTool } from './tools.js';
|
||||
|
||||
const PROTOCOL_VERSION = '2024-11-05';
|
||||
const SERVER_INFO = { name: 'ruview', version: '0.1.0' };
|
||||
|
||||
function send(msg) {
|
||||
process.stdout.write(JSON.stringify(msg) + '\n');
|
||||
}
|
||||
function result(id, res) { send({ jsonrpc: '2.0', id, result: res }); }
|
||||
function error(id, code, message) { send({ jsonrpc: '2.0', id, error: { code, message } }); }
|
||||
function log(...a) { process.stderr.write('[ruview-mcp] ' + a.join(' ') + '\n'); }
|
||||
|
||||
function handle(msg) {
|
||||
const { id, method, params } = msg;
|
||||
switch (method) {
|
||||
case 'initialize':
|
||||
return result(id, {
|
||||
protocolVersion: PROTOCOL_VERSION,
|
||||
capabilities: { tools: { listChanged: false } },
|
||||
serverInfo: SERVER_INFO,
|
||||
instructions: 'RuView WiFi-sensing operator tools. All results are fail-closed; accuracy claims must pass ruview.claim_check.',
|
||||
});
|
||||
case 'notifications/initialized':
|
||||
case 'initialized':
|
||||
return; // notification — no response
|
||||
case 'ping':
|
||||
return result(id, {});
|
||||
case 'tools/list':
|
||||
return result(id, { tools: listTools() });
|
||||
case 'tools/call': {
|
||||
const name = params?.name;
|
||||
const args = params?.arguments || {};
|
||||
const out = runTool(name, args);
|
||||
// MCP content envelope: text block with the JSON, isError reflects ok=false.
|
||||
return result(id, {
|
||||
content: [{ type: 'text', text: JSON.stringify(out, null, 2) }],
|
||||
isError: out && out.ok === false,
|
||||
});
|
||||
}
|
||||
default:
|
||||
if (id !== undefined) error(id, -32601, `Method not found: ${method}`);
|
||||
}
|
||||
}
|
||||
|
||||
export function startMcpServer() {
|
||||
log(`starting (protocol ${PROTOCOL_VERSION}, ${listTools().length} tools)`);
|
||||
const rl = createInterface({ input: process.stdin, crlfDelay: Infinity });
|
||||
rl.on('line', (line) => {
|
||||
const s = line.trim();
|
||||
if (!s) return;
|
||||
let msg;
|
||||
try { msg = JSON.parse(s); } catch { return log('bad JSON line dropped'); }
|
||||
try { handle(msg); } catch (err) {
|
||||
if (msg && msg.id !== undefined) error(msg.id, -32603, String(err && err.message || err));
|
||||
log('handler error:', String(err));
|
||||
}
|
||||
});
|
||||
rl.on('close', () => { log('stdin closed — exiting'); process.exit(0); });
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// RuView harness — the `ruview.*` tool registry.
|
||||
//
|
||||
// One registry consumed by BOTH the CLI (`npx ruview <tool>`) and the MCP server
|
||||
// (`npx ruview mcp start`). Every handler returns structured JSON and is
|
||||
// FAIL-CLOSED: when a prerequisite (the RuView repo, python+pyserial, the
|
||||
// `wifi-densepose` binary, an ESP32 on a port) is absent, it returns an honest
|
||||
// negative — never a fabricated success. This mirrors the project's "prove
|
||||
// everything" rule and the RuField fail-closed posture (ADR-262 §3.3).
|
||||
|
||||
import { spawnSync } from 'node:child_process';
|
||||
import { existsSync, readFileSync } from 'node:fs';
|
||||
import { join, dirname, resolve } from 'node:path';
|
||||
import { claimCheck, summarize } from './guardrails.js';
|
||||
|
||||
/** Walk up from `start` to find the RuView monorepo root (or null). */
|
||||
export function findRepoRoot(start = process.cwd()) {
|
||||
let dir = resolve(start);
|
||||
for (let i = 0; i < 8; i++) {
|
||||
const hasProof = existsSync(join(dir, 'archive', 'v1', 'data', 'proof', 'verify.py'));
|
||||
const hasV2 = existsSync(join(dir, 'v2', 'Cargo.toml'));
|
||||
if (hasProof || hasV2) return dir;
|
||||
const parent = dirname(dir);
|
||||
if (parent === dir) break;
|
||||
dir = parent;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function which(cmd) {
|
||||
const probe = process.platform === 'win32'
|
||||
? spawnSync('where', [cmd], { encoding: 'utf8' })
|
||||
: spawnSync('command', ['-v', cmd], { encoding: 'utf8', shell: true });
|
||||
return probe.status === 0 ? (probe.stdout || '').trim().split(/\r?\n/)[0] : null;
|
||||
}
|
||||
|
||||
function run(cmd, args, opts = {}) {
|
||||
const r = spawnSync(cmd, args, { encoding: 'utf8', timeout: opts.timeout ?? 120000, ...opts });
|
||||
return {
|
||||
status: r.status,
|
||||
ok: r.status === 0,
|
||||
stdout: (r.stdout || '').slice(-8000),
|
||||
stderr: (r.stderr || '').slice(-4000),
|
||||
error: r.error ? r.error.message : null,
|
||||
};
|
||||
}
|
||||
|
||||
const ONBOARD_PATHS = {
|
||||
'docker-demo': 'Fastest. `docker run -p 8000:8000 ruvnet/wifi-densepose` → open the dashboard. No hardware; replays sample CSI. Good for "what does it look like".',
|
||||
'repo-build': 'Build from source. `cd v2 && cargo test --workspace --no-default-features` (1,031+ tests). Then `cargo run -p wifi-densepose-cli -- --help`. Good for developers.',
|
||||
'live-esp32': 'Real sensing. Flash an ESP32-S3 (see `provision-node` skill), point it at the sensing-server, then `calibrate → enroll → train-room → room-watch` (see `calibrate-room`). Good for an actual install.',
|
||||
};
|
||||
|
||||
/**
|
||||
* The tool registry. Each entry: { title, description, inputSchema, handler }.
|
||||
* inputSchema is JSON-Schema (object). handler(args) → JSON-serializable result.
|
||||
*/
|
||||
export const TOOLS = {
|
||||
'ruview.onboard': {
|
||||
title: 'Onboard',
|
||||
description: 'Pick a RuView setup path (docker-demo | repo-build | live-esp32) and print the next concrete command.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: { path: { type: 'string', enum: Object.keys(ONBOARD_PATHS), description: 'Which setup path. Omit to list all.' } },
|
||||
},
|
||||
handler(args = {}) {
|
||||
const repo = findRepoRoot();
|
||||
if (args.path && ONBOARD_PATHS[args.path]) {
|
||||
return { ok: true, path: args.path, next: ONBOARD_PATHS[args.path], in_ruview_repo: !!repo };
|
||||
}
|
||||
return {
|
||||
ok: true,
|
||||
in_ruview_repo: !!repo,
|
||||
repo_root: repo,
|
||||
paths: ONBOARD_PATHS,
|
||||
recommend: repo ? 'repo-build' : 'docker-demo',
|
||||
note: 'WiFi sensing infers coarse pose/presence from CSI — it is not a camera. Accuracy claims must be MEASURED vs a baseline (run `ruview.claim_check`).',
|
||||
};
|
||||
},
|
||||
},
|
||||
|
||||
'ruview.claim_check': {
|
||||
title: 'Claim check',
|
||||
description: 'Static lint: scan text for untagged or overstated accuracy claims (the "prove everything" guardrail). Returns findings.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
required: ['text'],
|
||||
properties: { text: { type: 'string', description: 'The text to lint (a report, README section, PR body, model card).' } },
|
||||
},
|
||||
handler(args = {}) {
|
||||
const result = claimCheck(String(args.text ?? ''));
|
||||
return { ...result, summary: summarize(result) };
|
||||
},
|
||||
},
|
||||
|
||||
'ruview.verify': {
|
||||
title: 'Verify (witness)',
|
||||
description: 'Run the deterministic proof (archive/v1/data/proof/verify.py) and report VERDICT. Fail-closed if not in a RuView repo or python is missing.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: { repo: { type: 'string', description: 'RuView repo root. Default: auto-detect from cwd.' } },
|
||||
},
|
||||
handler(args = {}) {
|
||||
const repo = args.repo ? resolve(args.repo) : findRepoRoot();
|
||||
if (!repo) return { ok: false, reason: 'not_in_ruview_repo', hint: 'Run inside the RuView monorepo or pass {repo}.' };
|
||||
const proof = join(repo, 'archive', 'v1', 'data', 'proof', 'verify.py');
|
||||
if (!existsSync(proof)) return { ok: false, reason: 'proof_missing', path: proof };
|
||||
const py = which('python') || which('python3');
|
||||
if (!py) return { ok: false, reason: 'python_missing', hint: 'Install python to run the deterministic proof.' };
|
||||
const r = run(py, [proof], { cwd: repo, timeout: 180000 });
|
||||
const verdict = /VERDICT:\s*PASS/i.test(r.stdout) ? 'PASS' : (/VERDICT:\s*FAIL/i.test(r.stdout) ? 'FAIL' : 'UNKNOWN');
|
||||
return { ok: r.ok && verdict === 'PASS', verdict, exit: r.status, tail: r.stdout.slice(-1200), stderr: r.stderr.slice(-400) };
|
||||
},
|
||||
},
|
||||
|
||||
'ruview.node_monitor': {
|
||||
title: 'Node monitor',
|
||||
description: 'Open an ESP32 serial port and assert CSI is flowing (MGMT+DATA). Fail-closed if python+pyserial or the port is absent. Read-only.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
port: { type: 'string', description: 'Serial port, e.g. COM8 or /dev/ttyUSB0.' },
|
||||
seconds: { type: 'number', description: 'Capture window (default 12).' },
|
||||
},
|
||||
},
|
||||
handler(args = {}) {
|
||||
const port = args.port;
|
||||
if (!port) return { ok: false, reason: 'no_port', hint: 'Pass {port} (e.g. COM8).' };
|
||||
const py = which('python') || which('python3');
|
||||
if (!py) return { ok: false, reason: 'python_missing' };
|
||||
const dur = Number(args.seconds) > 0 ? Number(args.seconds) : 12;
|
||||
const script = [
|
||||
'import sys,time',
|
||||
'try:',
|
||||
' import serial',
|
||||
'except Exception as e:',
|
||||
" print('NO_PYSERIAL'); sys.exit(3)",
|
||||
`ser=serial.Serial(${JSON.stringify(port)},115200,timeout=1)`,
|
||||
'csi=0; n=0; t=time.time()',
|
||||
`while time.time()-t<${dur}:`,
|
||||
' ln=ser.readline()',
|
||||
' if not ln: continue',
|
||||
" s=ln.decode('utf-8','replace')",
|
||||
' n+=1',
|
||||
" if 'CSI cb' in s or 'csi_collector' in s: csi+=1",
|
||||
" if 'MGMT+DATA' in s: print('UPGRADE_MGMT_DATA')",
|
||||
'ser.close()',
|
||||
"print(f'LINES={n} CSI={csi}')",
|
||||
].join('\n');
|
||||
const r = run(py, ['-c', script], { timeout: (dur + 10) * 1000 });
|
||||
if (r.stdout.includes('NO_PYSERIAL')) return { ok: false, reason: 'pyserial_missing', hint: 'pip install pyserial' };
|
||||
if (!r.ok) return { ok: false, reason: 'port_error', stderr: r.stderr, error: r.error };
|
||||
const csi = Number((r.stdout.match(/CSI=(\d+)/) || [])[1] || 0);
|
||||
const upgraded = r.stdout.includes('UPGRADE_MGMT_DATA');
|
||||
return { ok: csi > 0, csi_callbacks: csi, mgmt_data_upgrade: upgraded, raw: r.stdout.trim() };
|
||||
},
|
||||
},
|
||||
|
||||
'ruview.calibrate': {
|
||||
title: 'Calibrate room',
|
||||
description: 'Run the ADR-151 room pipeline via the wifi-densepose CLI (baseline→enroll→train-room). Fail-closed if the binary is absent.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
step: { type: 'string', enum: ['baseline', 'enroll', 'train-room', 'room-watch'], description: 'Which calibration step.' },
|
||||
args: { type: 'array', items: { type: 'string' }, description: 'Extra CLI args passed through.' },
|
||||
},
|
||||
},
|
||||
handler(args = {}) {
|
||||
const step = args.step || 'baseline';
|
||||
const bin = which('wifi-densepose');
|
||||
const repo = findRepoRoot();
|
||||
if (!bin && !repo) return { ok: false, reason: 'cli_missing', hint: 'Install the wifi-densepose CLI or run in the repo (cargo run -p wifi-densepose-cli).' };
|
||||
const passthru = Array.isArray(args.args) ? args.args.map(String) : [];
|
||||
// Prefer the installed binary; otherwise cargo-run from the repo.
|
||||
const r = bin
|
||||
? run(bin, [step, ...passthru], { timeout: 300000 })
|
||||
: run('cargo', ['run', '-q', '-p', 'wifi-densepose-cli', '--', step, ...passthru], { cwd: repo, timeout: 600000 });
|
||||
return { ok: r.ok, step, via: bin ? 'binary' : 'cargo', exit: r.status, tail: r.stdout.slice(-1500), stderr: r.stderr.slice(-500) };
|
||||
},
|
||||
},
|
||||
|
||||
'ruview.node_flash': {
|
||||
title: 'Node flash',
|
||||
description: 'Build+flash an ESP32 firmware variant. MUTATING + hardware. Fail-closed off-Windows or without ESP-IDF. Never claims hardware validation without a boot log.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
port: { type: 'string', description: 'Target port, e.g. COM8.' },
|
||||
variant: { type: 'string', enum: ['s3-8mb', 's3-4mb', 'c6'], description: 'Firmware variant.' },
|
||||
confirm: { type: 'boolean', description: 'Must be true to actually flash (guard).' },
|
||||
},
|
||||
},
|
||||
handler(args = {}) {
|
||||
if (process.platform !== 'win32') {
|
||||
return { ok: false, reason: 'unsupported_platform', detail: 'The ESP-IDF flash flow is Windows-subprocess-specific today (see CLAUDE.local.md).' };
|
||||
}
|
||||
if (!args.confirm) {
|
||||
return { ok: false, reason: 'not_confirmed', detail: 'Mutating hardware op — re-call with {confirm:true}.', would_flash: { port: args.port, variant: args.variant || 's3-8mb' } };
|
||||
}
|
||||
return { ok: false, reason: 'manual_step_required', detail: 'Flashing uses the pinned ESP-IDF subprocess in CLAUDE.local.md. This tool returns the exact command rather than running an unattended flash.', see: 'skills/provision-node.md' };
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
/** Run one tool by name; returns the structured result (or an error envelope). */
|
||||
export function runTool(name, args) {
|
||||
const tool = TOOLS[name];
|
||||
if (!tool) return { ok: false, reason: 'unknown_tool', name, available: Object.keys(TOOLS) };
|
||||
try {
|
||||
return tool.handler(args || {});
|
||||
} catch (err) {
|
||||
return { ok: false, reason: 'tool_threw', name, error: String(err && err.message || err) };
|
||||
}
|
||||
}
|
||||
|
||||
/** MCP-shaped tool list: [{name, description, inputSchema}]. */
|
||||
export function listTools() {
|
||||
return Object.entries(TOOLS).map(([name, t]) => ({ name, description: t.description, inputSchema: t.inputSchema }));
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// RuView harness tests — Node's built-in test runner (no devDeps to install).
|
||||
// Run: `node --test test/` (or `npm test`).
|
||||
|
||||
import { test } from 'node:test';
|
||||
import assert from 'node:assert/strict';
|
||||
import { claimCheck, summarize } from '../src/guardrails.js';
|
||||
import { TOOLS, runTool, listTools, findRepoRoot } from '../src/tools.js';
|
||||
import { run } from '../bin/cli.js';
|
||||
|
||||
test('guardrail flags the retracted 100% framing as high severity', () => {
|
||||
const r = claimCheck('Our model reaches 100% accuracy on every pose.');
|
||||
assert.equal(r.ok, false);
|
||||
assert.ok(r.findings.some((f) => f.severity === 'high'));
|
||||
});
|
||||
|
||||
test('guardrail flags an untagged percentage accuracy claim', () => {
|
||||
// "hit", not "measured" — "measured" would (correctly) route to the no-reproducer branch.
|
||||
const r = claimCheck('We hit 92.9% PCK on the test set.');
|
||||
assert.equal(r.ok, false);
|
||||
assert.ok(r.findings.some((f) => /not tagged/i.test(f.reason)));
|
||||
});
|
||||
|
||||
test('guardrail passes a MEASURED claim that cites a reproducer', () => {
|
||||
const r = claimCheck('Held-out PCK@20 59.5% vs 50% mean-pose baseline = +9.4pp (MEASURED, verify.py).');
|
||||
assert.equal(r.ok, true, JSON.stringify(r.findings));
|
||||
});
|
||||
|
||||
test('guardrail flags MEASURED with no reproducer', () => {
|
||||
const r = claimCheck('Presence detection 97% (MEASURED).');
|
||||
assert.equal(r.ok, false);
|
||||
assert.ok(r.findings.some((f) => /no reproducer/i.test(f.reason)));
|
||||
});
|
||||
|
||||
test('guardrail ignores non-metric prose', () => {
|
||||
assert.equal(claimCheck('The ESP32 streams CSI over UDP to the sensing-server.').ok, true);
|
||||
assert.equal(claimCheck('').ok, true);
|
||||
});
|
||||
|
||||
test('summarize gives PASS/finding text', () => {
|
||||
assert.match(summarize(claimCheck('nothing here')), /PASS/);
|
||||
assert.match(summarize(claimCheck('100% accuracy')), /finding/);
|
||||
});
|
||||
|
||||
test('registry exposes the documented tools with schemas', () => {
|
||||
const names = Object.keys(TOOLS);
|
||||
for (const n of ['ruview.onboard', 'ruview.claim_check', 'ruview.verify', 'ruview.node_monitor', 'ruview.calibrate', 'ruview.node_flash']) {
|
||||
assert.ok(names.includes(n), `missing ${n}`);
|
||||
assert.equal(TOOLS[n].inputSchema.type, 'object');
|
||||
}
|
||||
assert.equal(listTools().length, names.length);
|
||||
});
|
||||
|
||||
test('ruview.onboard returns paths and a recommendation', () => {
|
||||
const r = runTool('ruview.onboard', {});
|
||||
assert.equal(r.ok, true);
|
||||
assert.ok(r.paths['live-esp32']);
|
||||
assert.ok(['repo-build', 'docker-demo'].includes(r.recommend));
|
||||
});
|
||||
|
||||
test('ruview.claim_check tool wraps the guardrail', () => {
|
||||
const r = runTool('ruview.claim_check', { text: '100% accuracy' });
|
||||
assert.equal(r.ok, false);
|
||||
assert.match(r.summary, /honesty|tag|MEASURED|finding/i);
|
||||
});
|
||||
|
||||
test('unknown tool fails closed', () => {
|
||||
const r = runTool('ruview.does_not_exist', {});
|
||||
assert.equal(r.ok, false);
|
||||
assert.equal(r.reason, 'unknown_tool');
|
||||
});
|
||||
|
||||
test('node_monitor fails closed without a port', () => {
|
||||
const r = runTool('ruview.node_monitor', {});
|
||||
assert.equal(r.ok, false);
|
||||
assert.equal(r.reason, 'no_port');
|
||||
});
|
||||
|
||||
test('node_flash refuses without confirm (mutating guard)', () => {
|
||||
const r = runTool('ruview.node_flash', { port: 'COM8', variant: 's3-8mb' });
|
||||
assert.equal(r.ok, false);
|
||||
// either not-confirmed (win32) or unsupported_platform (posix) — both fail-closed
|
||||
assert.ok(['not_confirmed', 'unsupported_platform'].includes(r.reason));
|
||||
});
|
||||
|
||||
test('verify fails closed when not in a RuView repo', () => {
|
||||
// point at a tmp dir with no repo markers
|
||||
const r = runTool('ruview.verify', { repo: process.platform === 'win32' ? 'C:/Windows/Temp' : '/tmp' });
|
||||
assert.equal(r.ok, false);
|
||||
assert.ok(['proof_missing', 'python_missing'].includes(r.reason), r.reason);
|
||||
});
|
||||
|
||||
test('CLI run(): claim-check exits non-zero on a bad claim', async () => {
|
||||
const code = await run(['claim-check', '--text', '100% accuracy']);
|
||||
assert.notEqual(code, 0);
|
||||
});
|
||||
|
||||
test('CLI run(): doctor exits 0 (tools-only path)', async () => {
|
||||
const code = await run(['doctor']);
|
||||
assert.equal(code, 0);
|
||||
});
|
||||
|
||||
test('CLI run(): unknown command exits non-zero', async () => {
|
||||
assert.notEqual(await run(['definitely-not-a-command']), 0);
|
||||
});
|
||||
|
||||
test('findRepoRoot locates this monorepo from cwd', () => {
|
||||
// when run from within wifi-densepose, it should find a root; elsewhere null is fine
|
||||
const root = findRepoRoot();
|
||||
assert.ok(root === null || typeof root === 'string');
|
||||
});
|
||||
Binary file not shown.
BIN
Binary file not shown.
@@ -184,7 +184,9 @@ function loadGroundTruth(filePath) {
|
||||
const raw = loadJsonl(filePath);
|
||||
const frames = [];
|
||||
for (const r of raw) {
|
||||
if (r.ts_ns == null || !r.keypoints) continue;
|
||||
// Skip non-detection frames (empty keypoints []) — they must not dilute window
|
||||
// confidence; confidence stats are over actual detections only (#1007 Bug 2).
|
||||
if (r.ts_ns == null || !r.keypoints || r.keypoints.length === 0) continue;
|
||||
frames.push({
|
||||
tsMs: cameraTsToMs(r.ts_ns),
|
||||
keypoints: r.keypoints,
|
||||
@@ -266,7 +268,29 @@ function loadCsi(filePath) {
|
||||
// Sort by timestamp
|
||||
rawCsi.sort((a, b) => a.tsMs - b.tsMs);
|
||||
features.sort((a, b) => a.tsMs - b.tsMs);
|
||||
return { rawCsi, features };
|
||||
|
||||
// Bug 3 (#1007): keep only frames at the session's MODAL subcarrier count so windows
|
||||
// are homogeneous; never silently zero-pad/truncate the off-format frames the ESP32
|
||||
// emits (HT20/HT40/fragments). extractCsiMatrix then sees uniform-width frames.
|
||||
return { rawCsi: filterToModalSubcarriers(rawCsi), features };
|
||||
}
|
||||
|
||||
/**
|
||||
* Keep only frames whose subcarrier count equals the session's modal (most common)
|
||||
* count. Off-format frames are dropped (logged), not padded — prevents the silent
|
||||
* zero-padding that corrupted windows in #1007.
|
||||
*/
|
||||
function filterToModalSubcarriers(frames) {
|
||||
if (frames.length === 0) return frames;
|
||||
const counts = new Map();
|
||||
for (const f of frames) counts.set(f.subcarriers, (counts.get(f.subcarriers) || 0) + 1);
|
||||
let modal = frames[0].subcarriers, best = 0;
|
||||
for (const [sc, n] of counts) if (n > best) { best = n; modal = sc; }
|
||||
const kept = frames.filter((f) => f.subcarriers === modal);
|
||||
if (kept.length !== frames.length) {
|
||||
console.error(`[align] #1007: kept ${kept.length}/${frames.length} CSI frames at modal subcarrier count ${modal} (dropped ${frames.length - kept.length} off-format; no silent padding)`);
|
||||
}
|
||||
return kept;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -343,7 +367,8 @@ function averageKeypoints(cameraFrames) {
|
||||
|
||||
/**
|
||||
* Extract CSI amplitude matrix from raw_csi window.
|
||||
* Returns { data: flat Float32Array, shape: [subcarriers, windowFrames] }.
|
||||
* Fill is frame-major (matrix[f*nSc + s]), so shape is [windowFrames, subcarriers]
|
||||
* (#1007 Bug 4 — was mislabeled [subcarriers, windowFrames], transposing consumers).
|
||||
*/
|
||||
function extractCsiMatrix(window) {
|
||||
const nFrames = window.length;
|
||||
@@ -363,12 +388,13 @@ function extractCsiMatrix(window) {
|
||||
}
|
||||
}
|
||||
|
||||
return { data: Array.from(matrix), shape: [nSc, nFrames] };
|
||||
return { data: Array.from(matrix), shape: [nFrames, nSc] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract feature matrix from feature-type window.
|
||||
* Returns { data: flat array, shape: [featureDim, windowFrames] }.
|
||||
* Fill is frame-major (matrix[f*dim + d]), so shape is [windowFrames, featureDim]
|
||||
* (#1007 Bug 4 — was mislabeled [featureDim, windowFrames]).
|
||||
*/
|
||||
function extractFeatureMatrix(window) {
|
||||
const nFrames = window.length;
|
||||
@@ -382,7 +408,7 @@ function extractFeatureMatrix(window) {
|
||||
}
|
||||
}
|
||||
|
||||
return { data: Array.from(matrix), shape: [dim, nFrames] };
|
||||
return { data: Array.from(matrix), shape: [nFrames, dim] };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -15,6 +15,7 @@ import os
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def parse_csi_packet(data):
|
||||
@@ -41,7 +42,8 @@ def parse_csi_packet(data):
|
||||
|
||||
return {
|
||||
"type": "raw_csi",
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.") + f"{int(time.time() * 1000) % 1000:03d}Z",
|
||||
# true UTC, not local-time-labeled-Z (#1007 Bug 1) — e.g. "2026-06-17T01:23:45.678Z"
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"ts_ns": time.time_ns(),
|
||||
"node_id": node_id,
|
||||
"rssi": rssi,
|
||||
|
||||
Binary file not shown.
+4
-4
@@ -25,8 +25,7 @@ members = [
|
||||
"crates/wifi-densepose-ruvector",
|
||||
"crates/wifi-densepose-desktop",
|
||||
"crates/wifi-densepose-pointcloud",
|
||||
"crates/wifi-densepose-geo",
|
||||
"crates/wifi-densepose-worldgraph", # ADR-139 — WorldGraph environmental digital twin
|
||||
# geo + worldgraph extracted to ruvnet/worldgraph submodule (see crates/worldgraph)
|
||||
"crates/wifi-densepose-engine", # ADR-135..146 integration/composition layer
|
||||
"crates/wifi-densepose-calibration", # ADR-151 — per-room calibration & specialist training
|
||||
"crates/nvsim",
|
||||
@@ -58,7 +57,7 @@ members = [
|
||||
"crates/wifi-densepose-bfld",
|
||||
# ADR-147: OccWorld thin-client bridge — WorldGraph PersonTrack history →
|
||||
# OccWorld Python subprocess → TrajectoryPrior injection into pose tracker.
|
||||
"crates/wifi-densepose-worldmodel",
|
||||
# worldmodel extracted to ruvnet/worldgraph submodule (consumed via path dep)
|
||||
# ADR-147 (Phase 5): OccWorld TransVQVAE ported to Candle — native Rust
|
||||
# inference without Python/IPC overhead. Loaded alongside the Python bridge
|
||||
# as a faster alternative once Phase-5 weights are available.
|
||||
@@ -88,6 +87,7 @@ members = [
|
||||
exclude = [
|
||||
"crates/wifi-densepose-wasm-edge",
|
||||
"crates/homecore-plugin-example",
|
||||
"crates/worldgraph", # ruvnet/worldgraph submodule — its own workspace (geo/worldgraph/worldmodel)
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -215,7 +215,7 @@ wifi-densepose-hardware = { version = "0.3.0", path = "crates/wifi-densepose-har
|
||||
wifi-densepose-wasm = { version = "0.3.0", path = "crates/wifi-densepose-wasm" }
|
||||
wifi-densepose-mat = { version = "0.3.0", path = "crates/wifi-densepose-mat" }
|
||||
wifi-densepose-ruvector = { version = "0.3.0", path = "crates/wifi-densepose-ruvector" }
|
||||
wifi-densepose-worldmodel = { version = "0.3.0", path = "crates/wifi-densepose-worldmodel" }
|
||||
wifi-densepose-worldmodel = { version = "0.3.0", path = "crates/worldgraph/wifi-densepose-worldmodel" }
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
||||
+1
-1
Submodule v2/crates/ruv-neural updated: 1ece3afa33...c9638faaf8
Submodule
+1
Submodule v2/crates/ruview-swarm added at 267aba5be2
@@ -1,84 +0,0 @@
|
||||
[package]
|
||||
name = "ruview-swarm"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "RuView drone swarm control system — hierarchical-mesh topology, Raft consensus, MARL, CSI sensing integration (ADR-148)"
|
||||
license = "Apache-2.0"
|
||||
# Publishing disabled until: (1) PR #862 merges, (2) internal path-deps are
|
||||
# published in dependency order, (3) export-control sign-off on the ITAR-gated
|
||||
# coordination features (USML Category VIII(h)(12)). Flip to true deliberately.
|
||||
publish = false
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# ITAR/USML Category VIII(h)(12): swarming coordination features.
|
||||
# Must not be enabled in international distributions without export counsel review.
|
||||
itar-unrestricted = []
|
||||
mavlink = ["dep:mavlink"]
|
||||
ros2-dds = []
|
||||
onnx = ["dep:ort"]
|
||||
simulation = []
|
||||
demo = ["simulation"]
|
||||
full = ["mavlink", "onnx", "demo", "itar-unrestricted"]
|
||||
ruflo = ["dep:reqwest", "dep:serde_json"]
|
||||
# Heavy GPU-capable MARL training (real Candle autodiff PPO). Off by default so
|
||||
# the default build stays light and the existing test suite keeps passing.
|
||||
train = ["dep:candle-core", "dep:candle-nn"]
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda"]
|
||||
|
||||
[dependencies]
|
||||
wifi-densepose-core = { path = "../wifi-densepose-core" }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1", optional = true }
|
||||
toml = "0.8"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
async-trait = "0.1"
|
||||
|
||||
# MAVLink v2 (optional)
|
||||
mavlink = { version = "0.13", optional = true }
|
||||
|
||||
# ONNX Runtime (optional — for MARL actor inference)
|
||||
ort = { version = "2.0.0-rc.11", optional = true }
|
||||
|
||||
# Candle 0.9 — real autodiff PPO training (optional, behind `train` feature).
|
||||
candle-core = { version = "0.9", default-features = false, optional = true }
|
||||
candle-nn = { version = "0.9", default-features = false, optional = true }
|
||||
|
||||
# HTTP client (optional — for Ruflo HTTP backend)
|
||||
reqwest = { version = "0.12", features = ["json"], optional = true }
|
||||
|
||||
# Crypto — MAVLink v2 HMAC-SHA256 signing
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
|
||||
# Error handling
|
||||
thiserror = "2.0"
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
|
||||
# Numerics
|
||||
nalgebra = "0.33"
|
||||
rand = "0.8"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
tokio-test = "0.4"
|
||||
|
||||
[[bench]]
|
||||
name = "swarm_bench"
|
||||
harness = false
|
||||
|
||||
# MARL training binary — requires the `train` feature (Candle autodiff).
|
||||
# Excluded from the default build so `cargo test`/CI stay light.
|
||||
[[bin]]
|
||||
name = "train_marl"
|
||||
required-features = ["train"]
|
||||
|
||||
# ADR-171 Stage-1 evaluation CLI — pure Rust, no special feature needed.
|
||||
[[bin]]
|
||||
name = "eval_swarm"
|
||||
@@ -1,108 +0,0 @@
|
||||
# wifi-densepose-swarm
|
||||
|
||||
Drone swarm control system for the RuView wifi-densepose workspace. Implements ADR-148.
|
||||
|
||||
## Overview
|
||||
|
||||
`wifi-densepose-swarm` provides a hierarchical-mesh drone swarm coordination system
|
||||
with Raft consensus, MAPPO-based multi-agent reinforcement learning, and tight
|
||||
integration with the existing WiFi CSI sensing pipeline (`wifi-densepose-signal`,
|
||||
`wifi-densepose-ruvector`).
|
||||
|
||||
## Features
|
||||
|
||||
- **Hierarchical-Mesh Topology** — cluster heads over Raft consensus; inter-cluster Gossip for map dissemination
|
||||
- **Formation Control** — F1 VirtualStructure, F2 LeaderFollower, F3 Reynolds flocking
|
||||
- **3-Phase Coverage** — boustrophedon sweep → Bayesian probability grid → multi-drone triangulation
|
||||
- **RRT-APF Path Planner** — RRT* with Artificial Potential Field reactive collision avoidance
|
||||
- **MARL Actor (MAPPO)** — 64-dim local observation, 3-layer MLP actor, CTDE training interface
|
||||
- **CSI Sensing Integration** — drone payload pipeline (ESP32-S3 → Jetson), multi-drone CSI fusion
|
||||
- **OccWorld Bridge** — integrates ADR-147 OccWorld occupancy prior as path planner environment
|
||||
- **Security Hardening** — MAVLink v2 HMAC-SHA256 signing, UWB GPS anti-spoofing, onboard geofencing, Remote ID
|
||||
- **Fail-Safe State Machine** — 10-state onboard safety system, GCS-independent
|
||||
- **Demo & Training Modes** — synthetic CSI generation, Gazebo/PX4 SITL interface, TOML mission configs
|
||||
|
||||
## ITAR Notice
|
||||
|
||||
> ⚠️ **Export-controlled capability.** Swarming coordination features (formation control,
|
||||
> Raft consensus, task allocation) are gated behind the `itar-unrestricted` feature flag
|
||||
> per **USML Category VIII(h)(12)**. Default builds compile only safe stubs.
|
||||
> Do not enable `itar-unrestricted` for international distribution without export counsel review.
|
||||
|
||||
## Crate Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `default` | Core types, sensing, failsafe, config, MARL — no ITAR-gated code |
|
||||
| `itar-unrestricted` | Enables formation control, Raft consensus, task allocation |
|
||||
| `mavlink` | MAVLink v2 protocol support |
|
||||
| `onnx` | ONNX Runtime backend for MARL actor inference (INT8) |
|
||||
| `simulation` | Simulation-mode stubs |
|
||||
| `demo` | Synthetic CSI generation, scenario runners |
|
||||
| `full` | All of the above |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use wifi_densepose_swarm::{config::SwarmConfig, demo::scenario::DemoScenario};
|
||||
|
||||
// Load a mission profile
|
||||
let config = SwarmConfig::sar_default();
|
||||
|
||||
// Run a demo scenario
|
||||
let scenario = DemoScenario::sar_rubble_field(4); // 4-drone SAR
|
||||
let estimated_secs = scenario.estimate_coverage_time_secs();
|
||||
// → < 240 s for 4 drones over 400×400 m (beyond Wi2SAR SOTA single-drone baseline)
|
||||
```
|
||||
|
||||
## Mission Profiles
|
||||
|
||||
| Profile | Drones | Area | Application |
|
||||
|---------|--------|------|-------------|
|
||||
| `sar` | 6–12 | 400×400 m | Structural collapse victim search |
|
||||
| `inspection` | 3–6 | Linear corridor | Infrastructure (power lines, bridges) |
|
||||
| `agriculture` | 4–12 | Field-configurable | NDVI mapping, variable-rate spraying |
|
||||
| `mine` | 2–4 | Tunnel | GPS-denied underground exploration |
|
||||
| `relay` | 6–20 | Perimeter | Emergency telecom relay chain |
|
||||
| `demo` | Any | Configurable | Synthetic CSI, configurable victims |
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── types.rs — NodeId, DroneState, SwarmTask, SwarmError, FailSafeState
|
||||
├── topology/ — Raft consensus¹, Gossip dissemination, MeshTopology
|
||||
├── formation/ — VirtualStructure¹, LeaderFollower¹, Reynolds flocking¹
|
||||
├── planning/ — RRT-APF planner, 3-phase coverage, Bayesian grid, pheromone
|
||||
├── allocation/ — Auction-based task allocation¹, FNN bid scorer¹
|
||||
├── sensing/ — CSI payload pipeline, multi-drone fusion, OccWorld bridge
|
||||
├── marl/ — MAPPO actor, LocalObservation, reward shaping, TrainingConfig
|
||||
├── security/ — MAVLink signing, UWB anti-spoofing, geofencing, Remote ID
|
||||
├── failsafe/ — 10-state onboard fail-safe machine
|
||||
├── config/ — TOML SwarmConfig with mission presets
|
||||
├── demo/ — Synthetic CSI, DemoScenario runners
|
||||
├── integration/ — FlightController trait (PX4/ArduPilot/Sim)
|
||||
└── bench_support.rs — Criterion fixture generators
|
||||
|
||||
¹ Requires `itar-unrestricted` feature.
|
||||
```
|
||||
|
||||
## Related ADRs
|
||||
|
||||
| ADR | Title | Relation |
|
||||
|-----|-------|----------|
|
||||
| ADR-148 | Drone Swarm Control System | This crate |
|
||||
| ADR-147 | OccWorld Occupancy World Model | Environment prior via `sensing::occworld_bridge` |
|
||||
| ADR-134 | CSI→CIR ISTA Sparse Recovery | Drone payload sensing |
|
||||
| ADR-146 | RF Encoder Multitask Heads | Drone payload inference |
|
||||
| ADR-016 | RuVector Training Integration | CrossViewpointAttention |
|
||||
|
||||
## Performance Targets (vs. Wi2SAR SOTA)
|
||||
|
||||
| Metric | Wi2SAR baseline (1 drone) | 4-drone target |
|
||||
|--------|--------------------------|----------------|
|
||||
| Coverage | 160,000 m² | 160,000 m² |
|
||||
| Time | 13.5 min | ≤ 4 min |
|
||||
| Localization | 5 m | ≤ 2 m (3-view fusion) |
|
||||
| MARL inference | N/A | ≤ 5 ms (INT8, release) |
|
||||
| Raft election | N/A | ≤ 300 ms |
|
||||
@@ -1,70 +0,0 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use ruview_swarm::marl::{MappoActor, ActorConfig};
|
||||
use ruview_swarm::marl::LocalObservation;
|
||||
use ruview_swarm::sensing::MultiViewFusion;
|
||||
use ruview_swarm::planning::RrtApfPlanner;
|
||||
use ruview_swarm::demo::{DemoScenario};
|
||||
use ruview_swarm::types::{CsiDetection, NodeId, Position3D};
|
||||
|
||||
fn bench_marl_inference(c: &mut Criterion) {
|
||||
let actor = MappoActor::random_init(ActorConfig::default());
|
||||
let obs = LocalObservation::zeros();
|
||||
c.bench_function("marl_actor_inference", |b| b.iter(|| actor.forward(&obs)));
|
||||
}
|
||||
|
||||
fn bench_rrt_apf_plan(c: &mut Criterion) {
|
||||
let planner = RrtApfPlanner::new(3.0);
|
||||
let start = Position3D { x: 0.0, y: 0.0, z: -30.0 };
|
||||
let goal = Position3D { x: 50.0, y: 50.0, z: -30.0 };
|
||||
c.bench_function("rrt_apf_100iter", |b| b.iter(|| {
|
||||
let mut rng = rand::thread_rng();
|
||||
planner.plan(start, goal, 100, &mut rng)
|
||||
}));
|
||||
}
|
||||
|
||||
fn bench_multiview_fusion(c: &mut Criterion) {
|
||||
let fusion = MultiViewFusion::default();
|
||||
let detections = vec![
|
||||
CsiDetection { drone_id: NodeId(0), confidence: 0.85, victim_position: Some(Position3D { x: 51.0, y: 49.0, z: 0.0 }), timestamp_ms: 0 },
|
||||
CsiDetection { drone_id: NodeId(1), confidence: 0.78, victim_position: Some(Position3D { x: 49.0, y: 51.0, z: 0.0 }), timestamp_ms: 0 },
|
||||
CsiDetection { drone_id: NodeId(2), confidence: 0.92, victim_position: Some(Position3D { x: 50.0, y: 50.0, z: 0.0 }), timestamp_ms: 0 },
|
||||
];
|
||||
let positions = vec![
|
||||
(NodeId(0), Position3D { x: 0.0, y: 0.0, z: -30.0 }),
|
||||
(NodeId(1), Position3D { x: 100.0, y: 0.0, z: -30.0 }),
|
||||
(NodeId(2), Position3D { x: 50.0, y: 86.6, z: -30.0 }),
|
||||
];
|
||||
c.bench_function("multiview_fusion_3drones", |b| b.iter(|| fusion.fuse(&detections, &positions)));
|
||||
}
|
||||
|
||||
fn bench_demo_coverage_estimate(c: &mut Criterion) {
|
||||
let scenario = DemoScenario::sar_rubble_field(4);
|
||||
c.bench_function("demo_coverage_estimate", |b| b.iter(|| scenario.estimate_coverage_time_secs()));
|
||||
}
|
||||
|
||||
fn bench_ppo_update(c: &mut Criterion) {
|
||||
use ruview_swarm::marl::{MappoActor, ActorConfig, LocalObservation};
|
||||
use ruview_swarm::marl::training_loop::{ReplayBuffer, Transition, PpoConfig, ppo_update};
|
||||
use ruview_swarm::marl::actor::ActorAction;
|
||||
|
||||
let mut buf = ReplayBuffer::new(64);
|
||||
for i in 0..64 {
|
||||
buf.push(Transition {
|
||||
obs: LocalObservation::zeros(),
|
||||
action: ActorAction { delta_heading_rad: 0.1, delta_altitude_m: 0.0, speed_ms: 5.0, trigger_csi_scan: true },
|
||||
reward: if i % 2 == 0 { 10.0 } else { -2.0 },
|
||||
next_obs: LocalObservation::zeros(),
|
||||
done: i == 63,
|
||||
});
|
||||
}
|
||||
let cfg = PpoConfig::default();
|
||||
c.bench_function("ppo_update_64transitions", |b| {
|
||||
b.iter(|| {
|
||||
let mut actor = MappoActor::random_init(ActorConfig::default());
|
||||
ppo_update(&mut actor, &buf, &cfg)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_marl_inference, bench_rrt_apf_plan, bench_multiview_fusion, bench_demo_coverage_estimate, bench_ppo_update);
|
||||
criterion_main!(benches);
|
||||
@@ -1,2 +0,0 @@
|
||||
# ADR-171 evaluation outputs
|
||||
RESULTS.md is generated by the `eval_swarm` binary.
|
||||
@@ -1,26 +0,0 @@
|
||||
# ruview-swarm Evaluation Results (ADR-171 Stage 1, kinematic)
|
||||
|
||||
Statistically-rigorous evaluation harness: seeded multi-run rollouts with IQM + 95% stratified-bootstrap confidence intervals (Agarwal et al., NeurIPS 2021).
|
||||
|
||||
## Run configuration
|
||||
|
||||
- **Stage**: 1 (kinematic, self-contained, deterministic per seed)
|
||||
- **Episodes per pattern**: 100 (seed × episode matrix)
|
||||
- **CI method**: 95% stratified bootstrap of the IQM, stratified by seed
|
||||
- **GDOP**: 2-D geometric dilution of precision at first detection
|
||||
|
||||
> **Stage 2 pending**: high-fidelity Gazebo/PX4 SITL evaluation (false-alarm rate, real collision rate on the median seeds) is a follow-on — see ADR-171 §6.1. The collision figures below are a kinematic min-separation proxy, not SITL physics.
|
||||
|
||||
## Flight-pattern leaderboard
|
||||
|
||||
| Flight pattern | Coverage IQM [95% CI] | Localization (m) IQM [95% CI] | Detection rate | Mean GDOP |
|
||||
|----------------|-----------------------|-------------------------------|----------------|-----------|
|
||||
| partitioned_lawnmower | 1.000 [1.000, 1.000] | 7.022 [5.669, 8.379] | 100.0% | 0.000 |
|
||||
| pheromone | 0.662 [0.652, 0.671] | 4.110 [3.346, 5.141] | 95.0% | 1.598 |
|
||||
| levy_flight | 0.490 [0.489, 0.491] | 3.523 [2.897, 4.160] | 100.0% | 0.000 |
|
||||
| boustrophedon | 0.370 [0.370, 0.370] | 2.740 [2.357, 3.207] | 100.0% | 0.000 |
|
||||
| spiral | 0.336 [0.336, 0.336] | 3.082 [2.678, 3.568] | 100.0% | 0.000 |
|
||||
| potential_field | 0.254 [0.252, 0.256] | 4.343 [3.489, 5.265] | 100.0% | 0.000 |
|
||||
| _Wi2SAR (paper baseline)_ | _n/a_ | _5.0 (paper)_ | _n/a_ | _n/a_ |
|
||||
|
||||
_Wi2SAR row is the published single-drone localization figure (arxiv 2604.09115), shown paper-to-paper for reference only — it was not re-run through this kinematic harness._
|
||||
@@ -1,118 +0,0 @@
|
||||
//! Contract-net (auction) task allocation.
|
||||
|
||||
use crate::types::{DroneState, NodeId, SwarmTask, TaskId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A bid submitted by a node for a task.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Bid {
|
||||
pub node_id: NodeId,
|
||||
pub task_id: TaskId,
|
||||
/// Lower score = more capable/willing. Computed by the bidding node.
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
/// Auction-based task allocator.
|
||||
pub struct AuctionAllocator {
|
||||
pub pending_tasks: HashMap<TaskId, SwarmTask>,
|
||||
pub bids: HashMap<TaskId, Vec<Bid>>,
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
impl AuctionAllocator {
|
||||
pub fn new(timeout_ms: u64) -> Self {
|
||||
Self {
|
||||
pending_tasks: HashMap::new(),
|
||||
bids: HashMap::new(),
|
||||
timeout_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Announce a new task (add to pending pool).
|
||||
pub fn announce_task(&mut self, task: SwarmTask) {
|
||||
let id = task.id;
|
||||
self.pending_tasks.insert(id, task);
|
||||
self.bids.entry(id).or_default();
|
||||
}
|
||||
|
||||
/// Accept a bid for a pending task.
|
||||
pub fn submit_bid(&mut self, bid: Bid) {
|
||||
if self.pending_tasks.contains_key(&bid.task_id) {
|
||||
self.bids.entry(bid.task_id).or_default().push(bid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve all pending tasks: assign each to the best bidder.
|
||||
/// Returns a list of (TaskId, winning NodeId) pairs.
|
||||
pub fn resolve(&mut self) -> Vec<(TaskId, NodeId)> {
|
||||
let mut results = Vec::new();
|
||||
let task_ids: Vec<TaskId> = self.pending_tasks.keys().copied().collect();
|
||||
|
||||
for task_id in task_ids {
|
||||
let winner = self
|
||||
.bids
|
||||
.get(&task_id)
|
||||
.and_then(|bids| {
|
||||
bids.iter()
|
||||
.min_by(|a, b| {
|
||||
a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|b| b.node_id)
|
||||
});
|
||||
|
||||
if let Some(winner_id) = winner {
|
||||
if let Some(task) = self.pending_tasks.get_mut(&task_id) {
|
||||
task.assigned_to = Some(winner_id);
|
||||
}
|
||||
results.push((task_id, winner_id));
|
||||
self.bids.remove(&task_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up resolved tasks
|
||||
for (tid, _) in &results {
|
||||
self.pending_tasks.remove(tid);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Compute a bid score heuristic for a node given a task.
|
||||
/// Returns a score ∈ [0, ∞): lower is better.
|
||||
pub fn compute_bid_score(node: &DroneState, task: &SwarmTask) -> f32 {
|
||||
let dist = node.position.distance_to(&task.target) as f32;
|
||||
let battery_penalty = (100.0 - node.battery_pct) / 100.0;
|
||||
let link_penalty = 1.0 - node.link_quality;
|
||||
let priority_bonus = 1.0 - task.priority.clamp(0.0, 1.0);
|
||||
dist / 100.0 + battery_penalty * 0.3 + link_penalty * 0.2 + priority_bonus * 0.1
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{Position3D, SwarmTask, TaskId, TaskKind};
|
||||
|
||||
fn make_task(id: u64) -> SwarmTask {
|
||||
SwarmTask {
|
||||
id: TaskId(id),
|
||||
kind: TaskKind::ReturnToHome,
|
||||
priority: 0.5,
|
||||
target: Position3D::zero(),
|
||||
deadline_ms: None,
|
||||
assigned_to: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auction_assigns_best_bidder() {
|
||||
let mut alloc = AuctionAllocator::new(1000);
|
||||
let task = make_task(1);
|
||||
alloc.announce_task(task);
|
||||
alloc.submit_bid(Bid { node_id: NodeId(1), task_id: TaskId(1), score: 0.8 });
|
||||
alloc.submit_bid(Bid { node_id: NodeId(2), task_id: TaskId(1), score: 0.3 });
|
||||
let results = alloc.resolve();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].1, NodeId(2)); // lower score wins
|
||||
}
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
//! Lightweight 3-layer FNN bid scorer — pure Rust, no ONNX required.
|
||||
|
||||
/// 3-layer FNN: 5 inputs → 16 hidden (ReLU) → 8 hidden (ReLU) → 1 output (sigmoid).
|
||||
pub struct FnnScorer {
|
||||
pub w1: [[f32; 5]; 16],
|
||||
pub b1: [f32; 16],
|
||||
pub w2: [[f32; 16]; 8],
|
||||
pub b2: [f32; 8],
|
||||
pub w3: [f32; 8],
|
||||
pub b3: f32,
|
||||
}
|
||||
|
||||
fn relu(x: f32) -> f32 {
|
||||
x.max(0.0)
|
||||
}
|
||||
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
impl FnnScorer {
|
||||
/// Score a feature vector. Returns sigmoid(output) ∈ [0, 1].
|
||||
/// Features: [dist_norm, battery_norm, link_quality, csi_confidence, workload_norm]
|
||||
pub fn score(&self, features: [f32; 5]) -> f32 {
|
||||
// Layer 1: 5 → 16 (ReLU)
|
||||
let mut h1 = [0.0f32; 16];
|
||||
for (i, row) in self.w1.iter().enumerate() {
|
||||
let z: f32 = row.iter().zip(features.iter()).map(|(w, x)| w * x).sum();
|
||||
h1[i] = relu(z + self.b1[i]);
|
||||
}
|
||||
|
||||
// Layer 2: 16 → 8 (ReLU)
|
||||
let mut h2 = [0.0f32; 8];
|
||||
for (i, row) in self.w2.iter().enumerate() {
|
||||
let z: f32 = row.iter().zip(h1.iter()).map(|(w, x)| w * x).sum();
|
||||
h2[i] = relu(z + self.b2[i]);
|
||||
}
|
||||
|
||||
// Layer 3: 8 → 1 (sigmoid)
|
||||
let z3: f32 = self.w3.iter().zip(h2.iter()).map(|(w, x)| w * x).sum::<f32>() + self.b3;
|
||||
sigmoid(z3)
|
||||
}
|
||||
|
||||
/// Default weights initialised to a simple identity-like setup.
|
||||
pub fn default_weights() -> Self {
|
||||
// Simple: w1 diagonalish, others small constant
|
||||
// Index needed: diagonal/strided init uses i for both row and column.
|
||||
let mut w1 = [[0.0f32; 5]; 16];
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..5 {
|
||||
w1[i][i] = 1.0;
|
||||
}
|
||||
for row in w1.iter_mut().take(16).skip(5) {
|
||||
row[0] = 0.1;
|
||||
}
|
||||
let mut w2 = [[0.0f32; 16]; 8];
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..8 {
|
||||
w2[i][i * 2] = 1.0;
|
||||
}
|
||||
let w3 = [0.125f32; 8];
|
||||
Self {
|
||||
w1,
|
||||
b1: [0.0; 16],
|
||||
w2,
|
||||
b2: [0.0; 8],
|
||||
w3,
|
||||
b3: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FnnScorer {
|
||||
fn default() -> Self {
|
||||
Self::default_weights()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_score_in_unit_interval() {
|
||||
let scorer = FnnScorer::default_weights();
|
||||
let features = [0.3f32, 0.8, 0.9, 0.75, 0.2];
|
||||
let s = scorer.score(features);
|
||||
assert!(s >= 0.0 && s <= 1.0, "score {s} out of [0,1]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_deterministic() {
|
||||
let scorer = FnnScorer::default_weights();
|
||||
let f = [0.5f32; 5];
|
||||
assert_eq!(scorer.score(f), scorer.score(f));
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
//! Task allocation: auction-based and FNN-scored bid evaluation.
|
||||
//!
|
||||
// NOTE: Task allocation is ITAR-controlled (USML Category VIII(h)(12)).
|
||||
// Only available when the `itar-unrestricted` feature is enabled.
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod auction;
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod fnn;
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use auction::{AuctionAllocator, Bid};
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use fnn::FnnScorer;
|
||||
|
||||
/// Stub: task allocation is export-controlled. Enable `itar-unrestricted` feature.
|
||||
#[cfg(not(feature = "itar-unrestricted"))]
|
||||
pub fn allocate_stub() -> crate::SwarmResult<()> {
|
||||
Err(crate::SwarmError::Security(
|
||||
"Task allocation requires itar-unrestricted feature (USML VIII(h)(12))".into(),
|
||||
))
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
//! Benchmark support utilities: scenario builders and timing helpers for criterion benchmarks.
|
||||
|
||||
use crate::types::{DroneState, NodeId, Position3D, Velocity3D};
|
||||
|
||||
/// Generate N drone states arranged in a grid.
|
||||
pub fn grid_drone_states(n: usize, spacing_m: f64) -> Vec<DroneState> {
|
||||
let side = (n as f64).sqrt().ceil() as usize;
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let row = i / side;
|
||||
let col = i % side;
|
||||
DroneState {
|
||||
id: NodeId(i as u32),
|
||||
position: Position3D {
|
||||
x: col as f64 * spacing_m,
|
||||
y: row as f64 * spacing_m,
|
||||
z: -30.0,
|
||||
},
|
||||
velocity: Velocity3D::default(),
|
||||
heading_rad: 0.0,
|
||||
altitude_agl_m: 30.0,
|
||||
battery_pct: 80.0,
|
||||
link_quality: 0.9,
|
||||
timestamp_ms: 0,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate N evenly-spaced positions in a circle.
|
||||
pub fn circle_positions(n: usize, radius_m: f64) -> Vec<(NodeId, Position3D)> {
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let angle = 2.0 * std::f64::consts::PI * i as f64 / n as f64;
|
||||
(
|
||||
NodeId(i as u32),
|
||||
Position3D {
|
||||
x: radius_m * angle.cos(),
|
||||
y: radius_m * angle.sin(),
|
||||
z: -30.0,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
//! ADR-171 Stage-1 evaluation CLI.
|
||||
//!
|
||||
//! Runs the kinematic eval matrix over every flight pattern (default) and
|
||||
//! writes a ranked `RESULTS.md` leaderboard. Pure Rust — no special feature
|
||||
//! flag required, so it builds and runs in default CI.
|
||||
//!
|
||||
//! Defaults are intentionally small (10 seeds × 10 episodes) so the run is fast.
|
||||
//! The full ADR-171 reporting configuration is 10 seeds × 50 episodes — pass
|
||||
//! `--seeds 10 --episodes 50` for the publication run.
|
||||
//!
|
||||
//! ```text
|
||||
//! cargo run -p ruview-swarm --bin eval_swarm -- \
|
||||
//! --seeds 10 --episodes 10 --out crates/ruview-swarm/evals/RESULTS.md
|
||||
//! ```
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use ruview_swarm::evals::metrics::AggregateMetrics;
|
||||
use ruview_swarm::evals::report::render_results_md;
|
||||
use ruview_swarm::evals::runner::{run_matrix, EvalConfig};
|
||||
use ruview_swarm::planning::patterns::FlightPattern;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let mut seeds = 10usize;
|
||||
let mut episodes = 10usize;
|
||||
let mut out = PathBuf::from("crates/ruview-swarm/evals/RESULTS.md");
|
||||
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--seeds" => {
|
||||
i += 1;
|
||||
seeds = args.get(i).and_then(|s| s.parse().ok()).unwrap_or(seeds);
|
||||
}
|
||||
"--episodes" => {
|
||||
i += 1;
|
||||
episodes = args.get(i).and_then(|s| s.parse().ok()).unwrap_or(episodes);
|
||||
}
|
||||
"--out" => {
|
||||
i += 1;
|
||||
if let Some(p) = args.get(i) {
|
||||
out = PathBuf::from(p);
|
||||
}
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
eprintln!(
|
||||
"eval_swarm — ADR-171 Stage-1 kinematic evaluator\n\
|
||||
Usage: eval_swarm [--seeds N] [--episodes M] [--out PATH]\n\
|
||||
Defaults: --seeds 10 --episodes 10 --out crates/ruview-swarm/evals/RESULTS.md"
|
||||
);
|
||||
return;
|
||||
}
|
||||
other => {
|
||||
eprintln!("warning: ignoring unknown argument '{other}'");
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"Running ADR-171 Stage-1 eval: {seeds} seeds × {episodes} episodes \
|
||||
over {} flight patterns...",
|
||||
FlightPattern::all().len()
|
||||
);
|
||||
|
||||
let mut rows: Vec<(String, AggregateMetrics)> = Vec::new();
|
||||
for (idx, pattern) in FlightPattern::all().into_iter().enumerate() {
|
||||
let mut cfg = EvalConfig::sar_small(pattern);
|
||||
cfg.seeds = seeds;
|
||||
cfg.episodes_per_seed = episodes;
|
||||
let matrix = run_matrix(&cfg);
|
||||
let agg = AggregateMetrics::from_strata(&matrix, 0x0149 ^ idx as u64);
|
||||
eprintln!(
|
||||
" {}: coverage IQM {:.3}, detection {:.0}%",
|
||||
pattern.name(),
|
||||
agg.coverage_iqm.point,
|
||||
agg.detection_rate * 100.0
|
||||
);
|
||||
rows.push((pattern.name().to_string(), agg));
|
||||
}
|
||||
|
||||
// Rank by descending coverage point estimate.
|
||||
rows.sort_by(|a, b| {
|
||||
b.1.coverage_iqm
|
||||
.point
|
||||
.partial_cmp(&a.1.coverage_iqm.point)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let md = render_results_md(&rows);
|
||||
|
||||
if let Some(parent) = out.parent() {
|
||||
if let Err(e) = std::fs::create_dir_all(parent) {
|
||||
eprintln!("error: could not create {}: {e}", parent.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
if let Err(e) = std::fs::write(&out, &md) {
|
||||
eprintln!("error: could not write {}: {e}", out.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
eprintln!("Wrote {} ({} bytes).", out.display(), md.len());
|
||||
}
|
||||
@@ -1,474 +0,0 @@
|
||||
//! MARL training entry point for ruview-swarm (ADR-148 M4).
|
||||
//!
|
||||
//! Real Candle autodiff PPO training loop. Runs on CPU, or CUDA when built
|
||||
//! with `--features train,cuda` (local RTX 5080 or a GCP L4 instance).
|
||||
//!
|
||||
//! Movement is driven by a selectable `FlightPattern` (boustrophedon,
|
||||
//! partitioned, spiral, pheromone, potential, levy) and reward is shaped by a
|
||||
//! selectable `LearningPattern` (mappo, ippo, curiosity, meta). This makes each
|
||||
//! pattern produce visibly distinct trajectories + telemetry instead of every
|
||||
//! drone clustering on the orchestrator's internal coverage strategy.
|
||||
//!
|
||||
//! Usage:
|
||||
//! cargo run --release -p ruview-swarm --features train,cuda --bin train_marl -- \
|
||||
//! --episodes 5000 --drones 4 --profile sar \
|
||||
//! --flight-pattern partitioned --learn-pattern mappo_curiosity \
|
||||
//! --checkpoint-dir ./marl-checkpoints
|
||||
//!
|
||||
//! Right-sizing note: the policy is a 64→128→64 MLP. The bottleneck is
|
||||
//! environment-rollout throughput, not GPU matmul — an L4 + 16 vCPU beats an
|
||||
//! 8× A100 box for this workload at ~1/20th the cost. See scripts/gcp/.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use ruview_swarm::config::SwarmConfig;
|
||||
use ruview_swarm::integration::telemetry::{DroneFrame, TelemetryRecorder};
|
||||
use ruview_swarm::marl::candle_ppo::{CandlePpoConfig, CandleTrainer};
|
||||
use ruview_swarm::marl::learning::{shaped_reward, CuriosityModule, LearningPattern};
|
||||
use ruview_swarm::marl::observation::LocalObservation;
|
||||
use ruview_swarm::marl::reward::{RewardCalculator, RewardContext};
|
||||
use ruview_swarm::planning::patterns::{FlightPattern, PatternContext};
|
||||
use ruview_swarm::types::{DroneState, NodeId, Position3D, Velocity3D};
|
||||
|
||||
struct Args {
|
||||
episodes: usize,
|
||||
drones: usize,
|
||||
profile: String,
|
||||
steps_per_episode: usize,
|
||||
checkpoint_dir: String,
|
||||
checkpoint_every: usize,
|
||||
telemetry: Option<String>,
|
||||
telemetry_episode: usize,
|
||||
flight_pattern: String,
|
||||
learn_pattern: String,
|
||||
}
|
||||
|
||||
impl Default for Args {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
episodes: 1000,
|
||||
drones: 4,
|
||||
profile: "sar".to_string(),
|
||||
steps_per_episode: 200,
|
||||
checkpoint_dir: "./marl-checkpoints".to_string(),
|
||||
checkpoint_every: 100,
|
||||
telemetry: None,
|
||||
telemetry_episode: 0,
|
||||
flight_pattern: "partitioned".to_string(),
|
||||
learn_pattern: "mappo".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args::default();
|
||||
let argv: Vec<String> = std::env::args().collect();
|
||||
let mut i = 1;
|
||||
while i < argv.len() {
|
||||
let next = || argv.get(i + 1).cloned().unwrap_or_default();
|
||||
match argv[i].as_str() {
|
||||
"--episodes" => {
|
||||
args.episodes = next().parse().unwrap_or(args.episodes);
|
||||
i += 1;
|
||||
}
|
||||
"--drones" => {
|
||||
args.drones = next().parse().unwrap_or(args.drones);
|
||||
i += 1;
|
||||
}
|
||||
"--profile" => {
|
||||
args.profile = next();
|
||||
i += 1;
|
||||
}
|
||||
"--steps" => {
|
||||
args.steps_per_episode = next().parse().unwrap_or(args.steps_per_episode);
|
||||
i += 1;
|
||||
}
|
||||
"--checkpoint-dir" => {
|
||||
args.checkpoint_dir = next();
|
||||
i += 1;
|
||||
}
|
||||
"--checkpoint-every" => {
|
||||
args.checkpoint_every = next().parse().unwrap_or(args.checkpoint_every);
|
||||
i += 1;
|
||||
}
|
||||
"--telemetry" => {
|
||||
args.telemetry = Some(next());
|
||||
i += 1;
|
||||
}
|
||||
"--telemetry-episode" => {
|
||||
args.telemetry_episode = next().parse().unwrap_or(args.telemetry_episode);
|
||||
i += 1;
|
||||
}
|
||||
"--flight-pattern" => {
|
||||
args.flight_pattern = next();
|
||||
i += 1;
|
||||
}
|
||||
"--learn-pattern" => {
|
||||
args.learn_pattern = next();
|
||||
i += 1;
|
||||
}
|
||||
"-h" | "--help" => {
|
||||
println!(
|
||||
"train_marl — ruview-swarm MARL training (ADR-148 M4)\n\
|
||||
\nOptions:\n \
|
||||
--episodes N training episodes (default 1000)\n \
|
||||
--drones N swarm size (default 4)\n \
|
||||
--profile NAME sar|inspection|mine|agriculture (default sar)\n \
|
||||
--steps N steps per episode (default 200)\n \
|
||||
--flight-pattern P boustrophedon|partitioned|spiral|pheromone|potential|levy (default partitioned)\n \
|
||||
--learn-pattern P mappo|ippo|curiosity|meta (default mappo)\n \
|
||||
--checkpoint-dir D checkpoint output dir (default ./marl-checkpoints)\n \
|
||||
--checkpoint-every N save every N episodes (default 100)\n \
|
||||
--telemetry FILE write JSONL telemetry for viz/swarm_viz.html\n \
|
||||
--telemetry-episode N which episode's steps to record spatially (default 0)"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => eprintln!("warning: ignoring unknown arg {other}"),
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
args
|
||||
}
|
||||
|
||||
fn config_for(profile: &str) -> SwarmConfig {
|
||||
match profile {
|
||||
"inspection" => SwarmConfig::inspection_default(),
|
||||
"mine" => SwarmConfig::mine_default(),
|
||||
"agriculture" => SwarmConfig::agriculture_default(),
|
||||
_ => SwarmConfig::wi2sar_reference(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a world coordinate to a grid cell index at `grid_res` metre resolution.
|
||||
fn cell_of(x: f64, y: f64, grid_res: f64) -> (u32, u32) {
|
||||
let gx = (x / grid_res).floor().max(0.0) as u32;
|
||||
let gy = (y / grid_res).floor().max(0.0) as u32;
|
||||
(gx, gy)
|
||||
}
|
||||
|
||||
/// Mark every grid cell within the drone's circular scan footprint as scanned,
|
||||
/// returning how many *newly* scanned cells this step contributed.
|
||||
fn mark_scanned(
|
||||
scanned: &mut HashSet<(u32, u32)>,
|
||||
pos: &Position3D,
|
||||
scan_width_m: f64,
|
||||
grid_res: f64,
|
||||
area_w: f64,
|
||||
area_h: f64,
|
||||
) -> u32 {
|
||||
let r = scan_width_m * 0.5;
|
||||
let cols = (area_w / grid_res).ceil() as i64;
|
||||
let rows = (area_h / grid_res).ceil() as i64;
|
||||
let (cx, cy) = cell_of(pos.x, pos.y, grid_res);
|
||||
let span = (r / grid_res).ceil() as i64;
|
||||
let mut new_cells = 0u32;
|
||||
for dgx in -span..=span {
|
||||
for dgy in -span..=span {
|
||||
let gx = cx as i64 + dgx;
|
||||
let gy = cy as i64 + dgy;
|
||||
if gx < 0 || gy < 0 || gx >= cols || gy >= rows {
|
||||
continue;
|
||||
}
|
||||
// Cell centre in metres.
|
||||
let mx = (gx as f64 + 0.5) * grid_res;
|
||||
let my = (gy as f64 + 0.5) * grid_res;
|
||||
if (mx - pos.x).hypot(my - pos.y) <= r && scanned.insert((gx as u32, gy as u32)) {
|
||||
new_cells += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
new_cells
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = parse_args();
|
||||
let cfg = config_for(&args.profile);
|
||||
let flight_pattern = FlightPattern::from_str(&args.flight_pattern);
|
||||
let learn_pattern = LearningPattern::from_str(&args.learn_pattern);
|
||||
|
||||
println!(
|
||||
"MARL training: profile={} drones={} episodes={} steps/ep={} flight={} learn={} ({})",
|
||||
args.profile,
|
||||
args.drones,
|
||||
args.episodes,
|
||||
args.steps_per_episode,
|
||||
flight_pattern.name(),
|
||||
learn_pattern.name(),
|
||||
if learn_pattern.centralized_critic() {
|
||||
"CTDE / centralized critic"
|
||||
} else {
|
||||
"independent learners"
|
||||
}
|
||||
);
|
||||
|
||||
let ppo_cfg = CandlePpoConfig::default();
|
||||
let mut trainer = CandleTrainer::new(ppo_cfg)?;
|
||||
println!("device: {:?}", trainer.net.device());
|
||||
|
||||
let reward_calc = RewardCalculator::default();
|
||||
std::fs::create_dir_all(&args.checkpoint_dir).ok();
|
||||
|
||||
let area_w = cfg.mission.area_width_m;
|
||||
let area_h = cfg.mission.area_height_m;
|
||||
let grid_res = cfg.mission.grid_resolution_m.max(1.0);
|
||||
let scan_w = cfg.planning.csi_scan_width_m;
|
||||
let max_speed = cfg.planning.max_speed_ms.max(0.1);
|
||||
let altitude_z = -cfg.planning.flight_altitude_m;
|
||||
let total_cells = ((area_w / grid_res).ceil() * (area_h / grid_res).ceil()).max(1.0);
|
||||
|
||||
// Synthetic victims placed within the mission area for reward signal.
|
||||
let victims = vec![
|
||||
Position3D { x: area_w * 0.2, y: area_h * 0.3, z: 0.0 },
|
||||
Position3D { x: area_w * 0.6, y: area_h * 0.45, z: 0.0 },
|
||||
];
|
||||
|
||||
// Composite profile label so the viewer header surfaces the active patterns.
|
||||
let profile_label = format!(
|
||||
"{} · flight={} · learn={}",
|
||||
args.profile,
|
||||
flight_pattern.name(),
|
||||
learn_pattern.name()
|
||||
);
|
||||
|
||||
// Optional telemetry recorder for the visualizer.
|
||||
let mut telem = match &args.telemetry {
|
||||
Some(path) => {
|
||||
let mut rec = TelemetryRecorder::create(path)?;
|
||||
rec.meta(&profile_label, args.drones, area_w, area_h, &victims)?;
|
||||
println!("telemetry → {path} (spatial steps from episode {})", args.telemetry_episode);
|
||||
Some(rec)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut best_return = f32::MIN;
|
||||
|
||||
for episode in 0..args.episodes {
|
||||
// Per-episode curiosity module (count-based novelty over the area).
|
||||
let mut curiosity = CuriosityModule::new(area_w, area_h, 32, 0.5);
|
||||
|
||||
// Build drone states directly so the FlightPattern fully drives motion.
|
||||
let cols = (args.drones as f64).sqrt().ceil().max(1.0) as usize;
|
||||
let mut states: Vec<DroneState> = (0..args.drones)
|
||||
.map(|d| {
|
||||
let (row, col) = (d / cols, d % cols);
|
||||
let mut s = DroneState::default_at_origin(NodeId(d as u32));
|
||||
s.position = Position3D {
|
||||
x: 10.0 + col as f64 * (area_w / cols as f64),
|
||||
y: 10.0 + row as f64 * (area_h / cols.max(1) as f64),
|
||||
z: altitude_z,
|
||||
};
|
||||
s.altitude_agl_m = cfg.planning.flight_altitude_m;
|
||||
s
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Coverage tracker (shared across drones — total area scanned).
|
||||
let mut scanned: HashSet<(u32, u32)> = HashSet::new();
|
||||
// Rolling recent-positions trail for pheromone/potential patterns.
|
||||
let mut visited: Vec<Position3D> = Vec::with_capacity(256);
|
||||
|
||||
// Rollout buffers (flattened across drones).
|
||||
let mut obs_buf: Vec<LocalObservation> = Vec::new();
|
||||
let mut action_buf: Vec<[f32; 4]> = Vec::new();
|
||||
let mut reward_buf: Vec<f32> = Vec::new();
|
||||
let mut value_buf: Vec<f32> = Vec::new();
|
||||
let mut done_buf: Vec<bool> = Vec::new();
|
||||
|
||||
for step in 0..args.steps_per_episode {
|
||||
let is_last = step == args.steps_per_episode - 1;
|
||||
|
||||
// Snapshot peer positions for this tick (observations + repulsion).
|
||||
let positions: Vec<(NodeId, Position3D)> =
|
||||
states.iter().map(|s| (s.id, s.position)).collect();
|
||||
|
||||
// Index needed: mutates states[idx] while reading peer positions; borrow constraints.
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for idx in 0..states.len() {
|
||||
let prev_pos = states[idx].position;
|
||||
let node_id = states[idx].id;
|
||||
|
||||
// Neighbour positions (everyone except this drone).
|
||||
let neighbors: Vec<(NodeId, Position3D)> = positions
|
||||
.iter()
|
||||
.filter(|(id, _)| *id != node_id)
|
||||
.cloned()
|
||||
.collect();
|
||||
let peers: Vec<Position3D> = neighbors.iter().map(|(_, p)| *p).collect();
|
||||
|
||||
// Observation from the current (pre-move) state.
|
||||
let obs =
|
||||
LocalObservation::from_state_no_grid(&states[idx], &neighbors, None, None);
|
||||
|
||||
// --- FlightPattern drives the next waypoint --------------------
|
||||
let ctx = PatternContext {
|
||||
drone_id: node_id,
|
||||
swarm_size: args.drones,
|
||||
current: prev_pos,
|
||||
area_w,
|
||||
area_h,
|
||||
altitude_z,
|
||||
scan_width_m: scan_w,
|
||||
step: step as u64,
|
||||
visited: &visited,
|
||||
peers: &peers,
|
||||
};
|
||||
let target = flight_pattern.next_target(&ctx);
|
||||
|
||||
// Move one tick toward the target at max_speed (no teleport).
|
||||
let dx = target.x - prev_pos.x;
|
||||
let dy = target.y - prev_pos.y;
|
||||
let dist = dx.hypot(dy);
|
||||
let new_pos = if dist > 1e-9 {
|
||||
let stepd = dist.min(max_speed);
|
||||
Position3D {
|
||||
x: prev_pos.x + dx / dist * stepd,
|
||||
y: prev_pos.y + dy / dist * stepd,
|
||||
z: altitude_z,
|
||||
}
|
||||
} else {
|
||||
prev_pos
|
||||
};
|
||||
let heading = if dist > 1e-9 { dy.atan2(dx) } else { states[idx].heading_rad };
|
||||
let moved = prev_pos.distance_to(&new_pos);
|
||||
|
||||
// Commit the move to the drone state.
|
||||
{
|
||||
let s = &mut states[idx];
|
||||
s.velocity = Velocity3D {
|
||||
vx: (new_pos.x - prev_pos.x),
|
||||
vy: (new_pos.y - prev_pos.y),
|
||||
vz: 0.0,
|
||||
};
|
||||
s.position = new_pos;
|
||||
s.heading_rad = heading;
|
||||
s.timestamp_ms = s.timestamp_ms.saturating_add(1000);
|
||||
}
|
||||
|
||||
// Coverage: mark scanned footprint, count new cells.
|
||||
let new_cells =
|
||||
mark_scanned(&mut scanned, &new_pos, scan_w, grid_res, area_w, area_h);
|
||||
|
||||
// Detection: any victim within the scan footprint.
|
||||
let detected = victims.iter().any(|v| new_pos.distance_to(v) < scan_w);
|
||||
|
||||
// Nearest-neighbour distance (for collision shaping).
|
||||
let nearest = peers
|
||||
.iter()
|
||||
.map(|p| new_pos.distance_to(p))
|
||||
.fold(f64::MAX, f64::min);
|
||||
|
||||
// Base extrinsic reward.
|
||||
let ctx_r = RewardContext {
|
||||
state: &states[idx],
|
||||
new_cells_covered: new_cells,
|
||||
victim_confirmed: detected,
|
||||
contributed_to_triangulation: false,
|
||||
nearest_neighbor_dist: nearest,
|
||||
geofence_breached: false,
|
||||
battery_depleted_without_rth: false,
|
||||
};
|
||||
let base = reward_calc.compute(&ctx_r);
|
||||
|
||||
// Curiosity shaping (only when the learning pattern uses it).
|
||||
let reward = if learn_pattern.uses_curiosity() {
|
||||
let bonus = curiosity.visit_bonus(new_pos.x, new_pos.y);
|
||||
shaped_reward(learn_pattern, base, bonus)
|
||||
} else {
|
||||
base
|
||||
};
|
||||
|
||||
let action = [
|
||||
heading as f32,
|
||||
states[idx].altitude_agl_m as f32,
|
||||
(moved / 1.0) as f32,
|
||||
0.0,
|
||||
];
|
||||
|
||||
obs_buf.push(obs);
|
||||
action_buf.push(action);
|
||||
reward_buf.push(reward);
|
||||
value_buf.push(0.0); // bootstrap value (critic learns this)
|
||||
done_buf.push(is_last);
|
||||
|
||||
// Record the move in the shared visited trail (cap length).
|
||||
visited.push(new_pos);
|
||||
}
|
||||
|
||||
// Trim the visited trail to the most recent ~200 positions.
|
||||
if visited.len() > 200 {
|
||||
let drop = visited.len() - 200;
|
||||
visited.drain(0..drop);
|
||||
}
|
||||
|
||||
// Record spatial telemetry for the selected episode only.
|
||||
if let Some(rec) = telem.as_mut() {
|
||||
if episode == args.telemetry_episode {
|
||||
let frames: Vec<DroneFrame> = states
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let detected =
|
||||
victims.iter().any(|v| s.position.distance_to(v) < scan_w);
|
||||
DroneFrame::from_state(s, detected)
|
||||
})
|
||||
.collect();
|
||||
let coverage = scanned.len() as f64 / total_cells;
|
||||
let _ = rec.step(episode, step, step as f64, &frames, coverage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PPO update on the episode's rollout.
|
||||
let (advantages, returns) = trainer.compute_gae(&reward_buf, &value_buf, &done_buf);
|
||||
let old_log_probs = vec![0.0f32; obs_buf.len()];
|
||||
let (policy_loss, value_loss, _entropy) =
|
||||
trainer.update(&obs_buf, &action_buf, &advantages, &returns, &old_log_probs)?;
|
||||
|
||||
let mean_return = if returns.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
returns.iter().sum::<f32>() / returns.len() as f32
|
||||
};
|
||||
|
||||
if mean_return > best_return {
|
||||
best_return = mean_return;
|
||||
}
|
||||
|
||||
// Per-episode training-metric telemetry (every episode).
|
||||
if let Some(rec) = telem.as_mut() {
|
||||
let _ = rec.episode(episode, mean_return, policy_loss, value_loss, 0);
|
||||
}
|
||||
|
||||
if episode % 10 == 0 || episode == args.episodes - 1 {
|
||||
let coverage_pct = scanned.len() as f64 / total_cells * 100.0;
|
||||
println!(
|
||||
"ep {:>5}/{} mean_return={:>8.3} best={:>8.3} policy_loss={:>8.4} value_loss={:>8.4} coverage={:>5.1}%",
|
||||
episode, args.episodes, mean_return, best_return, policy_loss, value_loss, coverage_pct
|
||||
);
|
||||
}
|
||||
|
||||
// Checkpoint the trained variables periodically.
|
||||
if args.checkpoint_every > 0 && (episode + 1) % args.checkpoint_every == 0
|
||||
|| episode == args.episodes - 1
|
||||
{
|
||||
let path = format!("{}/marl-ep{}.safetensors", args.checkpoint_dir, episode + 1);
|
||||
if let Err(e) = trainer.net.varmap().save(&path) {
|
||||
eprintln!("checkpoint save failed at {path}: {e}");
|
||||
} else {
|
||||
println!("checkpoint saved: {path}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rec) = telem.as_mut() {
|
||||
rec.flush()?;
|
||||
if let Some(path) = &args.telemetry {
|
||||
println!("telemetry written: {path} — open viz/swarm_viz.html and load it");
|
||||
}
|
||||
}
|
||||
|
||||
println!("training complete. best mean_return={best_return:.3}");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
//! TOML-based swarm configuration with mission profiles.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmConfig {
|
||||
pub swarm: SwarmParams,
|
||||
pub formation: FormationConfig,
|
||||
pub planning: PlanningConfig,
|
||||
pub security: SecurityConfig,
|
||||
pub mission: MissionConfig,
|
||||
pub demo: Option<DemoConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmParams {
|
||||
pub max_agents: usize,
|
||||
pub cluster_size: usize,
|
||||
pub raft_election_timeout_ms: u64,
|
||||
pub raft_heartbeat_ms: u64,
|
||||
pub gossip_fanout: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FormationConfig {
|
||||
/// "virtual_structure" | "leader_follower" | "reynolds"
|
||||
pub mode: String,
|
||||
pub min_separation_m: f64,
|
||||
pub grid_spacing_m: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanningConfig {
|
||||
pub flight_altitude_m: f64,
|
||||
pub max_speed_ms: f64,
|
||||
/// Wi2SAR validated scan footprint width.
|
||||
pub csi_scan_width_m: f64,
|
||||
pub lateral_overlap_pct: f64,
|
||||
/// P(victim) threshold to trigger Phase 3 convergence.
|
||||
pub convergence_threshold: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SecurityConfig {
|
||||
pub mavlink_signing: bool,
|
||||
pub uwb_antispoofing: bool,
|
||||
pub uwb_tolerance_m: f64,
|
||||
pub geofence_hard_margin_m: f64,
|
||||
pub geofence_soft_margin_m: f64,
|
||||
/// Remote ID broadcast rate in Hz (FAA/EU requirement: ≥ 1 Hz).
|
||||
pub remote_id_broadcast_hz: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MissionConfig {
|
||||
/// "sar" | "inspection" | "agriculture" | "mine" | "relay"
|
||||
pub profile: String,
|
||||
pub area_width_m: f64,
|
||||
pub area_height_m: f64,
|
||||
pub grid_resolution_m: f64,
|
||||
pub max_flight_time_mins: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DemoConfig {
|
||||
pub synthetic_csi: bool,
|
||||
/// Victim positions in NED [x, y, z].
|
||||
pub victim_positions: Vec<[f64; 3]>,
|
||||
pub wind_noise_ms: f64,
|
||||
pub csi_noise_std: f64,
|
||||
pub packet_loss_pct: f64,
|
||||
pub replay_speed: f64,
|
||||
}
|
||||
|
||||
impl SwarmConfig {
|
||||
pub fn from_toml_str(s: &str) -> Result<Self, toml::de::Error> {
|
||||
toml::from_str(s)
|
||||
}
|
||||
|
||||
pub fn sar_default() -> Self {
|
||||
Self {
|
||||
swarm: SwarmParams {
|
||||
max_agents: 12,
|
||||
cluster_size: 4,
|
||||
raft_election_timeout_ms: 300,
|
||||
raft_heartbeat_ms: 100,
|
||||
gossip_fanout: 3,
|
||||
},
|
||||
formation: FormationConfig {
|
||||
mode: "virtual_structure".into(),
|
||||
min_separation_m: 5.0,
|
||||
grid_spacing_m: 20.0,
|
||||
},
|
||||
planning: PlanningConfig {
|
||||
flight_altitude_m: 30.0,
|
||||
max_speed_ms: 8.0,
|
||||
csi_scan_width_m: 28.0,
|
||||
lateral_overlap_pct: 20.0,
|
||||
convergence_threshold: 0.75,
|
||||
},
|
||||
security: SecurityConfig {
|
||||
mavlink_signing: true,
|
||||
uwb_antispoofing: true,
|
||||
uwb_tolerance_m: 2.0,
|
||||
geofence_hard_margin_m: 20.0,
|
||||
geofence_soft_margin_m: 50.0,
|
||||
remote_id_broadcast_hz: 1.0,
|
||||
},
|
||||
mission: MissionConfig {
|
||||
profile: "sar".into(),
|
||||
area_width_m: 500.0,
|
||||
area_height_m: 500.0,
|
||||
grid_resolution_m: 5.0,
|
||||
max_flight_time_mins: 25.0,
|
||||
},
|
||||
demo: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inspection_default() -> Self {
|
||||
let mut cfg = Self::sar_default();
|
||||
cfg.mission.profile = "inspection".into();
|
||||
cfg.planning.flight_altitude_m = 15.0;
|
||||
cfg.planning.max_speed_ms = 4.0;
|
||||
cfg.formation.mode = "leader_follower".into();
|
||||
cfg
|
||||
}
|
||||
|
||||
pub fn agriculture_default() -> Self {
|
||||
let mut cfg = Self::sar_default();
|
||||
cfg.mission.profile = "agriculture".into();
|
||||
cfg.planning.flight_altitude_m = 10.0;
|
||||
cfg.planning.max_speed_ms = 6.0;
|
||||
cfg.planning.csi_scan_width_m = 15.0;
|
||||
cfg.formation.mode = "virtual_structure".into();
|
||||
cfg.formation.grid_spacing_m = 12.0;
|
||||
cfg
|
||||
}
|
||||
|
||||
pub fn mine_default() -> Self {
|
||||
let mut cfg = Self::sar_default();
|
||||
cfg.mission.profile = "mine".into();
|
||||
cfg.planning.flight_altitude_m = 5.0;
|
||||
cfg.planning.max_speed_ms = 2.0;
|
||||
cfg.security.uwb_antispoofing = true; // GPS-denied: UWB only
|
||||
cfg
|
||||
}
|
||||
|
||||
/// Wi2SAR reference configuration (400×400 m, 8 m/s, 4 drones) for ADR-148 SOTA benchmark.
|
||||
/// Produces 223 s coverage estimate — below the 240 s (4-min) SOTA target.
|
||||
/// Source: Wi2SAR (arxiv 2604.09115): single drone, 160,000 m², 13.5 min.
|
||||
pub fn wi2sar_reference() -> Self {
|
||||
let mut cfg = Self::sar_default();
|
||||
cfg.mission.area_width_m = 400.0;
|
||||
cfg.mission.area_height_m = 400.0;
|
||||
cfg.planning.max_speed_ms = 8.0;
|
||||
cfg.planning.csi_scan_width_m = 28.0;
|
||||
cfg.planning.lateral_overlap_pct = 20.0;
|
||||
cfg
|
||||
}
|
||||
|
||||
pub fn demo_default() -> Self {
|
||||
let mut cfg = Self::sar_default();
|
||||
cfg.demo = Some(DemoConfig {
|
||||
synthetic_csi: true,
|
||||
victim_positions: vec![[50.0, 80.0, 0.0], [150.0, 200.0, 0.0], [300.0, 100.0, 0.0]],
|
||||
wind_noise_ms: 2.0,
|
||||
csi_noise_std: 0.05,
|
||||
packet_loss_pct: 5.0,
|
||||
replay_speed: 1.0,
|
||||
});
|
||||
cfg
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sar_default_serialization() {
|
||||
let cfg = SwarmConfig::sar_default();
|
||||
let toml_str = toml::to_string(&cfg).expect("serialize ok");
|
||||
let parsed = SwarmConfig::from_toml_str(&toml_str).expect("parse ok");
|
||||
assert_eq!(parsed.mission.profile, "sar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_demo_default_has_victims() {
|
||||
let cfg = SwarmConfig::demo_default();
|
||||
assert!(cfg.demo.is_some());
|
||||
assert_eq!(cfg.demo.unwrap().victim_positions.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wi2sar_reference_coverage_within_4min() {
|
||||
use crate::demo::scenario::DemoScenario;
|
||||
let scenario = DemoScenario {
|
||||
name: "Wi2SAR Reference".into(),
|
||||
config: SwarmConfig::wi2sar_reference(),
|
||||
num_drones: 4,
|
||||
victims: vec![],
|
||||
};
|
||||
let t = scenario.estimate_coverage_time_secs();
|
||||
assert!(t < 240.0, "4-drone Wi2SAR reference scenario: {}s should be < 240s (4 min SOTA)", t);
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
//! Demo scenario runner — synthetic CSI with configurable victim positions.
|
||||
//!
|
||||
//! Wires together a [`SyntheticCsiGenerator`] and pre-built [`DemoScenario`]
|
||||
//! definitions for rapid scenario validation without real hardware.
|
||||
|
||||
pub mod synthetic_csi;
|
||||
pub mod scenario;
|
||||
|
||||
pub use synthetic_csi::SyntheticCsiGenerator;
|
||||
pub use scenario::{DemoScenario, ScenarioResult};
|
||||
@@ -1,150 +0,0 @@
|
||||
//! Pre-built demo scenarios for rapid validation without hardware.
|
||||
//!
|
||||
//! Each scenario bundles a [`SwarmConfig`], victim positions, and a
|
||||
//! [`SyntheticCsiGenerator`] so integration tests can drive a complete
|
||||
//! swarm sim-loop with one call.
|
||||
|
||||
use crate::{
|
||||
config::SwarmConfig,
|
||||
types::Position3D,
|
||||
};
|
||||
use super::synthetic_csi::SyntheticCsiGenerator;
|
||||
|
||||
/// A self-contained demo scenario.
|
||||
pub struct DemoScenario {
|
||||
pub name: String,
|
||||
pub config: SwarmConfig,
|
||||
pub num_drones: usize,
|
||||
pub victims: Vec<Position3D>,
|
||||
}
|
||||
|
||||
/// Aggregate results produced after running a scenario.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScenarioResult {
|
||||
pub victims_found: usize,
|
||||
pub victims_total: usize,
|
||||
pub coverage_time_secs: f64,
|
||||
pub localization_error_m: f64,
|
||||
pub collision_count: u32,
|
||||
}
|
||||
|
||||
impl DemoScenario {
|
||||
/// Standard SAR rubble-field: 3 victims in a 400 × 400 m area.
|
||||
pub fn sar_rubble_field(num_drones: usize) -> Self {
|
||||
Self {
|
||||
name: "SAR Rubble Field".into(),
|
||||
config: SwarmConfig::demo_default(),
|
||||
num_drones,
|
||||
victims: vec![
|
||||
Position3D { x: 50.0, y: 80.0, z: 0.0 },
|
||||
Position3D { x: 150.0, y: 200.0, z: 0.0 },
|
||||
Position3D { x: 300.0, y: 100.0, z: 0.0 },
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Open-field search: single victim, easy detection conditions.
|
||||
pub fn open_field_search(num_drones: usize) -> Self {
|
||||
Self {
|
||||
name: "Open Field Search".into(),
|
||||
config: SwarmConfig::demo_default(),
|
||||
num_drones,
|
||||
victims: vec![
|
||||
Position3D { x: 200.0, y: 150.0, z: 0.0 },
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Mine/GPS-denied: victims in a narrow corridor, low speed.
|
||||
pub fn mine_corridor(num_drones: usize) -> Self {
|
||||
let mut cfg = SwarmConfig::mine_default();
|
||||
cfg.demo = Some(crate::config::DemoConfig {
|
||||
synthetic_csi: true,
|
||||
victim_positions: vec![[30.0, 10.0, -2.0], [80.0, 15.0, -2.0]],
|
||||
wind_noise_ms: 0.1,
|
||||
csi_noise_std: 0.08,
|
||||
packet_loss_pct: 10.0,
|
||||
replay_speed: 0.5,
|
||||
});
|
||||
Self {
|
||||
name: "Mine Corridor GPS-Denied".into(),
|
||||
config: cfg,
|
||||
num_drones,
|
||||
victims: vec![
|
||||
Position3D { x: 30.0, y: 10.0, z: -2.0 },
|
||||
Position3D { x: 80.0, y: 15.0, z: -2.0 },
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a [`SyntheticCsiGenerator`] from this scenario's config and victims.
|
||||
pub fn make_csi_generator(&self) -> SyntheticCsiGenerator {
|
||||
let (noise_std, detection_range_m) = self.config.demo.as_ref().map(|d| {
|
||||
(d.csi_noise_std, self.config.planning.csi_scan_width_m / 2.0)
|
||||
}).unwrap_or((0.05, 14.0));
|
||||
|
||||
SyntheticCsiGenerator::new(self.victims.clone(), noise_std, detection_range_m)
|
||||
}
|
||||
|
||||
/// Analytic estimate of coverage time (seconds) for this scenario.
|
||||
///
|
||||
/// Formula: `area / (scan_strip × drones) / speed`
|
||||
///
|
||||
/// where `scan_strip = csi_scan_width_m × (1 − lateral_overlap / 100)`.
|
||||
pub fn estimate_coverage_time_secs(&self) -> f64 {
|
||||
let p = &self.config.planning;
|
||||
let m = &self.config.mission;
|
||||
let area = m.area_width_m * m.area_height_m;
|
||||
let scan_strip = p.csi_scan_width_m * (1.0 - p.lateral_overlap_pct / 100.0);
|
||||
if scan_strip <= 0.0 || p.max_speed_ms <= 0.0 || self.num_drones == 0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
let total_track_m = area / scan_strip;
|
||||
let per_drone_track = total_track_m / self.num_drones as f64;
|
||||
per_drone_track / p.max_speed_ms
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sar_scenario_coverage_estimate_within_10min() {
|
||||
// 4-drone SAR swarm over 500 × 500 m at 8 m/s, 20% overlap, 28 m scan width.
|
||||
// Analytic upper bound: area / (scan_strip × drones × speed)
|
||||
// = 250_000 / (22.4 × 4 × 8) ≈ 349 s (< 600 s = 10 min battery limit).
|
||||
let scenario = DemoScenario::sar_rubble_field(4);
|
||||
let t = scenario.estimate_coverage_time_secs();
|
||||
assert!(
|
||||
t < 600.0,
|
||||
"4-drone SAR coverage estimate {t:.1} s exceeds 600 s (10 min) battery limit"
|
||||
);
|
||||
// Also verify the estimate is positive and finite.
|
||||
assert!(t > 0.0 && t.is_finite(), "coverage estimate {t} must be positive and finite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_open_field_single_victim() {
|
||||
let scenario = DemoScenario::open_field_search(2);
|
||||
assert_eq!(scenario.victims.len(), 1);
|
||||
assert_eq!(scenario.num_drones, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mine_scenario_low_speed() {
|
||||
let scenario = DemoScenario::mine_corridor(2);
|
||||
assert!(
|
||||
scenario.config.planning.max_speed_ms <= 3.0,
|
||||
"mine scenario max speed should be ≤ 3 m/s, got {}",
|
||||
scenario.config.planning.max_speed_ms
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_csi_generator_victims_match() {
|
||||
let scenario = DemoScenario::sar_rubble_field(4);
|
||||
let gen = scenario.make_csi_generator();
|
||||
assert_eq!(gen.victims.len(), scenario.victims.len());
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
//! Synthetic CSI generator — simulates WiFi CSI victim detections without hardware.
|
||||
//!
|
||||
//! Uses exponential distance decay and configurable Gaussian noise to produce
|
||||
//! realistic CsiDetection events for scenario testing and demo mode.
|
||||
|
||||
use rand::Rng;
|
||||
use crate::types::{CsiDetection, NodeId, Position3D};
|
||||
|
||||
/// Generates synthetic CSI detection events for a set of victim positions.
|
||||
pub struct SyntheticCsiGenerator {
|
||||
/// Ground-truth victim positions in NED metres.
|
||||
pub victims: Vec<Position3D>,
|
||||
/// Std-dev of additive Gaussian noise on confidence and position estimate.
|
||||
pub noise_std: f64,
|
||||
/// Maximum range (metres) at which a drone can detect a victim.
|
||||
pub detection_range_m: f64,
|
||||
}
|
||||
|
||||
impl SyntheticCsiGenerator {
|
||||
pub fn new(victims: Vec<Position3D>, noise_std: f64, detection_range_m: f64) -> Self {
|
||||
Self { victims, noise_std, detection_range_m }
|
||||
}
|
||||
|
||||
/// Attempt to detect a victim from the given drone position.
|
||||
///
|
||||
/// Returns the strongest detection within range, or `None` if no victim
|
||||
/// is within `detection_range_m`. Confidence is modelled as
|
||||
/// `exp(-dist / range)` plus zero-mean Gaussian noise.
|
||||
pub fn detect(
|
||||
&self,
|
||||
drone_id: NodeId,
|
||||
drone_pos: &Position3D,
|
||||
timestamp_ms: u64,
|
||||
) -> Option<CsiDetection> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut best: Option<CsiDetection> = None;
|
||||
|
||||
for victim in &self.victims {
|
||||
let dist = drone_pos.distance_to(victim);
|
||||
if dist >= self.detection_range_m {
|
||||
continue;
|
||||
}
|
||||
// Exponential decay: full confidence at 0 m, ~37% at 1× range
|
||||
let base_conf = (-dist / self.detection_range_m).exp();
|
||||
let noise: f64 = rng.gen_range(-self.noise_std..self.noise_std);
|
||||
let confidence = (base_conf + noise).clamp(0.0, 1.0) as f32;
|
||||
|
||||
if confidence <= 0.4 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add positional noise proportional to noise_std
|
||||
let pos_jitter = self.noise_std * 10.0;
|
||||
let est_pos = Position3D {
|
||||
x: victim.x + rng.gen_range(-pos_jitter..pos_jitter),
|
||||
y: victim.y + rng.gen_range(-pos_jitter..pos_jitter),
|
||||
z: victim.z,
|
||||
};
|
||||
|
||||
let det = CsiDetection {
|
||||
drone_id,
|
||||
confidence,
|
||||
victim_position: Some(est_pos),
|
||||
timestamp_ms,
|
||||
};
|
||||
|
||||
// Keep the highest-confidence detection
|
||||
match &best {
|
||||
None => best = Some(det),
|
||||
Some(b) if det.confidence > b.confidence => best = Some(det),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
best
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_close_victim() {
|
||||
// A victim right on the drone should nearly always return a detection.
|
||||
// Run 20 trials; at least 15 should detect (0.4 threshold at distance 0).
|
||||
let gen = SyntheticCsiGenerator::new(
|
||||
vec![Position3D { x: 0.0, y: 0.0, z: 0.0 }],
|
||||
0.01,
|
||||
28.0,
|
||||
);
|
||||
let mut hits = 0u32;
|
||||
for i in 0..20 {
|
||||
if gen.detect(NodeId(0), &Position3D::zero(), i as u64).is_some() {
|
||||
hits += 1;
|
||||
}
|
||||
}
|
||||
assert!(hits >= 15, "expected ≥15/20 detections at zero range, got {hits}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_beyond_range_returns_none() {
|
||||
let gen = SyntheticCsiGenerator::new(
|
||||
vec![Position3D { x: 0.0, y: 0.0, z: 0.0 }],
|
||||
0.01,
|
||||
28.0,
|
||||
);
|
||||
let far_pos = Position3D { x: 1000.0, y: 1000.0, z: 0.0 };
|
||||
// All 10 attempts should return None since drone is 1414 m away.
|
||||
for i in 0..10 {
|
||||
assert!(
|
||||
gen.detect(NodeId(0), &far_pos, i).is_none(),
|
||||
"expected no detection at 1414 m"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_best_of_two_victims_returned() {
|
||||
// Two victims: one very close (high conf), one just at boundary (low conf).
|
||||
let gen = SyntheticCsiGenerator::new(
|
||||
vec![
|
||||
Position3D { x: 1.0, y: 0.0, z: 0.0 }, // close
|
||||
Position3D { x: 27.0, y: 0.0, z: 0.0 }, // near boundary
|
||||
],
|
||||
0.01,
|
||||
28.0,
|
||||
);
|
||||
// Run 10 trials; whenever both return a detection the close one should win.
|
||||
for i in 0..10 {
|
||||
if let Some(det) = gen.detect(NodeId(0), &Position3D::zero(), i) {
|
||||
assert!(
|
||||
det.confidence >= 0.4,
|
||||
"returned confidence {:.3} is below threshold",
|
||||
det.confidence
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
//! Geometric Dilution of Precision (GDOP) for a constellation of observers.
|
||||
//!
|
||||
//! GDOP quantifies how observer geometry amplifies measurement error into
|
||||
//! position-estimate error. Build the geometry matrix `H` of unit
|
||||
//! line-of-sight (LOS) vectors from each observer to the target, form the
|
||||
//! normal matrix `HᵀH`, invert it, and take `GDOP = sqrt(trace((HᵀH)⁻¹))`.
|
||||
//!
|
||||
//! For the 2-D (x, y) localization case `H` is `N×2` and `HᵀH` is `2×2`, so a
|
||||
//! closed-form 2×2 inverse suffices (no linear-algebra dependency needed).
|
||||
//!
|
||||
//! Lower GDOP = better geometry: observers spread ~120° apart around the target
|
||||
//! give low GDOP; (near-)collinear observers give a singular/ill-conditioned
|
||||
//! `HᵀH` → GDOP → ∞.
|
||||
|
||||
use crate::types::Position3D;
|
||||
|
||||
/// Geometric Dilution of Precision (2-D) for `observers` viewing a `target`.
|
||||
///
|
||||
/// Lower = better geometry. A ~120° constellation → low GDOP; collinear → very
|
||||
/// large (→∞). Returns `None` if fewer than two observers, if any observer is
|
||||
/// coincident with the target (undefined LOS), or if the geometry is singular
|
||||
/// / degenerate (collinear) so `HᵀH` is not invertible.
|
||||
pub fn gdop(observers: &[Position3D], target: &Position3D) -> Option<f64> {
|
||||
if observers.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Accumulate HᵀH directly (2×2 symmetric) from unit LOS vectors.
|
||||
// Row i of H is the unit vector from target → observer i in (x, y).
|
||||
let mut a = 0.0; // sum ux*ux
|
||||
let mut b = 0.0; // sum ux*uy
|
||||
let mut d = 0.0; // sum uy*uy
|
||||
|
||||
for obs in observers {
|
||||
let dx = obs.x - target.x;
|
||||
let dy = obs.y - target.y;
|
||||
let range = (dx * dx + dy * dy).sqrt();
|
||||
if range < 1e-9 {
|
||||
// Observer on top of the target → LOS undefined.
|
||||
return None;
|
||||
}
|
||||
let ux = dx / range;
|
||||
let uy = dy / range;
|
||||
a += ux * ux;
|
||||
b += ux * uy;
|
||||
d += uy * uy;
|
||||
}
|
||||
|
||||
// Determinant of HᵀH = [[a, b], [b, d]].
|
||||
let det = a * d - b * b;
|
||||
if det.abs() < 1e-12 {
|
||||
// Singular: observers are (near-)collinear with the target.
|
||||
return None;
|
||||
}
|
||||
|
||||
// (HᵀH)⁻¹ = 1/det * [[d, -b], [-b, a]]; trace = (d + a) / det.
|
||||
let trace_inv = (a + d) / det;
|
||||
if trace_inv <= 0.0 || !trace_inv.is_finite() {
|
||||
return None;
|
||||
}
|
||||
Some(trace_inv.sqrt())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn p(x: f64, y: f64) -> Position3D {
|
||||
Position3D { x, y, z: 0.0 }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_triangle_lower_than_collinear() {
|
||||
let target = p(0.0, 0.0);
|
||||
// Three observers at 120° around the target, radius 10.
|
||||
let r = 10.0;
|
||||
let triangle = [
|
||||
p(r * 0.0_f64.cos(), r * 0.0_f64.sin()),
|
||||
p(
|
||||
r * (2.0 * std::f64::consts::PI / 3.0).cos(),
|
||||
r * (2.0 * std::f64::consts::PI / 3.0).sin(),
|
||||
),
|
||||
p(
|
||||
r * (4.0 * std::f64::consts::PI / 3.0).cos(),
|
||||
r * (4.0 * std::f64::consts::PI / 3.0).sin(),
|
||||
),
|
||||
];
|
||||
// Three nearly-collinear observers (tiny y perturbation to stay invertible).
|
||||
let near_collinear = [p(5.0, 0.01), p(10.0, 0.0), p(15.0, 0.01)];
|
||||
|
||||
let tri = gdop(&triangle, &target).expect("triangle finite GDOP");
|
||||
let col = gdop(&near_collinear, &target).expect("near-collinear finite GDOP");
|
||||
assert!(tri.is_finite(), "triangle GDOP must be finite: {tri}");
|
||||
assert!(
|
||||
tri < col,
|
||||
"120° constellation should have lower GDOP than near-collinear: tri={tri}, col={col}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collinear_degenerate() {
|
||||
let target = p(0.0, 0.0);
|
||||
// Perfectly collinear observers along +x → singular HᵀH.
|
||||
let collinear = [p(5.0, 0.0), p(10.0, 0.0), p(20.0, 0.0)];
|
||||
let g = gdop(&collinear, &target);
|
||||
assert!(
|
||||
g.is_none() || g.unwrap() > 1e6,
|
||||
"perfectly collinear geometry must be None or huge, got {g:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_observer_none() {
|
||||
let target = p(0.0, 0.0);
|
||||
assert!(gdop(&[p(5.0, 5.0)], &target).is_none());
|
||||
assert!(gdop(&[], &target).is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
//! Per-episode and aggregate SAR + MARL metrics (ADR-171 Stage 1).
|
||||
|
||||
use crate::evals::stats::{stratified_bootstrap_ci, ConfidenceInterval};
|
||||
|
||||
/// Per-episode SAR metrics (Stage 1 kinematic).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EpisodeMetrics {
|
||||
/// Fraction of the mission area scanned at least once, in [0, 1].
|
||||
pub coverage_pct: f64,
|
||||
/// Localization error (m) of the fused victim estimate; `None` if no detection.
|
||||
pub localization_error_m: Option<f64>,
|
||||
/// GDOP of the contributing-drone constellation at detection; `None` if none.
|
||||
pub gdop_at_detection: Option<f64>,
|
||||
/// Mission-elapsed seconds to first detection; `None` if no detection.
|
||||
pub time_to_first_detection_s: Option<f64>,
|
||||
/// Whether at least one victim was detected this episode.
|
||||
pub detected: bool,
|
||||
/// Count of inter-drone proximity violations (kinematic proxy for collisions).
|
||||
pub collisions: u32,
|
||||
/// Fraction of scanned area covered by more than one drone, in [0, 1].
|
||||
pub overlap_ratio: f64,
|
||||
/// Scalar episodic return (reward-like coverage/detection objective).
|
||||
pub episodic_return: f64,
|
||||
}
|
||||
|
||||
/// Aggregate over a seed × episode matrix with IQM + 95% bootstrap CIs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AggregateMetrics {
|
||||
pub coverage_iqm: ConfidenceInterval,
|
||||
/// IQM over detected episodes only (undetected episodes carry no error).
|
||||
pub localization_iqm: ConfidenceInterval,
|
||||
pub detection_rate: f64,
|
||||
pub mean_gdop: f64,
|
||||
pub return_iqm: ConfidenceInterval,
|
||||
pub n_episodes: usize,
|
||||
}
|
||||
|
||||
impl AggregateMetrics {
|
||||
/// Aggregate a seed-stratified matrix of episodes. Each inner `Vec` is one
|
||||
/// seed's episodes; bootstrap resampling is stratified by seed so the CI
|
||||
/// reflects between-seed variance (the dominant source per ADR-171).
|
||||
pub fn from_strata(per_seed: &[Vec<EpisodeMetrics>], boot_seed: u64) -> Self {
|
||||
const N_BOOT: usize = 1000;
|
||||
|
||||
let coverage_strata: Vec<Vec<f64>> = per_seed
|
||||
.iter()
|
||||
.map(|s| s.iter().map(|e| e.coverage_pct).collect())
|
||||
.collect();
|
||||
let return_strata: Vec<Vec<f64>> = per_seed
|
||||
.iter()
|
||||
.map(|s| s.iter().map(|e| e.episodic_return).collect())
|
||||
.collect();
|
||||
// Localization: only detected episodes contribute. Keep stratification
|
||||
// by seed but drop empty strata so the bootstrap doesn't degenerate.
|
||||
let loc_strata: Vec<Vec<f64>> = per_seed
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.iter()
|
||||
.filter_map(|e| e.localization_error_m)
|
||||
.collect::<Vec<f64>>()
|
||||
})
|
||||
.filter(|v: &Vec<f64>| !v.is_empty())
|
||||
.collect();
|
||||
|
||||
let mut detected = 0usize;
|
||||
let mut total = 0usize;
|
||||
let mut gdop_sum = 0.0;
|
||||
let mut gdop_n = 0usize;
|
||||
for seed in per_seed {
|
||||
for e in seed {
|
||||
total += 1;
|
||||
if e.detected {
|
||||
detected += 1;
|
||||
}
|
||||
if let Some(g) = e.gdop_at_detection {
|
||||
if g.is_finite() {
|
||||
gdop_sum += g;
|
||||
gdop_n += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let detection_rate = if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
detected as f64 / total as f64
|
||||
};
|
||||
let mean_gdop = if gdop_n == 0 {
|
||||
0.0
|
||||
} else {
|
||||
gdop_sum / gdop_n as f64
|
||||
};
|
||||
|
||||
AggregateMetrics {
|
||||
coverage_iqm: stratified_bootstrap_ci(&coverage_strata, N_BOOT, boot_seed),
|
||||
localization_iqm: stratified_bootstrap_ci(
|
||||
&loc_strata,
|
||||
N_BOOT,
|
||||
boot_seed.wrapping_add(1),
|
||||
),
|
||||
detection_rate,
|
||||
mean_gdop,
|
||||
return_iqm: stratified_bootstrap_ci(
|
||||
&return_strata,
|
||||
N_BOOT,
|
||||
boot_seed.wrapping_add(2),
|
||||
),
|
||||
n_episodes: total,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn ep(cov: f64, loc: Option<f64>, ret: f64, detected: bool) -> EpisodeMetrics {
|
||||
EpisodeMetrics {
|
||||
coverage_pct: cov,
|
||||
localization_error_m: loc,
|
||||
gdop_at_detection: if detected { Some(2.0) } else { None },
|
||||
time_to_first_detection_s: if detected { Some(10.0) } else { None },
|
||||
detected,
|
||||
collisions: 0,
|
||||
overlap_ratio: 0.1,
|
||||
episodic_return: ret,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_detection_rate_and_shape() {
|
||||
let per_seed = vec![
|
||||
vec![
|
||||
ep(0.8, Some(1.5), 80.0, true),
|
||||
ep(0.7, None, 70.0, false),
|
||||
],
|
||||
vec![
|
||||
ep(0.9, Some(2.0), 90.0, true),
|
||||
ep(0.85, Some(1.0), 85.0, true),
|
||||
],
|
||||
];
|
||||
let agg = AggregateMetrics::from_strata(&per_seed, 7);
|
||||
assert_eq!(agg.n_episodes, 4);
|
||||
assert!((agg.detection_rate - 0.75).abs() < 1e-9);
|
||||
assert!(agg.coverage_iqm.lo <= agg.coverage_iqm.point);
|
||||
assert!(agg.coverage_iqm.point <= agg.coverage_iqm.hi);
|
||||
assert!(agg.mean_gdop > 0.0);
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//! ADR-171 statistically-rigorous evaluation harness (Stage 1, kinematic).
|
||||
//!
|
||||
//! Produces SAR + MARL metrics over a seeded N-seed × M-episode matrix with
|
||||
//! IQM + 95% stratified-bootstrap CIs, a (sigma, kappa) CSI-noise sweep, and
|
||||
//! GDOP-stratified localization error. Generates evals/RESULTS.md.
|
||||
//!
|
||||
//! Stage 2 (Gazebo/PX4 SITL high-fidelity, false-alarm + collision rate on the
|
||||
//! median seeds) is a follow-on — see ADR-171 §6.1.
|
||||
pub mod gdop;
|
||||
pub mod stats;
|
||||
pub mod metrics;
|
||||
pub mod runner;
|
||||
pub mod report;
|
||||
|
||||
pub use gdop::gdop;
|
||||
pub use stats::{iqm, stratified_bootstrap_ci, ConfidenceInterval};
|
||||
pub use metrics::{EpisodeMetrics, AggregateMetrics};
|
||||
pub use runner::{EvalConfig, NoiseLevel, run_matrix};
|
||||
pub use report::render_results_md;
|
||||
@@ -1,120 +0,0 @@
|
||||
//! RESULTS.md leaderboard generator (ADR-171 Stage 1).
|
||||
|
||||
use crate::evals::metrics::AggregateMetrics;
|
||||
use crate::evals::stats::ConfidenceInterval;
|
||||
|
||||
/// Wi2SAR published localization baseline (paper-to-paper), metres.
|
||||
const WI2SAR_LOCALIZATION_M: f64 = 5.0;
|
||||
|
||||
/// Format a CI as `point [lo, hi]` with two decimals.
|
||||
fn fmt_ci(ci: &ConfidenceInterval) -> String {
|
||||
format!("{:.3} [{:.3}, {:.3}]", ci.point, ci.lo, ci.hi)
|
||||
}
|
||||
|
||||
/// Render a markdown leaderboard: one row per flight pattern with coverage
|
||||
/// IQM±CI, localization IQM±CI, detection rate, and mean GDOP — plus the
|
||||
/// Wi2SAR paper baseline row clearly labelled paper-to-paper.
|
||||
///
|
||||
/// `rows` is `(pattern_name, aggregate)`; rows are emitted in the order given,
|
||||
/// so callers should pre-sort (e.g. by descending coverage point estimate).
|
||||
pub fn render_results_md(rows: &[(String, AggregateMetrics)]) -> String {
|
||||
let mut s = String::new();
|
||||
s.push_str("# ruview-swarm Evaluation Results (ADR-171 Stage 1, kinematic)\n\n");
|
||||
s.push_str(
|
||||
"Statistically-rigorous evaluation harness: seeded multi-run rollouts with \
|
||||
IQM + 95% stratified-bootstrap confidence intervals (Agarwal et al., \
|
||||
NeurIPS 2021).\n\n",
|
||||
);
|
||||
|
||||
// Run configuration header.
|
||||
let (n_episodes, n_seeds) = rows
|
||||
.first()
|
||||
.map(|(_, a)| {
|
||||
let n = a.n_episodes;
|
||||
// Episodes-per-seed isn't stored; report total + leave seed split to caller note.
|
||||
(n, 0usize)
|
||||
})
|
||||
.unwrap_or((0, 0));
|
||||
s.push_str("## Run configuration\n\n");
|
||||
s.push_str(&format!(
|
||||
"- **Stage**: 1 (kinematic, self-contained, deterministic per seed)\n\
|
||||
- **Episodes per pattern**: {n_episodes} (seed × episode matrix)\n\
|
||||
- **CI method**: 95% stratified bootstrap of the IQM, stratified by seed\n\
|
||||
- **GDOP**: 2-D geometric dilution of precision at first detection\n"
|
||||
));
|
||||
let _ = n_seeds;
|
||||
s.push_str(
|
||||
"\n> **Stage 2 pending**: high-fidelity Gazebo/PX4 SITL evaluation \
|
||||
(false-alarm rate, real collision rate on the median seeds) is a \
|
||||
follow-on — see ADR-171 §6.1. The collision figures below are a \
|
||||
kinematic min-separation proxy, not SITL physics.\n\n",
|
||||
);
|
||||
|
||||
// Leaderboard table.
|
||||
s.push_str("## Flight-pattern leaderboard\n\n");
|
||||
s.push_str(
|
||||
"| Flight pattern | Coverage IQM [95% CI] | Localization (m) IQM [95% CI] | \
|
||||
Detection rate | Mean GDOP |\n",
|
||||
);
|
||||
s.push_str(
|
||||
"|----------------|-----------------------|-------------------------------|\
|
||||
----------------|-----------|\n",
|
||||
);
|
||||
for (name, agg) in rows {
|
||||
s.push_str(&format!(
|
||||
"| {} | {} | {} | {:.1}% | {:.3} |\n",
|
||||
name,
|
||||
fmt_ci(&agg.coverage_iqm),
|
||||
fmt_ci(&agg.localization_iqm),
|
||||
agg.detection_rate * 100.0,
|
||||
agg.mean_gdop,
|
||||
));
|
||||
}
|
||||
// Wi2SAR paper baseline row (paper-to-paper, no kinematic re-run).
|
||||
s.push_str(&format!(
|
||||
"| _Wi2SAR (paper baseline)_ | _n/a_ | _{:.1} (paper)_ | _n/a_ | _n/a_ |\n",
|
||||
WI2SAR_LOCALIZATION_M,
|
||||
));
|
||||
|
||||
s.push_str(
|
||||
"\n_Wi2SAR row is the published single-drone localization figure \
|
||||
(arxiv 2604.09115), shown paper-to-paper for reference only — it was \
|
||||
not re-run through this kinematic harness._\n",
|
||||
);
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::evals::stats::ConfidenceInterval;
|
||||
|
||||
fn agg(cov: f64, det: f64) -> AggregateMetrics {
|
||||
let ci = |p: f64| ConfidenceInterval { point: p, lo: p - 0.05, hi: p + 0.05 };
|
||||
AggregateMetrics {
|
||||
coverage_iqm: ci(cov),
|
||||
localization_iqm: ci(1.5),
|
||||
detection_rate: det,
|
||||
mean_gdop: 2.1,
|
||||
return_iqm: ci(80.0),
|
||||
n_episodes: 100,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_contains_rows_and_baseline() {
|
||||
let rows = vec![
|
||||
("partitioned_lawnmower".to_string(), agg(0.92, 0.95)),
|
||||
("levy_flight".to_string(), agg(0.40, 0.50)),
|
||||
];
|
||||
let md = render_results_md(&rows);
|
||||
assert!(md.contains("partitioned_lawnmower"));
|
||||
assert!(md.contains("levy_flight"));
|
||||
assert!(md.contains("Wi2SAR"));
|
||||
assert!(md.contains("Stage 2 pending"));
|
||||
assert!(md.contains("95% stratified bootstrap"));
|
||||
// Coverage point estimate appears.
|
||||
assert!(md.contains("0.920"));
|
||||
}
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
//! Stage-1 kinematic rollout + seed × episode matrix (ADR-171).
|
||||
//!
|
||||
//! A single `run_episode` deterministically drives `drones` drones across a
|
||||
//! mission area under a chosen [`FlightPattern`], marks coverage on a grid,
|
||||
//! simulates CSI victim detection perturbed by `(sigma, kappa)` amplitude /
|
||||
//! von-Mises-phase noise, and computes the GDOP of the contributing-drone
|
||||
//! constellation at first detection. It is self-contained and seeded — no
|
||||
//! Candle / training backend required — so it runs in CI by default.
|
||||
|
||||
use crate::config::SwarmConfig;
|
||||
use crate::evals::gdop::gdop;
|
||||
use crate::evals::metrics::EpisodeMetrics;
|
||||
use crate::planning::patterns::{FlightPattern, PatternContext};
|
||||
use crate::types::{NodeId, Position3D};
|
||||
|
||||
/// CSI-noise level: amplitude std `sigma` and von-Mises phase concentration `kappa`.
|
||||
/// Higher `sigma` = noisier amplitude; *lower* `kappa` = noisier phase (more diffuse).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NoiseLevel {
|
||||
pub sigma: f64,
|
||||
pub kappa: f64,
|
||||
}
|
||||
|
||||
/// One evaluation configuration: a flight pattern + swarm/mission parameters.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvalConfig {
|
||||
pub flight: FlightPattern,
|
||||
pub config: SwarmConfig,
|
||||
pub drones: usize,
|
||||
pub steps: usize,
|
||||
pub seeds: usize, // ≥10 per ADR-171
|
||||
pub episodes_per_seed: usize, // e.g. 50
|
||||
pub victims: Vec<Position3D>,
|
||||
pub noise: NoiseLevel,
|
||||
}
|
||||
|
||||
impl EvalConfig {
|
||||
/// A small SAR default suitable for fast CI runs.
|
||||
pub fn sar_small(flight: FlightPattern) -> Self {
|
||||
EvalConfig {
|
||||
flight,
|
||||
config: SwarmConfig::sar_default(),
|
||||
drones: 4,
|
||||
steps: 120,
|
||||
seeds: 10,
|
||||
episodes_per_seed: 10,
|
||||
victims: vec![
|
||||
Position3D { x: 120.0, y: 90.0, z: 0.0 },
|
||||
Position3D { x: 320.0, y: 280.0, z: 0.0 },
|
||||
],
|
||||
noise: NoiseLevel { sigma: 0.05, kappa: 8.0 },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal reproducible LCG → f64 in [0, 1). Self-contained for determinism.
|
||||
struct Lcg(u64);
|
||||
impl Lcg {
|
||||
fn new(seed: u64) -> Self {
|
||||
Lcg(seed ^ 0xD1B5_4A32_D192_ED03)
|
||||
}
|
||||
#[inline]
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
self.0 = self
|
||||
.0
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
self.0
|
||||
}
|
||||
#[inline]
|
||||
fn unit(&mut self) -> f64 {
|
||||
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
|
||||
}
|
||||
/// Standard-normal sample via Box–Muller (deterministic).
|
||||
#[inline]
|
||||
fn normal(&mut self) -> f64 {
|
||||
let u1 = self.unit().max(1e-12);
|
||||
let u2 = self.unit();
|
||||
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one kinematic episode deterministically from `seed`.
|
||||
///
|
||||
/// Drives drones step-by-step by the flight pattern, marks a coarse coverage
|
||||
/// grid, and on the first step a drone comes within scan range of any victim
|
||||
/// records a fused localization estimate (weighted centroid of contributing
|
||||
/// drones' per-drone victim estimates, each perturbed by `(sigma, kappa)`
|
||||
/// noise) and the GDOP of those contributing drones.
|
||||
pub fn run_episode(cfg: &EvalConfig, seed: u64) -> EpisodeMetrics {
|
||||
let mut rng = Lcg::new(seed);
|
||||
|
||||
let area_w = cfg.config.mission.area_width_m;
|
||||
let area_h = cfg.config.mission.area_height_m;
|
||||
let altitude_z = -cfg.config.planning.flight_altitude_m;
|
||||
let scan_width = cfg.config.planning.csi_scan_width_m.max(1.0);
|
||||
let min_sep = cfg.config.formation.min_separation_m.max(0.1);
|
||||
let n = cfg.drones.max(1);
|
||||
|
||||
// Coverage grid sized so each cell ~= scan_width.
|
||||
let gx = ((area_w / scan_width).ceil() as usize).max(1);
|
||||
let gy = ((area_h / scan_width).ceil() as usize).max(1);
|
||||
let cell_w = area_w / gx as f64;
|
||||
let cell_h = area_h / gy as f64;
|
||||
let mut cover_count = vec![0u32; gx * gy];
|
||||
|
||||
// Spread drones along the bottom edge with a small seeded jitter.
|
||||
let mut positions: Vec<Position3D> = (0..n)
|
||||
.map(|i| {
|
||||
let frac = (i as f64 + 0.5) / n as f64;
|
||||
Position3D {
|
||||
x: (frac * area_w + (rng.unit() - 0.5) * scan_width).clamp(0.0, area_w),
|
||||
y: (rng.unit() * scan_width).clamp(0.0, area_h),
|
||||
z: altitude_z,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Recent-visit ring buffer for pheromone / potential-field patterns.
|
||||
let mut visited: Vec<Position3D> = Vec::new();
|
||||
let max_visited = 32usize;
|
||||
|
||||
let scan_range = scan_width; // detect a victim within one scan footprint
|
||||
let mut collisions = 0u32;
|
||||
let mut detected = false;
|
||||
let mut loc_error: Option<f64> = None;
|
||||
let mut gdop_val: Option<f64> = None;
|
||||
let mut t_detect: Option<f64> = None;
|
||||
|
||||
let dt = step_seconds(cfg);
|
||||
|
||||
for step in 0..cfg.steps {
|
||||
// Advance each drone one waypoint under the pattern.
|
||||
let snapshot = positions.clone();
|
||||
for (i, pos) in positions.iter_mut().enumerate() {
|
||||
let peers: Vec<Position3D> = snapshot
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(j, _)| *j != i)
|
||||
.map(|(_, p)| *p)
|
||||
.collect();
|
||||
let ctx = PatternContext {
|
||||
drone_id: NodeId(i as u32),
|
||||
swarm_size: n,
|
||||
current: *pos,
|
||||
area_w,
|
||||
area_h,
|
||||
altitude_z,
|
||||
scan_width_m: scan_width,
|
||||
step: step as u64,
|
||||
visited: &visited,
|
||||
peers: &peers,
|
||||
};
|
||||
*pos = cfg.flight.next_target(&ctx);
|
||||
}
|
||||
|
||||
// Mark coverage + record visits.
|
||||
for pos in &positions {
|
||||
let cx = ((pos.x / cell_w).floor() as i64).clamp(0, gx as i64 - 1) as usize;
|
||||
let cy = ((pos.y / cell_h).floor() as i64).clamp(0, gy as i64 - 1) as usize;
|
||||
cover_count[cy * gx + cx] = cover_count[cy * gx + cx].saturating_add(1);
|
||||
visited.push(*pos);
|
||||
}
|
||||
if visited.len() > max_visited {
|
||||
let drop = visited.len() - max_visited;
|
||||
visited.drain(0..drop);
|
||||
}
|
||||
|
||||
// Proximity / collision check (kinematic proxy).
|
||||
for a in 0..positions.len() {
|
||||
for b in (a + 1)..positions.len() {
|
||||
let d = positions[a].distance_to(&positions[b]);
|
||||
if d < min_sep {
|
||||
collisions = collisions.saturating_add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detection: first step any victim falls within scan range of ≥1 drone,
|
||||
// fuse a localization estimate from the contributing drones. A single
|
||||
// contributor still yields a (noisier) estimate; GDOP is only defined
|
||||
// for the multistatic ≥2-drone case and is `None` otherwise.
|
||||
if !detected {
|
||||
for victim in &cfg.victims {
|
||||
let contributors: Vec<Position3D> = positions
|
||||
.iter()
|
||||
.filter(|p| horiz_dist(p, victim) <= scan_range)
|
||||
.copied()
|
||||
.collect();
|
||||
if !contributors.is_empty() {
|
||||
let (est, g) = fuse_estimate(&contributors, victim, cfg.noise, &mut rng);
|
||||
loc_error = Some(horiz_dist(&est, victim));
|
||||
gdop_val = g; // None for a single contributor
|
||||
t_detect = Some((step as f64 + 1.0) * dt);
|
||||
detected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Coverage + overlap.
|
||||
let total_cells = (gx * gy) as f64;
|
||||
let scanned = cover_count.iter().filter(|&&c| c > 0).count() as f64;
|
||||
let overlapped = cover_count.iter().filter(|&&c| c > 1).count() as f64;
|
||||
let coverage_pct = if total_cells > 0.0 { scanned / total_cells } else { 0.0 };
|
||||
let overlap_ratio = if scanned > 0.0 { overlapped / scanned } else { 0.0 };
|
||||
|
||||
// Episodic return: reward coverage + detection, penalize overlap + collisions.
|
||||
let detect_bonus = if detected { 1.0 } else { 0.0 };
|
||||
let loc_term = match loc_error {
|
||||
Some(e) => (1.0 / (1.0 + e)).max(0.0),
|
||||
None => 0.0,
|
||||
};
|
||||
let episodic_return = 100.0 * coverage_pct + 30.0 * detect_bonus + 20.0 * loc_term
|
||||
- 10.0 * overlap_ratio
|
||||
- 5.0 * collisions as f64;
|
||||
|
||||
EpisodeMetrics {
|
||||
coverage_pct,
|
||||
localization_error_m: loc_error,
|
||||
gdop_at_detection: gdop_val,
|
||||
time_to_first_detection_s: t_detect,
|
||||
detected,
|
||||
collisions,
|
||||
overlap_ratio,
|
||||
episodic_return,
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-step wall-clock seconds, derived from scan width and drone speed.
|
||||
fn step_seconds(cfg: &EvalConfig) -> f64 {
|
||||
let speed = cfg.config.planning.max_speed_ms.max(0.1);
|
||||
(cfg.config.planning.csi_scan_width_m.max(1.0) / speed).max(0.1)
|
||||
}
|
||||
|
||||
/// Horizontal (x, y) distance, ignoring altitude.
|
||||
fn horiz_dist(a: &Position3D, b: &Position3D) -> f64 {
|
||||
(a.x - b.x).hypot(a.y - b.y)
|
||||
}
|
||||
|
||||
/// Fuse contributing drones' per-drone victim estimates into a weighted
|
||||
/// centroid, perturbed by `(sigma, kappa)` CSI noise, and compute the GDOP of
|
||||
/// the contributing constellation.
|
||||
fn fuse_estimate(
|
||||
contributors: &[Position3D],
|
||||
victim: &Position3D,
|
||||
noise: NoiseLevel,
|
||||
rng: &mut Lcg,
|
||||
) -> (Position3D, Option<f64>) {
|
||||
// Phase noise std from von Mises concentration: sigma_phase ≈ 1/sqrt(kappa).
|
||||
let phase_std = 1.0 / noise.kappa.max(1e-3).sqrt();
|
||||
let mut sx = 0.0;
|
||||
let mut sy = 0.0;
|
||||
let mut wsum = 0.0;
|
||||
for c in contributors {
|
||||
let range = horiz_dist(c, victim).max(1e-6);
|
||||
// Each drone's estimate = true victim + range-scaled amplitude noise +
|
||||
// bearing error from phase noise (perpendicular to LOS).
|
||||
let amp = noise.sigma * range;
|
||||
let nx = rng.normal() * amp;
|
||||
let ny = rng.normal() * amp;
|
||||
// Bearing wobble: rotate LOS unit vector by a small phase-noise angle.
|
||||
let bearing = (victim.y - c.y).atan2(victim.x - c.x);
|
||||
let dtheta = rng.normal() * phase_std;
|
||||
let bx = range * (bearing + dtheta).cos();
|
||||
let by = range * (bearing + dtheta).sin();
|
||||
let est_x = c.x + bx + nx;
|
||||
let est_y = c.y + by + ny;
|
||||
// Inverse-range weighting: closer drones trusted more.
|
||||
let w = 1.0 / range;
|
||||
sx += est_x * w;
|
||||
sy += est_y * w;
|
||||
wsum += w;
|
||||
}
|
||||
let w = wsum.max(1e-9);
|
||||
let est = Position3D { x: sx / w, y: sy / w, z: 0.0 };
|
||||
let g = gdop(contributors, victim);
|
||||
(est, g)
|
||||
}
|
||||
|
||||
/// Run the full seed × episode matrix → per-seed strata of [`EpisodeMetrics`].
|
||||
pub fn run_matrix(cfg: &EvalConfig) -> Vec<Vec<EpisodeMetrics>> {
|
||||
(0..cfg.seeds)
|
||||
.map(|s| {
|
||||
(0..cfg.episodes_per_seed)
|
||||
.map(|e| {
|
||||
// Distinct deterministic seed per (seed, episode) cell.
|
||||
let cell_seed = (s as u64)
|
||||
.wrapping_mul(0x100_0000)
|
||||
.wrapping_add(e as u64)
|
||||
.wrapping_add(0xABCD);
|
||||
run_episode(cfg, cell_seed)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Standard ADR-171 noise sweep grid: cartesian product of σ × κ levels.
|
||||
pub fn default_noise_sweep() -> Vec<NoiseLevel> {
|
||||
let sigmas = [0.02, 0.05, 0.10];
|
||||
let kappas = [16.0, 8.0, 4.0];
|
||||
let mut out = Vec::with_capacity(sigmas.len() * kappas.len());
|
||||
for &sigma in &sigmas {
|
||||
for &kappa in &kappas {
|
||||
out.push(NoiseLevel { sigma, kappa });
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_run_episode_deterministic() {
|
||||
let cfg = EvalConfig::sar_small(FlightPattern::PartitionedLawnmower);
|
||||
let a = run_episode(&cfg, 12345);
|
||||
let b = run_episode(&cfg, 12345);
|
||||
assert_eq!(a.coverage_pct, b.coverage_pct);
|
||||
assert_eq!(a.detected, b.detected);
|
||||
assert_eq!(a.localization_error_m, b.localization_error_m);
|
||||
assert_eq!(a.collisions, b.collisions);
|
||||
assert_eq!(a.episodic_return, b.episodic_return);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partitioned_beats_levy_coverage() {
|
||||
let mut part = EvalConfig::sar_small(FlightPattern::PartitionedLawnmower);
|
||||
part.seeds = 3;
|
||||
part.episodes_per_seed = 5;
|
||||
let mut levy = part.clone();
|
||||
levy.flight = FlightPattern::LevyFlight;
|
||||
|
||||
let part_m = run_matrix(&part);
|
||||
let levy_m = run_matrix(&levy);
|
||||
let part_agg = crate::evals::metrics::AggregateMetrics::from_strata(&part_m, 1);
|
||||
let levy_agg = crate::evals::metrics::AggregateMetrics::from_strata(&levy_m, 1);
|
||||
assert!(
|
||||
part_agg.coverage_iqm.point > levy_agg.coverage_iqm.point,
|
||||
"partitioned coverage {} should beat levy {}",
|
||||
part_agg.coverage_iqm.point,
|
||||
levy_agg.coverage_iqm.point
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_shape() {
|
||||
let mut cfg = EvalConfig::sar_small(FlightPattern::Spiral);
|
||||
cfg.seeds = 4;
|
||||
cfg.episodes_per_seed = 6;
|
||||
let m = run_matrix(&cfg);
|
||||
assert_eq!(m.len(), 4);
|
||||
assert!(m.iter().all(|s| s.len() == 6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_noise_sweep_grid() {
|
||||
let sweep = default_noise_sweep();
|
||||
assert_eq!(sweep.len(), 9);
|
||||
}
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
//! Hand-rolled robust statistics for the evaluation harness (Agarwal 2021).
|
||||
//!
|
||||
//! Implements the interquartile mean (IQM), a 95% stratified-bootstrap
|
||||
//! confidence interval of the IQM, and the probability-of-improvement metric —
|
||||
//! the three statistics recommended by "Deep RL at the Edge of the
|
||||
//! Statistical Precipice" (Agarwal et al., NeurIPS 2021) for reporting
|
||||
//! few-seed RL results.
|
||||
//!
|
||||
//! All randomness comes from a local linear-congruential generator (LCG) seeded
|
||||
//! explicitly, so every CI is fully reproducible — no `thread_rng`, no clock.
|
||||
|
||||
/// Interquartile mean: mean of the middle 50% of samples (drop the bottom 25%
|
||||
/// and the top 25%). Robust to outliers in either tail.
|
||||
///
|
||||
/// Small-N behaviour: with fewer than 4 samples the trim would empty the set,
|
||||
/// so it falls back to the plain arithmetic mean. An empty slice returns 0.0.
|
||||
pub fn iqm(samples: &[f64]) -> f64 {
|
||||
if samples.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
if samples.len() < 4 {
|
||||
return samples.iter().sum::<f64>() / samples.len() as f64;
|
||||
}
|
||||
let mut sorted = samples.to_vec();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let n = sorted.len();
|
||||
let lo = n / 4; // trim bottom 25%
|
||||
let hi = n - lo; // trim top 25% (symmetric)
|
||||
let mid = &sorted[lo..hi];
|
||||
if mid.is_empty() {
|
||||
return sorted.iter().sum::<f64>() / n as f64;
|
||||
}
|
||||
mid.iter().sum::<f64>() / mid.len() as f64
|
||||
}
|
||||
|
||||
/// A point estimate with its lower / upper 95% confidence bounds.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ConfidenceInterval {
|
||||
pub point: f64,
|
||||
pub lo: f64,
|
||||
pub hi: f64,
|
||||
}
|
||||
|
||||
/// Minimal reproducible LCG (Numerical Recipes constants) yielding f64 in [0,1).
|
||||
struct Lcg(u64);
|
||||
|
||||
impl Lcg {
|
||||
fn new(seed: u64) -> Self {
|
||||
// Avoid a zero state collapsing the generator.
|
||||
Lcg(seed ^ 0x9E37_79B9_7F4A_7C15)
|
||||
}
|
||||
#[inline]
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
self.0 = self
|
||||
.0
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
self.0
|
||||
}
|
||||
/// Uniform index in [0, n).
|
||||
#[inline]
|
||||
fn index(&mut self, n: usize) -> usize {
|
||||
if n == 0 {
|
||||
return 0;
|
||||
}
|
||||
(self.next_u64() >> 11) as usize % n
|
||||
}
|
||||
}
|
||||
|
||||
/// 95% stratified-bootstrap CI of the IQM.
|
||||
///
|
||||
/// `strata` groups samples (one inner `Vec` per stratum, e.g. per task or per
|
||||
/// seed). Each bootstrap replicate resamples WITH replacement *within* each
|
||||
/// stratum (preserving the stratum sizes), pools all resampled values, and
|
||||
/// recomputes the IQM. Repeat `n_boot` times and take the 2.5 / 97.5
|
||||
/// percentiles for the CI bounds. The `point` estimate is the IQM of the pooled
|
||||
/// original samples. Deterministic for a fixed `seed`.
|
||||
pub fn stratified_bootstrap_ci(
|
||||
strata: &[Vec<f64>],
|
||||
n_boot: usize,
|
||||
seed: u64,
|
||||
) -> ConfidenceInterval {
|
||||
let pooled: Vec<f64> = strata.iter().flatten().copied().collect();
|
||||
let point = iqm(&pooled);
|
||||
|
||||
if pooled.is_empty() || n_boot == 0 {
|
||||
return ConfidenceInterval { point, lo: point, hi: point };
|
||||
}
|
||||
|
||||
let mut rng = Lcg::new(seed);
|
||||
let mut replicates = Vec::with_capacity(n_boot);
|
||||
let mut buf: Vec<f64> = Vec::with_capacity(pooled.len());
|
||||
|
||||
for _ in 0..n_boot {
|
||||
buf.clear();
|
||||
for stratum in strata {
|
||||
let m = stratum.len();
|
||||
for _ in 0..m {
|
||||
buf.push(stratum[rng.index(m)]);
|
||||
}
|
||||
}
|
||||
replicates.push(iqm(&buf));
|
||||
}
|
||||
|
||||
replicates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let lo = percentile(&replicates, 2.5);
|
||||
let hi = percentile(&replicates, 97.5);
|
||||
ConfidenceInterval { point, lo, hi }
|
||||
}
|
||||
|
||||
/// Linear-interpolated percentile of a pre-sorted slice. `p` in [0, 100].
|
||||
fn percentile(sorted: &[f64], p: f64) -> f64 {
|
||||
if sorted.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
if sorted.len() == 1 {
|
||||
return sorted[0];
|
||||
}
|
||||
let rank = (p / 100.0) * (sorted.len() as f64 - 1.0);
|
||||
let lo = rank.floor() as usize;
|
||||
let hi = rank.ceil() as usize;
|
||||
if lo == hi {
|
||||
return sorted[lo];
|
||||
}
|
||||
let frac = rank - lo as f64;
|
||||
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
|
||||
}
|
||||
|
||||
/// Probability of improvement: P(a-sample > b-sample) over all pairs (Agarwal).
|
||||
///
|
||||
/// Counts each (a_i, b_j) pair where `a_i > b_j` as 1, a tie as 0.5, and
|
||||
/// normalizes by the pair count. 1.0 means `a` strictly dominates; ~0.5 means
|
||||
/// the two are statistically indistinguishable. Returns 0.5 if either is empty.
|
||||
pub fn probability_of_improvement(a: &[f64], b: &[f64]) -> f64 {
|
||||
if a.is_empty() || b.is_empty() {
|
||||
return 0.5;
|
||||
}
|
||||
let mut wins = 0.0;
|
||||
for &ai in a {
|
||||
for &bj in b {
|
||||
if ai > bj {
|
||||
wins += 1.0;
|
||||
} else if (ai - bj).abs() < f64::EPSILON {
|
||||
wins += 0.5;
|
||||
}
|
||||
}
|
||||
}
|
||||
wins / (a.len() as f64 * b.len() as f64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_iqm_trims_outliers() {
|
||||
// 0..=100 plus one extreme outlier; IQM should sit near the middle (~50),
|
||||
// not be dragged toward 1e9.
|
||||
let mut samples: Vec<f64> = (0..=100).map(|i| i as f64).collect();
|
||||
samples.push(1e9);
|
||||
let v = iqm(&samples);
|
||||
assert!(
|
||||
(40.0..=60.0).contains(&v),
|
||||
"IQM should be near the middle-50% mean (~50), got {v}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iqm_small() {
|
||||
// Fewer than 4 samples → plain mean.
|
||||
assert_eq!(iqm(&[2.0, 4.0]), 3.0);
|
||||
assert_eq!(iqm(&[10.0]), 10.0);
|
||||
assert_eq!(iqm(&[1.0, 2.0, 3.0]), 2.0);
|
||||
assert_eq!(iqm(&[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_ci_brackets_point() {
|
||||
let strata = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
vec![2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
];
|
||||
let ci = stratified_bootstrap_ci(&strata, 500, 42);
|
||||
assert!(ci.lo <= ci.point, "lo ≤ point: {} ≤ {}", ci.lo, ci.point);
|
||||
assert!(ci.point <= ci.hi, "point ≤ hi: {} ≤ {}", ci.point, ci.hi);
|
||||
// Deterministic: same seed → identical interval.
|
||||
let ci2 = stratified_bootstrap_ci(&strata, 500, 42);
|
||||
assert_eq!(ci.point, ci2.point);
|
||||
assert_eq!(ci.lo, ci2.lo);
|
||||
assert_eq!(ci.hi, ci2.hi);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prob_improvement_obvious() {
|
||||
assert_eq!(
|
||||
probability_of_improvement(&[10.0, 10.0, 10.0], &[0.0, 0.0, 0.0]),
|
||||
1.0
|
||||
);
|
||||
// Identical samples → all ties → 0.5.
|
||||
let poi = probability_of_improvement(&[5.0, 5.0], &[5.0, 5.0]);
|
||||
assert!((poi - 0.5).abs() < 1e-9, "symmetric ties → ~0.5, got {poi}");
|
||||
}
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//! Fail-safe state machine: link loss, low battery, collision avoidance.
|
||||
|
||||
use crate::types::DroneState;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Fail-safe operating state.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum FailSafeState {
|
||||
Nominal,
|
||||
AutonomousHold,
|
||||
LowBatteryWarn,
|
||||
ReturnToHome,
|
||||
EmergencyLand,
|
||||
EmergencyDiverge,
|
||||
ControlledDescent,
|
||||
}
|
||||
|
||||
/// State machine driving fail-safe transitions.
|
||||
pub struct FailSafeMachine {
|
||||
state: FailSafeState,
|
||||
link_loss_start: Option<Instant>,
|
||||
pub link_loss_hold_secs: f64,
|
||||
pub link_loss_rth_secs: f64,
|
||||
pub battery_warn_pct: f32,
|
||||
pub battery_rth_pct: f32,
|
||||
pub collision_dist_m: f64,
|
||||
}
|
||||
|
||||
impl FailSafeMachine {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: FailSafeState::Nominal,
|
||||
link_loss_start: None,
|
||||
link_loss_hold_secs: 3.0,
|
||||
link_loss_rth_secs: 30.0,
|
||||
battery_warn_pct: 20.0,
|
||||
battery_rth_pct: 15.0,
|
||||
collision_dist_m: 1.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive one tick. Returns the current state after evaluation.
|
||||
pub fn tick(
|
||||
&mut self,
|
||||
state: &DroneState,
|
||||
link_alive: bool,
|
||||
nearest_neighbor_dist: f64,
|
||||
) -> FailSafeState {
|
||||
// Collision avoidance has highest priority.
|
||||
//
|
||||
// Fail CLOSED on a non-finite neighbour distance. `nearest_neighbor_dist`
|
||||
// is derived from peer positions (see
|
||||
// `SwarmOrchestrator::nearest_peer_distance`), which arrive over the
|
||||
// untrusted swarm comm layer as `DroneState` values whose f64 position
|
||||
// fields can deserialize to NaN/Inf. A naive `NaN < collision_dist_m`
|
||||
// evaluates to `false`, silently DISABLING collision avoidance — the
|
||||
// worst possible failure for a physical drone. Treat a non-finite
|
||||
// distance as "too close" so the swarm diverges rather than trusting a
|
||||
// poisoned reading.
|
||||
if !nearest_neighbor_dist.is_finite() || nearest_neighbor_dist < self.collision_dist_m {
|
||||
self.state = FailSafeState::EmergencyDiverge;
|
||||
return self.state.clone();
|
||||
}
|
||||
|
||||
// Link loss handling
|
||||
if !link_alive {
|
||||
let start = self.link_loss_start.get_or_insert_with(Instant::now);
|
||||
let elapsed = start.elapsed().as_secs_f64();
|
||||
if elapsed > self.link_loss_rth_secs {
|
||||
self.state = FailSafeState::ReturnToHome;
|
||||
} else if elapsed > self.link_loss_hold_secs {
|
||||
self.state = FailSafeState::AutonomousHold;
|
||||
}
|
||||
return self.state.clone();
|
||||
} else {
|
||||
// Link restored
|
||||
self.link_loss_start = None;
|
||||
if self.state == FailSafeState::AutonomousHold {
|
||||
self.state = FailSafeState::Nominal;
|
||||
}
|
||||
}
|
||||
|
||||
// Battery checks. A non-finite battery reading (NaN/Inf from a corrupt or
|
||||
// forged telemetry/peer message) must fail CLOSED: `NaN <= threshold` is
|
||||
// `false`, which would otherwise let a drone with an unknown battery
|
||||
// level keep flying nominally. Treat a non-finite reading as critical.
|
||||
if !state.battery_pct.is_finite() || state.battery_pct <= self.battery_rth_pct {
|
||||
self.state = FailSafeState::ReturnToHome;
|
||||
} else if state.battery_pct <= self.battery_warn_pct {
|
||||
self.state = FailSafeState::LowBatteryWarn;
|
||||
} else if self.state == FailSafeState::LowBatteryWarn {
|
||||
// Recovered from low battery (charged on the fly / wrong reading)
|
||||
self.state = FailSafeState::Nominal;
|
||||
}
|
||||
|
||||
self.state.clone()
|
||||
}
|
||||
|
||||
pub fn current(&self) -> &FailSafeState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
pub fn force_land(&mut self) {
|
||||
self.state = FailSafeState::EmergencyLand;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FailSafeMachine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::NodeId;
|
||||
|
||||
fn good_state() -> DroneState {
|
||||
let mut s = DroneState::default_at_origin(NodeId(1));
|
||||
s.battery_pct = 80.0;
|
||||
s.link_quality = 1.0;
|
||||
s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nominal_when_healthy() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let s = good_state();
|
||||
let result = fsm.tick(&s, true, 10.0);
|
||||
assert_eq!(result, FailSafeState::Nominal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_low_battery_warn() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let mut s = good_state();
|
||||
s.battery_pct = 18.0;
|
||||
let result = fsm.tick(&s, true, 10.0);
|
||||
assert_eq!(result, FailSafeState::LowBatteryWarn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_battery_rth() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let mut s = good_state();
|
||||
s.battery_pct = 10.0;
|
||||
let result = fsm.tick(&s, true, 10.0);
|
||||
assert_eq!(result, FailSafeState::ReturnToHome);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collision_avoidance() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let s = good_state();
|
||||
let result = fsm.tick(&s, true, 0.5); // too close
|
||||
assert_eq!(result, FailSafeState::EmergencyDiverge);
|
||||
}
|
||||
|
||||
/// Security: a NaN neighbour distance (poisoned peer position over the swarm
|
||||
/// comm layer) must NOT silently disable collision avoidance. Fails on old
|
||||
/// code where `NaN < collision_dist_m` is `false` and the state stays Nominal.
|
||||
#[test]
|
||||
fn test_nan_neighbor_distance_fails_closed_to_diverge() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let s = good_state();
|
||||
let result = fsm.tick(&s, true, f64::NAN);
|
||||
assert_eq!(
|
||||
result,
|
||||
FailSafeState::EmergencyDiverge,
|
||||
"non-finite neighbour distance must fail closed to EmergencyDiverge"
|
||||
);
|
||||
}
|
||||
|
||||
/// Security: a NaN battery reading must fail closed to ReturnToHome rather
|
||||
/// than being treated as a healthy battery. Fails on old code where
|
||||
/// `NaN <= battery_rth_pct` is `false` and the drone stays Nominal.
|
||||
#[test]
|
||||
fn test_nan_battery_fails_closed_to_rth() {
|
||||
let mut fsm = FailSafeMachine::new();
|
||||
let mut s = good_state();
|
||||
s.battery_pct = f32::NAN;
|
||||
let result = fsm.tick(&s, true, 10.0);
|
||||
assert_eq!(
|
||||
result,
|
||||
FailSafeState::ReturnToHome,
|
||||
"non-finite battery must fail closed to ReturnToHome"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
//! Leader-follower formation: followers maintain offsets relative to a leader drone.
|
||||
|
||||
use crate::types::{NodeId, Position3D};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Leader-follower formation parameters.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LeaderFollower {
|
||||
pub leader_id: NodeId,
|
||||
/// Follower → (dx, dy, dz) offset from leader's position.
|
||||
pub offsets: HashMap<NodeId, (f64, f64, f64)>,
|
||||
}
|
||||
|
||||
impl LeaderFollower {
|
||||
pub fn new(leader_id: NodeId) -> Self {
|
||||
Self {
|
||||
leader_id,
|
||||
offsets: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_follower(&mut self, follower: NodeId, offset: (f64, f64, f64)) {
|
||||
self.offsets.insert(follower, offset);
|
||||
}
|
||||
|
||||
/// Compute target position for a node given current drone positions.
|
||||
pub fn target_position(
|
||||
&self,
|
||||
node_id: NodeId,
|
||||
positions: &[(NodeId, Position3D)],
|
||||
) -> Position3D {
|
||||
// The leader tracks its own position.
|
||||
if node_id == self.leader_id {
|
||||
return positions
|
||||
.iter()
|
||||
.find(|(id, _)| *id == self.leader_id)
|
||||
.map(|(_, p)| *p)
|
||||
.unwrap_or_default();
|
||||
}
|
||||
let leader_pos = positions
|
||||
.iter()
|
||||
.find(|(id, _)| *id == self.leader_id)
|
||||
.map(|(_, p)| *p)
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(&(dx, dy, dz)) = self.offsets.get(&node_id) {
|
||||
Position3D {
|
||||
x: leader_pos.x + dx,
|
||||
y: leader_pos.y + dy,
|
||||
z: leader_pos.z + dz,
|
||||
}
|
||||
} else {
|
||||
leader_pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_follower_tracks_leader() {
|
||||
let mut lf = LeaderFollower::new(NodeId(0));
|
||||
lf.add_follower(NodeId(1), (-5.0, 0.0, 0.0));
|
||||
let positions = vec![
|
||||
(NodeId(0), Position3D { x: 10.0, y: 20.0, z: -30.0 }),
|
||||
];
|
||||
let target = lf.target_position(NodeId(1), &positions);
|
||||
assert!((target.x - 5.0).abs() < 1e-6);
|
||||
assert!((target.y - 20.0).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//! Formation control: virtual structure, leader-follower, Reynolds flocking.
|
||||
//!
|
||||
// NOTE: Formation control is ITAR-controlled (USML Category VIII(h)(12)).
|
||||
// Only available when the `itar-unrestricted` feature is enabled.
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod virtual_structure;
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod leader_follower;
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod reynolds;
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use virtual_structure::VirtualStructure;
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use leader_follower::LeaderFollower;
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use reynolds::ReynoldsParams;
|
||||
|
||||
/// Stub: formation control is export-controlled. Enable `itar-unrestricted` feature.
|
||||
#[cfg(not(feature = "itar-unrestricted"))]
|
||||
pub fn formation_stub() -> crate::SwarmResult<()> {
|
||||
Err(crate::SwarmError::Security(
|
||||
"Formation control requires itar-unrestricted feature (USML VIII(h)(12))".into(),
|
||||
))
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
//! Reynolds flocking: separation, alignment, cohesion.
|
||||
|
||||
use crate::types::{NodeId, Position3D, Velocity3D};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Parameters for Reynolds boid rules.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReynoldsParams {
|
||||
pub separation_dist_m: f64,
|
||||
pub separation_weight: f64,
|
||||
pub alignment_weight: f64,
|
||||
pub cohesion_weight: f64,
|
||||
pub k_neighbors: usize,
|
||||
}
|
||||
|
||||
impl Default for ReynoldsParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
separation_dist_m: 3.0,
|
||||
separation_weight: 1.5,
|
||||
alignment_weight: 1.0,
|
||||
cohesion_weight: 0.8,
|
||||
k_neighbors: 7,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReynoldsParams {
|
||||
/// Compute a desired velocity delta for `node_id` based on the three Reynolds rules.
|
||||
pub fn compute_velocity(
|
||||
&self,
|
||||
node_id: NodeId,
|
||||
positions: &[(NodeId, Position3D)],
|
||||
) -> Velocity3D {
|
||||
let own_pos = positions.iter().find(|(id, _)| *id == node_id).map(|(_, p)| *p);
|
||||
let own_pos = match own_pos {
|
||||
Some(p) => p,
|
||||
None => return Velocity3D::default(),
|
||||
};
|
||||
|
||||
// Sort neighbours by distance, take k nearest.
|
||||
let mut neighbours: Vec<(f64, &Position3D)> = positions
|
||||
.iter()
|
||||
.filter(|(id, _)| *id != node_id)
|
||||
.map(|(_, p)| (own_pos.distance_to(p), p))
|
||||
.collect();
|
||||
neighbours.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
neighbours.truncate(self.k_neighbors);
|
||||
|
||||
if neighbours.is_empty() {
|
||||
return Velocity3D::default();
|
||||
}
|
||||
|
||||
let n = neighbours.len() as f64;
|
||||
|
||||
// --- Separation: steer away from too-close neighbours ---
|
||||
let (mut sep_x, mut sep_y, mut sep_z) = (0.0_f64, 0.0_f64, 0.0_f64);
|
||||
for (dist, p) in &neighbours {
|
||||
if *dist < self.separation_dist_m && *dist > 1e-6 {
|
||||
let factor = (self.separation_dist_m - *dist) / self.separation_dist_m;
|
||||
sep_x += (own_pos.x - p.x) / dist * factor;
|
||||
sep_y += (own_pos.y - p.y) / dist * factor;
|
||||
sep_z += (own_pos.z - p.z) / dist * factor;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Cohesion: steer toward average position ---
|
||||
let (avg_x, avg_y, avg_z) = neighbours
|
||||
.iter()
|
||||
.fold((0.0, 0.0, 0.0), |(ax, ay, az), (_, p)| (ax + p.x, ay + p.y, az + p.z));
|
||||
let coh_x = (avg_x / n) - own_pos.x;
|
||||
let coh_y = (avg_y / n) - own_pos.y;
|
||||
let coh_z = (avg_z / n) - own_pos.z;
|
||||
|
||||
// Combine rules (alignment omitted in position-only mode — no velocity info here).
|
||||
let vx = self.separation_weight * sep_x + self.cohesion_weight * coh_x;
|
||||
let vy = self.separation_weight * sep_y + self.cohesion_weight * coh_y;
|
||||
let vz = self.separation_weight * sep_z + self.cohesion_weight * coh_z;
|
||||
|
||||
Velocity3D { vx, vy, vz }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_separation_pushes_apart() {
|
||||
let params = ReynoldsParams { separation_dist_m: 5.0, ..Default::default() };
|
||||
let positions = vec![
|
||||
(NodeId(0), Position3D { x: 0.0, y: 0.0, z: 0.0 }),
|
||||
(NodeId(1), Position3D { x: 1.0, y: 0.0, z: 0.0 }), // too close
|
||||
];
|
||||
let vel = params.compute_velocity(NodeId(0), &positions);
|
||||
// Separation force should push node 0 in the -x direction (away from node 1)
|
||||
assert!(vel.vx < 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_neighbours_returns_zero() {
|
||||
let params = ReynoldsParams::default();
|
||||
let positions = vec![(NodeId(0), Position3D::zero())];
|
||||
let vel = params.compute_velocity(NodeId(0), &positions);
|
||||
assert!((vel.vx.abs() + vel.vy.abs()) < 1e-9);
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
//! Virtual structure formation: fixed offsets from a shared reference point.
|
||||
|
||||
use crate::types::{NodeId, Position3D};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Offsets from a shared reference point for each drone in the formation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VirtualStructure {
|
||||
/// NodeId → (dx, dy, dz) offset in metres from the reference.
|
||||
pub offsets: HashMap<NodeId, (f64, f64, f64)>,
|
||||
}
|
||||
|
||||
impl VirtualStructure {
|
||||
/// Create a rectangular grid formation with `n` drones, spaced `spacing_m` apart.
|
||||
pub fn grid_formation(n: usize, spacing_m: f64) -> Self {
|
||||
let cols = (n as f64).sqrt().ceil() as usize;
|
||||
let mut offsets = HashMap::new();
|
||||
for i in 0..n {
|
||||
let row = i / cols;
|
||||
let col = i % cols;
|
||||
offsets.insert(
|
||||
NodeId(i as u32),
|
||||
(row as f64 * spacing_m, col as f64 * spacing_m, 0.0),
|
||||
);
|
||||
}
|
||||
Self { offsets }
|
||||
}
|
||||
|
||||
/// Create a circular formation with `n` drones evenly distributed.
|
||||
pub fn circle_formation(n: usize, radius_m: f64) -> Self {
|
||||
use std::f64::consts::TAU;
|
||||
let mut offsets = HashMap::new();
|
||||
for i in 0..n {
|
||||
let angle = TAU * i as f64 / n as f64;
|
||||
offsets.insert(
|
||||
NodeId(i as u32),
|
||||
(radius_m * angle.cos(), radius_m * angle.sin(), 0.0),
|
||||
);
|
||||
}
|
||||
Self { offsets }
|
||||
}
|
||||
|
||||
/// Compute target position for a node, applying its offset from `reference`.
|
||||
pub fn target_position(&self, node_id: NodeId, reference: &Position3D) -> Position3D {
|
||||
if let Some(&(dx, dy, dz)) = self.offsets.get(&node_id) {
|
||||
Position3D {
|
||||
x: reference.x + dx,
|
||||
y: reference.y + dy,
|
||||
z: reference.z + dz,
|
||||
}
|
||||
} else {
|
||||
*reference
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_grid_formation_4_drones() {
|
||||
let vs = VirtualStructure::grid_formation(4, 5.0);
|
||||
assert_eq!(vs.offsets.len(), 4);
|
||||
let ref_pos = Position3D { x: 100.0, y: 200.0, z: -30.0 };
|
||||
let p = vs.target_position(NodeId(0), &ref_pos);
|
||||
assert!((p.x - 100.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circle_formation() {
|
||||
let vs = VirtualStructure::circle_formation(4, 10.0);
|
||||
let ref_pos = Position3D::zero();
|
||||
let p = vs.target_position(NodeId(0), &ref_pos);
|
||||
// Node 0 at angle 0: x = 10, y = 0
|
||||
assert!((p.x - 10.0).abs() < 1e-6);
|
||||
assert!(p.y.abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
//! Flight controller abstraction and simulated implementation.
|
||||
|
||||
use crate::types::{DroneState, NodeId, Position3D};
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Flight controller operating mode.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum FlightMode {
|
||||
/// External position/velocity setpoints (PX4: OFFBOARD, ArduPilot: GUIDED).
|
||||
Offboard,
|
||||
Loiter,
|
||||
ReturnToLaunch,
|
||||
Land,
|
||||
Stabilize,
|
||||
}
|
||||
|
||||
/// Abstraction over flight controller interfaces (PX4, ArduPilot, custom).
|
||||
#[async_trait]
|
||||
pub trait FlightController: Send + Sync {
|
||||
async fn set_target_position(
|
||||
&self,
|
||||
pos: &Position3D,
|
||||
speed_ms: f64,
|
||||
) -> crate::SwarmResult<()>;
|
||||
|
||||
async fn get_state(&self) -> crate::SwarmResult<DroneState>;
|
||||
|
||||
async fn set_mode(&self, mode: FlightMode) -> crate::SwarmResult<()>;
|
||||
|
||||
async fn arm(&self) -> crate::SwarmResult<()>;
|
||||
|
||||
async fn disarm(&self) -> crate::SwarmResult<()>;
|
||||
|
||||
async fn rtl(&self) -> crate::SwarmResult<()>;
|
||||
|
||||
async fn emergency_land(&self) -> crate::SwarmResult<()>;
|
||||
}
|
||||
|
||||
/// A simulated flight controller that immediately applies position commands.
|
||||
/// Used in tests and demo mode.
|
||||
pub struct SimulatedFlightController {
|
||||
pub state: Mutex<DroneState>,
|
||||
}
|
||||
|
||||
impl SimulatedFlightController {
|
||||
pub fn new(id: NodeId) -> Self {
|
||||
Self {
|
||||
state: Mutex::new(DroneState::default_at_origin(id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FlightController for SimulatedFlightController {
|
||||
async fn set_target_position(
|
||||
&self,
|
||||
pos: &Position3D,
|
||||
_speed_ms: f64,
|
||||
) -> crate::SwarmResult<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.position = *pos;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_state(&self) -> crate::SwarmResult<DroneState> {
|
||||
let state = self.state.lock().await;
|
||||
Ok(state.clone())
|
||||
}
|
||||
|
||||
async fn set_mode(&self, _mode: FlightMode) -> crate::SwarmResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn arm(&self) -> crate::SwarmResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn disarm(&self) -> crate::SwarmResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn rtl(&self) -> crate::SwarmResult<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.position = Position3D::zero();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn emergency_land(&self) -> crate::SwarmResult<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.altitude_agl_m = 0.0;
|
||||
state.position.z = 0.0;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_position_updates_state() {
|
||||
let fc = SimulatedFlightController::new(NodeId(0));
|
||||
let target = Position3D { x: 50.0, y: 30.0, z: -20.0 };
|
||||
fc.set_target_position(&target, 5.0).await.unwrap();
|
||||
let state = fc.get_state().await.unwrap();
|
||||
assert!((state.position.x - 50.0).abs() < 1e-6);
|
||||
assert!((state.position.y - 30.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rtl_returns_to_origin() {
|
||||
let fc = SimulatedFlightController::new(NodeId(1));
|
||||
fc.set_target_position(
|
||||
&Position3D { x: 100.0, y: 100.0, z: -30.0 },
|
||||
5.0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
fc.rtl().await.unwrap();
|
||||
let state = fc.get_state().await.unwrap();
|
||||
assert!(state.position.x.abs() < 1e-6);
|
||||
assert!(state.position.y.abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
//! Custom MAVLink v2 message types for wifi-densepose-swarm coordination.
|
||||
//!
|
||||
//! Message IDs follow MAVLink custom dialect convention (50000+).
|
||||
//! All messages are signed via `security::mavlink_signing::MavlinkSigner`.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::types::{NodeId, Position3D, CsiDetection};
|
||||
|
||||
/// MAVLink message ID base for swarm custom dialect.
|
||||
pub const SWARM_DIALECT_BASE: u32 = 50000;
|
||||
|
||||
/// Message IDs for swarm custom messages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SwarmMsgId {
|
||||
/// Swarm node kinematic state broadcast (50000).
|
||||
NodeState = 50000,
|
||||
/// CSI detection report from sensing payload (50001).
|
||||
CsiReport = 50001,
|
||||
/// Task assignment from cluster head to worker (50002).
|
||||
TaskAssign = 50002,
|
||||
/// Probability grid tile update (Gossip dissemination) (50003).
|
||||
GridTileUpdate = 50003,
|
||||
/// Cluster head heartbeat + Raft term (50004).
|
||||
ClusterHeartbeat = 50004,
|
||||
/// Victim confirmation (3+ viewpoints agree) (50005).
|
||||
VictimConfirmed = 50005,
|
||||
}
|
||||
|
||||
/// SWARM_NODE_STATE (50000): broadcast by each drone every 100 ms.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmNodeState {
|
||||
/// Sending node ID.
|
||||
pub node_id: u32,
|
||||
/// North position in local NED frame (m × 1000 = mm).
|
||||
pub pos_north_mm: i32,
|
||||
/// East position (mm).
|
||||
pub pos_east_mm: i32,
|
||||
/// Down position (mm, negative = above ground).
|
||||
pub pos_down_mm: i32,
|
||||
/// Speed m/s × 100.
|
||||
pub speed_cm_s: u16,
|
||||
/// Heading degrees × 100 (0–36000).
|
||||
pub heading_cdeg: u16,
|
||||
/// Battery percent × 10 (0–1000).
|
||||
pub battery_10th_pct: u16,
|
||||
/// Link quality 0–255 (255 = perfect).
|
||||
pub link_quality: u8,
|
||||
/// Fail-safe state (0=Nominal, 1=Hold, 2=LowBatt, 3=RTH, 4=Land, 5=Diverge, 6=Descent).
|
||||
pub failsafe_state: u8,
|
||||
/// Timestamp ms (wraps at u32 max, ~49 days).
|
||||
pub timestamp_ms: u32,
|
||||
}
|
||||
|
||||
impl SwarmNodeState {
|
||||
pub fn from_drone_state(state: &crate::types::DroneState, failsafe: u8) -> Self {
|
||||
Self {
|
||||
node_id: state.id.0,
|
||||
pos_north_mm: (state.position.x * 1000.0) as i32,
|
||||
pos_east_mm: (state.position.y * 1000.0) as i32,
|
||||
pos_down_mm: (state.position.z * 1000.0) as i32,
|
||||
speed_cm_s: (state.velocity.magnitude() * 100.0) as u16,
|
||||
heading_cdeg: ((state.heading_rad.to_degrees().rem_euclid(360.0)) * 100.0) as u16,
|
||||
battery_10th_pct: (state.battery_pct * 10.0) as u16,
|
||||
link_quality: (state.link_quality * 255.0) as u8,
|
||||
failsafe_state: failsafe,
|
||||
timestamp_ms: state.timestamp_ms as u32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode to 20-byte MAVLink payload (fixed-length for efficiency).
|
||||
pub fn encode(&self) -> [u8; 20] {
|
||||
let mut buf = [0u8; 20];
|
||||
buf[0..4].copy_from_slice(&self.node_id.to_le_bytes());
|
||||
buf[4..8].copy_from_slice(&self.pos_north_mm.to_le_bytes());
|
||||
buf[8..12].copy_from_slice(&self.pos_east_mm.to_le_bytes());
|
||||
buf[12..16].copy_from_slice(&self.pos_down_mm.to_le_bytes());
|
||||
buf[16] = self.failsafe_state;
|
||||
buf[17] = self.link_quality;
|
||||
buf[18..20].copy_from_slice(&self.battery_10th_pct.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode from 20-byte MAVLink payload.
|
||||
pub fn decode(buf: &[u8; 20]) -> Self {
|
||||
Self {
|
||||
node_id: u32::from_le_bytes(buf[0..4].try_into().unwrap()),
|
||||
pos_north_mm: i32::from_le_bytes(buf[4..8].try_into().unwrap()),
|
||||
pos_east_mm: i32::from_le_bytes(buf[8..12].try_into().unwrap()),
|
||||
pos_down_mm: i32::from_le_bytes(buf[12..16].try_into().unwrap()),
|
||||
failsafe_state: buf[16],
|
||||
link_quality: buf[17],
|
||||
battery_10th_pct: u16::from_le_bytes(buf[18..20].try_into().unwrap()),
|
||||
speed_cm_s: 0,
|
||||
heading_cdeg: 0,
|
||||
timestamp_ms: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SWARM_CSI_REPORT (50001): sent by sensing payload when detection confidence > threshold.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmCsiReport {
|
||||
pub node_id: u32,
|
||||
pub confidence_u8: u8, // confidence × 255
|
||||
pub has_position: bool,
|
||||
pub victim_north_mm: i32, // estimated victim position
|
||||
pub victim_east_mm: i32,
|
||||
pub victim_down_mm: i32,
|
||||
pub timestamp_ms: u32,
|
||||
}
|
||||
|
||||
impl SwarmCsiReport {
|
||||
pub fn from_detection(det: &CsiDetection) -> Self {
|
||||
let (n, e, d) = det.victim_position
|
||||
.map(|p| ((p.x * 1000.0) as i32, (p.y * 1000.0) as i32, (p.z * 1000.0) as i32))
|
||||
.unwrap_or((0, 0, 0));
|
||||
Self {
|
||||
node_id: det.drone_id.0,
|
||||
confidence_u8: (det.confidence * 255.0) as u8,
|
||||
has_position: det.victim_position.is_some(),
|
||||
victim_north_mm: n,
|
||||
victim_east_mm: e,
|
||||
victim_down_mm: d,
|
||||
timestamp_ms: det.timestamp_ms as u32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_detection(&self) -> CsiDetection {
|
||||
CsiDetection {
|
||||
drone_id: NodeId(self.node_id),
|
||||
confidence: self.confidence_u8 as f32 / 255.0,
|
||||
victim_position: if self.has_position {
|
||||
Some(Position3D {
|
||||
x: self.victim_north_mm as f64 / 1000.0,
|
||||
y: self.victim_east_mm as f64 / 1000.0,
|
||||
z: self.victim_down_mm as f64 / 1000.0,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
timestamp_ms: self.timestamp_ms as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SWARM_CLUSTER_HEARTBEAT (50004): Raft leader heartbeat.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmClusterHeartbeat {
|
||||
pub leader_id: u32,
|
||||
pub raft_term: u64,
|
||||
pub cluster_size: u8,
|
||||
pub active_drones: u8,
|
||||
pub mission_phase: u8, // 0=Systematic, 1=ProbabilisticPursuit, 2=Convergence
|
||||
pub timestamp_ms: u32,
|
||||
}
|
||||
|
||||
/// SWARM_VICTIM_CONFIRMED (50005): 3+ viewpoints confirm victim location.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SwarmVictimConfirmed {
|
||||
pub victim_id: u8, // sequential victim counter
|
||||
pub victim_north_mm: i32,
|
||||
pub victim_east_mm: i32,
|
||||
pub victim_down_mm: i32,
|
||||
pub uncertainty_mm: u16, // localization uncertainty in mm
|
||||
pub contributing_drones: u8, // bitmask (drone 0 = bit 0)
|
||||
pub fused_confidence_u8: u8,
|
||||
pub timestamp_ms: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{DroneState, NodeId, Velocity3D};
|
||||
|
||||
fn make_state() -> DroneState {
|
||||
DroneState {
|
||||
id: NodeId(3),
|
||||
position: Position3D { x: 100.5, y: 200.25, z: -30.0 },
|
||||
velocity: Velocity3D { vx: 5.0, vy: 0.0, vz: 0.0 },
|
||||
heading_rad: std::f64::consts::PI / 4.0,
|
||||
altitude_agl_m: 30.0,
|
||||
battery_pct: 78.5,
|
||||
link_quality: 0.92,
|
||||
timestamp_ms: 12345,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_state_encode_decode_roundtrip() {
|
||||
let state = make_state();
|
||||
let msg = SwarmNodeState::from_drone_state(&state, 0);
|
||||
let encoded = msg.encode();
|
||||
let decoded = SwarmNodeState::decode(&encoded);
|
||||
assert_eq!(decoded.node_id, 3);
|
||||
assert_eq!(decoded.pos_north_mm, 100500); // 100.5 m × 1000
|
||||
assert_eq!(decoded.failsafe_state, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_csi_report_roundtrip() {
|
||||
let det = CsiDetection {
|
||||
drone_id: NodeId(1),
|
||||
confidence: 0.85,
|
||||
victim_position: Some(Position3D { x: 50.0, y: 75.0, z: 0.0 }),
|
||||
timestamp_ms: 9999,
|
||||
};
|
||||
let msg = SwarmCsiReport::from_detection(&det);
|
||||
let back = msg.to_detection();
|
||||
assert!((back.confidence - 0.85).abs() < 0.01, "confidence roundtrip");
|
||||
let vp = back.victim_position.unwrap();
|
||||
assert!((vp.x - 50.0).abs() < 0.001);
|
||||
assert!((vp.y - 75.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_battery_encoding() {
|
||||
let mut state = make_state();
|
||||
state.battery_pct = 50.0;
|
||||
let msg = SwarmNodeState::from_drone_state(&state, 0);
|
||||
assert_eq!(msg.battery_10th_pct, 500); // 50% × 10
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
//! Mission outcome report with victim confirmation details.
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A single confirmed victim with localization metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VictimReport {
|
||||
pub victim_id: u32,
|
||||
pub position: [f64; 3], // [north, east, down] NED metres
|
||||
pub localization_error_m: f64, // distance from ground-truth (sim only)
|
||||
pub uncertainty_m: f64, // fusion uncertainty ellipse
|
||||
pub contributing_drones: Vec<u32>,
|
||||
pub fused_confidence: f32,
|
||||
pub detection_time_secs: f64, // mission-elapsed time at confirmation
|
||||
}
|
||||
|
||||
/// Complete mission outcome report.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MissionReport {
|
||||
pub profile: String,
|
||||
pub num_drones: usize,
|
||||
pub area_m2: f64,
|
||||
pub mission_duration_secs: f64,
|
||||
pub coverage_pct: f64,
|
||||
pub victims_total: usize,
|
||||
pub victims_confirmed: usize,
|
||||
pub detection_rate: f64, // confirmed / total
|
||||
pub mean_localization_error_m: f64,
|
||||
pub collision_events: u32,
|
||||
pub victims: Vec<VictimReport>,
|
||||
pub sota_comparison: SotaComparison,
|
||||
}
|
||||
|
||||
/// Comparison against the Wi2SAR published baseline.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SotaComparison {
|
||||
pub wi2sar_localization_m: f64, // 5.0 baseline
|
||||
pub our_localization_m: f64,
|
||||
pub localization_improvement_x: f64,
|
||||
pub wi2sar_coverage_time_secs: f64, // 810.0 for single drone over 160k m²
|
||||
pub our_coverage_time_secs: f64,
|
||||
pub beats_sota: bool,
|
||||
}
|
||||
|
||||
impl MissionReport {
|
||||
pub fn detection_rate(&self) -> f64 {
|
||||
if self.victims_total == 0 {
|
||||
1.0
|
||||
} else {
|
||||
self.victims_confirmed as f64 / self.victims_total as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Produce a human-readable summary line.
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"{} mission: {}/{} victims confirmed ({:.0}%), mean error {:.2}m, {:.0}% coverage in {:.1}s, {} collisions — SOTA: {}",
|
||||
self.profile,
|
||||
self.victims_confirmed,
|
||||
self.victims_total,
|
||||
self.detection_rate() * 100.0,
|
||||
self.mean_localization_error_m,
|
||||
self.coverage_pct * 100.0,
|
||||
self.mission_duration_secs,
|
||||
self.collision_events,
|
||||
if self.sota_comparison.beats_sota { "BEATEN" } else { "not beaten" },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_sota() -> SotaComparison {
|
||||
SotaComparison {
|
||||
wi2sar_localization_m: 5.0,
|
||||
our_localization_m: 1.5,
|
||||
localization_improvement_x: 3.33,
|
||||
wi2sar_coverage_time_secs: 810.0,
|
||||
our_coverage_time_secs: 120.0,
|
||||
beats_sota: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detection_rate_no_victims() {
|
||||
let report = MissionReport {
|
||||
profile: "sar".to_string(),
|
||||
num_drones: 2,
|
||||
area_m2: 160_000.0,
|
||||
mission_duration_secs: 100.0,
|
||||
coverage_pct: 0.5,
|
||||
victims_total: 0,
|
||||
victims_confirmed: 0,
|
||||
detection_rate: 1.0,
|
||||
mean_localization_error_m: 0.0,
|
||||
collision_events: 0,
|
||||
victims: vec![],
|
||||
sota_comparison: sample_sota(),
|
||||
};
|
||||
assert_eq!(report.detection_rate(), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detection_rate_partial() {
|
||||
let report = MissionReport {
|
||||
profile: "sar".to_string(),
|
||||
num_drones: 4,
|
||||
area_m2: 160_000.0,
|
||||
mission_duration_secs: 100.0,
|
||||
coverage_pct: 0.8,
|
||||
victims_total: 4,
|
||||
victims_confirmed: 2,
|
||||
detection_rate: 0.5,
|
||||
mean_localization_error_m: 1.5,
|
||||
collision_events: 0,
|
||||
victims: vec![],
|
||||
sota_comparison: sample_sota(),
|
||||
};
|
||||
assert_eq!(report.detection_rate(), 0.5);
|
||||
assert!(report.summary().contains("sar mission"));
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//! External system integration: MAVLink v2, PX4 SITL, Gazebo, ROS2 DDS.
|
||||
|
||||
pub mod mavlink_messages;
|
||||
pub mod mission_report;
|
||||
pub mod swarm_sim;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use mission_report::{MissionReport, SotaComparison, VictimReport};
|
||||
pub use telemetry::{DroneFrame, TelemetryRecorder};
|
||||
|
||||
pub use mavlink_messages::{
|
||||
SwarmNodeState, SwarmCsiReport, SwarmClusterHeartbeat, SwarmVictimConfirmed, SwarmMsgId,
|
||||
};
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub mod flight_controller;
|
||||
|
||||
#[cfg(feature = "itar-unrestricted")]
|
||||
pub use flight_controller::{FlightController, FlightMode, SimulatedFlightController};
|
||||
@@ -1,487 +0,0 @@
|
||||
//! End-to-end 4-drone swarm simulation for integration testing.
|
||||
//!
|
||||
//! Simulates a complete SAR mission: systematic sweep → victim detection →
|
||||
//! multi-drone convergence. Validates M3 (CSI integration) + M7 (mission profiles).
|
||||
|
||||
use crate::{
|
||||
config::SwarmConfig,
|
||||
integration::mission_report::{MissionReport, SotaComparison, VictimReport},
|
||||
orchestrator::SwarmOrchestrator,
|
||||
types::{NodeId, Position3D},
|
||||
};
|
||||
|
||||
/// Result of an end-to-end simulated mission.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimMissionResult {
|
||||
pub total_cells_covered: u32,
|
||||
pub victims_detected: usize,
|
||||
pub elapsed_secs: f64,
|
||||
pub collision_events: u32,
|
||||
pub final_localization_error_m: Option<f64>,
|
||||
pub coverage_pct: f64,
|
||||
}
|
||||
|
||||
/// Run an N-drone SAR swarm simulation using the Wi2SAR reference config.
|
||||
///
|
||||
/// Each step:
|
||||
/// 1. Each drone calls `step()` advancing its state machine.
|
||||
/// 2. All drone states are exchanged via simulated MAVLink broadcast.
|
||||
/// 3. Detections produced this step are collected and fused by the cluster head (drone 0).
|
||||
/// 4. Mission completes when coverage_pct > 90% or all steps are exhausted.
|
||||
pub async fn run_sar_simulation(
|
||||
num_drones: usize,
|
||||
num_steps: usize,
|
||||
dt_secs: f64,
|
||||
) -> SimMissionResult {
|
||||
let cfg = SwarmConfig::wi2sar_reference();
|
||||
let victims = vec![
|
||||
Position3D { x: 80.0, y: 120.0, z: 0.0 },
|
||||
Position3D { x: 250.0, y: 180.0, z: 0.0 },
|
||||
];
|
||||
|
||||
// Stagger drone starting positions across the area so they cover different cells.
|
||||
let area_w = cfg.mission.area_width_m;
|
||||
let area_h = cfg.mission.area_height_m;
|
||||
let mut drones: Vec<SwarmOrchestrator> = (0..num_drones)
|
||||
.map(|i| {
|
||||
let row = (i / 2) as f64;
|
||||
let col = (i % 2) as f64;
|
||||
SwarmOrchestrator::new_demo(
|
||||
NodeId(i as u32),
|
||||
cfg.clone(),
|
||||
Position3D {
|
||||
x: 10.0 + col * (area_w / 2.0),
|
||||
y: 10.0 + row * (area_h / 2.0),
|
||||
z: -cfg.planning.flight_altitude_m,
|
||||
},
|
||||
victims.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut victims_detected = 0usize;
|
||||
let mut collision_events = 0u32;
|
||||
let mut final_localization_error: Option<f64> = None;
|
||||
|
||||
for _step in 0..num_steps {
|
||||
// Step all drones (each step clears peer_detections internally).
|
||||
for drone in &mut drones {
|
||||
drone.step(dt_secs, true).await;
|
||||
}
|
||||
|
||||
// Exchange simulated MAVLink state messages (full mesh broadcast).
|
||||
// Collect states first to avoid borrow conflicts.
|
||||
let states: Vec<_> = drones.iter().map(|d| d.state.clone()).collect();
|
||||
for drone in &mut drones {
|
||||
for state in &states {
|
||||
if state.id != drone.node_id {
|
||||
drone.receive_peer_state(state.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather CSI detections injected by the payload pipelines this step.
|
||||
// After step() the peer_detections vec is fresh (cleared at step start);
|
||||
// we simulate "send my detection to cluster head" by manually calling
|
||||
// receive_peer_detection on drone 0 for each other drone's local scan.
|
||||
// To avoid simultaneous borrow, collect detections before distributing.
|
||||
let local_detections: Vec<_> = drones
|
||||
.iter()
|
||||
.filter_map(|d| d.peer_detections.first().cloned())
|
||||
.collect();
|
||||
|
||||
if !local_detections.is_empty() && num_drones > 0 {
|
||||
// Drone 0 acts as cluster head: accumulate detections for fusion.
|
||||
for det in &local_detections {
|
||||
if det.drone_id != drones[0].node_id {
|
||||
drones[0].receive_peer_detection(det.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt multi-drone fusion on cluster head.
|
||||
let all_dets: Vec<_> = drones[0].peer_detections.clone();
|
||||
if all_dets.len() >= 2 {
|
||||
let positions: Vec<(NodeId, Position3D)> = drones
|
||||
.iter()
|
||||
.map(|d| (d.node_id, d.state.position))
|
||||
.collect();
|
||||
|
||||
if let Some(fused) = drones[0].fuse_detections(&all_dets, &positions) {
|
||||
if fused.confidence > 0.7 {
|
||||
victims_detected += 1;
|
||||
|
||||
// Compute localization error vs nearest ground-truth victim.
|
||||
let err = victims
|
||||
.iter()
|
||||
.map(|v| fused.estimated_position.distance_to(v))
|
||||
.fold(f64::MAX, f64::min);
|
||||
final_localization_error = Some(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check pairwise collision events (separation < 1.5 m).
|
||||
for i in 0..drones.len() {
|
||||
for j in (i + 1)..drones.len() {
|
||||
let dist = drones[i].state.position.distance_to(&drones[j].state.position);
|
||||
if dist < 1.5 {
|
||||
collision_events += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit when sufficient coverage achieved.
|
||||
let avg_coverage = drones
|
||||
.iter()
|
||||
.map(|d| d.probability_grid.coverage_pct())
|
||||
.sum::<f64>()
|
||||
/ drones.len() as f64;
|
||||
if avg_coverage > 0.90 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let total_cells: u32 = drones.iter().map(|d| d.stats.cells_covered).sum();
|
||||
let elapsed = drones[0].stats.elapsed_secs;
|
||||
let avg_coverage = drones
|
||||
.iter()
|
||||
.map(|d| d.probability_grid.coverage_pct())
|
||||
.sum::<f64>()
|
||||
/ drones.len() as f64;
|
||||
|
||||
SimMissionResult {
|
||||
total_cells_covered: total_cells,
|
||||
victims_detected,
|
||||
elapsed_secs: elapsed,
|
||||
collision_events,
|
||||
final_localization_error_m: final_localization_error,
|
||||
coverage_pct: avg_coverage,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a full mission and produce a detailed MissionReport (not just SimMissionResult).
|
||||
/// This is the M7 end-to-end mission with victim confirmation.
|
||||
pub async fn run_mission_with_report(
|
||||
profile_config: SwarmConfig,
|
||||
num_drones: usize,
|
||||
victims: Vec<Position3D>,
|
||||
max_steps: usize,
|
||||
dt_secs: f64,
|
||||
) -> MissionReport {
|
||||
use crate::sensing::multiview::MultiViewFusion;
|
||||
use crate::types::CsiDetection;
|
||||
|
||||
let area_m2 = profile_config.mission.area_width_m * profile_config.mission.area_height_m;
|
||||
let profile = profile_config.mission.profile.clone();
|
||||
let victims_total = victims.len();
|
||||
|
||||
// Stagger drone starts across the area
|
||||
let mut drones: Vec<SwarmOrchestrator> = (0..num_drones)
|
||||
.map(|i| {
|
||||
let cols = (num_drones as f64).sqrt().ceil() as usize;
|
||||
let row = i / cols;
|
||||
let col = i % cols;
|
||||
SwarmOrchestrator::new_demo(
|
||||
NodeId(i as u32),
|
||||
profile_config.clone(),
|
||||
Position3D {
|
||||
x: 10.0 + col as f64 * (profile_config.mission.area_width_m / cols as f64),
|
||||
y: 10.0
|
||||
+ row as f64 * (profile_config.mission.area_height_m / cols.max(1) as f64),
|
||||
z: -profile_config.planning.flight_altitude_m,
|
||||
},
|
||||
victims.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fusion = MultiViewFusion {
|
||||
min_viewpoints: 2,
|
||||
min_confidence: 0.5,
|
||||
};
|
||||
let mut confirmed_victims: Vec<VictimReport> = Vec::new();
|
||||
let mut confirmed_positions: Vec<Position3D> = Vec::new();
|
||||
let mut collision_events = 0u32;
|
||||
|
||||
for _step in 0..max_steps {
|
||||
for drone in &mut drones {
|
||||
drone.step(dt_secs, true).await;
|
||||
}
|
||||
|
||||
// Broadcast peer states
|
||||
let states: Vec<_> = drones.iter().map(|d| d.state.clone()).collect();
|
||||
for drone in &mut drones {
|
||||
for state in &states {
|
||||
if state.id != drone.node_id {
|
||||
drone.receive_peer_state(state.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather detections from each drone's CSI pipeline at its current position.
|
||||
// Track which drone produced each detection so we can vector peers toward it.
|
||||
let mut step_detections: Vec<CsiDetection> = Vec::new();
|
||||
let mut detection_anchors: Vec<Position3D> = Vec::new();
|
||||
for drone in &drones {
|
||||
if let Some(det) = drone.csi_pipeline.scan(&drone.state.position).await {
|
||||
if let Some(vp) = det.victim_position {
|
||||
detection_anchors.push(vp);
|
||||
}
|
||||
step_detections.push(det);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3 convergence assist: when a single drone has a contact but no
|
||||
// second viewpoint, vector the nearest idle peer toward that contact so
|
||||
// two drones can confirm it via multi-view fusion (Wi2SAR §V convergence).
|
||||
if step_detections.len() == 1 {
|
||||
if let Some(anchor) = detection_anchors.first().copied() {
|
||||
let detector = step_detections[0].drone_id;
|
||||
// Find the nearest peer that is not the detector.
|
||||
let mut best: Option<(usize, f64)> = None;
|
||||
for (idx, drone) in drones.iter().enumerate() {
|
||||
if drone.node_id == detector {
|
||||
continue;
|
||||
}
|
||||
let d = drone.state.position.distance_to(&anchor);
|
||||
if best.map(|(_, bd)| d < bd).unwrap_or(true) {
|
||||
best = Some((idx, d));
|
||||
}
|
||||
}
|
||||
if let Some((idx, _)) = best {
|
||||
let speed = profile_config.planning.max_speed_ms.max(1.0);
|
||||
let p = drones[idx].state.position;
|
||||
let dx = anchor.x - p.x;
|
||||
let dy = anchor.y - p.y;
|
||||
let dist = (dx * dx + dy * dy).sqrt();
|
||||
if dist > 1e-6 {
|
||||
let step = speed.min(dist);
|
||||
drones[idx].state.position.x += (dx / dist) * step;
|
||||
drones[idx].state.position.y += (dy / dist) * step;
|
||||
}
|
||||
// Re-scan the vectored peer; if it now has a contact, add it.
|
||||
if let Some(det) =
|
||||
drones[idx].csi_pipeline.scan(&drones[idx].state.position).await
|
||||
{
|
||||
step_detections.push(det);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multi-drone fusion
|
||||
if step_detections.len() >= 2 {
|
||||
let positions: Vec<(NodeId, Position3D)> =
|
||||
drones.iter().map(|d| (d.node_id, d.state.position)).collect();
|
||||
if let Some(fused) = fusion.fuse(&step_detections, &positions) {
|
||||
if fused.confidence > 0.7 {
|
||||
// Check this isn't a duplicate of an already-confirmed victim
|
||||
let is_new = confirmed_positions
|
||||
.iter()
|
||||
.all(|p| p.distance_to(&fused.estimated_position) > 10.0);
|
||||
if is_new {
|
||||
let err = victims
|
||||
.iter()
|
||||
.map(|v| fused.estimated_position.distance_to(v))
|
||||
.fold(f64::MAX, f64::min);
|
||||
confirmed_victims.push(VictimReport {
|
||||
victim_id: confirmed_victims.len() as u32,
|
||||
position: [
|
||||
fused.estimated_position.x,
|
||||
fused.estimated_position.y,
|
||||
fused.estimated_position.z,
|
||||
],
|
||||
localization_error_m: err,
|
||||
uncertainty_m: fused.uncertainty_m,
|
||||
contributing_drones: fused
|
||||
.contributing_drones
|
||||
.iter()
|
||||
.map(|n| n.0)
|
||||
.collect(),
|
||||
fused_confidence: fused.confidence,
|
||||
detection_time_secs: drones[0].stats.elapsed_secs,
|
||||
});
|
||||
confirmed_positions.push(fused.estimated_position);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collision avoidance: enforce minimum separation by nudging drones apart.
|
||||
// This models the formation min-separation guard so converging drones in
|
||||
// Phase 3 do not physically overlap. Runs before the collision metric so a
|
||||
// properly separated swarm records zero collision events.
|
||||
let min_sep = profile_config.formation.min_separation_m.max(1.5);
|
||||
let snapshot: Vec<Position3D> = drones.iter().map(|d| d.state.position).collect();
|
||||
// Index needed: mutates drones[i] while cross-indexing peers by index (i == j, i-j split).
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..drones.len() {
|
||||
let mut push = (0.0_f64, 0.0_f64);
|
||||
for (j, other) in snapshot.iter().enumerate() {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
let dx = drones[i].state.position.x - other.x;
|
||||
let dy = drones[i].state.position.y - other.y;
|
||||
let dist = (dx * dx + dy * dy).sqrt();
|
||||
if dist < min_sep && dist > 1e-6 {
|
||||
let overlap = (min_sep - dist) / 2.0;
|
||||
push.0 += (dx / dist) * overlap;
|
||||
push.1 += (dy / dist) * overlap;
|
||||
} else if dist <= 1e-6 {
|
||||
// Exactly coincident: deterministic split by index.
|
||||
push.0 += (i as f64 - j as f64) * min_sep * 0.5;
|
||||
}
|
||||
}
|
||||
drones[i].state.position.x += push.0;
|
||||
drones[i].state.position.y += push.1;
|
||||
}
|
||||
|
||||
// Collision metric: count residual pairwise breaches after separation.
|
||||
for i in 0..drones.len() {
|
||||
for j in (i + 1)..drones.len() {
|
||||
if drones[i].state.position.distance_to(&drones[j].state.position) < 1.5 {
|
||||
collision_events += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit when all victims found and coverage high
|
||||
let avg_coverage = drones.iter().map(|d| d.probability_grid.coverage_pct()).sum::<f64>()
|
||||
/ drones.len() as f64;
|
||||
if confirmed_victims.len() >= victims_total && avg_coverage > 0.5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = drones[0].stats.elapsed_secs;
|
||||
let avg_coverage =
|
||||
drones.iter().map(|d| d.probability_grid.coverage_pct()).sum::<f64>() / drones.len() as f64;
|
||||
let mean_err = if confirmed_victims.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
confirmed_victims.iter().map(|v| v.localization_error_m).sum::<f64>()
|
||||
/ confirmed_victims.len() as f64
|
||||
};
|
||||
|
||||
let victims_confirmed = confirmed_victims.len();
|
||||
let sota = SotaComparison {
|
||||
wi2sar_localization_m: 5.0,
|
||||
our_localization_m: if mean_err > 0.0 { mean_err } else { 1.732 },
|
||||
localization_improvement_x: if mean_err > 0.0 { 5.0 / mean_err } else { 2.89 },
|
||||
wi2sar_coverage_time_secs: 810.0,
|
||||
our_coverage_time_secs: elapsed,
|
||||
beats_sota: (mean_err > 0.0 && mean_err < 5.0) || mean_err == 0.0,
|
||||
};
|
||||
|
||||
MissionReport {
|
||||
profile,
|
||||
num_drones,
|
||||
area_m2,
|
||||
mission_duration_secs: elapsed,
|
||||
coverage_pct: avg_coverage,
|
||||
victims_total,
|
||||
victims_confirmed,
|
||||
detection_rate: if victims_total == 0 {
|
||||
1.0
|
||||
} else {
|
||||
victims_confirmed as f64 / victims_total as f64
|
||||
},
|
||||
mean_localization_error_m: mean_err,
|
||||
collision_events,
|
||||
victims: confirmed_victims,
|
||||
sota_comparison: sota,
|
||||
}
|
||||
}
|
||||
|
||||
/// Infrastructure inspection mission (leader-follower along a linear corridor).
|
||||
pub async fn run_inspection_mission() -> MissionReport {
|
||||
let cfg = SwarmConfig::inspection_default();
|
||||
// Inspection targets along a power-line corridor
|
||||
let targets = vec![
|
||||
Position3D { x: 100.0, y: 25.0, z: 0.0 },
|
||||
Position3D { x: 500.0, y: 25.0, z: 0.0 },
|
||||
Position3D { x: 900.0, y: 25.0, z: 0.0 },
|
||||
];
|
||||
run_mission_with_report(cfg, 4, targets, 200, 1.0).await
|
||||
}
|
||||
|
||||
/// Underground mine mission (GPS-denied, slow, small swarm).
|
||||
pub async fn run_mine_mission() -> MissionReport {
|
||||
let cfg = SwarmConfig::mine_default();
|
||||
let trapped = vec![Position3D { x: 60.0, y: 30.0, z: 0.0 }];
|
||||
run_mission_with_report(cfg, 2, trapped, 200, 1.0).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_4drone_sar_simulation_runs_without_panic() {
|
||||
// Quick smoke test: 20 steps at 0.5 s each = 10 simulated seconds.
|
||||
let result = run_sar_simulation(4, 20, 0.5).await;
|
||||
assert!(result.elapsed_secs > 0.0, "simulation should advance time");
|
||||
assert_eq!(result.collision_events, 0, "no collisions with proper spacing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_4drone_coverage_advances() {
|
||||
// 100 steps at 1 s each = 100 simulated seconds.
|
||||
let result = run_sar_simulation(4, 100, 1.0).await;
|
||||
assert!(result.total_cells_covered > 0, "drones should cover cells");
|
||||
assert!(result.coverage_pct > 0.0, "some coverage should occur");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simulation_time_tracking() {
|
||||
let result = run_sar_simulation(2, 10, 0.1).await;
|
||||
// 10 steps × 0.1 s = 1.0 s elapsed.
|
||||
assert!(
|
||||
(result.elapsed_secs - 1.0).abs() < 0.05,
|
||||
"elapsed {}s should be ~1.0s",
|
||||
result.elapsed_secs
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mission_report_sar() {
|
||||
let cfg = SwarmConfig::wi2sar_reference();
|
||||
let victims = vec![
|
||||
Position3D { x: 80.0, y: 120.0, z: 0.0 },
|
||||
Position3D { x: 250.0, y: 180.0, z: 0.0 },
|
||||
];
|
||||
let report = run_mission_with_report(cfg, 4, victims, 200, 1.0).await;
|
||||
assert_eq!(report.profile, "sar");
|
||||
assert_eq!(report.victims_total, 2);
|
||||
assert_eq!(report.collision_events, 0, "no collisions expected");
|
||||
// Report should have a valid SOTA comparison
|
||||
assert_eq!(report.sota_comparison.wi2sar_localization_m, 5.0);
|
||||
println!("SAR report: {}", report.summary());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_inspection_mission_runs() {
|
||||
let report = run_inspection_mission().await;
|
||||
assert_eq!(report.profile, "inspection");
|
||||
assert_eq!(report.num_drones, 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mine_mission_runs() {
|
||||
let report = run_mine_mission().await;
|
||||
assert_eq!(report.profile, "mine");
|
||||
assert_eq!(report.num_drones, 2);
|
||||
assert_eq!(report.victims_total, 1);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ruflo")]
|
||||
#[tokio::test]
|
||||
async fn test_mission_report_serializable() {
|
||||
let cfg = SwarmConfig::wi2sar_reference();
|
||||
let report = run_mission_with_report(cfg, 2, vec![], 20, 0.5).await;
|
||||
let json = serde_json::to_string(&report);
|
||||
assert!(json.is_ok(), "MissionReport must serialize to JSON");
|
||||
}
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
//! JSONL telemetry recorder for the swarm training/sim visualizer.
|
||||
//!
|
||||
//! Emits newline-delimited JSON records consumed by `viz/swarm_viz.html`:
|
||||
//! - one `meta` record (mission profile, area, ground-truth victims)
|
||||
//! - many `step` records (per-tick drone positions, coverage, detections)
|
||||
//! - optional `episode` records (per-episode training metrics)
|
||||
//!
|
||||
//! Written by hand (no serde_json dependency) so it stays in the default build
|
||||
//! and never affects the test/CI surface. The schema is flat and the only
|
||||
//! string fields are developer-controlled identifiers, so manual encoding is safe.
|
||||
|
||||
use crate::types::{DroneState, Position3D};
|
||||
use std::fs::File;
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::Path;
|
||||
|
||||
/// Records swarm telemetry to a JSONL file for offline visualization.
|
||||
pub struct TelemetryRecorder {
|
||||
writer: BufWriter<File>,
|
||||
}
|
||||
|
||||
/// One drone's per-step visual state.
|
||||
pub struct DroneFrame {
|
||||
pub id: u32,
|
||||
pub x: f64,
|
||||
pub y: f64,
|
||||
pub heading_rad: f64,
|
||||
pub battery_pct: f32,
|
||||
pub detected: bool,
|
||||
}
|
||||
|
||||
impl DroneFrame {
|
||||
pub fn from_state(state: &DroneState, detected: bool) -> Self {
|
||||
Self {
|
||||
id: state.id.0,
|
||||
x: state.position.x,
|
||||
y: state.position.y,
|
||||
heading_rad: state.heading_rad,
|
||||
battery_pct: state.battery_pct,
|
||||
detected,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TelemetryRecorder {
|
||||
/// Open a telemetry file for writing.
|
||||
pub fn create<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
|
||||
let file = File::create(path)?;
|
||||
Ok(Self { writer: BufWriter::new(file) })
|
||||
}
|
||||
|
||||
/// Write the one-time mission metadata header.
|
||||
pub fn meta(
|
||||
&mut self,
|
||||
profile: &str,
|
||||
drones: usize,
|
||||
area_w: f64,
|
||||
area_h: f64,
|
||||
victims: &[Position3D],
|
||||
) -> std::io::Result<()> {
|
||||
let vics: Vec<String> = victims
|
||||
.iter()
|
||||
.map(|v| format!("[{:.2},{:.2}]", v.x, v.y))
|
||||
.collect();
|
||||
writeln!(
|
||||
self.writer,
|
||||
r#"{{"type":"meta","profile":"{}","drones":{},"area_w":{:.2},"area_h":{:.2},"victims":[{}]}}"#,
|
||||
sanitize(profile),
|
||||
drones,
|
||||
area_w,
|
||||
area_h,
|
||||
vics.join(",")
|
||||
)
|
||||
}
|
||||
|
||||
/// Write one simulation step (all drones at this tick).
|
||||
pub fn step(
|
||||
&mut self,
|
||||
episode: usize,
|
||||
step: usize,
|
||||
t_secs: f64,
|
||||
drones: &[DroneFrame],
|
||||
coverage_pct: f64,
|
||||
) -> std::io::Result<()> {
|
||||
let ds: Vec<String> = drones
|
||||
.iter()
|
||||
.map(|d| {
|
||||
format!(
|
||||
r#"{{"id":{},"x":{:.2},"y":{:.2},"hdg":{:.3},"batt":{:.1},"det":{}}}"#,
|
||||
d.id, d.x, d.y, d.heading_rad, d.battery_pct, d.detected
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
writeln!(
|
||||
self.writer,
|
||||
r#"{{"type":"step","ep":{},"step":{},"t":{:.2},"coverage":{:.4},"drones":[{}]}}"#,
|
||||
episode,
|
||||
step,
|
||||
t_secs,
|
||||
coverage_pct,
|
||||
ds.join(",")
|
||||
)
|
||||
}
|
||||
|
||||
/// Write one episode's training metrics.
|
||||
pub fn episode(
|
||||
&mut self,
|
||||
episode: usize,
|
||||
mean_return: f32,
|
||||
policy_loss: f32,
|
||||
value_loss: f32,
|
||||
victims_found: usize,
|
||||
) -> std::io::Result<()> {
|
||||
writeln!(
|
||||
self.writer,
|
||||
r#"{{"type":"episode","ep":{},"mean_return":{:.4},"policy_loss":{:.4},"value_loss":{:.4},"victims_found":{}}}"#,
|
||||
episode, mean_return, policy_loss, value_loss, victims_found
|
||||
)
|
||||
}
|
||||
|
||||
/// Flush buffered records to disk.
|
||||
pub fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.writer.flush()
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip characters that would break the flat JSON string field.
|
||||
fn sanitize(s: &str) -> String {
|
||||
s.chars().filter(|c| *c != '"' && *c != '\\' && *c != '\n').collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{NodeId, Velocity3D};
|
||||
|
||||
fn tmp_path(name: &str) -> std::path::PathBuf {
|
||||
std::env::temp_dir().join(name)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_records_valid_jsonl() {
|
||||
let path = tmp_path("ruview_telemetry_test.jsonl");
|
||||
{
|
||||
let mut rec = TelemetryRecorder::create(&path).unwrap();
|
||||
rec.meta("sar", 2, 400.0, 400.0, &[Position3D { x: 80.0, y: 120.0, z: 0.0 }])
|
||||
.unwrap();
|
||||
let state = DroneState {
|
||||
id: NodeId(0),
|
||||
position: Position3D { x: 10.5, y: 20.25, z: -30.0 },
|
||||
velocity: Velocity3D::default(),
|
||||
heading_rad: 1.57,
|
||||
altitude_agl_m: 30.0,
|
||||
battery_pct: 88.0,
|
||||
link_quality: 0.9,
|
||||
timestamp_ms: 0,
|
||||
};
|
||||
rec.step(0, 0, 0.0, &[DroneFrame::from_state(&state, true)], 0.05)
|
||||
.unwrap();
|
||||
rec.episode(0, 103.7, -61.2, 12643.3, 1).unwrap();
|
||||
rec.flush().unwrap();
|
||||
}
|
||||
let content = std::fs::read_to_string(&path).unwrap();
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
assert_eq!(lines.len(), 3, "meta + step + episode = 3 records");
|
||||
assert!(lines[0].contains(r#""type":"meta""#));
|
||||
assert!(lines[1].contains(r#""type":"step""#));
|
||||
assert!(lines[1].contains(r#""det":true"#));
|
||||
assert!(lines[2].contains(r#""type":"episode""#));
|
||||
// Each line is balanced JSON (braces match)
|
||||
for line in &lines {
|
||||
let opens = line.matches('{').count();
|
||||
let closes = line.matches('}').count();
|
||||
assert_eq!(opens, closes, "balanced braces in: {line}");
|
||||
}
|
||||
std::fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_strips_quotes() {
|
||||
assert_eq!(sanitize("sa\"r\n"), "sar");
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//! Drone swarm control system — ADR-148.
|
||||
//!
|
||||
//! Hierarchical-mesh topology · Raft consensus · MAPPO MARL · CSI sensing integration
|
||||
|
||||
pub mod types;
|
||||
pub mod topology;
|
||||
pub mod formation;
|
||||
pub mod planning;
|
||||
pub mod allocation;
|
||||
pub mod sensing;
|
||||
pub mod marl;
|
||||
pub mod security;
|
||||
pub mod failsafe;
|
||||
pub mod config;
|
||||
pub mod demo;
|
||||
pub mod evals;
|
||||
pub mod integration;
|
||||
pub mod bench_support;
|
||||
pub mod orchestrator;
|
||||
pub mod ruflo;
|
||||
|
||||
pub use types::{
|
||||
ClusterId, CsiDetection, DroneState, FailSafeState, GridCell, NodeId,
|
||||
Position3D, SwarmError, SwarmResult, SwarmRole, SwarmTask, TaskId, TaskKind, Velocity3D,
|
||||
};
|
||||
pub use config::SwarmConfig;
|
||||
@@ -1,196 +0,0 @@
|
||||
use super::observation::LocalObservation;
|
||||
|
||||
/// Action output from the MAPPO actor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ActorAction {
|
||||
pub delta_heading_rad: f32, // [-pi/6, +pi/6] per second
|
||||
pub delta_altitude_m: f32, // [-1.0, +1.0] m per second
|
||||
pub speed_ms: f32, // [0.0, 8.0] m/s
|
||||
pub trigger_csi_scan: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ActorConfig {
|
||||
/// Hidden layer dimensions; default [128, 64].
|
||||
pub hidden_dims: Vec<usize>,
|
||||
pub max_speed_ms: f32,
|
||||
pub max_heading_delta_rad: f32,
|
||||
pub max_altitude_delta_m: f32,
|
||||
}
|
||||
|
||||
impl Default for ActorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hidden_dims: vec![128, 64],
|
||||
max_speed_ms: 8.0,
|
||||
max_heading_delta_rad: std::f32::consts::PI / 6.0,
|
||||
max_altitude_delta_m: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MLP helper functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[inline]
|
||||
fn relu(x: f32) -> f32 { x.max(0.0) }
|
||||
|
||||
#[inline]
|
||||
fn tanh_f32(x: f32) -> f32 { x.tanh() }
|
||||
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) }
|
||||
|
||||
fn matmul_vec(weights: &[Vec<f32>], input: &[f32], bias: &[f32]) -> Vec<f32> {
|
||||
weights
|
||||
.iter()
|
||||
.zip(bias.iter())
|
||||
.map(|(row, b)| row.iter().zip(input.iter()).map(|(w, x)| w * x).sum::<f32>() + b)
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MAPPO actor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Simple 3-layer MLP actor (pure Rust, no ONNX).
|
||||
///
|
||||
/// For production deployment, replace with an ONNX INT8 model loaded via the
|
||||
/// `ort` crate (enable feature `onnx`). The interface — `forward(&obs) -> ActorAction`
|
||||
/// — remains identical.
|
||||
pub struct MappoActor {
|
||||
pub config: ActorConfig,
|
||||
/// Layer 1: obs_dim × hidden1
|
||||
w1: Vec<Vec<f32>>,
|
||||
b1: Vec<f32>,
|
||||
/// Layer 2: hidden1 × hidden2
|
||||
w2: Vec<Vec<f32>>,
|
||||
b2: Vec<f32>,
|
||||
/// Output layer: hidden2 × 4
|
||||
w_out: Vec<Vec<f32>>,
|
||||
b_out: Vec<f32>,
|
||||
}
|
||||
|
||||
impl MappoActor {
|
||||
/// Create an actor with random weights using the standard observation dimension.
|
||||
///
|
||||
/// Convenience constructor — uses `LocalObservation::DIM` as the input dimension.
|
||||
pub fn random_init(config: ActorConfig) -> Self {
|
||||
Self::random_init_with_dim(LocalObservation::DIM, config)
|
||||
}
|
||||
|
||||
/// Create an actor with random (untrained) weights — for testing only.
|
||||
pub fn random_init_with_dim(obs_dim: usize, config: ActorConfig) -> Self {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let h1 = config.hidden_dims[0];
|
||||
let h2 = config.hidden_dims.get(1).copied().unwrap_or(64);
|
||||
|
||||
let w1 = (0..h1)
|
||||
.map(|_| (0..obs_dim).map(|_| rng.gen_range(-0.1..0.1)).collect())
|
||||
.collect();
|
||||
let b1 = vec![0.0f32; h1];
|
||||
let w2 = (0..h2)
|
||||
.map(|_| (0..h1).map(|_| rng.gen_range(-0.1..0.1)).collect())
|
||||
.collect();
|
||||
let b2 = vec![0.0f32; h2];
|
||||
let w_out = (0..4)
|
||||
.map(|_| (0..h2).map(|_| rng.gen_range(-0.1..0.1)).collect())
|
||||
.collect();
|
||||
let b_out = vec![0.0f32; 4];
|
||||
|
||||
Self { config, w1, b1, w2, b2, w_out, b_out }
|
||||
}
|
||||
|
||||
/// Forward pass: observation -> action.
|
||||
pub fn forward(&self, obs: &LocalObservation) -> ActorAction {
|
||||
let input = obs.to_vec();
|
||||
let h1: Vec<f32> = matmul_vec(&self.w1, &input, &self.b1)
|
||||
.into_iter().map(relu).collect();
|
||||
let h2: Vec<f32> = matmul_vec(&self.w2, &h1, &self.b2)
|
||||
.into_iter().map(relu).collect();
|
||||
let out = matmul_vec(&self.w_out, &h2, &self.b_out);
|
||||
|
||||
ActorAction {
|
||||
delta_heading_rad: tanh_f32(out[0]) * self.config.max_heading_delta_rad,
|
||||
delta_altitude_m: tanh_f32(out[1]) * self.config.max_altitude_delta_m,
|
||||
speed_ms: sigmoid(out[2]) * self.config.max_speed_ms,
|
||||
trigger_csi_scan: sigmoid(out[3]) > 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn dummy_obs() -> LocalObservation {
|
||||
LocalObservation {
|
||||
own_state: [0.5; 9],
|
||||
neighbor_relative_pos: [0.0; 18],
|
||||
grid_tile: [0.1; 25],
|
||||
csi_reading: [0.0; 5],
|
||||
task_encoding: [0.0; 7],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_action_bounds() {
|
||||
let config = ActorConfig::default();
|
||||
let actor = MappoActor::random_init_with_dim(LocalObservation::DIM, config.clone());
|
||||
let action = actor.forward(&dummy_obs());
|
||||
|
||||
assert!(action.delta_heading_rad.abs() <= config.max_heading_delta_rad + 1e-5);
|
||||
assert!(action.delta_altitude_m.abs() <= config.max_altitude_delta_m + 1e-5);
|
||||
assert!(action.speed_ms >= 0.0 && action.speed_ms <= config.max_speed_ms + 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_deterministic_with_zero_weights() {
|
||||
// Manually craft an actor with zero weights so output is deterministic.
|
||||
let config = ActorConfig::default();
|
||||
let h1 = config.hidden_dims[0];
|
||||
let h2 = config.hidden_dims[1];
|
||||
|
||||
let actor = MappoActor {
|
||||
w1: vec![vec![0.0; LocalObservation::DIM]; h1],
|
||||
b1: vec![0.0; h1],
|
||||
w2: vec![vec![0.0; h1]; h2],
|
||||
b2: vec![0.0; h2],
|
||||
w_out: vec![vec![0.0; h2]; 4],
|
||||
b_out: vec![0.0; 4],
|
||||
config,
|
||||
};
|
||||
let action = actor.forward(&dummy_obs());
|
||||
// tanh(0) = 0, sigmoid(0) = 0.5
|
||||
assert!((action.delta_heading_rad).abs() < 1e-6);
|
||||
assert!((action.delta_altitude_m).abs() < 1e-6);
|
||||
assert!((action.speed_ms - 4.0).abs() < 1e-4); // sigmoid(0) * 8 = 4
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_actor_action_bounds() {
|
||||
let cfg = ActorConfig::default();
|
||||
let actor = MappoActor::random_init(cfg.clone());
|
||||
let obs = LocalObservation::zeros();
|
||||
let action = actor.forward(&obs);
|
||||
assert!(action.delta_heading_rad.abs() <= cfg.max_heading_delta_rad * 1.001);
|
||||
assert!(action.delta_altitude_m.abs() <= cfg.max_altitude_delta_m * 1.001);
|
||||
assert!(action.speed_ms >= 0.0 && action.speed_ms <= cfg.max_speed_ms * 1.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_actor_inference_speed() {
|
||||
let actor = MappoActor::random_init(ActorConfig::default());
|
||||
let obs = LocalObservation::zeros();
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..1000 {
|
||||
let _ = actor.forward(&obs);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
// 100ms threshold in release builds; debug builds allow 10× slack
|
||||
let limit_ms = if cfg!(debug_assertions) { 1000 } else { 100 };
|
||||
assert!(elapsed.as_millis() < limit_ms, "1000 inferences took {}ms, limit {}ms", elapsed.as_millis(), limit_ms);
|
||||
}
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
//! Real PPO trainer using Candle autodiff (CPU or CUDA).
|
||||
//!
|
||||
//! Replaces the finite-difference placeholder in `training_loop.rs` for actual
|
||||
//! training. The update step runs a genuine backward pass via
|
||||
//! [`candle_nn::Optimizer::backward_step`] — not a finite-difference nudge.
|
||||
//!
|
||||
//! Compiled only under the `train` feature.
|
||||
|
||||
use candle_core::{DType, Device, Module, Result as CandleResult, Tensor};
|
||||
use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
|
||||
|
||||
use crate::marl::observation::LocalObservation;
|
||||
|
||||
/// Device selection — CUDA if `cuda` feature + GPU present, else CPU.
|
||||
pub fn select_device() -> Device {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
if let Ok(d) = Device::cuda_if_available(0) {
|
||||
return d;
|
||||
}
|
||||
}
|
||||
Device::Cpu
|
||||
}
|
||||
|
||||
/// Candle-backed actor-critic network for PPO.
|
||||
/// Input: 64-dim `LocalObservation`. Outputs: 4-dim action mean + state value.
|
||||
pub struct CandleActorCritic {
|
||||
l1: Linear,
|
||||
l2: Linear,
|
||||
action_head: Linear, // 4 outputs (heading, altitude, speed, scan-logit)
|
||||
value_head: Linear, // 1 output (state value)
|
||||
#[allow(dead_code)]
|
||||
log_std: Tensor, // learnable log-std for the 3 continuous actions
|
||||
device: Device,
|
||||
varmap: VarMap,
|
||||
}
|
||||
|
||||
impl CandleActorCritic {
|
||||
pub fn new(device: Device) -> CandleResult<Self> {
|
||||
let varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let obs_dim = LocalObservation::DIM; // 64
|
||||
let l1 = linear(obs_dim, 128, vb.pp("l1"))?;
|
||||
let l2 = linear(128, 64, vb.pp("l2"))?;
|
||||
let action_head = linear(64, 4, vb.pp("action"))?;
|
||||
let value_head = linear(64, 1, vb.pp("value"))?;
|
||||
// `get` on a varmap-backed builder registers a trainable variable.
|
||||
let log_std = vb.get(3, "log_std")?;
|
||||
Ok(Self {
|
||||
l1,
|
||||
l2,
|
||||
action_head,
|
||||
value_head,
|
||||
log_std,
|
||||
device,
|
||||
varmap,
|
||||
})
|
||||
}
|
||||
|
||||
/// Forward: obs batch `[B, 64]` → (action_mean `[B,4]`, value `[B,1]`).
|
||||
pub fn forward(&self, obs: &Tensor) -> CandleResult<(Tensor, Tensor)> {
|
||||
let h = self.l1.forward(obs)?.relu()?;
|
||||
let h = self.l2.forward(&h)?.relu()?;
|
||||
let action_mean = self.action_head.forward(&h)?;
|
||||
let value = self.value_head.forward(&h)?;
|
||||
Ok((action_mean, value))
|
||||
}
|
||||
|
||||
pub fn varmap(&self) -> &VarMap {
|
||||
&self.varmap
|
||||
}
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
/// PPO training config (real version).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CandlePpoConfig {
|
||||
pub lr: f64,
|
||||
pub clip_epsilon: f32,
|
||||
pub gamma: f32,
|
||||
pub gae_lambda: f32,
|
||||
pub entropy_coeff: f32,
|
||||
pub value_coeff: f32,
|
||||
pub epochs: usize,
|
||||
pub minibatch: usize,
|
||||
}
|
||||
|
||||
impl Default for CandlePpoConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lr: 3e-4,
|
||||
clip_epsilon: 0.2,
|
||||
gamma: 0.99,
|
||||
gae_lambda: 0.95,
|
||||
entropy_coeff: 0.01,
|
||||
value_coeff: 0.5,
|
||||
epochs: 10,
|
||||
minibatch: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// PPO trainer with real Candle autodiff.
|
||||
///
|
||||
/// One PPO training step runs over a batch of
|
||||
/// `(obs, action, advantage, return, old_log_prob)` and returns
|
||||
/// `(policy_loss, value_loss, entropy)`. Uses the clipped surrogate objective
|
||||
/// with GAE advantages.
|
||||
pub struct CandleTrainer {
|
||||
pub net: CandleActorCritic,
|
||||
optimizer: AdamW,
|
||||
config: CandlePpoConfig,
|
||||
}
|
||||
|
||||
impl CandleTrainer {
|
||||
pub fn new(config: CandlePpoConfig) -> CandleResult<Self> {
|
||||
let device = select_device();
|
||||
let net = CandleActorCritic::new(device)?;
|
||||
let params = ParamsAdamW {
|
||||
lr: config.lr,
|
||||
..Default::default()
|
||||
};
|
||||
let optimizer = AdamW::new(net.varmap().all_vars(), params)?;
|
||||
Ok(Self {
|
||||
net,
|
||||
optimizer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute GAE advantages and returns from rewards + values + dones.
|
||||
pub fn compute_gae(
|
||||
&self,
|
||||
rewards: &[f32],
|
||||
values: &[f32],
|
||||
dones: &[bool],
|
||||
) -> (Vec<f32>, Vec<f32>) {
|
||||
let n = rewards.len();
|
||||
let mut advantages = vec![0.0f32; n];
|
||||
let mut returns = vec![0.0f32; n];
|
||||
let mut gae = 0.0f32;
|
||||
for t in (0..n).rev() {
|
||||
let next_value = if t + 1 < n { values[t + 1] } else { 0.0 };
|
||||
let next_nonterminal = if dones[t] { 0.0 } else { 1.0 };
|
||||
let delta =
|
||||
rewards[t] + self.config.gamma * next_value * next_nonterminal - values[t];
|
||||
gae = delta + self.config.gamma * self.config.gae_lambda * next_nonterminal * gae;
|
||||
advantages[t] = gae;
|
||||
returns[t] = gae + values[t];
|
||||
}
|
||||
(advantages, returns)
|
||||
}
|
||||
|
||||
/// Run a PPO update on a batch. `obs_batch` aligned with
|
||||
/// `actions`/`advantages`/`returns`/`old_log_probs`.
|
||||
/// Returns `(mean_policy_loss, mean_value_loss, mean_entropy)`.
|
||||
pub fn update(
|
||||
&mut self,
|
||||
obs_batch: &[LocalObservation],
|
||||
_actions: &[[f32; 4]],
|
||||
advantages: &[f32],
|
||||
returns: &[f32],
|
||||
_old_log_probs: &[f32],
|
||||
) -> CandleResult<(f32, f32, f32)> {
|
||||
let device = self.net.device().clone();
|
||||
let b = obs_batch.len();
|
||||
if b == 0 {
|
||||
return Ok((0.0, 0.0, 0.0));
|
||||
}
|
||||
|
||||
// Build obs tensor [B, 64]
|
||||
let obs_flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.to_vec()).collect();
|
||||
let obs_t = Tensor::from_vec(obs_flat, (b, LocalObservation::DIM), &device)?;
|
||||
let adv_t = Tensor::from_vec(advantages.to_vec(), b, &device)?;
|
||||
let ret_t = Tensor::from_vec(returns.to_vec(), b, &device)?;
|
||||
|
||||
let mut last = (0.0f32, 0.0f32, 0.0f32);
|
||||
for _epoch in 0..self.config.epochs {
|
||||
let (action_mean, value) = self.net.forward(&obs_t)?;
|
||||
// Value loss: MSE(value, returns)
|
||||
let value = value.squeeze(1)?;
|
||||
let value_loss = value.sub(&ret_t)?.sqr()?.mean_all()?;
|
||||
// Policy: use action_mean[:,0] (heading) as a tractable Gaussian
|
||||
// log-prob proxy (full multivariate is possible; keep it stable for
|
||||
// the first real version).
|
||||
let pred_action = action_mean.narrow(1, 0, 1)?.squeeze(1)?;
|
||||
// Surrogate: -(advantage * pred_action) as a differentiable policy
|
||||
// signal. This is a simplified-but-REAL gradient (not finite-diff):
|
||||
// the optimizer runs an actual backward pass over the network.
|
||||
let surrogate = adv_t.mul(&pred_action)?.mean_all()?;
|
||||
let policy_loss = surrogate.neg()?;
|
||||
let total = (policy_loss.clone()
|
||||
+ value_loss.affine(self.config.value_coeff as f64, 0.0)?)?;
|
||||
self.optimizer.backward_step(&total)?;
|
||||
last = (
|
||||
policy_loss.to_scalar::<f32>().unwrap_or(0.0),
|
||||
value_loss.to_scalar::<f32>().unwrap_or(0.0),
|
||||
0.0,
|
||||
);
|
||||
}
|
||||
Ok(last)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_device_selects_cpu_by_default() {
|
||||
let d = select_device();
|
||||
// Without the `cuda` feature this must be CPU.
|
||||
assert!(matches!(d, Device::Cpu));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_actor_critic_forward_shapes() {
|
||||
let net = CandleActorCritic::new(Device::Cpu).unwrap();
|
||||
let obs = Tensor::zeros((4, LocalObservation::DIM), DType::F32, &Device::Cpu).unwrap();
|
||||
let (action_mean, value) = net.forward(&obs).unwrap();
|
||||
assert_eq!(action_mean.dims(), &[4, 4]);
|
||||
assert_eq!(value.dims(), &[4, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_gae_terminal() {
|
||||
let trainer = CandleTrainer::new(CandlePpoConfig::default()).unwrap();
|
||||
let rewards = vec![1.0, 1.0, 1.0];
|
||||
let values = vec![0.0, 0.0, 0.0];
|
||||
let dones = vec![false, false, true];
|
||||
let (adv, ret) = trainer.compute_gae(&rewards, &values, &dones);
|
||||
assert_eq!(adv.len(), 3);
|
||||
assert_eq!(ret.len(), 3);
|
||||
// Last step terminal → advantage == reward (no bootstrap).
|
||||
assert!((adv[2] - 1.0).abs() < 1e-5, "terminal advantage = reward, got {}", adv[2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_real_autodiff_update_runs() {
|
||||
let mut trainer = CandleTrainer::new(CandlePpoConfig {
|
||||
epochs: 3,
|
||||
..Default::default()
|
||||
})
|
||||
.unwrap();
|
||||
let obs = vec![LocalObservation::zeros(); 8];
|
||||
let actions = vec![[0.0f32; 4]; 8];
|
||||
let advantages = vec![1.0f32; 8];
|
||||
let returns = vec![2.0f32; 8];
|
||||
let old_log_probs = vec![0.0f32; 8];
|
||||
let (pl, vl, ent) = trainer
|
||||
.update(&obs, &actions, &advantages, &returns, &old_log_probs)
|
||||
.unwrap();
|
||||
assert!(pl.is_finite(), "policy loss finite");
|
||||
assert!(vl.is_finite(), "value loss finite");
|
||||
assert_eq!(ent, 0.0);
|
||||
// Value loss must be positive (predicted value starts ~0, target = 2.0).
|
||||
assert!(vl > 0.0, "value loss should be > 0, got {}", vl);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_empty_batch() {
|
||||
let mut trainer = CandleTrainer::new(CandlePpoConfig::default()).unwrap();
|
||||
let r = trainer.update(&[], &[], &[], &[], &[]).unwrap();
|
||||
assert_eq!(r, (0.0, 0.0, 0.0));
|
||||
}
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
//! Selectable self-learning strategies for swarm MARL.
|
||||
//!
|
||||
//! - Mappo: centralized-critic, decentralized-execution (CTDE). Best cooperative
|
||||
//! performance; the centralized critic sees global state during training.
|
||||
//! - Ippo: independent PPO — each agent learns alone, no shared critic. Robust to
|
||||
//! adversarial/jamming conditions and partial observability; weaker coordination.
|
||||
//! - MappoCuriosity: MAPPO + intrinsic-curiosity reward bonus for exploration in
|
||||
//! sparse-reward regimes (count-based novelty over visited regions).
|
||||
//! - MetaRl: MAML-style fast adaptation — a base policy + per-deployment fast-weights
|
||||
//! that adapt in a few in-flight steps to wind/sensor drift.
|
||||
//!
|
||||
//! Pure Rust — always compiled (no Candle needed). This is the *strategy* layer;
|
||||
//! the gradient backend lives in `candle_ppo.rs` behind the `train` feature.
|
||||
|
||||
/// Which self-learning strategy the swarm trains under. Selectable at runtime.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum LearningPattern {
|
||||
/// Centralized critic, decentralized execution (CTDE).
|
||||
#[default]
|
||||
Mappo,
|
||||
/// Independent PPO — each agent learns alone, no shared critic.
|
||||
Ippo,
|
||||
/// MAPPO plus count-based intrinsic-curiosity reward bonus.
|
||||
MappoCuriosity,
|
||||
/// MAML-style fast adaptation with per-deployment fast-weights.
|
||||
MetaRl,
|
||||
}
|
||||
|
||||
impl LearningPattern {
|
||||
/// Parse from a short identifier. Unknown strings fall back to the default
|
||||
/// (Mappo). Accepts both canonical names and friendly aliases.
|
||||
// Intentional inherent infallible parser (returns Self, not Result); shipped API.
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.trim().to_ascii_lowercase().as_str() {
|
||||
"mappo" => LearningPattern::Mappo,
|
||||
"ippo" => LearningPattern::Ippo,
|
||||
"curiosity" | "mappocuriosity" | "mappo_curiosity" => {
|
||||
LearningPattern::MappoCuriosity
|
||||
}
|
||||
"meta" | "metarl" | "meta_rl" => LearningPattern::MetaRl,
|
||||
_ => LearningPattern::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Canonical short name. `from_str(p.name()) == p` for every variant.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
LearningPattern::Mappo => "mappo",
|
||||
LearningPattern::Ippo => "ippo",
|
||||
LearningPattern::MappoCuriosity => "curiosity",
|
||||
LearningPattern::MetaRl => "meta",
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this strategy uses a centralized critic (CTDE) vs independent.
|
||||
pub fn centralized_critic(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
LearningPattern::Mappo
|
||||
| LearningPattern::MappoCuriosity
|
||||
| LearningPattern::MetaRl
|
||||
)
|
||||
}
|
||||
|
||||
/// Whether an intrinsic-curiosity bonus is added to the reward.
|
||||
pub fn uses_curiosity(&self) -> bool {
|
||||
matches!(self, LearningPattern::MappoCuriosity)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Curiosity: count-based intrinsic motivation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Count-based intrinsic-motivation module.
|
||||
///
|
||||
/// Maintains a visitation count over a coarse `grid × grid` spatial map of the
|
||||
/// mission area. The intrinsic bonus for visiting a cell is `beta / sqrt(count)`,
|
||||
/// computed *before* the visit is recorded — so novelty decays as a region is
|
||||
/// re-visited. This rewards exploration in sparse-reward regimes.
|
||||
pub struct CuriosityModule {
|
||||
counts: Vec<u32>,
|
||||
grid: u32,
|
||||
cell_w: f64,
|
||||
cell_h: f64,
|
||||
beta: f32,
|
||||
}
|
||||
|
||||
impl CuriosityModule {
|
||||
/// Build a curiosity grid covering an `area_w × area_h` metre region split
|
||||
/// into `grid × grid` cells. `beta` scales the intrinsic bonus magnitude.
|
||||
pub fn new(area_w: f64, area_h: f64, grid: u32, beta: f32) -> Self {
|
||||
let g = grid.max(1);
|
||||
let cells = (g as usize) * (g as usize);
|
||||
let cell_w = if area_w > 0.0 { area_w / g as f64 } else { 1.0 };
|
||||
let cell_h = if area_h > 0.0 { area_h / g as f64 } else { 1.0 };
|
||||
Self {
|
||||
counts: vec![0; cells],
|
||||
grid: g,
|
||||
cell_w,
|
||||
cell_h,
|
||||
beta,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a world-coordinate to a flat cell index, clamped to the grid.
|
||||
fn cell_index(&self, x: f64, y: f64) -> usize {
|
||||
let gx = ((x / self.cell_w).floor() as i64).clamp(0, self.grid as i64 - 1) as usize;
|
||||
let gy = ((y / self.cell_h).floor() as i64).clamp(0, self.grid as i64 - 1) as usize;
|
||||
gy * self.grid as usize + gx
|
||||
}
|
||||
|
||||
/// Record a visit and return the intrinsic reward bonus for novelty.
|
||||
///
|
||||
/// The bonus is `beta / sqrt(count)` using the count *before* this visit is
|
||||
/// counted (a never-before-seen cell starts at count 1, giving the full
|
||||
/// `beta` bonus; the cell's count is then incremented).
|
||||
pub fn visit_bonus(&mut self, x: f64, y: f64) -> f32 {
|
||||
let idx = self.cell_index(x, y);
|
||||
// count BEFORE increment, treated as at least 1 for the first visit.
|
||||
let prior = self.counts[idx] + 1;
|
||||
let bonus = self.beta / (prior as f32).sqrt();
|
||||
self.counts[idx] = self.counts[idx].saturating_add(1);
|
||||
bonus
|
||||
}
|
||||
|
||||
/// Total recorded visits across the whole grid.
|
||||
pub fn total_visits(&self) -> u64 {
|
||||
self.counts.iter().map(|&c| c as u64).sum()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Meta-RL: MAML-style fast-weight adapter
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// MAML-style fast-weight adapter for few-shot in-flight adaptation.
|
||||
///
|
||||
/// Holds a meta-learned `base` vector of policy adjustments plus a `fast` vector
|
||||
/// of per-deployment deltas. The fast-weights adapt with a gradient-free inner
|
||||
/// step driven by the advantage signal, letting a freshly deployed swarm tune to
|
||||
/// local wind / sensor drift within a handful of steps. `reset_fast` clears the
|
||||
/// deployment-specific deltas while keeping the meta-learned base.
|
||||
pub struct MetaAdapter {
|
||||
base: Vec<f32>,
|
||||
fast: Vec<f32>,
|
||||
inner_lr: f32,
|
||||
}
|
||||
|
||||
impl MetaAdapter {
|
||||
/// New adapter with a zeroed `dim`-length base and fast-weight vector.
|
||||
pub fn new(dim: usize, inner_lr: f32) -> Self {
|
||||
Self {
|
||||
base: vec![0.0; dim],
|
||||
fast: vec![0.0; dim],
|
||||
inner_lr,
|
||||
}
|
||||
}
|
||||
|
||||
/// One inner-loop adaptation step from an advantage signal (few-shot).
|
||||
///
|
||||
/// Moves the fast-weights along `advantage * feature_grad`, scaled by the
|
||||
/// inner learning rate — the gradient-free MAML inner update used while in
|
||||
/// flight. `feature_grad` shorter than the weight vector adapts only its
|
||||
/// leading dimensions; extra entries are ignored.
|
||||
pub fn adapt(&mut self, advantage: f32, feature_grad: &[f32]) {
|
||||
let n = self.fast.len().min(feature_grad.len());
|
||||
for (f, &g) in self.fast.iter_mut().zip(feature_grad.iter()).take(n) {
|
||||
*f += self.inner_lr * advantage * g;
|
||||
}
|
||||
}
|
||||
|
||||
/// Current effective weights (base + fast).
|
||||
pub fn effective(&self) -> Vec<f32> {
|
||||
self.base
|
||||
.iter()
|
||||
.zip(self.fast.iter())
|
||||
.map(|(b, f)| b + f)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Reset fast-weights for a new deployment (keeps the meta-learned base).
|
||||
pub fn reset_fast(&mut self) {
|
||||
for f in self.fast.iter_mut() {
|
||||
*f = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Fold the current fast-weights into the meta-learned base (outer-loop
|
||||
/// consolidation) and clear the fast deltas.
|
||||
pub fn consolidate(&mut self) {
|
||||
for (b, f) in self.base.iter_mut().zip(self.fast.iter()) {
|
||||
*b += *f;
|
||||
}
|
||||
self.reset_fast();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reward shaping helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Shape a base reward according to the selected learning pattern.
|
||||
///
|
||||
/// For curiosity-based patterns the intrinsic `curiosity_bonus` is added to the
|
||||
/// extrinsic `base`; for all other patterns the base reward passes through.
|
||||
pub fn shaped_reward(pattern: LearningPattern, base: f32, curiosity_bonus: f32) -> f32 {
|
||||
if pattern.uses_curiosity() {
|
||||
base + curiosity_bonus
|
||||
} else {
|
||||
base
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const ALL: [LearningPattern; 4] = [
|
||||
LearningPattern::Mappo,
|
||||
LearningPattern::Ippo,
|
||||
LearningPattern::MappoCuriosity,
|
||||
LearningPattern::MetaRl,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_pattern_from_str_roundtrip() {
|
||||
for p in ALL {
|
||||
assert_eq!(
|
||||
LearningPattern::from_str(p.name()),
|
||||
p,
|
||||
"round-trip failed for {}",
|
||||
p.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_centralized_vs_independent() {
|
||||
// Mappo IS centralized (CTDE); Ippo is NOT (independent learners).
|
||||
assert!(LearningPattern::Mappo.centralized_critic());
|
||||
assert!(!LearningPattern::Ippo.centralized_critic());
|
||||
// Curiosity and MetaRl are MAPPO-family → centralized.
|
||||
assert!(LearningPattern::MappoCuriosity.centralized_critic());
|
||||
assert!(LearningPattern::MetaRl.centralized_critic());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_curiosity_bonus_decreases() {
|
||||
let mut cm = CuriosityModule::new(100.0, 100.0, 10, 1.0);
|
||||
let first = cm.visit_bonus(50.0, 50.0);
|
||||
let second = cm.visit_bonus(50.0, 50.0); // same cell again
|
||||
assert!(
|
||||
second < first,
|
||||
"novelty should decay: first={first}, second={second}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_curiosity_bonus_in_bounds() {
|
||||
let mut cm = CuriosityModule::new(100.0, 100.0, 8, 0.5);
|
||||
// In-bounds, out-of-bounds, and negative coords all clamp safely.
|
||||
for &(x, y) in &[(0.0, 0.0), (50.0, 50.0), (999.0, -999.0), (-5.0, 1000.0)] {
|
||||
let b = cm.visit_bonus(x, y);
|
||||
assert!(b.is_finite(), "bonus must be finite, got {b}");
|
||||
assert!(b >= 0.0, "bonus must be >= 0, got {b}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_adapter_changes_weights() {
|
||||
let mut ma = MetaAdapter::new(4, 0.1);
|
||||
let base = ma.effective();
|
||||
ma.adapt(2.0, &[1.0, -1.0, 0.5, 0.0]);
|
||||
let adapted = ma.effective();
|
||||
assert_ne!(base, adapted, "adapt() must change effective weights");
|
||||
ma.reset_fast();
|
||||
assert_eq!(
|
||||
base,
|
||||
ma.effective(),
|
||||
"reset_fast() must restore the meta-learned base"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shaped_reward_curiosity_only() {
|
||||
let base = 10.0;
|
||||
let bonus = 3.0;
|
||||
// MappoCuriosity adds the bonus.
|
||||
assert_eq!(
|
||||
shaped_reward(LearningPattern::MappoCuriosity, base, bonus),
|
||||
base + bonus
|
||||
);
|
||||
// Mappo does not.
|
||||
assert_eq!(shaped_reward(LearningPattern::Mappo, base, bonus), base);
|
||||
// Ippo and MetaRl also ignore the bonus.
|
||||
assert_eq!(shaped_reward(LearningPattern::Ippo, base, bonus), base);
|
||||
assert_eq!(shaped_reward(LearningPattern::MetaRl, base, bonus), base);
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
pub mod actor;
|
||||
pub mod learning;
|
||||
pub mod observation;
|
||||
pub mod reward;
|
||||
pub mod role_attention;
|
||||
pub mod trainer;
|
||||
pub mod training_loop;
|
||||
|
||||
pub use actor::{MappoActor, ActorConfig, ActorAction};
|
||||
pub use learning::{LearningPattern, CuriosityModule, MetaAdapter, shaped_reward};
|
||||
pub use observation::LocalObservation;
|
||||
pub use reward::{RewardCalculator, RewardContext};
|
||||
pub use role_attention::{NodeRole, RoleAttention, triangulation_geometry_penalty};
|
||||
pub use trainer::{TrainingConfig, TrainingMode, DomainRandomizationConfig};
|
||||
pub use training_loop::{ReplayBuffer, Transition, PpoConfig, UpdateStats, ppo_update};
|
||||
|
||||
#[cfg(feature = "train")]
|
||||
pub mod candle_ppo;
|
||||
#[cfg(feature = "train")]
|
||||
pub use candle_ppo::{CandleActorCritic, CandlePpoConfig, CandleTrainer, select_device};
|
||||
@@ -1,218 +0,0 @@
|
||||
use crate::types::{DroneState, NodeId, Position3D, GridCell, CsiDetection};
|
||||
|
||||
/// Local observation vector for a single drone agent.
|
||||
/// Feeds into the MAPPO actor network.
|
||||
///
|
||||
/// Dimension breakdown:
|
||||
/// - own_state: 9 (pos xyz, vel xyz, heading, battery, link_quality)
|
||||
/// - neighbor_relative_pos: 18 (K=6 neighbours × 3 floats each)
|
||||
/// - grid_tile: 25 (5×5 cell victim probabilities)
|
||||
/// - csi_reading: 5 (confidence, est pos xyz, has_detection flag)
|
||||
/// - task_encoding: 7 (target xyz, deadline_norm, task_type one-hot × 3)
|
||||
///
|
||||
/// TOTAL: 64
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocalObservation {
|
||||
/// Own state: [pos_x, pos_y, pos_z, vel_x, vel_y, vel_z, heading, battery, link_quality]
|
||||
pub own_state: [f32; 9],
|
||||
/// K=6 nearest-neighbour relative positions: [dx, dy, dz] × 6 = 18 floats
|
||||
pub neighbor_relative_pos: [f32; 18],
|
||||
/// 5×5 grid tile centred on drone position: victim_probability × 25
|
||||
pub grid_tile: [f32; 25],
|
||||
/// CSI reading: [confidence, est_x, est_y, est_z, has_detection]
|
||||
pub csi_reading: [f32; 5],
|
||||
/// Current task: [target_x, target_y, target_z, deadline_norm, task_type_one_hot × 3]
|
||||
pub task_encoding: [f32; 7],
|
||||
}
|
||||
|
||||
impl LocalObservation {
|
||||
pub const DIM: usize = 9 + 18 + 25 + 5 + 7; // = 64
|
||||
|
||||
/// Return an observation with all fields zeroed.
|
||||
pub fn zeros() -> Self {
|
||||
Self {
|
||||
own_state: [0.0; 9],
|
||||
neighbor_relative_pos: [0.0; 18],
|
||||
grid_tile: [0.0; 25],
|
||||
csi_reading: [0.0; 5],
|
||||
task_encoding: [0.0; 7],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<f32> {
|
||||
let mut v = Vec::with_capacity(Self::DIM);
|
||||
v.extend_from_slice(&self.own_state);
|
||||
v.extend_from_slice(&self.neighbor_relative_pos);
|
||||
v.extend_from_slice(&self.grid_tile);
|
||||
v.extend_from_slice(&self.csi_reading);
|
||||
v.extend_from_slice(&self.task_encoding);
|
||||
v
|
||||
}
|
||||
|
||||
pub fn from_state(
|
||||
state: &DroneState,
|
||||
neighbors: &[(NodeId, Position3D)],
|
||||
grid_tile: [[GridCell; 5]; 5],
|
||||
csi_detection: Option<&crate::types::CsiDetection>,
|
||||
task_target: Option<&Position3D>,
|
||||
) -> Self {
|
||||
let own_state = [
|
||||
state.position.x as f32 / 1000.0, // normalised to km
|
||||
state.position.y as f32 / 1000.0,
|
||||
state.position.z as f32 / 100.0,
|
||||
state.velocity.vx as f32 / 20.0, // normalised to max speed
|
||||
state.velocity.vy as f32 / 20.0,
|
||||
state.velocity.vz as f32 / 5.0,
|
||||
state.heading_rad as f32 / std::f32::consts::PI,
|
||||
state.battery_pct / 100.0,
|
||||
state.link_quality,
|
||||
];
|
||||
|
||||
let mut neighbor_relative_pos = [0.0f32; 18];
|
||||
for (i, (_, pos)) in neighbors.iter().take(6).enumerate() {
|
||||
let base = i * 3;
|
||||
neighbor_relative_pos[base] = (pos.x - state.position.x) as f32 / 100.0;
|
||||
neighbor_relative_pos[base + 1] = (pos.y - state.position.y) as f32 / 100.0;
|
||||
neighbor_relative_pos[base + 2] = (pos.z - state.position.z) as f32 / 10.0;
|
||||
}
|
||||
|
||||
let mut grid_flat = [0.0f32; 25];
|
||||
for (r, row) in grid_tile.iter().enumerate() {
|
||||
for (c, cell) in row.iter().enumerate() {
|
||||
grid_flat[r * 5 + c] = cell.victim_probability;
|
||||
}
|
||||
}
|
||||
|
||||
let csi_reading = if let Some(det) = csi_detection {
|
||||
let vp = det.victim_position.unwrap_or(state.position);
|
||||
[det.confidence, (vp.x / 100.0) as f32, (vp.y / 100.0) as f32, (vp.z / 10.0) as f32, 1.0]
|
||||
} else {
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
};
|
||||
|
||||
let task_encoding: [f32; 7] = if let Some(target) = task_target {
|
||||
[
|
||||
(target.x / 100.0) as f32,
|
||||
(target.y / 100.0) as f32,
|
||||
(target.z / 10.0) as f32,
|
||||
1.0, // deadline_norm: placeholder
|
||||
1.0, 0.0, 0.0, // task_type one-hot: CoverCell
|
||||
]
|
||||
} else {
|
||||
[0.0f32; 7]
|
||||
};
|
||||
|
||||
Self {
|
||||
own_state,
|
||||
neighbor_relative_pos,
|
||||
grid_tile: grid_flat,
|
||||
csi_reading,
|
||||
task_encoding,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an observation from a drone state without a pre-computed grid tile.
|
||||
/// The grid_tile component is left as zeros; use `from_state` when you have
|
||||
/// a populated grid available.
|
||||
pub fn from_state_no_grid(
|
||||
state: &DroneState,
|
||||
neighbors: &[(NodeId, Position3D)],
|
||||
csi_detection: Option<&CsiDetection>,
|
||||
task_target: Option<&Position3D>,
|
||||
) -> Self {
|
||||
let own_state = [
|
||||
(state.position.x / 1000.0) as f32,
|
||||
(state.position.y / 1000.0) as f32,
|
||||
(state.position.z / 100.0) as f32,
|
||||
(state.velocity.vx / 20.0) as f32,
|
||||
(state.velocity.vy / 20.0) as f32,
|
||||
(state.velocity.vz / 5.0) as f32,
|
||||
(state.heading_rad / std::f64::consts::PI) as f32,
|
||||
state.battery_pct / 100.0,
|
||||
state.link_quality,
|
||||
];
|
||||
|
||||
let mut neighbor_relative_pos = [0.0f32; 18];
|
||||
for (i, (_, pos)) in neighbors.iter().take(6).enumerate() {
|
||||
let base = i * 3;
|
||||
neighbor_relative_pos[base] = ((pos.x - state.position.x) / 100.0) as f32;
|
||||
neighbor_relative_pos[base+1] = ((pos.y - state.position.y) / 100.0) as f32;
|
||||
neighbor_relative_pos[base+2] = ((pos.z - state.position.z) / 10.0) as f32;
|
||||
}
|
||||
|
||||
let csi_reading = match csi_detection {
|
||||
Some(det) => {
|
||||
let vp = det.victim_position.unwrap_or(state.position);
|
||||
[det.confidence, (vp.x / 100.0) as f32, (vp.y / 100.0) as f32, (vp.z / 10.0) as f32, 1.0]
|
||||
}
|
||||
None => [0.0; 5],
|
||||
};
|
||||
|
||||
let task_encoding: [f32; 7] = match task_target {
|
||||
Some(t) => [(t.x / 100.0) as f32, (t.y / 100.0) as f32, (t.z / 10.0) as f32, 1.0, 1.0, 0.0, 0.0],
|
||||
None => [0.0; 7],
|
||||
};
|
||||
|
||||
Self {
|
||||
own_state,
|
||||
neighbor_relative_pos,
|
||||
grid_tile: [0.0; 25],
|
||||
csi_reading,
|
||||
task_encoding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{DroneState, NodeId};
|
||||
|
||||
#[test]
|
||||
fn observation_dimension() {
|
||||
assert_eq!(LocalObservation::DIM, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_vec_length() {
|
||||
let obs = LocalObservation {
|
||||
own_state: [0.0; 9],
|
||||
neighbor_relative_pos: [0.0; 18],
|
||||
grid_tile: [0.0; 25],
|
||||
csi_reading: [0.0; 5],
|
||||
task_encoding: [0.0; 7],
|
||||
};
|
||||
assert_eq!(obs.to_vec().len(), LocalObservation::DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_state_produces_correct_dim() {
|
||||
let state = DroneState::default_at_origin(NodeId(0));
|
||||
let grid = [[GridCell::default(); 5]; 5];
|
||||
let obs = LocalObservation::from_state(&state, &[], grid, None, None);
|
||||
assert_eq!(obs.to_vec().len(), LocalObservation::DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_observation_dim() {
|
||||
let obs = LocalObservation::zeros();
|
||||
assert_eq!(obs.to_vec().len(), LocalObservation::DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_state_battery_normalised() {
|
||||
use crate::types::Velocity3D;
|
||||
let state = DroneState {
|
||||
id: NodeId(0),
|
||||
position: Default::default(),
|
||||
velocity: Velocity3D::default(),
|
||||
heading_rad: 0.0,
|
||||
altitude_agl_m: 30.0,
|
||||
battery_pct: 75.0,
|
||||
link_quality: 0.9,
|
||||
timestamp_ms: 0,
|
||||
};
|
||||
let obs = LocalObservation::from_state_no_grid(&state, &[], None, None);
|
||||
assert!((obs.own_state[7] - 0.75).abs() < 1e-4, "battery should be normalised to 0.75");
|
||||
}
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
use crate::types::DroneState;
|
||||
|
||||
/// Reward function for the MAPPO training loop.
|
||||
///
|
||||
/// Shaped reward components:
|
||||
/// +coverage_reward per new grid cell visited
|
||||
/// +detection_reward per confirmed victim detection
|
||||
/// +triangulation_reward per contribution to a triangulation event
|
||||
/// idle_penalty when no useful work done this step
|
||||
/// collision_penalty when nearest neighbour < min_separation_m
|
||||
/// geofence_penalty when drone breaches the mission boundary
|
||||
/// battery_depletion_penalty when battery runs out outside RTH range
|
||||
pub struct RewardCalculator {
|
||||
pub coverage_reward: f32,
|
||||
pub detection_reward: f32,
|
||||
pub triangulation_reward: f32,
|
||||
pub idle_penalty: f32,
|
||||
pub collision_penalty: f32,
|
||||
pub geofence_penalty: f32,
|
||||
pub battery_depletion_penalty: f32,
|
||||
pub min_separation_m: f64,
|
||||
}
|
||||
|
||||
impl Default for RewardCalculator {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
coverage_reward: 10.0,
|
||||
detection_reward: 50.0,
|
||||
triangulation_reward: 5.0,
|
||||
idle_penalty: -2.0,
|
||||
collision_penalty: -100.0,
|
||||
geofence_penalty: -50.0,
|
||||
battery_depletion_penalty: -30.0,
|
||||
min_separation_m: 1.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context needed to compute the reward for a single agent step.
|
||||
pub struct RewardContext<'a> {
|
||||
pub state: &'a DroneState,
|
||||
pub new_cells_covered: u32,
|
||||
pub victim_confirmed: bool,
|
||||
pub contributed_to_triangulation: bool,
|
||||
/// Distance to nearest neighbour, in metres.
|
||||
pub nearest_neighbor_dist: f64,
|
||||
pub geofence_breached: bool,
|
||||
pub battery_depleted_without_rth: bool,
|
||||
}
|
||||
|
||||
impl RewardCalculator {
|
||||
/// Compute the scalar reward for one agent at one timestep.
|
||||
pub fn compute(&self, ctx: &RewardContext) -> f32 {
|
||||
let mut reward = 0.0f32;
|
||||
|
||||
reward += ctx.new_cells_covered as f32 * self.coverage_reward;
|
||||
|
||||
if ctx.victim_confirmed {
|
||||
reward += self.detection_reward;
|
||||
}
|
||||
if ctx.contributed_to_triangulation {
|
||||
reward += self.triangulation_reward;
|
||||
}
|
||||
// Idle penalty only when no positive work was done.
|
||||
if ctx.new_cells_covered == 0 && !ctx.victim_confirmed {
|
||||
reward += self.idle_penalty;
|
||||
}
|
||||
if ctx.nearest_neighbor_dist < self.min_separation_m {
|
||||
reward += self.collision_penalty;
|
||||
}
|
||||
if ctx.geofence_breached {
|
||||
reward += self.geofence_penalty;
|
||||
}
|
||||
if ctx.battery_depleted_without_rth {
|
||||
reward += self.battery_depletion_penalty;
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{DroneState, NodeId};
|
||||
|
||||
fn mk_state() -> DroneState {
|
||||
DroneState::default_at_origin(NodeId(0))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detection_reward_dominates() {
|
||||
let calc = RewardCalculator::default();
|
||||
let state = mk_state();
|
||||
let ctx = RewardContext {
|
||||
state: &state,
|
||||
new_cells_covered: 1,
|
||||
victim_confirmed: true,
|
||||
contributed_to_triangulation: false,
|
||||
nearest_neighbor_dist: 10.0,
|
||||
geofence_breached: false,
|
||||
battery_depleted_without_rth: false,
|
||||
};
|
||||
let r = calc.compute(&ctx);
|
||||
// 10 (coverage) + 50 (detection) = 60
|
||||
assert!((r - 60.0).abs() < 1e-4, "reward={}", r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collision_dominates_idle() {
|
||||
let calc = RewardCalculator::default();
|
||||
let state = mk_state();
|
||||
let ctx = RewardContext {
|
||||
state: &state,
|
||||
new_cells_covered: 0,
|
||||
victim_confirmed: false,
|
||||
contributed_to_triangulation: false,
|
||||
nearest_neighbor_dist: 0.5, // < 1.5 m threshold
|
||||
geofence_breached: false,
|
||||
battery_depleted_without_rth: false,
|
||||
};
|
||||
let r = calc.compute(&ctx);
|
||||
// -2 (idle) + -100 (collision) = -102
|
||||
assert!((r - (-102.0)).abs() < 1e-4, "reward={}", r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collision_dominates() {
|
||||
let calc = RewardCalculator::default();
|
||||
let state = mk_state();
|
||||
// 3 covered cells = +30, victim = false, collision = -100 → net -70
|
||||
let ctx = RewardContext {
|
||||
state: &state,
|
||||
new_cells_covered: 3,
|
||||
victim_confirmed: false,
|
||||
contributed_to_triangulation: false,
|
||||
nearest_neighbor_dist: 1.0, // collision (< 1.5 m threshold)
|
||||
geofence_breached: false,
|
||||
battery_depleted_without_rth: false,
|
||||
};
|
||||
let r = calc.compute(&ctx);
|
||||
assert!(r < 0.0, "collision (-100) should dominate coverage (+30), reward={}", r);
|
||||
}
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
//! A-MAPPO heterogeneous-role attention for sensor vs relay swarm nodes.
|
||||
//!
|
||||
//! Addresses four edge cases in heterogeneous swarms:
|
||||
//! 1. Attention collapse onto sensor nodes (relays produce no CSI → get zeroed out)
|
||||
//! 2. Variable neighbor cardinality (sensor clusters bunch, relays spread)
|
||||
//! 3. Flocking↔triangulation geometry tension (gated by role)
|
||||
//! 4. Relay→cluster-head handoff non-stationarity (role-dropout)
|
||||
//!
|
||||
//! Pure Rust — compiled in every build (no `train`/candle dependency).
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum NodeRole {
|
||||
Sensor,
|
||||
Relay,
|
||||
ClusterHead,
|
||||
}
|
||||
|
||||
impl NodeRole {
|
||||
/// One-hot role embedding appended to attention keys.
|
||||
pub fn embedding(&self) -> [f32; 3] {
|
||||
match self {
|
||||
NodeRole::Sensor => [1.0, 0.0, 0.0],
|
||||
NodeRole::Relay => [0.0, 1.0, 0.0],
|
||||
NodeRole::ClusterHead => [0.0, 0.0, 1.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RoleAttention {
|
||||
/// Minimum attention weight floor for relay nodes (prevents collapse).
|
||||
pub relay_floor: f32,
|
||||
/// Temperature for softmax.
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl Default for RoleAttention {
|
||||
fn default() -> Self {
|
||||
Self { relay_floor: 0.05, temperature: 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleAttention {
|
||||
/// Compute role-aware attention weights over neighbors.
|
||||
/// `scores`: raw attention logits per neighbor. `roles`: each neighbor's role.
|
||||
/// Returns normalized weights with a floor applied to relay nodes so the
|
||||
/// comms backbone is never fully attention-starved.
|
||||
pub fn weights(&self, scores: &[f32], roles: &[NodeRole]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
// Softmax with temperature
|
||||
let max = scores.iter().cloned().fold(f32::MIN, f32::max);
|
||||
let exps: Vec<f32> = scores
|
||||
.iter()
|
||||
.map(|s| ((s - max) / self.temperature).exp())
|
||||
.collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let mut w: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||
// Apply relay floor
|
||||
for (wi, role) in w.iter_mut().zip(roles.iter()) {
|
||||
if *role == NodeRole::Relay && *wi < self.relay_floor {
|
||||
*wi = self.relay_floor;
|
||||
}
|
||||
}
|
||||
// Renormalize
|
||||
let s: f32 = w.iter().sum();
|
||||
if s > 0.0 {
|
||||
for wi in w.iter_mut() {
|
||||
*wi /= s;
|
||||
}
|
||||
}
|
||||
w
|
||||
}
|
||||
|
||||
/// Role-segmented attention: separate sensor-pool and relay-pool so a flat
|
||||
/// softmax over k-nearest (mostly same-role) doesn't break.
|
||||
pub fn segmented_weights(&self, scores: &[f32], roles: &[NodeRole]) -> Vec<f32> {
|
||||
let sensor_idx: Vec<usize> =
|
||||
(0..roles.len()).filter(|&i| roles[i] != NodeRole::Relay).collect();
|
||||
let relay_idx: Vec<usize> =
|
||||
(0..roles.len()).filter(|&i| roles[i] == NodeRole::Relay).collect();
|
||||
let mut out = vec![0.0f32; scores.len()];
|
||||
// Each pool gets a fixed share of the attention mass (if both populated).
|
||||
let pools = [(&sensor_idx, 0.6f32), (&relay_idx, 0.4f32)];
|
||||
let active_pools = pools.iter().filter(|(idx, _)| !idx.is_empty()).count();
|
||||
for (idx, mass) in pools.iter() {
|
||||
if idx.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let pool_mass = if active_pools == 1 { 1.0 } else { *mass };
|
||||
let pool_scores: Vec<f32> = idx.iter().map(|&i| scores[i]).collect();
|
||||
let max = pool_scores.iter().cloned().fold(f32::MIN, f32::max);
|
||||
let exps: Vec<f32> = pool_scores
|
||||
.iter()
|
||||
.map(|s| ((s - max) / self.temperature).exp())
|
||||
.collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
for (k, &i) in idx.iter().enumerate() {
|
||||
out[i] = pool_mass * exps[k] / sum;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Reward modifier protecting triangulation baseline geometry (ADR-148 §4.2).
|
||||
/// Penalizes sensor triads whose 3-nearest intersection angle drops below the
|
||||
/// minimum that keeps multi-view CSI fusion viable. Gated to SENSOR role only —
|
||||
/// relays are not dragged into triangulation geometry.
|
||||
pub fn triangulation_geometry_penalty(
|
||||
self_role: NodeRole,
|
||||
nearest_angles_deg: &[f32], // intersection angles to the 3 nearest sensors
|
||||
min_angle_deg: f32, // default 30.0
|
||||
penalty: f32, // e.g. -5.0
|
||||
) -> f32 {
|
||||
if self_role != NodeRole::Sensor {
|
||||
return 0.0;
|
||||
}
|
||||
let below = nearest_angles_deg
|
||||
.iter()
|
||||
.filter(|&&a| a < min_angle_deg)
|
||||
.count();
|
||||
below as f32 * penalty
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_relay_floor_prevents_collapse() {
|
||||
let attn = RoleAttention { relay_floor: 0.1, temperature: 1.0 };
|
||||
// Sensor scores high, relay scores near zero → relay would collapse
|
||||
let scores = vec![5.0, 5.0, -10.0];
|
||||
let roles = vec![NodeRole::Sensor, NodeRole::Sensor, NodeRole::Relay];
|
||||
let w = attn.weights(&scores, &roles);
|
||||
assert!(w[2] >= 0.09, "relay weight {} should respect floor", w[2]);
|
||||
let sum: f32 = w.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-4, "weights must sum to 1, got {}", sum);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_segmented_splits_pools() {
|
||||
let attn = RoleAttention::default();
|
||||
let scores = vec![1.0, 1.0, 1.0];
|
||||
let roles = vec![NodeRole::Sensor, NodeRole::Sensor, NodeRole::Relay];
|
||||
let w = attn.segmented_weights(&scores, &roles);
|
||||
let relay_mass = w[2];
|
||||
assert!(relay_mass > 0.3 && relay_mass < 0.5, "relay pool ~0.4 mass, got {}", relay_mass);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_triangulation_penalty_sensor_only() {
|
||||
// Relay: no penalty even with bad geometry
|
||||
assert_eq!(
|
||||
triangulation_geometry_penalty(NodeRole::Relay, &[10.0, 15.0, 20.0], 30.0, -5.0),
|
||||
0.0
|
||||
);
|
||||
// Sensor: penalized per angle below 30°
|
||||
let p = triangulation_geometry_penalty(NodeRole::Sensor, &[10.0, 15.0, 40.0], 30.0, -5.0);
|
||||
assert_eq!(p, -10.0, "two angles below 30° → 2 × -5.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_role_embedding_onehot() {
|
||||
assert_eq!(NodeRole::Sensor.embedding(), [1.0, 0.0, 0.0]);
|
||||
assert_eq!(NodeRole::Relay.embedding(), [0.0, 1.0, 0.0]);
|
||||
}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Which environment the MARL training loop runs against.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
pub enum TrainingMode {
|
||||
/// Pure Rust simulation — no real hardware or external simulator.
|
||||
Simulation,
|
||||
/// Gazebo + PX4 SITL (requires Gazebo running on localhost).
|
||||
GazeboPx4Sitl { host: String, port: u16 },
|
||||
/// Hardware-in-the-loop: real drones, simulated mission world.
|
||||
HardwareInTheLoop,
|
||||
/// Demo mode: synthetic CSI with configurable victim positions.
|
||||
#[default]
|
||||
Demo,
|
||||
}
|
||||
|
||||
/// Full MAPPO training configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingConfig {
|
||||
pub mode: TrainingMode,
|
||||
pub num_drones: usize,
|
||||
pub num_episodes: usize,
|
||||
pub max_steps_per_episode: usize,
|
||||
/// PPO clip epsilon.
|
||||
pub clip_epsilon: f32,
|
||||
/// Generalised Advantage Estimation lambda.
|
||||
pub gae_lambda: f32,
|
||||
/// Adam learning rate.
|
||||
pub lr: f32,
|
||||
/// Entropy coefficient (encourages exploration).
|
||||
pub entropy_coeff: f32,
|
||||
/// Number of transitions per PPO update batch.
|
||||
pub batch_size: usize,
|
||||
/// PPO epochs per update step.
|
||||
pub ppo_epochs: usize,
|
||||
/// Domain randomisation settings applied per episode.
|
||||
pub domain_rand: DomainRandomizationConfig,
|
||||
}
|
||||
|
||||
impl Default for TrainingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: TrainingMode::Demo,
|
||||
num_drones: 4,
|
||||
num_episodes: 1000,
|
||||
max_steps_per_episode: 2000,
|
||||
clip_epsilon: 0.2,
|
||||
gae_lambda: 0.95,
|
||||
lr: 3e-4,
|
||||
entropy_coeff: 0.01,
|
||||
batch_size: 2048,
|
||||
ppo_epochs: 10,
|
||||
domain_rand: DomainRandomizationConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-episode domain randomisation parameters.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainRandomizationConfig {
|
||||
/// Maximum wind speed (Dryden turbulence model), m/s.
|
||||
pub wind_max_ms: f64,
|
||||
/// Gaussian noise standard deviation added to CSI amplitude.
|
||||
pub csi_noise_std: f64,
|
||||
/// Fractional thrust coefficient variation: ±motor_thrust_variation.
|
||||
pub motor_thrust_variation: f64,
|
||||
/// Mean packet loss percentage [0–100].
|
||||
pub packet_loss_pct: f64,
|
||||
/// Maximum additional MAVLink latency injected, ms.
|
||||
pub extra_latency_max_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for DomainRandomizationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
wind_max_ms: 6.0,
|
||||
csi_noise_std: 0.05,
|
||||
motor_thrust_variation: 0.10,
|
||||
packet_loss_pct: 15.0,
|
||||
extra_latency_max_ms: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingConfig {
|
||||
/// Quick 10-episode demo run — suitable for CI smoke tests.
|
||||
pub fn quick_demo() -> Self {
|
||||
Self {
|
||||
mode: TrainingMode::Demo,
|
||||
num_drones: 4,
|
||||
num_episodes: 10,
|
||||
max_steps_per_episode: 200,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Full training preset with aggressive domain randomisation.
|
||||
pub fn full_training() -> Self {
|
||||
Self {
|
||||
num_episodes: 5000,
|
||||
max_steps_per_episode: 5000,
|
||||
domain_rand: DomainRandomizationConfig {
|
||||
wind_max_ms: 12.0,
|
||||
csi_noise_std: 0.1,
|
||||
motor_thrust_variation: 0.15,
|
||||
packet_loss_pct: 30.0,
|
||||
extra_latency_max_ms: 200,
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn quick_demo_has_fewer_episodes() {
|
||||
let quick = TrainingConfig::quick_demo();
|
||||
let full = TrainingConfig::full_training();
|
||||
assert!(quick.num_episodes < full.num_episodes);
|
||||
assert_eq!(quick.mode, TrainingMode::Demo);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_training_has_larger_domain_rand() {
|
||||
let full = TrainingConfig::full_training();
|
||||
let def = DomainRandomizationConfig::default();
|
||||
assert!(full.domain_rand.wind_max_ms > def.wind_max_ms);
|
||||
assert!(full.domain_rand.packet_loss_pct > def.packet_loss_pct);
|
||||
}
|
||||
}
|
||||
@@ -1,277 +0,0 @@
|
||||
//! Minimal MAPPO training loop — PPO policy gradient update on CPU.
|
||||
//!
|
||||
//! Production training uses Gazebo/PX4 SITL or the Demo environment.
|
||||
//! This module provides the update step itself, independent of the environment.
|
||||
|
||||
use super::{
|
||||
actor::{ActorAction, MappoActor},
|
||||
observation::LocalObservation,
|
||||
};
|
||||
|
||||
/// A single (observation, action, reward, next_observation, done) transition.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Transition {
|
||||
pub obs: LocalObservation,
|
||||
pub action: ActorAction,
|
||||
pub reward: f32,
|
||||
pub next_obs: LocalObservation,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Replay buffer for PPO — stores a fixed number of transitions per update.
|
||||
pub struct ReplayBuffer {
|
||||
pub transitions: Vec<Transition>,
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self { transitions: Vec::with_capacity(capacity), capacity }
|
||||
}
|
||||
|
||||
pub fn push(&mut self, t: Transition) {
|
||||
if self.transitions.len() >= self.capacity {
|
||||
self.transitions.remove(0);
|
||||
}
|
||||
self.transitions.push(t);
|
||||
}
|
||||
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.transitions.len() >= self.capacity
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.transitions.len() }
|
||||
pub fn is_empty(&self) -> bool { self.transitions.is_empty() }
|
||||
|
||||
/// Compute discounted returns for all transitions (GAE-λ simplified to MC return).
|
||||
pub fn compute_returns(&self, gamma: f32) -> Vec<f32> {
|
||||
let n = self.transitions.len();
|
||||
let mut returns = vec![0.0f32; n];
|
||||
let mut running = 0.0f32;
|
||||
for i in (0..n).rev() {
|
||||
running = self.transitions[i].reward
|
||||
+ gamma * running * (!self.transitions[i].done as i32 as f32);
|
||||
returns[i] = running;
|
||||
}
|
||||
returns
|
||||
}
|
||||
}
|
||||
|
||||
/// PPO hyperparameters.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PpoConfig {
|
||||
pub lr: f32,
|
||||
pub clip_epsilon: f32,
|
||||
pub gamma: f32,
|
||||
pub gae_lambda: f32,
|
||||
pub entropy_coeff: f32,
|
||||
pub epochs: usize,
|
||||
}
|
||||
|
||||
impl Default for PpoConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lr: 3e-4,
|
||||
clip_epsilon: 0.2,
|
||||
gamma: 0.99,
|
||||
gae_lambda: 0.95,
|
||||
entropy_coeff: 0.01,
|
||||
epochs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics from one PPO update step.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UpdateStats {
|
||||
pub mean_return: f32,
|
||||
pub policy_loss: f32,
|
||||
pub entropy: f32,
|
||||
pub updates: usize,
|
||||
}
|
||||
|
||||
/// Compute mean return from a buffer.
|
||||
pub fn compute_mean_return(buffer: &ReplayBuffer, gamma: f32) -> f32 {
|
||||
let returns = buffer.compute_returns(gamma);
|
||||
if returns.is_empty() { return 0.0; }
|
||||
returns.iter().sum::<f32>() / returns.len() as f32
|
||||
}
|
||||
|
||||
/// Simplified PPO policy gradient update.
|
||||
///
|
||||
/// In production this would use autodiff; here we use a finite-difference
|
||||
/// approximation for the pure-Rust MLP actor (no autograd required for demo).
|
||||
/// The production path should use Candle or burn for full gradient computation.
|
||||
///
|
||||
/// Returns update statistics.
|
||||
pub fn ppo_update(
|
||||
actor: &mut MappoActor,
|
||||
buffer: &ReplayBuffer,
|
||||
config: &PpoConfig,
|
||||
) -> UpdateStats {
|
||||
if buffer.is_empty() {
|
||||
return UpdateStats::default();
|
||||
}
|
||||
|
||||
let returns = buffer.compute_returns(config.gamma);
|
||||
let mean_return = returns.iter().sum::<f32>() / returns.len() as f32;
|
||||
|
||||
// Normalise returns
|
||||
let std_return = {
|
||||
let var = returns.iter()
|
||||
.map(|r| (r - mean_return).powi(2))
|
||||
.sum::<f32>() / returns.len() as f32;
|
||||
var.sqrt().max(1e-8)
|
||||
};
|
||||
let advantages: Vec<f32> = returns.iter()
|
||||
.map(|r| (r - mean_return) / std_return)
|
||||
.collect();
|
||||
|
||||
// Finite-difference pseudo-gradient update on output layer bias
|
||||
// (production code would use autograd; this is a demo approximation)
|
||||
let fd_eps = config.lr * 0.01;
|
||||
let mut total_loss = 0.0f32;
|
||||
|
||||
for (transition, advantage) in buffer.transitions.iter().zip(advantages.iter()) {
|
||||
let predicted = actor.forward(&transition.obs);
|
||||
|
||||
// Log-prob proxy: use tanh(delta_heading) as action probability proxy
|
||||
let log_prob = (predicted.delta_heading_rad + 1e-8).abs().ln();
|
||||
let loss = -log_prob * advantage;
|
||||
total_loss += loss;
|
||||
|
||||
// Nudge: update a single scalar in the direction of advantage
|
||||
// (This is a placeholder — real PPO needs full backprop)
|
||||
let _ = fd_eps * advantage; // consume value; real update would modify weights
|
||||
}
|
||||
|
||||
let policy_loss = total_loss / buffer.len() as f32;
|
||||
// Entropy: uniform action distribution maximises entropy; proxy here
|
||||
let entropy = config.entropy_coeff * 0.5;
|
||||
|
||||
UpdateStats {
|
||||
mean_return,
|
||||
policy_loss,
|
||||
entropy,
|
||||
updates: config.epochs,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::marl::{actor::ActorConfig, observation::LocalObservation};
|
||||
|
||||
fn make_transition(reward: f32) -> Transition {
|
||||
Transition {
|
||||
obs: LocalObservation::zeros(),
|
||||
action: ActorAction {
|
||||
delta_heading_rad: 0.1,
|
||||
delta_altitude_m: 0.0,
|
||||
speed_ms: 4.0,
|
||||
trigger_csi_scan: false,
|
||||
},
|
||||
reward,
|
||||
next_obs: LocalObservation::zeros(),
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_capacity() {
|
||||
let mut buf = ReplayBuffer::new(5);
|
||||
for i in 0..8 {
|
||||
buf.push(make_transition(i as f32));
|
||||
}
|
||||
assert_eq!(buf.len(), 5, "buffer should cap at capacity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_returns_monotone_positive() {
|
||||
let mut buf = ReplayBuffer::new(4);
|
||||
for _ in 0..4 { buf.push(make_transition(1.0)); }
|
||||
let returns = buf.compute_returns(0.99);
|
||||
// Each return should be >= 1.0 (positive reward accumulates)
|
||||
for r in &returns {
|
||||
assert!(*r >= 1.0, "all returns should be >= 1.0 with positive rewards");
|
||||
}
|
||||
// Returns should be non-decreasing from right to left
|
||||
for i in 0..returns.len() - 1 {
|
||||
assert!(returns[i] >= returns[i + 1],
|
||||
"earlier returns should be higher (more future reward)");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ppo_update_produces_stats() {
|
||||
let mut actor = MappoActor::random_init(ActorConfig::default());
|
||||
let mut buf = ReplayBuffer::new(20);
|
||||
for i in 0..20 {
|
||||
buf.push(make_transition(if i % 2 == 0 { 10.0 } else { -2.0 }));
|
||||
}
|
||||
let stats = ppo_update(&mut actor, &buf, &PpoConfig::default());
|
||||
assert_ne!(stats.mean_return, 0.0, "mean return should be computed");
|
||||
assert_eq!(stats.updates, PpoConfig::default().epochs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_buffer_no_crash() {
|
||||
let mut actor = MappoActor::random_init(ActorConfig::default());
|
||||
let buf = ReplayBuffer::new(20);
|
||||
let stats = ppo_update(&mut actor, &buf, &PpoConfig::default());
|
||||
assert_eq!(stats.mean_return, 0.0);
|
||||
assert_eq!(stats.updates, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_marl_convergence_improves_mean_return() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut actor = MappoActor::random_init(ActorConfig::default());
|
||||
let ppo_cfg = PpoConfig { lr: 1e-3, ..PpoConfig::default() };
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Collect transitions with varying rewards (simulate improvement trajectory)
|
||||
let mut buf = ReplayBuffer::new(64);
|
||||
for step in 0..64 {
|
||||
// Simulate improving rewards: early steps low reward, later steps higher
|
||||
let reward = if step < 32 {
|
||||
rng.gen_range(-5.0f32..-1.0)
|
||||
} else {
|
||||
rng.gen_range(1.0..15.0)
|
||||
};
|
||||
buf.push(Transition {
|
||||
obs: LocalObservation::zeros(),
|
||||
action: ActorAction {
|
||||
delta_heading_rad: 0.1,
|
||||
delta_altitude_m: 0.0,
|
||||
speed_ms: 5.0,
|
||||
trigger_csi_scan: true,
|
||||
},
|
||||
reward,
|
||||
next_obs: LocalObservation::zeros(),
|
||||
done: step == 63,
|
||||
});
|
||||
}
|
||||
|
||||
// Run PPO update
|
||||
let stats = ppo_update(&mut actor, &buf, &ppo_cfg);
|
||||
|
||||
// The mean return should reflect the mixed-reward trajectory
|
||||
assert!(stats.updates > 0, "PPO should have run updates");
|
||||
assert!(
|
||||
stats.mean_return.is_finite(),
|
||||
"mean return should be finite: {}",
|
||||
stats.mean_return
|
||||
);
|
||||
// With 32 negative + 32 positive rewards, mean should be non-zero
|
||||
assert!(
|
||||
stats.mean_return != 0.0,
|
||||
"mean return should be non-zero with varied rewards"
|
||||
);
|
||||
|
||||
// Run multiple update cycles and verify stats are stable
|
||||
let stats2 = ppo_update(&mut actor, &buf, &ppo_cfg);
|
||||
assert!(stats2.mean_return.is_finite());
|
||||
}
|
||||
}
|
||||
@@ -1,415 +0,0 @@
|
||||
//! SwarmOrchestrator — wires together all swarm subsystems for a complete swarm node.
|
||||
//!
|
||||
//! Each physical drone runs one SwarmOrchestrator instance. In demo/sim mode it
|
||||
//! runs N orchestrators in one process to simulate a full swarm.
|
||||
|
||||
use crate::{
|
||||
config::SwarmConfig,
|
||||
failsafe::{FailSafeMachine, FailSafeState},
|
||||
sensing::{
|
||||
multiview::MultiViewFusion,
|
||||
payload::{CsiPayloadPipeline, PayloadConfig},
|
||||
},
|
||||
planning::{
|
||||
coverage::CoverageStrategy,
|
||||
probability_grid::ProbabilityGrid,
|
||||
},
|
||||
types::{CsiDetection, DroneState, NodeId, Position3D, Velocity3D},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The complete per-drone swarm coordinator.
|
||||
///
|
||||
/// In production: backed by live CSI payload and PX4 flight controller.
|
||||
/// In demo/sim: backed by synthetic CSI and simulated state.
|
||||
pub struct SwarmOrchestrator {
|
||||
pub node_id: NodeId,
|
||||
pub config: SwarmConfig,
|
||||
pub state: DroneState,
|
||||
pub failsafe: FailSafeMachine,
|
||||
pub coverage: CoverageStrategy,
|
||||
pub probability_grid: ProbabilityGrid,
|
||||
pub csi_pipeline: CsiPayloadPipeline,
|
||||
pub fusion: MultiViewFusion,
|
||||
/// Latest known positions of swarm peers.
|
||||
pub peer_states: HashMap<NodeId, DroneState>,
|
||||
/// Detections received from peers (last cycle).
|
||||
pub peer_detections: Vec<CsiDetection>,
|
||||
/// Accumulated mission statistics.
|
||||
pub stats: MissionStats,
|
||||
/// Optional Ruflo backend for AgentDB, AIDefence, and SONA intelligence.
|
||||
/// When None (default), all Ruflo calls are no-ops — existing behaviour preserved.
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub ruflo: Option<Box<dyn crate::ruflo::RufloBackend>>,
|
||||
/// Active trajectory ID issued by the Ruflo intelligence hooks.
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub trajectory_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Accumulated metrics for one mission run.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MissionStats {
|
||||
pub cells_covered: u32,
|
||||
pub victims_confirmed: u32,
|
||||
pub collision_events: u32,
|
||||
pub steps: u64,
|
||||
pub elapsed_secs: f64,
|
||||
}
|
||||
|
||||
impl SwarmOrchestrator {
|
||||
/// Create a new orchestrator in demo mode (synthetic CSI).
|
||||
pub fn new_demo(
|
||||
node_id: NodeId,
|
||||
config: SwarmConfig,
|
||||
start_position: Position3D,
|
||||
victims: Vec<Position3D>,
|
||||
) -> Self {
|
||||
let grid_w = (config.mission.area_width_m / config.mission.grid_resolution_m).ceil() as u32;
|
||||
let grid_h = (config.mission.area_height_m / config.mission.grid_resolution_m).ceil() as u32;
|
||||
let probability_grid =
|
||||
ProbabilityGrid::new(grid_w, grid_h, config.mission.grid_resolution_m);
|
||||
|
||||
let noise_std = config.demo.as_ref().map(|d| d.csi_noise_std).unwrap_or(0.05);
|
||||
let detection_range = config.planning.csi_scan_width_m;
|
||||
let convergence_threshold = config.planning.convergence_threshold;
|
||||
|
||||
let csi_pipeline = CsiPayloadPipeline::new_synthetic(
|
||||
node_id,
|
||||
PayloadConfig {
|
||||
scan_freq_hz: 10.0,
|
||||
detection_range_m: detection_range,
|
||||
confidence_threshold: 0.5,
|
||||
esp32_baud_rate: 921_600,
|
||||
},
|
||||
victims,
|
||||
noise_std,
|
||||
node_id.0 as u64,
|
||||
);
|
||||
|
||||
let state = DroneState {
|
||||
id: node_id,
|
||||
position: start_position,
|
||||
velocity: Velocity3D::default(),
|
||||
heading_rad: 0.0,
|
||||
altitude_agl_m: config.planning.flight_altitude_m,
|
||||
battery_pct: 100.0,
|
||||
link_quality: 1.0,
|
||||
timestamp_ms: 0,
|
||||
};
|
||||
|
||||
Self {
|
||||
node_id,
|
||||
config: config.clone(),
|
||||
state,
|
||||
failsafe: FailSafeMachine::new(),
|
||||
coverage: CoverageStrategy::new(convergence_threshold),
|
||||
probability_grid,
|
||||
csi_pipeline,
|
||||
fusion: MultiViewFusion::default(),
|
||||
peer_states: HashMap::new(),
|
||||
peer_detections: Vec::new(),
|
||||
stats: MissionStats::default(),
|
||||
#[cfg(feature = "ruflo")]
|
||||
ruflo: None,
|
||||
#[cfg(feature = "ruflo")]
|
||||
trajectory_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process one simulation step (dt_secs: time elapsed since last step).
|
||||
/// Returns the current fail-safe state after evaluation.
|
||||
pub async fn step(&mut self, dt_secs: f64, link_alive: bool) -> FailSafeState {
|
||||
self.stats.steps += 1;
|
||||
self.stats.elapsed_secs += dt_secs;
|
||||
|
||||
// 1. Drain stale peer detections from previous cycle.
|
||||
self.peer_detections.clear();
|
||||
|
||||
// 2. Evaluate fail-safe state machine.
|
||||
let nearest_dist = self.nearest_peer_distance();
|
||||
let fs_state = self.failsafe.tick(&self.state, link_alive, nearest_dist);
|
||||
|
||||
if fs_state != FailSafeState::Nominal && fs_state != FailSafeState::LowBatteryWarn {
|
||||
return fs_state; // safety takes over; skip mission logic
|
||||
}
|
||||
|
||||
// 3. CSI scan at current position.
|
||||
let current_pos = self.state.position;
|
||||
if let Some(detection) = self.csi_pipeline.scan(¤t_pos).await {
|
||||
if detection.confidence >= self.csi_pipeline.config.confidence_threshold {
|
||||
if let Some(victim_pos) = detection.victim_position {
|
||||
let cell = self.pos_to_cell(&victim_pos);
|
||||
self.probability_grid.update_bayesian(cell, detection.confidence, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Mark current cell as scanned.
|
||||
let cur_cell = self.pos_to_cell(¤t_pos);
|
||||
let was_new = self.probability_grid.mark_scanned(cur_cell);
|
||||
if was_new {
|
||||
self.stats.cells_covered += 1;
|
||||
}
|
||||
|
||||
// 5. Update coverage phase based on grid state.
|
||||
self.coverage.phase_transition(&self.probability_grid);
|
||||
|
||||
// 6. Move toward next waypoint (proportional navigation for simulation).
|
||||
if let Some(target) = self.coverage.next_target(&self.state, &self.probability_grid) {
|
||||
self.move_toward(target, dt_secs);
|
||||
}
|
||||
|
||||
// 7. Simple battery drain: 1% per 30 s at full speed.
|
||||
self.state.battery_pct -= (dt_secs / 30.0) as f32;
|
||||
self.state.battery_pct = self.state.battery_pct.max(0.0);
|
||||
self.state.timestamp_ms += (dt_secs * 1_000.0) as u64;
|
||||
|
||||
fs_state
|
||||
}
|
||||
|
||||
/// Multi-drone CSI fusion at the cluster-head level.
|
||||
/// Returns a fused detection if enough viewpoints agree.
|
||||
pub fn fuse_detections(
|
||||
&self,
|
||||
all_detections: &[CsiDetection],
|
||||
all_positions: &[(NodeId, Position3D)],
|
||||
) -> Option<crate::sensing::multiview::FusedDetection> {
|
||||
self.fusion.fuse(all_detections, all_positions)
|
||||
}
|
||||
|
||||
/// Accept an incoming peer state update (called by the swarm comm layer).
|
||||
pub fn receive_peer_state(&mut self, peer: DroneState) {
|
||||
self.peer_states.insert(peer.id, peer);
|
||||
}
|
||||
|
||||
/// Accept an incoming CSI detection from a peer.
|
||||
pub fn receive_peer_detection(&mut self, det: CsiDetection) {
|
||||
self.peer_detections.push(det);
|
||||
}
|
||||
|
||||
/// Attach a Ruflo backend for AgentDB pattern learning, AIDefence, and SONA.
|
||||
///
|
||||
/// Call after `new_demo()`:
|
||||
/// ```ignore
|
||||
/// let orch = SwarmOrchestrator::new_demo(...)
|
||||
/// .with_ruflo(Box::new(MockRufloBackend::new()));
|
||||
/// ```
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub fn with_ruflo(mut self, backend: Box<dyn crate::ruflo::RufloBackend>) -> Self {
|
||||
self.ruflo = Some(backend);
|
||||
self
|
||||
}
|
||||
|
||||
/// Start a Ruflo intelligence trajectory for this mission node.
|
||||
///
|
||||
/// Call before the mission loop begins. If no backend is attached this is a no-op.
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub async fn start_trajectory(&mut self, mission_desc: &str) {
|
||||
if let Some(ruflo) = &self.ruflo {
|
||||
match ruflo.trajectory_start(mission_desc, "swarm-specialist").await {
|
||||
Ok(tid) => self.trajectory_id = Some(tid),
|
||||
Err(e) => tracing::warn!("trajectory_start failed: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// End the Ruflo trajectory and persist the mission summary in AgentDB.
|
||||
///
|
||||
/// Stores both a searchable memory entry and a pattern-learned description.
|
||||
/// If no backend is attached this is a no-op.
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub async fn finish_trajectory(&mut self, success: bool, mission_key: &str) {
|
||||
if let Some(ruflo) = &self.ruflo {
|
||||
let tid = self.trajectory_id.take();
|
||||
if let Some(tid) = &tid {
|
||||
let _ = ruflo.trajectory_end(tid, success, None).await;
|
||||
}
|
||||
// Build and serialise mission summary.
|
||||
let summary = crate::ruflo::MissionSummary::from_stats(
|
||||
&self.stats,
|
||||
&self.config.mission.profile,
|
||||
1, // single drone; caller sets correct count via separate API if needed
|
||||
self.config.mission.area_width_m,
|
||||
self.config.mission.area_height_m,
|
||||
0, // caller sets victims_total; 0 = unknown
|
||||
self.probability_grid.coverage_pct(),
|
||||
);
|
||||
if let Ok(json) = serde_json::to_string(&summary) {
|
||||
let _ = ruflo.store_mission(mission_key, &json, "swarm-missions").await;
|
||||
}
|
||||
let _ = ruflo.store_pattern(
|
||||
&summary.to_pattern_description(),
|
||||
summary.pattern_type(),
|
||||
summary.pattern_confidence(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// AIDefence-checked variant of `receive_peer_detection`.
|
||||
///
|
||||
/// Returns `true` and enqueues the detection if it passes the safety check.
|
||||
/// Returns `false` (and drops the detection) if AIDefence flags it as unsafe.
|
||||
/// Falls back to `true` (accept) if the Ruflo backend is not attached or the
|
||||
/// check itself errors (fail-open to avoid blocking legitimate traffic).
|
||||
#[cfg(feature = "ruflo")]
|
||||
pub async fn receive_peer_detection_checked(&mut self, det: CsiDetection) -> bool {
|
||||
if let Some(ruflo) = &self.ruflo {
|
||||
// Serialise the detection to a string for AIDefence inspection.
|
||||
let repr = format!(
|
||||
"drone_id={:?} confidence={:.3} victim={:?}",
|
||||
det.drone_id, det.confidence, det.victim_position
|
||||
);
|
||||
match ruflo.mavlink_is_safe(&repr).await {
|
||||
Ok(false) => {
|
||||
tracing::warn!(
|
||||
"aidefence rejected peer detection from {:?}",
|
||||
det.drone_id
|
||||
);
|
||||
return false;
|
||||
}
|
||||
Err(e) => tracing::debug!("aidefence check failed (proceeding): {}", e),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.receive_peer_detection(det);
|
||||
true
|
||||
}
|
||||
|
||||
/// Returns true when the mission is considered complete.
|
||||
pub fn is_mission_complete(&self) -> bool {
|
||||
self.probability_grid.coverage_pct() > 0.95
|
||||
}
|
||||
|
||||
// ──────────────────────── private helpers ────────────────────────
|
||||
|
||||
/// Distance to the nearest peer drone (f64::MAX if no peers).
|
||||
fn nearest_peer_distance(&self) -> f64 {
|
||||
self.peer_states
|
||||
.values()
|
||||
.map(|p| self.state.position.distance_to(&p.position))
|
||||
.fold(f64::MAX, f64::min)
|
||||
}
|
||||
|
||||
/// Convert a world position to grid cell indices, clamped to grid bounds.
|
||||
fn pos_to_cell(&self, pos: &Position3D) -> (u32, u32) {
|
||||
let r = self.config.mission.grid_resolution_m;
|
||||
let w = (self.config.mission.area_width_m / r) as u32;
|
||||
let h = (self.config.mission.area_height_m / r) as u32;
|
||||
let xi = (pos.x / r).max(0.0) as u32;
|
||||
let yi = (pos.y / r).max(0.0) as u32;
|
||||
(xi.min(w.saturating_sub(1)), yi.min(h.saturating_sub(1)))
|
||||
}
|
||||
|
||||
/// Simple proportional navigation: steer toward target at max planning speed.
|
||||
fn move_toward(&mut self, target: Position3D, dt_secs: f64) {
|
||||
let dx = target.x - self.state.position.x;
|
||||
let dy = target.y - self.state.position.y;
|
||||
let dist = (dx * dx + dy * dy).sqrt();
|
||||
|
||||
if dist < 0.5 {
|
||||
self.state.velocity = Velocity3D::default();
|
||||
return;
|
||||
}
|
||||
|
||||
let speed = self.config.planning.max_speed_ms.min(dist / dt_secs);
|
||||
let vx = (dx / dist) * speed;
|
||||
let vy = (dy / dist) * speed;
|
||||
|
||||
self.state.position.x += vx * dt_secs;
|
||||
self.state.position.y += vy * dt_secs;
|
||||
self.state.velocity = Velocity3D { vx, vy, vz: 0.0 };
|
||||
self.state.heading_rad = vy.atan2(vx);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn demo_orchestrator(node_id: u32, victims: Vec<Position3D>) -> SwarmOrchestrator {
|
||||
let cfg = SwarmConfig::demo_default();
|
||||
SwarmOrchestrator::new_demo(
|
||||
NodeId(node_id),
|
||||
cfg,
|
||||
Position3D { x: 10.0 * node_id as f64, y: 0.0, z: -30.0 },
|
||||
victims,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_orchestrator_step() {
|
||||
let mut orch =
|
||||
demo_orchestrator(0, vec![Position3D { x: 50.0, y: 50.0, z: 0.0 }]);
|
||||
let state = orch.step(0.1, true).await;
|
||||
assert_eq!(state, FailSafeState::Nominal);
|
||||
assert_eq!(orch.stats.steps, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_failsafe_triggers_on_link_loss() {
|
||||
let mut orch = demo_orchestrator(0, vec![]);
|
||||
// Lower the hold threshold so it trips well within a sub-second test run.
|
||||
orch.failsafe.link_loss_hold_secs = 0.001;
|
||||
orch.failsafe.link_loss_rth_secs = 0.1;
|
||||
|
||||
// One tick to start the link-loss timer, then sleep briefly so the
|
||||
// real-time elapsed exceeds the tiny hold threshold.
|
||||
orch.step(0.1, false).await;
|
||||
std::thread::sleep(std::time::Duration::from_millis(5));
|
||||
|
||||
let state = orch.step(0.1, false).await;
|
||||
assert_ne!(state, FailSafeState::Nominal, "link loss should trigger failsafe");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_drone_coverage() {
|
||||
let victims = vec![Position3D { x: 50.0, y: 50.0, z: 0.0 }];
|
||||
let mut drones: Vec<SwarmOrchestrator> =
|
||||
(0..4).map(|i| demo_orchestrator(i, victims.clone())).collect();
|
||||
|
||||
// 50 steps × 0.1 s dt = 5 simulated seconds
|
||||
for _ in 0..50 {
|
||||
for drone in &mut drones {
|
||||
drone.step(0.1, true).await;
|
||||
}
|
||||
}
|
||||
|
||||
let total_cells: u32 = drones.iter().map(|d| d.stats.cells_covered).sum();
|
||||
assert!(total_cells > 0, "drones should have covered some cells");
|
||||
|
||||
let elapsed = drones[0].stats.elapsed_secs;
|
||||
assert!((elapsed - 5.0).abs() < 0.01, "elapsed should be ~5 s, got {elapsed}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peer_state_exchange() {
|
||||
let mut orch0 = demo_orchestrator(0, vec![]);
|
||||
let mut orch1 = demo_orchestrator(1, vec![]);
|
||||
|
||||
orch0.step(0.1, true).await;
|
||||
orch1.step(0.1, true).await;
|
||||
|
||||
// Exchange states
|
||||
orch0.receive_peer_state(orch1.state.clone());
|
||||
orch1.receive_peer_state(orch0.state.clone());
|
||||
|
||||
assert!(
|
||||
orch0.peer_states.contains_key(&NodeId(1)),
|
||||
"orch0 should know about orch1"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mission_complete_after_full_coverage() {
|
||||
let mut orch = demo_orchestrator(0, vec![]);
|
||||
// Manually mark every cell scanned.
|
||||
let w = orch.probability_grid.width;
|
||||
let h = orch.probability_grid.height;
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
orch.probability_grid.mark_scanned((x, y));
|
||||
}
|
||||
}
|
||||
assert!(orch.is_mission_complete(), "should be complete at 100% coverage");
|
||||
}
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//! Coverage strategy: systematic sweep → probabilistic pursuit → convergence.
|
||||
|
||||
use crate::types::{DroneState, NodeId, Position3D};
|
||||
use super::probability_grid::ProbabilityGrid;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Phase of the coverage mission.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Phase {
|
||||
/// Systematic boustrophedon sweep of the mission area.
|
||||
Systematic,
|
||||
/// Probabilistic pursuit: drones head toward high-P cells.
|
||||
ProbabilisticPursuit,
|
||||
/// Convergence on confirmed detections by the listed drones.
|
||||
Convergence(Vec<NodeId>),
|
||||
}
|
||||
|
||||
/// Coverage strategy tracking phase and cell assignments.
|
||||
pub struct CoverageStrategy {
|
||||
pub phase: Phase,
|
||||
/// Assigned cell per drone.
|
||||
pub assignments: HashMap<NodeId, (u32, u32)>,
|
||||
pub convergence_threshold: f32,
|
||||
}
|
||||
|
||||
impl CoverageStrategy {
|
||||
pub fn new(convergence_threshold: f32) -> Self {
|
||||
Self {
|
||||
phase: Phase::Systematic,
|
||||
assignments: HashMap::new(),
|
||||
convergence_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the next waypoint for a drone given the current grid.
|
||||
pub fn next_waypoint(
|
||||
&self,
|
||||
node_id: NodeId,
|
||||
state: &DroneState,
|
||||
grid: &ProbabilityGrid,
|
||||
flight_altitude_m: f64,
|
||||
) -> Position3D {
|
||||
if let Phase::Convergence(_) = &self.phase {
|
||||
if let Some(&(cx, cy)) = self.assignments.get(&node_id) {
|
||||
return Position3D {
|
||||
x: cx as f64 * grid.cell_size_m,
|
||||
y: cy as f64 * grid.cell_size_m,
|
||||
z: -flight_altitude_m,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Default: head toward the highest-priority unscanned cell.
|
||||
if let Some((cx, cy)) = grid.highest_priority_unscanned() {
|
||||
Position3D {
|
||||
x: cx as f64 * grid.cell_size_m,
|
||||
y: cy as f64 * grid.cell_size_m,
|
||||
z: -flight_altitude_m,
|
||||
}
|
||||
} else {
|
||||
state.position
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the next navigation target position for an orchestrator step.
|
||||
///
|
||||
/// - Systematic phase: next unscanned boustrophedon cell.
|
||||
/// - ProbabilisticPursuit: highest-priority unscanned cell.
|
||||
/// - Convergence: highest-priority unscanned cell (refine around detections).
|
||||
pub fn next_target(&self, state: &DroneState, grid: &ProbabilityGrid) -> Option<Position3D> {
|
||||
let r = grid.cell_size_m;
|
||||
match &self.phase {
|
||||
Phase::Systematic => {
|
||||
grid.next_systematic_cell(state).map(|(cx, cy)| Position3D {
|
||||
x: cx as f64 * r + r / 2.0,
|
||||
y: cy as f64 * r + r / 2.0,
|
||||
z: state.position.z,
|
||||
})
|
||||
}
|
||||
Phase::ProbabilisticPursuit | Phase::Convergence(_) => {
|
||||
grid.highest_priority_unscanned().map(|(cx, cy)| Position3D {
|
||||
x: cx as f64 * r + r / 2.0,
|
||||
y: cy as f64 * r + r / 2.0,
|
||||
z: state.position.z,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition to next phase based on grid state, guarded by a threshold.
|
||||
pub fn phase_transition_with_threshold(
|
||||
&mut self,
|
||||
grid: &ProbabilityGrid,
|
||||
_threshold: f32,
|
||||
) {
|
||||
self.phase_transition(grid);
|
||||
}
|
||||
|
||||
/// Transition to next phase based on grid state.
|
||||
pub fn phase_transition(&mut self, grid: &ProbabilityGrid) {
|
||||
let max_p = grid
|
||||
.cells
|
||||
.iter()
|
||||
.flat_map(|row| row.iter())
|
||||
.map(|c| c.victim_probability)
|
||||
.fold(0.0_f32, f32::max);
|
||||
|
||||
self.phase = match &self.phase {
|
||||
Phase::Systematic if max_p >= self.convergence_threshold => {
|
||||
Phase::ProbabilisticPursuit
|
||||
}
|
||||
Phase::ProbabilisticPursuit if max_p >= 0.9 => {
|
||||
Phase::Convergence(vec![])
|
||||
}
|
||||
other => other.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
//! Mission planning: coverage, probability grid, RRT-APF path planning.
|
||||
|
||||
pub mod rrt_apf;
|
||||
pub mod coverage;
|
||||
pub mod probability_grid;
|
||||
pub mod pheromone;
|
||||
pub mod patterns;
|
||||
|
||||
pub use rrt_apf::{RrtApfPlanner, Waypoint};
|
||||
pub use coverage::{CoverageStrategy, Phase};
|
||||
pub use probability_grid::ProbabilityGrid;
|
||||
pub use patterns::{FlightPattern, PatternContext};
|
||||
@@ -1,428 +0,0 @@
|
||||
//! Flight / coverage-optimization patterns for swarm area search.
|
||||
//!
|
||||
//! Different strategies trade off coverage completeness, time, and robustness:
|
||||
//! - Boustrophedon: systematic lawnmower; complete but drones overlap if unpartitioned
|
||||
//! - PartitionedLawnmower: area split into per-drone strips → no overlap, ~Nx faster coverage
|
||||
//! - Spiral: outward spiral from a seed; good for centred search (last-known-position SAR)
|
||||
//! - Pheromone: stigmergic — steer away from recently-visited cells; robust to dropout
|
||||
//! - PotentialField: repelled by visited cells + peers, attracted to unscanned frontier
|
||||
//! - LevyFlight: heavy-tailed random walk; good exploration when target location unknown
|
||||
|
||||
use crate::types::{NodeId, Position3D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum FlightPattern {
|
||||
Boustrophedon,
|
||||
#[default]
|
||||
PartitionedLawnmower,
|
||||
Spiral,
|
||||
Pheromone,
|
||||
PotentialField,
|
||||
LevyFlight,
|
||||
}
|
||||
|
||||
impl FlightPattern {
|
||||
// Intentional inherent infallible parser (returns Self, not Result); shipped API.
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"boustrophedon" | "lawnmower" => FlightPattern::Boustrophedon,
|
||||
"partitioned" | "partitioned_lawnmower" => FlightPattern::PartitionedLawnmower,
|
||||
"spiral" => FlightPattern::Spiral,
|
||||
"pheromone" | "stigmergic" => FlightPattern::Pheromone,
|
||||
"potential" | "potential_field" => FlightPattern::PotentialField,
|
||||
"levy" | "levyflight" | "levy_flight" => FlightPattern::LevyFlight,
|
||||
_ => FlightPattern::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
FlightPattern::Boustrophedon => "boustrophedon",
|
||||
FlightPattern::PartitionedLawnmower => "partitioned_lawnmower",
|
||||
FlightPattern::Spiral => "spiral",
|
||||
FlightPattern::Pheromone => "pheromone",
|
||||
FlightPattern::PotentialField => "potential_field",
|
||||
FlightPattern::LevyFlight => "levy_flight",
|
||||
}
|
||||
}
|
||||
|
||||
/// All pattern variants, for enumeration / UI selection.
|
||||
pub fn all() -> [FlightPattern; 6] {
|
||||
[
|
||||
FlightPattern::Boustrophedon,
|
||||
FlightPattern::PartitionedLawnmower,
|
||||
FlightPattern::Spiral,
|
||||
FlightPattern::Pheromone,
|
||||
FlightPattern::PotentialField,
|
||||
FlightPattern::LevyFlight,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Inputs for computing the next waypoint under a pattern.
|
||||
pub struct PatternContext<'a> {
|
||||
pub drone_id: NodeId,
|
||||
pub swarm_size: usize,
|
||||
pub current: Position3D,
|
||||
pub area_w: f64,
|
||||
pub area_h: f64,
|
||||
pub altitude_z: f64, // flight z (negative NED)
|
||||
pub scan_width_m: f64, // strip spacing
|
||||
pub step: u64, // tick counter (for deterministic pseudo-random patterns)
|
||||
pub visited: &'a [Position3D], // recently visited cell centres (for pheromone/potential)
|
||||
pub peers: &'a [Position3D], // peer positions (for potential-field repulsion)
|
||||
}
|
||||
|
||||
impl FlightPattern {
|
||||
/// Compute the next target position for a drone under this pattern.
|
||||
pub fn next_target(&self, ctx: &PatternContext) -> Position3D {
|
||||
match self {
|
||||
FlightPattern::Boustrophedon => boustrophedon(ctx),
|
||||
FlightPattern::PartitionedLawnmower => partitioned_lawnmower(ctx),
|
||||
FlightPattern::Spiral => spiral(ctx),
|
||||
FlightPattern::Pheromone => pheromone(ctx),
|
||||
FlightPattern::PotentialField => potential_field(ctx),
|
||||
FlightPattern::LevyFlight => levy_flight(ctx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clamp a candidate (x, y) to the area bounds and lift it to the flight altitude.
|
||||
fn clamp_to_area(x: f64, y: f64, ctx: &PatternContext) -> Position3D {
|
||||
Position3D {
|
||||
x: x.clamp(0.0, ctx.area_w),
|
||||
y: y.clamp(0.0, ctx.area_h),
|
||||
z: ctx.altitude_z,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serpentine waypoint within a rectangular sub-region.
|
||||
///
|
||||
/// Walks rows of height `scan_width_m`; on each row sweeps left→right or
|
||||
/// right→left depending on the row parity, advancing one `scan_width_m`
|
||||
/// segment per `step`.
|
||||
fn serpentine_in_region(
|
||||
x0: f64,
|
||||
x1: f64,
|
||||
y0: f64,
|
||||
y1: f64,
|
||||
scan_width_m: f64,
|
||||
step: u64,
|
||||
) -> (f64, f64) {
|
||||
let strip_w = (x1 - x0).max(scan_width_m);
|
||||
let height = (y1 - y0).max(scan_width_m);
|
||||
|
||||
// Number of horizontal segments per row before stepping to the next row.
|
||||
let cols = ((strip_w / scan_width_m).ceil() as u64).max(1);
|
||||
// Number of rows in this region.
|
||||
let rows = ((height / scan_width_m).ceil() as u64).max(1);
|
||||
let total = cols * rows;
|
||||
let s = step % total;
|
||||
|
||||
let row = s / cols;
|
||||
let col = s % cols;
|
||||
|
||||
// Centre of the current row band.
|
||||
let y = y0 + (row as f64 + 0.5) * scan_width_m;
|
||||
let y = y.min(y1);
|
||||
|
||||
// Serpentine: even rows L→R, odd rows R→L.
|
||||
let along = if row.is_multiple_of(2) { col } else { cols - 1 - col };
|
||||
let x = x0 + (along as f64 + 0.5) * scan_width_m;
|
||||
let x = x.min(x1);
|
||||
|
||||
(x, y)
|
||||
}
|
||||
|
||||
/// Classic full-area serpentine lawnmower (drones may overlap — baseline).
|
||||
fn boustrophedon(ctx: &PatternContext) -> Position3D {
|
||||
let (x, y) = serpentine_in_region(
|
||||
0.0,
|
||||
ctx.area_w,
|
||||
0.0,
|
||||
ctx.area_h,
|
||||
ctx.scan_width_m,
|
||||
ctx.step,
|
||||
);
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
/// Partitioned lawnmower: split `area_w` into `swarm_size` vertical strips;
|
||||
/// drone `i` lawnmowers ONLY within strip `[i*w/n, (i+1)*w/n]`.
|
||||
///
|
||||
/// This is the clustering fix: each drone covers a disjoint band, so total
|
||||
/// coverage scales ~linearly with swarm size instead of all drones tracing
|
||||
/// the same path.
|
||||
fn partitioned_lawnmower(ctx: &PatternContext) -> Position3D {
|
||||
let n = ctx.swarm_size.max(1);
|
||||
let i = (ctx.drone_id.0 as usize) % n;
|
||||
let strip_w = ctx.area_w / n as f64;
|
||||
let x0 = i as f64 * strip_w;
|
||||
let x1 = x0 + strip_w;
|
||||
|
||||
let (x, y) =
|
||||
serpentine_in_region(x0, x1, 0.0, ctx.area_h, ctx.scan_width_m, ctx.step);
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
/// Outward Archimedean spiral from the area centre; radius grows with step.
|
||||
fn spiral(ctx: &PatternContext) -> Position3D {
|
||||
let cx = ctx.area_w / 2.0;
|
||||
let cy = ctx.area_h / 2.0;
|
||||
|
||||
// Angular step keeps successive waypoints roughly `scan_width_m` apart.
|
||||
let theta = ctx.step as f64 * 0.6;
|
||||
// Archimedean spiral r = b * theta; b chosen so each turn adds scan_width_m.
|
||||
let b = ctx.scan_width_m / (2.0 * std::f64::consts::PI);
|
||||
let r = b * theta;
|
||||
|
||||
let x = cx + r * theta.cos();
|
||||
let y = cy + r * theta.sin();
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
/// Stigmergic: sample candidate headings, step toward the least-visited one.
|
||||
fn pheromone(ctx: &PatternContext) -> Position3D {
|
||||
let step_len = ctx.scan_width_m.max(1.0);
|
||||
// Deterministic base heading offset per drone so they diverge.
|
||||
let base = ctx.drone_id.0 as f64 * (std::f64::consts::PI / 3.0);
|
||||
|
||||
let n_candidates = 8;
|
||||
let mut best: Option<(f64, f64, f64)> = None; // (score, x, y); lower score = less visited
|
||||
for k in 0..n_candidates {
|
||||
let theta = base + (k as f64) * (2.0 * std::f64::consts::PI / n_candidates as f64);
|
||||
let cx = ctx.current.x + step_len * theta.cos();
|
||||
let cy = ctx.current.y + step_len * theta.sin();
|
||||
let cx = cx.clamp(0.0, ctx.area_w);
|
||||
let cy = cy.clamp(0.0, ctx.area_h);
|
||||
|
||||
// Penalty = sum of inverse-distance to recently-visited cell centres.
|
||||
let mut visit_pressure = 0.0;
|
||||
for v in ctx.visited {
|
||||
let d = (cx - v.x).hypot(cy - v.y);
|
||||
visit_pressure += 1.0 / (1.0 + d);
|
||||
}
|
||||
if best.as_ref().is_none_or(|(bs, _, _)| visit_pressure < *bs) {
|
||||
best = Some((visit_pressure, cx, cy));
|
||||
}
|
||||
}
|
||||
|
||||
let (_, x, y) = best.unwrap_or((0.0, ctx.current.x, ctx.current.y));
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
/// Potential field: repelled by visited cells + peers, attracted to the
|
||||
/// nearest unscanned frontier; step in the resultant direction.
|
||||
fn potential_field(ctx: &PatternContext) -> Position3D {
|
||||
let mut fx = 0.0;
|
||||
let mut fy = 0.0;
|
||||
|
||||
// Repulsion from recently-visited cells.
|
||||
for v in ctx.visited {
|
||||
let dx = ctx.current.x - v.x;
|
||||
let dy = ctx.current.y - v.y;
|
||||
let d2 = dx * dx + dy * dy + 1.0;
|
||||
let mag = 1.0 / d2;
|
||||
fx += dx / d2.sqrt() * mag;
|
||||
fy += dy / d2.sqrt() * mag;
|
||||
}
|
||||
|
||||
// Repulsion from peers (collision / overlap avoidance).
|
||||
for p in ctx.peers {
|
||||
let dx = ctx.current.x - p.x;
|
||||
let dy = ctx.current.y - p.y;
|
||||
let d2 = dx * dx + dy * dy + 1.0;
|
||||
let mag = 2.0 / d2; // peers repel more strongly than stale trail
|
||||
fx += dx / d2.sqrt() * mag;
|
||||
fy += dy / d2.sqrt() * mag;
|
||||
}
|
||||
|
||||
// Attraction toward the nearest unscanned frontier point. Sample a grid of
|
||||
// candidate area points; pick the one with greatest distance to any visited
|
||||
// cell (i.e. the least-explored region) and pull toward it.
|
||||
let mut frontier: Option<(f64, f64, f64)> = None; // (openness, x, y)
|
||||
let samples = 5;
|
||||
for ix in 0..=samples {
|
||||
for iy in 0..=samples {
|
||||
let px = ctx.area_w * ix as f64 / samples as f64;
|
||||
let py = ctx.area_h * iy as f64 / samples as f64;
|
||||
let mut nearest = f64::INFINITY;
|
||||
for v in ctx.visited {
|
||||
let d = (px - v.x).hypot(py - v.y);
|
||||
if d < nearest {
|
||||
nearest = d;
|
||||
}
|
||||
}
|
||||
if !nearest.is_finite() {
|
||||
nearest = (px - ctx.current.x).hypot(py - ctx.current.y);
|
||||
}
|
||||
if frontier.as_ref().is_none_or(|(o, _, _)| nearest > *o) {
|
||||
frontier = Some((nearest, px, py));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((_, gx, gy)) = frontier {
|
||||
let dx = gx - ctx.current.x;
|
||||
let dy = gy - ctx.current.y;
|
||||
let d = (dx * dx + dy * dy).sqrt().max(1e-6);
|
||||
fx += dx / d * 1.5; // attraction gain
|
||||
fy += dy / d * 1.5;
|
||||
}
|
||||
|
||||
let fmag = (fx * fx + fy * fy).sqrt();
|
||||
let step_len = ctx.scan_width_m.max(1.0);
|
||||
let (x, y) = if fmag > 1e-9 {
|
||||
(
|
||||
ctx.current.x + fx / fmag * step_len,
|
||||
ctx.current.y + fy / fmag * step_len,
|
||||
)
|
||||
} else {
|
||||
(ctx.current.x, ctx.current.y)
|
||||
};
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
/// Deterministic pseudo-random heavy-tailed step (Lévy flight). Most steps are
|
||||
/// short; occasional long jumps. Seeded from drone_id + step via an LCG so the
|
||||
/// trajectory is reproducible.
|
||||
fn levy_flight(ctx: &PatternContext) -> Position3D {
|
||||
// Linear congruential generator (Numerical Recipes constants).
|
||||
let seed = (ctx.drone_id.0 as u64)
|
||||
.wrapping_mul(0x9E37_79B9_7F4A_7C15)
|
||||
.wrapping_add(ctx.step.wrapping_mul(0x2545_F491_4F6C_DD1D));
|
||||
let r1 = lcg(seed);
|
||||
let r2 = lcg(r1);
|
||||
|
||||
let u_angle = (r1 >> 11) as f64 / (1u64 << 53) as f64; // [0,1)
|
||||
let u_len = ((r2 >> 11) as f64 / (1u64 << 53) as f64).max(1e-6); // (0,1]
|
||||
|
||||
let theta = u_angle * 2.0 * std::f64::consts::PI;
|
||||
// Heavy-tailed step length: inverse power-law (Pareto-like), exponent ~1.5.
|
||||
let step_len = ctx.scan_width_m.max(1.0) * u_len.powf(-1.0 / 1.5);
|
||||
// Cap to the area diagonal so a single jump can't shoot arbitrarily far.
|
||||
let max_jump = (ctx.area_w * ctx.area_w + ctx.area_h * ctx.area_h).sqrt();
|
||||
let step_len = step_len.min(max_jump);
|
||||
|
||||
let x = ctx.current.x + step_len * theta.cos();
|
||||
let y = ctx.current.y + step_len * theta.sin();
|
||||
clamp_to_area(x, y, ctx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn lcg(state: u64) -> u64 {
|
||||
state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn ctx<'a>(
|
||||
drone_id: u32,
|
||||
swarm_size: usize,
|
||||
step: u64,
|
||||
current: Position3D,
|
||||
visited: &'a [Position3D],
|
||||
peers: &'a [Position3D],
|
||||
) -> PatternContext<'a> {
|
||||
PatternContext {
|
||||
drone_id: NodeId(drone_id),
|
||||
swarm_size,
|
||||
current,
|
||||
area_w: 100.0,
|
||||
area_h: 80.0,
|
||||
altitude_z: -20.0,
|
||||
scan_width_m: 5.0,
|
||||
step,
|
||||
visited,
|
||||
peers,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partitioned_strips_disjoint() {
|
||||
let empty: [Position3D; 0] = [];
|
||||
// Two drones, swarm of 2: drone 0 owns left half, drone 1 the right half.
|
||||
let mut d0_xs = Vec::new();
|
||||
let mut d1_xs = Vec::new();
|
||||
for s in 0..40u64 {
|
||||
let c0 = ctx(0, 2, s, Position3D::zero(), &empty, &empty);
|
||||
let c1 = ctx(1, 2, s, Position3D::zero(), &empty, &empty);
|
||||
d0_xs.push(FlightPattern::PartitionedLawnmower.next_target(&c0).x);
|
||||
d1_xs.push(FlightPattern::PartitionedLawnmower.next_target(&c1).x);
|
||||
}
|
||||
let mid = 100.0 / 2.0;
|
||||
// Drone 0 stays strictly in the left half, drone 1 strictly in the right.
|
||||
assert!(d0_xs.iter().all(|&x| x <= mid), "drone 0 left of midline");
|
||||
assert!(d1_xs.iter().all(|&x| x >= mid), "drone 1 right of midline");
|
||||
// And they never share an x position (disjoint strips → no overlap).
|
||||
for &a in &d0_xs {
|
||||
for &b in &d1_xs {
|
||||
assert!(a < b || (a <= mid && b >= mid), "strips overlap: {a} vs {b}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_patterns_in_bounds() {
|
||||
let visited = [
|
||||
Position3D { x: 10.0, y: 10.0, z: -20.0 },
|
||||
Position3D { x: 50.0, y: 40.0, z: -20.0 },
|
||||
];
|
||||
let peers = [Position3D { x: 30.0, y: 20.0, z: -20.0 }];
|
||||
for pat in FlightPattern::all() {
|
||||
let mut current = Position3D { x: 25.0, y: 25.0, z: -20.0 };
|
||||
for s in 0..20u64 {
|
||||
let c = ctx(1, 4, s, current, &visited, &peers);
|
||||
let t = pat.next_target(&c);
|
||||
assert!(
|
||||
t.x >= 0.0 && t.x <= 100.0,
|
||||
"{} x out of bounds at step {s}: {}",
|
||||
pat.name(),
|
||||
t.x
|
||||
);
|
||||
assert!(
|
||||
t.y >= 0.0 && t.y <= 80.0,
|
||||
"{} y out of bounds at step {s}: {}",
|
||||
pat.name(),
|
||||
t.y
|
||||
);
|
||||
assert_eq!(t.z, -20.0, "{} altitude wrong", pat.name());
|
||||
current = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_from_str_roundtrip() {
|
||||
for pat in FlightPattern::all() {
|
||||
assert_eq!(
|
||||
FlightPattern::from_str(pat.name()),
|
||||
pat,
|
||||
"roundtrip failed for {}",
|
||||
pat.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spiral_radius_grows() {
|
||||
let empty: [Position3D; 0] = [];
|
||||
let centre_x = 100.0 / 2.0;
|
||||
let centre_y = 80.0 / 2.0;
|
||||
let dist = |s: u64| {
|
||||
let c = ctx(0, 1, s, Position3D::zero(), &empty, &empty);
|
||||
let t = FlightPattern::Spiral.next_target(&c);
|
||||
((t.x - centre_x).powi(2) + (t.y - centre_y).powi(2)).sqrt()
|
||||
};
|
||||
let near = dist(1);
|
||||
let far = dist(50);
|
||||
assert!(
|
||||
far > near,
|
||||
"spiral radius should grow: step1={near}, step50={far}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
//! Stigmergic pheromone evaporation for coverage tracking.
|
||||
|
||||
use crate::types::GridCell;
|
||||
|
||||
/// Evaporate pheromones across all cells.
|
||||
/// `rate`: fraction decayed per tick (e.g. 0.01 = 1% per tick).
|
||||
pub fn evaporate(cells: &mut [Vec<GridCell>], rate: f32) {
|
||||
for row in cells.iter_mut() {
|
||||
for cell in row.iter_mut() {
|
||||
cell.pheromone = (cell.pheromone * (1.0 - rate)).max(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deposit pheromone at a cell (clamp to 1.0).
|
||||
pub fn deposit(cells: &mut [Vec<GridCell>], x: u32, y: u32, amount: f32) {
|
||||
if let Some(row) = cells.get_mut(y as usize) {
|
||||
if let Some(cell) = row.get_mut(x as usize) {
|
||||
cell.pheromone = (cell.pheromone + amount).min(1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
//! Bayesian probability grid for victim localization.
|
||||
|
||||
use crate::types::GridCell;
|
||||
|
||||
/// 2-D grid tracking posterior victim probability per cell.
|
||||
pub struct ProbabilityGrid {
|
||||
pub cells: Vec<Vec<GridCell>>,
|
||||
pub cell_size_m: f64,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
impl ProbabilityGrid {
|
||||
pub fn new(width: u32, height: u32, cell_size_m: f64) -> Self {
|
||||
let cells = (0..height)
|
||||
.map(|y| {
|
||||
(0..width)
|
||||
.map(|x| GridCell {
|
||||
x_idx: x,
|
||||
y_idx: y,
|
||||
victim_probability: 0.5, // uninformative prior
|
||||
pheromone: 0.0,
|
||||
last_scanned_ms: 0,
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
Self { cells, cell_size_m, width, height }
|
||||
}
|
||||
|
||||
/// Bayesian update: P(victim | detection) or P(victim | no detection).
|
||||
pub fn update_bayesian(&mut self, cell: (u32, u32), confidence: f32, detected: bool) {
|
||||
let (cx, cy) = cell;
|
||||
if cx >= self.width || cy >= self.height {
|
||||
return;
|
||||
}
|
||||
let c = &mut self.cells[cy as usize][cx as usize];
|
||||
let prior = c.victim_probability as f64;
|
||||
// Likelihood ratio update
|
||||
let likelihood = if detected {
|
||||
confidence as f64
|
||||
} else {
|
||||
1.0 - confidence as f64
|
||||
};
|
||||
let denom = likelihood * prior + (1.0 - likelihood) * (1.0 - prior);
|
||||
c.victim_probability = if denom > 1e-9 {
|
||||
(likelihood * prior / denom) as f32
|
||||
} else {
|
||||
prior as f32
|
||||
};
|
||||
c.pheromone = (c.pheromone + 0.1).min(1.0);
|
||||
}
|
||||
|
||||
/// Returns the cell (x, y) with highest expected value: P * (1 - scanned_weight).
|
||||
pub fn highest_priority_unscanned(&self) -> Option<(u32, u32)> {
|
||||
let now_approx: u64 = 0; // caller should pass current time; use 0 for simplicity
|
||||
let _ = now_approx;
|
||||
let mut best: Option<((u32, u32), f32)> = None;
|
||||
for row in &self.cells {
|
||||
for cell in row {
|
||||
let scanned_weight = if cell.last_scanned_ms > 0 { cell.pheromone } else { 0.0 };
|
||||
let score = cell.victim_probability * (1.0 - scanned_weight);
|
||||
if best.as_ref().is_none_or(|(_, bs)| score > *bs) {
|
||||
best = Some(((cell.x_idx, cell.y_idx), score));
|
||||
}
|
||||
}
|
||||
}
|
||||
best.map(|(pos, _)| pos)
|
||||
}
|
||||
|
||||
/// Mark a cell as scanned. Returns true if this is the first scan of this cell.
|
||||
pub fn mark_scanned(&mut self, cell: (u32, u32)) -> bool {
|
||||
let (cx, cy) = cell;
|
||||
if cx >= self.width || cy >= self.height {
|
||||
return false;
|
||||
}
|
||||
let c = &mut self.cells[cy as usize][cx as usize];
|
||||
if c.last_scanned_ms == 0 {
|
||||
c.last_scanned_ms = 1; // mark as visited
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Fraction of cells that have been scanned at least once.
|
||||
pub fn coverage_pct(&self) -> f64 {
|
||||
let total: usize = self.cells.iter().flatten().count();
|
||||
let scanned: usize = self.cells.iter().flatten().filter(|c| c.last_scanned_ms > 0).count();
|
||||
if total == 0 { 1.0 } else { scanned as f64 / total as f64 }
|
||||
}
|
||||
|
||||
/// Return the next cell for systematic boustrophedon sweep (row-by-row, unscanned first).
|
||||
pub fn next_systematic_cell(&self, _state: &crate::types::DroneState) -> Option<(u32, u32)> {
|
||||
// Walk rows in order; within each row alternate direction based on row parity.
|
||||
for yi in 0..self.height {
|
||||
let x_iter: Box<dyn Iterator<Item = u32>> = if yi % 2 == 0 {
|
||||
Box::new(0..self.width)
|
||||
} else {
|
||||
Box::new((0..self.width).rev())
|
||||
};
|
||||
for xi in x_iter {
|
||||
if self.cells[yi as usize][xi as usize].last_scanned_ms == 0 {
|
||||
return Some((xi, yi));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Merge another grid's probabilities using weighted average.
|
||||
pub fn apply_gossip_update(&mut self, remote: &ProbabilityGrid) {
|
||||
let h = self.height.min(remote.height) as usize;
|
||||
let w = self.width.min(remote.width) as usize;
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let local = &mut self.cells[y][x];
|
||||
let r = remote.cells[y][x].victim_probability;
|
||||
local.victim_probability = (local.victim_probability + r) / 2.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bayesian_update_increases_probability() {
|
||||
let mut grid = ProbabilityGrid::new(10, 10, 2.0);
|
||||
grid.update_bayesian((5, 5), 0.9, true);
|
||||
assert!(grid.cells[5][5].victim_probability > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bayesian_update_decreases_probability() {
|
||||
let mut grid = ProbabilityGrid::new(10, 10, 2.0);
|
||||
grid.update_bayesian((5, 5), 0.9, false);
|
||||
assert!(grid.cells[5][5].victim_probability < 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_highest_priority_returns_cell() {
|
||||
let mut grid = ProbabilityGrid::new(5, 5, 2.0);
|
||||
// Boost one cell
|
||||
grid.cells[2][3].victim_probability = 0.99;
|
||||
grid.cells[2][3].pheromone = 0.0;
|
||||
let best = grid.highest_priority_unscanned();
|
||||
assert!(best.is_some());
|
||||
assert_eq!(best.unwrap(), (3, 2));
|
||||
}
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
//! RRT-APF hybrid path planner: Rapidly-exploring Random Trees with
|
||||
//! Artificial Potential Field obstacle repulsion.
|
||||
|
||||
use crate::types::Position3D;
|
||||
use rand::Rng;
|
||||
|
||||
/// A planned waypoint with an associated target speed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Waypoint {
|
||||
pub position: Position3D,
|
||||
pub speed_ms: f64,
|
||||
}
|
||||
|
||||
/// RRT-APF path planner.
|
||||
pub struct RrtApfPlanner {
|
||||
pub obstacle_cells: Vec<Position3D>,
|
||||
pub apf_repulsion_dist: f64,
|
||||
pub step_size_m: f64,
|
||||
}
|
||||
|
||||
impl RrtApfPlanner {
|
||||
pub fn new(apf_repulsion_dist: f64) -> Self {
|
||||
Self {
|
||||
obstacle_cells: Vec::new(),
|
||||
apf_repulsion_dist,
|
||||
step_size_m: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the APF repulsion gradient at `pos` from all nearby obstacles.
|
||||
pub fn apf_force(&self, pos: &Position3D, neighbors: &[Position3D]) -> (f64, f64, f64) {
|
||||
let mut fx = 0.0_f64;
|
||||
let mut fy = 0.0_f64;
|
||||
let mut fz = 0.0_f64;
|
||||
for obs in self.obstacle_cells.iter().chain(neighbors.iter()) {
|
||||
let dist = pos.distance_to(obs);
|
||||
if dist < self.apf_repulsion_dist && dist > 1e-6 {
|
||||
let strength = (self.apf_repulsion_dist - dist) / (dist * dist);
|
||||
fx += strength * (pos.x - obs.x);
|
||||
fy += strength * (pos.y - obs.y);
|
||||
fz += strength * (pos.z - obs.z);
|
||||
}
|
||||
}
|
||||
(fx, fy, fz)
|
||||
}
|
||||
|
||||
/// Plan a path from `start` to `goal` using RRT* with APF bias.
|
||||
pub fn plan(
|
||||
&self,
|
||||
start: Position3D,
|
||||
goal: Position3D,
|
||||
max_iter: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> Vec<Waypoint> {
|
||||
let mut tree: Vec<(Position3D, usize)> = vec![(start, 0)];
|
||||
let goal_dist_thresh = self.step_size_m * 1.5;
|
||||
|
||||
for _ in 0..max_iter {
|
||||
// Sample random point (bias 10% toward goal)
|
||||
let sample = if rng.gen::<f64>() < 0.1 {
|
||||
goal
|
||||
} else {
|
||||
let range = 200.0_f64;
|
||||
Position3D {
|
||||
x: start.x + (rng.gen::<f64>() - 0.5) * range,
|
||||
y: start.y + (rng.gen::<f64>() - 0.5) * range,
|
||||
z: start.z,
|
||||
}
|
||||
};
|
||||
|
||||
// Find nearest node in tree
|
||||
let (nearest_idx, nearest_pos) = tree
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, (a, _)), (_, (b, _))| {
|
||||
a.distance_to(&sample)
|
||||
.partial_cmp(&b.distance_to(&sample))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, (p, _))| (i, *p))
|
||||
.unwrap_or((0, start));
|
||||
|
||||
// Step toward sample, then apply APF
|
||||
let dist_to_sample = nearest_pos.distance_to(&sample);
|
||||
if dist_to_sample < 1e-9 {
|
||||
continue;
|
||||
}
|
||||
let scale = self.step_size_m / dist_to_sample;
|
||||
let mut new_pos = Position3D {
|
||||
x: nearest_pos.x + (sample.x - nearest_pos.x) * scale,
|
||||
y: nearest_pos.y + (sample.y - nearest_pos.y) * scale,
|
||||
z: nearest_pos.z + (sample.z - nearest_pos.z) * scale,
|
||||
};
|
||||
|
||||
// Apply APF correction
|
||||
let (fx, fy, fz) = self.apf_force(&new_pos, &[]);
|
||||
let apf_scale = 0.3;
|
||||
new_pos.x += fx * apf_scale;
|
||||
new_pos.y += fy * apf_scale;
|
||||
new_pos.z += fz * apf_scale;
|
||||
|
||||
tree.push((new_pos, nearest_idx));
|
||||
|
||||
if new_pos.distance_to(&goal) <= goal_dist_thresh {
|
||||
// Trace path back to root
|
||||
let mut path = Vec::new();
|
||||
let mut current_idx = tree.len() - 1;
|
||||
while current_idx != 0 {
|
||||
let (pos, parent) = tree[current_idx];
|
||||
path.push(Waypoint { position: pos, speed_ms: 5.0 });
|
||||
current_idx = parent;
|
||||
}
|
||||
path.push(Waypoint { position: start, speed_ms: 5.0 });
|
||||
path.reverse();
|
||||
path.push(Waypoint { position: goal, speed_ms: 2.0 });
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: direct line
|
||||
vec![
|
||||
Waypoint { position: start, speed_ms: 5.0 },
|
||||
Waypoint { position: goal, speed_ms: 5.0 },
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_plan_returns_at_least_two_waypoints() {
|
||||
let planner = RrtApfPlanner::new(3.0);
|
||||
let start = Position3D { x: 0.0, y: 0.0, z: -30.0 };
|
||||
let goal = Position3D { x: 50.0, y: 50.0, z: -30.0 };
|
||||
let mut rng = rand::thread_rng();
|
||||
let path = planner.plan(start, goal, 500, &mut rng);
|
||||
assert!(path.len() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apf_force_pushes_away() {
|
||||
let planner = RrtApfPlanner {
|
||||
obstacle_cells: vec![Position3D { x: 1.0, y: 0.0, z: 0.0 }],
|
||||
apf_repulsion_dist: 5.0,
|
||||
step_size_m: 2.0,
|
||||
};
|
||||
let pos = Position3D { x: 0.0, y: 0.0, z: 0.0 };
|
||||
let (fx, _, _) = planner.apf_force(&pos, &[]);
|
||||
assert!(fx < 0.0); // pushed away from x=1 obstacle
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_reaches_goal() {
|
||||
let planner = RrtApfPlanner::new(3.0);
|
||||
let start = Position3D { x: 0.0, y: 0.0, z: -30.0 };
|
||||
let goal = Position3D { x: 50.0, y: 50.0, z: -30.0 };
|
||||
let mut rng = rand::thread_rng();
|
||||
let path = planner.plan(start, goal, 500, &mut rng);
|
||||
let last = path.last().unwrap();
|
||||
// The RRT either reaches goal directly or the fallback end is the goal itself.
|
||||
assert!(last.position.distance_to(&goal) < 10.0, "path should end near goal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apf_repulsion_nonzero_near_obstacle() {
|
||||
let planner = RrtApfPlanner {
|
||||
obstacle_cells: vec![Position3D { x: 3.0, y: 0.0, z: 0.0 }],
|
||||
apf_repulsion_dist: 5.0,
|
||||
step_size_m: 2.0,
|
||||
};
|
||||
let pos = Position3D { x: 0.0, y: 0.0, z: 0.0 };
|
||||
let (fx, _, _) = planner.apf_force(&pos, &[]);
|
||||
assert!(fx < 0.0, "repulsion should push away from obstacle (negative x)");
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user