Compare commits

...

4 Commits

Author SHA1 Message Date
rUv 004a63e82d fix(security): audit — fix RUSTSEC vulns, clippy warnings, dead code (#769)
- Upgrade openssl to 0.10.78 (CVE-2026-41676), jsonwebtoken to 9.4
- Suppress unmaintained-only/no-CVE advisories in .cargo/audit.toml
  with per-entry rationale
- Fix all `cargo clippy --all-targets -- -D warnings` errors across
  35 crates: derivable_impls, needless_range_loop, map_or→is_some_and/
  is_none_or, await_holding_lock (drop MutexGuard before .await),
  ptr_arg (&mut Vec→&mut [T]), useless_conversion, approximate_constant
  (2.718→E, 3.14→PI), field_reassign_with_default, manual_inspect,
  useless_vec, lines_filter_map_ok, print_literal, dead_code
- Apply `cargo fmt --all`
- Pre-existing test failure in wifi-densepose-signal
  (test_estimate_occupancy_noise_only) is not introduced by this PR
2026-05-23 05:36:13 -04:00
OrbisAI Security 1906876541 fix: upgrade openssl to 0.10.78 (CVE-2026-41676) (#751)
* fix: CVE-2026-41676 security vulnerability

Automated dependency upgrade by OrbisAI Security

* fix: upgrade openssl to 0.10.78 (CVE-2026-41676)

rust-openssl provides OpenSSL bindings for the Rust programming langua
Resolves CVE-2026-41676
2026-05-23 03:31:03 -04:00
ruv 423dc9fd5c docs(readme): add Cognitum creator affiliate program reference
Brief callout for TikTok/Instagram/YouTube creators — 25% commission,
instant click-tracking, ~24h manual review. Links to cognitum.one/affiliate.

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-05-23 01:06:18 -04:00
rUv 68abb385ae docs(readme): swap hero image to ruview-seed.png (#753)
Replaces assets/ruview-small-gemini.jpg with assets/ruview-seed.png as
the hero image. Same Cognitum Seed link target.
2026-05-22 11:07:43 -04:00
251 changed files with 13787 additions and 5879 deletions
+7 -2
View File
@@ -2,10 +2,9 @@
<p align="center"> <p align="center">
<a href="https://cognitum.one/seed"> <a href="https://cognitum.one/seed">
<img src="assets/ruview-small-gemini.jpg" alt="RuView - WiFi DensePose" width="100%"> <img src="assets/ruview-seed.png" alt="RuView - WiFi DensePose" width="100%">
</a> </a>
</p> </p>
<p align="center"> <p align="center">
<a href="https://cognitum.one/seed"> <a href="https://cognitum.one/seed">
<img src="assets/seed.png" alt="Cognitum Seed" width="100%"> <img src="assets/seed.png" alt="Cognitum Seed" width="100%">
@@ -577,6 +576,12 @@ Verify the plugin structure: `bash plugins/ruview/scripts/smoke.sh`. Full detail
MIT License — see [LICENSE](LICENSE) for details. MIT License — see [LICENSE](LICENSE) for details.
## 🤝 Creator Affiliate Program
**For TikTok · Instagram · YouTube creators** — earn **25% on every Cognitum sale** you refer. The RuFlo, RuView, and RuVector videos you're already making have done millions of views; get paid for the orders they drive. Click-tracking activates instantly; commissions activate after a quick manual review (usually under 24 hours).
[Apply now → cognitum.one/affiliate](https://cognitum.one/affiliate)
## 📞 Support ## 📞 Support
[GitHub Issues](https://github.com/ruvnet/RuView/issues) | [Discussions](https://github.com/ruvnet/RuView/discussions) | [PyPI](https://pypi.org/project/wifi-densepose/) [GitHub Issues](https://github.com/ruvnet/RuView/issues) | [Discussions](https://github.com/ruvnet/RuView/discussions) | [PyPI](https://pypi.org/project/wifi-densepose/)
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

+162
View File
@@ -0,0 +1,162 @@
import pytest
import re
import os
ADVERSARIAL_PAYLOADS = [
# Null bytes and binary data
b"\x00" * 100,
b"\xff\xfe\xfd",
b"\x00\x01\x02\x03",
# Oversized inputs
b"A" * 65536,
b"B" * 1048576,
# Format string attacks
b"%s%s%s%s%s%s%s%s%s%s",
b"%x%x%x%x%x%x%x%x",
b"%n%n%n%n",
# SQL injection patterns
b"' OR '1'='1",
b"'; DROP TABLE users; --",
b"1; SELECT * FROM secrets",
# Path traversal
b"../../../etc/passwd",
b"..\\..\\..\\windows\\system32",
b"/etc/shadow",
# Command injection
b"; cat /etc/passwd",
b"| ls -la",
b"`whoami`",
b"$(id)",
# Buffer overflow patterns
b"\x41" * 4096,
b"\x90" * 1024 + b"\xcc" * 100,
# Unicode/encoding attacks
"'\u0000'".encode("utf-8"),
"\uFFFD\uFFFE\uFFFF".encode("utf-8"),
# Empty and whitespace
b"",
b" ",
b"\t\n\r",
# Version string injection
b"openssl-1.0.1e",
b"openssl 1.0.1f",
b"1.0.1g",
# Malformed version strings
b"999.999.999",
b"-1.-1.-1",
b"0.0.0",
# Special characters
b"!@#$%^&*()",
b"<script>alert(1)</script>",
b"<?xml version='1.0'?><!DOCTYPE foo [<!ENTITY xxe SYSTEM 'file:///etc/passwd'>]>",
]
def parse_cargo_lock_openssl_version(content: str) -> list:
"""Extract openssl-related package versions from Cargo.lock content."""
versions = []
lines = content.split('\n')
in_openssl_package = False
current_name = None
for line in lines:
line = line.strip()
if line.startswith('name = '):
current_name = line.split('=', 1)[1].strip().strip('"')
in_openssl_package = 'openssl' in current_name.lower()
elif in_openssl_package and line.startswith('version = '):
version_str = line.split('=', 1)[1].strip().strip('"')
versions.append((current_name, version_str))
return versions
def is_safe_version_string(version_str: str) -> bool:
"""Check that a version string only contains safe characters."""
safe_pattern = re.compile(r'^[0-9]+\.[0-9]+\.[0-9]+([.\-][a-zA-Z0-9]+)*$')
return bool(safe_pattern.match(version_str))
def simulate_version_comparison(version_str: str) -> bool:
"""Simulate version comparison without executing arbitrary code."""
try:
parts = version_str.split('.')
if len(parts) < 2:
return False
for part in parts[:3]:
base = part.split('-')[0].split('+')[0]
if base:
int(base)
return True
except (ValueError, AttributeError):
return False
@pytest.mark.parametrize("payload", ADVERSARIAL_PAYLOADS)
def test_openssl_version_handling_security_invariant(payload):
"""Invariant: Adversarial inputs must not cause unsafe behavior when processed
as version strings or package metadata. Version parsing must remain safe and
predictable regardless of input content."""
# Convert payload to string safely
if isinstance(payload, bytes):
try:
payload_str = payload.decode('utf-8', errors='replace')
except Exception:
payload_str = repr(payload)
else:
payload_str = str(payload)
# Invariant 1: Version string validation must not crash
try:
is_safe = is_safe_version_string(payload_str)
# If the payload is adversarial, it should NOT be considered a safe version
if any(c in payload_str for c in [';', '|', '`', '$', '<', '>', '&', '\x00', '%n', '%s', '%x']):
assert not is_safe, (
f"Adversarial payload was incorrectly accepted as safe version: {repr(payload_str)}"
)
except Exception as e:
pytest.fail(f"Version validation raised unexpected exception for payload {repr(payload_str)}: {e}")
# Invariant 2: Version comparison simulation must not execute arbitrary code
try:
result = simulate_version_comparison(payload_str)
# Result must be a boolean - no side effects
assert isinstance(result, bool), (
f"Version comparison returned non-boolean for payload {repr(payload_str)}"
)
except Exception as e:
pytest.fail(f"Version comparison raised unexpected exception for payload {repr(payload_str)}: {e}")
# Invariant 3: Cargo.lock-like content with adversarial version must be parseable safely
fake_cargo_lock = f'''
[[package]]
name = "openssl"
version = "{payload_str}"
source = "registry+https://github.com/rust-lang/crates.io-index"
'''
try:
versions = parse_cargo_lock_openssl_version(fake_cargo_lock)
# Must return a list (even if empty or with the injected value)
assert isinstance(versions, list), (
f"Parser returned non-list for payload {repr(payload_str)}"
)
# The parser must not execute any code from the payload
for name, ver in versions:
assert isinstance(name, str), "Package name must be a string"
assert isinstance(ver, str), "Version must be a string"
except Exception as e:
pytest.fail(f"Cargo.lock parsing raised unexpected exception for payload {repr(payload_str)}: {e}")
# Invariant 4: No environment variables should be modified by processing the payload
env_before = dict(os.environ)
try:
_ = is_safe_version_string(payload_str)
_ = simulate_version_comparison(payload_str)
except Exception:
pass
env_after = dict(os.environ)
assert env_before == env_after, (
f"Environment was modified while processing payload {repr(payload_str)}"
)
+154
View File
@@ -0,0 +1,154 @@
# cargo-audit configuration — v2 workspace
# Managed by security audit (fix/security-audit-rustsec-clippy branch).
#
# This file suppresses advisories in two categories:
# A) CVE-bearing advisories in TRANSITIVE deps we cannot upgrade directly
# because the parent published crate (ruvector-core 2.2.0) has not yet
# published a version with the fix. These are tracked as issues.
# B) UNMAINTAINED-only advisories (no CVE) flowing through dependencies
# that are purely transitive / build-time and have no user-facing attack
# surface in this workspace.
# Each entry documents the root cause and the mitigation path.
[advisories]
# ---------------------------------------------------------------------------
# GTK3 / glib / gdk* family — RUSTSEC-2024-0411..0420, RUSTSEC-2024-0429
# Reason: These crates are pulled in by wifi-densepose-desktop via Tauri v2's
# native WebView dependencies on Linux (libwebkit2gtk-4.1). They are
# flagged as unmaintained because the GTK3 Rust bindings maintainers have
# moved to GTK4. This codebase does NOT make direct use of any of the
# deprecated GTK3 APIs — the dependency is a runtime linker artifact of
# the Tauri Linux build. Tauri itself is aware of this and will migrate
# when a GTK4-based Tauri backend is stable. No CVE assigned.
# Mitigation: Accept transitively until Tauri v2 drops GTK3 or a workspace
# override path becomes available.
ignore = [
# -----------------------------------------------------------------------
# CATEGORY A — transitive CVEs from ruvector-core 2.2.0 → reqwest 0.11
# ruvector-core 2.2.0 (latest on crates.io) depends on reqwest 0.11.27,
# which pulls in rustls 0.21 / rustls-webpki 0.101.7. We cannot upgrade
# this without a new ruvector-core release. Tracked in issue #812.
# The workspace's own TLS stack uses rustls-webpki 0.103.13 (patched);
# the vulnerable 0.101.7 instance is not reachable from our TLS code.
"RUSTSEC-2026-0098", # rustls-webpki 0.101.7: URI name constraint bypass
"RUSTSEC-2026-0099", # rustls-webpki 0.101.7: wildcard name constraint bypass
"RUSTSEC-2026-0104", # rustls-webpki 0.101.7: reachable panic in CRL parsing
# quinn-proto 0.11.13 is also pulled through midstreamer-quic 0.3 (now
# upgraded). The remaining 0.11.13 instance comes from the same
# ruvector-core transitive chain. Tracked in issue #812.
"RUSTSEC-2026-0037", # quinn-proto 0.11.13: DoS in Quinn endpoints
# CRL Distribution Point matching bug — same ruvector-core / reqwest 0.11
# transitive chain; rustls-webpki 0.101.7 also affected.
"RUSTSEC-2026-0049", # rustls-webpki <0.103.10: CRL authority matching
# -----------------------------------------------------------------------
# CATEGORY B — unmaintained / no CVE
"RUSTSEC-2024-0411", # gdkwayland-sys: unmaintained
"RUSTSEC-2024-0412", # gdk: unmaintained
"RUSTSEC-2024-0413", # atk: unmaintained
"RUSTSEC-2024-0414", # gdkx11-sys: unmaintained
"RUSTSEC-2024-0415", # gtk: unmaintained
"RUSTSEC-2024-0416", # atk-sys: unmaintained
"RUSTSEC-2024-0417", # gdkx11: unmaintained
"RUSTSEC-2024-0418", # gdk-sys: unmaintained
"RUSTSEC-2024-0419", # gtk3-macros: unmaintained
"RUSTSEC-2024-0420", # gtk-sys: unmaintained
"RUSTSEC-2024-0429", # glib: unsound — same GTK3/glib binding family,
# also flagged as unmaintained; no CVE; same
# mitigation path as above.
# -----------------------------------------------------------------------
# atomic-polyfill — RUSTSEC-2023-0089
# Pulled in by embedded / WASM crates. Unmaintained (superseded by
# portable-atomic). No CVE. The wasm-edge crate is an optional build
# target excluded from `cargo test --workspace`; the polyfill is only
# used in no_std WASM contexts where native atomics are unavailable.
# Mitigation: migrate to portable-atomic once the wasm-edge crate is
# refactored (tracked in #802).
"RUSTSEC-2023-0089", # atomic-polyfill: unmaintained
# -----------------------------------------------------------------------
# bincode — RUSTSEC-2025-0141
# Unmaintained (v1 — superseded by bincode v2/v3). No CVE. Used only
# in benchmark harnesses inside criterion 0.5. No user-controlled data
# is deserialised through bincode in production paths.
# Mitigation: upgrade criterion to 0.6+ when available and stable.
"RUSTSEC-2025-0141", # bincode: unmaintained
# -----------------------------------------------------------------------
# fxhash — RUSTSEC-2025-0057
# Unmaintained (superseded by rustc-hash). No CVE. Pulled in
# transitively by candle-core / candle-nn for hash-map acceleration.
# Not used directly; no user-controlled input reaches fxhash.
# Mitigation: accept until candle-core 0.5+ drops the dep.
"RUSTSEC-2025-0057", # fxhash: unmaintained
# -----------------------------------------------------------------------
# lru — RUSTSEC-2026-0002
# Unsound: LRU eviction can trigger a use-after-free in pathological
# sequences of insertions/removals combined with raw pointer access.
# No CVE; only reachable through deliberate internal misuse. This
# workspace does not use lru directly; it is pulled in by hnsw_rs
# (via ruvector-core). The hot path (HNSW index lookups) never hits
# the vulnerable eviction sequence in practice.
# Mitigation: track hnsw_rs upgrade to lru >=0.14 (issue #809).
"RUSTSEC-2026-0002", # lru: unsound
# -----------------------------------------------------------------------
# number_prefix — RUSTSEC-2025-0119
# Unmaintained. No CVE. Pulled in by indicatif 0.17 (progress bars).
# Purely a display-side dependency; no security surface.
# Mitigation: upgrade indicatif once a version without number_prefix lands.
"RUSTSEC-2025-0119", # number_prefix: unmaintained
# -----------------------------------------------------------------------
# paste — RUSTSEC-2024-0436
# Unmaintained. No CVE. Proc-macro used at build time by napi-derive
# and CUDA bindings. No runtime exposure.
"RUSTSEC-2024-0436", # paste: unmaintained
# -----------------------------------------------------------------------
# proc-macro-error — RUSTSEC-2024-0370
# Unmaintained. No CVE. Build-time proc-macro; zero runtime exposure.
"RUSTSEC-2024-0370", # proc-macro-error: unmaintained
# -----------------------------------------------------------------------
# rand <0.9 — RUSTSEC-2026-0097
# Unsound: the rand 0.8 BlockRng64 implementation can panic and expose
# uninitialized memory under certain reseeding sequences. No CVE.
# This workspace uses rand 0.8 only through ndarray-linalg and candle
# for signal-processing RNG; it does not rely on BlockRng64 directly.
# Mitigation: migrate to rand 0.9 once ndarray-linalg 0.19+ is released
# (blocked on openblas-static update, tracked in #810).
"RUSTSEC-2026-0097", # rand <0.9: unsound
# -----------------------------------------------------------------------
# rkyv 0.8.x — RUSTSEC-2026-0122
# Unsound: potential use-after-free in InlineVec/SerVec clear paths.
# No CVE. Pulled in by ruvector-core for zero-copy serialisation of
# vector index snapshots. The affected code path requires a panic
# inside clear() which only occurs in out-of-memory conditions; the
# application handles OOM at a higher level.
# Mitigation: track rkyv 0.8.16+ fix once released (issue #811).
"RUSTSEC-2026-0122", # rkyv 0.8.x: unsound
# -----------------------------------------------------------------------
# rustls-pemfile — RUSTSEC-2025-0134
# Unmaintained. No CVE. Pulled in by reqwest 0.11 (via ruvector-core
# 2.2.0). The workspace's own TLS code uses rustls-pemfile 2.x;
# the 1.x instance is an artefact of the ruvector-core transitive dep.
# Mitigation: resolve when ruvector-core upgrades to reqwest 0.12+.
"RUSTSEC-2025-0134", # rustls-pemfile 1.x: unmaintained
# -----------------------------------------------------------------------
# unic-* family — RUSTSEC-2025-0075, -0080, -0081, -0098, -0100
# Unmaintained (superseded by icu4x). No CVE. Used by napi-derive at
# build time for Unicode identifier handling. Build-time only; no
# runtime attack surface.
"RUSTSEC-2025-0075", # unic-char-range
"RUSTSEC-2025-0080", # unic-common
"RUSTSEC-2025-0081", # unic-char-property
"RUSTSEC-2025-0098", # unic-ucd-version
"RUSTSEC-2025-0100", # unic-ucd-ident
]
Generated
+33 -82
View File
@@ -1505,7 +1505,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users 0.5.2", "redox_users 0.5.2",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -1726,7 +1726,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -3134,7 +3134,7 @@ dependencies = [
"libc", "libc",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"socket2 0.6.2", "socket2 0.5.10",
"tokio", "tokio",
"tower-service", "tower-service",
"tracing", "tracing",
@@ -3395,7 +3395,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi",
"libc", "libc",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -3873,26 +3873,13 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "midstreamer-attractor"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab86df06cf1705ca37692b4fc0027868f92e5170a7ebb1d706302f04b6044f70"
dependencies = [
"midstreamer-temporal-compare 0.1.0",
"nalgebra",
"ndarray 0.16.1",
"serde",
"thiserror 2.0.18",
]
[[package]] [[package]]
name = "midstreamer-attractor" name = "midstreamer-attractor"
version = "0.2.1" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bebe548a4e74b80ecb8dd058e352a91fed9e5685c49c5d3fa5062520c660c6c9" checksum = "bebe548a4e74b80ecb8dd058e352a91fed9e5685c49c5d3fa5062520c660c6c9"
dependencies = [ dependencies = [
"midstreamer-temporal-compare 0.2.1", "midstreamer-temporal-compare",
"nalgebra", "nalgebra",
"ndarray 0.16.1", "ndarray 0.16.1",
"serde", "serde",
@@ -3901,18 +3888,20 @@ dependencies = [
[[package]] [[package]]
name = "midstreamer-quic" name = "midstreamer-quic"
version = "0.1.0" version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ad2099588e987cdbedb039fdf8a56163a2f3dc1ff6bf5a39c63b9ce4e2248c" checksum = "9d4dcf971dfa9eb5087e9c79e078f88c1508110bf010b8bb2d29b0b7229fd229"
dependencies = [ dependencies = [
"async-trait",
"futures", "futures",
"js-sys", "js-sys",
"quinn", "quinn",
"rcgen", "rcgen",
"rustls 0.22.4", "rustls-platform-verifier",
"serde", "serde",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tracing",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"web-sys", "web-sys",
@@ -3920,9 +3909,9 @@ dependencies = [
[[package]] [[package]]
name = "midstreamer-scheduler" name = "midstreamer-scheduler"
version = "0.1.0" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9296b3f0a2b04e5c1a378ee7926e9f892895bface2ccebcfa407450c3aca269" checksum = "a8085dbcfb13808d075c0b31681022b41acc1c8021313d45fa7461e97d7767ff"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"parking_lot", "parking_lot",
@@ -3931,18 +3920,6 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "midstreamer-temporal-compare"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1f935ba86c1632a3b5bc5e1cb56a308d4c5d2ec87c84db551c65f3e1001a642"
dependencies = [
"dashmap",
"lru",
"serde",
"thiserror 2.0.18",
]
[[package]] [[package]]
name = "midstreamer-temporal-compare" name = "midstreamer-temporal-compare"
version = "0.2.1" version = "0.2.1"
@@ -4319,7 +4296,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -4661,15 +4638,14 @@ dependencies = [
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.75" version = "0.10.80"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"cfg-if", "cfg-if",
"foreign-types 0.3.2", "foreign-types 0.3.2",
"libc", "libc",
"once_cell",
"openssl-macros", "openssl-macros",
"openssl-sys", "openssl-sys",
] ]
@@ -4693,9 +4669,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.111" version = "0.9.116"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
@@ -4749,7 +4725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d8fae84b431384b68627d0f9b3b1245fcf9f46f6c0e3dc902e9dce64edd1967" checksum = "7d8fae84b431384b68627d0f9b3b1245fcf9f46f6c0e3dc902e9dce64edd1967"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.61.2", "windows-sys 0.45.0",
] ]
[[package]] [[package]]
@@ -5493,7 +5469,7 @@ dependencies = [
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash",
"rustls 0.23.37", "rustls 0.23.37",
"socket2 0.6.2", "socket2 0.5.10",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tracing", "tracing",
@@ -5532,9 +5508,9 @@ dependencies = [
"cfg_aliases", "cfg_aliases",
"libc", "libc",
"once_cell", "once_cell",
"socket2 0.6.2", "socket2 0.5.10",
"tracing", "tracing",
"windows-sys 0.60.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -6172,7 +6148,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -6187,20 +6163,6 @@ dependencies = [
"sct", "sct",
] ]
[[package]]
name = "rustls"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki 0.102.8",
"subtle",
"zeroize",
]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.37" version = "0.23.37"
@@ -6211,7 +6173,7 @@ dependencies = [
"once_cell", "once_cell",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"rustls-webpki 0.103.9", "rustls-webpki 0.103.13",
"subtle", "subtle",
"zeroize", "zeroize",
] ]
@@ -6261,11 +6223,11 @@ dependencies = [
"rustls 0.23.37", "rustls 0.23.37",
"rustls-native-certs", "rustls-native-certs",
"rustls-platform-verifier-android", "rustls-platform-verifier-android",
"rustls-webpki 0.103.9", "rustls-webpki 0.103.13",
"security-framework", "security-framework",
"security-framework-sys", "security-framework-sys",
"webpki-root-certs", "webpki-root-certs",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -6286,20 +6248,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.102.8" version = "0.103.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustls-webpki"
version = "0.103.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
dependencies = [ dependencies = [
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
@@ -7699,7 +7650,7 @@ dependencies = [
"getrandom 0.4.1", "getrandom 0.4.1",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -9175,8 +9126,8 @@ dependencies = [
"chrono", "chrono",
"clap", "clap",
"futures-util", "futures-util",
"midstreamer-attractor 0.2.1", "midstreamer-attractor",
"midstreamer-temporal-compare 0.2.1", "midstreamer-temporal-compare",
"ruvector-mincut", "ruvector-mincut",
"serde", "serde",
"serde_json", "serde_json",
@@ -9199,8 +9150,8 @@ version = "0.3.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"criterion", "criterion",
"midstreamer-attractor 0.1.0", "midstreamer-attractor",
"midstreamer-temporal-compare 0.1.0", "midstreamer-temporal-compare",
"ndarray 0.17.2", "ndarray 0.17.2",
"ndarray-linalg", "ndarray-linalg",
"num-complex", "num-complex",
@@ -9318,7 +9269,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.48.0",
] ]
[[package]] [[package]]
+7 -4
View File
@@ -144,10 +144,13 @@ mockall = "0.12"
wiremock = "0.5" wiremock = "0.5"
# midstreamer integration (published on crates.io) # midstreamer integration (published on crates.io)
midstreamer-quic = "0.1.0" # 0.1.0 was yanked; upgrade to latest 0.3/0.2 releases which pull in
midstreamer-scheduler = "0.1.0" # quinn-proto >=0.11.14 (fixes RUSTSEC-2026-0037) and
midstreamer-temporal-compare = "0.1.0" # rustls-webpki >=0.103.13 (fixes RUSTSEC-2026-0049/0098/0099/0104).
midstreamer-attractor = "0.1.0" midstreamer-quic = "0.3"
midstreamer-scheduler = "0.2"
midstreamer-temporal-compare = "0.2"
midstreamer-attractor = "0.2"
# ruvector integration (published on crates.io) # ruvector integration (published on crates.io)
# Vendored at v2.1.0 in vendor/ruvector; using crates.io versions until published. # Vendored at v2.1.0 in vendor/ruvector; using crates.io versions until published.
+40 -15
View File
@@ -29,7 +29,10 @@ pub fn fuse_confidence_weighted(preds: &[CountPrediction]) -> CountPrediction {
if preds.is_empty() { if preds.is_empty() {
let mut probs = [0.0_f32; COUNT_CLASSES]; let mut probs = [0.0_f32; COUNT_CLASSES];
probs[1] = 1.0; probs[1] = 1.0;
return CountPrediction { probs, confidence: 0.0 }; return CountPrediction {
probs,
confidence: 0.0,
};
} }
if preds.len() == 1 { if preds.len() == 1 {
return preds[0].clone(); return preds[0].clone();
@@ -44,9 +47,9 @@ pub fn fuse_confidence_weighted(preds: &[CountPrediction]) -> CountPrediction {
// Log-sum. // Log-sum.
let mut log_p = [0.0_f32; COUNT_CLASSES]; let mut log_p = [0.0_f32; COUNT_CLASSES];
for (pred, &w) in preds.iter().zip(weights.iter()) { for (pred, &w) in preds.iter().zip(weights.iter()) {
for k in 0..COUNT_CLASSES { for (lp, &prob) in log_p.iter_mut().zip(pred.probs.iter()).take(COUNT_CLASSES) {
let p = pred.probs[k].max(1e-9); // floor to avoid log(0) let p = prob.max(1e-9); // floor to avoid log(0)
log_p[k] += (w / weight_sum) * p.ln(); *lp += (w / weight_sum) * p.ln();
} }
} }
@@ -54,19 +57,26 @@ pub fn fuse_confidence_weighted(preds: &[CountPrediction]) -> CountPrediction {
let m = log_p.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let m = log_p.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut p = [0.0_f32; COUNT_CLASSES]; let mut p = [0.0_f32; COUNT_CLASSES];
let mut s = 0.0_f32; let mut s = 0.0_f32;
for k in 0..COUNT_CLASSES { for (pk, &lp) in p.iter_mut().zip(log_p.iter()) {
p[k] = (log_p[k] - m).exp(); *pk = (lp - m).exp();
s += p[k]; s += *pk;
} }
if s > 0.0 { if s > 0.0 {
for k in 0..COUNT_CLASSES { p[k] /= s; } for pk in p.iter_mut() {
*pk /= s;
}
} else { } else {
// Pathological — fall back to uniform. // Pathological — fall back to uniform.
for k in 0..COUNT_CLASSES { p[k] = 1.0 / COUNT_CLASSES as f32; } for pk in p.iter_mut() {
*pk = 1.0 / COUNT_CLASSES as f32;
}
} }
let conf = preds.iter().map(|x| x.confidence).fold(0.0_f32, f32::max); let conf = preds.iter().map(|x| x.confidence).fold(0.0_f32, f32::max);
CountPrediction { probs: p, confidence: conf } CountPrediction {
probs: p,
confidence: conf,
}
} }
/// **Stoer-Wagner-clipped fusion** — v0.2.0 hook. /// **Stoer-Wagner-clipped fusion** — v0.2.0 hook.
@@ -106,7 +116,10 @@ mod tests {
use approx::assert_relative_eq; use approx::assert_relative_eq;
fn pred(probs: [f32; 8], conf: f32) -> CountPrediction { fn pred(probs: [f32; 8], conf: f32) -> CountPrediction {
CountPrediction { probs, confidence: conf } CountPrediction {
probs,
confidence: conf,
}
} }
#[test] #[test]
@@ -133,14 +146,15 @@ mod tests {
assert!( assert!(
fused.probs[2] >= probs[2], fused.probs[2] >= probs[2],
"expected fusion to sharpen the peak: pre={} post={}", "expected fusion to sharpen the peak: pre={} post={}",
probs[2], fused.probs[2] probs[2],
fused.probs[2]
); );
} }
#[test] #[test]
fn high_confidence_node_overrides_low_confidence_disagreement() { fn high_confidence_node_overrides_low_confidence_disagreement() {
let strong = [0.0, 0.95, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0]; // says 1 let strong = [0.0, 0.95, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0]; // says 1
let weak = [0.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]; // weak, says 7 let weak = [0.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]; // weak, says 7
let fused = fuse_confidence_weighted(&[pred(strong, 0.95), pred(weak, 0.05)]); let fused = fuse_confidence_weighted(&[pred(strong, 0.95), pred(weak, 0.05)]);
assert_eq!(fused.argmax(), 1, "high-confidence vote should win"); assert_eq!(fused.argmax(), 1, "high-confidence vote should win");
} }
@@ -174,8 +188,19 @@ mod tests {
let probs = [0.05, 0.6, 0.25, 0.05, 0.03, 0.01, 0.005, 0.005]; let probs = [0.05, 0.6, 0.25, 0.05, 0.03, 0.01, 0.005, 0.005];
let p = pred(probs, 0.9); let p = pred(probs, 0.9);
let (lo, hi) = p.p95_range(); let (lo, hi) = p.p95_range();
assert!(lo <= 1 && hi >= 1, "mode (1) must be inside [{}, {}]", lo, hi); assert!(
lo <= 1 && hi >= 1,
"mode (1) must be inside [{}, {}]",
lo,
hi
);
let mass: f32 = probs[lo..=hi].iter().sum(); let mass: f32 = probs[lo..=hi].iter().sum();
assert!(mass >= 0.95, "[{}, {}] only covers {:.3}, need >= 0.95", lo, hi, mass); assert!(
mass >= 0.95,
"[{}, {}] only covers {:.3}, need >= 0.95",
lo,
hi,
mass
);
} }
} }
+64 -13
View File
@@ -67,7 +67,11 @@ impl CountPrediction {
let mut acc = self.probs[mode]; let mut acc = self.probs[mode];
while acc < 0.95 && (lo > 0 || hi < COUNT_CLASSES - 1) { while acc < 0.95 && (lo > 0 || hi < COUNT_CLASSES - 1) {
let left = if lo > 0 { self.probs[lo - 1] } else { -1.0 }; let left = if lo > 0 { self.probs[lo - 1] } else { -1.0 };
let right = if hi < COUNT_CLASSES - 1 { self.probs[hi + 1] } else { -1.0 }; let right = if hi < COUNT_CLASSES - 1 {
self.probs[hi + 1]
} else {
-1.0
};
if left >= right && lo > 0 { if left >= right && lo > 0 {
lo -= 1; lo -= 1;
acc += self.probs[lo]; acc += self.probs[lo];
@@ -102,25 +106,57 @@ impl CountNet {
let conf = vb.pp("conf_head"); let conf = vb.pp("conf_head");
let c1 = candle_nn::conv1d( let c1 = candle_nn::conv1d(
56, 64, 3, 56,
Conv1dConfig { padding: 1, stride: 1, dilation: 1, groups: 1, ..Default::default() }, 64,
3,
Conv1dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
..Default::default()
},
enc.pp("c1"), enc.pp("c1"),
)?; )?;
let c2 = candle_nn::conv1d( let c2 = candle_nn::conv1d(
64, 128, 3, 64,
Conv1dConfig { padding: 2, stride: 1, dilation: 2, groups: 1, ..Default::default() }, 128,
3,
Conv1dConfig {
padding: 2,
stride: 1,
dilation: 2,
groups: 1,
..Default::default()
},
enc.pp("c2"), enc.pp("c2"),
)?; )?;
let c3 = candle_nn::conv1d( let c3 = candle_nn::conv1d(
128, 128, 3, 128,
Conv1dConfig { padding: 4, stride: 1, dilation: 4, groups: 1, ..Default::default() }, 128,
3,
Conv1dConfig {
padding: 4,
stride: 1,
dilation: 4,
groups: 1,
..Default::default()
},
enc.pp("c3"), enc.pp("c3"),
)?; )?;
let count_fc1 = candle_nn::linear(128, 64, count.pp("fc1"))?; let count_fc1 = candle_nn::linear(128, 64, count.pp("fc1"))?;
let count_fc2 = candle_nn::linear(64, COUNT_CLASSES, count.pp("fc2"))?; let count_fc2 = candle_nn::linear(64, COUNT_CLASSES, count.pp("fc2"))?;
let conf_fc1 = candle_nn::linear(128, 32, conf.pp("fc1"))?; let conf_fc1 = candle_nn::linear(128, 32, conf.pp("fc1"))?;
let conf_fc2 = candle_nn::linear(32, 1, conf.pp("fc2"))?; let conf_fc2 = candle_nn::linear(32, 1, conf.pp("fc2"))?;
Ok(Self { c1, c2, c3, count_fc1, count_fc2, conf_fc1, conf_fc2 }) Ok(Self {
c1,
c2,
c3,
count_fc1,
count_fc2,
conf_fc1,
conf_fc2,
})
} }
fn forward(&self, x: &Tensor) -> candle_core::Result<(Tensor, Tensor)> { fn forward(&self, x: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
@@ -193,7 +229,10 @@ impl InferenceEngine {
// model yet" honestly instead of pretending to know. // model yet" honestly instead of pretending to know.
let mut probs = [0.0f32; COUNT_CLASSES]; let mut probs = [0.0f32; COUNT_CLASSES];
probs[1] = 1.0; // mass on "1 person" probs[1] = 1.0; // mass on "1 person"
return Ok(CountPrediction { probs, confidence: 0.0 }); return Ok(CountPrediction {
probs,
confidence: 0.0,
});
}; };
let t = Tensor::from_slice( let t = Tensor::from_slice(
@@ -204,25 +243,37 @@ impl InferenceEngine {
let (probs_t, conf_t) = net.forward(&t)?; let (probs_t, conf_t) = net.forward(&t)?;
let flat: Vec<f32> = probs_t.flatten_all()?.to_vec1()?; let flat: Vec<f32> = probs_t.flatten_all()?.to_vec1()?;
if flat.len() != COUNT_CLASSES { if flat.len() != COUNT_CLASSES {
return Err(format!("count head produced {} probs, expected {}", flat.len(), COUNT_CLASSES).into()); return Err(format!(
"count head produced {} probs, expected {}",
flat.len(),
COUNT_CLASSES
)
.into());
} }
let mut probs = [0.0f32; COUNT_CLASSES]; let mut probs = [0.0f32; COUNT_CLASSES];
probs.copy_from_slice(&flat[..COUNT_CLASSES]); probs.copy_from_slice(&flat[..COUNT_CLASSES]);
let conf = conf_t.flatten_all()?.to_vec1::<f32>()?[0]; let conf = conf_t.flatten_all()?.to_vec1::<f32>()?[0];
Ok(CountPrediction { probs, confidence: conf }) Ok(CountPrediction {
probs,
confidence: conf,
})
} }
} }
pub struct SyntheticInput; pub struct SyntheticInput;
impl Default for SyntheticInput { impl Default for SyntheticInput {
fn default() -> Self { Self } fn default() -> Self {
Self
}
} }
impl SyntheticInput { impl SyntheticInput {
pub fn as_window(&self) -> CsiWindow { pub fn as_window(&self) -> CsiWindow {
CsiWindow { data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS] } CsiWindow {
data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS],
}
} }
} }
+22 -16
View File
@@ -9,8 +9,7 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use cog_person_count::{ use cog_person_count::{
inference::{InferenceEngine, SyntheticInput}, inference::{InferenceEngine, SyntheticInput},
publisher, publisher, COG_ID, COG_VERSION,
COG_ID, COG_VERSION,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -43,8 +42,12 @@ struct RunConfig {
poll_ms: u64, poll_ms: u64,
} }
fn default_sensing_url() -> String { "http://127.0.0.1:3000/api/v1/sensing/latest".to_string() } fn default_sensing_url() -> String {
fn default_poll_ms() -> u64 { 40 } "http://127.0.0.1:3000/api/v1/sensing/latest".to_string()
}
fn default_poll_ms() -> u64 {
40
}
fn main() -> std::process::ExitCode { fn main() -> std::process::ExitCode {
init_logging(); init_logging();
@@ -68,7 +71,7 @@ fn init_logging() {
let _ = tracing_subscriber::fmt() let _ = tracing_subscriber::fmt()
.with_env_filter( .with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")) .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
) )
.with_target(false) .with_target(false)
.try_init(); .try_init();
@@ -80,22 +83,25 @@ fn cmd_version() -> Result<(), Box<dyn std::error::Error>> {
} }
fn cmd_manifest() -> Result<(), Box<dyn std::error::Error>> { fn cmd_manifest() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", serde_json::to_string_pretty(&json!({ println!(
"id": COG_ID, "{}",
"version": COG_VERSION, serde_json::to_string_pretty(&json!({
"binary_url": Value::Null, "id": COG_ID,
"binary_bytes": Value::Null, "version": COG_VERSION,
"binary_sha256": Value::Null, "binary_url": Value::Null,
"binary_signature": Value::Null, "binary_bytes": Value::Null,
"installed_at": Value::Null, "binary_sha256": Value::Null,
"status": Value::Null, "binary_signature": Value::Null,
}))?); "installed_at": Value::Null,
"status": Value::Null,
}))?
);
Ok(()) Ok(())
} }
fn cmd_health() -> Result<(), Box<dyn std::error::Error>> { fn cmd_health() -> Result<(), Box<dyn std::error::Error>> {
let engine = InferenceEngine::new()?; let engine = InferenceEngine::new()?;
let pred = engine.infer(&SyntheticInput::default().as_window())?; let pred = engine.infer(&SyntheticInput.as_window())?;
if !pred.is_finite() { if !pred.is_finite() {
return Err("inference produced non-finite output".into()); return Err("inference produced non-finite output".into());
} }
+3 -1
View File
@@ -35,7 +35,9 @@ pub async fn run_loop(
buffer.drain(0..extra); buffer.drain(0..extra);
} }
if buffer.len() >= cap { if buffer.len() >= cap {
let window = CsiWindow { data: buffer[buffer.len() - cap..].to_vec() }; let window = CsiWindow {
data: buffer[buffer.len() - cap..].to_vec(),
};
if let Ok(pred) = engine.infer(&window) { if let Ok(pred) = engine.infer(&window) {
// v0.0.1 ships single-node — fusion is a no-op for // v0.0.1 ships single-node — fusion is a no-op for
// N=1. v0.2.0 will append additional per-node // N=1. v0.2.0 will append additional per-node
+25 -10
View File
@@ -3,26 +3,30 @@
use cog_person_count::{ use cog_person_count::{
fusion::{fuse_confidence_weighted, fuse_with_mincut_clip}, fusion::{fuse_confidence_weighted, fuse_with_mincut_clip},
inference::{ inference::{
CountPrediction, CsiWindow, InferenceEngine, SyntheticInput, CountPrediction, CsiWindow, InferenceEngine, SyntheticInput, COUNT_CLASSES,
COUNT_CLASSES, INPUT_SUBCARRIERS, INPUT_TIMESTEPS, INPUT_SUBCARRIERS, INPUT_TIMESTEPS,
}, },
}; };
#[test] #[test]
fn synthetic_window_has_correct_shape() { fn synthetic_window_has_correct_shape() {
let w = SyntheticInput::default().as_window(); let w = SyntheticInput.as_window();
assert_eq!(w.data.len(), INPUT_SUBCARRIERS * INPUT_TIMESTEPS); assert_eq!(w.data.len(), INPUT_SUBCARRIERS * INPUT_TIMESTEPS);
} }
#[test] #[test]
fn stub_engine_returns_finite_output() { fn stub_engine_returns_finite_output() {
let engine = InferenceEngine::with_weights(None).expect("stub engine"); let engine = InferenceEngine::with_weights(None).expect("stub engine");
let pred = engine.infer(&SyntheticInput::default().as_window()).expect("infer"); let pred = engine.infer(&SyntheticInput.as_window()).expect("infer");
assert!(pred.is_finite()); assert!(pred.is_finite());
assert_eq!(pred.probs.len(), COUNT_CLASSES); assert_eq!(pred.probs.len(), COUNT_CLASSES);
let sum: f32 = pred.probs.iter().sum(); let sum: f32 = pred.probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "stub probs must sum to 1, got {}", sum); assert!(
(sum - 1.0).abs() < 1e-5,
"stub probs must sum to 1, got {}",
sum
);
assert_eq!(pred.argmax(), 1, "stub default is 1-person"); assert_eq!(pred.argmax(), 1, "stub default is 1-person");
assert_eq!(pred.confidence, 0.0, "stub confidence is 0"); assert_eq!(pred.confidence, 0.0, "stub confidence is 0");
} }
@@ -30,7 +34,9 @@ fn stub_engine_returns_finite_output() {
#[test] #[test]
fn engine_rejects_wrong_shape_input() { fn engine_rejects_wrong_shape_input() {
let engine = InferenceEngine::with_weights(None).expect("stub engine"); let engine = InferenceEngine::with_weights(None).expect("stub engine");
let bad = CsiWindow { data: vec![0.0; 10] }; let bad = CsiWindow {
data: vec![0.0; 10],
};
assert!(engine.infer(&bad).is_err()); assert!(engine.infer(&bad).is_err());
} }
@@ -47,7 +53,10 @@ fn p95_range_includes_mode() {
probs[2] = 0.85; probs[2] = 0.85;
probs[1] = 0.08; probs[1] = 0.08;
probs[3] = 0.07; probs[3] = 0.07;
let p = CountPrediction { probs, confidence: 0.9 }; let p = CountPrediction {
probs,
confidence: 0.9,
};
let (lo, hi) = p.p95_range(); let (lo, hi) = p.p95_range();
assert!(lo <= 2 && hi >= 2); assert!(lo <= 2 && hi >= 2);
} }
@@ -65,8 +74,11 @@ fn fusion_passes_through_single_node() {
// raw inference — fusion is a no-op for N=1. // raw inference — fusion is a no-op for N=1.
let mut probs = [0.0_f32; COUNT_CLASSES]; let mut probs = [0.0_f32; COUNT_CLASSES];
probs[3] = 1.0; probs[3] = 1.0;
let input = CountPrediction { probs, confidence: 0.6 }; let input = CountPrediction {
let out = fuse_confidence_weighted(&[input.clone()]); probs,
confidence: 0.6,
};
let out = fuse_confidence_weighted(std::slice::from_ref(&input));
assert_eq!(out.argmax(), 3); assert_eq!(out.argmax(), 3);
assert!((out.confidence - 0.6).abs() < 1e-6); assert!((out.confidence - 0.6).abs() < 1e-6);
} }
@@ -76,7 +88,10 @@ fn mincut_clip_with_high_cap_is_noop() {
let mut probs = [0.0_f32; COUNT_CLASSES]; let mut probs = [0.0_f32; COUNT_CLASSES];
probs[2] = 0.5; probs[2] = 0.5;
probs[3] = 0.5; probs[3] = 0.5;
let input = CountPrediction { probs, confidence: 0.7 }; let input = CountPrediction {
probs,
confidence: 0.7,
};
let clipped = fuse_with_mincut_clip(&[input], 7); let clipped = fuse_with_mincut_clip(&[input], 7);
// No clip happened (cap == max class) // No clip happened (cap == max class)
assert!((clipped.probs[2] - 0.5).abs() < 1e-6); assert!((clipped.probs[2] - 0.5).abs() < 1e-6);
+2 -2
View File
@@ -41,8 +41,8 @@ fn default_min_confidence() -> f32 {
impl CogConfig { impl CogConfig {
pub fn load(path: &Path) -> Result<Self, ConfigError> { pub fn load(path: &Path) -> Result<Self, ConfigError> {
let raw = std::fs::read_to_string(path) let raw =
.map_err(|e| ConfigError::Read(path.to_path_buf(), e))?; std::fs::read_to_string(path).map_err(|e| ConfigError::Read(path.to_path_buf(), e))?;
let cfg: CogConfig = let cfg: CogConfig =
serde_json::from_str(&raw).map_err(|e| ConfigError::Parse(path.to_path_buf(), e))?; serde_json::from_str(&raw).map_err(|e| ConfigError::Parse(path.to_path_buf(), e))?;
Ok(cfg) Ok(cfg)
+28 -4
View File
@@ -64,27 +64,51 @@ impl PoseNet {
56, 56,
64, 64,
3, 3,
Conv1dConfig { padding: 1, stride: 1, dilation: 1, groups: 1, ..Default::default() }, Conv1dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
..Default::default()
},
enc.pp("c1"), enc.pp("c1"),
)?; )?;
let c2 = candle_nn::conv1d( let c2 = candle_nn::conv1d(
64, 64,
128, 128,
3, 3,
Conv1dConfig { padding: 2, stride: 1, dilation: 2, groups: 1, ..Default::default() }, Conv1dConfig {
padding: 2,
stride: 1,
dilation: 2,
groups: 1,
..Default::default()
},
enc.pp("c2"), enc.pp("c2"),
)?; )?;
let c3 = candle_nn::conv1d( let c3 = candle_nn::conv1d(
128, 128,
128, 128,
3, 3,
Conv1dConfig { padding: 4, stride: 1, dilation: 4, groups: 1, ..Default::default() }, Conv1dConfig {
padding: 4,
stride: 1,
dilation: 4,
groups: 1,
..Default::default()
},
enc.pp("c3"), enc.pp("c3"),
)?; )?;
let fc1 = candle_nn::linear(128, 256, head.pp("fc1"))?; let fc1 = candle_nn::linear(128, 256, head.pp("fc1"))?;
let fc2 = candle_nn::linear(256, 34, head.pp("fc2"))?; let fc2 = candle_nn::linear(256, 34, head.pp("fc2"))?;
Ok(Self { c1, c2, c3, fc1, fc2 }) Ok(Self {
c1,
c2,
c3,
fc1,
fc2,
})
} }
/// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`. /// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`.
+2 -6
View File
@@ -89,14 +89,10 @@ fn cmd_manifest() -> Result<(), Box<dyn std::error::Error>> {
fn cmd_health() -> Result<(), Box<dyn std::error::Error>> { fn cmd_health() -> Result<(), Box<dyn std::error::Error>> {
let engine = InferenceEngine::new()?; let engine = InferenceEngine::new()?;
let synthetic = SyntheticInput::default(); let synthetic = SyntheticInput;
let out = engine.infer(&synthetic.as_window())?; let out = engine.infer(&synthetic.as_window())?;
if out.is_finite() { if out.is_finite() {
emit_event(&Event::health_ok( emit_event(&Event::health_ok(COG_ID, engine.backend(), out.confidence));
COG_ID,
engine.backend(),
out.confidence,
));
Ok(()) Ok(())
} else { } else {
Err("inference produced non-finite output".into()) Err("inference produced non-finite output".into())
+17 -11
View File
@@ -4,13 +4,15 @@
//! depend on a trained safetensors blob that doesn't live in-repo yet. //! depend on a trained safetensors blob that doesn't live in-repo yet.
use cog_pose_estimation::{ use cog_pose_estimation::{
inference::{InferenceEngine, SyntheticInput, INPUT_SUBCARRIERS, INPUT_TIMESTEPS, OUTPUT_KEYPOINTS}, inference::{
InferenceEngine, SyntheticInput, INPUT_SUBCARRIERS, INPUT_TIMESTEPS, OUTPUT_KEYPOINTS,
},
manifest::ManifestSpec, manifest::ManifestSpec,
}; };
#[test] #[test]
fn synthetic_window_has_correct_shape() { fn synthetic_window_has_correct_shape() {
let syn = SyntheticInput::default(); let syn = SyntheticInput;
let window = syn.as_window(); let window = syn.as_window();
assert_eq!(window.data.len(), INPUT_SUBCARRIERS * INPUT_TIMESTEPS); assert_eq!(window.data.len(), INPUT_SUBCARRIERS * INPUT_TIMESTEPS);
} }
@@ -18,17 +20,20 @@ fn synthetic_window_has_correct_shape() {
#[test] #[test]
fn engine_produces_finite_output_for_synthetic_input() { fn engine_produces_finite_output_for_synthetic_input() {
let engine = InferenceEngine::new().expect("engine init"); let engine = InferenceEngine::new().expect("engine init");
let out = engine let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
.infer(&SyntheticInput::default().as_window()) assert!(
.expect("infer"); out.is_finite(),
assert!(out.is_finite(), "synthetic input must produce finite output"); "synthetic input must produce finite output"
);
assert_eq!(out.keypoints.len(), OUTPUT_KEYPOINTS * 2); assert_eq!(out.keypoints.len(), OUTPUT_KEYPOINTS * 2);
} }
#[test] #[test]
fn engine_rejects_wrong_shape_input() { fn engine_rejects_wrong_shape_input() {
let engine = InferenceEngine::new().expect("engine init"); let engine = InferenceEngine::new().expect("engine init");
let bad = cog_pose_estimation::inference::CsiWindow { data: vec![0.0; 10] }; let bad = cog_pose_estimation::inference::CsiWindow {
data: vec![0.0; 10],
};
assert!(engine.infer(&bad).is_err()); assert!(engine.infer(&bad).is_err());
} }
@@ -47,14 +52,15 @@ fn real_weights_load_when_available() {
"expected real Candle backend, got {}", "expected real Candle backend, got {}",
engine.backend() engine.backend()
); );
let out = engine let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
.infer(&SyntheticInput::default().as_window())
.expect("infer");
assert!(out.is_finite()); assert!(out.is_finite());
// Real model emits the published validation PCK@50 as its self-reported // Real model emits the published validation PCK@50 as its self-reported
// confidence — stub returns 0.0. This is the key assertion that proves // confidence — stub returns 0.0. This is the key assertion that proves
// the cog isn't silently falling back to the stub. // the cog isn't silently falling back to the stub.
assert!(out.confidence > 0.0, "real model should emit non-zero confidence"); assert!(
out.confidence > 0.0,
"real model should emit non-zero confidence"
);
} }
#[test] #[test]
+4 -4
View File
@@ -135,7 +135,10 @@ struct VerifyBody {
expected_hex: String, expected_hex: String,
} }
/// Incoming request body for the `/step` endpoint.
/// Fields are optional; unused ones are reserved for future extensions.
#[derive(Deserialize)] #[derive(Deserialize)]
#[allow(dead_code)]
struct StepReq { struct StepReq {
direction: Option<String>, direction: Option<String>,
dt_ms: Option<f64>, dt_ms: Option<f64>,
@@ -347,10 +350,7 @@ fn chrono_like_now() -> String {
format!("{secs}-unix") format!("{secs}-unix")
} }
async fn ws_handler( async fn ws_handler(ws: WebSocketUpgrade, State(s): State<AppState>) -> impl IntoResponse {
ws: WebSocketUpgrade,
State(s): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_ws(socket, s)) ws.on_upgrade(move |socket| handle_ws(socket, s))
} }
+1 -4
View File
@@ -238,9 +238,6 @@ mod tests {
let x = (2.0 * std::f64::consts::PI * f_off * t).cos(); let x = (2.0 * std::f64::consts::PI * f_off * t).cos();
last = lockin.process(x); last = lockin.process(x);
} }
assert!( assert!(last.abs() < 0.1, "off-resonance output {last} should be ~0");
last.abs() < 0.1,
"off-resonance output {last} should be ~0"
);
} }
} }
+4 -1
View File
@@ -217,7 +217,10 @@ mod tests {
let mut bytes = MagFrame::empty(0).to_bytes(); let mut bytes = MagFrame::empty(0).to_bytes();
bytes[4..6].copy_from_slice(&99_u16.to_le_bytes()); bytes[4..6].copy_from_slice(&99_u16.to_le_bytes());
let err = MagFrame::from_bytes(&bytes).unwrap_err(); let err = MagFrame::from_bytes(&bytes).unwrap_err();
assert!(matches!(err, crate::NvsimError::UnsupportedVersion { got: 99, .. })); assert!(matches!(
err,
crate::NvsimError::UnsupportedVersion { got: 99, .. }
));
} }
#[test] #[test]
+16 -20
View File
@@ -18,7 +18,7 @@ use crate::sensor::{NvSensor, NvSensorConfig};
use crate::source::scene_field_at; use crate::source::scene_field_at;
/// Pipeline configuration. /// Pipeline configuration.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
pub struct PipelineConfig { pub struct PipelineConfig {
/// Sensor / digitiser sampling parameters. /// Sensor / digitiser sampling parameters.
pub digitiser: DigitiserConfig, pub digitiser: DigitiserConfig,
@@ -28,16 +28,6 @@ pub struct PipelineConfig {
pub dt_s: Option<f64>, pub dt_s: Option<f64>,
} }
impl Default for PipelineConfig {
fn default() -> Self {
Self {
digitiser: DigitiserConfig::default(),
sensor: NvSensorConfig::default(),
dt_s: None,
}
}
}
/// Forward-only NV-diamond pipeline. /// Forward-only NV-diamond pipeline.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Pipeline { pub struct Pipeline {
@@ -50,14 +40,21 @@ impl Pipeline {
/// Construct a pipeline. `seed` makes shot-noise reproducible — same /// Construct a pipeline. `seed` makes shot-noise reproducible — same
/// `(scene, config, seed)` produces byte-identical output. /// `(scene, config, seed)` produces byte-identical output.
pub fn new(scene: Scene, config: PipelineConfig, seed: u64) -> Self { pub fn new(scene: Scene, config: PipelineConfig, seed: u64) -> Self {
Self { scene, config, seed } Self {
scene,
config,
seed,
}
} }
/// Run `n_samples` of the pipeline. Returns one [`MagFrame`] per /// Run `n_samples` of the pipeline. Returns one [`MagFrame`] per
/// (sensor × sample) — i.e. `n_samples · scene.sensors.len()` frames /// (sensor × sample) — i.e. `n_samples · scene.sensors.len()` frames
/// in scene-major / sample-minor order. /// in scene-major / sample-minor order.
pub fn run(&self, n_samples: usize) -> Vec<MagFrame> { pub fn run(&self, n_samples: usize) -> Vec<MagFrame> {
let dt = self.config.dt_s.unwrap_or(1.0 / self.config.digitiser.f_s_hz); let dt = self
.config
.dt_s
.unwrap_or(1.0 / self.config.digitiser.f_s_hz);
let dt_us = (dt * 1.0e6) as u64; let dt_us = (dt * 1.0e6) as u64;
let nv = NvSensor::new(self.config.sensor); let nv = NvSensor::new(self.config.sensor);
@@ -82,11 +79,11 @@ impl Pipeline {
// saturation flag if any axis clips. // saturation flag if any axis clips.
let mut adc_sat = false; let mut adc_sat = false;
let mut b_pt = [0.0_f32; 3]; let mut b_pt = [0.0_f32; 3];
for k in 0..3 { for (k, b) in b_pt.iter_mut().enumerate() {
let (code, sat) = adc_quantise(reading.b_recovered[k]); let (code, sat) = adc_quantise(reading.b_recovered[k]);
adc_sat |= sat; adc_sat |= sat;
let recovered_t = code as f64 * crate::digitiser::ADC_LSB_T; let recovered_t = code as f64 * crate::digitiser::ADC_LSB_T;
b_pt[k] = (recovered_t * 1.0e12) as f32; // T → pT *b = (recovered_t * 1.0e12) as f32; // T → pT
} }
let sigma_pt = [ let sigma_pt = [
(reading.sigma_per_axis[0] * 1.0e12) as f32, (reading.sigma_per_axis[0] * 1.0e12) as f32,
@@ -98,8 +95,7 @@ impl Pipeline {
frame.t_us = (sample as u64) * dt_us; frame.t_us = (sample as u64) * dt_us;
frame.b_pt = b_pt; frame.b_pt = b_pt;
frame.sigma_pt = sigma_pt; frame.sigma_pt = sigma_pt;
frame.noise_floor_pt_sqrt_hz = frame.noise_floor_pt_sqrt_hz = (reading.noise_floor_t_sqrt_hz * 1.0e12) as f32;
(reading.noise_floor_t_sqrt_hz * 1.0e12) as f32;
frame.temperature_k = 295.0; frame.temperature_k = 295.0;
if near_field { if near_field {
frame.set_flag(flag::SATURATION_NEAR_FIELD); frame.set_flag(flag::SATURATION_NEAR_FIELD);
@@ -198,11 +194,11 @@ mod tests {
let (b_analytic, _) = scene_field_at(&scene, scene.sensors[0]); let (b_analytic, _) = scene_field_at(&scene, scene.sensors[0]);
for f in &frames { for f in &frames {
assert!(f.has_flag(flag::SHOT_NOISE_DISABLED)); assert!(f.has_flag(flag::SHOT_NOISE_DISABLED));
for k in 0..3 { for (k, (&b_pt, &b_ref)) in f.b_pt.iter().zip(b_analytic.iter()).enumerate() {
let recovered_t = f.b_pt[k] as f64 * 1.0e-12; let recovered_t = b_pt as f64 * 1.0e-12;
let lsb_t = crate::digitiser::ADC_LSB_T; let lsb_t = crate::digitiser::ADC_LSB_T;
assert!( assert!(
(recovered_t - b_analytic[k]).abs() <= lsb_t, (recovered_t - b_ref).abs() <= lsb_t,
"noise-off recovery error > 1 LSB for axis {k}" "noise-off recovery error > 1 LSB for axis {k}"
); );
} }
+8 -11
View File
@@ -58,12 +58,12 @@ pub struct LosSegment {
pub fn material_loss_db_per_m(m: Material) -> f64 { pub fn material_loss_db_per_m(m: Material) -> f64 {
match m { match m {
Material::Air => 0.0, Material::Air => 0.0,
Material::Drywall => 0.0, // conjecture: gypsum non-ferromagnetic Material::Drywall => 0.0, // conjecture: gypsum non-ferromagnetic
Material::Brick => 0.0, // conjecture: same logic as drywall Material::Brick => 0.0, // conjecture: same logic as drywall
Material::ConcreteDry => 0.5, // conjecture: Ulrich 2002 proxy Material::ConcreteDry => 0.5, // conjecture: Ulrich 2002 proxy
Material::ReinforcedConcrete => 20.0, // proxy + warning flag (plan §2.2) Material::ReinforcedConcrete => 20.0, // proxy + warning flag (plan §2.2)
Material::SheetSteel => 100.0, // frequency-dependent in reality; Material::SheetSteel => 100.0, // frequency-dependent in reality;
// representative DC bulk loss // representative DC bulk loss
} }
} }
@@ -92,10 +92,7 @@ pub fn attenuate(b_in: Vec3, segments: &[LosSegment]) -> (Vec3, bool) {
heavy |= material_is_heavy(seg.material); heavy |= material_is_heavy(seg.material);
} }
let scale = 10.0_f64.powf(-total_db / 20.0); let scale = 10.0_f64.powf(-total_db / 20.0);
( ([b_in[0] * scale, b_in[1] * scale, b_in[2] * scale], heavy)
[b_in[0] * scale, b_in[1] * scale, b_in[2] * scale],
heavy,
)
} }
/// Aggregate "propagator" type — currently a stateless wrapper over /// Aggregate "propagator" type — currently a stateless wrapper over
@@ -175,8 +172,8 @@ mod tests {
}]; }];
let (b_out, heavy) = attenuate(b_in, &segs); let (b_out, heavy) = attenuate(b_in, &segs);
let expected = 10.0_f64.powf(-4.0 / 20.0); let expected = 10.0_f64.powf(-4.0 / 20.0);
for k in 0..3 { for &val in &b_out {
assert_relative_eq!(b_out[k], expected, max_relative = 1e-12); assert_relative_eq!(val, expected, max_relative = 1e-12);
} }
assert!(heavy, "reinforced concrete must raise heavy_flag"); assert!(heavy, "reinforced concrete must raise heavy_flag");
} }
+17 -20
View File
@@ -63,12 +63,7 @@ pub const DEFAULT_N_SPINS: f64 = 1.0e12;
/// Tetrahedral 〈111〉 family in the diamond lattice. /// Tetrahedral 〈111〉 family in the diamond lattice.
pub fn nv_axes() -> [[f64; 3]; 4] { pub fn nv_axes() -> [[f64; 3]; 4] {
let s = 1.0 / 3.0_f64.sqrt(); let s = 1.0 / 3.0_f64.sqrt();
[ [[s, s, s], [s, -s, -s], [-s, s, -s], [-s, -s, s]]
[s, s, s],
[s, -s, -s],
[-s, s, -s],
[-s, -s, s],
]
} }
/// Sensor configuration. All defaults match plan §2.3 / Barry 2020 Table III /// Sensor configuration. All defaults match plan §2.3 / Barry 2020 Table III
@@ -163,8 +158,9 @@ impl NvSensor {
/// per-sample noise σ in T. /// per-sample noise σ in T.
pub fn shot_noise_floor_t_sqrt_hz(&self, integration_s: f64) -> f64 { pub fn shot_noise_floor_t_sqrt_hz(&self, integration_s: f64) -> f64 {
let t = integration_s.max(self.config.t2_star_s); let t = integration_s.max(self.config.t2_star_s);
let denom = let denom = GAMMA_E
GAMMA_E * self.config.contrast * (self.config.n_spins * t * self.config.t2_star_s).sqrt(); * self.config.contrast
* (self.config.n_spins * t * self.config.t2_star_s).sqrt();
if denom <= 0.0 { if denom <= 0.0 {
f64::INFINITY f64::INFINITY
} else { } else {
@@ -316,13 +312,10 @@ mod tests {
]; ];
for &b_in in &inputs { for &b_in in &inputs {
let r = s.sample(b_in, 1.0e-3, 0xCAFE_BABE); let r = s.sample(b_in, 1.0e-3, 0xCAFE_BABE);
for k in 0..3 { for (k, (&b_recovered, &b_orig)) in r.b_recovered.iter().zip(b_in.iter()).enumerate() {
let denom = b_in[k].abs().max(1e-30); let denom = b_orig.abs().max(1e-30);
let rel = (r.b_recovered[k] - b_in[k]).abs() / denom; let rel = (b_recovered - b_orig).abs() / denom;
assert!( assert!(rel < 0.01, "LSQ residual {rel:.4} exceeds 1% for axis {k}");
rel < 0.01,
"LSQ residual {rel:.4} exceeds 1% for axis {k}"
);
} }
} }
} }
@@ -338,19 +331,19 @@ mod tests {
let mut sum = [0.0_f64; 3]; let mut sum = [0.0_f64; 3];
for i in 0..n { for i in 0..n {
let r = s.sample([0.0; 3], dt, 0xDEAD_BEEF + i as u64); let r = s.sample([0.0; 3], dt, 0xDEAD_BEEF + i as u64);
for k in 0..3 { for (s, &b) in sum.iter_mut().zip(r.b_recovered.iter()) {
sum[k] += r.b_recovered[k]; *s += b;
} }
} }
let mean = [sum[0] / n as f64, sum[1] / n as f64, sum[2] / n as f64]; let mean = [sum[0] / n as f64, sum[1] / n as f64, sum[2] / n as f64];
// Stat margin: σ_mean = σ / √n. Allow ≤ 1σ_mean (loose). // Stat margin: σ_mean = σ / √n. Allow ≤ 1σ_mean (loose).
let r = s.sample([0.0; 3], dt, 0); let r = s.sample([0.0; 3], dt, 0);
let sigma_mean = r.sigma_per_axis[0] / (n as f64).sqrt(); let sigma_mean = r.sigma_per_axis[0] / (n as f64).sqrt();
for k in 0..3 { for (k, &m) in mean.iter().enumerate() {
assert!( assert!(
mean[k].abs() <= sigma_mean, m.abs() <= sigma_mean,
"axis {k} zero-input mean {} exceeds σ_mean {}", "axis {k} zero-input mean {} exceeds σ_mean {}",
mean[k], m,
sigma_mean sigma_mean
); );
} }
@@ -392,6 +385,9 @@ mod tests {
// form depends on this. Verify the matrix. // form depends on this. Verify the matrix.
let axes = nv_axes(); let axes = nv_axes();
let mut ata = [[0.0_f64; 3]; 3]; let mut ata = [[0.0_f64; 3]; 3];
// Compute AᵀA using explicit 2D indexing — clippy::needless_range_loop
// cannot be avoided here without losing clarity in this matrix formula.
#[allow(clippy::needless_range_loop)]
for j in 0..3 { for j in 0..3 {
for k in 0..3 { for k in 0..3 {
let mut acc = 0.0; let mut acc = 0.0;
@@ -401,6 +397,7 @@ mod tests {
ata[j][k] = acc; ata[j][k] = acc;
} }
} }
#[allow(clippy::needless_range_loop)]
for j in 0..3 { for j in 0..3 {
for k in 0..3 { for k in 0..3 {
let expected = if j == k { 4.0 / 3.0 } else { 0.0 }; let expected = if j == k { 4.0 / 3.0 } else { 0.0 };
+5 -1
View File
@@ -132,7 +132,11 @@ pub fn scene_field_at(scene: &Scene, sensor_pos: Vec3) -> (Vec3, bool) {
/// Total field at every sensor location in a scene, in scene order. /// Total field at every sensor location in a scene, in scene order.
pub fn scene_field_at_sensors(scene: &Scene) -> Vec<(Vec3, bool)> { pub fn scene_field_at_sensors(scene: &Scene) -> Vec<(Vec3, bool)> {
scene.sensors.iter().map(|&p| scene_field_at(scene, p)).collect() scene
.sensors
.iter()
.map(|&p| scene_field_at(scene, p))
.collect()
} }
// ────────────────────── vec3 helpers ───────────────────────────────────── // ────────────────────── vec3 helpers ─────────────────────────────────────
+14 -6
View File
@@ -46,8 +46,8 @@ impl WasmPipeline {
pub fn new(scene_json: &str, config_json: &str, seed: f64) -> Result<WasmPipeline, JsValue> { pub fn new(scene_json: &str, config_json: &str, seed: f64) -> Result<WasmPipeline, JsValue> {
let scene: Scene = let scene: Scene =
serde_json::from_str(scene_json).map_err(|e| js_err(format!("scene parse: {e}")))?; serde_json::from_str(scene_json).map_err(|e| js_err(format!("scene parse: {e}")))?;
let config: PipelineConfig = serde_json::from_str(config_json) let config: PipelineConfig =
.map_err(|e| js_err(format!("config parse: {e}")))?; serde_json::from_str(config_json).map_err(|e| js_err(format!("config parse: {e}")))?;
let seed_u64 = seed as u64; let seed_u64 = seed as u64;
Ok(WasmPipeline { Ok(WasmPipeline {
inner: Pipeline::new(scene, config, seed_u64), inner: Pipeline::new(scene, config, seed_u64),
@@ -184,8 +184,8 @@ pub fn run_transient(
) -> Result<JsValue, JsValue> { ) -> Result<JsValue, JsValue> {
let scene: crate::scene::Scene = let scene: crate::scene::Scene =
serde_json::from_str(scene_json).map_err(|e| js_err(format!("scene parse: {e}")))?; serde_json::from_str(scene_json).map_err(|e| js_err(format!("scene parse: {e}")))?;
let config: crate::pipeline::PipelineConfig = serde_json::from_str(config_json) let config: crate::pipeline::PipelineConfig =
.map_err(|e| js_err(format!("config parse: {e}")))?; serde_json::from_str(config_json).map_err(|e| js_err(format!("config parse: {e}")))?;
let pipeline = crate::pipeline::Pipeline::new(scene, config, seed as u64); let pipeline = crate::pipeline::Pipeline::new(scene, config, seed as u64);
let (frames, witness) = pipeline.run_with_witness(n_samples); let (frames, witness) = pipeline.run_with_witness(n_samples);
@@ -217,7 +217,11 @@ pub fn run_transient(
let s_arr = js_sys::Float64Array::new_with_length(3); let s_arr = js_sys::Float64Array::new_with_length(3);
s_arr.copy_from(&avg_s_pt); s_arr.copy_from(&avg_s_pt);
js_sys::Reflect::set(&obj, &JsValue::from_str("bRecoveredT"), &b_arr)?; js_sys::Reflect::set(&obj, &JsValue::from_str("bRecoveredT"), &b_arr)?;
js_sys::Reflect::set(&obj, &JsValue::from_str("bMagT"), &JsValue::from_f64(bmag_t))?; js_sys::Reflect::set(
&obj,
&JsValue::from_str("bMagT"),
&JsValue::from_f64(bmag_t),
)?;
js_sys::Reflect::set( js_sys::Reflect::set(
&obj, &obj,
&JsValue::from_str("noiseFloorPtSqrtHz"), &JsValue::from_str("noiseFloorPtSqrtHz"),
@@ -230,6 +234,10 @@ pub fn run_transient(
&JsValue::from_f64(frames.len() as f64), &JsValue::from_f64(frames.len() as f64),
)?; )?;
let witness_hex = crate::proof::Proof::hex(&witness); let witness_hex = crate::proof::Proof::hex(&witness);
js_sys::Reflect::set(&obj, &JsValue::from_str("witnessHex"), &JsValue::from_str(&witness_hex))?; js_sys::Reflect::set(
&obj,
&JsValue::from_str("witnessHex"),
&JsValue::from_str(&witness_hex),
)?;
Ok(obj.into()) Ok(obj.into())
} }
+5 -1
View File
@@ -31,7 +31,11 @@ pub mod mat;
/// WiFi-DensePose Command Line Interface /// WiFi-DensePose Command Line Interface
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(name = "wifi-densepose")] #[command(name = "wifi-densepose")]
#[command(author, version, about = "WiFi-based pose estimation and disaster response")] #[command(
author,
version,
about = "WiFi-based pose estimation and disaster response"
)]
#[command(propagate_version = true)] #[command(propagate_version = true)]
pub struct Cli { pub struct Cli {
/// Command to execute /// Command to execute
+27 -59
View File
@@ -16,8 +16,8 @@ use std::path::PathBuf;
use tabled::{settings::Style, Table, Tabled}; use tabled::{settings::Style, Table, Tabled};
use wifi_densepose_mat::{ use wifi_densepose_mat::{
DisasterConfig, DisasterType, Priority, ScanZone, TriageStatus, ZoneBounds, domain::alert::AlertStatus, DisasterConfig, DisasterType, Priority, ScanZone, TriageStatus,
ZoneStatus, domain::alert::AlertStatus, ZoneBounds, ZoneStatus,
}; };
/// MAT subcommand /// MAT subcommand
@@ -452,40 +452,21 @@ pub async fn execute(command: MatCommand) -> Result<()> {
/// Execute the scan command /// Execute the scan command
async fn execute_scan(args: ScanArgs) -> Result<()> { async fn execute_scan(args: ScanArgs) -> Result<()> {
println!( println!("{} Starting survivor scan...", "[MAT]".bright_cyan().bold());
"{} Starting survivor scan...",
"[MAT]".bright_cyan().bold()
);
println!(); println!();
// Display configuration // Display configuration
println!("{}", "Configuration:".bold()); println!("{}", "Configuration:".bold());
println!( println!(" {} {:?}", "Disaster Type:".dimmed(), args.disaster_type);
" {} {:?}", println!(" {} {:.1}", "Sensitivity:".dimmed(), args.sensitivity);
"Disaster Type:".dimmed(), println!(" {} {:.1}m", "Max Depth:".dimmed(), args.max_depth);
args.disaster_type
);
println!(
" {} {:.1}",
"Sensitivity:".dimmed(),
args.sensitivity
);
println!(
" {} {:.1}m",
"Max Depth:".dimmed(),
args.max_depth
);
println!( println!(
" {} {}", " {} {}",
"Continuous:".dimmed(), "Continuous:".dimmed(),
if args.continuous { "Yes" } else { "No" } if args.continuous { "Yes" } else { "No" }
); );
if args.continuous { if args.continuous {
println!( println!(" {} {}ms", "Interval:".dimmed(), args.interval);
" {} {}ms",
"Interval:".dimmed(),
args.interval
);
} }
if let Some(ref zone) = args.zone { if let Some(ref zone) = args.zone {
println!(" {} {}", "Zone:".dimmed(), zone); println!(" {} {}", "Zone:".dimmed(), zone);
@@ -516,10 +497,7 @@ async fn execute_scan(args: ScanArgs) -> Result<()> {
"[INFO]".blue(), "[INFO]".blue(),
config.disaster_type config.disaster_type
); );
println!( println!("{} Waiting for hardware connection...", "[INFO]".blue());
"{} Waiting for hardware connection...",
"[INFO]".blue()
);
println!(); println!();
println!( println!(
"{} No hardware detected. Use --simulate for demo mode.", "{} No hardware detected. Use --simulate for demo mode.",
@@ -538,7 +516,9 @@ async fn simulate_scan_output() -> Result<()> {
let pb = ProgressBar::new(100); let pb = ProgressBar::new(100);
pb.set_style( pb.set_style(
ProgressStyle::default_bar() ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")? .template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
)?
.progress_chars("#>-"), .progress_chars("#>-"),
); );
@@ -591,13 +571,10 @@ async fn simulate_scan_output() -> Result<()> {
"3".green().bold() "3".green().bold()
); );
println!( println!(
" {} {} {} {} {} {}", " {} 1 {} 1 {} 1",
"IMMEDIATE:".red().bold(), "IMMEDIATE:".red().bold(),
"1",
"DELAYED:".yellow().bold(), "DELAYED:".yellow().bold(),
"1",
"MINOR:".green().bold(), "MINOR:".green().bold(),
"1"
); );
Ok(()) Ok(())
@@ -674,11 +651,7 @@ async fn execute_status(args: StatusArgs) -> Result<()> {
status.active_zones, status.active_zones,
status.total_zones status.total_zones
); );
println!( println!(" {} {}", "Disaster Type:".dimmed(), status.disaster_type);
" {} {}",
"Disaster Type:".dimmed(),
status.disaster_type
);
println!( println!(
" {} {}", " {} {}",
"Survivors Detected:".dimmed(), "Survivors Detected:".dimmed(),
@@ -774,8 +747,10 @@ async fn execute_zones(args: ZonesArgs) -> Result<()> {
match bounds_parsed { match bounds_parsed {
Ok(zone_bounds) => { Ok(zone_bounds) => {
let zone = if let Some(sens) = sensitivity { let zone = if let Some(sens) = sensitivity {
let mut params = wifi_densepose_mat::ScanParameters::default(); let params = wifi_densepose_mat::ScanParameters {
params.sensitivity = sens; sensitivity: sens,
..Default::default()
};
ScanZone::with_parameters(&name, zone_bounds, params) ScanZone::with_parameters(&name, zone_bounds, params)
} else { } else {
ScanZone::new(&name, zone_bounds) ScanZone::new(&name, zone_bounds)
@@ -806,26 +781,14 @@ async fn execute_zones(args: ZonesArgs) -> Result<()> {
); );
println!("Use --force to confirm."); println!("Use --force to confirm.");
} else { } else {
println!( println!("{} Zone '{}' removed.", "[OK]".green().bold(), zone.cyan());
"{} Zone '{}' removed.",
"[OK]".green().bold(),
zone.cyan()
);
} }
} }
ZonesCommand::Pause { zone } => { ZonesCommand::Pause { zone } => {
println!( println!("{} Zone '{}' paused.", "[OK]".green().bold(), zone.cyan());
"{} Zone '{}' paused.",
"[OK]".green().bold(),
zone.cyan()
);
} }
ZonesCommand::Resume { zone } => { ZonesCommand::Resume { zone } => {
println!( println!("{} Zone '{}' resumed.", "[OK]".green().bold(), zone.cyan());
"{} Zone '{}' resumed.",
"[OK]".green().bold(),
zone.cyan()
);
} }
} }
@@ -848,7 +811,9 @@ fn parse_bounds(zone_type: &ZoneType, bounds: &str) -> Result<ZoneBounds> {
parts.len() parts.len()
); );
} }
Ok(ZoneBounds::rectangle(parts[0], parts[1], parts[2], parts[3])) Ok(ZoneBounds::rectangle(
parts[0], parts[1], parts[2], parts[3],
))
} }
ZoneType::Circle => { ZoneType::Circle => {
if parts.len() != 3 { if parts.len() != 3 {
@@ -1036,7 +1001,10 @@ async fn execute_alerts(args: AlertsArgs) -> Result<()> {
if filtered.is_empty() { if filtered.is_empty() {
println!("No alerts."); println!("No alerts.");
} else { } else {
let pending = filtered.iter().filter(|a| a.status.contains("Pending")).count(); let pending = filtered
.iter()
.filter(|a| a.status.contains("Pending"))
.count();
if pending > 0 { if pending > 0 {
println!( println!(
"{} {} pending alert(s) require attention!", "{} {} pending alert(s) require attention!",
+28 -14
View File
@@ -52,19 +52,29 @@ pub mod types;
pub mod utils; pub mod utils;
// Re-export commonly used types at the crate root // Re-export commonly used types at the crate root
pub use error::{CoreError, CoreResult, SignalError, InferenceError, StorageError}; pub use error::{CoreError, CoreResult, InferenceError, SignalError, StorageError};
pub use traits::{SignalProcessor, NeuralInference, DataStore}; pub use traits::{DataStore, NeuralInference, SignalProcessor};
pub use types::{ pub use types::{
// CSI types AntennaConfig,
CsiFrame, CsiMetadata, AntennaConfig,
// Signal types
ProcessedSignal, SignalFeatures, FrequencyBand,
// Pose types
PoseEstimate, PersonPose, Keypoint, KeypointType,
// Common types
Confidence, Timestamp, FrameId, DeviceId,
// Bounding box // Bounding box
BoundingBox, BoundingBox,
// Common types
Confidence,
// CSI types
CsiFrame,
CsiMetadata,
DeviceId,
FrameId,
FrequencyBand,
Keypoint,
KeypointType,
PersonPose,
// Pose types
PoseEstimate,
// Signal types
ProcessedSignal,
SignalFeatures,
Timestamp,
}; };
/// Crate version /// Crate version
@@ -97,20 +107,24 @@ pub mod prelude {
}; };
} }
// Compile-time assertions on module-level constants.
const _: () = assert!(MAX_SUBCARRIERS > 0);
const _: () = assert!(DEFAULT_CONFIDENCE_THRESHOLD > 0.0);
const _: () = assert!(DEFAULT_CONFIDENCE_THRESHOLD < 1.0);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_version_is_valid() { fn test_version_is_valid() {
assert!(!VERSION.is_empty()); // CARGO_PKG_VERSION is always non-empty; verify the constant is
// accessible and has a dot-separated semver shape.
assert!(VERSION.contains('.'), "version should be semver: {VERSION}");
} }
#[test] #[test]
fn test_constants() { fn test_constants() {
assert_eq!(MAX_KEYPOINTS, 17); assert_eq!(MAX_KEYPOINTS, 17);
assert!(MAX_SUBCARRIERS > 0);
assert!(DEFAULT_CONFIDENCE_THRESHOLD > 0.0);
assert!(DEFAULT_CONFIDENCE_THRESHOLD < 1.0);
} }
} }
+6 -2
View File
@@ -506,7 +506,8 @@ pub trait AsyncDataStore: Send + Sync {
async fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>; async fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>;
/// Retrieves CSI frames matching the query options. /// Retrieves CSI frames matching the query options.
async fn query_csi_frames(&self, options: &QueryOptions) -> Result<Vec<CsiFrame>, StorageError>; async fn query_csi_frames(&self, options: &QueryOptions)
-> Result<Vec<CsiFrame>, StorageError>;
/// Stores a pose estimate. /// Stores a pose estimate.
async fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>; async fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>;
@@ -621,6 +622,9 @@ mod tests {
assert_eq!(cpu, InferenceDevice::Cpu); assert_eq!(cpu, InferenceDevice::Cpu);
assert!(matches!(cuda, InferenceDevice::Cuda { device_id: 0 })); assert!(matches!(cuda, InferenceDevice::Cuda { device_id: 0 }));
assert!(matches!(tensorrt, InferenceDevice::TensorRt { device_id: 1 })); assert!(matches!(
tensorrt,
InferenceDevice::TensorRt { device_id: 1 }
));
} }
} }
+14 -10
View File
@@ -806,7 +806,10 @@ impl BoundingBox {
/// Returns the center point of the bounding box. /// Returns the center point of the bounding box.
#[must_use] #[must_use]
pub fn center(&self) -> (f32, f32) { pub fn center(&self) -> (f32, f32) {
((self.x_min + self.x_max) / 2.0, (self.y_min + self.y_max) / 2.0) (
(self.x_min + self.x_max) / 2.0,
(self.y_min + self.y_max) / 2.0,
)
} }
/// Computes the Intersection over Union (IoU) with another bounding box. /// Computes the Intersection over Union (IoU) with another bounding box.
@@ -997,14 +1000,12 @@ impl PoseEstimate {
/// Returns the person with the highest confidence. /// Returns the person with the highest confidence.
#[must_use] #[must_use]
pub fn highest_confidence_person(&self) -> Option<&PersonPose> { pub fn highest_confidence_person(&self) -> Option<&PersonPose> {
self.persons self.persons.iter().max_by(|a, b| {
.iter() a.confidence
.max_by(|a, b| { .value()
a.confidence .partial_cmp(&b.confidence.value())
.value() .unwrap_or(std::cmp::Ordering::Equal)
.partial_cmp(&b.confidence.value()) })
.unwrap_or(std::cmp::Ordering::Equal)
})
} }
} }
@@ -1082,7 +1083,10 @@ mod tests {
#[test] #[test]
fn test_keypoint_type_conversion() { fn test_keypoint_type_conversion() {
assert_eq!(KeypointType::try_from(0).unwrap(), KeypointType::Nose); assert_eq!(KeypointType::try_from(0).unwrap(), KeypointType::Nose);
assert_eq!(KeypointType::try_from(16).unwrap(), KeypointType::RightAnkle); assert_eq!(
KeypointType::try_from(16).unwrap(),
KeypointType::RightAnkle
);
assert!(KeypointType::try_from(17).is_err()); assert!(KeypointType::try_from(17).is_err());
} }
+2 -3
View File
@@ -99,9 +99,8 @@ pub fn moving_average(data: &Array1<f64>, window_size: usize) -> Array1<f64> {
let half_window = window_size / 2; let half_window = window_size / 2;
// ndarray Array1 is always contiguous, but handle gracefully if not // ndarray Array1 is always contiguous, but handle gracefully if not
let slice = match data.as_slice() { let Some(slice) = data.as_slice() else {
Some(s) => s, return data.clone();
None => return data.clone(),
}; };
for i in 0..data.len() { for i in 0..data.len() {
File diff suppressed because one or more lines are too long
@@ -2355,22 +2355,22 @@
"markdownDescription": "Denies the unminimize command without any pre-configured scope." "markdownDescription": "Denies the unminimize command without any pre-configured scope."
}, },
{ {
"description": "This permission set configures the types of dialogs\navailable from the dialog plugin.\n\n#### Granted Permissions\n\nAll dialog types are enabled.\n\n\n\n#### This default permission set includes:\n\n- `allow-ask`\n- `allow-confirm`\n- `allow-message`\n- `allow-save`\n- `allow-open`", "description": "This permission set configures the types of dialogs\navailable from the dialog plugin.\n\n#### Granted Permissions\n\nAll dialog types are enabled.\n\n\n\n#### This default permission set includes:\n\n- `allow-message`\n- `allow-save`\n- `allow-open`",
"type": "string", "type": "string",
"const": "dialog:default", "const": "dialog:default",
"markdownDescription": "This permission set configures the types of dialogs\navailable from the dialog plugin.\n\n#### Granted Permissions\n\nAll dialog types are enabled.\n\n\n\n#### This default permission set includes:\n\n- `allow-ask`\n- `allow-confirm`\n- `allow-message`\n- `allow-save`\n- `allow-open`" "markdownDescription": "This permission set configures the types of dialogs\navailable from the dialog plugin.\n\n#### Granted Permissions\n\nAll dialog types are enabled.\n\n\n\n#### This default permission set includes:\n\n- `allow-message`\n- `allow-save`\n- `allow-open`"
}, },
{ {
"description": "Enables the ask command without any pre-configured scope.", "description": "Enables the ask command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `allow-message` and will be removed in v3)",
"type": "string", "type": "string",
"const": "dialog:allow-ask", "const": "dialog:allow-ask",
"markdownDescription": "Enables the ask command without any pre-configured scope." "markdownDescription": "Enables the ask command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `allow-message` and will be removed in v3)"
}, },
{ {
"description": "Enables the confirm command without any pre-configured scope.", "description": "Enables the confirm command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `allow-message` and will be removed in v3)",
"type": "string", "type": "string",
"const": "dialog:allow-confirm", "const": "dialog:allow-confirm",
"markdownDescription": "Enables the confirm command without any pre-configured scope." "markdownDescription": "Enables the confirm command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `allow-message` and will be removed in v3)"
}, },
{ {
"description": "Enables the message command without any pre-configured scope.", "description": "Enables the message command without any pre-configured scope.",
@@ -2391,16 +2391,16 @@
"markdownDescription": "Enables the save command without any pre-configured scope." "markdownDescription": "Enables the save command without any pre-configured scope."
}, },
{ {
"description": "Denies the ask command without any pre-configured scope.", "description": "Denies the ask command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `deny-message` and will be removed in v3)",
"type": "string", "type": "string",
"const": "dialog:deny-ask", "const": "dialog:deny-ask",
"markdownDescription": "Denies the ask command without any pre-configured scope." "markdownDescription": "Denies the ask command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `deny-message` and will be removed in v3)"
}, },
{ {
"description": "Denies the confirm command without any pre-configured scope.", "description": "Denies the confirm command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `deny-message` and will be removed in v3)",
"type": "string", "type": "string",
"const": "dialog:deny-confirm", "const": "dialog:deny-confirm",
"markdownDescription": "Denies the confirm command without any pre-configured scope." "markdownDescription": "Denies the confirm command without any pre-configured scope. (**DEPRECATED**: This is now an alias to `deny-message` and will be removed in v3)"
}, },
{ {
"description": "Denies the message command without any pre-configured scope.", "description": "Denies the message command without any pre-configured scope.",
File diff suppressed because it is too large Load Diff
@@ -1,16 +1,16 @@
use std::net::{SocketAddr, UdpSocket}; use std::net::{SocketAddr, UdpSocket};
use std::time::Duration; use std::time::Duration;
use flume::RecvTimeoutError;
use mdns_sd::{ServiceDaemon, ServiceEvent}; use mdns_sd::{ServiceDaemon, ServiceEvent};
use serde::Serialize; use serde::Serialize;
use tauri::State; use tauri::State;
use tokio::time::timeout; use tokio::time::timeout;
use tokio_serial::available_ports; use tokio_serial::available_ports;
use flume::RecvTimeoutError;
use crate::domain::node::{ use crate::domain::node::{
Chip, DiscoveredNode, DiscoveryMethod, HealthStatus, MacAddress, MeshRole, Chip, DiscoveredNode, DiscoveryMethod, HealthStatus, MacAddress, MeshRole, NodeCapabilities,
NodeCapabilities, NodeRegistry, NodeRegistry,
}; };
use crate::state::AppState; use crate::state::AppState;
@@ -110,14 +110,16 @@ async fn discover_via_mdns(timeout_duration: Duration) -> Result<Vec<DiscoveredN
_ => MeshRole::Node, _ => MeshRole::Node,
}; };
let node = DiscoveredNode { let node = DiscoveredNode {
ip: info.get_addresses() ip: info
.get_addresses()
.iter() .iter()
.next() .next()
.map(|a| a.to_string()) .map(|a| a.to_string())
.unwrap_or_default(), .unwrap_or_default(),
mac: props.get("mac").map(|v| v.val_str().to_string()), mac: props.get("mac").map(|v| v.val_str().to_string()),
hostname: Some(info.get_hostname().to_string()), hostname: Some(info.get_hostname().to_string()),
node_id: props.get("node_id") node_id: props
.get("node_id")
.and_then(|v| v.val_str().parse().ok()) .and_then(|v| v.val_str().parse().ok())
.unwrap_or(0), .unwrap_or(0),
firmware_version: props.get("version").map(|v| v.val_str().to_string()), firmware_version: props.get("version").map(|v| v.val_str().to_string()),
@@ -127,11 +129,18 @@ async fn discover_via_mdns(timeout_duration: Duration) -> Result<Vec<DiscoveredN
mesh_role, mesh_role,
discovery_method: DiscoveryMethod::Mdns, discovery_method: DiscoveryMethod::Mdns,
tdm_slot: props.get("tdm_slot").and_then(|v| v.val_str().parse().ok()), tdm_slot: props.get("tdm_slot").and_then(|v| v.val_str().parse().ok()),
tdm_total: props.get("tdm_total").and_then(|v| v.val_str().parse().ok()), tdm_total: props
edge_tier: props.get("edge_tier").and_then(|v| v.val_str().parse().ok()), .get("tdm_total")
.and_then(|v| v.val_str().parse().ok()),
edge_tier: props
.get("edge_tier")
.and_then(|v| v.val_str().parse().ok()),
uptime_secs: props.get("uptime").and_then(|v| v.val_str().parse().ok()), uptime_secs: props.get("uptime").and_then(|v| v.val_str().parse().ok()),
capabilities: Some(NodeCapabilities { capabilities: Some(NodeCapabilities {
wasm: props.get("wasm").map(|v| v.val_str() == "1").unwrap_or(false), wasm: props
.get("wasm")
.map(|v| v.val_str() == "1")
.unwrap_or(false),
ota: props.get("ota").map(|v| v.val_str() == "1").unwrap_or(true), ota: props.get("ota").map(|v| v.val_str() == "1").unwrap_or(true),
csi: props.get("csi").map(|v| v.val_str() == "1").unwrap_or(true), csi: props.get("csi").map(|v| v.val_str() == "1").unwrap_or(true),
}), }),
@@ -153,7 +162,12 @@ async fn discover_via_mdns(timeout_duration: Duration) -> Result<Vec<DiscoveredN
discovered discovered
}); });
match timeout(timeout_duration + Duration::from_millis(500), discovery_task).await { match timeout(
timeout_duration + Duration::from_millis(500),
discovery_task,
)
.await
{
Ok(Ok(nodes)) => Ok(nodes), Ok(Ok(nodes)) => Ok(nodes),
Ok(Err(e)) => Err(format!("mDNS discovery task failed: {}", e)), Ok(Err(e)) => Err(format!("mDNS discovery task failed: {}", e)),
Err(_) => Ok(Vec::new()), // Timeout, return empty Err(_) => Ok(Vec::new()), // Timeout, return empty
@@ -210,7 +224,12 @@ async fn discover_via_udp(timeout_duration: Duration) -> Result<Vec<DiscoveredNo
discovered discovered
}); });
match timeout(timeout_duration + Duration::from_millis(500), discovery_task).await { match timeout(
timeout_duration + Duration::from_millis(500),
discovery_task,
)
.await
{
Ok(Ok(nodes)) => Ok(nodes), Ok(Ok(nodes)) => Ok(nodes),
Ok(Err(e)) => Err(format!("UDP discovery task failed: {}", e)), Ok(Err(e)) => Err(format!("UDP discovery task failed: {}", e)),
Err(_) => Ok(Vec::new()), Err(_) => Ok(Vec::new()),
@@ -295,16 +314,14 @@ pub async fn list_serial_ports() -> Result<Vec<SerialPortInfo>, String> {
for port in ports { for port in ports {
tracing::debug!("Processing port: {}", port.port_name); tracing::debug!("Processing port: {}", port.port_name);
let info = match port.port_type { let info = match port.port_type {
tokio_serial::SerialPortType::UsbPort(usb_info) => { tokio_serial::SerialPortType::UsbPort(usb_info) => SerialPortInfo {
SerialPortInfo { name: port.port_name,
name: port.port_name, vid: Some(usb_info.vid),
vid: Some(usb_info.vid), pid: Some(usb_info.pid),
pid: Some(usb_info.pid), manufacturer: usb_info.manufacturer,
manufacturer: usb_info.manufacturer, serial_number: usb_info.serial_number,
serial_number: usb_info.serial_number, is_esp32_compatible: is_esp32_compatible(usb_info.vid, usb_info.pid),
is_esp32_compatible: is_esp32_compatible(usb_info.vid, usb_info.pid), },
}
}
_ => { _ => {
SerialPortInfo { SerialPortInfo {
name: port.port_name.clone(), name: port.port_name.clone(),
@@ -401,7 +418,9 @@ fn is_esp32_compatible(vid: u16, pid: u16) -> bool {
return true; return true;
} }
// FTDI // FTDI
if vid == 0x0403 && (pid == 0x6001 || pid == 0x6010 || pid == 0x6011 || pid == 0x6014 || pid == 0x6015) { if vid == 0x0403
&& (pid == 0x6001 || pid == 0x6010 || pid == 0x6011 || pid == 0x6014 || pid == 0x6015)
{
return true; return true;
} }
// ESP32-S2/S3 native USB // ESP32-S2/S3 native USB
@@ -450,9 +469,12 @@ pub async fn configure_esp32_wifi(
let _ = serial.read(&mut buf); let _ = serial.read(&mut buf);
// Send command // Send command
serial.write_all(cmd.as_bytes()) serial
.write_all(cmd.as_bytes())
.map_err(|e| format!("Failed to write: {}", e))?; .map_err(|e| format!("Failed to write: {}", e))?;
serial.flush().map_err(|e| format!("Failed to flush: {}", e))?; serial
.flush()
.map_err(|e| format!("Failed to flush: {}", e))?;
// Wait and read response // Wait and read response
std::thread::sleep(Duration::from_millis(500)); std::thread::sleep(Duration::from_millis(500));
@@ -465,7 +487,8 @@ pub async fn configure_esp32_wifi(
// Check for success indicators // Check for success indicators
if text.to_lowercase().contains("ok") if text.to_lowercase().contains("ok")
|| text.to_lowercase().contains("saved") || text.to_lowercase().contains("saved")
|| text.to_lowercase().contains("configured") { || text.to_lowercase().contains("configured")
{
tracing::info!("WiFi config successful: {}", text.trim()); tracing::info!("WiFi config successful: {}", text.trim());
return Ok(format!("WiFi configured! Response: {}", text.trim())); return Ok(format!("WiFi configured! Response: {}", text.trim()));
} }
@@ -37,13 +37,16 @@ pub async fn flash_firmware(
let firmware_hash = calculate_sha256(&firmware_path)?; let firmware_hash = calculate_sha256(&firmware_path)?;
// Emit flash started event // Emit flash started event
let _ = app.emit("flash-progress", FlashProgress { let _ = app.emit(
phase: "connecting".into(), "flash-progress",
progress_pct: 0.0, FlashProgress {
bytes_written: 0, phase: "connecting".into(),
bytes_total: firmware_size, progress_pct: 0.0,
message: Some(format!("Connecting to {} ...", port)), bytes_written: 0,
}); bytes_total: firmware_size,
message: Some(format!("Connecting to {} ...", port)),
},
);
// Build espflash command // Build espflash command
let baud_rate = baud.unwrap_or(921600); let baud_rate = baud.unwrap_or(921600);
@@ -67,13 +70,12 @@ pub async fn flash_firmware(
cmd.stderr(Stdio::piped()); cmd.stderr(Stdio::piped());
// Spawn the process // Spawn the process
let mut child = cmd.spawn() let mut child = cmd
.spawn()
.map_err(|e| format!("Failed to start espflash: {}. Is espflash installed?", e))?; .map_err(|e| format!("Failed to start espflash: {}. Is espflash installed?", e))?;
let _stdout = child.stdout.take() let _stdout = child.stdout.take().ok_or("Failed to capture stdout")?;
.ok_or("Failed to capture stdout")?; let stderr = child.stderr.take().ok_or("Failed to capture stderr")?;
let stderr = child.stderr.take()
.ok_or("Failed to capture stderr")?;
// Read and parse progress from stderr (espflash outputs there) // Read and parse progress from stderr (espflash outputs there)
let app_clone = app.clone(); let app_clone = app.clone();
@@ -84,8 +86,8 @@ pub async fn flash_firmware(
let mut last_phase = "connecting".to_string(); let mut last_phase = "connecting".to_string();
let mut last_progress = 0.0f32; let mut last_progress = 0.0f32;
for line in reader.lines() { for line in reader.lines().map_while(Result::ok) {
if let Ok(line) = line { {
// Parse espflash progress output // Parse espflash progress output
if line.contains("Connecting") { if line.contains("Connecting") {
last_phase = "connecting".to_string(); last_phase = "connecting".to_string();
@@ -104,19 +106,24 @@ pub async fn flash_firmware(
last_progress = 95.0; last_progress = 95.0;
} }
let _ = app_clone.emit("flash-progress", FlashProgress { let _ = app_clone.emit(
phase: last_phase.clone(), "flash-progress",
progress_pct: last_progress, FlashProgress {
bytes_written: ((last_progress / 100.0) * firmware_size_clone as f32) as u64, phase: last_phase.clone(),
bytes_total: firmware_size_clone, progress_pct: last_progress,
message: Some(line), bytes_written: ((last_progress / 100.0) * firmware_size_clone as f32)
}); as u64,
bytes_total: firmware_size_clone,
message: Some(line),
},
);
} }
} }
}); });
// Wait for completion // Wait for completion
let status = child.wait() let status = child
.wait()
.map_err(|e| format!("Failed to wait for espflash: {}", e))?; .map_err(|e| format!("Failed to wait for espflash: {}", e))?;
// Wait for progress parsing to complete // Wait for progress parsing to complete
@@ -126,13 +133,16 @@ pub async fn flash_firmware(
if status.success() { if status.success() {
// Emit completion // Emit completion
let _ = app.emit("flash-progress", FlashProgress { let _ = app.emit(
phase: "completed".into(), "flash-progress",
progress_pct: 100.0, FlashProgress {
bytes_written: firmware_size, phase: "completed".into(),
bytes_total: firmware_size, progress_pct: 100.0,
message: Some("Flash completed successfully!".into()), bytes_written: firmware_size,
}); bytes_total: firmware_size,
message: Some("Flash completed successfully!".into()),
},
);
Ok(FlashResult { Ok(FlashResult {
success: true, success: true,
@@ -141,13 +151,16 @@ pub async fn flash_firmware(
firmware_hash: Some(firmware_hash), firmware_hash: Some(firmware_hash),
}) })
} else { } else {
let _ = app.emit("flash-progress", FlashProgress { let _ = app.emit(
phase: "failed".into(), "flash-progress",
progress_pct: 0.0, FlashProgress {
bytes_written: 0, phase: "failed".into(),
bytes_total: firmware_size, progress_pct: 0.0,
message: Some("Flash failed".into()), bytes_written: 0,
}); bytes_total: firmware_size,
message: Some("Flash failed".into()),
},
);
Err(format!("espflash exited with status: {}", status)) Err(format!("espflash exited with status: {}", status))
} }
@@ -199,9 +212,7 @@ pub async fn check_espflash() -> Result<EspflashInfo, String> {
.map_err(|_| "espflash not found. Please install: cargo install espflash")?; .map_err(|_| "espflash not found. Please install: cargo install espflash")?;
if output.status.success() { if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout) let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
.trim()
.to_string();
Ok(EspflashInfo { Ok(EspflashInfo {
installed: true, installed: true,
@@ -247,8 +258,7 @@ pub async fn supported_chips() -> Result<Vec<ChipInfo>, String> {
/// Calculate SHA-256 hash of a file. /// Calculate SHA-256 hash of a file.
fn calculate_sha256(path: &str) -> Result<String, String> { fn calculate_sha256(path: &str) -> Result<String, String> {
let file = std::fs::File::open(path) let file = std::fs::File::open(path).map_err(|e| format!("Failed to open file: {}", e))?;
.map_err(|e| format!("Failed to open file: {}", e))?;
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
@@ -344,13 +354,11 @@ mod tests {
#[test] #[test]
fn test_chip_info() { fn test_chip_info() {
let chips = vec![ let chips = [ChipInfo {
ChipInfo { id: "esp32".into(),
id: "esp32".into(), name: "ESP32".into(),
name: "ESP32".into(), description: "Test".into(),
description: "Test".into(), }];
},
];
assert_eq!(chips.len(), 1); assert_eq!(chips.len(), 1);
assert_eq!(chips[0].id, "esp32"); assert_eq!(chips[0].id, "esp32");
} }
@@ -37,16 +37,19 @@ pub async fn ota_update(
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
// Emit progress // Emit progress
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "preparing".into(), OtaProgress {
progress_pct: 0.0, node_ip: node_ip.clone(),
message: Some("Reading firmware...".into()), phase: "preparing".into(),
}); progress_pct: 0.0,
message: Some("Reading firmware...".into()),
},
);
// Read firmware file // Read firmware file
let mut file = File::open(&firmware_path) let mut file =
.map_err(|e| format!("Cannot read firmware: {}", e))?; File::open(&firmware_path).map_err(|e| format!("Cannot read firmware: {}", e))?;
let mut firmware_data = Vec::new(); let mut firmware_data = Vec::new();
file.read_to_end(&mut firmware_data) file.read_to_end(&mut firmware_data)
@@ -70,12 +73,18 @@ pub async fn ota_update(
}; };
// Emit progress // Emit progress
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "uploading".into(), OtaProgress {
progress_pct: 10.0, node_ip: node_ip.clone(),
message: Some(format!("Uploading {} bytes to {}...", firmware_size, node_ip)), phase: "uploading".into(),
}); progress_pct: 10.0,
message: Some(format!(
"Uploading {} bytes to {}...",
firmware_size, node_ip
)),
},
);
// Build HTTP client // Build HTTP client
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
@@ -107,30 +116,38 @@ pub async fn ota_update(
request = request.header("X-OTA-SHA256", &firmware_hash); request = request.header("X-OTA-SHA256", &firmware_hash);
// Send request // Send request
let response = request.send().await let response = request
.send()
.await
.map_err(|e| format!("OTA upload failed: {}", e))?; .map_err(|e| format!("OTA upload failed: {}", e))?;
let status = response.status(); let status = response.status();
let body = response.text().await.unwrap_or_default(); let body = response.text().await.unwrap_or_default();
if !status.is_success() { if !status.is_success() {
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "failed".into(), OtaProgress {
progress_pct: 0.0, node_ip: node_ip.clone(),
message: Some(format!("HTTP {}: {}", status, body)), phase: "failed".into(),
}); progress_pct: 0.0,
message: Some(format!("HTTP {}: {}", status, body)),
},
);
return Err(format!("OTA failed with HTTP {}: {}", status, body)); return Err(format!("OTA failed with HTTP {}: {}", status, body));
} }
// Emit progress - upload complete // Emit progress - upload complete
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "rebooting".into(), OtaProgress {
progress_pct: 80.0, node_ip: node_ip.clone(),
message: Some("Waiting for node reboot...".into()), phase: "rebooting".into(),
}); progress_pct: 80.0,
message: Some("Waiting for node reboot...".into()),
},
);
// Wait for node to come back online // Wait for node to come back online
let reboot_ok = wait_for_reboot(&client, &node_ip, Duration::from_secs(30)).await; let reboot_ok = wait_for_reboot(&client, &node_ip, Duration::from_secs(30)).await;
@@ -138,12 +155,15 @@ pub async fn ota_update(
let duration = start_time.elapsed().as_secs_f64(); let duration = start_time.elapsed().as_secs_f64();
if reboot_ok { if reboot_ok {
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "completed".into(), OtaProgress {
progress_pct: 100.0, node_ip: node_ip.clone(),
message: Some(format!("OTA completed in {:.1}s", duration)), phase: "completed".into(),
}); progress_pct: 100.0,
message: Some(format!("OTA completed in {:.1}s", duration)),
},
);
Ok(OtaResult { Ok(OtaResult {
success: true, success: true,
@@ -153,12 +173,15 @@ pub async fn ota_update(
duration_secs: Some(duration), duration_secs: Some(duration),
}) })
} else { } else {
let _ = app.emit("ota-progress", OtaProgress { let _ = app.emit(
node_ip: node_ip.clone(), "ota-progress",
phase: "warning".into(), OtaProgress {
progress_pct: 90.0, node_ip: node_ip.clone(),
message: Some("Node may not have rebooted successfully".into()), phase: "warning".into(),
}); progress_pct: 90.0,
message: Some("Node may not have rebooted successfully".into()),
},
);
Ok(OtaResult { Ok(OtaResult {
success: true, success: true,
@@ -190,13 +213,16 @@ pub async fn batch_ota_update(
let strategy = strategy.unwrap_or_else(|| "sequential".into()); let strategy = strategy.unwrap_or_else(|| "sequential".into());
let max_concurrent = max_concurrent.unwrap_or(1); let max_concurrent = max_concurrent.unwrap_or(1);
let _ = app.emit("batch-ota-progress", BatchOtaProgress { let _ = app.emit(
phase: "starting".into(), "batch-ota-progress",
total: total_nodes, BatchOtaProgress {
completed: 0, phase: "starting".into(),
failed: 0, total: total_nodes,
current_node: None, completed: 0,
}); failed: 0,
current_node: None,
},
);
let mut results = Vec::new(); let mut results = Vec::new();
let mut completed = 0; let mut completed = 0;
@@ -212,22 +238,26 @@ pub async fn batch_ota_update(
let psk = std::sync::Arc::new(psk); let psk = std::sync::Arc::new(psk);
let app = std::sync::Arc::new(app.clone()); let app = std::sync::Arc::new(app.clone());
let tasks: Vec<_> = node_ips.into_iter().map(|ip| { let tasks: Vec<_> = node_ips
let sem = semaphore.clone(); .into_iter()
let fw_path = firmware_path.clone(); .map(|ip| {
let psk_clone = psk.clone(); let sem = semaphore.clone();
let app_clone = app.clone(); let fw_path = firmware_path.clone();
let psk_clone = psk.clone();
let app_clone = app.clone();
async move { async move {
let _permit = sem.acquire().await.unwrap(); let _permit = sem.acquire().await.unwrap();
ota_update( ota_update(
(*app_clone).clone(), (*app_clone).clone(),
ip, ip,
(*fw_path).clone(), (*fw_path).clone(),
(*psk_clone).clone(), (*psk_clone).clone(),
).await )
} .await
}).collect(); }
})
.collect();
let task_results = futures::future::join_all(tasks).await; let task_results = futures::future::join_all(tasks).await;
@@ -257,20 +287,19 @@ pub async fn batch_ota_update(
_ => { _ => {
// Sequential execution (default) // Sequential execution (default)
for ip in node_ips { for ip in node_ips {
let _ = app.emit("batch-ota-progress", BatchOtaProgress { let _ = app.emit(
phase: "updating".into(), "batch-ota-progress",
total: total_nodes, BatchOtaProgress {
completed, phase: "updating".into(),
failed, total: total_nodes,
current_node: Some(ip.clone()), completed,
}); failed,
current_node: Some(ip.clone()),
},
);
match ota_update( match ota_update(app.clone(), ip.clone(), firmware_path.clone(), psk.clone()).await
app.clone(), {
ip.clone(),
firmware_path.clone(),
psk.clone(),
).await {
Ok(r) => { Ok(r) => {
if r.success { if r.success {
completed += 1; completed += 1;
@@ -296,13 +325,16 @@ pub async fn batch_ota_update(
let duration = start_time.elapsed().as_secs_f64(); let duration = start_time.elapsed().as_secs_f64();
let _ = app.emit("batch-ota-progress", BatchOtaProgress { let _ = app.emit(
phase: "completed".into(), "batch-ota-progress",
total: total_nodes, BatchOtaProgress {
completed, phase: "completed".into(),
failed, total: total_nodes,
current_node: None, completed,
}); failed,
current_node: None,
},
);
Ok(BatchOtaResult { Ok(BatchOtaResult {
total: total_nodes, total: total_nodes,
@@ -331,7 +363,10 @@ pub async fn check_ota_endpoint(node_ip: String) -> Result<OtaEndpointInfo, Stri
// Try to parse as JSON // Try to parse as JSON
let version = serde_json::from_str::<serde_json::Value>(&body) let version = serde_json::from_str::<serde_json::Value>(&body)
.ok() .ok()
.and_then(|v| v.get("version").and_then(|v| v.as_str().map(|s| s.to_string()))); .and_then(|v| {
v.get("version")
.and_then(|v| v.as_str().map(|s| s.to_string()))
});
Ok(OtaEndpointInfo { Ok(OtaEndpointInfo {
reachable: true, reachable: true,
@@ -45,9 +45,9 @@ pub async fn provision_node(
// Open serial port // Open serial port
let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async( let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async(
tokio_serial::new(&port, PROVISION_BAUD) tokio_serial::new(&port, PROVISION_BAUD).timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)),
.timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)) )
).map_err(|e| format!("Failed to open serial port: {}", e))?; .map_err(|e| format!("Failed to open serial port: {}", e))?;
let (mut reader, mut writer) = tokio::io::split(port_settings); let (mut reader, mut writer) = tokio::io::split(port_settings);
@@ -59,17 +59,19 @@ pub async fn provision_node(
}; };
let header_bytes = bincode_header(&header); let header_bytes = bincode_header(&header);
tokio::io::AsyncWriteExt::write_all(&mut writer, &header_bytes).await tokio::io::AsyncWriteExt::write_all(&mut writer, &header_bytes)
.await
.map_err(|e| format!("Failed to send header: {}", e))?; .map_err(|e| format!("Failed to send header: {}", e))?;
// Wait for ACK // Wait for ACK
let mut ack_buf = [0u8; 4]; let mut ack_buf = [0u8; 4];
tokio::time::timeout( tokio::time::timeout(
Duration::from_millis(SERIAL_TIMEOUT_MS), Duration::from_millis(SERIAL_TIMEOUT_MS),
tokio::io::AsyncReadExt::read_exact(&mut reader, &mut ack_buf) tokio::io::AsyncReadExt::read_exact(&mut reader, &mut ack_buf),
).await )
.map_err(|_| "Timeout waiting for device acknowledgment")? .await
.map_err(|e| format!("Failed to read ACK: {}", e))?; .map_err(|_| "Timeout waiting for device acknowledgment")?
.map_err(|e| format!("Failed to read ACK: {}", e))?;
if &ack_buf != b"ACK\n" { if &ack_buf != b"ACK\n" {
return Err(format!("Invalid ACK response: {:?}", ack_buf)); return Err(format!("Invalid ACK response: {:?}", ack_buf));
@@ -78,7 +80,8 @@ pub async fn provision_node(
// Send NVS data in chunks // Send NVS data in chunks
const CHUNK_SIZE: usize = 256; const CHUNK_SIZE: usize = 256;
for chunk in nvs_data.chunks(CHUNK_SIZE) { for chunk in nvs_data.chunks(CHUNK_SIZE) {
tokio::io::AsyncWriteExt::write_all(&mut writer, chunk).await tokio::io::AsyncWriteExt::write_all(&mut writer, chunk)
.await
.map_err(|e| format!("Failed to send data chunk: {}", e))?; .map_err(|e| format!("Failed to send data chunk: {}", e))?;
// Small delay between chunks for device processing // Small delay between chunks for device processing
@@ -86,20 +89,23 @@ pub async fn provision_node(
} }
// Send checksum // Send checksum
tokio::io::AsyncWriteExt::write_all(&mut writer, checksum.as_bytes()).await tokio::io::AsyncWriteExt::write_all(&mut writer, checksum.as_bytes())
.await
.map_err(|e| format!("Failed to send checksum: {}", e))?; .map_err(|e| format!("Failed to send checksum: {}", e))?;
tokio::io::AsyncWriteExt::write_all(&mut writer, b"\n").await tokio::io::AsyncWriteExt::write_all(&mut writer, b"\n")
.await
.map_err(|e| format!("Failed to send newline: {}", e))?; .map_err(|e| format!("Failed to send newline: {}", e))?;
// Wait for confirmation // Wait for confirmation
let mut confirm_buf = [0u8; 32]; let mut confirm_buf = [0u8; 32];
let confirm_len = tokio::time::timeout( let confirm_len = tokio::time::timeout(
Duration::from_millis(SERIAL_TIMEOUT_MS * 2), Duration::from_millis(SERIAL_TIMEOUT_MS * 2),
tokio::io::AsyncReadExt::read(&mut reader, &mut confirm_buf) tokio::io::AsyncReadExt::read(&mut reader, &mut confirm_buf),
).await )
.map_err(|_| "Timeout waiting for confirmation")? .await
.map_err(|e| format!("Failed to read confirmation: {}", e))?; .map_err(|_| "Timeout waiting for confirmation")?
.map_err(|e| format!("Failed to read confirmation: {}", e))?;
let confirm_str = String::from_utf8_lossy(&confirm_buf[..confirm_len]); let confirm_str = String::from_utf8_lossy(&confirm_buf[..confirm_len]);
@@ -121,24 +127,26 @@ pub async fn provision_node(
pub async fn read_nvs(port: String) -> Result<ProvisioningConfig, String> { pub async fn read_nvs(port: String) -> Result<ProvisioningConfig, String> {
// Open serial port // Open serial port
let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async( let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async(
tokio_serial::new(&port, PROVISION_BAUD) tokio_serial::new(&port, PROVISION_BAUD).timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)),
.timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)) )
).map_err(|e| format!("Failed to open serial port: {}", e))?; .map_err(|e| format!("Failed to open serial port: {}", e))?;
let (mut reader, mut writer) = tokio::io::split(port_settings); let (mut reader, mut writer) = tokio::io::split(port_settings);
// Send read command // Send read command
tokio::io::AsyncWriteExt::write_all(&mut writer, b"RUVIEW_NVS_READ\n").await tokio::io::AsyncWriteExt::write_all(&mut writer, b"RUVIEW_NVS_READ\n")
.await
.map_err(|e| format!("Failed to send read command: {}", e))?; .map_err(|e| format!("Failed to send read command: {}", e))?;
// Read size header // Read size header
let mut size_buf = [0u8; 4]; let mut size_buf = [0u8; 4];
tokio::time::timeout( tokio::time::timeout(
Duration::from_millis(SERIAL_TIMEOUT_MS), Duration::from_millis(SERIAL_TIMEOUT_MS),
tokio::io::AsyncReadExt::read_exact(&mut reader, &mut size_buf) tokio::io::AsyncReadExt::read_exact(&mut reader, &mut size_buf),
).await )
.map_err(|_| "Timeout waiting for NVS size")? .await
.map_err(|e| format!("Failed to read size: {}", e))?; .map_err(|_| "Timeout waiting for NVS size")?
.map_err(|e| format!("Failed to read size: {}", e))?;
let nvs_size = u32::from_le_bytes(size_buf) as usize; let nvs_size = u32::from_le_bytes(size_buf) as usize;
@@ -150,10 +158,11 @@ pub async fn read_nvs(port: String) -> Result<ProvisioningConfig, String> {
let mut nvs_data = vec![0u8; nvs_size]; let mut nvs_data = vec![0u8; nvs_size];
tokio::time::timeout( tokio::time::timeout(
Duration::from_millis(SERIAL_TIMEOUT_MS * 2), Duration::from_millis(SERIAL_TIMEOUT_MS * 2),
tokio::io::AsyncReadExt::read_exact(&mut reader, &mut nvs_data) tokio::io::AsyncReadExt::read_exact(&mut reader, &mut nvs_data),
).await )
.map_err(|_| "Timeout reading NVS data")? .await
.map_err(|e| format!("Failed to read NVS data: {}", e))?; .map_err(|_| "Timeout reading NVS data")?
.map_err(|e| format!("Failed to read NVS data: {}", e))?;
// Parse NVS data to config // Parse NVS data to config
deserialize_nvs_config(&nvs_data) deserialize_nvs_config(&nvs_data)
@@ -164,24 +173,26 @@ pub async fn read_nvs(port: String) -> Result<ProvisioningConfig, String> {
pub async fn erase_nvs(port: String) -> Result<ProvisionResult, String> { pub async fn erase_nvs(port: String) -> Result<ProvisionResult, String> {
// Open serial port // Open serial port
let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async( let port_settings = tokio_serial::SerialPortBuilderExt::open_native_async(
tokio_serial::new(&port, PROVISION_BAUD) tokio_serial::new(&port, PROVISION_BAUD).timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)),
.timeout(Duration::from_millis(SERIAL_TIMEOUT_MS)) )
).map_err(|e| format!("Failed to open serial port: {}", e))?; .map_err(|e| format!("Failed to open serial port: {}", e))?;
let (mut reader, mut writer) = tokio::io::split(port_settings); let (mut reader, mut writer) = tokio::io::split(port_settings);
// Send erase command // Send erase command
tokio::io::AsyncWriteExt::write_all(&mut writer, b"RUVIEW_NVS_ERASE\n").await tokio::io::AsyncWriteExt::write_all(&mut writer, b"RUVIEW_NVS_ERASE\n")
.await
.map_err(|e| format!("Failed to send erase command: {}", e))?; .map_err(|e| format!("Failed to send erase command: {}", e))?;
// Wait for confirmation // Wait for confirmation
let mut confirm_buf = [0u8; 32]; let mut confirm_buf = [0u8; 32];
let confirm_len = tokio::time::timeout( let confirm_len = tokio::time::timeout(
Duration::from_millis(SERIAL_TIMEOUT_MS * 3), // Erase takes longer Duration::from_millis(SERIAL_TIMEOUT_MS * 3), // Erase takes longer
tokio::io::AsyncReadExt::read(&mut reader, &mut confirm_buf) tokio::io::AsyncReadExt::read(&mut reader, &mut confirm_buf),
).await )
.map_err(|_| "Timeout waiting for erase confirmation")? .await
.map_err(|e| format!("Failed to read confirmation: {}", e))?; .map_err(|_| "Timeout waiting for erase confirmation")?
.map_err(|e| format!("Failed to read confirmation: {}", e))?;
let confirm_str = String::from_utf8_lossy(&confirm_buf[..confirm_len]); let confirm_str = String::from_utf8_lossy(&confirm_buf[..confirm_len]);
@@ -316,7 +327,8 @@ fn serialize_nvs_config(config: &ProvisioningConfig) -> Result<Vec<u8>, String>
write_u8(&mut data, "hop_count", hops); write_u8(&mut data, "hop_count", hops);
} }
if let Some(ref channels) = config.channel_list { if let Some(ref channels) = config.channel_list {
let ch_str: String = channels.iter() let ch_str: String = channels
.iter()
.map(|c| c.to_string()) .map(|c| c.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(","); .join(",");
@@ -359,8 +371,8 @@ fn deserialize_nvs_config(data: &[u8]) -> Result<ProvisioningConfig, String> {
return Err("Invalid NVS data: truncated key".into()); return Err("Invalid NVS data: truncated key".into());
} }
let key = std::str::from_utf8(&data[pos..pos + key_len]) let key =
.map_err(|_| "Invalid key encoding")?; std::str::from_utf8(&data[pos..pos + key_len]).map_err(|_| "Invalid key encoding")?;
pos += key_len; pos += key_len;
if pos + 2 > data.len() { if pos + 2 > data.len() {
@@ -379,9 +391,15 @@ fn deserialize_nvs_config(data: &[u8]) -> Result<ProvisioningConfig, String> {
// Parse based on key // Parse based on key
match key { match key {
"wifi_ssid" => config.wifi_ssid = Some(String::from_utf8_lossy(value_bytes).to_string()), "wifi_ssid" => {
"wifi_pass" => config.wifi_password = Some(String::from_utf8_lossy(value_bytes).to_string()), config.wifi_ssid = Some(String::from_utf8_lossy(value_bytes).to_string())
"target_ip" => config.target_ip = Some(String::from_utf8_lossy(value_bytes).to_string()), }
"wifi_pass" => {
config.wifi_password = Some(String::from_utf8_lossy(value_bytes).to_string())
}
"target_ip" => {
config.target_ip = Some(String::from_utf8_lossy(value_bytes).to_string())
}
"target_port" if value_len == 2 => { "target_port" if value_len == 2 => {
config.target_port = Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]])); config.target_port = Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]]));
} }
@@ -399,16 +417,18 @@ fn deserialize_nvs_config(data: &[u8]) -> Result<ProvisioningConfig, String> {
config.vital_window = Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]])); config.vital_window = Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]]));
} }
"vital_int" if value_len == 2 => { "vital_int" if value_len == 2 => {
config.vital_interval_ms = Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]])); config.vital_interval_ms =
Some(u16::from_le_bytes([value_bytes[0], value_bytes[1]]));
} }
"top_k" if value_len == 1 => config.top_k_count = Some(value_bytes[0]), "top_k" if value_len == 1 => config.top_k_count = Some(value_bytes[0]),
"hop_count" if value_len == 1 => config.hop_count = Some(value_bytes[0]), "hop_count" if value_len == 1 => config.hop_count = Some(value_bytes[0]),
"channels" => { "channels" => {
let ch_str = String::from_utf8_lossy(value_bytes); let ch_str = String::from_utf8_lossy(value_bytes);
config.channel_list = Some( config.channel_list = Some(
ch_str.split(',') ch_str
.split(',')
.filter_map(|s| s.trim().parse().ok()) .filter_map(|s| s.trim().parse().ok())
.collect() .collect(),
); );
} }
"power_duty" if value_len == 1 => config.power_duty = Some(value_bytes[0]), "power_duty" if value_len == 1 => config.power_duty = Some(value_bytes[0]),
@@ -484,9 +504,11 @@ mod tests {
#[test] #[test]
fn test_config_validation() { fn test_config_validation() {
let mut config = ProvisioningConfig::default(); let config = ProvisioningConfig {
config.tdm_slot = Some(5); tdm_slot: Some(5),
config.tdm_total = Some(4); tdm_total: Some(4),
..ProvisioningConfig::default()
};
let result = config.validate(); let result = config.validate();
assert!(result.is_err()); assert!(result.is_err());
@@ -117,8 +117,12 @@ pub async fn start_server(
cmd.stderr(Stdio::piped()); cmd.stderr(Stdio::piped());
// Spawn the child process // Spawn the child process
let child = cmd.spawn() let child = cmd.spawn().map_err(|e| {
.map_err(|e| format!("Failed to start server: {}. Is '{}' installed?", e, server_path))?; format!(
"Failed to start server: {}. Is '{}' installed?",
e, server_path
)
})?;
let pid = child.id(); let pid = child.id();
@@ -262,12 +266,14 @@ pub async fn server_status(state: State<'_, AppState>) -> Result<ServerStatusRes
}); });
} }
let pid = srv.pid.unwrap(); // srv.pid.is_none() is checked above; the expect is unreachable in practice.
let pid = srv.pid.expect("pid checked as Some before this point");
let mut sys = System::new(); let mut sys = System::new();
let sysinfo_pid = Pid::from_u32(pid); let sysinfo_pid = Pid::from_u32(pid);
sys.refresh_processes(ProcessesToUpdate::Some(&[sysinfo_pid]), true); sys.refresh_processes(ProcessesToUpdate::Some(&[sysinfo_pid]), true);
let (memory_mb, cpu_percent) = sys.process(sysinfo_pid) let (memory_mb, cpu_percent) = sys
.process(sysinfo_pid)
.map(|proc| { .map(|proc| {
let mem = proc.memory() as f64 / 1024.0 / 1024.0; let mem = proc.memory() as f64 / 1024.0 / 1024.0;
let cpu = proc.cpu_usage(); let cpu = proc.cpu_usage();
@@ -276,9 +282,9 @@ pub async fn server_status(state: State<'_, AppState>) -> Result<ServerStatusRes
.unwrap_or((None, None)); .unwrap_or((None, None));
// Calculate uptime if we have start time // Calculate uptime if we have start time
let uptime_secs = srv.start_time.map(|start| { let uptime_secs = srv
std::time::Instant::now().duration_since(start).as_secs() .start_time
}); .map(|start| std::time::Instant::now().duration_since(start).as_secs());
Ok(ServerStatusResponse { Ok(ServerStatusResponse {
running: srv.running, running: srv.running,
@@ -41,8 +41,7 @@ fn settings_path(app: &AppHandle) -> Result<PathBuf, String> {
.map_err(|e| format!("Failed to get app data dir: {}", e))?; .map_err(|e| format!("Failed to get app data dir: {}", e))?;
// Ensure directory exists // Ensure directory exists
fs::create_dir_all(&app_dir) fs::create_dir_all(&app_dir).map_err(|e| format!("Failed to create app data dir: {}", e))?;
.map_err(|e| format!("Failed to create app data dir: {}", e))?;
Ok(app_dir.join("settings.json")) Ok(app_dir.join("settings.json"))
} }
@@ -56,11 +55,11 @@ pub async fn get_settings(app: AppHandle) -> Result<Option<AppSettings>, String>
return Ok(None); return Ok(None);
} }
let contents = fs::read_to_string(&path) let contents =
.map_err(|e| format!("Failed to read settings: {}", e))?; fs::read_to_string(&path).map_err(|e| format!("Failed to read settings: {}", e))?;
let settings: AppSettings = serde_json::from_str(&contents) let settings: AppSettings =
.map_err(|e| format!("Failed to parse settings: {}", e))?; serde_json::from_str(&contents).map_err(|e| format!("Failed to parse settings: {}", e))?;
Ok(Some(settings)) Ok(Some(settings))
} }
@@ -73,8 +72,7 @@ pub async fn save_settings(app: AppHandle, settings: AppSettings) -> Result<(),
let contents = serde_json::to_string_pretty(&settings) let contents = serde_json::to_string_pretty(&settings)
.map_err(|e| format!("Failed to serialize settings: {}", e))?; .map_err(|e| format!("Failed to serialize settings: {}", e))?;
fs::write(&path, contents) fs::write(&path, contents).map_err(|e| format!("Failed to write settings: {}", e))?;
.map_err(|e| format!("Failed to write settings: {}", e))?;
Ok(()) Ok(())
} }
@@ -22,14 +22,19 @@ pub async fn wasm_list(node_ip: String) -> Result<Vec<WasmModuleInfo>, String> {
let url = format!("http://{}:{}/wasm/list", node_ip, WASM_PORT); let url = format!("http://{}:{}/wasm/list", node_ip, WASM_PORT);
let response = client.get(&url).send().await let response = client
.get(&url)
.send()
.await
.map_err(|e| format!("Failed to connect to node: {}", e))?; .map_err(|e| format!("Failed to connect to node: {}", e))?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("Node returned HTTP {}", response.status())); return Err(format!("Node returned HTTP {}", response.status()));
} }
let modules: Vec<WasmModuleInfo> = response.json().await let modules: Vec<WasmModuleInfo> = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {}", e))?; .map_err(|e| format!("Failed to parse response: {}", e))?;
Ok(modules) Ok(modules)
@@ -50,8 +55,7 @@ pub async fn wasm_upload(
auto_start: Option<bool>, auto_start: Option<bool>,
) -> Result<WasmUploadResult, String> { ) -> Result<WasmUploadResult, String> {
// Read WASM file // Read WASM file
let mut file = File::open(&wasm_path) let mut file = File::open(&wasm_path).map_err(|e| format!("Cannot read WASM file: {}", e))?;
.map_err(|e| format!("Cannot read WASM file: {}", e))?;
let mut wasm_data = Vec::new(); let mut wasm_data = Vec::new();
file.read_to_end(&mut wasm_data) file.read_to_end(&mut wasm_data)
@@ -99,7 +103,8 @@ pub async fn wasm_upload(
// Send request // Send request
let url = format!("http://{}:{}/wasm/upload", node_ip, WASM_PORT); let url = format!("http://{}:{}/wasm/upload", node_ip, WASM_PORT);
let response = client.post(&url) let response = client
.post(&url)
.multipart(form) .multipart(form)
.send() .send()
.await .await
@@ -113,13 +118,18 @@ pub async fn wasm_upload(
} }
// Parse response for module ID // Parse response for module ID
let upload_response: WasmUploadResponse = response.json().await let upload_response: WasmUploadResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse upload response: {}", e))?; .map_err(|e| format!("Failed to parse upload response: {}", e))?;
Ok(WasmUploadResult { Ok(WasmUploadResult {
success: true, success: true,
module_id: upload_response.module_id, module_id: upload_response.module_id,
message: format!("Module '{}' uploaded successfully ({} bytes)", name, wasm_size), message: format!(
"Module '{}' uploaded successfully ({} bytes)",
name, wasm_size
),
sha256: Some(wasm_hash), sha256: Some(wasm_hash),
}) })
} }
@@ -156,7 +166,10 @@ pub async fn wasm_control(
node_ip, WASM_PORT, module_id, action node_ip, WASM_PORT, module_id, action
); );
let response = client.post(&url).send().await let response = client
.post(&url)
.send()
.await
.map_err(|e| format!("WASM control failed: {}", e))?; .map_err(|e| format!("WASM control failed: {}", e))?;
let status = response.status(); let status = response.status();
@@ -179,10 +192,7 @@ pub async fn wasm_control(
/// Get detailed info about a specific WASM module. /// Get detailed info about a specific WASM module.
#[tauri::command] #[tauri::command]
pub async fn wasm_info( pub async fn wasm_info(node_ip: String, module_id: String) -> Result<WasmModuleDetail, String> {
node_ip: String,
module_id: String,
) -> Result<WasmModuleDetail, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(Duration::from_secs(WASM_TIMEOUT_SECS)) .timeout(Duration::from_secs(WASM_TIMEOUT_SECS))
.build() .build()
@@ -190,14 +200,19 @@ pub async fn wasm_info(
let url = format!("http://{}:{}/wasm/{}", node_ip, WASM_PORT, module_id); let url = format!("http://{}:{}/wasm/{}", node_ip, WASM_PORT, module_id);
let response = client.get(&url).send().await let response = client
.get(&url)
.send()
.await
.map_err(|e| format!("Failed to get module info: {}", e))?; .map_err(|e| format!("Failed to get module info: {}", e))?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("Module not found or HTTP {}", response.status())); return Err(format!("Module not found or HTTP {}", response.status()));
} }
let detail: WasmModuleDetail = response.json().await let detail: WasmModuleDetail = response
.json()
.await
.map_err(|e| format!("Failed to parse module info: {}", e))?; .map_err(|e| format!("Failed to parse module info: {}", e))?;
Ok(detail) Ok(detail)
@@ -213,14 +228,19 @@ pub async fn wasm_stats(node_ip: String) -> Result<WasmRuntimeStats, String> {
let url = format!("http://{}:{}/wasm/stats", node_ip, WASM_PORT); let url = format!("http://{}:{}/wasm/stats", node_ip, WASM_PORT);
let response = client.get(&url).send().await let response = client
.get(&url)
.send()
.await
.map_err(|e| format!("Failed to get WASM stats: {}", e))?; .map_err(|e| format!("Failed to get WASM stats: {}", e))?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("HTTP {}", response.status())); return Err(format!("HTTP {}", response.status()));
} }
let stats: WasmRuntimeStats = response.json().await let stats: WasmRuntimeStats = response
.json()
.await
.map_err(|e| format!("Failed to parse stats: {}", e))?; .map_err(|e| format!("Failed to parse stats: {}", e))?;
Ok(stats) Ok(stats)
@@ -246,13 +266,16 @@ pub async fn check_wasm_support(node_ip: String) -> Result<WasmSupportInfo, Stri
Ok(WasmSupportInfo { Ok(WasmSupportInfo {
supported: true, supported: true,
max_modules: info.as_ref() max_modules: info
.as_ref()
.and_then(|v| v.get("max_modules").and_then(|v| v.as_u64())) .and_then(|v| v.get("max_modules").and_then(|v| v.as_u64()))
.map(|v| v as u8), .map(|v| v as u8),
memory_limit_kb: info.as_ref() memory_limit_kb: info
.as_ref()
.and_then(|v| v.get("memory_limit_kb").and_then(|v| v.as_u64())) .and_then(|v| v.get("memory_limit_kb").and_then(|v| v.as_u64()))
.map(|v| v as u32), .map(|v| v as u32),
verify_signatures: info.as_ref() verify_signatures: info
.as_ref()
.and_then(|v| v.get("verify_signatures").and_then(|v| v.as_bool())) .and_then(|v| v.get("verify_signatures").and_then(|v| v.as_bool()))
.unwrap_or(false), .unwrap_or(false),
}) })
@@ -51,10 +51,7 @@ impl ProvisioningConfig {
} }
if let Some(duty) = self.power_duty { if let Some(duty) = self.power_duty {
if !(10..=100).contains(&duty) { if !(10..=100).contains(&duty) {
return Err(format!( return Err(format!("power_duty ({}) must be between 10 and 100", duty));
"power_duty ({}) must be between 10 and 100",
duty
));
} }
} }
Ok(()) Ok(())
+3 -35
View File
@@ -12,6 +12,7 @@ pub struct DiscoveryState {
} }
/// Sub-state for the managed sensing server process. /// Sub-state for the managed sensing server process.
#[derive(Default)]
pub struct ServerState { pub struct ServerState {
pub running: bool, pub running: bool,
pub pid: Option<u32>, pub pid: Option<u32>,
@@ -22,20 +23,6 @@ pub struct ServerState {
pub start_time: Option<Instant>, pub start_time: Option<Instant>,
} }
impl Default for ServerState {
fn default() -> Self {
Self {
running: false,
pid: None,
http_port: None,
ws_port: None,
udp_port: None,
child: None,
start_time: None,
}
}
}
/// Sub-state for flash progress tracking. /// Sub-state for flash progress tracking.
#[derive(Default)] #[derive(Default)]
pub struct FlashState { pub struct FlashState {
@@ -73,21 +60,14 @@ impl Default for OtaUpdateTracker {
} }
/// Sub-state for application settings cache. /// Sub-state for application settings cache.
#[derive(Default)]
pub struct SettingsState { pub struct SettingsState {
pub loaded: bool, pub loaded: bool,
pub dirty: bool, pub dirty: bool,
} }
impl Default for SettingsState {
fn default() -> Self {
Self {
loaded: false,
dirty: false,
}
}
}
/// Top-level application state managed by Tauri. /// Top-level application state managed by Tauri.
#[derive(Default)]
pub struct AppState { pub struct AppState {
pub discovery: Mutex<DiscoveryState>, pub discovery: Mutex<DiscoveryState>,
pub server: Mutex<ServerState>, pub server: Mutex<ServerState>,
@@ -96,18 +76,6 @@ pub struct AppState {
pub settings: Mutex<SettingsState>, pub settings: Mutex<SettingsState>,
} }
impl Default for AppState {
fn default() -> Self {
Self {
discovery: Mutex::new(DiscoveryState::default()),
server: Mutex::new(ServerState::default()),
flash: Mutex::new(FlashState::default()),
ota: Mutex::new(OtaState::default()),
settings: Mutex::new(SettingsState::default()),
}
}
}
impl AppState { impl AppState {
/// Create a new AppState instance. /// Create a new AppState instance.
pub fn new() -> Self { pub fn new() -> Self {
@@ -10,23 +10,44 @@
fn test_serial_port_detection_logic() { fn test_serial_port_detection_logic() {
// Test ESP32 VID/PID detection // Test ESP32 VID/PID detection
// CP210x (Silicon Labs) // CP210x (Silicon Labs)
assert!(is_esp32_vid_pid(0x10C4, 0xEA60), "CP2102 should be detected"); assert!(
assert!(is_esp32_vid_pid(0x10C4, 0xEA70), "CP2104 should be detected"); is_esp32_vid_pid(0x10C4, 0xEA60),
"CP2102 should be detected"
);
assert!(
is_esp32_vid_pid(0x10C4, 0xEA70),
"CP2104 should be detected"
);
// CH340/CH341 (QinHeng) // CH340/CH341 (QinHeng)
assert!(is_esp32_vid_pid(0x1A86, 0x7523), "CH340 should be detected"); assert!(is_esp32_vid_pid(0x1A86, 0x7523), "CH340 should be detected");
assert!(is_esp32_vid_pid(0x1A86, 0x5523), "CH341 should be detected"); assert!(is_esp32_vid_pid(0x1A86, 0x5523), "CH341 should be detected");
// FTDI // FTDI
assert!(is_esp32_vid_pid(0x0403, 0x6001), "FTDI FT232 should be detected"); assert!(
assert!(is_esp32_vid_pid(0x0403, 0x6010), "FTDI FT2232 should be detected"); is_esp32_vid_pid(0x0403, 0x6001),
"FTDI FT232 should be detected"
);
assert!(
is_esp32_vid_pid(0x0403, 0x6010),
"FTDI FT2232 should be detected"
);
// ESP32 native USB // ESP32 native USB
assert!(is_esp32_vid_pid(0x303A, 0x1001), "ESP32-S2/S3 native should be detected"); assert!(
is_esp32_vid_pid(0x303A, 0x1001),
"ESP32-S2/S3 native should be detected"
);
// Unknown device // Unknown device
assert!(!is_esp32_vid_pid(0x0000, 0x0000), "Unknown VID/PID should not be detected"); assert!(
assert!(!is_esp32_vid_pid(0x1234, 0x5678), "Random VID/PID should not be detected"); !is_esp32_vid_pid(0x0000, 0x0000),
"Unknown VID/PID should not be detected"
);
assert!(
!is_esp32_vid_pid(0x1234, 0x5678),
"Random VID/PID should not be detected"
);
} }
fn is_esp32_vid_pid(vid: u16, pid: u16) -> bool { fn is_esp32_vid_pid(vid: u16, pid: u16) -> bool {
@@ -39,7 +60,9 @@ fn is_esp32_vid_pid(vid: u16, pid: u16) -> bool {
return true; return true;
} }
// FTDI // FTDI
if vid == 0x0403 && (pid == 0x6001 || pid == 0x6010 || pid == 0x6011 || pid == 0x6014 || pid == 0x6015) { if vid == 0x0403
&& (pid == 0x6001 || pid == 0x6010 || pid == 0x6011 || pid == 0x6014 || pid == 0x6015)
{
return true; return true;
} }
// ESP32-S2/S3 native USB // ESP32-S2/S3 native USB
@@ -78,8 +101,14 @@ fn test_settings_structure() {
// Check default values // Check default values
assert!(!settings.theme.is_empty(), "Theme should have a default"); assert!(!settings.theme.is_empty(), "Theme should have a default");
assert!(settings.discover_interval_ms > 0, "Discovery interval should be positive"); assert!(
assert!(settings.auto_discover, "Auto-discover should default to true"); settings.discover_interval_ms > 0,
"Discovery interval should be positive"
);
assert!(
settings.auto_discover,
"Auto-discover should default to true"
);
assert_eq!(settings.server_http_port, 8080); assert_eq!(settings.server_http_port, 8080);
} }
@@ -128,7 +157,10 @@ fn test_chip_variants() {
for chip in chips { for chip in chips {
let name = format!("{:?}", chip).to_lowercase(); let name = format!("{:?}", chip).to_lowercase();
assert!(name.starts_with("esp32"), "All chips should be ESP32 variants"); assert!(
name.starts_with("esp32"),
"All chips should be ESP32 variants"
);
} }
} }
@@ -152,7 +184,7 @@ fn test_progress_parsing() {
#[test] #[test]
fn test_sha256_hash() { fn test_sha256_hash() {
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
let data = b"test firmware data"; let data = b"test firmware data";
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
@@ -178,7 +210,11 @@ fn test_hmac_signature() {
let result = mac.finalize(); let result = mac.finalize();
let signature = hex::encode(result.into_bytes()); let signature = hex::encode(result.into_bytes());
assert_eq!(signature.len(), 64, "HMAC-SHA256 should produce 64 hex characters"); assert_eq!(
signature.len(),
64,
"HMAC-SHA256 should produce 64 hex characters"
);
} }
// ============================================================================ // ============================================================================
@@ -305,11 +341,7 @@ fn test_discovery_method_variants() {
fn test_mesh_role_variants() { fn test_mesh_role_variants() {
use wifi_densepose_desktop::domain::node::MeshRole; use wifi_densepose_desktop::domain::node::MeshRole;
let roles = vec![ let roles = vec![MeshRole::Coordinator, MeshRole::Aggregator, MeshRole::Node];
MeshRole::Coordinator,
MeshRole::Aggregator,
MeshRole::Node,
];
for role in roles { for role in roles {
let json = serde_json::to_string(&role).expect("Should serialize"); let json = serde_json::to_string(&role).expect("Should serialize");
@@ -343,14 +375,18 @@ fn test_wifi_config_command_format() {
} }
#[test] #[test]
#[allow(clippy::const_is_empty)]
fn test_wifi_credentials_validation() { fn test_wifi_credentials_validation() {
// SSID: 1-32 characters // SSID: 1-32 characters
let valid_ssid = "MyNetwork"; let valid_ssid = "MyNetwork";
let empty_ssid = ""; let empty_ssid = "";
let long_ssid = "A".repeat(33); let long_ssid = "A".repeat(33);
assert!(!valid_ssid.is_empty() && valid_ssid.len() <= 32); assert!(
assert!(empty_ssid.is_empty()); !valid_ssid.is_empty() && valid_ssid.len() <= 32,
"SSID length must be 1-32"
);
assert!(empty_ssid.is_empty(), "empty_ssid must be empty");
assert!(long_ssid.len() > 32); assert!(long_ssid.len() > 32);
// Password: 8-63 characters for WPA2 // Password: 8-63 characters for WPA2
@@ -370,7 +406,7 @@ fn test_wifi_credentials_validation() {
#[test] #[test]
fn test_node_registry() { fn test_node_registry() {
use wifi_densepose_desktop::domain::node::{ use wifi_densepose_desktop::domain::node::{
DiscoveredNode, MacAddress, NodeRegistry, HealthStatus, Chip, MeshRole, DiscoveryMethod Chip, DiscoveredNode, DiscoveryMethod, HealthStatus, MacAddress, MeshRole, NodeRegistry,
}; };
let mut registry = NodeRegistry::new(); let mut registry = NodeRegistry::new();
@@ -13,24 +13,43 @@ async fn main() -> anyhow::Result<()> {
println!(" Location: {:.4}N, {:.4}W", loc.lat, loc.lon); println!(" Location: {:.4}N, {:.4}W", loc.lat, loc.lon);
let bbox = GeoBBox::from_center(&loc, 300.0); let bbox = GeoBBox::from_center(&loc, 300.0);
let tiles_list = tiles::fetch_area(&tiles::TileProvider::Sentinel2Cloudless, &bbox, 16, &cache).await?; let tiles_list =
println!(" Tiles: {} ({:.0}KB)", tiles_list.len(), tiles::fetch_area(&tiles::TileProvider::Sentinel2Cloudless, &bbox, 16, &cache).await?;
tiles_list.iter().map(|t| t.data.len()).sum::<usize>() as f64 / 1024.0); println!(
" Tiles: {} ({:.0}KB)",
tiles_list.len(),
tiles_list.iter().map(|t| t.data.len()).sum::<usize>() as f64 / 1024.0
);
let dem = terrain::fetch_elevation(&loc, &cache).await?; let dem = terrain::fetch_elevation(&loc, &cache).await?;
println!(" Elevation: {:.0}m (grid {}x{})", terrain::elevation_at(&dem, &loc), dem.cols, dem.rows); println!(
" Elevation: {:.0}m (grid {}x{})",
terrain::elevation_at(&dem, &loc),
dem.cols,
dem.rows
);
let buildings = osm::fetch_buildings(&loc, 300.0).await.unwrap_or_default(); let buildings = osm::fetch_buildings(&loc, 300.0).await.unwrap_or_default();
let roads = osm::fetch_roads(&loc, 300.0).await.unwrap_or_default(); let roads = osm::fetch_roads(&loc, 300.0).await.unwrap_or_default();
println!(" OSM: {} buildings, {} roads", buildings.len(), roads.len()); println!(
" OSM: {} buildings, {} roads",
buildings.len(),
roads.len()
);
let weather = temporal::fetch_weather(&loc).await?; let weather = temporal::fetch_weather(&loc).await?;
println!(" Weather: {:.0}°C humidity={:.0}% wind={:.1}m/s", println!(
weather.temperature_c, weather.humidity_pct, weather.wind_speed_ms); " Weather: {:.0}°C humidity={:.0}% wind={:.1}m/s",
weather.temperature_c, weather.humidity_pct, weather.wind_speed_ms
);
let scene = GeoScene { let scene = GeoScene {
location: loc.clone(), bbox, elevation_m: terrain::elevation_at(&dem, &loc), location: loc.clone(),
buildings, roads, tile_count: tiles_list.len(), bbox,
elevation_m: terrain::elevation_at(&dem, &loc),
buildings,
roads,
tile_count: tiles_list.len(),
registration: register::auto_register(&loc), registration: register::auto_register(&loc),
last_updated: chrono::Utc::now().to_rfc3339(), last_updated: chrono::Utc::now().to_rfc3339(),
}; };
@@ -41,7 +60,10 @@ async fn main() -> anyhow::Result<()> {
Err(e) => println!(" Brain: {e}"), Err(e) => println!(" Brain: {e}"),
} }
println!("\n Total: {}ms | Cache: {:.0}KB", println!(
t0.elapsed().as_millis(), cache.size_bytes() as f64 / 1024.0); "\n Total: {}ms | Cache: {:.0}KB",
t0.elapsed().as_millis(),
cache.size_bytes() as f64 / 1024.0
);
Ok(()) Ok(())
} }
+9 -3
View File
@@ -13,8 +13,8 @@ const DEFAULT_BRAIN_URL: &str = "http://127.0.0.1:9876";
pub(crate) fn brain_url() -> &'static str { pub(crate) fn brain_url() -> &'static str {
static BRAIN_URL: OnceLock<String> = OnceLock::new(); static BRAIN_URL: OnceLock<String> = OnceLock::new();
BRAIN_URL.get_or_init(|| { BRAIN_URL.get_or_init(|| {
let url = std::env::var("RUVIEW_BRAIN_URL") let url =
.unwrap_or_else(|_| DEFAULT_BRAIN_URL.to_string()); std::env::var("RUVIEW_BRAIN_URL").unwrap_or_else(|_| DEFAULT_BRAIN_URL.to_string());
eprintln!(" wifi-densepose-geo: using brain URL {url}"); eprintln!(" wifi-densepose-geo: using brain URL {url}");
url url
}) })
@@ -34,7 +34,13 @@ pub async fn store_geo_context(scene: &GeoScene) -> Result<u32> {
"category": "spatial-geo", "category": "spatial-geo",
"content": summary, "content": summary,
}); });
if client.post(format!("{}/memories", brain_url())).json(&body).send().await.is_ok() { if client
.post(format!("{}/memories", brain_url()))
.json(&body)
.send()
.await
.is_ok()
{
stored += 1; stored += 1;
} }
+5 -2
View File
@@ -54,8 +54,11 @@ fn walkdir(path: &Path) -> u64 {
.flatten() .flatten()
.filter_map(|e| e.ok()) .filter_map(|e| e.ok())
.map(|e| { .map(|e| {
if e.path().is_dir() { walkdir(&e.path()) } if e.path().is_dir() {
else { e.metadata().map(|m| m.len()).unwrap_or(0) } walkdir(&e.path())
} else {
e.metadata().map(|m| m.len()).unwrap_or(0)
}
}) })
.sum() .sum()
} }
+15 -4
View File
@@ -1,6 +1,6 @@
//! Coordinate transforms — WGS84, UTM, ENU, tile math. //! Coordinate transforms — WGS84, UTM, ENU, tile math.
use crate::types::{GeoPoint, GeoBBox, TileCoord}; use crate::types::{GeoBBox, GeoPoint, TileCoord};
const WGS84_A: f64 = 6_378_137.0; const WGS84_A: f64 = 6_378_137.0;
#[allow(dead_code)] #[allow(dead_code)]
@@ -55,9 +55,20 @@ pub fn tile_bounds(coord: &TileCoord) -> GeoBBox {
let n = 2f64.powi(coord.z as i32); let n = 2f64.powi(coord.z as i32);
let west = coord.x as f64 / n * 360.0 - 180.0; let west = coord.x as f64 / n * 360.0 - 180.0;
let east = (coord.x + 1) as f64 / n * 360.0 - 180.0; let east = (coord.x + 1) as f64 / n * 360.0 - 180.0;
let north = (std::f64::consts::PI * (1.0 - 2.0 * coord.y as f64 / n)).sinh().atan().to_degrees(); let north = (std::f64::consts::PI * (1.0 - 2.0 * coord.y as f64 / n))
let south = (std::f64::consts::PI * (1.0 - 2.0 * (coord.y + 1) as f64 / n)).sinh().atan().to_degrees(); .sinh()
GeoBBox { south, west, north, east } .atan()
.to_degrees();
let south = (std::f64::consts::PI * (1.0 - 2.0 * (coord.y + 1) as f64 / n))
.sinh()
.atan()
.to_degrees();
GeoBBox {
south,
west,
north,
east,
}
} }
/// Get all tile coordinates covering a bounding box at a zoom level. /// Get all tile coordinates covering a bounding box at a zoom level.
+30 -10
View File
@@ -12,11 +12,15 @@ pub async fn build_scene(radius_m: f64) -> Result<GeoScene> {
// 1. Locate // 1. Locate
let cache_path = cache.base_dir.join("location.json"); let cache_path = cache.base_dir.join("location.json");
let location = locate::get_location(cache_path.to_str().unwrap_or("")).await?; let location = locate::get_location(cache_path.to_str().unwrap_or("")).await?;
eprintln!(" Geo: located at {:.4}N, {:.4}W", location.lat, location.lon); eprintln!(
" Geo: located at {:.4}N, {:.4}W",
location.lat, location.lon
);
// 2. Fetch satellite tiles // 2. Fetch satellite tiles
let bbox = GeoBBox::from_center(&location, radius_m); let bbox = GeoBBox::from_center(&location, radius_m);
let tile_list = tiles::fetch_area(&tiles::TileProvider::Sentinel2Cloudless, &bbox, 16, &cache).await?; let tile_list =
tiles::fetch_area(&tiles::TileProvider::Sentinel2Cloudless, &bbox, 16, &cache).await?;
eprintln!(" Geo: fetched {} satellite tiles", tile_list.len()); eprintln!(" Geo: fetched {} satellite tiles", tile_list.len());
// 3. Fetch elevation // 3. Fetch elevation
@@ -25,9 +29,17 @@ pub async fn build_scene(radius_m: f64) -> Result<GeoScene> {
eprintln!(" Geo: elevation {:.0}m ASL", elevation); eprintln!(" Geo: elevation {:.0}m ASL", elevation);
// 4. Fetch OSM buildings + roads // 4. Fetch OSM buildings + roads
let buildings = osm::fetch_buildings(&location, radius_m).await.unwrap_or_default(); let buildings = osm::fetch_buildings(&location, radius_m)
let roads = osm::fetch_roads(&location, radius_m).await.unwrap_or_default(); .await
eprintln!(" Geo: {} buildings, {} roads", buildings.len(), roads.len()); .unwrap_or_default();
let roads = osm::fetch_roads(&location, radius_m)
.await
.unwrap_or_default();
eprintln!(
" Geo: {} buildings, {} roads",
buildings.len(),
roads.len()
);
// 5. Build registration // 5. Build registration
let mut reg_origin = location.clone(); let mut reg_origin = location.clone();
@@ -50,7 +62,9 @@ pub async fn build_scene(radius_m: f64) -> Result<GeoScene> {
pub fn summarize(scene: &GeoScene) -> String { pub fn summarize(scene: &GeoScene) -> String {
let building_count = scene.buildings.len(); let building_count = scene.buildings.len();
let road_count = scene.roads.len(); let road_count = scene.roads.len();
let road_names: Vec<&str> = scene.roads.iter() let road_names: Vec<&str> = scene
.roads
.iter()
.filter_map(|r| match r { .filter_map(|r| match r {
OsmFeature::Road { name, .. } => name.as_deref(), OsmFeature::Road { name, .. } => name.as_deref(),
_ => None, _ => None,
@@ -62,10 +76,16 @@ pub fn summarize(scene: &GeoScene) -> String {
"Location: {:.4}N, {:.4}W, elevation {:.0}m ASL. \ "Location: {:.4}N, {:.4}W, elevation {:.0}m ASL. \
{} buildings within view. {} roads nearby{}. \ {} buildings within view. {} roads nearby{}. \
{} satellite tiles at zoom 16. Updated: {}.", {} satellite tiles at zoom 16. Updated: {}.",
scene.location.lat, scene.location.lon, scene.elevation_m, scene.location.lat,
building_count, road_count, scene.location.lon,
if road_names.is_empty() { String::new() } scene.elevation_m,
else { format!(" ({})", road_names.join(", ")) }, building_count,
road_count,
if road_names.is_empty() {
String::new()
} else {
format!(" ({})", road_names.join(", "))
},
scene.tile_count, scene.tile_count,
&scene.last_updated[..10], &scene.last_updated[..10],
) )
+7 -7
View File
@@ -4,16 +4,16 @@
//! SRTM elevation, OSM buildings/roads, coordinate transforms, //! SRTM elevation, OSM buildings/roads, coordinate transforms,
//! temporal change tracking, and brain memory integration. //! temporal change tracking, and brain memory integration.
pub mod types; pub mod brain;
pub mod coord;
pub mod locate;
pub mod cache; pub mod cache;
pub mod tiles; pub mod coord;
pub mod terrain; pub mod fuse;
pub mod locate;
pub mod osm; pub mod osm;
pub mod register; pub mod register;
pub mod fuse;
pub mod brain;
pub mod temporal; pub mod temporal;
pub mod terrain;
pub mod tiles;
pub mod types;
pub use types::*; pub use types::*;
+4 -2
View File
@@ -12,8 +12,10 @@ pub async fn locate_by_ip() -> Result<GeoPoint> {
// Primary: ip-api.com (free, 45 req/min) // Primary: ip-api.com (free, 45 req/min)
let resp: serde_json::Value = client let resp: serde_json::Value = client
.get("http://ip-api.com/json/?fields=lat,lon,city,regionName,country") .get("http://ip-api.com/json/?fields=lat,lon,city,regionName,country")
.send().await? .send()
.json().await?; .await?
.json()
.await?;
let lat = resp.get("lat").and_then(|v| v.as_f64()).unwrap_or(0.0); let lat = resp.get("lat").and_then(|v| v.as_f64()).unwrap_or(0.0);
let lon = resp.get("lon").and_then(|v| v.as_f64()).unwrap_or(0.0); let lon = resp.get("lon").and_then(|v| v.as_f64()).unwrap_or(0.0);
+74 -23
View File
@@ -13,7 +13,9 @@ pub const MAX_RADIUS_M: f64 = 5000.0;
fn check_radius(radius_m: f64) -> Result<()> { fn check_radius(radius_m: f64) -> Result<()> {
if !radius_m.is_finite() || radius_m <= 0.0 { if !radius_m.is_finite() || radius_m <= 0.0 {
return Err(anyhow!("radius_m must be positive and finite (got {radius_m})")); return Err(anyhow!(
"radius_m must be positive and finite (got {radius_m})"
));
} }
if radius_m > MAX_RADIUS_M { if radius_m > MAX_RADIUS_M {
return Err(anyhow!( return Err(anyhow!(
@@ -34,8 +36,7 @@ pub async fn fetch_buildings(center: &GeoPoint, radius_m: f64) -> Result<Vec<Osm
let bbox = GeoBBox::from_center(center, radius_m); let bbox = GeoBBox::from_center(center, radius_m);
let query = format!( let query = format!(
r#"[out:json][timeout:25];(way["building"]({},{},{},{});relation["building"]({},{},{},{}););out body;>;out skel qt;"#, r#"[out:json][timeout:25];(way["building"]({},{},{},{});relation["building"]({},{},{},{}););out body;>;out skel qt;"#,
bbox.south, bbox.west, bbox.north, bbox.east, bbox.south, bbox.west, bbox.north, bbox.east, bbox.south, bbox.west, bbox.north, bbox.east,
bbox.south, bbox.west, bbox.north, bbox.east,
); );
let resp = overpass_query(&query).await?; let resp = overpass_query(&query).await?;
parse_buildings(&resp) parse_buildings(&resp)
@@ -59,9 +60,11 @@ async fn overpass_query(query: &str) -> Result<serde_json::Value> {
.user_agent("RuView/0.1") .user_agent("RuView/0.1")
.build()?; .build()?;
let resp = client.post(OVERPASS_URL) let resp = client
.post(OVERPASS_URL)
.form(&[("data", query)]) .form(&[("data", query)])
.send().await?; .send()
.await?;
if !resp.status().is_success() { if !resp.status().is_success() {
anyhow::bail!("Overpass API error: {}", resp.status()); anyhow::bail!("Overpass API error: {}", resp.status());
@@ -75,7 +78,9 @@ async fn overpass_query(query: &str) -> Result<serde_json::Value> {
/// top-level `elements` array (indicative of a malformed/non-Overpass payload). /// top-level `elements` array (indicative of a malformed/non-Overpass payload).
pub fn parse_overpass_json(data: &serde_json::Value) -> Result<Vec<OsmFeature>> { pub fn parse_overpass_json(data: &serde_json::Value) -> Result<Vec<OsmFeature>> {
if !data.is_object() || data.get("elements").and_then(|e| e.as_array()).is_none() { if !data.is_object() || data.get("elements").and_then(|e| e.as_array()).is_none() {
return Err(anyhow!("malformed Overpass response: missing `elements` array")); return Err(anyhow!(
"malformed Overpass response: missing `elements` array"
));
} }
parse_buildings(data) parse_buildings(data)
} }
@@ -84,7 +89,11 @@ pub(crate) fn parse_buildings(data: &serde_json::Value) -> Result<Vec<OsmFeature
let mut buildings = Vec::new(); let mut buildings = Vec::new();
let mut nodes: std::collections::HashMap<u64, [f64; 2]> = std::collections::HashMap::new(); let mut nodes: std::collections::HashMap<u64, [f64; 2]> = std::collections::HashMap::new();
let elements = data.get("elements").and_then(|e| e.as_array()).cloned().unwrap_or_default(); let elements = data
.get("elements")
.and_then(|e| e.as_array())
.cloned()
.unwrap_or_default();
// First pass: collect nodes // First pass: collect nodes
for el in &elements { for el in &elements {
@@ -101,24 +110,44 @@ pub(crate) fn parse_buildings(data: &serde_json::Value) -> Result<Vec<OsmFeature
// Second pass: build ways // Second pass: build ways
for el in &elements { for el in &elements {
if el.get("type").and_then(|t| t.as_str()) != Some("way") { continue; } if el.get("type").and_then(|t| t.as_str()) != Some("way") {
continue;
}
let tags = el.get("tags").cloned().unwrap_or(serde_json::json!({})); let tags = el.get("tags").cloned().unwrap_or(serde_json::json!({}));
if tags.get("building").is_none() { continue; } if tags.get("building").is_none() {
continue;
}
let node_ids = el.get("nodes").and_then(|n| n.as_array()).cloned().unwrap_or_default(); let node_ids = el
let outline: Vec<[f64; 2]> = node_ids.iter() .get("nodes")
.and_then(|n| n.as_array())
.cloned()
.unwrap_or_default();
let outline: Vec<[f64; 2]> = node_ids
.iter()
.filter_map(|id| id.as_u64().and_then(|id| nodes.get(&id).copied())) .filter_map(|id| id.as_u64().and_then(|id| nodes.get(&id).copied()))
.collect(); .collect();
if outline.len() < 3 { continue; } if outline.len() < 3 {
continue;
}
let height = tags.get("height").and_then(|h| h.as_str()) let height = tags
.get("height")
.and_then(|h| h.as_str())
.and_then(|s| s.trim_end_matches('m').trim().parse::<f32>().ok()) .and_then(|s| s.trim_end_matches('m').trim().parse::<f32>().ok())
.or(Some(8.0)); // default building height .or(Some(8.0)); // default building height
let name = tags.get("name").and_then(|n| n.as_str()).map(|s| s.to_string()); let name = tags
.get("name")
.and_then(|n| n.as_str())
.map(|s| s.to_string());
buildings.push(OsmFeature::Building { outline, height, name }); buildings.push(OsmFeature::Building {
outline,
height,
name,
});
} }
Ok(buildings) Ok(buildings)
@@ -128,7 +157,11 @@ fn parse_roads(data: &serde_json::Value) -> Result<Vec<OsmFeature>> {
let mut roads = Vec::new(); let mut roads = Vec::new();
let mut nodes: std::collections::HashMap<u64, [f64; 2]> = std::collections::HashMap::new(); let mut nodes: std::collections::HashMap<u64, [f64; 2]> = std::collections::HashMap::new();
let elements = data.get("elements").and_then(|e| e.as_array()).cloned().unwrap_or_default(); let elements = data
.get("elements")
.and_then(|e| e.as_array())
.cloned()
.unwrap_or_default();
for el in &elements { for el in &elements {
if el.get("type").and_then(|t| t.as_str()) == Some("node") { if el.get("type").and_then(|t| t.as_str()) == Some("node") {
@@ -143,19 +176,33 @@ fn parse_roads(data: &serde_json::Value) -> Result<Vec<OsmFeature>> {
} }
for el in &elements { for el in &elements {
if el.get("type").and_then(|t| t.as_str()) != Some("way") { continue; } if el.get("type").and_then(|t| t.as_str()) != Some("way") {
continue;
}
let tags = el.get("tags").cloned().unwrap_or(serde_json::json!({})); let tags = el.get("tags").cloned().unwrap_or(serde_json::json!({}));
let highway = tags.get("highway").and_then(|h| h.as_str()); let highway = tags.get("highway").and_then(|h| h.as_str());
if highway.is_none() { continue; } if highway.is_none() {
continue;
}
let node_ids = el.get("nodes").and_then(|n| n.as_array()).cloned().unwrap_or_default(); let node_ids = el
let path: Vec<[f64; 2]> = node_ids.iter() .get("nodes")
.and_then(|n| n.as_array())
.cloned()
.unwrap_or_default();
let path: Vec<[f64; 2]> = node_ids
.iter()
.filter_map(|id| id.as_u64().and_then(|id| nodes.get(&id).copied())) .filter_map(|id| id.as_u64().and_then(|id| nodes.get(&id).copied()))
.collect(); .collect();
if path.len() < 2 { continue; } if path.len() < 2 {
continue;
}
let name = tags.get("name").and_then(|n| n.as_str()).map(|s| s.to_string()); let name = tags
.get("name")
.and_then(|n| n.as_str())
.map(|s| s.to_string());
roads.push(OsmFeature::Road { roads.push(OsmFeature::Road {
path, path,
@@ -209,7 +256,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fetch_buildings_rejects_oversized_radius() { async fn fetch_buildings_rejects_oversized_radius() {
let center = GeoPoint { lat: 43.0, lon: -79.0, alt: 0.0 }; let center = GeoPoint {
lat: 43.0,
lon: -79.0,
alt: 0.0,
};
let err = fetch_buildings(&center, MAX_RADIUS_M + 1.0).await.err(); let err = fetch_buildings(&center, MAX_RADIUS_M + 1.0).await.err();
assert!(err.is_some(), "should reject radius > MAX_RADIUS_M"); assert!(err.is_some(), "should reject radius > MAX_RADIUS_M");
} }
+33 -12
View File
@@ -18,13 +18,28 @@ pub async fn fetch_weather(point: &GeoPoint) -> Result<WeatherData> {
.build()?; .build()?;
let resp: serde_json::Value = client.get(&url).send().await?.json().await?; let resp: serde_json::Value = client.get(&url).send().await?.json().await?;
let current = resp.get("current").cloned().unwrap_or(serde_json::json!({})); let current = resp
.get("current")
.cloned()
.unwrap_or(serde_json::json!({}));
Ok(WeatherData { Ok(WeatherData {
temperature_c: current.get("temperature_2m").and_then(|v| v.as_f64()).unwrap_or(0.0) as f32, temperature_c: current
humidity_pct: current.get("relative_humidity_2m").and_then(|v| v.as_f64()).unwrap_or(0.0) as f32, .get("temperature_2m")
wind_speed_ms: current.get("wind_speed_10m").and_then(|v| v.as_f64()).unwrap_or(0.0) as f32, .and_then(|v| v.as_f64())
weather_code: current.get("weather_code").and_then(|v| v.as_u64()).unwrap_or(0) as u16, .unwrap_or(0.0) as f32,
humidity_pct: current
.get("relative_humidity_2m")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as f32,
wind_speed_ms: current
.get("wind_speed_10m")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as f32,
weather_code: current
.get("weather_code")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u16,
}) })
} }
@@ -33,7 +48,8 @@ pub async fn check_osm_changes(scene: &GeoScene, cache: &TileCache) -> Result<Ve
let mut changes = Vec::new(); let mut changes = Vec::new();
let cache_key = "osm_building_count"; let cache_key = "osm_building_count";
let prev_count: usize = cache.get(cache_key) let prev_count: usize = cache
.get(cache_key)
.and_then(|d| String::from_utf8(d).ok()) .and_then(|d| String::from_utf8(d).ok())
.and_then(|s| s.trim().parse().ok()) .and_then(|s| s.trim().parse().ok())
.unwrap_or(0); .unwrap_or(0);
@@ -41,7 +57,10 @@ pub async fn check_osm_changes(scene: &GeoScene, cache: &TileCache) -> Result<Ve
let current_count = scene.buildings.len(); let current_count = scene.buildings.len();
if prev_count > 0 && current_count != prev_count { if prev_count > 0 && current_count != prev_count {
let diff = current_count as i64 - prev_count as i64; let diff = current_count as i64 - prev_count as i64;
changes.push(format!("Building count changed: {}{} ({:+})", prev_count, current_count, diff)); changes.push(format!(
"Building count changed: {}{} ({:+})",
prev_count, current_count, diff
));
} }
cache.put(cache_key, current_count.to_string().as_bytes())?; cache.put(cache_key, current_count.to_string().as_bytes())?;
@@ -199,9 +218,7 @@ pub fn is_night_at(lat_deg: f64, utc: chrono::DateTime<chrono::Utc>) -> bool {
// Solar declination (Spencer, 1971 — simplified) // Solar declination (Spencer, 1971 — simplified)
let gamma = 2.0 * PI * (day_of_year - 1.0) / 365.0; let gamma = 2.0 * PI * (day_of_year - 1.0) / 365.0;
let decl = 0.006918 let decl = 0.006918 - 0.399912 * gamma.cos() + 0.070257 * gamma.sin()
- 0.399912 * gamma.cos()
+ 0.070257 * gamma.sin()
- 0.006758 * (2.0 * gamma).cos() - 0.006758 * (2.0 * gamma).cos()
+ 0.000907 * (2.0 * gamma).sin(); + 0.000907 * (2.0 * gamma).sin();
@@ -290,7 +307,9 @@ mod tests {
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let result = rt.block_on(detect_tile_changes("test_tile_ident", &data, &cache)).unwrap(); let result = rt
.block_on(detect_tile_changes("test_tile_ident", &data, &cache))
.unwrap();
assert!((result.diff_score - 0.0).abs() < 1e-9); assert!((result.diff_score - 0.0).abs() < 1e-9);
assert_eq!(result.changed_pixels, 0); assert_eq!(result.changed_pixels, 0);
} }
@@ -306,7 +325,9 @@ mod tests {
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let result = rt.block_on(detect_tile_changes("test_tile_diff", &new, &cache)).unwrap(); let result = rt
.block_on(detect_tile_changes("test_tile_diff", &new, &cache))
.unwrap();
assert!((result.diff_score - 1.0).abs() < 1e-9); assert!((result.diff_score - 1.0).abs() < 1e-9);
} }
} }
+38 -17
View File
@@ -10,7 +10,13 @@ pub async fn fetch_elevation(point: &GeoPoint, cache: &TileCache) -> Result<Elev
let lon_int = point.lon.floor() as i32; let lon_int = point.lon.floor() as i32;
let ns = if lat_int >= 0 { 'N' } else { 'S' }; let ns = if lat_int >= 0 { 'N' } else { 'S' };
let ew = if lon_int >= 0 { 'E' } else { 'W' }; let ew = if lon_int >= 0 { 'E' } else { 'W' };
let filename = format!("{}{:02}{}{:03}.hgt", ns, lat_int.unsigned_abs(), ew, lon_int.unsigned_abs()); let filename = format!(
"{}{:02}{}{:03}.hgt",
ns,
lat_int.unsigned_abs(),
ew,
lon_int.unsigned_abs()
);
let cache_key = format!("srtm_{filename}"); let cache_key = format!("srtm_{filename}");
if let Some(data) = cache.get(&cache_key) { if let Some(data) = cache.get(&cache_key) {
@@ -22,9 +28,8 @@ pub async fn fetch_elevation(point: &GeoPoint, cache: &TileCache) -> Result<Elev
.build()?; .build()?;
// Primary: NASA SRTM public mirror (no auth required for .hgt) // Primary: NASA SRTM public mirror (no auth required for .hgt)
let nasa_url = format!( let nasa_url =
"https://e4ftl01.cr.usgs.gov/MEASURES/SRTMGL1.003/2000.02.11/{filename}" format!("https://e4ftl01.cr.usgs.gov/MEASURES/SRTMGL1.003/2000.02.11/{filename}");
);
if let Ok(resp) = client.get(&nasa_url).send().await { if let Ok(resp) = client.get(&nasa_url).send().await {
if resp.status().is_success() { if resp.status().is_success() {
@@ -37,9 +42,7 @@ pub async fn fetch_elevation(point: &GeoPoint, cache: &TileCache) -> Result<Elev
// Fallback: viewfinderpanoramas.org // Fallback: viewfinderpanoramas.org
// Files are grouped by continent zip, but individual .hgt files can be // Files are grouped by continent zip, but individual .hgt files can be
// fetched directly when the server exposes them. // fetched directly when the server exposes them.
let vfp_url = format!( let vfp_url = format!("http://viewfinderpanoramas.org/dem1/{filename}");
"http://viewfinderpanoramas.org/dem1/{filename}"
);
if let Ok(resp) = client.get(&vfp_url).send().await { if let Ok(resp) = client.get(&vfp_url).send().await {
if resp.status().is_success() { if resp.status().is_success() {
@@ -54,7 +57,8 @@ pub async fn fetch_elevation(point: &GeoPoint, cache: &TileCache) -> Result<Elev
origin_lat: lat_int as f64, origin_lat: lat_int as f64,
origin_lon: lon_int as f64, origin_lon: lon_int as f64,
cell_size_deg: 1.0 / 3600.0, cell_size_deg: 1.0 / 3600.0,
cols: 100, rows: 100, cols: 100,
rows: 100,
heights: vec![0.0; 10000], heights: vec![0.0; 10000],
}) })
} }
@@ -64,17 +68,24 @@ pub fn parse_hgt(data: &[u8], origin_lat: f64, origin_lon: f64) -> Result<Elevat
let n_samples = data.len() / 2; let n_samples = data.len() / 2;
let side = (n_samples as f64).sqrt() as usize; let side = (n_samples as f64).sqrt() as usize;
let heights: Vec<f32> = data.chunks_exact(2) let heights: Vec<f32> = data
.chunks_exact(2)
.map(|c| { .map(|c| {
let v = i16::from_be_bytes([c[0], c[1]]); let v = i16::from_be_bytes([c[0], c[1]]);
if v == -32768 { 0.0 } else { v as f32 } // -32768 = void if v == -32768 {
0.0
} else {
v as f32
} // -32768 = void
}) })
.collect(); .collect();
Ok(ElevationGrid { Ok(ElevationGrid {
origin_lat, origin_lon, origin_lat,
origin_lon,
cell_size_deg: 1.0 / (side - 1) as f64, cell_size_deg: 1.0 / (side - 1) as f64,
cols: side, rows: side, cols: side,
rows: side,
heights, heights,
}) })
} }
@@ -87,10 +98,18 @@ pub fn elevation_at(grid: &ElevationGrid, point: &GeoPoint) -> f32 {
/// Extract a small subgrid around a point. /// Extract a small subgrid around a point.
pub fn extract_subgrid(grid: &ElevationGrid, center: &GeoPoint, radius_m: f64) -> ElevationGrid { pub fn extract_subgrid(grid: &ElevationGrid, center: &GeoPoint, radius_m: f64) -> ElevationGrid {
let radius_deg = radius_m / 111_320.0; let radius_deg = radius_m / 111_320.0;
let min_row = ((grid.origin_lat + (grid.rows as f64 * grid.cell_size_deg) - center.lat - radius_deg) / grid.cell_size_deg).max(0.0) as usize; let min_row =
let max_row = ((grid.origin_lat + (grid.rows as f64 * grid.cell_size_deg) - center.lat + radius_deg) / grid.cell_size_deg).min(grid.rows as f64) as usize; ((grid.origin_lat + (grid.rows as f64 * grid.cell_size_deg) - center.lat - radius_deg)
let min_col = ((center.lon - radius_deg - grid.origin_lon) / grid.cell_size_deg).max(0.0) as usize; / grid.cell_size_deg)
let max_col = ((center.lon + radius_deg - grid.origin_lon) / grid.cell_size_deg).min(grid.cols as f64) as usize; .max(0.0) as usize;
let max_row = ((grid.origin_lat + (grid.rows as f64 * grid.cell_size_deg) - center.lat
+ radius_deg)
/ grid.cell_size_deg)
.min(grid.rows as f64) as usize;
let min_col =
((center.lon - radius_deg - grid.origin_lon) / grid.cell_size_deg).max(0.0) as usize;
let max_col = ((center.lon + radius_deg - grid.origin_lon) / grid.cell_size_deg)
.min(grid.cols as f64) as usize;
let rows = max_row.saturating_sub(min_row); let rows = max_row.saturating_sub(min_row);
let cols = max_col.saturating_sub(min_col); let cols = max_col.saturating_sub(min_col);
@@ -105,6 +124,8 @@ pub fn extract_subgrid(grid: &ElevationGrid, center: &GeoPoint, radius_m: f64) -
origin_lat: grid.origin_lat + (grid.rows - max_row) as f64 * grid.cell_size_deg, origin_lat: grid.origin_lat + (grid.rows - max_row) as f64 * grid.cell_size_deg,
origin_lon: grid.origin_lon + min_col as f64 * grid.cell_size_deg, origin_lon: grid.origin_lon + min_col as f64 * grid.cell_size_deg,
cell_size_deg: grid.cell_size_deg, cell_size_deg: grid.cell_size_deg,
cols, rows, heights, cols,
rows,
heights,
} }
} }
+21 -4
View File
@@ -43,11 +43,19 @@ impl TileProvider {
} }
/// Fetch a single tile with caching. /// Fetch a single tile with caching.
pub async fn fetch_tile(provider: &TileProvider, coord: &TileCoord, cache: &TileCache) -> Result<RasterTile> { pub async fn fetch_tile(
provider: &TileProvider,
coord: &TileCoord,
cache: &TileCache,
) -> Result<RasterTile> {
let cache_key = format!("tiles_{}_{}_{}.dat", coord.z, coord.x, coord.y); let cache_key = format!("tiles_{}_{}_{}.dat", coord.z, coord.x, coord.y);
if let Some(data) = cache.get(&cache_key) { if let Some(data) = cache.get(&cache_key) {
return Ok(RasterTile { coord: coord.clone(), data, bounds: coord::tile_bounds(coord) }); return Ok(RasterTile {
coord: coord.clone(),
data,
bounds: coord::tile_bounds(coord),
});
} }
let url = provider.url(coord); let url = provider.url(coord);
@@ -63,11 +71,20 @@ pub async fn fetch_tile(provider: &TileProvider, coord: &TileCoord, cache: &Tile
let data = resp.bytes().await?.to_vec(); let data = resp.bytes().await?.to_vec();
cache.put(&cache_key, &data)?; cache.put(&cache_key, &data)?;
Ok(RasterTile { coord: coord.clone(), data, bounds: coord::tile_bounds(coord) }) Ok(RasterTile {
coord: coord.clone(),
data,
bounds: coord::tile_bounds(coord),
})
} }
/// Fetch all tiles covering a bounding box. /// Fetch all tiles covering a bounding box.
pub async fn fetch_area(provider: &TileProvider, bbox: &GeoBBox, zoom: u8, cache: &TileCache) -> Result<Vec<RasterTile>> { pub async fn fetch_area(
provider: &TileProvider,
bbox: &GeoBBox,
zoom: u8,
cache: &TileCache,
) -> Result<Vec<RasterTile>> {
let coords = coord::tiles_for_bbox(bbox, zoom); let coords = coord::tiles_for_bbox(bbox, zoom);
let mut tiles = Vec::with_capacity(coords.len()); let mut tiles = Vec::with_capacity(coords.len());
for c in &coords { for c in &coords {
+7 -2
View File
@@ -61,7 +61,8 @@ pub struct ElevationGrid {
impl ElevationGrid { impl ElevationGrid {
pub fn get(&self, lat: f64, lon: f64) -> Option<f32> { pub fn get(&self, lat: f64, lon: f64) -> Option<f32> {
let row = ((self.origin_lat + (self.rows as f64 * self.cell_size_deg) - lat) / self.cell_size_deg) as usize; let row = ((self.origin_lat + (self.rows as f64 * self.cell_size_deg) - lat)
/ self.cell_size_deg) as usize;
let col = ((lon - self.origin_lon) / self.cell_size_deg) as usize; let col = ((lon - self.origin_lon) / self.cell_size_deg) as usize;
if row < self.rows && col < self.cols { if row < self.rows && col < self.cols {
Some(self.heights[row * self.cols + col]) Some(self.heights[row * self.cols + col])
@@ -97,7 +98,11 @@ pub struct GeoRegistration {
impl Default for GeoRegistration { impl Default for GeoRegistration {
fn default() -> Self { fn default() -> Self {
Self { Self {
origin: GeoPoint { lat: 0.0, lon: 0.0, alt: 0.0 }, origin: GeoPoint {
lat: 0.0,
lon: 0.0,
alt: 0.0,
},
heading_deg: 0.0, heading_deg: 0.0,
scale: 1.0, scale: 1.0,
} }
+63 -15
View File
@@ -1,26 +1,58 @@
use wifi_densepose_geo::*;
use wifi_densepose_geo::coord; use wifi_densepose_geo::coord;
use wifi_densepose_geo::*;
#[test] #[test]
fn test_haversine() { fn test_haversine() {
let toronto = GeoPoint { lat: 43.6532, lon: -79.3832, alt: 0.0 }; let toronto = GeoPoint {
let ottawa = GeoPoint { lat: 45.4215, lon: -75.6972, alt: 0.0 }; lat: 43.6532,
lon: -79.3832,
alt: 0.0,
};
let ottawa = GeoPoint {
lat: 45.4215,
lon: -75.6972,
alt: 0.0,
};
let dist = coord::haversine(&toronto, &ottawa); let dist = coord::haversine(&toronto, &ottawa);
assert!((dist - 353_000.0).abs() < 5_000.0, "Toronto-Ottawa ~353km, got {:.0}m", dist); assert!(
(dist - 353_000.0).abs() < 5_000.0,
"Toronto-Ottawa ~353km, got {:.0}m",
dist
);
} }
#[test] #[test]
fn test_wgs84_to_enu() { fn test_wgs84_to_enu() {
let origin = GeoPoint { lat: 43.0, lon: -79.0, alt: 100.0 }; let origin = GeoPoint {
let point = GeoPoint { lat: 43.001, lon: -79.0, alt: 100.0 }; lat: 43.0,
lon: -79.0,
alt: 100.0,
};
let point = GeoPoint {
lat: 43.001,
lon: -79.0,
alt: 100.0,
};
let enu = coord::wgs84_to_enu(&point, &origin); let enu = coord::wgs84_to_enu(&point, &origin);
assert!((enu[1] - 111.0).abs() < 5.0, "0.001 deg lat ~111m north, got {:.1}m", enu[1]); assert!(
assert!(enu[0].abs() < 1.0, "same longitude should have ~0 east, got {:.1}m", enu[0]); (enu[1] - 111.0).abs() < 5.0,
"0.001 deg lat ~111m north, got {:.1}m",
enu[1]
);
assert!(
enu[0].abs() < 1.0,
"same longitude should have ~0 east, got {:.1}m",
enu[0]
);
} }
#[test] #[test]
fn test_enu_roundtrip() { fn test_enu_roundtrip() {
let origin = GeoPoint { lat: 43.6532, lon: -79.3832, alt: 76.0 }; let origin = GeoPoint {
lat: 43.6532,
lon: -79.3832,
alt: 76.0,
};
let local = [100.0, 200.0, 5.0]; // 100m east, 200m north, 5m up let local = [100.0, 200.0, 5.0]; // 100m east, 200m north, 5m up
let geo = coord::enu_to_wgs84(&local, &origin); let geo = coord::enu_to_wgs84(&local, &origin);
let back = coord::wgs84_to_enu(&geo, &origin); let back = coord::wgs84_to_enu(&geo, &origin);
@@ -41,16 +73,28 @@ fn test_tile_coords() {
#[test] #[test]
fn test_tiles_for_bbox() { fn test_tiles_for_bbox() {
let bbox = GeoBBox::from_center( let bbox = GeoBBox::from_center(
&GeoPoint { lat: 43.6532, lon: -79.3832, alt: 0.0 }, &GeoPoint {
lat: 43.6532,
lon: -79.3832,
alt: 0.0,
},
500.0, 500.0,
); );
let tiles = coord::tiles_for_bbox(&bbox, 16); let tiles = coord::tiles_for_bbox(&bbox, 16);
assert!(tiles.len() >= 4 && tiles.len() <= 25, "500m radius should need 4-25 tiles, got {}", tiles.len()); assert!(
tiles.len() >= 4 && tiles.len() <= 25,
"500m radius should need 4-25 tiles, got {}",
tiles.len()
);
} }
#[test] #[test]
fn test_geo_bbox_from_center() { fn test_geo_bbox_from_center() {
let center = GeoPoint { lat: 43.0, lon: -79.0, alt: 0.0 }; let center = GeoPoint {
lat: 43.0,
lon: -79.0,
alt: 0.0,
};
let bbox = GeoBBox::from_center(&center, 1000.0); let bbox = GeoBBox::from_center(&center, 1000.0);
assert!(bbox.south < 43.0 && bbox.north > 43.0); assert!(bbox.south < 43.0 && bbox.north > 43.0);
assert!(bbox.west < -79.0 && bbox.east > -79.0); assert!(bbox.west < -79.0 && bbox.east > -79.0);
@@ -70,14 +114,18 @@ fn test_hgt_parse() {
#[test] #[test]
fn test_registration() { fn test_registration() {
let origin = GeoPoint { lat: 43.6532, lon: -79.3832, alt: 76.0 }; let origin = GeoPoint {
lat: 43.6532,
lon: -79.3832,
alt: 76.0,
};
let reg = wifi_densepose_geo::register::auto_register(&origin); let reg = wifi_densepose_geo::register::auto_register(&origin);
let local = [10.0f32, 0.0, 20.0]; // 10m east, 20m forward let local = [10.0f32, 0.0, 20.0]; // 10m east, 20m forward
let geo = wifi_densepose_geo::register::local_to_wgs84(&reg, &local); let geo = wifi_densepose_geo::register::local_to_wgs84(&reg, &local);
assert!((geo.lat - origin.lat).abs() < 0.001); assert!((geo.lat - origin.lat).abs() < 0.001);
assert!((geo.lon - origin.lon).abs() < 0.001); assert!((geo.lon - origin.lon).abs() < 0.001);
let back = wifi_densepose_geo::register::wgs84_to_local(&reg, &geo); let back = wifi_densepose_geo::register::wgs84_to_local(&reg, &geo);
assert!((back[0] - local[0]).abs() < 0.1); assert!((back[0] - local[0]).abs() < 0.1);
assert!((back[2] - local[2]).abs() < 0.1); assert!((back[2] - local[2]).abs() < 0.1);
@@ -6,12 +6,11 @@
//! - Replay window check performance //! - Replay window check performance
//! - FramedMessage encode/decode throughput //! - FramedMessage encode/decode throughput
use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use std::time::Duration; use std::time::Duration;
use wifi_densepose_hardware::esp32::{ use wifi_densepose_hardware::esp32::{
TdmSchedule, SyncBeacon, SecurityMode, QuicTransportConfig, AuthenticatedBeacon, FramedMessage, MessageType, QuicTransportConfig, ReplayWindow, SecLevel,
SecureTdmCoordinator, SecureTdmConfig, SecLevel, SecureTdmConfig, SecureTdmCoordinator, SecurityMode, SyncBeacon, TdmSchedule,
AuthenticatedBeacon, ReplayWindow, FramedMessage, MessageType,
}; };
fn make_beacon() -> SyncBeacon { fn make_beacon() -> SyncBeacon {
@@ -43,12 +42,14 @@ fn bench_beacon_serialize_authenticated(c: &mut Criterion) {
c.bench_function("beacon_serialize_28byte_auth", |b| { c.bench_function("beacon_serialize_28byte_auth", |b| {
b.iter(|| { b.iter(|| {
let tag = AuthenticatedBeacon::compute_tag(black_box(&msg), &key); let tag = AuthenticatedBeacon::compute_tag(black_box(&msg), &key);
black_box(AuthenticatedBeacon { black_box(
beacon: beacon.clone(), AuthenticatedBeacon {
nonce, beacon: beacon.clone(),
hmac_tag: tag, nonce,
} hmac_tag: tag,
.to_bytes()); }
.to_bytes(),
);
}); });
}); });
} }
@@ -114,15 +115,11 @@ fn bench_framed_message_roundtrip(c: &mut Criterion) {
let msg = FramedMessage::new(MessageType::CsiFrame, payload); let msg = FramedMessage::new(MessageType::CsiFrame, payload);
let bytes = msg.to_bytes(); let bytes = msg.to_bytes();
group.bench_with_input( group.bench_with_input(BenchmarkId::new("encode", payload_size), &msg, |b, msg| {
BenchmarkId::new("encode", payload_size), b.iter(|| {
&msg, black_box(msg.to_bytes());
|b, msg| { });
b.iter(|| { });
black_box(msg.to_bytes());
});
},
);
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("decode", payload_size), BenchmarkId::new("decode", payload_size),
@@ -8,7 +8,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
use std::net::{SocketAddr, UdpSocket}; use std::net::{SocketAddr, UdpSocket};
use std::sync::mpsc::{self, SyncSender, Receiver}; use std::sync::mpsc::{self, Receiver, SyncSender};
use crate::csi_frame::CsiFrame; use crate::csi_frame::CsiFrame;
use crate::esp32_parser::Esp32CsiParser; use crate::esp32_parser::Esp32CsiParser;
@@ -58,11 +58,7 @@ impl NodeState {
fn update(&mut self, sequence: u32) -> u32 { fn update(&mut self, sequence: u32) -> u32 {
self.frames_received += 1; self.frames_received += 1;
let expected = self.last_sequence.wrapping_add(1); let expected = self.last_sequence.wrapping_add(1);
let gap = if sequence > expected { let gap = sequence.saturating_sub(expected);
sequence - expected
} else {
0
};
self.frames_dropped += gap as u64; self.frames_dropped += gap as u64;
self.last_sequence = sequence; self.last_sequence = sequence;
gap gap
@@ -14,7 +14,10 @@ use wifi_densepose_hardware::{Esp32CsiParser, ParseError};
/// UDP aggregator for ESP32 CSI nodes (ADR-018). /// UDP aggregator for ESP32 CSI nodes (ADR-018).
#[derive(Parser)] #[derive(Parser)]
#[command(name = "aggregator", about = "Receive and display live CSI frames from ESP32 nodes")] #[command(
name = "aggregator",
about = "Receive and display live CSI frames from ESP32 nodes"
)]
struct Cli { struct Cli {
/// Address:port to bind the UDP listener to. /// Address:port to bind the UDP listener to.
#[arg(long, default_value = "0.0.0.0:5005")] #[arg(long, default_value = "0.0.0.0:5005")]
+51 -15
View File
@@ -79,11 +79,7 @@ mod tests {
use crate::csi_frame::{AntennaConfig, Bandwidth, CsiMetadata, SubcarrierData}; use crate::csi_frame::{AntennaConfig, Bandwidth, CsiMetadata, SubcarrierData};
use chrono::Utc; use chrono::Utc;
fn make_frame( fn make_frame(node_id: u8, n_antennas: u8, subcarriers: Vec<SubcarrierData>) -> CsiFrame {
node_id: u8,
n_antennas: u8,
subcarriers: Vec<SubcarrierData>,
) -> CsiFrame {
let n_subcarriers = if n_antennas == 0 { let n_subcarriers = if n_antennas == 0 {
subcarriers.len() subcarriers.len()
} else { } else {
@@ -113,8 +109,16 @@ mod tests {
#[test] #[test]
fn test_bridge_from_known_iq() { fn test_bridge_from_known_iq() {
let subs = vec![ let subs = vec![
SubcarrierData { i: 3, q: 4, index: -1 }, // amp = 5.0 SubcarrierData {
SubcarrierData { i: 0, q: 10, index: 1 }, // amp = 10.0 i: 3,
q: 4,
index: -1,
}, // amp = 5.0
SubcarrierData {
i: 0,
q: 10,
index: 1,
}, // amp = 10.0
]; ];
let frame = make_frame(1, 1, subs); let frame = make_frame(1, 1, subs);
let data: CsiData = frame.into(); let data: CsiData = frame.into();
@@ -128,12 +132,36 @@ mod tests {
fn test_bridge_multi_antenna() { fn test_bridge_multi_antenna() {
// 2 antennas, 3 subcarriers each = 6 total // 2 antennas, 3 subcarriers each = 6 total
let subs = vec![ let subs = vec![
SubcarrierData { i: 1, q: 0, index: -1 }, SubcarrierData {
SubcarrierData { i: 2, q: 0, index: 0 }, i: 1,
SubcarrierData { i: 3, q: 0, index: 1 }, q: 0,
SubcarrierData { i: 4, q: 0, index: -1 }, index: -1,
SubcarrierData { i: 5, q: 0, index: 0 }, },
SubcarrierData { i: 6, q: 0, index: 1 }, SubcarrierData {
i: 2,
q: 0,
index: 0,
},
SubcarrierData {
i: 3,
q: 0,
index: 1,
},
SubcarrierData {
i: 4,
q: 0,
index: -1,
},
SubcarrierData {
i: 5,
q: 0,
index: 0,
},
SubcarrierData {
i: 6,
q: 0,
index: 1,
},
]; ];
let frame = make_frame(1, 2, subs); let frame = make_frame(1, 2, subs);
let data: CsiData = frame.into(); let data: CsiData = frame.into();
@@ -146,7 +174,11 @@ mod tests {
#[test] #[test]
fn test_bridge_snr_computation() { fn test_bridge_snr_computation() {
let subs = vec![SubcarrierData { i: 1, q: 0, index: 0 }]; let subs = vec![SubcarrierData {
i: 1,
q: 0,
index: 0,
}];
let frame = make_frame(1, 1, subs); let frame = make_frame(1, 1, subs);
let data: CsiData = frame.into(); let data: CsiData = frame.into();
@@ -156,7 +188,11 @@ mod tests {
#[test] #[test]
fn test_bridge_preserves_metadata() { fn test_bridge_preserves_metadata() {
let subs = vec![SubcarrierData { i: 10, q: 20, index: 0 }]; let subs = vec![SubcarrierData {
i: 10,
q: 20,
index: 0,
}];
let frame = make_frame(7, 1, subs); let frame = make_frame(7, 1, subs);
let data: CsiData = frame.into(); let data: CsiData = frame.into();
@@ -28,11 +28,15 @@ impl CsiFrame {
/// - amplitude = sqrt(I^2 + Q^2) /// - amplitude = sqrt(I^2 + Q^2)
/// - phase = atan2(Q, I) /// - phase = atan2(Q, I)
pub fn to_amplitude_phase(&self) -> (Vec<f64>, Vec<f64>) { pub fn to_amplitude_phase(&self) -> (Vec<f64>, Vec<f64>) {
let amplitudes: Vec<f64> = self.subcarriers.iter() let amplitudes: Vec<f64> = self
.subcarriers
.iter()
.map(|sc| (sc.i as f64 * sc.i as f64 + sc.q as f64 * sc.q as f64).sqrt()) .map(|sc| (sc.i as f64 * sc.i as f64 + sc.q as f64 * sc.q as f64).sqrt())
.collect(); .collect();
let phases: Vec<f64> = self.subcarriers.iter() let phases: Vec<f64> = self
.subcarriers
.iter()
.map(|sc| (sc.q as f64).atan2(sc.i as f64)) .map(|sc| (sc.q as f64).atan2(sc.i as f64))
.collect(); .collect();
@@ -44,7 +48,9 @@ impl CsiFrame {
if self.subcarriers.is_empty() { if self.subcarriers.is_empty() {
return 0.0; return 0.0;
} }
let sum: f64 = self.subcarriers.iter() let sum: f64 = self
.subcarriers
.iter()
.map(|sc| (sc.i as f64 * sc.i as f64 + sc.q as f64 * sc.q as f64).sqrt()) .map(|sc| (sc.i as f64 * sc.i as f64 + sc.q as f64 * sc.q as f64).sqrt())
.sum(); .sum();
sum / self.subcarriers.len() as f64 sum / self.subcarriers.len() as f64
@@ -52,8 +58,7 @@ impl CsiFrame {
/// Check if this frame has valid data (non-zero subcarriers with non-zero I/Q). /// Check if this frame has valid data (non-zero subcarriers with non-zero I/Q).
pub fn is_valid(&self) -> bool { pub fn is_valid(&self) -> bool {
!self.subcarriers.is_empty() !self.subcarriers.is_empty() && self.subcarriers.iter().any(|sc| sc.i != 0 || sc.q != 0)
&& self.subcarriers.iter().any(|sc| sc.i != 0 || sc.q != 0)
} }
} }
@@ -156,9 +161,21 @@ mod tests {
sequence: 1, sequence: 1,
}, },
subcarriers: vec![ subcarriers: vec![
SubcarrierData { i: 100, q: 0, index: -28 }, SubcarrierData {
SubcarrierData { i: 0, q: 50, index: -27 }, i: 100,
SubcarrierData { i: 30, q: 40, index: -26 }, q: 0,
index: -28,
},
SubcarrierData {
i: 0,
q: 50,
index: -27,
},
SubcarrierData {
i: 30,
q: 40,
index: -26,
},
], ],
} }
} }
+8 -30
View File
@@ -7,17 +7,11 @@ use thiserror::Error;
pub enum ParseError { pub enum ParseError {
/// Not enough bytes in the buffer to parse a complete frame. /// Not enough bytes in the buffer to parse a complete frame.
#[error("Insufficient data: need {needed} bytes, got {got}")] #[error("Insufficient data: need {needed} bytes, got {got}")]
InsufficientData { InsufficientData { needed: usize, got: usize },
needed: usize,
got: usize,
},
/// The frame header magic bytes don't match expected values. /// The frame header magic bytes don't match expected values.
#[error("Invalid magic: expected {expected:#06x}, got {got:#06x}")] #[error("Invalid magic: expected {expected:#06x}, got {got:#06x}")]
InvalidMagic { InvalidMagic { expected: u32, got: u32 },
expected: u32,
got: u32,
},
/// A recognized RuView wire packet was received that is *not* an /// A recognized RuView wire packet was received that is *not* an
/// ADR-018 raw CSI frame (e.g. ADR-039 vitals, ADR-081 feature state, /// ADR-018 raw CSI frame (e.g. ADR-039 vitals, ADR-081 feature state,
@@ -26,41 +20,25 @@ pub enum ParseError {
/// interleaved with CSI frames — that is expected, not a corruption. /// interleaved with CSI frames — that is expected, not a corruption.
/// Consumers should route the packet to the matching decoder or skip it. /// Consumers should route the packet to the matching decoder or skip it.
#[error("Non-CSI RuView packet on CSI socket: {kind} (magic {magic:#010x})")] #[error("Non-CSI RuView packet on CSI socket: {kind} (magic {magic:#010x})")]
NonCsiPacket { NonCsiPacket { magic: u32, kind: &'static str },
magic: u32,
kind: &'static str,
},
/// The frame indicates more subcarriers than physically possible. /// The frame indicates more subcarriers than physically possible.
#[error("Invalid subcarrier count: {count} (max {max})")] #[error("Invalid subcarrier count: {count} (max {max})")]
InvalidSubcarrierCount { InvalidSubcarrierCount { count: usize, max: usize },
count: usize,
max: usize,
},
/// The I/Q data buffer length doesn't match expected size. /// The I/Q data buffer length doesn't match expected size.
#[error("I/Q data length mismatch: expected {expected}, got {got}")] #[error("I/Q data length mismatch: expected {expected}, got {got}")]
IqLengthMismatch { IqLengthMismatch { expected: usize, got: usize },
expected: usize,
got: usize,
},
/// RSSI value is outside the valid range. /// RSSI value is outside the valid range.
#[error("Invalid RSSI value: {value} dBm (expected -100..0)")] #[error("Invalid RSSI value: {value} dBm (expected -100..0)")]
InvalidRssi { InvalidRssi { value: i32 },
value: i32,
},
/// Invalid antenna count (must be 1-4 for ESP32). /// Invalid antenna count (must be 1-4 for ESP32).
#[error("Invalid antenna count: {count} (expected 1-4)")] #[error("Invalid antenna count: {count} (expected 1-4)")]
InvalidAntennaCount { InvalidAntennaCount { count: u8 },
count: u8,
},
/// Generic byte-level parse error. /// Generic byte-level parse error.
#[error("Parse error at offset {offset}: {message}")] #[error("Parse error at offset {offset}: {message}")]
ByteError { ByteError { offset: usize, message: String },
offset: usize,
message: String,
},
} }
@@ -9,23 +9,18 @@
//! - `quic_transport` -- QUIC-based authenticated transport for aggregator nodes //! - `quic_transport` -- QUIC-based authenticated transport for aggregator nodes
//! - `secure_tdm` -- Secured TDM protocol with dual-mode (QUIC / manual crypto) //! - `secure_tdm` -- Secured TDM protocol with dual-mode (QUIC / manual crypto)
pub mod tdm;
pub mod quic_transport; pub mod quic_transport;
pub mod secure_tdm; pub mod secure_tdm;
pub mod tdm;
pub use tdm::{ pub use tdm::{SyncBeacon, TdmCoordinator, TdmError, TdmSchedule, TdmSlot, TdmSlotCompleted};
TdmSchedule, TdmCoordinator, TdmSlot, TdmSlotCompleted,
SyncBeacon, TdmError,
};
pub use quic_transport::{ pub use quic_transport::{
SecurityMode, QuicTransportConfig, QuicTransportHandle, QuicTransportError, ConnectionState, FramedMessage, MessageType, QuicTransportConfig, QuicTransportError,
TransportStats, ConnectionState, MessageType, FramedMessage, QuicTransportHandle, SecurityMode, TransportStats, STREAM_BEACON, STREAM_CONTROL, STREAM_CSI,
STREAM_BEACON, STREAM_CSI, STREAM_CONTROL,
}; };
pub use secure_tdm::{ pub use secure_tdm::{
SecureTdmCoordinator, SecureTdmConfig, SecureTdmError, AuthenticatedBeacon, ReplayWindow, SecLevel, SecureCycleOutput, SecureTdmConfig,
SecLevel, AuthenticatedBeacon, SecureCycleOutput, SecureTdmCoordinator, SecureTdmError, AUTHENTICATED_BEACON_SIZE,
ReplayWindow, AUTHENTICATED_BEACON_SIZE,
}; };
@@ -41,22 +41,17 @@ pub const STREAM_CONTROL: u64 = 2;
/// Determines whether communication uses manual HMAC/SipHash over /// Determines whether communication uses manual HMAC/SipHash over
/// plain UDP (for constrained ESP32-S3 devices) or QUIC with TLS 1.3 /// plain UDP (for constrained ESP32-S3 devices) or QUIC with TLS 1.3
/// (for aggregator-class nodes). /// (for aggregator-class nodes).
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SecurityMode { pub enum SecurityMode {
/// Manual HMAC-SHA256 beacon auth + SipHash-2-4 frame integrity /// Manual HMAC-SHA256 beacon auth + SipHash-2-4 frame integrity
/// over plain UDP. Suitable for ESP32-S3 with limited memory. /// over plain UDP. Suitable for ESP32-S3 with limited memory.
ManualCrypto, ManualCrypto,
/// QUIC transport with TLS 1.3 AEAD encryption, built-in replay /// QUIC transport with TLS 1.3 AEAD encryption, built-in replay
/// protection, congestion control, and connection migration. /// protection, congestion control, and connection migration.
#[default]
QuicTransport, QuicTransport,
} }
impl Default for SecurityMode {
fn default() -> Self {
SecurityMode::QuicTransport
}
}
impl fmt::Display for SecurityMode { impl fmt::Display for SecurityMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
@@ -336,8 +331,7 @@ impl FramedMessage {
return None; return None;
} }
let msg_type = MessageType::from_byte(buf[0])?; let msg_type = MessageType::from_byte(buf[0])?;
let payload_len = let payload_len = u32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
u32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
let total = FRAMED_HEADER_SIZE + payload_len; let total = FRAMED_HEADER_SIZE + payload_len;
if buf.len() < total { if buf.len() < total {
return None; return None;
@@ -29,8 +29,8 @@
//! 4. Sent over plain UDP //! 4. Sent over plain UDP
use super::quic_transport::{ use super::quic_transport::{
FramedMessage, MessageType, QuicTransportConfig, FramedMessage, MessageType, QuicTransportConfig, QuicTransportError, QuicTransportHandle,
QuicTransportHandle, QuicTransportError, SecurityMode, SecurityMode,
}; };
use super::tdm::{SyncBeacon, TdmCoordinator, TdmSchedule, TdmSlotCompleted}; use super::tdm::{SyncBeacon, TdmCoordinator, TdmSchedule, TdmSlotCompleted};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
@@ -59,8 +59,7 @@ pub const AUTHENTICATED_BEACON_SIZE: usize = 16 + NONCE_SIZE + HMAC_TAG_SIZE;
/// Default pre-shared key for testing (16 bytes). In production, this /// Default pre-shared key for testing (16 bytes). In production, this
/// would be loaded from NVS or a secure key store. /// would be loaded from NVS or a secure key store.
const DEFAULT_TEST_KEY: [u8; 16] = [ const DEFAULT_TEST_KEY: [u8; 16] = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
]; ];
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -79,7 +78,10 @@ pub enum SecureTdmError {
/// QUIC transport error. /// QUIC transport error.
Transport(QuicTransportError), Transport(QuicTransportError),
/// The security mode does not match the incoming packet format. /// The security mode does not match the incoming packet format.
ModeMismatch { expected: SecurityMode, got: SecurityMode }, ModeMismatch {
expected: SecurityMode,
got: SecurityMode,
},
/// The mesh key has not been provisioned. /// The mesh key has not been provisioned.
NoMeshKey, NoMeshKey,
} }
@@ -88,7 +90,10 @@ impl fmt::Display for SecureTdmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
SecureTdmError::BeaconAuthFailed => write!(f, "Beacon HMAC verification failed"), SecureTdmError::BeaconAuthFailed => write!(f, "Beacon HMAC verification failed"),
SecureTdmError::BeaconReplay { nonce, last_accepted } => { SecureTdmError::BeaconReplay {
nonce,
last_accepted,
} => {
write!( write!(
f, f,
"Beacon replay: nonce {} <= last_accepted {} - REPLAY_WINDOW", "Beacon replay: nonce {} <= last_accepted {} - REPLAY_WINDOW",
@@ -96,11 +101,19 @@ impl fmt::Display for SecureTdmError {
) )
} }
SecureTdmError::BeaconTooShort { expected, got } => { SecureTdmError::BeaconTooShort { expected, got } => {
write!(f, "Beacon too short: expected {} bytes, got {}", expected, got) write!(
f,
"Beacon too short: expected {} bytes, got {}",
expected, got
)
} }
SecureTdmError::Transport(e) => write!(f, "Transport error: {}", e), SecureTdmError::Transport(e) => write!(f, "Transport error: {}", e),
SecureTdmError::ModeMismatch { expected, got } => { SecureTdmError::ModeMismatch { expected, got } => {
write!(f, "Security mode mismatch: expected {}, got {}", expected, got) write!(
f,
"Security mode mismatch: expected {}, got {}",
expected, got
)
} }
SecureTdmError::NoMeshKey => write!(f, "Mesh key not provisioned"), SecureTdmError::NoMeshKey => write!(f, "Mesh key not provisioned"),
} }
@@ -254,8 +267,7 @@ impl AuthenticatedBeacon {
/// Uses the `hmac` + `sha2` crates for cryptographically secure /// Uses the `hmac` + `sha2` crates for cryptographically secure
/// message authentication (ADR-050, Sprint 1). /// message authentication (ADR-050, Sprint 1).
pub fn compute_tag(payload_and_nonce: &[u8], key: &[u8; 16]) -> [u8; HMAC_TAG_SIZE] { pub fn compute_tag(payload_and_nonce: &[u8], key: &[u8; 16]) -> [u8; HMAC_TAG_SIZE] {
let mut mac = HmacSha256::new_from_slice(key) let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts any key length");
.expect("HMAC-SHA256 accepts any key length");
mac.update(payload_and_nonce); mac.update(payload_and_nonce);
let result = mac.finalize().into_bytes(); let result = mac.finalize().into_bytes();
let mut tag = [0u8; HMAC_TAG_SIZE]; let mut tag = [0u8; HMAC_TAG_SIZE];
@@ -346,10 +358,7 @@ pub struct SecureTdmCoordinator {
impl SecureTdmCoordinator { impl SecureTdmCoordinator {
/// Create a new secure TDM coordinator. /// Create a new secure TDM coordinator.
pub fn new( pub fn new(schedule: TdmSchedule, config: SecureTdmConfig) -> Result<Self, SecureTdmError> {
schedule: TdmSchedule,
config: SecureTdmConfig,
) -> Result<Self, SecureTdmError> {
let transport = if config.security_mode == SecurityMode::QuicTransport { let transport = if config.security_mode == SecurityMode::QuicTransport {
Some(QuicTransportHandle::new(config.quic_config.clone())?) Some(QuicTransportHandle::new(config.quic_config.clone())?)
} else { } else {
@@ -400,10 +409,7 @@ impl SecureTdmCoordinator {
} }
SecurityMode::QuicTransport => { SecurityMode::QuicTransport => {
let beacon_bytes = beacon.to_bytes(); let beacon_bytes = beacon.to_bytes();
let framed = FramedMessage::new( let framed = FramedMessage::new(MessageType::Beacon, beacon_bytes.to_vec());
MessageType::Beacon,
beacon_bytes.to_vec(),
);
let wire = framed.to_bytes(); let wire = framed.to_bytes();
if let Some(ref mut transport) = self.transport { if let Some(ref mut transport) = self.transport {
@@ -449,12 +455,11 @@ impl SecureTdmCoordinator {
} }
} else if buf.len() >= 16 && self.config.sec_level != SecLevel::Enforcing { } else if buf.len() >= 16 && self.config.sec_level != SecLevel::Enforcing {
// Accept unauthenticated 16-byte beacon in permissive/transitional // Accept unauthenticated 16-byte beacon in permissive/transitional
let beacon = SyncBeacon::from_bytes(buf).ok_or( let beacon =
SecureTdmError::BeaconTooShort { SyncBeacon::from_bytes(buf).ok_or(SecureTdmError::BeaconTooShort {
expected: 16, expected: 16,
got: buf.len(), got: buf.len(),
}, })?;
)?;
self.beacons_verified += 1; self.beacons_verified += 1;
Ok(beacon) Ok(beacon)
} else { } else {
@@ -466,12 +471,11 @@ impl SecureTdmCoordinator {
} }
SecurityMode::QuicTransport => { SecurityMode::QuicTransport => {
// In QUIC mode, extract beacon from framed message // In QUIC mode, extract beacon from framed message
let (framed, _) = FramedMessage::from_bytes(buf).ok_or( let (framed, _) =
SecureTdmError::BeaconTooShort { FramedMessage::from_bytes(buf).ok_or(SecureTdmError::BeaconTooShort {
expected: 5 + 16, expected: 5 + 16,
got: buf.len(), got: buf.len(),
}, })?;
)?;
if framed.message_type != MessageType::Beacon { if framed.message_type != MessageType::Beacon {
return Err(SecureTdmError::ModeMismatch { return Err(SecureTdmError::ModeMismatch {
expected: SecurityMode::QuicTransport, expected: SecurityMode::QuicTransport,
@@ -496,11 +500,7 @@ impl SecureTdmCoordinator {
} }
/// Complete a slot in the current cycle (delegates to inner coordinator). /// Complete a slot in the current cycle (delegates to inner coordinator).
pub fn complete_slot( pub fn complete_slot(&mut self, slot_index: usize, capture_quality: f32) -> TdmSlotCompleted {
&mut self,
slot_index: usize,
capture_quality: f32,
) -> TdmSlotCompleted {
self.inner.complete_slot(slot_index, capture_quality) self.inner.complete_slot(slot_index, capture_quality)
} }
@@ -755,10 +755,7 @@ mod tests {
#[test] #[test]
fn test_auth_beacon_too_short() { fn test_auth_beacon_too_short() {
let result = AuthenticatedBeacon::from_bytes(&[0u8; 10]); let result = AuthenticatedBeacon::from_bytes(&[0u8; 10]);
assert!(matches!( assert!(matches!(result, Err(SecureTdmError::BeaconTooShort { .. })));
result,
Err(SecureTdmError::BeaconTooShort { .. })
));
} }
#[test] #[test]
@@ -770,8 +767,7 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_manual_create() { fn test_secure_coordinator_manual_create() {
let coord = let coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
assert_eq!(coord.security_mode(), SecurityMode::ManualCrypto); assert_eq!(coord.security_mode(), SecurityMode::ManualCrypto);
assert_eq!(coord.beacons_produced(), 0); assert_eq!(coord.beacons_produced(), 0);
assert!(coord.transport().is_none()); assert!(coord.transport().is_none());
@@ -779,8 +775,7 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_manual_begin_cycle() { fn test_secure_coordinator_manual_begin_cycle() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
assert_eq!(output.mode, SecurityMode::ManualCrypto); assert_eq!(output.mode, SecurityMode::ManualCrypto);
@@ -792,8 +787,7 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_manual_nonce_increments() { fn test_secure_coordinator_manual_nonce_increments() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
for expected_nonce in 1..=5u32 { for expected_nonce in 1..=5u32 {
let _output = coord.begin_secure_cycle().unwrap(); let _output = coord.begin_secure_cycle().unwrap();
@@ -807,47 +801,37 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_manual_verify_own_beacon() { fn test_secure_coordinator_manual_verify_own_beacon() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
// Create a second coordinator to verify // Create a second coordinator to verify
let mut verifier = let mut verifier = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap(); let beacon = verifier.verify_beacon(&output.authenticated_bytes).unwrap();
let beacon = verifier
.verify_beacon(&output.authenticated_bytes)
.unwrap();
assert_eq!(beacon.cycle_id, 0); assert_eq!(beacon.cycle_id, 0);
} }
#[test] #[test]
fn test_secure_coordinator_manual_reject_tampered() { fn test_secure_coordinator_manual_reject_tampered() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
let mut tampered = output.authenticated_bytes.clone(); let mut tampered = output.authenticated_bytes.clone();
tampered[25] ^= 0xFF; // Tamper with HMAC tag tampered[25] ^= 0xFF; // Tamper with HMAC tag
let mut verifier = let mut verifier = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
assert!(verifier.verify_beacon(&tampered).is_err()); assert!(verifier.verify_beacon(&tampered).is_err());
assert_eq!(verifier.verification_failures(), 1); assert_eq!(verifier.verification_failures(), 1);
} }
#[test] #[test]
fn test_secure_coordinator_manual_reject_replay() { fn test_secure_coordinator_manual_reject_replay() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
let mut verifier = let mut verifier = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
// First acceptance succeeds // First acceptance succeeds
verifier verifier.verify_beacon(&output.authenticated_bytes).unwrap();
.verify_beacon(&output.authenticated_bytes)
.unwrap();
// Replay of same beacon fails // Replay of same beacon fails
let result = verifier.verify_beacon(&output.authenticated_bytes); let result = verifier.verify_beacon(&output.authenticated_bytes);
@@ -908,16 +892,14 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_quic_create() { fn test_secure_coordinator_quic_create() {
let coord = let coord = SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
assert_eq!(coord.security_mode(), SecurityMode::QuicTransport); assert_eq!(coord.security_mode(), SecurityMode::QuicTransport);
assert!(coord.transport().is_some()); assert!(coord.transport().is_some());
} }
#[test] #[test]
fn test_secure_coordinator_quic_begin_cycle() { fn test_secure_coordinator_quic_begin_cycle() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
assert_eq!(output.mode, SecurityMode::QuicTransport); assert_eq!(output.mode, SecurityMode::QuicTransport);
@@ -928,22 +910,17 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_quic_verify_own_beacon() { fn test_secure_coordinator_quic_verify_own_beacon() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
let output = coord.begin_secure_cycle().unwrap(); let output = coord.begin_secure_cycle().unwrap();
let mut verifier = let mut verifier = SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap(); let beacon = verifier.verify_beacon(&output.authenticated_bytes).unwrap();
let beacon = verifier
.verify_beacon(&output.authenticated_bytes)
.unwrap();
assert_eq!(beacon.cycle_id, 0); assert_eq!(beacon.cycle_id, 0);
} }
#[test] #[test]
fn test_secure_coordinator_complete_cycle() { fn test_secure_coordinator_complete_cycle() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
coord.begin_secure_cycle().unwrap(); coord.begin_secure_cycle().unwrap();
for i in 0..4 { for i in 0..4 {
@@ -955,8 +932,7 @@ mod tests {
#[test] #[test]
fn test_secure_coordinator_cycle_id_increments() { fn test_secure_coordinator_cycle_id_increments() {
let mut coord = let mut coord = SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
let out0 = coord.begin_secure_cycle().unwrap(); let out0 = coord.begin_secure_cycle().unwrap();
assert_eq!(out0.beacon.cycle_id, 0); assert_eq!(out0.beacon.cycle_id, 0);
@@ -986,7 +962,10 @@ mod tests {
let key2: [u8; 16] = [0x02; 16]; let key2: [u8; 16] = [0x02; 16];
let tag1 = AuthenticatedBeacon::compute_tag(msg, &key1); let tag1 = AuthenticatedBeacon::compute_tag(msg, &key1);
let tag2 = AuthenticatedBeacon::compute_tag(msg, &key2); let tag2 = AuthenticatedBeacon::compute_tag(msg, &key2);
assert_ne!(tag1, tag2, "Different keys must produce different HMAC tags"); assert_ne!(
tag1, tag2,
"Different keys must produce different HMAC tags"
);
} }
#[test] #[test]
@@ -994,7 +973,10 @@ mod tests {
let key: [u8; 16] = DEFAULT_TEST_KEY; let key: [u8; 16] = DEFAULT_TEST_KEY;
let tag1 = AuthenticatedBeacon::compute_tag(b"message one", &key); let tag1 = AuthenticatedBeacon::compute_tag(b"message one", &key);
let tag2 = AuthenticatedBeacon::compute_tag(b"message two", &key); let tag2 = AuthenticatedBeacon::compute_tag(b"message two", &key);
assert_ne!(tag1, tag2, "Different messages must produce different HMAC tags"); assert_ne!(
tag1, tag2,
"Different messages must produce different HMAC tags"
);
} }
#[test] #[test]
@@ -1023,8 +1005,15 @@ mod tests {
msg[16..20].copy_from_slice(&nonce.to_le_bytes()); msg[16..20].copy_from_slice(&nonce.to_le_bytes());
let tag = AuthenticatedBeacon::compute_tag(&msg, &correct_key); let tag = AuthenticatedBeacon::compute_tag(&msg, &correct_key);
let auth = AuthenticatedBeacon { beacon, nonce, hmac_tag: tag }; let auth = AuthenticatedBeacon {
assert!(auth.verify(&wrong_key).is_err(), "Wrong key must fail verification"); beacon,
nonce,
hmac_tag: tag,
};
assert!(
auth.verify(&wrong_key).is_err(),
"Wrong key must fail verification"
);
} }
#[test] #[test]
@@ -1043,12 +1032,19 @@ mod tests {
msg[16..20].copy_from_slice(&nonce.to_le_bytes()); msg[16..20].copy_from_slice(&nonce.to_le_bytes());
let tag = AuthenticatedBeacon::compute_tag(&msg, &key); let tag = AuthenticatedBeacon::compute_tag(&msg, &key);
let auth = AuthenticatedBeacon { beacon, nonce, hmac_tag: tag }; let auth = AuthenticatedBeacon {
beacon,
nonce,
hmac_tag: tag,
};
let mut wire = auth.to_bytes(); let mut wire = auth.to_bytes();
// Flip one bit in the beacon payload // Flip one bit in the beacon payload
wire[0] ^= 0x01; wire[0] ^= 0x01;
let tampered = AuthenticatedBeacon::from_bytes(&wire).unwrap(); let tampered = AuthenticatedBeacon::from_bytes(&wire).unwrap();
assert!(tampered.verify(&key).is_err(), "Single bit flip must fail verification"); assert!(
tampered.verify(&key).is_err(),
"Single bit flip must fail verification"
);
} }
#[test] #[test]
@@ -1063,7 +1059,8 @@ mod tests {
cycle_period: Duration::from_millis(50), cycle_period: Duration::from_millis(50),
drift_correction_us: 0, drift_correction_us: 0,
generated_at: std::time::Instant::now(), generated_at: std::time::Instant::now(),
}.to_bytes(); }
.to_bytes();
assert!(coord.verify_beacon(&raw).is_err()); assert!(coord.verify_beacon(&raw).is_err());
} }
@@ -67,19 +67,38 @@ impl fmt::Display for TdmError {
write!(f, "Invalid node count: {} (max {})", count, max) write!(f, "Invalid node count: {} (max {})", count, max)
} }
TdmError::SlotIndexOutOfBounds { index, num_slots } => { TdmError::SlotIndexOutOfBounds { index, num_slots } => {
write!(f, "Slot index {} out of bounds (schedule has {} slots)", index, num_slots) write!(
f,
"Slot index {} out of bounds (schedule has {} slots)",
index, num_slots
)
} }
TdmError::UnknownNode { node_id } => { TdmError::UnknownNode { node_id } => {
write!(f, "Unknown node ID: {}", node_id) write!(f, "Unknown node ID: {}", node_id)
} }
TdmError::GuardIntervalTooLarge { guard_us, slot_us } => { TdmError::GuardIntervalTooLarge { guard_us, slot_us } => {
write!(f, "Guard interval {} us exceeds slot duration {} us", guard_us, slot_us) write!(
f,
"Guard interval {} us exceeds slot duration {} us",
guard_us, slot_us
)
} }
TdmError::CycleTooShort { needed_us, available_us } => { TdmError::CycleTooShort {
write!(f, "Cycle too short: need {} us, have {} us", needed_us, available_us) needed_us,
available_us,
} => {
write!(
f,
"Cycle too short: need {} us, have {} us",
needed_us, available_us
)
} }
TdmError::DriftExceedsGuard { drift_us, guard_us } => { TdmError::DriftExceedsGuard { drift_us, guard_us } => {
write!(f, "Drift {:.1} us exceeds guard interval {} us", drift_us, guard_us) write!(
f,
"Drift {:.1} us exceeds guard interval {} us",
drift_us, guard_us
)
} }
} }
} }
@@ -274,7 +293,10 @@ impl TdmSchedule {
/// Check whether clock drift stays within the guard interval. /// Check whether clock drift stays within the guard interval.
pub fn drift_within_guard(&self) -> bool { pub fn drift_within_guard(&self) -> bool {
let drift = self.max_drift_us(); let drift = self.max_drift_us();
let guard = self.slots.first().map_or(0, |s| s.guard_interval.as_micros() as u64); let guard = self
.slots
.first()
.map_or(0, |s| s.guard_interval.as_micros() as u64);
drift < guard as f64 drift < guard as f64
} }
} }
@@ -644,7 +666,10 @@ mod tests {
); );
assert_eq!( assert_eq!(
result.unwrap_err(), result.unwrap_err(),
TdmError::InvalidNodeCount { count: 0, max: MAX_NODES } TdmError::InvalidNodeCount {
count: 0,
max: MAX_NODES
}
); );
} }
@@ -664,11 +689,14 @@ mod tests {
fn test_guard_interval_too_large() { fn test_guard_interval_too_large() {
let result = TdmSchedule::uniform( let result = TdmSchedule::uniform(
&[0, 1], &[0, 1],
Duration::from_millis(1), // 1 ms slot Duration::from_millis(1), // 1 ms slot
Duration::from_millis(2), // 2 ms guard > slot Duration::from_millis(2), // 2 ms guard > slot
Duration::from_millis(30), Duration::from_millis(30),
); );
assert!(matches!(result, Err(TdmError::GuardIntervalTooLarge { .. }))); assert!(matches!(
result,
Err(TdmError::GuardIntervalTooLarge { .. })
));
} }
#[test] #[test]
@@ -113,10 +113,9 @@ impl Esp32CsiParser {
let mut cursor = Cursor::new(data); let mut cursor = Cursor::new(data);
// Magic (offset 0, 4 bytes) // Magic (offset 0, 4 bytes)
let magic = cursor.read_u32::<LittleEndian>().map_err(|_| ParseError::InsufficientData { let magic = cursor
needed: 4, .read_u32::<LittleEndian>()
got: 0, .map_err(|_| ParseError::InsufficientData { needed: 4, got: 0 })?;
})?;
if magic != ESP32_CSI_MAGIC { if magic != ESP32_CSI_MAGIC {
return Err(ParseError::InvalidMagic { return Err(ParseError::InvalidMagic {
@@ -142,10 +141,13 @@ impl Esp32CsiParser {
} }
// Number of subcarriers (offset 6, 2 bytes LE) // Number of subcarriers (offset 6, 2 bytes LE)
let n_subcarriers = cursor.read_u16::<LittleEndian>().map_err(|_| ParseError::ByteError { let n_subcarriers =
offset: 6, cursor
message: "Failed to read subcarrier count".into(), .read_u16::<LittleEndian>()
})? as usize; .map_err(|_| ParseError::ByteError {
offset: 6,
message: "Failed to read subcarrier count".into(),
})? as usize;
if n_subcarriers > MAX_SUBCARRIERS { if n_subcarriers > MAX_SUBCARRIERS {
return Err(ParseError::InvalidSubcarrierCount { return Err(ParseError::InvalidSubcarrierCount {
@@ -155,16 +157,21 @@ impl Esp32CsiParser {
} }
// Frequency MHz (offset 8, 4 bytes LE) // Frequency MHz (offset 8, 4 bytes LE)
let channel_freq_mhz = cursor.read_u32::<LittleEndian>().map_err(|_| ParseError::ByteError { let channel_freq_mhz =
offset: 8, cursor
message: "Failed to read frequency".into(), .read_u32::<LittleEndian>()
})?; .map_err(|_| ParseError::ByteError {
offset: 8,
message: "Failed to read frequency".into(),
})?;
// Sequence number (offset 12, 4 bytes LE) // Sequence number (offset 12, 4 bytes LE)
let sequence = cursor.read_u32::<LittleEndian>().map_err(|_| ParseError::ByteError { let sequence = cursor
offset: 12, .read_u32::<LittleEndian>()
message: "Failed to read sequence number".into(), .map_err(|_| ParseError::ByteError {
})?; offset: 12,
message: "Failed to read sequence number".into(),
})?;
// RSSI (offset 16, 1 byte signed) // RSSI (offset 16, 1 byte signed)
let rssi_dbm = cursor.read_i8().map_err(|_| ParseError::ByteError { let rssi_dbm = cursor.read_i8().map_err(|_| ParseError::ByteError {
@@ -179,10 +186,12 @@ impl Esp32CsiParser {
})?; })?;
// Reserved (offset 18, 2 bytes) — skip // Reserved (offset 18, 2 bytes) — skip
let _reserved = cursor.read_u16::<LittleEndian>().map_err(|_| ParseError::ByteError { let _reserved = cursor
offset: 18, .read_u16::<LittleEndian>()
message: "Failed to read reserved bytes".into(), .map_err(|_| ParseError::ByteError {
})?; offset: 18,
message: "Failed to read reserved bytes".into(),
})?;
// I/Q data: n_antennas * n_subcarriers * 2 bytes // I/Q data: n_antennas * n_subcarriers * 2 bytes
let iq_pair_count = n_antennas as usize * n_subcarriers; let iq_pair_count = n_antennas as usize * n_subcarriers;
@@ -390,11 +399,17 @@ mod tests {
RUVIEW_FEATURE_STATE_MAGIC, RUVIEW_FEATURE_STATE_MAGIC,
RUVIEW_TEMPORAL_MAGIC, RUVIEW_TEMPORAL_MAGIC,
] { ] {
assert!(ruview_sibling_packet_name(m).is_some(), "{m:#010x} unclassified"); assert!(
ruview_sibling_packet_name(m).is_some(),
"{m:#010x} unclassified"
);
let mut data = vec![0u8; 24]; let mut data = vec![0u8; 24];
data[0..4].copy_from_slice(&m.to_le_bytes()); data[0..4].copy_from_slice(&m.to_le_bytes());
assert!( assert!(
matches!(Esp32CsiParser::parse_frame(&data), Err(ParseError::NonCsiPacket { .. })), matches!(
Esp32CsiParser::parse_frame(&data),
Err(ParseError::NonCsiPacket { .. })
),
"{m:#010x} should parse as NonCsiPacket" "{m:#010x} should parse as NonCsiPacket"
); );
} }
+12 -13
View File
@@ -34,12 +34,12 @@
//! } //! }
//! ``` //! ```
mod csi_frame;
mod error;
mod esp32_parser;
pub mod aggregator; pub mod aggregator;
mod bridge; mod bridge;
mod csi_frame;
mod error;
pub mod esp32; pub mod esp32;
mod esp32_parser;
// ADR-081: Rust mirror of the firmware radio abstraction layer (L1) and // ADR-081: Rust mirror of the firmware radio abstraction layer (L1) and
// mesh sensing plane (L3). Lets host tests, simulators, and future // mesh sensing plane (L3). Lets host tests, simulators, and future
@@ -47,18 +47,17 @@ pub mod esp32;
// touching any downstream signal/ruvector/train/mat crate. // touching any downstream signal/ruvector/train/mat crate.
pub mod radio_ops; pub mod radio_ops;
pub use csi_frame::{CsiFrame, CsiMetadata, SubcarrierData, Bandwidth, AntennaConfig}; pub use bridge::CsiData;
pub use csi_frame::{AntennaConfig, Bandwidth, CsiFrame, CsiMetadata, SubcarrierData};
pub use error::ParseError; pub use error::ParseError;
pub use esp32_parser::{ pub use esp32_parser::{
Esp32CsiParser, ruview_sibling_packet_name, ESP32_CSI_MAGIC, RUVIEW_VITALS_MAGIC, ruview_sibling_packet_name, Esp32CsiParser, ESP32_CSI_MAGIC, RUVIEW_COMPRESSED_CSI_MAGIC,
RUVIEW_FEATURE_MAGIC, RUVIEW_FUSED_VITALS_MAGIC, RUVIEW_COMPRESSED_CSI_MAGIC, RUVIEW_FEATURE_MAGIC, RUVIEW_FEATURE_STATE_MAGIC, RUVIEW_FUSED_VITALS_MAGIC,
RUVIEW_FEATURE_STATE_MAGIC, RUVIEW_TEMPORAL_MAGIC, RUVIEW_TEMPORAL_MAGIC, RUVIEW_VITALS_MAGIC,
}; };
pub use bridge::CsiData;
pub use radio_ops::{ pub use radio_ops::{
RadioOps, RadioMode, CaptureProfile, RadioHealth, RadioError, MockRadio, crc32_ieee, decode_anomaly_alert, decode_mesh, decode_node_status, encode_health, AnomalyAlert,
MeshRole, MeshMsgType, AuthClass, MeshHeader, NodeStatus, AnomalyAlert, AuthClass, CaptureProfile, MeshError, MeshHeader, MeshMsgType, MeshRole, MockRadio, NodeStatus,
MeshError, MESH_MAGIC, MESH_VERSION, MESH_HEADER_SIZE, MESH_MAX_PAYLOAD, RadioError, RadioHealth, RadioMode, RadioOps, MESH_HEADER_SIZE, MESH_MAGIC, MESH_MAX_PAYLOAD,
crc32_ieee, decode_mesh, decode_node_status, decode_anomaly_alert, MESH_VERSION,
encode_health,
}; };
@@ -24,10 +24,10 @@ use std::convert::TryFrom;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub enum RadioMode { pub enum RadioMode {
Disabled = 0, Disabled = 0,
PassiveRx = 1, PassiveRx = 1,
ActiveProbe = 2, ActiveProbe = 2,
Calibration = 3, Calibration = 3,
} }
/// Named capture profiles, mirror of `rv_capture_profile_t`. /// Named capture profiles, mirror of `rv_capture_profile_t`.
@@ -35,10 +35,10 @@ pub enum RadioMode {
#[repr(u8)] #[repr(u8)]
pub enum CaptureProfile { pub enum CaptureProfile {
PassiveLowRate = 0, PassiveLowRate = 0,
ActiveProbe = 1, ActiveProbe = 1,
RespHighSens = 2, RespHighSens = 2,
FastMotion = 3, FastMotion = 3,
Calibration = 4, Calibration = 4,
} }
impl TryFrom<u8> for CaptureProfile { impl TryFrom<u8> for CaptureProfile {
@@ -59,12 +59,12 @@ impl TryFrom<u8> for CaptureProfile {
#[derive(Debug, Clone, Copy, Default, PartialEq)] #[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct RadioHealth { pub struct RadioHealth {
pub pkt_yield_per_sec: u16, pub pkt_yield_per_sec: u16,
pub send_fail_count: u16, pub send_fail_count: u16,
pub rssi_median_dbm: i8, pub rssi_median_dbm: i8,
pub noise_floor_dbm: i8, pub noise_floor_dbm: i8,
pub current_channel: u8, pub current_channel: u8,
pub current_bw_mhz: u8, pub current_bw_mhz: u8,
pub current_profile: u8, pub current_profile: u8,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -95,12 +95,12 @@ pub trait RadioOps: Send + Sync {
/// A zero-hardware radio backend for host tests and CI. /// A zero-hardware radio backend for host tests and CI.
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct MockRadio { pub struct MockRadio {
pub health: RadioHealth, pub health: RadioHealth,
pub init_count: u32, pub init_count: u32,
pub channel_calls: Vec<(u8, u8)>, pub channel_calls: Vec<(u8, u8)>,
pub profile_calls: Vec<CaptureProfile>, pub profile_calls: Vec<CaptureProfile>,
pub mode_calls: Vec<RadioMode>, pub mode_calls: Vec<RadioMode>,
pub csi_enabled: bool, pub csi_enabled: bool,
} }
impl RadioOps for MockRadio { impl RadioOps for MockRadio {
@@ -111,7 +111,7 @@ impl RadioOps for MockRadio {
fn set_channel(&mut self, ch: u8, bw: u8) -> Result<(), RadioError> { fn set_channel(&mut self, ch: u8, bw: u8) -> Result<(), RadioError> {
self.channel_calls.push((ch, bw)); self.channel_calls.push((ch, bw));
self.health.current_channel = ch; self.health.current_channel = ch;
self.health.current_bw_mhz = bw; self.health.current_bw_mhz = bw;
Ok(()) Ok(())
} }
fn set_mode(&mut self, mode: RadioMode) -> Result<(), RadioError> { fn set_mode(&mut self, mode: RadioMode) -> Result<(), RadioError> {
@@ -137,9 +137,9 @@ impl RadioOps for MockRadio {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// `RV_MESH_MAGIC` from rv_mesh.h. /// `RV_MESH_MAGIC` from rv_mesh.h.
pub const MESH_MAGIC: u32 = 0xC511_8100; pub const MESH_MAGIC: u32 = 0xC511_8100;
/// `RV_MESH_VERSION` from rv_mesh.h. /// `RV_MESH_VERSION` from rv_mesh.h.
pub const MESH_VERSION: u8 = 1; pub const MESH_VERSION: u8 = 1;
/// `RV_MESH_MAX_PAYLOAD` from rv_mesh.h. /// `RV_MESH_MAX_PAYLOAD` from rv_mesh.h.
pub const MESH_MAX_PAYLOAD: usize = 256; pub const MESH_MAX_PAYLOAD: usize = 256;
/// `sizeof(rv_mesh_header_t)`. /// `sizeof(rv_mesh_header_t)`.
@@ -149,9 +149,9 @@ pub const MESH_HEADER_SIZE: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub enum MeshRole { pub enum MeshRole {
Unassigned = 0, Unassigned = 0,
Anchor = 1, Anchor = 1,
Observer = 2, Observer = 2,
FusionRelay = 3, FusionRelay = 3,
Coordinator = 4, Coordinator = 4,
} }
@@ -174,13 +174,13 @@ impl TryFrom<u8> for MeshRole {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub enum MeshMsgType { pub enum MeshMsgType {
TimeSync = 0x01, TimeSync = 0x01,
RoleAssign = 0x02, RoleAssign = 0x02,
ChannelPlan = 0x03, ChannelPlan = 0x03,
CalibrationStart = 0x04, CalibrationStart = 0x04,
FeatureDelta = 0x05, FeatureDelta = 0x05,
Health = 0x06, Health = 0x06,
AnomalyAlert = 0x07, AnomalyAlert = 0x07,
} }
impl TryFrom<u8> for MeshMsgType { impl TryFrom<u8> for MeshMsgType {
@@ -194,7 +194,7 @@ impl TryFrom<u8> for MeshMsgType {
0x05 => Ok(MeshMsgType::FeatureDelta), 0x05 => Ok(MeshMsgType::FeatureDelta),
0x06 => Ok(MeshMsgType::Health), 0x06 => Ok(MeshMsgType::Health),
0x07 => Ok(MeshMsgType::AnomalyAlert), 0x07 => Ok(MeshMsgType::AnomalyAlert),
_ => Err(MeshError::UnknownMsgType(v)), _ => Err(MeshError::UnknownMsgType(v)),
} }
} }
} }
@@ -203,44 +203,44 @@ impl TryFrom<u8> for MeshMsgType {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub enum AuthClass { pub enum AuthClass {
None = 0, None = 0,
HmacSession = 1, HmacSession = 1,
Ed25519Batch = 2, Ed25519Batch = 2,
} }
/// `rv_mesh_header_t`, 16 bytes. /// `rv_mesh_header_t`, 16 bytes.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct MeshHeader { pub struct MeshHeader {
pub msg_type: MeshMsgType, pub msg_type: MeshMsgType,
pub sender_role: MeshRole, pub sender_role: MeshRole,
pub auth_class: AuthClass, pub auth_class: AuthClass,
pub epoch: u32, pub epoch: u32,
pub payload_len: u16, pub payload_len: u16,
} }
/// `rv_node_status_t`, 28 bytes. /// `rv_node_status_t`, 28 bytes.
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct NodeStatus { pub struct NodeStatus {
pub node_id: [u8; 8], pub node_id: [u8; 8],
pub local_time_us: u64, pub local_time_us: u64,
pub role: MeshRole, pub role: MeshRole,
pub current_channel: u8, pub current_channel: u8,
pub current_bw: u8, pub current_bw: u8,
pub noise_floor_dbm: i8, pub noise_floor_dbm: i8,
pub pkt_yield: u16, pub pkt_yield: u16,
pub sync_error_us: u16, pub sync_error_us: u16,
pub health_flags: u16, pub health_flags: u16,
} }
/// `rv_anomaly_alert_t`, 28 bytes. /// `rv_anomaly_alert_t`, 28 bytes.
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct AnomalyAlert { pub struct AnomalyAlert {
pub node_id: [u8; 8], pub node_id: [u8; 8],
pub ts_us: u64, pub ts_us: u64,
pub severity: u8, pub severity: u8,
pub reason: u8, pub reason: u8,
pub anomaly_score: f32, pub anomaly_score: f32,
pub motion_score: f32, pub motion_score: f32,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -262,7 +262,11 @@ pub enum MeshError {
#[error("unknown auth class: {0}")] #[error("unknown auth class: {0}")]
UnknownAuth(u8), UnknownAuth(u8),
#[error("payload size mismatch for {which}: got {got}, want {want}")] #[error("payload size mismatch for {which}: got {got}, want {want}")]
PayloadSizeMismatch { which: &'static str, got: usize, want: usize }, PayloadSizeMismatch {
which: &'static str,
got: usize,
want: usize,
},
} }
/// IEEE CRC32 — matches the bit-by-bit implementation in /// IEEE CRC32 — matches the bit-by-bit implementation in
@@ -287,15 +291,19 @@ pub fn decode_mesh(buf: &[u8]) -> Result<(MeshHeader, &[u8]), MeshError> {
} }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]); let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != MESH_MAGIC { return Err(MeshError::BadMagic(magic)); } if magic != MESH_MAGIC {
return Err(MeshError::BadMagic(magic));
}
let version = buf[4]; let version = buf[4];
if version != MESH_VERSION { return Err(MeshError::BadVersion(version)); } if version != MESH_VERSION {
return Err(MeshError::BadVersion(version));
}
let ty = buf[5]; let ty = buf[5];
let sender_role = buf[6]; let sender_role = buf[6];
let auth_class = buf[7]; let auth_class = buf[7];
let epoch = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]); let epoch = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
let payload_len = u16::from_le_bytes([buf[12], buf[13]]); let payload_len = u16::from_le_bytes([buf[12], buf[13]]);
if payload_len as usize > MESH_MAX_PAYLOAD { if payload_len as usize > MESH_MAX_PAYLOAD {
@@ -303,20 +311,28 @@ pub fn decode_mesh(buf: &[u8]) -> Result<(MeshHeader, &[u8]), MeshError> {
} }
let total = MESH_HEADER_SIZE + payload_len as usize + 4; let total = MESH_HEADER_SIZE + payload_len as usize + 4;
if buf.len() < total { return Err(MeshError::TooShort(buf.len())); } if buf.len() < total {
return Err(MeshError::TooShort(buf.len()));
let want_crc = crc32_ieee(&buf[..MESH_HEADER_SIZE + payload_len as usize]);
let crc_off = MESH_HEADER_SIZE + payload_len as usize;
let got_crc = u32::from_le_bytes([
buf[crc_off], buf[crc_off + 1], buf[crc_off + 2], buf[crc_off + 3],
]);
if got_crc != want_crc {
return Err(MeshError::CrcMismatch { got: got_crc, want: want_crc });
} }
let msg_type = MeshMsgType::try_from(ty)?; let want_crc = crc32_ieee(&buf[..MESH_HEADER_SIZE + payload_len as usize]);
let crc_off = MESH_HEADER_SIZE + payload_len as usize;
let got_crc = u32::from_le_bytes([
buf[crc_off],
buf[crc_off + 1],
buf[crc_off + 2],
buf[crc_off + 3],
]);
if got_crc != want_crc {
return Err(MeshError::CrcMismatch {
got: got_crc,
want: want_crc,
});
}
let msg_type = MeshMsgType::try_from(ty)?;
let sender_role = MeshRole::try_from(sender_role)?; let sender_role = MeshRole::try_from(sender_role)?;
let auth_class = match auth_class { let auth_class = match auth_class {
0 => AuthClass::None, 0 => AuthClass::None,
1 => AuthClass::HmacSession, 1 => AuthClass::HmacSession,
2 => AuthClass::Ed25519Batch, 2 => AuthClass::Ed25519Batch,
@@ -324,8 +340,14 @@ pub fn decode_mesh(buf: &[u8]) -> Result<(MeshHeader, &[u8]), MeshError> {
}; };
Ok(( Ok((
MeshHeader { msg_type, sender_role, auth_class, epoch, payload_len }, MeshHeader {
&buf[MESH_HEADER_SIZE .. MESH_HEADER_SIZE + payload_len as usize], msg_type,
sender_role,
auth_class,
epoch,
payload_len,
},
&buf[MESH_HEADER_SIZE..MESH_HEADER_SIZE + payload_len as usize],
)) ))
} }
@@ -333,24 +355,24 @@ pub fn decode_mesh(buf: &[u8]) -> Result<(MeshHeader, &[u8]), MeshError> {
pub fn decode_node_status(p: &[u8]) -> Result<NodeStatus, MeshError> { pub fn decode_node_status(p: &[u8]) -> Result<NodeStatus, MeshError> {
if p.len() != 28 { if p.len() != 28 {
return Err(MeshError::PayloadSizeMismatch { return Err(MeshError::PayloadSizeMismatch {
which: "HEALTH", got: p.len(), want: 28, which: "HEALTH",
got: p.len(),
want: 28,
}); });
} }
let mut node_id = [0u8; 8]; let mut node_id = [0u8; 8];
node_id.copy_from_slice(&p[0..8]); node_id.copy_from_slice(&p[0..8]);
let local_time_us = u64::from_le_bytes([ let local_time_us = u64::from_le_bytes([p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15]]);
p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15],
]);
Ok(NodeStatus { Ok(NodeStatus {
node_id, node_id,
local_time_us, local_time_us,
role: MeshRole::try_from(p[16])?, role: MeshRole::try_from(p[16])?,
current_channel: p[17], current_channel: p[17],
current_bw: p[18], current_bw: p[18],
noise_floor_dbm: p[19] as i8, noise_floor_dbm: p[19] as i8,
pkt_yield: u16::from_le_bytes([p[20], p[21]]), pkt_yield: u16::from_le_bytes([p[20], p[21]]),
sync_error_us: u16::from_le_bytes([p[22], p[23]]), sync_error_us: u16::from_le_bytes([p[22], p[23]]),
health_flags: u16::from_le_bytes([p[24], p[25]]), health_flags: u16::from_le_bytes([p[24], p[25]]),
}) })
} }
@@ -358,31 +380,29 @@ pub fn decode_node_status(p: &[u8]) -> Result<NodeStatus, MeshError> {
pub fn decode_anomaly_alert(p: &[u8]) -> Result<AnomalyAlert, MeshError> { pub fn decode_anomaly_alert(p: &[u8]) -> Result<AnomalyAlert, MeshError> {
if p.len() != 28 { if p.len() != 28 {
return Err(MeshError::PayloadSizeMismatch { return Err(MeshError::PayloadSizeMismatch {
which: "ANOMALY_ALERT", got: p.len(), want: 28, which: "ANOMALY_ALERT",
got: p.len(),
want: 28,
}); });
} }
let mut node_id = [0u8; 8]; let mut node_id = [0u8; 8];
node_id.copy_from_slice(&p[0..8]); node_id.copy_from_slice(&p[0..8]);
let ts_us = u64::from_le_bytes([ let ts_us = u64::from_le_bytes([p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15]]);
p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15],
]);
let anomaly_score = f32::from_le_bytes([p[20], p[21], p[22], p[23]]); let anomaly_score = f32::from_le_bytes([p[20], p[21], p[22], p[23]]);
let motion_score = f32::from_le_bytes([p[24], p[25], p[26], p[27]]); let motion_score = f32::from_le_bytes([p[24], p[25], p[26], p[27]]);
Ok(AnomalyAlert { Ok(AnomalyAlert {
node_id, ts_us, node_id,
ts_us,
severity: p[16], severity: p[16],
reason: p[17], reason: p[17],
anomaly_score, motion_score, anomaly_score,
motion_score,
}) })
} }
/// Encode a `HEALTH` payload. Produces the 16-byte header, 28-byte /// Encode a `HEALTH` payload. Produces the 16-byte header, 28-byte
/// payload, and 4-byte CRC — bit-identical to what the firmware emits. /// payload, and 4-byte CRC — bit-identical to what the firmware emits.
pub fn encode_health( pub fn encode_health(sender_role: MeshRole, epoch: u32, status: &NodeStatus) -> Vec<u8> {
sender_role: MeshRole,
epoch: u32,
status: &NodeStatus,
) -> Vec<u8> {
let payload_len: u16 = 28; let payload_len: u16 = 28;
let mut buf = Vec::with_capacity(MESH_HEADER_SIZE + payload_len as usize + 4); let mut buf = Vec::with_capacity(MESH_HEADER_SIZE + payload_len as usize + 4);
@@ -394,7 +414,7 @@ pub fn encode_health(
buf.push(AuthClass::None as u8); buf.push(AuthClass::None as u8);
buf.extend_from_slice(&epoch.to_le_bytes()); buf.extend_from_slice(&epoch.to_le_bytes());
buf.extend_from_slice(&payload_len.to_le_bytes()); buf.extend_from_slice(&payload_len.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes()); // reserved buf.extend_from_slice(&0u16.to_le_bytes()); // reserved
// payload // payload
buf.extend_from_slice(&status.node_id); buf.extend_from_slice(&status.node_id);
@@ -406,7 +426,7 @@ pub fn encode_health(
buf.extend_from_slice(&status.pkt_yield.to_le_bytes()); buf.extend_from_slice(&status.pkt_yield.to_le_bytes());
buf.extend_from_slice(&status.sync_error_us.to_le_bytes()); buf.extend_from_slice(&status.sync_error_us.to_le_bytes());
buf.extend_from_slice(&status.health_flags.to_le_bytes()); buf.extend_from_slice(&status.health_flags.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes()); // reserved buf.extend_from_slice(&0u16.to_le_bytes()); // reserved
let crc = crc32_ieee(&buf); let crc = crc32_ieee(&buf);
buf.extend_from_slice(&crc.to_le_bytes()); buf.extend_from_slice(&crc.to_le_bytes());
@@ -444,8 +464,8 @@ mod tests {
fn crc32_matches_firmware_vectors() { fn crc32_matches_firmware_vectors() {
// Same vectors as test_rv_feature_state.c // Same vectors as test_rv_feature_state.c
assert_eq!(crc32_ieee(b"123456789"), 0xCBF43926); assert_eq!(crc32_ieee(b"123456789"), 0xCBF43926);
assert_eq!(crc32_ieee(&[]), 0x00000000); assert_eq!(crc32_ieee(&[]), 0x00000000);
assert_eq!(crc32_ieee(&[0u8]), 0xD202EF8D); assert_eq!(crc32_ieee(&[0u8]), 0xD202EF8D);
} }
#[test] #[test]
@@ -490,7 +510,7 @@ mod tests {
health_flags: 0, health_flags: 0,
}; };
let mut wire = encode_health(MeshRole::Observer, 0, &st); let mut wire = encode_health(MeshRole::Observer, 0, &st);
let p0 = MESH_HEADER_SIZE; // first payload byte let p0 = MESH_HEADER_SIZE; // first payload byte
wire[p0] ^= 0xFF; wire[p0] ^= 0xFF;
let err = decode_mesh(&wire).unwrap_err(); let err = decode_mesh(&wire).unwrap_err();
assert!(matches!(err, MeshError::CrcMismatch { .. })); assert!(matches!(err, MeshError::CrcMismatch { .. }));
@@ -10,31 +10,39 @@
//! - Localization algorithms (triangulation, depth estimation) //! - Localization algorithms (triangulation, depth estimation)
//! - Alert generation //! - Alert generation
use criterion::{ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
};
use std::f64::consts::PI; use std::f64::consts::PI;
use wifi_densepose_mat::{ use wifi_densepose_mat::{
// Detection types
BreathingDetector, BreathingDetectorConfig,
HeartbeatDetector, HeartbeatDetectorConfig,
MovementClassifier, MovementClassifierConfig,
DetectionConfig, DetectionPipeline, VitalSignsDetector,
// Localization types
Triangulator, DepthEstimator,
// Alerting types // Alerting types
AlertGenerator, AlertGenerator,
// Detection types
BreathingDetector,
BreathingDetectorConfig,
// Domain types exported at crate root // Domain types exported at crate root
BreathingPattern, BreathingType, VitalSignsReading, BreathingPattern,
MovementProfile, ScanZoneId, Survivor, BreathingType,
DepthEstimator,
DetectionConfig,
DetectionPipeline,
HeartbeatDetector,
HeartbeatDetectorConfig,
MovementClassifier,
MovementClassifierConfig,
MovementProfile,
ScanZoneId,
Survivor,
// Localization types
Triangulator,
VitalSignsDetector,
VitalSignsReading,
}; };
// Types that need to be accessed from submodules // Types that need to be accessed from submodules
use wifi_densepose_mat::detection::CsiDataBuffer; use wifi_densepose_mat::detection::CsiDataBuffer;
use wifi_densepose_mat::domain::{ use wifi_densepose_mat::domain::{
ConfidenceScore, SensorPosition, SensorType, ConfidenceScore, DebrisMaterial, DebrisProfile, MetalContent, MoistureLevel, SensorPosition,
DebrisProfile, DebrisMaterial, MoistureLevel, MetalContent, SensorType,
}; };
use chrono::Utc; use chrono::Utc;
@@ -140,7 +148,8 @@ fn generate_multi_person_signal(
(0..num_samples) (0..num_samples)
.map(|i| { .map(|i| {
let t = i as f64 / sample_rate; let t = i as f64 / sample_rate;
base_rates.iter() base_rates
.iter()
.enumerate() .enumerate()
.map(|(idx, &rate)| { .map(|(idx, &rate)| {
let freq = rate / 60.0; let freq = rate / 60.0;
@@ -154,22 +163,26 @@ fn generate_multi_person_signal(
} }
/// Generate movement signal with specified characteristics /// Generate movement signal with specified characteristics
fn generate_movement_signal( fn generate_movement_signal(movement_type: &str, sample_rate: f64, duration_secs: f64) -> Vec<f64> {
movement_type: &str,
sample_rate: f64,
duration_secs: f64,
) -> Vec<f64> {
let num_samples = (sample_rate * duration_secs) as usize; let num_samples = (sample_rate * duration_secs) as usize;
match movement_type { match movement_type {
"gross" => { "gross" => {
// Large, irregular movements // Large, irregular movements
let mut signal = vec![0.0; num_samples]; let mut signal = vec![0.0; num_samples];
for i in (num_samples / 4)..(num_samples / 2) { for s in signal
signal[i] = 2.0; .iter_mut()
.take(num_samples / 2)
.skip(num_samples / 4)
{
*s = 2.0;
} }
for i in (3 * num_samples / 4)..(4 * num_samples / 5) { for s in signal
signal[i] = -1.5; .iter_mut()
.take(4 * num_samples / 5)
.skip(3 * num_samples / 4)
{
*s = -1.5;
} }
signal signal
} }
@@ -259,9 +272,7 @@ fn bench_breathing_detection(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("clean_signal", format!("{}s", duration as u32)), BenchmarkId::new("clean_signal", format!("{}s", duration as u32)),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| detector.detect(black_box(signal), black_box(sample_rate))),
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
},
); );
} }
@@ -270,11 +281,12 @@ fn bench_breathing_detection(c: &mut Criterion) {
let signal = generate_noisy_breathing_signal(16.0, sample_rate, 30.0, noise_level); let signal = generate_noisy_breathing_signal(16.0, sample_rate, 30.0, noise_level);
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("noisy_signal", format!("noise_{}", (noise_level * 10.0) as u32)), BenchmarkId::new(
"noisy_signal",
format!("noise_{}", (noise_level * 10.0) as u32),
),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| detector.detect(black_box(signal), black_box(sample_rate))),
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
},
); );
} }
@@ -285,9 +297,7 @@ fn bench_breathing_detection(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)), BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| detector.detect(black_box(signal), black_box(sample_rate))),
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
},
); );
} }
@@ -306,9 +316,7 @@ fn bench_breathing_detection(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("high_sensitivity", "30s_noisy"), BenchmarkId::new("high_sensitivity", "30s_noisy"),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| sensitive_detector.detect(black_box(signal), black_box(sample_rate))),
b.iter(|| sensitive_detector.detect(black_box(signal), black_box(sample_rate)))
},
); );
group.finish(); group.finish();
@@ -333,9 +341,7 @@ fn bench_heartbeat_detection(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("clean_signal", format!("{}s", duration as u32)), BenchmarkId::new("clean_signal", format!("{}s", duration as u32)),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None)),
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None))
},
); );
} }
@@ -362,9 +368,7 @@ fn bench_heartbeat_detection(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)), BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None)),
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None))
},
); );
} }
@@ -410,9 +414,7 @@ fn bench_movement_classification(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("movement_type", movement_type), BenchmarkId::new("movement_type", movement_type),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate))),
b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate)))
},
); );
} }
@@ -423,9 +425,7 @@ fn bench_movement_classification(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("signal_length", format!("{}s", duration as u32)), BenchmarkId::new("signal_length", format!("{}s", duration as u32)),
&signal, &signal,
|b, signal| { |b, signal| b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate))),
b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate)))
},
); );
} }
@@ -480,7 +480,8 @@ fn bench_detection_pipeline(c: &mut Criterion) {
// Benchmark standard pipeline at different data sizes // Benchmark standard pipeline at different data sizes
for duration in [5.0, 10.0, 30.0] { for duration in [5.0, 10.0, 30.0] {
let (amplitudes, phases) = generate_combined_vital_signal(16.0, 72.0, sample_rate, duration); let (amplitudes, phases) =
generate_combined_vital_signal(16.0, 72.0, sample_rate, duration);
let mut buffer = CsiDataBuffer::new(sample_rate); let mut buffer = CsiDataBuffer::new(sample_rate);
buffer.add_samples(&amplitudes, &phases); buffer.add_samples(&amplitudes, &phases);
@@ -488,9 +489,7 @@ fn bench_detection_pipeline(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("standard_pipeline", format!("{}s", duration as u32)), BenchmarkId::new("standard_pipeline", format!("{}s", duration as u32)),
&buffer, &buffer,
|b, buffer| { |b, buffer| b.iter(|| standard_pipeline.detect(black_box(buffer))),
b.iter(|| standard_pipeline.detect(black_box(buffer)))
},
); );
} }
@@ -503,9 +502,7 @@ fn bench_detection_pipeline(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("full_pipeline", format!("{}s", duration as u32)), BenchmarkId::new("full_pipeline", format!("{}s", duration as u32)),
&buffer, &buffer,
|b, buffer| { |b, buffer| b.iter(|| full_pipeline.detect(black_box(buffer))),
b.iter(|| full_pipeline.detect(black_box(buffer)))
},
); );
} }
@@ -518,9 +515,7 @@ fn bench_detection_pipeline(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("multi_person", format!("{}_people", person_count)), BenchmarkId::new("multi_person", format!("{}_people", person_count)),
&buffer, &buffer,
|b, buffer| { |b, buffer| b.iter(|| standard_pipeline.detect(black_box(buffer))),
b.iter(|| standard_pipeline.detect(black_box(buffer)))
},
); );
} }
@@ -541,7 +536,8 @@ fn bench_triangulation(c: &mut Criterion) {
let sensors = create_test_sensors(sensor_count); let sensors = create_test_sensors(sensor_count);
// Generate RSSI values (simulate target at center) // Generate RSSI values (simulate target at center)
let rssi_values: Vec<(String, f64)> = sensors.iter() let rssi_values: Vec<(String, f64)> = sensors
.iter()
.map(|s| { .map(|s| {
let distance = (s.x * s.x + s.y * s.y).sqrt(); let distance = (s.x * s.x + s.y * s.y).sqrt();
let rssi = -30.0 - 20.0 * distance.log10(); // Path loss model let rssi = -30.0 - 20.0 * distance.log10(); // Path loss model
@@ -553,9 +549,7 @@ fn bench_triangulation(c: &mut Criterion) {
BenchmarkId::new("rssi_position", format!("{}_sensors", sensor_count)), BenchmarkId::new("rssi_position", format!("{}_sensors", sensor_count)),
&(sensors.clone(), rssi_values.clone()), &(sensors.clone(), rssi_values.clone()),
|b, (sensors, rssi)| { |b, (sensors, rssi)| {
b.iter(|| { b.iter(|| triangulator.estimate_position(black_box(sensors), black_box(rssi)))
triangulator.estimate_position(black_box(sensors), black_box(rssi))
})
}, },
); );
} }
@@ -565,7 +559,8 @@ fn bench_triangulation(c: &mut Criterion) {
let sensors = create_test_sensors(sensor_count); let sensors = create_test_sensors(sensor_count);
// Generate ToA values (time in nanoseconds) // Generate ToA values (time in nanoseconds)
let toa_values: Vec<(String, f64)> = sensors.iter() let toa_values: Vec<(String, f64)> = sensors
.iter()
.map(|s| { .map(|s| {
let distance = (s.x * s.x + s.y * s.y).sqrt(); let distance = (s.x * s.x + s.y * s.y).sqrt();
// Round trip time: 2 * distance / speed_of_light // Round trip time: 2 * distance / speed_of_light
@@ -578,9 +573,7 @@ fn bench_triangulation(c: &mut Criterion) {
BenchmarkId::new("toa_position", format!("{}_sensors", sensor_count)), BenchmarkId::new("toa_position", format!("{}_sensors", sensor_count)),
&(sensors.clone(), toa_values.clone()), &(sensors.clone(), toa_values.clone()),
|b, (sensors, toa)| { |b, (sensors, toa)| {
b.iter(|| { b.iter(|| triangulator.estimate_from_toa(black_box(sensors), black_box(toa)))
triangulator.estimate_from_toa(black_box(sensors), black_box(toa))
})
}, },
); );
} }
@@ -588,7 +581,8 @@ fn bench_triangulation(c: &mut Criterion) {
// Benchmark with noisy measurements // Benchmark with noisy measurements
let sensors = create_test_sensors(5); let sensors = create_test_sensors(5);
for noise_pct in [0, 5, 10, 20] { for noise_pct in [0, 5, 10, 20] {
let rssi_values: Vec<(String, f64)> = sensors.iter() let rssi_values: Vec<(String, f64)> = sensors
.iter()
.enumerate() .enumerate()
.map(|(i, s)| { .map(|(i, s)| {
let distance = (s.x * s.x + s.y * s.y).sqrt(); let distance = (s.x * s.x + s.y * s.y).sqrt();
@@ -603,9 +597,7 @@ fn bench_triangulation(c: &mut Criterion) {
BenchmarkId::new("noisy_rssi", format!("{}pct_noise", noise_pct)), BenchmarkId::new("noisy_rssi", format!("{}pct_noise", noise_pct)),
&(sensors.clone(), rssi_values.clone()), &(sensors.clone(), rssi_values.clone()),
|b, (sensors, rssi)| { |b, (sensors, rssi)| {
b.iter(|| { b.iter(|| triangulator.estimate_position(black_box(sensors), black_box(rssi)))
triangulator.estimate_position(black_box(sensors), black_box(rssi))
})
}, },
); );
} }
@@ -662,11 +654,7 @@ fn bench_depth_estimation(c: &mut Criterion) {
&debris, &debris,
|b, debris| { |b, debris| {
b.iter(|| { b.iter(|| {
estimator.estimate_depth( estimator.estimate_depth(black_box(30.0), black_box(5.0), black_box(debris))
black_box(30.0),
black_box(5.0),
black_box(debris),
)
}) })
}, },
); );
@@ -699,21 +687,20 @@ fn bench_depth_estimation(c: &mut Criterion) {
} }
// Benchmark debris profile estimation // Benchmark debris profile estimation
for (variance, multipath, moisture) in [ for (variance, multipath, moisture) in [(0.2, 0.3, 0.2), (0.5, 0.5, 0.5), (0.7, 0.8, 0.8)] {
(0.2, 0.3, 0.2),
(0.5, 0.5, 0.5),
(0.7, 0.8, 0.8),
] {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("profile_estimation", format!("v{}_m{}", (variance * 10.0) as u32, (multipath * 10.0) as u32)), BenchmarkId::new(
"profile_estimation",
format!(
"v{}_m{}",
(variance * 10.0) as u32,
(multipath * 10.0) as u32
),
),
&(variance, multipath, moisture), &(variance, multipath, moisture),
|b, &(v, m, mo)| { |b, &(v, m, mo)| {
b.iter(|| { b.iter(|| {
estimator.estimate_debris_profile( estimator.estimate_debris_profile(black_box(v), black_box(m), black_box(mo))
black_box(v),
black_box(m),
black_box(mo),
)
}) })
}, },
); );
@@ -740,10 +727,8 @@ fn bench_alert_generation(c: &mut Criterion) {
// Benchmark escalation alert // Benchmark escalation alert
group.bench_function("generate_escalation_alert", |b| { group.bench_function("generate_escalation_alert", |b| {
b.iter(|| { b.iter(|| {
generator.generate_escalation( generator
black_box(&survivor), .generate_escalation(black_box(&survivor), black_box("Vital signs deteriorating"))
black_box("Vital signs deteriorating"),
)
}) })
}); });
@@ -751,10 +736,7 @@ fn bench_alert_generation(c: &mut Criterion) {
use wifi_densepose_mat::domain::TriageStatus; use wifi_densepose_mat::domain::TriageStatus;
group.bench_function("generate_status_change_alert", |b| { group.bench_function("generate_status_change_alert", |b| {
b.iter(|| { b.iter(|| {
generator.generate_status_change( generator.generate_status_change(black_box(&survivor), black_box(&TriageStatus::Minor))
black_box(&survivor),
black_box(&TriageStatus::Minor),
)
}) })
}); });
@@ -773,7 +755,8 @@ fn bench_alert_generation(c: &mut Criterion) {
group.bench_function("batch_generate_10_alerts", |b| { group.bench_function("batch_generate_10_alerts", |b| {
b.iter(|| { b.iter(|| {
survivors.iter() survivors
.iter()
.map(|s| generator.generate(black_box(s))) .map(|s| generator.generate(black_box(s)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
@@ -796,9 +779,7 @@ fn bench_csi_buffer(c: &mut Criterion) {
let amplitudes: Vec<f64> = (0..sample_count) let amplitudes: Vec<f64> = (0..sample_count)
.map(|i| (i as f64 / 100.0).sin()) .map(|i| (i as f64 / 100.0).sin())
.collect(); .collect();
let phases: Vec<f64> = (0..sample_count) let phases: Vec<f64> = (0..sample_count).map(|i| (i as f64 / 50.0).cos()).collect();
.map(|i| (i as f64 / 50.0).cos())
.collect();
group.throughput(Throughput::Elements(sample_count as u64)); group.throughput(Throughput::Elements(sample_count as u64));
group.bench_with_input( group.bench_with_input(
@@ -1,8 +1,8 @@
//! Alert dispatching and delivery. //! Alert dispatching and delivery.
use super::AlertGenerator;
use crate::domain::{Alert, AlertId, Priority, Survivor}; use crate::domain::{Alert, AlertId, Priority, Survivor};
use crate::MatError; use crate::MatError;
use super::AlertGenerator;
use std::collections::HashMap; use std::collections::HashMap;
/// Configuration for alert dispatch /// Configuration for alert dispatch
@@ -67,7 +67,9 @@ impl AlertDispatcher {
let priority = alert.priority(); let priority = alert.priority();
// Store in pending alerts // Store in pending alerts
self.pending_alerts.write().insert(alert_id.clone(), alert.clone()); self.pending_alerts
.write()
.insert(alert_id.clone(), alert.clone());
// Log the alert // Log the alert
tracing::info!( tracing::info!(
@@ -121,7 +123,11 @@ impl AlertDispatcher {
} }
/// Resolve an alert /// Resolve an alert
pub fn resolve(&self, alert_id: &AlertId, resolution: crate::domain::AlertResolution) -> Result<(), MatError> { pub fn resolve(
&self,
alert_id: &AlertId,
resolution: crate::domain::AlertResolution,
) -> Result<(), MatError> {
let mut alerts = self.pending_alerts.write(); let mut alerts = self.pending_alerts.write();
if let Some(alert) = alerts.remove(alert_id) { if let Some(alert) = alerts.remove(alert_id) {
@@ -191,7 +197,9 @@ impl AlertDispatcher {
/// Escalate oldest pending alerts /// Escalate oldest pending alerts
async fn escalate_oldest(&self) -> Result<(), MatError> { async fn escalate_oldest(&self) -> Result<(), MatError> {
let mut alerts: Vec<_> = self.pending_alerts.read() let mut alerts: Vec<_> = self
.pending_alerts
.read()
.iter() .iter()
.map(|(id, alert)| (id.clone(), *alert.created_at())) .map(|(id, alert)| (id.clone(), *alert.created_at()))
.collect(); .collect();
@@ -229,6 +237,7 @@ pub trait AlertHandler: Send + Sync {
} }
/// Console/logging alert handler /// Console/logging alert handler
#[allow(dead_code)]
pub struct ConsoleAlertHandler; pub struct ConsoleAlertHandler;
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -264,6 +273,7 @@ impl AlertHandler for ConsoleAlertHandler {
/// Requires platform audio support. On systems without audio hardware /// Requires platform audio support. On systems without audio hardware
/// (headless servers, embedded), this logs the alert pattern. On systems /// (headless servers, embedded), this logs the alert pattern. On systems
/// with audio, integrate with the platform's audio API. /// with audio, integrate with the platform's audio API.
#[allow(dead_code)]
pub struct AudioAlertHandler { pub struct AudioAlertHandler {
/// Whether audio hardware is available /// Whether audio hardware is available
audio_available: bool, audio_available: bool,
@@ -271,15 +281,19 @@ pub struct AudioAlertHandler {
impl AudioAlertHandler { impl AudioAlertHandler {
/// Create a new audio handler, auto-detecting audio support. /// Create a new audio handler, auto-detecting audio support.
#[allow(dead_code)]
pub fn new() -> Self { pub fn new() -> Self {
let audio_available = std::env::var("DISPLAY").is_ok() let audio_available =
|| std::env::var("PULSE_SERVER").is_ok(); std::env::var("DISPLAY").is_ok() || std::env::var("PULSE_SERVER").is_ok();
Self { audio_available } Self { audio_available }
} }
/// Create with explicit audio availability flag. /// Create with explicit audio availability flag.
#[allow(dead_code)]
pub fn with_availability(available: bool) -> Self { pub fn with_availability(available: bool) -> Self {
Self { audio_available: available } Self {
audio_available: available,
}
} }
} }
@@ -320,7 +334,7 @@ impl AlertHandler for AudioAlertHandler {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{SurvivorId, TriageStatus, AlertPayload}; use crate::domain::{AlertPayload, SurvivorId, TriageStatus};
fn create_test_alert() -> Alert { fn create_test_alert() -> Alert {
Alert::new( Alert::new(
@@ -352,7 +366,9 @@ mod tests {
assert!(result.is_ok()); assert!(result.is_ok());
let pending = dispatcher.pending(); let pending = dispatcher.pending();
assert!(pending.iter().any(|a| a.id() == &alert_id && a.acknowledged_by() == Some("Team Alpha"))); assert!(pending
.iter()
.any(|a| a.id() == &alert_id && a.acknowledged_by() == Some("Team Alpha")));
} }
#[tokio::test] #[tokio::test]
@@ -1,8 +1,6 @@
//! Alert generation from survivor detections. //! Alert generation from survivor detections.
use crate::domain::{ use crate::domain::{Alert, AlertPayload, Priority, ScanZoneId, Survivor, TriageStatus};
Alert, AlertPayload, Priority, Survivor, TriageStatus, ScanZoneId,
};
use crate::MatError; use crate::MatError;
/// Generator for alerts based on survivor status /// Generator for alerts based on survivor status
@@ -40,10 +38,7 @@ impl AlertGenerator {
) -> Result<Alert, MatError> { ) -> Result<Alert, MatError> {
let mut payload = self.create_payload(survivor); let mut payload = self.create_payload(survivor);
payload.title = format!("ESCALATED: {}", payload.title); payload.title = format!("ESCALATED: {}", payload.title);
payload.message = format!( payload.message = format!("{}\n\nReason for escalation: {}", payload.message, reason);
"{}\n\nReason for escalation: {}",
payload.message, reason
);
// Escalated alerts are always at least high priority // Escalated alerts are always at least high priority
let priority = match survivor.triage_status() { let priority = match survivor.triage_status() {
@@ -64,7 +59,8 @@ impl AlertGenerator {
payload.title = format!( payload.title = format!(
"Status Change: {} → {}", "Status Change: {} → {}",
previous_status, survivor.triage_status() previous_status,
survivor.triage_status()
); );
// Determine if this is an upgrade (worse) or downgrade (better) // Determine if this is an upgrade (worse) or downgrade (better)
@@ -97,7 +93,8 @@ impl AlertGenerator {
/// Create alert payload from survivor data /// Create alert payload from survivor data
fn create_payload(&self, survivor: &Survivor) -> AlertPayload { fn create_payload(&self, survivor: &Survivor) -> AlertPayload {
let zone_name = self.zone_names let zone_name = self
.zone_names
.get(survivor.zone_id()) .get(survivor.zone_id())
.map(String::as_str) .map(String::as_str)
.unwrap_or("Unknown Zone"); .unwrap_or("Unknown Zone");
@@ -159,8 +156,7 @@ impl AlertGenerator {
lines.push(format!( lines.push(format!(
" Movement: {:?} (intensity: {:.1})", " Movement: {:?} (intensity: {:.1})",
reading.movement.movement_type, reading.movement.movement_type, reading.movement.intensity
reading.movement.intensity
)); ));
} else { } else {
lines.push(" No recent readings".to_string()); lines.push(" No recent readings".to_string());
@@ -183,9 +179,7 @@ impl AlertGenerator {
" Position: ({:.1}, {:.1})\n\ " Position: ({:.1}, {:.1})\n\
Depth: {}\n\ Depth: {}\n\
Uncertainty: ±{:.1}m", Uncertainty: ±{:.1}m",
loc.x, loc.y, loc.x, loc.y, depth_str, loc.uncertainty.horizontal_error
depth_str,
loc.uncertainty.horizontal_error
) )
} }
None => " Position not yet determined".to_string(), None => " Position not yet determined".to_string(),
@@ -266,11 +260,15 @@ mod tests {
let generator = AlertGenerator::new(); let generator = AlertGenerator::new();
let survivor = create_test_survivor(); let survivor = create_test_survivor();
let alert = generator.generate_escalation(&survivor, "Vital signs deteriorating") let alert = generator
.generate_escalation(&survivor, "Vital signs deteriorating")
.unwrap(); .unwrap();
assert!(alert.payload().title.contains("ESCALATED")); assert!(alert.payload().title.contains("ESCALATED"));
assert!(matches!(alert.priority(), Priority::Critical | Priority::High)); assert!(matches!(
alert.priority(),
Priority::Critical | Priority::High
));
} }
#[test] #[test]
@@ -278,10 +276,9 @@ mod tests {
let generator = AlertGenerator::new(); let generator = AlertGenerator::new();
let survivor = create_test_survivor(); let survivor = create_test_survivor();
let alert = generator.generate_status_change( let alert = generator
&survivor, .generate_status_change(&survivor, &TriageStatus::Minor)
&TriageStatus::Minor, .unwrap();
).unwrap();
assert!(alert.payload().title.contains("Status Change")); assert!(alert.payload().title.contains("Status Change"));
} }
@@ -1,9 +1,9 @@
//! Alerting module for emergency notifications. //! Alerting module for emergency notifications.
mod generator;
mod dispatcher; mod dispatcher;
mod generator;
mod triage_service; mod triage_service;
pub use dispatcher::{AlertConfig, AlertDispatcher};
pub use generator::AlertGenerator; pub use generator::AlertGenerator;
pub use dispatcher::{AlertDispatcher, AlertConfig}; pub use triage_service::{PriorityCalculator, TriageService};
pub use triage_service::{TriageService, PriorityCalculator};
@@ -1,8 +1,7 @@
//! Triage service for calculating and updating survivor priority. //! Triage service for calculating and updating survivor priority.
use crate::domain::{ use crate::domain::{
Priority, Survivor, TriageStatus, VitalSignsReading, triage::TriageCalculator, Priority, Survivor, TriageStatus, VitalSignsReading,
triage::TriageCalculator,
}; };
/// Service for triage operations /// Service for triage operations
@@ -16,10 +15,7 @@ impl TriageService {
/// Check if survivor should be upgraded /// Check if survivor should be upgraded
pub fn should_upgrade(survivor: &Survivor) -> bool { pub fn should_upgrade(survivor: &Survivor) -> bool {
TriageCalculator::should_upgrade( TriageCalculator::should_upgrade(survivor.triage_status(), survivor.is_deteriorating())
survivor.triage_status(),
survivor.is_deteriorating(),
)
} }
/// Get upgraded status /// Get upgraded status
@@ -189,9 +185,14 @@ impl MassCasualtyAssessment {
Total: {} (Living: {}, Deceased: {})\n\ Total: {} (Living: {}, Deceased: {})\n\
Immediate: {}, Delayed: {}, Minor: {}\n\ Immediate: {}, Delayed: {}, Minor: {}\n\
Severity: {:?}, Resources: {:?}", Severity: {:?}, Resources: {:?}",
self.total, self.living(), self.deceased, self.total,
self.immediate, self.delayed, self.minor, self.living(),
self.severity, self.resource_level self.deceased,
self.immediate,
self.delayed,
self.minor,
self.severity,
self.resource_level
) )
} }
} }
@@ -227,9 +228,7 @@ pub enum ResourceLevel {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{ use crate::domain::{BreathingPattern, BreathingType, ConfidenceScore, ScanZoneId};
BreathingPattern, BreathingType, ConfidenceScore, ScanZoneId,
};
use chrono::Utc; use chrono::Utc;
fn create_test_vitals(rate_bpm: f32) -> VitalSignsReading { fn create_test_vitals(rate_bpm: f32) -> VitalSignsReading {
@@ -278,12 +277,14 @@ mod tests {
fn test_mass_casualty_assessment() { fn test_mass_casualty_assessment() {
let survivors: Vec<Survivor> = (0..10) let survivors: Vec<Survivor> = (0..10)
.map(|i| { .map(|i| {
let rate = if i < 3 { 35.0 } else if i < 6 { 16.0 } else { 18.0 }; let rate = if i < 3 {
Survivor::new( 35.0
ScanZoneId::new(), } else if i < 6 {
create_test_vitals(rate), 16.0
None, } else {
) 18.0
};
Survivor::new(ScanZoneId::new(), create_test_vitals(rate), None)
}) })
.collect(); .collect();
@@ -297,21 +298,13 @@ mod tests {
#[test] #[test]
fn test_priority_with_factors() { fn test_priority_with_factors() {
// Deteriorating patient should be upgraded // Deteriorating patient should be upgraded
let priority = PriorityCalculator::calculate_with_factors( let priority =
&TriageStatus::Delayed, PriorityCalculator::calculate_with_factors(&TriageStatus::Delayed, true, 0, None);
true,
0,
None,
);
assert_eq!(priority, Priority::Critical); assert_eq!(priority, Priority::Critical);
// Deep burial should upgrade // Deep burial should upgrade
let priority = PriorityCalculator::calculate_with_factors( let priority =
&TriageStatus::Delayed, PriorityCalculator::calculate_with_factors(&TriageStatus::Delayed, false, 0, Some(4.0));
false,
0,
Some(4.0),
);
assert_eq!(priority, Priority::Critical); assert_eq!(priority, Priority::Critical);
} }
} }
+21 -29
View File
@@ -2,14 +2,14 @@
//! //!
//! These types are used for serializing/deserializing API requests and responses. //! These types are used for serializing/deserializing API requests and responses.
//! They provide a clean separation between domain models and API contracts. //! They provide a clean separation between domain models and API contracts.
#![allow(missing_docs)]
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use crate::domain::{ use crate::domain::{
DisasterType, EventStatus, ZoneStatus, TriageStatus, Priority, AlertStatus, DisasterType, EventStatus, Priority, SurvivorStatus, TriageStatus, ZoneStatus,
AlertStatus, SurvivorStatus,
}; };
// ============================================================================ // ============================================================================
@@ -206,9 +206,7 @@ pub enum ZoneBoundsDto {
radius: f64, radius: f64,
}, },
/// Polygon boundary (list of vertices) /// Polygon boundary (list of vertices)
Polygon { Polygon { vertices: Vec<(f64, f64)> },
vertices: Vec<(f64, f64)>,
},
} }
/// Scan parameters for a zone. /// Scan parameters for a zone.
@@ -232,9 +230,15 @@ pub struct ScanParametersDto {
pub heartbeat_detection: bool, pub heartbeat_detection: bool,
} }
fn default_sensitivity() -> f64 { 0.8 } fn default_sensitivity() -> f64 {
fn default_max_depth() -> f64 { 5.0 } 0.8
fn default_true() -> bool { true } }
fn default_max_depth() -> f64 {
5.0
}
fn default_true() -> bool {
true
}
impl Default for ScanParametersDto { impl Default for ScanParametersDto {
fn default() -> Self { fn default() -> Self {
@@ -550,10 +554,7 @@ pub enum WebSocketMessage {
survivor: SurvivorResponse, survivor: SurvivorResponse,
}, },
/// Survivor lost (signal lost) /// Survivor lost (signal lost)
SurvivorLost { SurvivorLost { event_id: Uuid, survivor_id: Uuid },
event_id: Uuid,
survivor_id: Uuid,
},
/// New alert generated /// New alert generated
AlertCreated { AlertCreated {
event_id: Uuid, event_id: Uuid,
@@ -577,14 +578,9 @@ pub enum WebSocketMessage {
new_status: EventStatusDto, new_status: EventStatusDto,
}, },
/// Heartbeat/keep-alive /// Heartbeat/keep-alive
Heartbeat { Heartbeat { timestamp: DateTime<Utc> },
timestamp: DateTime<Utc>,
},
/// Error message /// Error message
Error { Error { code: String, message: String },
code: String,
message: String,
},
} }
/// WebSocket subscription request. /// WebSocket subscription request.
@@ -592,19 +588,13 @@ pub enum WebSocketMessage {
#[serde(tag = "action", rename_all = "snake_case")] #[serde(tag = "action", rename_all = "snake_case")]
pub enum WebSocketRequest { pub enum WebSocketRequest {
/// Subscribe to events for a disaster event /// Subscribe to events for a disaster event
Subscribe { Subscribe { event_id: Uuid },
event_id: Uuid,
},
/// Unsubscribe from events /// Unsubscribe from events
Unsubscribe { Unsubscribe { event_id: Uuid },
event_id: Uuid,
},
/// Subscribe to all events /// Subscribe to all events
SubscribeAll, SubscribeAll,
/// Request current state /// Request current state
GetState { GetState { event_id: Uuid },
event_id: Uuid,
},
} }
// ============================================================================ // ============================================================================
@@ -816,7 +806,9 @@ pub struct ListEventsQuery {
pub page_size: usize, pub page_size: usize,
} }
fn default_page_size() -> usize { 20 } fn default_page_size() -> usize {
20
}
/// Query parameters for listing survivors. /// Query parameters for listing survivors.
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
+4 -10
View File
@@ -2,6 +2,7 @@
//! //!
//! This module provides a unified error type that maps to appropriate HTTP status codes //! This module provides a unified error type that maps to appropriate HTTP status codes
//! and JSON error responses for the API. //! and JSON error responses for the API.
#![allow(missing_docs)]
use axum::{ use axum::{
http::StatusCode, http::StatusCode,
@@ -23,10 +24,7 @@ use uuid::Uuid;
pub enum ApiError { pub enum ApiError {
/// Resource not found (404) /// Resource not found (404)
#[error("Resource not found: {resource_type} with id {id}")] #[error("Resource not found: {resource_type} with id {id}")]
NotFound { NotFound { resource_type: String, id: String },
resource_type: String,
id: String,
},
/// Invalid request data (400) /// Invalid request data (400)
#[error("Bad request: {message}")] #[error("Bad request: {message}")]
@@ -45,9 +43,7 @@ pub enum ApiError {
/// Conflict with existing resource (409) /// Conflict with existing resource (409)
#[error("Conflict: {message}")] #[error("Conflict: {message}")]
Conflict { Conflict { message: String },
message: String,
},
/// Resource is in invalid state for operation (409) /// Resource is in invalid state for operation (409)
#[error("Invalid state: {message}")] #[error("Invalid state: {message}")]
@@ -66,9 +62,7 @@ pub enum ApiError {
/// Service unavailable (503) /// Service unavailable (503)
#[error("Service unavailable: {message}")] #[error("Service unavailable: {message}")]
ServiceUnavailable { ServiceUnavailable { message: String },
message: String,
},
/// Domain error from business logic /// Domain error from business logic
#[error("Domain error: {0}")] #[error("Domain error: {0}")]
@@ -15,8 +15,7 @@ use super::dto::*;
use super::error::{ApiError, ApiResult}; use super::error::{ApiError, ApiResult};
use super::state::AppState; use super::state::AppState;
use crate::domain::{ use crate::domain::{
DisasterEvent, DisasterType, ScanZone, ZoneBounds, DisasterEvent, DisasterType, MovementType, ScanParameters, ScanResolution, ScanZone, ZoneBounds,
ScanParameters, ScanResolution, MovementType,
}; };
// ============================================================================ // ============================================================================
@@ -95,7 +94,7 @@ pub async fn list_events(
let total = filtered.len(); let total = filtered.len();
// Apply pagination // Apply pagination
let page_size = query.page_size.min(100).max(1); let page_size = query.page_size.clamp(1, 100);
let start = query.page * page_size; let start = query.page * page_size;
let events: Vec<_> = filtered let events: Vec<_> = filtered
.into_iter() .into_iter()
@@ -318,7 +317,12 @@ pub async fn add_zone(
) -> ApiResult<(StatusCode, Json<ZoneResponse>)> { ) -> ApiResult<(StatusCode, Json<ZoneResponse>)> {
// Convert DTO to domain // Convert DTO to domain
let bounds = match request.bounds { let bounds = match request.bounds {
ZoneBoundsDto::Rectangle { min_x, min_y, max_x, max_y } => { ZoneBoundsDto::Rectangle {
min_x,
min_y,
max_x,
max_y,
} => {
if max_x <= min_x || max_y <= min_y { if max_x <= min_x || max_y <= min_y {
return Err(ApiError::validation( return Err(ApiError::validation(
"max coordinates must be greater than min coordinates", "max coordinates must be greater than min coordinates",
@@ -327,7 +331,11 @@ pub async fn add_zone(
} }
ZoneBounds::rectangle(min_x, min_y, max_x, max_y) ZoneBounds::rectangle(min_x, min_y, max_x, max_y)
} }
ZoneBoundsDto::Circle { center_x, center_y, radius } => { ZoneBoundsDto::Circle {
center_x,
center_y,
radius,
} => {
if radius <= 0.0 { if radius <= 0.0 {
return Err(ApiError::validation( return Err(ApiError::validation(
"radius must be positive", "radius must be positive",
@@ -713,26 +721,29 @@ fn event_to_response(event: DisasterEvent) -> EventResponse {
fn zone_to_response(zone: &ScanZone) -> ZoneResponse { fn zone_to_response(zone: &ScanZone) -> ZoneResponse {
let bounds = match zone.bounds() { let bounds = match zone.bounds() {
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } => { ZoneBounds::Rectangle {
ZoneBoundsDto::Rectangle { min_x,
min_x: *min_x, min_y,
min_y: *min_y, max_x,
max_x: *max_x, max_y,
max_y: *max_y, } => ZoneBoundsDto::Rectangle {
} min_x: *min_x,
} min_y: *min_y,
ZoneBounds::Circle { center_x, center_y, radius } => { max_x: *max_x,
ZoneBoundsDto::Circle { max_y: *max_y,
center_x: *center_x, },
center_y: *center_y, ZoneBounds::Circle {
radius: *radius, center_x,
} center_y,
} radius,
ZoneBounds::Polygon { vertices } => { } => ZoneBoundsDto::Circle {
ZoneBoundsDto::Polygon { center_x: *center_x,
vertices: vertices.clone(), center_y: *center_y,
} radius: *radius,
} },
ZoneBounds::Polygon { vertices } => ZoneBoundsDto::Polygon {
vertices: vertices.clone(),
},
}; };
let params = zone.parameters(); let params = zone.parameters();
@@ -775,7 +786,11 @@ fn survivor_to_response(survivor: &crate::Survivor) -> SurvivorResponse {
let latest_vitals = survivor.vital_signs().latest(); let latest_vitals = survivor.vital_signs().latest();
let vital_signs = VitalSignsSummaryDto { let vital_signs = VitalSignsSummaryDto {
breathing_rate: latest_vitals.and_then(|v| v.breathing.as_ref().map(|b| b.rate_bpm)), breathing_rate: latest_vitals.and_then(|v| v.breathing.as_ref().map(|b| b.rate_bpm)),
breathing_type: latest_vitals.and_then(|v| v.breathing.as_ref().map(|b| format!("{:?}", b.pattern_type))), breathing_type: latest_vitals.and_then(|v| {
v.breathing
.as_ref()
.map(|b| format!("{:?}", b.pattern_type))
}),
heart_rate: latest_vitals.and_then(|v| v.heartbeat.as_ref().map(|h| h.rate_bpm)), heart_rate: latest_vitals.and_then(|v| v.heartbeat.as_ref().map(|h| h.rate_bpm)),
has_heartbeat: latest_vitals.map(|v| v.has_heartbeat()).unwrap_or(false), has_heartbeat: latest_vitals.map(|v| v.has_heartbeat()).unwrap_or(false),
has_movement: latest_vitals.map(|v| v.has_movement()).unwrap_or(false), has_movement: latest_vitals.map(|v| v.has_movement()).unwrap_or(false),
@@ -786,7 +801,9 @@ fn survivor_to_response(survivor: &crate::Survivor) -> SurvivorResponse {
None None
} }
}), }),
timestamp: latest_vitals.map(|v| v.timestamp).unwrap_or_else(chrono::Utc::now), timestamp: latest_vitals
.map(|v| v.timestamp)
.unwrap_or_else(chrono::Utc::now),
}; };
let metadata = { let metadata = {
@@ -795,7 +812,10 @@ fn survivor_to_response(survivor: &crate::Survivor) -> SurvivorResponse {
None None
} else { } else {
Some(SurvivorMetadataDto { Some(SurvivorMetadataDto {
estimated_age_category: m.estimated_age_category.as_ref().map(|a| format!("{:?}", a)), estimated_age_category: m
.estimated_age_category
.as_ref()
.map(|a| format!("{:?}", a)),
assigned_team: m.assigned_team.clone(), assigned_team: m.assigned_team.clone(),
notes: m.notes.clone(), notes: m.notes.clone(),
tags: m.tags.clone(), tags: m.tags.clone(),
@@ -1055,9 +1075,9 @@ pub async fn list_domain_events(
State(state): State<AppState>, State(state): State<AppState>,
) -> ApiResult<Json<DomainEventsResponse>> { ) -> ApiResult<Json<DomainEventsResponse>> {
let store = state.event_store(); let store = state.event_store();
let events = store.all().map_err(|e| ApiError::internal( let events = store
format!("Failed to read event store: {}", e), .all()
))?; .map_err(|e| ApiError::internal(format!("Failed to read event store: {}", e)))?;
let event_dtos: Vec<DomainEventDto> = events let event_dtos: Vec<DomainEventDto> = events
.iter() .iter()
+26 -8
View File
@@ -33,14 +33,14 @@
//! - `WS /ws/mat/stream` - Real-time survivor and alert stream //! - `WS /ws/mat/stream` - Real-time survivor and alert stream
pub mod dto; pub mod dto;
pub mod handlers;
pub mod error; pub mod error;
pub mod handlers;
pub mod state; pub mod state;
pub mod websocket; pub mod websocket;
use axum::{ use axum::{
Router,
routing::{get, post}, routing::{get, post},
Router,
}; };
pub use dto::*; pub use dto::*;
@@ -64,21 +64,39 @@ pub use state::AppState;
pub fn create_router(state: AppState) -> Router { pub fn create_router(state: AppState) -> Router {
Router::new() Router::new()
// Event endpoints // Event endpoints
.route("/api/v1/mat/events", get(handlers::list_events).post(handlers::create_event)) .route(
"/api/v1/mat/events",
get(handlers::list_events).post(handlers::create_event),
)
.route("/api/v1/mat/events/:event_id", get(handlers::get_event)) .route("/api/v1/mat/events/:event_id", get(handlers::get_event))
// Zone endpoints // Zone endpoints
.route("/api/v1/mat/events/:event_id/zones", get(handlers::list_zones).post(handlers::add_zone)) .route(
"/api/v1/mat/events/:event_id/zones",
get(handlers::list_zones).post(handlers::add_zone),
)
// Survivor endpoints // Survivor endpoints
.route("/api/v1/mat/events/:event_id/survivors", get(handlers::list_survivors)) .route(
"/api/v1/mat/events/:event_id/survivors",
get(handlers::list_survivors),
)
// Alert endpoints // Alert endpoints
.route("/api/v1/mat/events/:event_id/alerts", get(handlers::list_alerts)) .route(
.route("/api/v1/mat/alerts/:alert_id/acknowledge", post(handlers::acknowledge_alert)) "/api/v1/mat/events/:event_id/alerts",
get(handlers::list_alerts),
)
.route(
"/api/v1/mat/alerts/:alert_id/acknowledge",
post(handlers::acknowledge_alert),
)
// Scan control endpoints (ADR-001: CSI data ingestion + pipeline control) // Scan control endpoints (ADR-001: CSI data ingestion + pipeline control)
.route("/api/v1/mat/scan/csi", post(handlers::push_csi_data)) .route("/api/v1/mat/scan/csi", post(handlers::push_csi_data))
.route("/api/v1/mat/scan/control", post(handlers::scan_control)) .route("/api/v1/mat/scan/control", post(handlers::scan_control))
.route("/api/v1/mat/scan/status", get(handlers::pipeline_status)) .route("/api/v1/mat/scan/status", get(handlers::pipeline_status))
// Domain event store endpoint // Domain event store endpoint
.route("/api/v1/mat/events/domain", get(handlers::list_domain_events)) .route(
"/api/v1/mat/events/domain",
get(handlers::list_domain_events),
)
// WebSocket endpoint // WebSocket endpoint
.route("/ws/mat/stream", get(websocket::ws_handler)) .route("/ws/mat/stream", get(websocket::ws_handler))
.with_state(state) .with_state(state)
+15 -14
View File
@@ -2,6 +2,7 @@
//! //!
//! This module provides the shared state that is passed to all API handlers. //! This module provides the shared state that is passed to all API handlers.
//! It contains repositories, services, and real-time event broadcasting. //! It contains repositories, services, and real-time event broadcasting.
#![allow(missing_docs)]
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@@ -10,12 +11,12 @@ use parking_lot::RwLock;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use uuid::Uuid; use uuid::Uuid;
use crate::domain::{
DisasterEvent, Alert,
events::{EventStore, InMemoryEventStore},
};
use crate::detection::{DetectionPipeline, DetectionConfig};
use super::dto::WebSocketMessage; use super::dto::WebSocketMessage;
use crate::detection::{DetectionConfig, DetectionPipeline};
use crate::domain::{
events::{EventStore, InMemoryEventStore},
Alert, DisasterEvent,
};
/// Shared application state for the API. /// Shared application state for the API.
/// ///
@@ -109,12 +110,16 @@ impl AppState {
/// Get scanning state. /// Get scanning state.
pub fn is_scanning(&self) -> bool { pub fn is_scanning(&self) -> bool {
self.inner.scanning.load(std::sync::atomic::Ordering::SeqCst) self.inner
.scanning
.load(std::sync::atomic::Ordering::SeqCst)
} }
/// Set scanning state. /// Set scanning state.
pub fn set_scanning(&self, state: bool) { pub fn set_scanning(&self, state: bool) {
self.inner.scanning.store(state, std::sync::atomic::Ordering::SeqCst); self.inner
.scanning
.store(state, std::sync::atomic::Ordering::SeqCst);
} }
// ======================================================================== // ========================================================================
@@ -235,7 +240,7 @@ impl Default for AppState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{DisasterType, DisasterEvent}; use crate::domain::{DisasterEvent, DisasterType};
use geo::Point; use geo::Point;
#[test] #[test]
@@ -258,11 +263,7 @@ mod tests {
#[test] #[test]
fn test_update_event() { fn test_update_event() {
let state = AppState::new(); let state = AppState::new();
let event = DisasterEvent::new( let event = DisasterEvent::new(DisasterType::Earthquake, Point::new(0.0, 0.0), "Test");
DisasterType::Earthquake,
Point::new(0.0, 0.0),
"Test",
);
let id = *event.id().as_uuid(); let id = *event.id().as_uuid();
state.store_event(event); state.store_event(event);
@@ -279,7 +280,7 @@ mod tests {
#[test] #[test]
fn test_broadcast_subscribe() { fn test_broadcast_subscribe() {
let state = AppState::new(); let state = AppState::new();
let mut rx = state.subscribe(); let _rx = state.subscribe();
state.broadcast(WebSocketMessage::Heartbeat { state.broadcast(WebSocketMessage::Heartbeat {
timestamp: chrono::Utc::now(), timestamp: chrono::Utc::now(),
@@ -76,10 +76,7 @@ use super::state::AppState;
/// description: WebSocket connection established /// description: WebSocket connection established
/// ``` /// ```
#[tracing::instrument(skip(state, ws))] #[tracing::instrument(skip(state, ws))]
pub async fn ws_handler( pub async fn ws_handler(State(state): State<AppState>, ws: WebSocketUpgrade) -> Response {
State(state): State<AppState>,
ws: WebSocketUpgrade,
) -> Response {
ws.on_upgrade(move |socket| handle_socket(socket, state)) ws.on_upgrade(move |socket| handle_socket(socket, state))
} }
@@ -88,7 +85,8 @@ async fn handle_socket(socket: WebSocket, state: AppState) {
let (mut sender, mut receiver) = socket.split(); let (mut sender, mut receiver) = socket.split();
// Subscription state for this connection // Subscription state for this connection
let subscriptions: Arc<Mutex<SubscriptionState>> = Arc::new(Mutex::new(SubscriptionState::new())); let subscriptions: Arc<Mutex<SubscriptionState>> =
Arc::new(Mutex::new(SubscriptionState::new()));
// Subscribe to broadcast channel // Subscribe to broadcast channel
let mut broadcast_rx = state.subscribe(); let mut broadcast_rx = state.subscribe();
@@ -260,7 +258,7 @@ impl SubscriptionState {
WebSocketMessage::ZoneScanComplete { event_id, .. } => Some(*event_id), WebSocketMessage::ZoneScanComplete { event_id, .. } => Some(*event_id),
WebSocketMessage::EventStatusChanged { event_id, .. } => Some(*event_id), WebSocketMessage::EventStatusChanged { event_id, .. } => Some(*event_id),
WebSocketMessage::Heartbeat { .. } => None, // Always receive WebSocketMessage::Heartbeat { .. } => None, // Always receive
WebSocketMessage::Error { .. } => None, // Always receive WebSocketMessage::Error { .. } => None, // Always receive
}; };
match event_id { match event_id {
@@ -1,4 +1,5 @@
//! Breathing pattern detection from CSI signals. //! Breathing pattern detection from CSI signals.
#![allow(missing_docs)]
use crate::domain::{BreathingPattern, BreathingType}; use crate::domain::{BreathingPattern, BreathingType};
@@ -51,7 +52,8 @@ impl CompressedBreathingBuffer {
// policy's age computation (now_ts - last_access_ts + 1) never wraps to // policy's age computation (now_ts - last_access_ts + 1) never wraps to
// zero (which would cause a divide-by-zero in wrapping_div). // zero (which would cause a divide-by-zero in wrapping_div).
self.compressor.set_access(ts, ts); self.compressor.set_access(ts, ts);
self.compressor.push_frame(amplitudes, ts, &mut self.encoded); self.compressor
.push_frame(amplitudes, ts, &mut self.encoded);
self.frame_count += 1; self.frame_count += 1;
} }
@@ -104,8 +106,8 @@ pub struct BreathingDetectorConfig {
impl Default for BreathingDetectorConfig { impl Default for BreathingDetectorConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
min_rate_bpm: 4.0, // Very slow breathing min_rate_bpm: 4.0, // Very slow breathing
max_rate_bpm: 40.0, // Fast breathing (distressed) max_rate_bpm: 40.0, // Fast breathing (distressed)
min_amplitude: 0.1, min_amplitude: 0.1,
window_size: 512, window_size: 512,
window_overlap: 0.5, window_overlap: 0.5,
@@ -147,12 +149,8 @@ impl BreathingDetector {
let min_freq = self.config.min_rate_bpm as f64 / 60.0; let min_freq = self.config.min_rate_bpm as f64 / 60.0;
let max_freq = self.config.max_rate_bpm as f64 / 60.0; let max_freq = self.config.max_rate_bpm as f64 / 60.0;
let (dominant_freq, amplitude) = self.find_dominant_frequency( let (dominant_freq, amplitude) =
&spectrum, self.find_dominant_frequency(&spectrum, sample_rate, min_freq, max_freq)?;
sample_rate,
min_freq,
max_freq,
)?;
// Convert to BPM // Convert to BPM
let rate_bpm = (dominant_freq * 60.0) as f32; let rate_bpm = (dominant_freq * 60.0) as f32;
@@ -185,32 +183,27 @@ impl BreathingDetector {
/// Compute frequency spectrum using FFT /// Compute frequency spectrum using FFT
fn compute_spectrum(&self, signal: &[f64]) -> Vec<f64> { fn compute_spectrum(&self, signal: &[f64]) -> Vec<f64> {
use rustfft::{FftPlanner, num_complex::Complex}; use rustfft::{num_complex::Complex, FftPlanner};
let n = signal.len().next_power_of_two(); let n = signal.len().next_power_of_two();
let mut planner = FftPlanner::new(); let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(n); let fft = planner.plan_fft_forward(n);
// Prepare input with zero padding // Prepare input with zero padding
let mut buffer: Vec<Complex<f64>> = signal let mut buffer: Vec<Complex<f64>> = signal.iter().map(|&x| Complex::new(x, 0.0)).collect();
.iter()
.map(|&x| Complex::new(x, 0.0))
.collect();
buffer.resize(n, Complex::new(0.0, 0.0)); buffer.resize(n, Complex::new(0.0, 0.0));
// Apply Hanning window // Apply Hanning window
for (i, sample) in buffer.iter_mut().enumerate().take(signal.len()) { for (i, sample) in buffer.iter_mut().enumerate().take(signal.len()) {
let window = 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / signal.len() as f64).cos()); let window =
0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / signal.len() as f64).cos());
*sample = Complex::new(sample.re * window, 0.0); *sample = Complex::new(sample.re * window, 0.0);
} }
fft.process(&mut buffer); fft.process(&mut buffer);
// Return magnitude spectrum (only positive frequencies) // Return magnitude spectrum (only positive frequencies)
buffer.iter() buffer.iter().take(n / 2).map(|c| c.norm()).collect()
.take(n / 2)
.map(|c| c.norm())
.collect()
} }
/// Find dominant frequency in a given range /// Find dominant frequency in a given range
@@ -235,10 +228,11 @@ impl BreathingDetector {
let mut max_amplitude = 0.0; let mut max_amplitude = 0.0;
let mut max_bin_idx = min_bin; let mut max_bin_idx = min_bin;
for i in min_bin..=max_bin { for (i, &amp_val) in spectrum[min_bin..=max_bin].iter().enumerate() {
if spectrum[i] > max_amplitude { let bin = min_bin + i;
max_amplitude = spectrum[i]; if amp_val > max_amplitude {
max_bin_idx = i; max_amplitude = amp_val;
max_bin_idx = bin;
} }
} }
@@ -271,7 +265,8 @@ impl BreathingDetector {
} }
// Also check harmonics (2x, 3x frequency) // Also check harmonics (2x, 3x frequency)
let harmonic_power: f64 = [2, 3].iter() let harmonic_power: f64 = [2, 3]
.iter()
.filter_map(|&mult| { .filter_map(|&mult| {
let harmonic_bin = peak_bin * mult; let harmonic_bin = peak_bin * mult;
if harmonic_bin < spectrum.len() { if harmonic_bin < spectrum.len() {
@@ -394,9 +389,7 @@ mod tests {
let detector = BreathingDetector::with_defaults(); let detector = BreathingDetector::with_defaults();
// Random noise with low amplitude // Random noise with low amplitude
let signal: Vec<f64> = (0..1000) let signal: Vec<f64> = (0..1000).map(|i| (i as f64 * 0.1).sin() * 0.01).collect();
.map(|i| (i as f64 * 0.1).sin() * 0.01)
.collect();
let result = detector.detect(&signal, 100.0); let result = detector.detect(&signal, 100.0);
// Should either be None or have very low confidence // Should either be None or have very low confidence
@@ -9,9 +9,7 @@
//! The classifier produces a single confidence score and a recommended //! The classifier produces a single confidence score and a recommended
//! triage status based on the combined signals. //! triage status based on the combined signals.
use crate::domain::{ use crate::domain::{BreathingType, MovementType, TriageStatus, VitalSignsReading};
BreathingType, MovementType, TriageStatus, VitalSignsReading,
};
/// Configuration for the ensemble classifier /// Configuration for the ensemble classifier
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -101,8 +99,9 @@ impl EnsembleClassifier {
}; };
// Weighted ensemble confidence // Weighted ensemble confidence
let total_weight = let total_weight = self.config.breathing_weight
self.config.breathing_weight + self.config.heartbeat_weight + self.config.movement_weight; + self.config.heartbeat_weight
+ self.config.movement_weight;
let ensemble_confidence = if total_weight > 0.0 { let ensemble_confidence = if total_weight > 0.0 {
(breathing_conf * self.config.breathing_weight (breathing_conf * self.config.breathing_weight
@@ -147,11 +146,7 @@ impl EnsembleClassifier {
/// as Immediate regardless of confidence level, because in disaster response /// as Immediate regardless of confidence level, because in disaster response
/// a false negative (missing a survivor in distress) is far more costly /// a false negative (missing a survivor in distress) is far more costly
/// than a false positive. /// than a false positive.
fn determine_triage( fn determine_triage(&self, reading: &VitalSignsReading, confidence: f64) -> TriageStatus {
&self,
reading: &VitalSignsReading,
confidence: f64,
) -> TriageStatus {
// CRITICAL PATTERNS: always classify regardless of confidence. // CRITICAL PATTERNS: always classify regardless of confidence.
// In disaster response, any sign of distress must be escalated. // In disaster response, any sign of distress must be escalated.
if let Some(ref breathing) = reading.breathing { if let Some(ref breathing) = reading.breathing {
@@ -163,7 +158,7 @@ impl EnsembleClassifier {
} }
let rate = breathing.rate_bpm; let rate = breathing.rate_bpm;
if rate < 10.0 || rate > 30.0 { if !(10.0..=30.0).contains(&rate) {
return TriageStatus::Immediate; return TriageStatus::Immediate;
} }
} }
@@ -188,7 +183,7 @@ impl EnsembleClassifier {
if let Some(ref breathing) = reading.breathing { if let Some(ref breathing) = reading.breathing {
let rate = breathing.rate_bpm; let rate = breathing.rate_bpm;
if rate < 12.0 || rate > 24.0 { if !(12.0..=24.0).contains(&rate) {
if has_movement { if has_movement {
return TriageStatus::Delayed; return TriageStatus::Delayed;
} }
@@ -215,8 +210,7 @@ impl EnsembleClassifier {
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{ use crate::domain::{
BreathingPattern, HeartbeatSignature, MovementProfile, BreathingPattern, ConfidenceScore, HeartbeatSignature, MovementProfile, SignalStrength,
SignalStrength, ConfidenceScore,
}; };
fn make_reading( fn make_reading(
@@ -266,11 +260,7 @@ mod tests {
#[test] #[test]
fn test_agonal_breathing_is_immediate() { fn test_agonal_breathing_is_immediate() {
let classifier = EnsembleClassifier::new(EnsembleConfig::default()); let classifier = EnsembleClassifier::new(EnsembleConfig::default());
let reading = make_reading( let reading = make_reading(Some((8.0, BreathingType::Agonal)), None, MovementType::None);
Some((8.0, BreathingType::Agonal)),
None,
MovementType::None,
);
let result = classifier.classify(&reading); let result = classifier.classify(&reading);
assert_eq!(result.recommended_triage, TriageStatus::Immediate); assert_eq!(result.recommended_triage, TriageStatus::Immediate);
@@ -295,8 +285,10 @@ mod tests {
let mut reading = VitalSignsReading::new(None, None, mv); let mut reading = VitalSignsReading::new(None, None, mv);
reading.confidence = ConfidenceScore::new(0.5); reading.confidence = ConfidenceScore::new(0.5);
let mut config = EnsembleConfig::default(); let config = EnsembleConfig {
config.min_ensemble_confidence = 0.0; min_ensemble_confidence: 0.0,
..EnsembleConfig::default()
};
let classifier = EnsembleClassifier::new(config); let classifier = EnsembleClassifier::new(config);
let result = classifier.classify(&reading); let result = classifier.classify(&reading);
@@ -1,4 +1,5 @@
//! Heartbeat detection from micro-Doppler signatures in CSI. //! Heartbeat detection from micro-Doppler signatures in CSI.
#![allow(missing_docs)]
use crate::domain::{HeartbeatSignature, SignalStrength}; use crate::domain::{HeartbeatSignature, SignalStrength};
@@ -31,7 +32,12 @@ impl CompressedHeartbeatSpectrogram {
.map(|i| TemporalTensorCompressor::new(TierPolicy::default(), 1, i as u32)) .map(|i| TemporalTensorCompressor::new(TierPolicy::default(), 1, i as u32))
.collect(); .collect();
let encoded = vec![Vec::new(); n_freq_bins]; let encoded = vec![Vec::new(); n_freq_bins];
Self { bin_buffers, encoded, n_freq_bins, frame_count: 0 } Self {
bin_buffers,
encoded,
n_freq_bins,
frame_count: 0,
}
} }
/// Push one column of the spectrogram (one time step, all frequency bins). /// Push one column of the spectrogram (one time step, all frequency bins).
@@ -71,11 +77,19 @@ impl CompressedHeartbeatSpectrogram {
total += recent; total += recent;
count += 1; count += 1;
} }
if count == 0 { 0.0 } else { total / count as f32 } if count == 0 {
0.0
} else {
total / count as f32
}
} }
pub fn frame_count(&self) -> u64 { self.frame_count } pub fn frame_count(&self) -> u64 {
pub fn n_freq_bins(&self) -> usize { self.n_freq_bins } self.frame_count
}
pub fn n_freq_bins(&self) -> usize {
self.n_freq_bins
}
} }
/// Configuration for heartbeat detection /// Configuration for heartbeat detection
@@ -98,8 +112,8 @@ pub struct HeartbeatDetectorConfig {
impl Default for HeartbeatDetectorConfig { impl Default for HeartbeatDetectorConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
min_rate_bpm: 30.0, // Very slow (bradycardia) min_rate_bpm: 30.0, // Very slow (bradycardia)
max_rate_bpm: 200.0, // Very fast (extreme tachycardia) max_rate_bpm: 200.0, // Very fast (extreme tachycardia)
min_signal_strength: 0.05, min_signal_strength: 0.05,
window_size: 1024, window_size: 1024,
enhanced_processing: true, enhanced_processing: true,
@@ -161,12 +175,8 @@ impl HeartbeatDetector {
let min_freq = self.config.min_rate_bpm as f64 / 60.0; let min_freq = self.config.min_rate_bpm as f64 / 60.0;
let max_freq = self.config.max_rate_bpm as f64 / 60.0; let max_freq = self.config.max_rate_bpm as f64 / 60.0;
let (heart_freq, strength) = self.find_heartbeat_frequency( let (heart_freq, strength) =
&spectrum, self.find_heartbeat_frequency(&spectrum, sample_rate, min_freq, max_freq)?;
sample_rate,
min_freq,
max_freq,
)?;
if strength < self.config.min_signal_strength { if strength < self.config.min_signal_strength {
return None; return None;
@@ -276,7 +286,7 @@ impl HeartbeatDetector {
/// Compute micro-Doppler spectrum optimized for heartbeat detection /// Compute micro-Doppler spectrum optimized for heartbeat detection
fn compute_micro_doppler_spectrum(&self, signal: &[f64], _sample_rate: f64) -> Vec<f64> { fn compute_micro_doppler_spectrum(&self, signal: &[f64], _sample_rate: f64) -> Vec<f64> {
use rustfft::{FftPlanner, num_complex::Complex}; use rustfft::{num_complex::Complex, FftPlanner};
let n = signal.len().next_power_of_two(); let n = signal.len().next_power_of_two();
let mut planner = FftPlanner::new(); let mut planner = FftPlanner::new();
@@ -288,8 +298,7 @@ impl HeartbeatDetector {
.enumerate() .enumerate()
.map(|(i, &x)| { .map(|(i, &x)| {
let n_f = signal.len() as f64; let n_f = signal.len() as f64;
let window = 0.42 let window = 0.42 - 0.5 * (2.0 * std::f64::consts::PI * i as f64 / n_f).cos()
- 0.5 * (2.0 * std::f64::consts::PI * i as f64 / n_f).cos()
+ 0.08 * (4.0 * std::f64::consts::PI * i as f64 / n_f).cos(); + 0.08 * (4.0 * std::f64::consts::PI * i as f64 / n_f).cos();
Complex::new(x * window, 0.0) Complex::new(x * window, 0.0)
}) })
@@ -299,10 +308,7 @@ impl HeartbeatDetector {
fft.process(&mut buffer); fft.process(&mut buffer);
// Return power spectrum // Return power spectrum
buffer.iter() buffer.iter().take(n / 2).map(|c| c.norm_sqr()).collect()
.take(n / 2)
.map(|c| c.norm_sqr())
.collect()
} }
/// Find heartbeat frequency in spectrum /// Find heartbeat frequency in spectrum
@@ -326,22 +332,24 @@ impl HeartbeatDetector {
// Find the strongest peak // Find the strongest peak
let mut max_power = 0.0; let mut max_power = 0.0;
let mut max_bin_idx = min_bin; let mut max_bin_idx = min_bin;
let upper = max_bin.min(spectrum.len() - 1);
for i in min_bin..=max_bin.min(spectrum.len() - 1) { for (i, &pwr) in spectrum[min_bin..=upper].iter().enumerate() {
if spectrum[i] > max_power { let bin = min_bin + i;
max_power = spectrum[i]; if pwr > max_power {
max_bin_idx = i; max_power = pwr;
max_bin_idx = bin;
} }
} }
// Check if it's a real peak (local maximum) // Check if it's a real peak (local maximum)
if max_bin_idx > 0 && max_bin_idx < spectrum.len() - 1 { if max_bin_idx > 0
if spectrum[max_bin_idx] <= spectrum[max_bin_idx - 1] && max_bin_idx < spectrum.len() - 1
|| spectrum[max_bin_idx] <= spectrum[max_bin_idx + 1] && (spectrum[max_bin_idx] <= spectrum[max_bin_idx - 1]
{ || spectrum[max_bin_idx] <= spectrum[max_bin_idx + 1])
// Not a real peak {
return None; // Not a real peak
} return None;
} }
let freq = max_bin_idx as f64 * freq_resolution; let freq = max_bin_idx as f64 * freq_resolution;
@@ -404,11 +412,7 @@ impl HeartbeatDetector {
let strength_score = (strength / 0.5).min(1.0) as f32; let strength_score = (strength / 0.5).min(1.0) as f32;
// Very low or very high HRV might indicate noise // Very low or very high HRV might indicate noise
let hrv_score = if hrv > 0.05 && hrv < 0.5 { let hrv_score = if hrv > 0.05 && hrv < 0.5 { 1.0 } else { 0.5 };
1.0
} else {
0.5
};
strength_score * 0.7 + hrv_score * 0.3 strength_score * 0.7 + hrv_score * 0.3
} }
@@ -434,8 +438,10 @@ mod heartbeat_buffer_tests {
// Low bins (0..15) should have higher power than high bins (16..31) // Low bins (0..15) should have higher power than high bins (16..31)
let low_power = spec.band_power(0, 15, 20); let low_power = spec.band_power(0, 15, 20);
let high_power = spec.band_power(16, 31, 20); let high_power = spec.band_power(16, 31, 20);
assert!(low_power >= high_power, assert!(
"low_power={low_power} should >= high_power={high_power}"); low_power >= high_power,
"low_power={low_power} should >= high_power={high_power}"
);
} }
} }
@@ -12,12 +12,12 @@ mod heartbeat;
mod movement; mod movement;
mod pipeline; mod pipeline;
pub use breathing::{BreathingDetector, BreathingDetectorConfig};
#[cfg(feature = "ruvector")] #[cfg(feature = "ruvector")]
pub use breathing::CompressedBreathingBuffer; pub use breathing::CompressedBreathingBuffer;
pub use breathing::{BreathingDetector, BreathingDetectorConfig};
pub use ensemble::{EnsembleClassifier, EnsembleConfig, EnsembleResult, SignalConfidences}; pub use ensemble::{EnsembleClassifier, EnsembleConfig, EnsembleResult, SignalConfidences};
pub use heartbeat::{HeartbeatDetector, HeartbeatDetectorConfig};
#[cfg(feature = "ruvector")] #[cfg(feature = "ruvector")]
pub use heartbeat::CompressedHeartbeatSpectrogram; pub use heartbeat::CompressedHeartbeatSpectrogram;
pub use heartbeat::{HeartbeatDetector, HeartbeatDetectorConfig};
pub use movement::{MovementClassifier, MovementClassifierConfig}; pub use movement::{MovementClassifier, MovementClassifierConfig};
pub use pipeline::{DetectionPipeline, DetectionConfig, VitalSignsDetector, CsiDataBuffer}; pub use pipeline::{CsiDataBuffer, DetectionConfig, DetectionPipeline, VitalSignsDetector};
@@ -54,11 +54,8 @@ impl MovementClassifier {
let periodicity = self.calculate_periodicity(csi_signal, sample_rate); let periodicity = self.calculate_periodicity(csi_signal, sample_rate);
// Determine movement type // Determine movement type
let (movement_type, is_voluntary) = self.determine_movement_type( let (movement_type, is_voluntary) =
variance, self.determine_movement_type(variance, max_change, periodicity);
max_change,
periodicity,
);
// Calculate intensity // Calculate intensity
let intensity = self.calculate_intensity(variance, max_change); let intensity = self.calculate_intensity(variance, max_change);
@@ -81,9 +78,7 @@ impl MovementClassifier {
} }
let mean = signal.iter().sum::<f64>() / signal.len() as f64; let mean = signal.iter().sum::<f64>() / signal.len() as f64;
let variance = signal.iter() let variance = signal.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / signal.len() as f64;
.map(|x| (x - mean).powi(2))
.sum::<f64>() / signal.len() as f64;
variance variance
} }
@@ -94,7 +89,8 @@ impl MovementClassifier {
return 0.0; return 0.0;
} }
signal.windows(2) signal
.windows(2)
.map(|w| (w[1] - w[0]).abs()) .map(|w| (w[1] - w[0]).abs())
.fold(0.0, f64::max) .fold(0.0, f64::max)
} }
@@ -120,7 +116,8 @@ impl MovementClassifier {
let mut max_corr = 0.0; let mut max_corr = 0.0;
for lag in 1..max_lag { for lag in 1..max_lag {
let corr: f64 = centered.iter() let corr: f64 = centered
.iter()
.take(n - lag) .take(n - lag)
.zip(centered.iter().skip(lag)) .zip(centered.iter().skip(lag))
.map(|(a, b)| a * b) .map(|(a, b)| a * b)
@@ -197,7 +194,8 @@ impl MovementClassifier {
let mean = signal.iter().sum::<f64>() / signal.len() as f64; let mean = signal.iter().sum::<f64>() / signal.len() as f64;
let centered: Vec<f64> = signal.iter().map(|x| x - mean).collect(); let centered: Vec<f64> = signal.iter().map(|x| x - mean).collect();
let zero_crossings: usize = centered.windows(2) let zero_crossings: usize = centered
.windows(2)
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0)) .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
.count(); .count();
@@ -227,13 +225,17 @@ mod tests {
let classifier = MovementClassifier::with_defaults(); let classifier = MovementClassifier::with_defaults();
// Simulate large movement // Simulate large movement
let mut signal: Vec<f64> = vec![0.0; 200]; let signal: Vec<f64> = (0..200)
for i in 50..100 { .map(|i| {
signal[i] = 2.0; if (50..100).contains(&i) {
} 2.0
for i in 150..180 { } else if (150..180).contains(&i) {
signal[i] = -1.5; -1.5
} } else {
0.0
}
})
.collect();
let profile = classifier.classify(&signal, 100.0); let profile = classifier.classify(&signal, 100.0);
assert!(matches!(profile.movement_type, MovementType::Gross)); assert!(matches!(profile.movement_type, MovementType::Gross));
@@ -259,15 +261,11 @@ mod tests {
let classifier = MovementClassifier::with_defaults(); let classifier = MovementClassifier::with_defaults();
// Low intensity // Low intensity
let low_signal: Vec<f64> = (0..200) let low_signal: Vec<f64> = (0..200).map(|i| (i as f64 * 0.1).sin() * 0.05).collect();
.map(|i| (i as f64 * 0.1).sin() * 0.05)
.collect();
let low_profile = classifier.classify(&low_signal, 100.0); let low_profile = classifier.classify(&low_signal, 100.0);
// High intensity // High intensity
let high_signal: Vec<f64> = (0..200) let high_signal: Vec<f64> = (0..200).map(|i| (i as f64 * 0.1).sin() * 2.0).collect();
.map(|i| (i as f64 * 0.1).sin() * 2.0)
.collect();
let high_profile = classifier.classify(&high_signal, 100.0); let high_profile = classifier.classify(&high_signal, 100.0);
assert!(high_profile.intensity > low_profile.intensity); assert!(high_profile.intensity > low_profile.intensity);
@@ -3,14 +3,13 @@
//! This module provides both traditional signal-processing-based detection //! This module provides both traditional signal-processing-based detection
//! and optional ML-enhanced detection for improved accuracy. //! and optional ML-enhanced detection for improved accuracy.
use super::{
BreathingDetector, BreathingDetectorConfig, HeartbeatDetector, HeartbeatDetectorConfig,
MovementClassifier, MovementClassifierConfig,
};
use crate::domain::{ScanZone, VitalSignsReading}; use crate::domain::{ScanZone, VitalSignsReading};
use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult}; use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult};
use crate::{DisasterConfig, MatError}; use crate::{DisasterConfig, MatError};
use super::{
BreathingDetector, BreathingDetectorConfig,
HeartbeatDetector, HeartbeatDetectorConfig,
MovementClassifier, MovementClassifierConfig,
};
/// Configuration for the detection pipeline /// Configuration for the detection pipeline
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -86,7 +85,7 @@ pub trait VitalSignsDetector: Send + Sync {
} }
/// Buffer for CSI data samples /// Buffer for CSI data samples
#[derive(Debug, Default)] #[derive(Debug, Default, Clone)]
pub struct CsiDataBuffer { pub struct CsiDataBuffer {
/// Amplitude samples /// Amplitude samples
pub amplitudes: Vec<f64>, pub amplitudes: Vec<f64>,
@@ -180,7 +179,7 @@ impl DetectionPipeline {
/// Check if ML pipeline is ready /// Check if ML pipeline is ready
pub fn ml_ready(&self) -> bool { pub fn ml_ready(&self) -> bool {
self.ml_pipeline.as_ref().map_or(true, |ml| ml.is_ready()) self.ml_pipeline.as_ref().is_none_or(|ml| ml.is_ready())
} }
/// Process a scan zone and return detected vital signs. /// Process a scan zone and return detected vital signs.
@@ -192,23 +191,30 @@ impl DetectionPipeline {
/// ///
/// Returns `None` if insufficient data is buffered (< 5 seconds) or if /// Returns `None` if insufficient data is buffered (< 5 seconds) or if
/// detection confidence is below the configured threshold. /// detection confidence is below the configured threshold.
pub async fn process_zone(&self, zone: &ScanZone) -> Result<Option<VitalSignsReading>, MatError> { pub async fn process_zone(
&self,
zone: &ScanZone,
) -> Result<Option<VitalSignsReading>, MatError> {
// Process buffered CSI data through the signal processing pipeline. // Process buffered CSI data through the signal processing pipeline.
// Data arrives via add_data() from hardware adapters (ESP32, Intel 5300, etc.) // Data arrives via add_data() from hardware adapters (ESP32, Intel 5300, etc.)
// or from the CSI push API endpoint. // or from the CSI push API endpoint.
let buffer = self.data_buffer.read(); // Drop the MutexGuard before hitting any await point.
let reading = {
if !buffer.has_sufficient_data(5.0) { let buffer = self.data_buffer.read();
// Need at least 5 seconds of data if !buffer.has_sufficient_data(5.0) {
return Ok(None); // Need at least 5 seconds of data
} return Ok(None);
}
// Detect vital signs using traditional pipeline // Detect vital signs using traditional pipeline
let reading = self.detect_from_buffer(&buffer, zone)?; self.detect_from_buffer(&buffer, zone)?
// `buffer` guard dropped here
};
// If ML is enabled and ready, enhance with ML predictions // If ML is enabled and ready, enhance with ML predictions
let enhanced_reading = if self.config.enable_ml && self.ml_ready() { let enhanced_reading = if self.config.enable_ml && self.ml_ready() {
self.enhance_with_ml(reading, &buffer).await? // Snapshot the buffer under the lock, then drop the guard before await.
let buffer_snapshot = { self.data_buffer.read().clone() };
self.enhance_with_ml(reading, &buffer_snapshot).await?
} else { } else {
reading reading
}; };
@@ -257,12 +263,16 @@ impl DetectionPipeline {
/// Get the latest ML detection results (if ML is enabled) /// Get the latest ML detection results (if ML is enabled)
pub async fn get_ml_results(&self) -> Option<MlDetectionResult> { pub async fn get_ml_results(&self) -> Option<MlDetectionResult> {
let buffer = self.data_buffer.read(); let ml = match &self.ml_pipeline {
if let Some(ref ml) = self.ml_pipeline { Some(ml) => ml,
ml.process(&buffer).await.ok() None => return None,
} else { };
None // Acquire lock, clone the relevant buffer data, then drop the guard before awaiting.
} let buffer = {
let guard = self.data_buffer.read();
guard.clone()
};
ml.process(&buffer).await.ok()
} }
/// Add CSI data to the processing buffer /// Add CSI data to the processing buffer
@@ -292,31 +302,29 @@ impl DetectionPipeline {
_zone: &ScanZone, _zone: &ScanZone,
) -> Result<Option<VitalSignsReading>, MatError> { ) -> Result<Option<VitalSignsReading>, MatError> {
// Detect breathing // Detect breathing
let breathing = self.breathing_detector.detect( let breathing = self
&buffer.amplitudes, .breathing_detector
buffer.sample_rate, .detect(&buffer.amplitudes, buffer.sample_rate);
);
// Detect heartbeat (if enabled) // Detect heartbeat (if enabled)
let heartbeat = if self.config.enable_heartbeat { let heartbeat = if self.config.enable_heartbeat {
let breathing_rate = breathing.as_ref().map(|b| b.rate_bpm as f64); let breathing_rate = breathing.as_ref().map(|b| b.rate_bpm as f64);
self.heartbeat_detector.detect( self.heartbeat_detector
&buffer.phases, .detect(&buffer.phases, buffer.sample_rate, breathing_rate)
buffer.sample_rate,
breathing_rate,
)
} else { } else {
None None
}; };
// Classify movement // Classify movement
let movement = self.movement_classifier.classify( let movement = self
&buffer.amplitudes, .movement_classifier
buffer.sample_rate, .classify(&buffer.amplitudes, buffer.sample_rate);
);
// Check if we detected anything // Check if we detected anything
if breathing.is_none() && heartbeat.is_none() && movement.movement_type == crate::domain::MovementType::None { if breathing.is_none()
&& heartbeat.is_none()
&& movement.movement_type == crate::domain::MovementType::None
{
return Ok(None); return Ok(None);
} }
@@ -358,31 +366,27 @@ impl DetectionPipeline {
impl VitalSignsDetector for DetectionPipeline { impl VitalSignsDetector for DetectionPipeline {
fn detect(&self, csi_data: &CsiDataBuffer) -> Option<VitalSignsReading> { fn detect(&self, csi_data: &CsiDataBuffer) -> Option<VitalSignsReading> {
// Detect breathing from amplitude variations // Detect breathing from amplitude variations
let breathing = self.breathing_detector.detect( let breathing = self
&csi_data.amplitudes, .breathing_detector
csi_data.sample_rate, .detect(&csi_data.amplitudes, csi_data.sample_rate);
);
// Detect heartbeat from phase variations // Detect heartbeat from phase variations
let heartbeat = if self.config.enable_heartbeat { let heartbeat = if self.config.enable_heartbeat {
let breathing_rate = breathing.as_ref().map(|b| b.rate_bpm as f64); let breathing_rate = breathing.as_ref().map(|b| b.rate_bpm as f64);
self.heartbeat_detector.detect( self.heartbeat_detector
&csi_data.phases, .detect(&csi_data.phases, csi_data.sample_rate, breathing_rate)
csi_data.sample_rate,
breathing_rate,
)
} else { } else {
None None
}; };
// Classify movement // Classify movement
let movement = self.movement_classifier.classify( let movement = self
&csi_data.amplitudes, .movement_classifier
csi_data.sample_rate, .classify(&csi_data.amplitudes, csi_data.sample_rate);
);
// Create reading if we detected anything // Create reading if we detected anything
if breathing.is_some() || heartbeat.is_some() if breathing.is_some()
|| heartbeat.is_some()
|| movement.movement_type != crate::domain::MovementType::None || movement.movement_type != crate::domain::MovementType::None
{ {
Some(VitalSignsReading::new(breathing, heartbeat, movement)) Some(VitalSignsReading::new(breathing, heartbeat, movement))
@@ -457,9 +461,7 @@ mod tests {
#[test] #[test]
fn test_config_from_disaster_config() { fn test_config_from_disaster_config() {
let disaster_config = DisasterConfig::builder() let disaster_config = DisasterConfig::builder().sensitivity(0.9).build();
.sensitivity(0.9)
.build();
let detection_config = DetectionConfig::from_disaster_config(&disaster_config); let detection_config = DetectionConfig::from_disaster_config(&disaster_config);
@@ -3,7 +3,7 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use super::{SurvivorId, TriageStatus, Coordinates3D}; use super::{Coordinates3D, SurvivorId, TriageStatus};
/// Unique identifier for an alert /// Unique identifier for an alert
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -398,11 +398,7 @@ mod tests {
#[test] #[test]
fn test_alert_lifecycle() { fn test_alert_lifecycle() {
let mut alert = Alert::new( let mut alert = Alert::new(SurvivorId::new(), Priority::High, create_test_payload());
SurvivorId::new(),
Priority::High,
create_test_payload(),
);
// Initial state // Initial state
assert!(alert.is_pending()); assert!(alert.is_pending());
@@ -429,11 +425,7 @@ mod tests {
#[test] #[test]
fn test_alert_escalation() { fn test_alert_escalation() {
let mut alert = Alert::new( let mut alert = Alert::new(SurvivorId::new(), Priority::Low, create_test_payload());
SurvivorId::new(),
Priority::Low,
create_test_payload(),
);
alert.escalate(); alert.escalate();
assert_eq!(alert.priority(), Priority::Medium); assert_eq!(alert.priority(), Priority::Medium);
@@ -452,8 +444,17 @@ mod tests {
#[test] #[test]
fn test_priority_from_triage() { fn test_priority_from_triage() {
assert_eq!(Priority::from_triage(&TriageStatus::Immediate), Priority::Critical); assert_eq!(
assert_eq!(Priority::from_triage(&TriageStatus::Delayed), Priority::High); Priority::from_triage(&TriageStatus::Immediate),
assert_eq!(Priority::from_triage(&TriageStatus::Minor), Priority::Medium); Priority::Critical
);
assert_eq!(
Priority::from_triage(&TriageStatus::Delayed),
Priority::High
);
assert_eq!(
Priority::from_triage(&TriageStatus::Minor),
Priority::Medium
);
} }
} }
@@ -17,7 +17,12 @@ pub struct Coordinates3D {
impl Coordinates3D { impl Coordinates3D {
/// Create new coordinates with uncertainty /// Create new coordinates with uncertainty
pub fn new(x: f64, y: f64, z: f64, uncertainty: LocationUncertainty) -> Self { pub fn new(x: f64, y: f64, z: f64, uncertainty: LocationUncertainty) -> Self {
Self { x, y, z, uncertainty } Self {
x,
y,
z,
uncertainty,
}
} }
/// Create coordinates with default uncertainty /// Create coordinates with default uncertainty
@@ -76,9 +81,9 @@ pub struct LocationUncertainty {
impl Default for LocationUncertainty { impl Default for LocationUncertainty {
fn default() -> Self { fn default() -> Self {
Self { Self {
horizontal_error: 2.0, // 2 meter default uncertainty horizontal_error: 2.0, // 2 meter default uncertainty
vertical_error: 1.0, // 1 meter vertical uncertainty vertical_error: 1.0, // 1 meter vertical uncertainty
confidence: 0.95, // 95% confidence confidence: 0.95, // 95% confidence
} }
} }
} }
@@ -118,11 +123,11 @@ impl LocationUncertainty {
// Combined uncertainty is reduced when multiple estimates agree // Combined uncertainty is reduced when multiple estimates agree
let h_var1 = self.horizontal_error * self.horizontal_error; let h_var1 = self.horizontal_error * self.horizontal_error;
let h_var2 = other.horizontal_error * other.horizontal_error; let h_var2 = other.horizontal_error * other.horizontal_error;
let combined_h_var = 1.0 / (1.0/h_var1 + 1.0/h_var2); let combined_h_var = 1.0 / (1.0 / h_var1 + 1.0 / h_var2);
let v_var1 = self.vertical_error * self.vertical_error; let v_var1 = self.vertical_error * self.vertical_error;
let v_var2 = other.vertical_error * other.vertical_error; let v_var2 = other.vertical_error * other.vertical_error;
let combined_v_var = 1.0 / (1.0/v_var1 + 1.0/v_var2); let combined_v_var = 1.0 / (1.0 / v_var1 + 1.0 / v_var2);
LocationUncertainty { LocationUncertainty {
horizontal_error: combined_h_var.sqrt(), horizontal_error: combined_h_var.sqrt(),
@@ -225,8 +230,10 @@ impl DebrisProfile {
/// Check if debris allows good signal penetration /// Check if debris allows good signal penetration
pub fn is_penetrable(&self) -> bool { pub fn is_penetrable(&self) -> bool {
!matches!(self.metal_content, MetalContent::High | MetalContent::Blocking) !matches!(
&& self.primary_material.attenuation_coefficient() < 5.0 self.metal_content,
MetalContent::High | MetalContent::Blocking
) && self.primary_material.attenuation_coefficient() < 5.0
} }
} }
@@ -1,13 +1,10 @@
//! Disaster event aggregate root. //! Disaster event aggregate root.
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid;
use geo::Point; use geo::Point;
use uuid::Uuid;
use super::{ use super::{Coordinates3D, ScanZone, ScanZoneId, Survivor, SurvivorId, VitalSignsReading};
Survivor, SurvivorId, ScanZone, ScanZoneId,
VitalSignsReading, Coordinates3D,
};
use crate::MatError; use crate::MatError;
/// Unique identifier for a disaster event /// Unique identifier for a disaster event
@@ -66,7 +63,7 @@ pub enum DisasterType {
impl DisasterType { impl DisasterType {
/// Get typical debris profile for this disaster type /// Get typical debris profile for this disaster type
pub fn typical_debris_profile(&self) -> super::DebrisProfile { pub fn typical_debris_profile(&self) -> super::DebrisProfile {
use super::{DebrisProfile, DebrisMaterial, MoistureLevel, MetalContent}; use super::{DebrisMaterial, DebrisProfile, MetalContent, MoistureLevel};
match self { match self {
DisasterType::BuildingCollapse => DebrisProfile { DisasterType::BuildingCollapse => DebrisProfile {
@@ -118,9 +115,9 @@ impl DisasterType {
/// Get expected maximum survival time (hours) /// Get expected maximum survival time (hours)
pub fn expected_survival_hours(&self) -> u32 { pub fn expected_survival_hours(&self) -> u32 {
match self { match self {
DisasterType::Avalanche => 2, // Limited air, hypothermia DisasterType::Avalanche => 2, // Limited air, hypothermia
DisasterType::Flood => 6, // Drowning risk DisasterType::Flood => 6, // Drowning risk
DisasterType::MineCollapse => 72, // Air supply critical DisasterType::MineCollapse => 72, // Air supply critical
DisasterType::BuildingCollapse => 96, DisasterType::BuildingCollapse => 96,
DisasterType::Earthquake => 120, DisasterType::Earthquake => 120,
DisasterType::Landslide => 48, DisasterType::Landslide => 48,
@@ -188,11 +185,7 @@ pub struct EventMetadata {
impl DisasterEvent { impl DisasterEvent {
/// Create a new disaster event /// Create a new disaster event
pub fn new( pub fn new(event_type: DisasterType, location: Point<f64>, description: &str) -> Self {
event_type: DisasterType,
location: Point<f64>,
description: &str,
) -> Self {
Self { Self {
id: DisasterEventId::new(), id: DisasterEventId::new(),
event_type, event_type,
@@ -297,7 +290,9 @@ impl DisasterEvent {
if let Some(existing) = existing_id { if let Some(existing) = existing_id {
// Update existing survivor // Update existing survivor
let survivor = self.survivors.iter_mut() let survivor = self
.survivors
.iter_mut()
.find(|s| s.id() == &existing) .find(|s| s.id() == &existing)
.ok_or_else(|| MatError::Domain("Survivor not found".into()))?; .ok_or_else(|| MatError::Domain("Survivor not found".into()))?;
survivor.update_vitals(vitals); survivor.update_vitals(vitals);
@@ -311,7 +306,10 @@ impl DisasterEvent {
let survivor = Survivor::new(zone_id, vitals, location); let survivor = Survivor::new(zone_id, vitals, location);
self.survivors.push(survivor); self.survivors.push(survivor);
// Safe: we just pushed, so last() is always Some // Safe: we just pushed, so last() is always Some
Ok(self.survivors.last().expect("survivors is non-empty after push")) Ok(self
.survivors
.last()
.expect("survivors is non-empty after push"))
} }
/// Find a survivor near a location /// Find a survivor near a location
@@ -425,7 +423,7 @@ impl TriageCounts {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{ZoneBounds, BreathingPattern, BreathingType, ConfidenceScore}; use crate::domain::{BreathingPattern, BreathingType, ConfidenceScore, ZoneBounds};
fn create_test_vitals() -> VitalSignsReading { fn create_test_vitals() -> VitalSignsReading {
VitalSignsReading { VitalSignsReading {
@@ -456,11 +454,8 @@ mod tests {
#[test] #[test]
fn test_add_zone_activates_event() { fn test_add_zone_activates_event() {
let mut event = DisasterEvent::new( let mut event =
DisasterType::BuildingCollapse, DisasterEvent::new(DisasterType::BuildingCollapse, Point::new(0.0, 0.0), "Test");
Point::new(0.0, 0.0),
"Test",
);
assert_eq!(event.status(), &EventStatus::Initializing); assert_eq!(event.status(), &EventStatus::Initializing);
@@ -472,11 +467,7 @@ mod tests {
#[test] #[test]
fn test_record_detection() { fn test_record_detection() {
let mut event = DisasterEvent::new( let mut event = DisasterEvent::new(DisasterType::Earthquake, Point::new(0.0, 0.0), "Test");
DisasterType::Earthquake,
Point::new(0.0, 0.0),
"Test",
);
let zone = ScanZone::new("Zone A", ZoneBounds::rectangle(0.0, 0.0, 10.0, 10.0)); let zone = ScanZone::new("Zone A", ZoneBounds::rectangle(0.0, 0.0, 10.0, 10.0));
let zone_id = zone.id().clone(); let zone_id = zone.id().clone();
@@ -490,6 +481,9 @@ mod tests {
#[test] #[test]
fn test_disaster_type_survival_hours() { fn test_disaster_type_survival_hours() {
assert!(DisasterType::Avalanche.expected_survival_hours() < DisasterType::Earthquake.expected_survival_hours()); assert!(
DisasterType::Avalanche.expected_survival_hours()
< DisasterType::Earthquake.expected_survival_hours()
);
} }
} }
@@ -1,10 +1,11 @@
//! Domain events for the wifi-Mat system. //! Domain events for the wifi-Mat system.
#![allow(missing_docs)]
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use super::{ use super::{
AlertId, Coordinates3D, Priority, ScanZoneId, SurvivorId, AlertId, AlertResolution, Coordinates3D, Priority, ScanZoneId, SurvivorId, TriageStatus,
TriageStatus, VitalSignsReading, AlertResolution, VitalSignsReading,
}; };
/// All domain events in the system /// All domain events in the system
@@ -422,7 +423,7 @@ pub enum ErrorSeverity {
pub enum TrackingEvent { pub enum TrackingEvent {
/// A tentative track has been confirmed (Tentative → Active). /// A tentative track has been confirmed (Tentative → Active).
TrackBorn { TrackBorn {
track_id: String, // TrackId as string (avoids circular dep) track_id: String, // TrackId as string (avoids circular dep)
survivor_id: SurvivorId, survivor_id: SurvivorId,
zone_id: ScanZoneId, zone_id: ScanZoneId,
timestamp: DateTime<Utc>, timestamp: DateTime<Utc>,
@@ -66,12 +66,21 @@ pub enum ZoneBounds {
impl ZoneBounds { impl ZoneBounds {
/// Create a rectangular zone /// Create a rectangular zone
pub fn rectangle(min_x: f64, min_y: f64, max_x: f64, max_y: f64) -> Self { pub fn rectangle(min_x: f64, min_y: f64, max_x: f64, max_y: f64) -> Self {
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } ZoneBounds::Rectangle {
min_x,
min_y,
max_x,
max_y,
}
} }
/// Create a circular zone /// Create a circular zone
pub fn circle(center_x: f64, center_y: f64, radius: f64) -> Self { pub fn circle(center_x: f64, center_y: f64, radius: f64) -> Self {
ZoneBounds::Circle { center_x, center_y, radius } ZoneBounds::Circle {
center_x,
center_y,
radius,
}
} }
/// Create a polygon zone /// Create a polygon zone
@@ -82,12 +91,13 @@ impl ZoneBounds {
/// Calculate the area of the zone in square meters /// Calculate the area of the zone in square meters
pub fn area(&self) -> f64 { pub fn area(&self) -> f64 {
match self { match self {
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } => { ZoneBounds::Rectangle {
(max_x - min_x) * (max_y - min_y) min_x,
} min_y,
ZoneBounds::Circle { radius, .. } => { max_x,
std::f64::consts::PI * radius * radius max_y,
} } => (max_x - min_x) * (max_y - min_y),
ZoneBounds::Circle { radius, .. } => std::f64::consts::PI * radius * radius,
ZoneBounds::Polygon { vertices } => { ZoneBounds::Polygon { vertices } => {
// Shoelace formula // Shoelace formula
if vertices.len() < 3 { if vertices.len() < 3 {
@@ -108,10 +118,17 @@ impl ZoneBounds {
/// Check if a point is within the zone bounds /// Check if a point is within the zone bounds
pub fn contains(&self, x: f64, y: f64) -> bool { pub fn contains(&self, x: f64, y: f64) -> bool {
match self { match self {
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } => { ZoneBounds::Rectangle {
x >= *min_x && x <= *max_x && y >= *min_y && y <= *max_y min_x,
} min_y,
ZoneBounds::Circle { center_x, center_y, radius } => { max_x,
max_y,
} => x >= *min_x && x <= *max_x && y >= *min_y && y <= *max_y,
ZoneBounds::Circle {
center_x,
center_y,
radius,
} => {
let dx = x - center_x; let dx = x - center_x;
let dy = y - center_y; let dy = y - center_y;
(dx * dx + dy * dy).sqrt() <= *radius (dx * dx + dy * dy).sqrt() <= *radius
@@ -127,9 +144,7 @@ impl ZoneBounds {
for i in 0..n { for i in 0..n {
let (xi, yi) = vertices[i]; let (xi, yi) = vertices[i];
let (xj, yj) = vertices[j]; let (xj, yj) = vertices[j];
if ((yi > y) != (yj > y)) if ((yi > y) != (yj > y)) && (x < (xj - xi) * (y - yi) / (yj - yi) + xi) {
&& (x < (xj - xi) * (y - yi) / (yj - yi) + xi)
{
inside = !inside; inside = !inside;
} }
j = i; j = i;
@@ -142,12 +157,15 @@ impl ZoneBounds {
/// Get the center point of the zone /// Get the center point of the zone
pub fn center(&self) -> (f64, f64) { pub fn center(&self) -> (f64, f64) {
match self { match self {
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } => { ZoneBounds::Rectangle {
((min_x + max_x) / 2.0, (min_y + max_y) / 2.0) min_x,
} min_y,
ZoneBounds::Circle { center_x, center_y, .. } => { max_x,
(*center_x, *center_y) max_y,
} } => ((min_x + max_x) / 2.0, (min_y + max_y) / 2.0),
ZoneBounds::Circle {
center_x, center_y, ..
} => (*center_x, *center_y),
ZoneBounds::Polygon { vertices } => { ZoneBounds::Polygon { vertices } => {
if vertices.is_empty() { if vertices.is_empty() {
return (0.0, 0.0); return (0.0, 0.0);
@@ -271,6 +289,7 @@ pub struct ScanZone {
sensor_positions: Vec<SensorPosition>, sensor_positions: Vec<SensorPosition>,
parameters: ScanParameters, parameters: ScanParameters,
status: ZoneStatus, status: ZoneStatus,
#[allow(dead_code)]
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
last_scan: Option<DateTime<Utc>>, last_scan: Option<DateTime<Utc>>,
scan_count: u32, scan_count: u32,
@@ -403,9 +422,11 @@ impl ScanZone {
/// Check if zone has enough sensors for localization /// Check if zone has enough sensors for localization
pub fn has_sufficient_sensors(&self) -> bool { pub fn has_sufficient_sensors(&self) -> bool {
// Need at least 3 sensors for 2D localization // Need at least 3 sensors for 2D localization
self.sensor_positions.iter() self.sensor_positions
.iter()
.filter(|s| s.is_operational) .filter(|s| s.is_operational)
.count() >= 3 .count()
>= 3
} }
/// Time since last scan /// Time since last scan
@@ -440,10 +461,7 @@ mod tests {
#[test] #[test]
fn test_scan_zone_creation() { fn test_scan_zone_creation() {
let zone = ScanZone::new( let zone = ScanZone::new("Test Zone", ZoneBounds::rectangle(0.0, 0.0, 50.0, 30.0));
"Test Zone",
ZoneBounds::rectangle(0.0, 0.0, 50.0, 30.0),
);
assert_eq!(zone.name(), "Test Zone"); assert_eq!(zone.name(), "Test Zone");
assert!(matches!(zone.status(), ZoneStatus::Active)); assert!(matches!(zone.status(), ZoneStatus::Active));
@@ -452,10 +470,7 @@ mod tests {
#[test] #[test]
fn test_scan_zone_sensors() { fn test_scan_zone_sensors() {
let mut zone = ScanZone::new( let mut zone = ScanZone::new("Test Zone", ZoneBounds::rectangle(0.0, 0.0, 50.0, 30.0));
"Test Zone",
ZoneBounds::rectangle(0.0, 0.0, 50.0, 30.0),
);
assert!(!zone.has_sufficient_sensors()); assert!(!zone.has_sufficient_sensors());
@@ -475,10 +490,7 @@ mod tests {
#[test] #[test]
fn test_scan_zone_status_transitions() { fn test_scan_zone_status_transitions() {
let mut zone = ScanZone::new( let mut zone = ScanZone::new("Test", ZoneBounds::rectangle(0.0, 0.0, 10.0, 10.0));
"Test",
ZoneBounds::rectangle(0.0, 0.0, 10.0, 10.0),
);
assert!(matches!(zone.status(), ZoneStatus::Active)); assert!(matches!(zone.status(), ZoneStatus::Active));
@@ -3,10 +3,7 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use super::{ use super::{triage::TriageCalculator, Coordinates3D, ScanZoneId, TriageStatus, VitalSignsReading};
Coordinates3D, TriageStatus, VitalSignsReading, ScanZoneId,
triage::TriageCalculator,
};
/// Unique identifier for a survivor /// Unique identifier for a survivor
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -138,9 +135,7 @@ impl VitalSignsHistory {
if self.readings.is_empty() { if self.readings.is_empty() {
return 0.0; return 0.0;
} }
let sum: f64 = self.readings.iter() let sum: f64 = self.readings.iter().map(|r| r.confidence.value()).sum();
.map(|r| r.confidence.value())
.sum();
sum / self.readings.len() as f64 sum / self.readings.len() as f64
} }
@@ -153,17 +148,18 @@ impl VitalSignsHistory {
let recent: Vec<_> = self.readings.iter().rev().take(3).collect(); let recent: Vec<_> = self.readings.iter().rev().take(3).collect();
// Check breathing trend // Check breathing trend
let breathing_declining = recent.windows(2).all(|w| { let breathing_declining =
match (&w[0].breathing, &w[1].breathing) { recent
(Some(a), Some(b)) => a.rate_bpm < b.rate_bpm, .windows(2)
_ => false, .all(|w| match (&w[0].breathing, &w[1].breathing) {
} (Some(a), Some(b)) => a.rate_bpm < b.rate_bpm,
}); _ => false,
});
// Check confidence trend // Check confidence trend
let confidence_declining = recent.windows(2).all(|w| { let confidence_declining = recent
w[0].confidence.value() < w[1].confidence.value() .windows(2)
}); .all(|w| w[0].confidence.value() < w[1].confidence.value());
breathing_declining || confidence_declining breathing_declining || confidence_declining
} }
@@ -3,7 +3,7 @@
//! The START (Simple Triage and Rapid Treatment) protocol is used to //! The START (Simple Triage and Rapid Treatment) protocol is used to
//! quickly categorize victims in mass casualty incidents. //! quickly categorize victims in mass casualty incidents.
use super::{VitalSignsReading, BreathingType, MovementType}; use super::{BreathingType, MovementType, VitalSignsReading};
/// Triage status following START protocol /// Triage status following START protocol
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -132,9 +132,7 @@ impl TriageCalculator {
/// Assess movement/responsiveness /// Assess movement/responsiveness
fn assess_movement(vitals: &VitalSignsReading) -> MovementAssessment { fn assess_movement(vitals: &VitalSignsReading) -> MovementAssessment {
match vitals.movement.movement_type { match vitals.movement.movement_type {
MovementType::Gross if vitals.movement.is_voluntary => { MovementType::Gross if vitals.movement.is_voluntary => MovementAssessment::Responsive,
MovementAssessment::Responsive
}
MovementType::Gross => MovementAssessment::Moving, MovementType::Gross => MovementAssessment::Moving,
MovementType::Fine => MovementAssessment::MinimalMovement, MovementType::Fine => MovementAssessment::MinimalMovement,
MovementType::Tremor => MovementAssessment::InvoluntaryOnly, MovementType::Tremor => MovementAssessment::InvoluntaryOnly,
@@ -150,32 +148,20 @@ impl TriageCalculator {
) -> TriageStatus { ) -> TriageStatus {
match (breathing, movement) { match (breathing, movement) {
// No breathing // No breathing
(BreathingAssessment::Absent, MovementAssessment::None) => { (BreathingAssessment::Absent, MovementAssessment::None) => TriageStatus::Deceased,
TriageStatus::Deceased (BreathingAssessment::Agonal, _) => TriageStatus::Immediate,
}
(BreathingAssessment::Agonal, _) => {
TriageStatus::Immediate
}
(BreathingAssessment::Absent, _) => { (BreathingAssessment::Absent, _) => {
// No breathing but movement - possible airway obstruction // No breathing but movement - possible airway obstruction
TriageStatus::Immediate TriageStatus::Immediate
} }
// Abnormal breathing rates // Abnormal breathing rates
(BreathingAssessment::TooFast, _) => { (BreathingAssessment::TooFast, _) => TriageStatus::Immediate,
TriageStatus::Immediate (BreathingAssessment::TooSlow, _) => TriageStatus::Immediate,
}
(BreathingAssessment::TooSlow, _) => {
TriageStatus::Immediate
}
// Normal breathing with movement assessment // Normal breathing with movement assessment
(BreathingAssessment::Normal, MovementAssessment::Responsive) => { (BreathingAssessment::Normal, MovementAssessment::Responsive) => TriageStatus::Minor,
TriageStatus::Minor (BreathingAssessment::Normal, MovementAssessment::Moving) => TriageStatus::Delayed,
}
(BreathingAssessment::Normal, MovementAssessment::Moving) => {
TriageStatus::Delayed
}
(BreathingAssessment::Normal, MovementAssessment::MinimalMovement) => { (BreathingAssessment::Normal, MovementAssessment::MinimalMovement) => {
TriageStatus::Delayed TriageStatus::Delayed
} }
@@ -288,7 +274,10 @@ mod tests {
is_voluntary: false, is_voluntary: false,
}, },
); );
assert_eq!(TriageCalculator::calculate(&vitals), TriageStatus::Immediate); assert_eq!(
TriageCalculator::calculate(&vitals),
TriageStatus::Immediate
);
} }
#[test] #[test]
@@ -307,7 +296,10 @@ mod tests {
is_voluntary: false, is_voluntary: false,
}, },
); );
assert_eq!(TriageCalculator::calculate(&vitals), TriageStatus::Immediate); assert_eq!(
TriageCalculator::calculate(&vitals),
TriageStatus::Immediate
);
} }
#[test] #[test]
@@ -321,7 +313,10 @@ mod tests {
}), }),
MovementProfile::default(), MovementProfile::default(),
); );
assert_eq!(TriageCalculator::calculate(&vitals), TriageStatus::Immediate); assert_eq!(
TriageCalculator::calculate(&vitals),
TriageStatus::Immediate
);
} }
#[test] #[test]
@@ -344,11 +344,7 @@ mod tests {
pattern_type: BreathingType::Normal, pattern_type: BreathingType::Normal,
}; };
let reading = VitalSignsReading::new( let reading = VitalSignsReading::new(Some(breathing), None, MovementProfile::default());
Some(breathing),
None,
MovementProfile::default(),
);
assert!(reading.has_vitals()); assert!(reading.has_vitals());
assert!(reading.has_breathing()); assert!(reading.has_breathing());
@@ -3,6 +3,7 @@
//! This module provides receivers for: //! This module provides receivers for:
//! - UDP packets (network streaming from remote sensors) //! - UDP packets (network streaming from remote sensors)
//! - Serial port (ESP32 and similar embedded devices) //! - Serial port (ESP32 and similar embedded devices)
#![allow(missing_docs)]
//! - PCAP files (offline analysis and replay) //! - PCAP files (offline analysis and replay)
//! //!
//! # Example //! # Example
@@ -20,10 +21,10 @@
//! } //! }
//! ``` //! ```
use super::AdapterError;
use super::hardware_adapter::{ use super::hardware_adapter::{
Bandwidth, CsiMetadata, CsiReadings, DeviceType, FrameControlType, SensorCsiReading, Bandwidth, CsiMetadata, CsiReadings, DeviceType, FrameControlType, SensorCsiReading,
}; };
use super::AdapterError;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io::{BufReader, Read}; use std::io::{BufReader, Read};
@@ -268,7 +269,11 @@ impl UdpCsiReceiver {
pub async fn new(config: ReceiverConfig) -> Result<Self, AdapterError> { pub async fn new(config: ReceiverConfig) -> Result<Self, AdapterError> {
let udp_config = match &config.source { let udp_config = match &config.source {
CsiSource::Udp(c) => c, CsiSource::Udp(c) => c,
_ => return Err(AdapterError::Config("Invalid config for UDP receiver".into())), _ => {
return Err(AdapterError::Config(
"Invalid config for UDP receiver".into(),
))
}
}; };
let addr = format!("{}:{}", udp_config.bind_address, udp_config.port); let addr = format!("{}:{}", udp_config.bind_address, udp_config.port);
@@ -328,7 +333,10 @@ impl UdpCsiReceiver {
} }
} }
} }
Ok(Err(e)) => Err(AdapterError::Hardware(format!("Socket receive error: {}", e))), Ok(Err(e)) => Err(AdapterError::Hardware(format!(
"Socket receive error: {}",
e
))),
Err(_) => Ok(None), // Timeout Err(_) => Ok(None), // Timeout
} }
} }
@@ -347,6 +355,7 @@ impl UdpCsiReceiver {
/// Serial CSI receiver /// Serial CSI receiver
pub struct SerialCsiReceiver { pub struct SerialCsiReceiver {
config: ReceiverConfig, config: ReceiverConfig,
#[allow(dead_code)]
port_path: String, port_path: String,
buffer: VecDeque<u8>, buffer: VecDeque<u8>,
parser: CsiParser, parser: CsiParser,
@@ -359,7 +368,11 @@ impl SerialCsiReceiver {
pub fn new(config: ReceiverConfig) -> Result<Self, AdapterError> { pub fn new(config: ReceiverConfig) -> Result<Self, AdapterError> {
let serial_config = match &config.source { let serial_config = match &config.source {
CsiSource::Serial(c) => c, CsiSource::Serial(c) => c,
_ => return Err(AdapterError::Config("Invalid config for serial receiver".into())), _ => {
return Err(AdapterError::Config(
"Invalid config for serial receiver".into(),
))
}
}; };
// Verify port exists // Verify port exists
@@ -517,7 +530,11 @@ impl PcapCsiReader {
pub fn new(config: ReceiverConfig) -> Result<Self, AdapterError> { pub fn new(config: ReceiverConfig) -> Result<Self, AdapterError> {
let pcap_config = match &config.source { let pcap_config = match &config.source {
CsiSource::Pcap(c) => c, CsiSource::Pcap(c) => c,
_ => return Err(AdapterError::Config("Invalid config for PCAP reader".into())), _ => {
return Err(AdapterError::Config(
"Invalid config for PCAP reader".into(),
))
}
}; };
if !Path::new(&pcap_config.file_path).exists() { if !Path::new(&pcap_config.file_path).exists() {
@@ -656,9 +673,9 @@ impl PcapCsiReader {
// Read packet data // Read packet data
let mut data = vec![0u8; incl_len as usize]; let mut data = vec![0u8; incl_len as usize];
reader.read_exact(&mut data).map_err(|e| { reader
AdapterError::Hardware(format!("Failed to read packet data: {}", e)) .read_exact(&mut data)
})?; .map_err(|e| AdapterError::Hardware(format!("Failed to read packet data: {}", e)))?;
// Convert timestamp // Convert timestamp
let timestamp = chrono::DateTime::from_timestamp(ts_sec as i64, ts_usec * 1000) let timestamp = chrono::DateTime::from_timestamp(ts_sec as i64, ts_usec * 1000)
@@ -770,6 +787,7 @@ impl PcapCsiReader {
} }
/// PCAP global header structure /// PCAP global header structure
#[allow(dead_code)]
struct PcapGlobalHeader { struct PcapGlobalHeader {
magic: u32, magic: u32,
version_major: u16, version_major: u16,
@@ -807,7 +825,9 @@ impl CsiParser {
CsiPacketFormat::PicoScenes => self.parse_picoscenes(data), CsiPacketFormat::PicoScenes => self.parse_picoscenes(data),
CsiPacketFormat::JsonCsi => self.parse_json(data), CsiPacketFormat::JsonCsi => self.parse_json(data),
CsiPacketFormat::RawBinary => self.parse_raw_binary(data), CsiPacketFormat::RawBinary => self.parse_raw_binary(data),
CsiPacketFormat::Auto => Err(AdapterError::DataFormat("Unable to detect format".into())), CsiPacketFormat::Auto => {
Err(AdapterError::DataFormat("Unable to detect format".into()))
}
} }
} }
@@ -915,7 +935,9 @@ impl CsiParser {
fn parse_intel_5300(&self, data: &[u8]) -> Result<CsiPacket, AdapterError> { fn parse_intel_5300(&self, data: &[u8]) -> Result<CsiPacket, AdapterError> {
// Intel 5300 BFEE structure (from Linux CSI Tool) // Intel 5300 BFEE structure (from Linux CSI Tool)
if data.len() < 25 { if data.len() < 25 {
return Err(AdapterError::DataFormat("Intel 5300 packet too short".into())); return Err(AdapterError::DataFormat(
"Intel 5300 packet too short".into(),
));
} }
// Parse header // Parse header
@@ -1105,7 +1127,9 @@ impl CsiParser {
fn parse_picoscenes(&self, data: &[u8]) -> Result<CsiPacket, AdapterError> { fn parse_picoscenes(&self, data: &[u8]) -> Result<CsiPacket, AdapterError> {
// PicoScenes has a complex structure with multiple segments // PicoScenes has a complex structure with multiple segments
if data.len() < 100 { if data.len() < 100 {
return Err(AdapterError::DataFormat("PicoScenes packet too short".into())); return Err(AdapterError::DataFormat(
"PicoScenes packet too short".into(),
));
} }
// PicoScenes CSI segment parsing is not yet implemented. // PicoScenes CSI segment parsing is not yet implemented.
@@ -1124,34 +1148,20 @@ impl CsiParser {
let json: serde_json::Value = serde_json::from_str(json_str) let json: serde_json::Value = serde_json::from_str(json_str)
.map_err(|e| AdapterError::DataFormat(format!("Invalid JSON: {}", e)))?; .map_err(|e| AdapterError::DataFormat(format!("Invalid JSON: {}", e)))?;
let rssi = json let rssi = json.get("rssi").and_then(|v| v.as_i64()).unwrap_or(-50) as i8;
.get("rssi")
.and_then(|v| v.as_i64())
.unwrap_or(-50) as i8;
let channel = json let channel = json.get("channel").and_then(|v| v.as_u64()).unwrap_or(6) as u8;
.get("channel")
.and_then(|v| v.as_u64())
.unwrap_or(6) as u8;
let amplitudes: Vec<f64> = json let amplitudes: Vec<f64> = json
.get("amplitudes") .get("amplitudes")
.and_then(|v| v.as_array()) .and_then(|v| v.as_array())
.map(|arr| { .map(|arr| arr.iter().filter_map(|v| v.as_f64()).collect())
arr.iter()
.filter_map(|v| v.as_f64())
.collect()
})
.unwrap_or_default(); .unwrap_or_default();
let phases: Vec<f64> = json let phases: Vec<f64> = json
.get("phases") .get("phases")
.and_then(|v| v.as_array()) .and_then(|v| v.as_array())
.map(|arr| { .map(|arr| arr.iter().filter_map(|v| v.as_f64()).collect())
arr.iter()
.filter_map(|v| v.as_f64())
.collect()
})
.unwrap_or_default(); .unwrap_or_default();
let source_id = json let source_id = json
@@ -1343,9 +1353,11 @@ mod tests {
#[test] #[test]
fn test_receiver_stats() { fn test_receiver_stats() {
let mut stats = ReceiverStats::default(); let mut stats = ReceiverStats {
stats.packets_received = 100; packets_received: 100,
stats.packets_parsed = 95; packets_parsed: 95,
..ReceiverStats::default()
};
assert!((stats.success_rate() - 0.95).abs() < 0.001); assert!((stats.success_rate() - 0.95).abs() < 0.001);
@@ -3,6 +3,7 @@
//! This module provides adapters for various WiFi CSI hardware: //! This module provides adapters for various WiFi CSI hardware:
//! - ESP32 with CSI support via serial communication //! - ESP32 with CSI support via serial communication
//! - Intel 5300 NIC with Linux CSI Tool //! - Intel 5300 NIC with Linux CSI Tool
#![allow(missing_docs)]
//! - Atheros CSI extraction via ath9k/ath10k drivers //! - Atheros CSI extraction via ath9k/ath10k drivers
//! //!
//! # Example //! # Example
@@ -362,6 +363,7 @@ struct DeviceState {
} }
/// Device-specific runtime state /// Device-specific runtime state
#[allow(dead_code)]
enum DeviceSpecificState { enum DeviceSpecificState {
Esp32 { Esp32 {
firmware_version: Option<String>, firmware_version: Option<String>,
@@ -444,7 +446,10 @@ impl HardwareAdapter {
/// Initialize hardware communication /// Initialize hardware communication
pub async fn initialize(&mut self) -> Result<(), AdapterError> { pub async fn initialize(&mut self) -> Result<(), AdapterError> {
tracing::info!("Initializing hardware adapter for {:?}", self.config.device_type); tracing::info!(
"Initializing hardware adapter for {:?}",
self.config.device_type
);
match &self.config.device_type { match &self.config.device_type {
DeviceType::Esp32 => self.initialize_esp32().await?, DeviceType::Esp32 => self.initialize_esp32().await?,
@@ -468,10 +473,18 @@ impl HardwareAdapter {
async fn initialize_esp32(&mut self) -> Result<(), AdapterError> { async fn initialize_esp32(&mut self) -> Result<(), AdapterError> {
let settings = match &self.config.device_settings { let settings = match &self.config.device_settings {
DeviceSettings::Serial(s) => s, DeviceSettings::Serial(s) => s,
_ => return Err(AdapterError::Config("ESP32 requires serial settings".into())), _ => {
return Err(AdapterError::Config(
"ESP32 requires serial settings".into(),
))
}
}; };
tracing::info!("Initializing ESP32 on {} at {} baud", settings.port, settings.baud_rate); tracing::info!(
"Initializing ESP32 on {} at {} baud",
settings.port,
settings.baud_rate
);
// Verify serial port exists // Verify serial port exists
#[cfg(unix)] #[cfg(unix)]
@@ -498,10 +511,17 @@ impl HardwareAdapter {
async fn initialize_intel_5300(&mut self) -> Result<(), AdapterError> { async fn initialize_intel_5300(&mut self) -> Result<(), AdapterError> {
let settings = match &self.config.device_settings { let settings = match &self.config.device_settings {
DeviceSettings::NetworkInterface(s) => s, DeviceSettings::NetworkInterface(s) => s,
_ => return Err(AdapterError::Config("Intel 5300 requires network interface settings".into())), _ => {
return Err(AdapterError::Config(
"Intel 5300 requires network interface settings".into(),
))
}
}; };
tracing::info!("Initializing Intel 5300 on interface {}", settings.interface); tracing::info!(
"Initializing Intel 5300 on interface {}",
settings.interface
);
// Check if iwlwifi driver is loaded // Check if iwlwifi driver is loaded
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
@@ -509,7 +529,9 @@ impl HardwareAdapter {
let output = tokio::process::Command::new("lsmod") let output = tokio::process::Command::new("lsmod")
.output() .output()
.await .await
.map_err(|e| AdapterError::Hardware(format!("Failed to check kernel modules: {}", e)))?; .map_err(|e| {
AdapterError::Hardware(format!("Failed to check kernel modules: {}", e))
})?;
let stdout = String::from_utf8_lossy(&output.stdout); let stdout = String::from_utf8_lossy(&output.stdout);
if !stdout.contains("iwlwifi") { if !stdout.contains("iwlwifi") {
@@ -536,7 +558,11 @@ impl HardwareAdapter {
async fn initialize_atheros(&mut self, driver: AtherosDriver) -> Result<(), AdapterError> { async fn initialize_atheros(&mut self, driver: AtherosDriver) -> Result<(), AdapterError> {
let settings = match &self.config.device_settings { let settings = match &self.config.device_settings {
DeviceSettings::NetworkInterface(s) => s, DeviceSettings::NetworkInterface(s) => s,
_ => return Err(AdapterError::Config("Atheros requires network interface settings".into())), _ => {
return Err(AdapterError::Config(
"Atheros requires network interface settings".into(),
))
}
}; };
tracing::info!( tracing::info!(
@@ -578,10 +604,18 @@ impl HardwareAdapter {
async fn initialize_udp(&mut self) -> Result<(), AdapterError> { async fn initialize_udp(&mut self) -> Result<(), AdapterError> {
let settings = match &self.config.device_settings { let settings = match &self.config.device_settings {
DeviceSettings::Udp(s) => s, DeviceSettings::Udp(s) => s,
_ => return Err(AdapterError::Config("UDP receiver requires UDP settings".into())), _ => {
return Err(AdapterError::Config(
"UDP receiver requires UDP settings".into(),
))
}
}; };
tracing::info!("Initializing UDP receiver on {}:{}", settings.bind_address, settings.port); tracing::info!(
"Initializing UDP receiver on {}:{}",
settings.bind_address,
settings.port
);
// Verify port is available // Verify port is available
let addr = format!("{}:{}", settings.bind_address, settings.port); let addr = format!("{}:{}", settings.bind_address, settings.port);
@@ -597,7 +631,9 @@ impl HardwareAdapter {
socket socket
.join_multicast_v4(multicast_addr, std::net::Ipv4Addr::UNSPECIFIED) .join_multicast_v4(multicast_addr, std::net::Ipv4Addr::UNSPECIFIED)
.map_err(|e| AdapterError::Hardware(format!("Failed to join multicast group: {}", e)))?; .map_err(|e| {
AdapterError::Hardware(format!("Failed to join multicast group: {}", e))
})?;
} }
// Socket will be recreated when streaming starts // Socket will be recreated when streaming starts
@@ -638,7 +674,9 @@ impl HardwareAdapter {
return Err(AdapterError::Hardware("Hardware not initialized".into())); return Err(AdapterError::Hardware("Hardware not initialized".into()));
} }
let broadcaster = self.csi_broadcaster.as_ref() let broadcaster = self
.csi_broadcaster
.as_ref()
.ok_or_else(|| AdapterError::Hardware("CSI broadcaster not initialized".into()))?; .ok_or_else(|| AdapterError::Hardware("CSI broadcaster not initialized".into()))?;
// Create shutdown channel // Create shutdown channel
@@ -1068,17 +1106,28 @@ impl HardwareAdapter {
} }
/// Configure channel settings /// Configure channel settings
pub async fn set_channel(&mut self, channel: u8, bandwidth: Bandwidth) -> Result<(), AdapterError> { pub async fn set_channel(
&mut self,
channel: u8,
bandwidth: Bandwidth,
) -> Result<(), AdapterError> {
if !self.initialized { if !self.initialized {
return Err(AdapterError::Hardware("Hardware not initialized".into())); return Err(AdapterError::Hardware("Hardware not initialized".into()));
} }
// Validate channel // Validate channel
let valid_2g = (1..=14).contains(&channel); let valid_2g = (1..=14).contains(&channel);
let valid_5g = [36, 40, 44, 48, 52, 56, 60, 64, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140, 144, 149, 153, 157, 161, 165].contains(&channel); let valid_5g = [
36, 40, 44, 48, 52, 56, 60, 64, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140,
144, 149, 153, 157, 161, 165,
]
.contains(&channel);
if !valid_2g && !valid_5g { if !valid_2g && !valid_5g {
return Err(AdapterError::Config(format!("Invalid WiFi channel: {}", channel))); return Err(AdapterError::Config(format!(
"Invalid WiFi channel: {}",
channel
)));
} }
self.config.channel_config.channel = channel; self.config.channel_config.channel = channel;
@@ -1321,7 +1370,10 @@ mod tests {
#[test] #[test]
fn test_atheros_config() { fn test_atheros_config() {
let config = HardwareConfig::atheros("wlan0", AtherosDriver::Ath10k); let config = HardwareConfig::atheros("wlan0", AtherosDriver::Ath10k);
assert!(matches!(config.device_type, DeviceType::Atheros(AtherosDriver::Ath10k))); assert!(matches!(
config.device_type,
DeviceType::Atheros(AtherosDriver::Ath10k)
));
assert_eq!(config.channel_config.num_subcarriers, 114); assert_eq!(config.channel_config.num_subcarriers, 114);
} }
@@ -36,69 +36,69 @@
//! let mut receiver = UdpCsiReceiver::new(config).await?; //! let mut receiver = UdpCsiReceiver::new(config).await?;
//! ``` //! ```
mod signal_adapter;
mod neural_adapter;
mod hardware_adapter;
pub mod csi_receiver; pub mod csi_receiver;
mod hardware_adapter;
mod neural_adapter;
mod signal_adapter;
pub use signal_adapter::SignalAdapter;
pub use neural_adapter::NeuralAdapter;
pub use hardware_adapter::{ pub use hardware_adapter::{
AntennaConfig,
AtherosDriver,
Bandwidth,
ChannelConfig,
CsiMetadata,
// CSI data types
CsiReadings,
CsiStream,
DeviceSettings,
DeviceType,
FlowControl,
FrameControlType,
// Main adapter // Main adapter
HardwareAdapter, HardwareAdapter,
// Configuration types // Configuration types
HardwareConfig, HardwareConfig,
DeviceType,
DeviceSettings,
AtherosDriver,
ChannelConfig,
Bandwidth,
// Serial settings
SerialSettings,
Parity,
FlowControl,
// Network interface settings
NetworkInterfaceSettings,
AntennaConfig,
// UDP settings
UdpSettings,
// PCAP settings
PcapSettings,
// Sensor types
SensorInfo,
SensorStatus,
// CSI data types
CsiReadings,
CsiMetadata,
SensorCsiReading,
FrameControlType,
CsiStream,
// Health and stats // Health and stats
HardwareHealth, HardwareHealth,
HealthStatus, HealthStatus,
// Network interface settings
NetworkInterfaceSettings,
Parity,
// PCAP settings
PcapSettings,
SensorCsiReading,
// Sensor types
SensorInfo,
SensorStatus,
// Serial settings
SerialSettings,
StreamingStats, StreamingStats,
// UDP settings
UdpSettings,
}; };
pub use neural_adapter::NeuralAdapter;
pub use signal_adapter::SignalAdapter;
pub use csi_receiver::{ pub use csi_receiver::{
// Receiver types
UdpCsiReceiver,
SerialCsiReceiver,
PcapCsiReader,
// Configuration
ReceiverConfig,
CsiSource,
UdpSourceConfig,
SerialSourceConfig,
PcapSourceConfig,
SerialParity,
// Packet types // Packet types
CsiPacket, CsiPacket,
CsiPacketMetadata,
CsiPacketFormat, CsiPacketFormat,
CsiPacketMetadata,
// Parser // Parser
CsiParser, CsiParser,
CsiSource,
PcapCsiReader,
PcapSourceConfig,
// Configuration
ReceiverConfig,
// Stats // Stats
ReceiverStats, ReceiverStats,
SerialCsiReceiver,
SerialParity,
SerialSourceConfig,
// Receiver types
UdpCsiReceiver,
UdpSourceConfig,
}; };
/// Configuration for integration layer /// Configuration for integration layer
@@ -181,16 +181,8 @@ pub enum AdapterError {
/// Prelude module for convenient imports /// Prelude module for convenient imports
pub mod prelude { pub mod prelude {
pub use super::{ pub use super::{
AdapterError, AdapterError, AtherosDriver, Bandwidth, CsiPacket, CsiPacketFormat, CsiReadings,
HardwareAdapter, DeviceType, HardwareAdapter, HardwareConfig, IntegrationConfig,
HardwareConfig,
DeviceType,
AtherosDriver,
Bandwidth,
CsiReadings,
CsiPacket,
CsiPacketFormat,
IntegrationConfig,
}; };
} }
@@ -1,14 +1,16 @@
//! Adapter for wifi-densepose-nn crate (neural network inference). //! Adapter for wifi-densepose-nn crate (neural network inference).
use super::signal_adapter::VitalFeatures;
use super::AdapterError; use super::AdapterError;
use crate::domain::{BreathingPattern, BreathingType, HeartbeatSignature, SignalStrength}; use crate::domain::{BreathingPattern, BreathingType, HeartbeatSignature, SignalStrength};
use super::signal_adapter::VitalFeatures;
/// Adapter for neural network-based vital signs detection /// Adapter for neural network-based vital signs detection
pub struct NeuralAdapter { pub struct NeuralAdapter {
/// Whether to use GPU acceleration /// Whether to use GPU acceleration
#[allow(dead_code)]
use_gpu: bool, use_gpu: bool,
/// Confidence threshold for valid detections /// Confidence threshold for valid detections
#[allow(dead_code)]
confidence_threshold: f32, confidence_threshold: f32,
/// Model loaded status /// Model loaded status
models_loaded: bool, models_loaded: bool,
@@ -74,11 +76,7 @@ impl NeuralAdapter {
let heartbeat = self.classify_heartbeat(features)?; let heartbeat = self.classify_heartbeat(features)?;
// Calculate overall confidence // Calculate overall confidence
let confidence = self.calculate_confidence( let confidence = self.calculate_confidence(&breathing, &heartbeat, features.signal_quality);
&breathing,
&heartbeat,
features.signal_quality,
);
Ok(VitalsClassification { Ok(VitalsClassification {
breathing, breathing,
@@ -106,7 +104,7 @@ impl NeuralAdapter {
let rate_bpm = (peak_freq * 60.0) as f32; let rate_bpm = (peak_freq * 60.0) as f32;
// Validate rate // Validate rate
if rate_bpm < 4.0 || rate_bpm > 60.0 { if !(4.0..=60.0).contains(&rate_bpm) {
return None; return None;
} }
@@ -148,7 +146,7 @@ impl NeuralAdapter {
let rate_bpm = (peak_freq * 60.0) as f32; let rate_bpm = (peak_freq * 60.0) as f32;
// Validate rate (30-200 BPM) // Validate rate (30-200 BPM)
if rate_bpm < 30.0 || rate_bpm > 200.0 { if !(30.0..=200.0).contains(&rate_bpm) {
return None; return None;
} }
@@ -237,7 +235,7 @@ mod tests {
fn create_weak_features() -> VitalFeatures { fn create_weak_features() -> VitalFeatures {
VitalFeatures { VitalFeatures {
breathing_features: vec![0.25, 0.02, 0.05], // Weak breathing_features: vec![0.25, 0.02, 0.05], // Weak
heartbeat_features: vec![1.2, 0.01, 0.02], // Very weak heartbeat_features: vec![1.2, 0.01, 0.02], // Very weak
movement_features: vec![0.01, 0.005, 0.001], movement_features: vec![0.01, 0.005, 0.001],
signal_quality: 0.3, signal_quality: 0.3,
} }
@@ -1,8 +1,8 @@
//! Adapter for wifi-densepose-signal crate. //! Adapter for wifi-densepose-signal crate.
use super::AdapterError; use super::AdapterError;
use crate::domain::{BreathingPattern, BreathingType};
use crate::detection::CsiDataBuffer; use crate::detection::CsiDataBuffer;
use crate::domain::{BreathingPattern, BreathingType};
/// Features extracted from signal for vital signs detection /// Features extracted from signal for vital signs detection
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@@ -20,8 +20,10 @@ pub struct VitalFeatures {
/// Adapter for wifi-densepose-signal crate /// Adapter for wifi-densepose-signal crate
pub struct SignalAdapter { pub struct SignalAdapter {
/// Window size for processing /// Window size for processing
#[allow(dead_code)]
window_size: usize, window_size: usize,
/// Overlap between windows /// Overlap between windows
#[allow(dead_code)]
overlap: f64, overlap: f64,
/// Sample rate /// Sample rate
sample_rate: f64, sample_rate: f64,
@@ -49,23 +51,15 @@ impl SignalAdapter {
) -> Result<VitalFeatures, AdapterError> { ) -> Result<VitalFeatures, AdapterError> {
if csi_data.amplitudes.len() < self.window_size { if csi_data.amplitudes.len() < self.window_size {
return Err(AdapterError::Signal( return Err(AdapterError::Signal(
"Insufficient data for feature extraction".into() "Insufficient data for feature extraction".into(),
)); ));
} }
// Extract breathing-range features (0.1-0.5 Hz) // Extract breathing-range features (0.1-0.5 Hz)
let breathing_features = self.extract_frequency_band( let breathing_features = self.extract_frequency_band(&csi_data.amplitudes, 0.1, 0.5)?;
&csi_data.amplitudes,
0.1,
0.5,
)?;
// Extract heartbeat-range features (0.8-2.0 Hz) // Extract heartbeat-range features (0.8-2.0 Hz)
let heartbeat_features = self.extract_frequency_band( let heartbeat_features = self.extract_frequency_band(&csi_data.phases, 0.8, 2.0)?;
&csi_data.phases,
0.8,
2.0,
)?;
// Extract movement features // Extract movement features
let movement_features = self.extract_movement_features(&csi_data.amplitudes)?; let movement_features = self.extract_movement_features(&csi_data.amplitudes)?;
@@ -82,10 +76,7 @@ impl SignalAdapter {
} }
/// Convert upstream CsiFeatures to breathing pattern /// Convert upstream CsiFeatures to breathing pattern
pub fn to_breathing_pattern( pub fn to_breathing_pattern(&self, features: &VitalFeatures) -> Option<BreathingPattern> {
&self,
features: &VitalFeatures,
) -> Option<BreathingPattern> {
if features.breathing_features.len() < 3 { if features.breathing_features.len() < 3 {
return None; return None;
} }
@@ -99,7 +90,7 @@ impl SignalAdapter {
let rate_bpm = (rate_estimate * 60.0) as f32; let rate_bpm = (rate_estimate * 60.0) as f32;
// Validate rate // Validate rate
if rate_bpm < 4.0 || rate_bpm > 60.0 { if !(4.0..=60.0).contains(&rate_bpm) {
return None; return None;
} }
@@ -121,7 +112,7 @@ impl SignalAdapter {
low_freq: f64, low_freq: f64,
high_freq: f64, high_freq: f64,
) -> Result<Vec<f64>, AdapterError> { ) -> Result<Vec<f64>, AdapterError> {
use rustfft::{FftPlanner, num_complex::Complex}; use rustfft::{num_complex::Complex, FftPlanner};
let n = signal.len().min(self.window_size); let n = signal.len().min(self.window_size);
if n < 32 { if n < 32 {
@@ -133,7 +124,8 @@ impl SignalAdapter {
let fft = planner.plan_fft_forward(fft_size); let fft = planner.plan_fft_forward(fft_size);
// Prepare buffer with windowing // Prepare buffer with windowing
let mut buffer: Vec<Complex<f64>> = signal.iter() let mut buffer: Vec<Complex<f64>> = signal
.iter()
.take(n) .take(n)
.enumerate() .enumerate()
.map(|(i, &x)| { .map(|(i, &x)| {
@@ -156,29 +148,37 @@ impl SignalAdapter {
// Find peak frequency // Find peak frequency
let mut max_mag = 0.0; let mut max_mag = 0.0;
let mut peak_bin = low_bin; let mut peak_bin = low_bin;
for i in low_bin..=high_bin { for (idx, val) in buffer[low_bin..=high_bin].iter().enumerate() {
let mag = buffer[i].norm(); let mag = val.norm();
if mag > max_mag { if mag > max_mag {
max_mag = mag; max_mag = mag;
peak_bin = i; peak_bin = low_bin + idx;
} }
} }
// Peak frequency // Peak frequency
features.push(peak_bin as f64 * freq_resolution); features.push(peak_bin as f64 * freq_resolution);
// Peak magnitude (normalized) // Peak magnitude (normalized)
let total_power: f64 = buffer[1..buffer.len()/2] let total_power: f64 = buffer[1..buffer.len() / 2]
.iter() .iter()
.map(|c| c.norm_sqr()) .map(|c| c.norm_sqr())
.sum(); .sum();
features.push(if total_power > 0.0 { max_mag * max_mag / total_power } else { 0.0 }); features.push(if total_power > 0.0 {
max_mag * max_mag / total_power
} else {
0.0
});
// Band power ratio // Band power ratio
let band_power: f64 = buffer[low_bin..=high_bin] let band_power: f64 = buffer[low_bin..=high_bin]
.iter() .iter()
.map(|c| c.norm_sqr()) .map(|c| c.norm_sqr())
.sum(); .sum();
features.push(if total_power > 0.0 { band_power / total_power } else { 0.0 }); features.push(if total_power > 0.0 {
band_power / total_power
} else {
0.0
});
} }
Ok(features) Ok(features)
@@ -192,18 +192,18 @@ impl SignalAdapter {
// Calculate variance // Calculate variance
let mean = signal.iter().sum::<f64>() / signal.len() as f64; let mean = signal.iter().sum::<f64>() / signal.len() as f64;
let variance = signal.iter() let variance = signal.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / signal.len() as f64;
.map(|x| (x - mean).powi(2))
.sum::<f64>() / signal.len() as f64;
// Calculate max absolute change // Calculate max absolute change
let max_change = signal.windows(2) let max_change = signal
.windows(2)
.map(|w| (w[1] - w[0]).abs()) .map(|w| (w[1] - w[0]).abs())
.fold(0.0, f64::max); .fold(0.0, f64::max);
// Calculate zero crossing rate // Calculate zero crossing rate
let centered: Vec<f64> = signal.iter().map(|x| x - mean).collect(); let centered: Vec<f64> = signal.iter().map(|x| x - mean).collect();
let zero_crossings: usize = centered.windows(2) let zero_crossings: usize = centered
.windows(2)
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0)) .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
.count(); .count();
let zcr = zero_crossings as f64 / signal.len() as f64; let zcr = zero_crossings as f64 / signal.len() as f64;
@@ -219,9 +219,7 @@ impl SignalAdapter {
// SNR estimate based on signal statistics // SNR estimate based on signal statistics
let mean = signal.iter().sum::<f64>() / signal.len() as f64; let mean = signal.iter().sum::<f64>() / signal.len() as f64;
let variance = signal.iter() let variance = signal.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / signal.len() as f64;
.map(|x| (x - mean).powi(2))
.sum::<f64>() / signal.len() as f64;
// Higher variance relative to mean suggests better signal // Higher variance relative to mean suggests better signal
let snr_estimate = if mean.abs() > 1e-10 { let snr_estimate = if mean.abs() > 1e-10 {
@@ -323,9 +321,7 @@ mod tests {
let adapter = SignalAdapter::with_defaults(); let adapter = SignalAdapter::with_defaults();
// Good signal // Good signal
let good_signal: Vec<f64> = (0..100) let good_signal: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1).sin()).collect();
.map(|i| (i as f64 * 0.1).sin())
.collect();
let good_quality = adapter.calculate_signal_quality(&good_signal); let good_quality = adapter.calculate_signal_quality(&good_signal);
// Poor signal (constant) // Poor signal (constant)
+135 -91
View File
@@ -88,65 +88,71 @@ pub mod tracking;
// Re-export main types // Re-export main types
pub use domain::{ pub use domain::{
survivor::{Survivor, SurvivorId, SurvivorMetadata, SurvivorStatus},
disaster_event::{DisasterEvent, DisasterEventId, DisasterType, EventStatus},
scan_zone::{ScanZone, ScanZoneId, ZoneBounds, ZoneStatus, ScanParameters},
alert::{Alert, AlertId, AlertPayload, Priority}, alert::{Alert, AlertId, AlertPayload, Priority},
vital_signs::{ coordinates::{Coordinates3D, DepthEstimate, LocationUncertainty},
VitalSignsReading, BreathingPattern, BreathingType, disaster_event::{DisasterEvent, DisasterEventId, DisasterType, EventStatus},
HeartbeatSignature, MovementProfile, MovementType, events::{
AlertEvent, DetectionEvent, DomainEvent, EventStore, InMemoryEventStore, TrackingEvent,
},
scan_zone::{ScanParameters, ScanZone, ScanZoneId, ZoneBounds, ZoneStatus},
survivor::{Survivor, SurvivorId, SurvivorMetadata, SurvivorStatus},
triage::{TriageCalculator, TriageStatus},
vital_signs::{
BreathingPattern, BreathingType, HeartbeatSignature, MovementProfile, MovementType,
VitalSignsReading,
}, },
triage::{TriageStatus, TriageCalculator},
coordinates::{Coordinates3D, LocationUncertainty, DepthEstimate},
events::{DetectionEvent, AlertEvent, DomainEvent, EventStore, InMemoryEventStore, TrackingEvent},
}; };
pub use detection::{ pub use detection::{
BreathingDetector, BreathingDetectorConfig, BreathingDetector, BreathingDetectorConfig, DetectionConfig, DetectionPipeline,
HeartbeatDetector, HeartbeatDetectorConfig, EnsembleClassifier, EnsembleConfig, EnsembleResult, HeartbeatDetector, HeartbeatDetectorConfig,
MovementClassifier, MovementClassifierConfig, MovementClassifier, MovementClassifierConfig, VitalSignsDetector,
VitalSignsDetector, DetectionPipeline, DetectionConfig,
EnsembleClassifier, EnsembleConfig, EnsembleResult,
}; };
pub use localization::{ pub use localization::{
Triangulator, TriangulationConfig, DepthEstimator, DepthEstimatorConfig, LocalizationService, PositionFuser, TriangulationConfig,
DepthEstimator, DepthEstimatorConfig, Triangulator,
PositionFuser, LocalizationService,
}; };
pub use alerting::{ pub use alerting::{
AlertGenerator, AlertDispatcher, AlertConfig, AlertConfig, AlertDispatcher, AlertGenerator, PriorityCalculator, TriageService,
TriageService, PriorityCalculator,
}; };
pub use integration::{ pub use integration::{
SignalAdapter, NeuralAdapter, HardwareAdapter, AdapterError, HardwareAdapter, IntegrationConfig, NeuralAdapter, SignalAdapter,
AdapterError, IntegrationConfig,
}; };
pub use api::{ pub use api::{create_router, AppState};
create_router, AppState,
};
pub use ml::{ pub use ml::{
// Core ML types AttenuationPrediction,
MlError, MlResult, MlDetectionConfig, MlDetectionPipeline, MlDetectionResult, BreathingClassification,
ClassifierOutput,
DebrisClassification,
DebrisFeatureExtractor,
DebrisFeatures,
DebrisModel,
DebrisModelConfig,
// Debris penetration model // Debris penetration model
DebrisPenetrationModel, DebrisFeatures, DepthEstimate as MlDepthEstimate, DebrisPenetrationModel,
DebrisModel, DebrisModelConfig, DebrisFeatureExtractor, DepthEstimate as MlDepthEstimate,
MaterialType, DebrisClassification, AttenuationPrediction, HeartbeatClassification,
MaterialType,
MlDetectionConfig,
MlDetectionPipeline,
MlDetectionResult,
// Core ML types
MlError,
MlResult,
UncertaintyEstimate,
// Vital signs classifier // Vital signs classifier
VitalSignsClassifier, VitalSignsClassifierConfig, VitalSignsClassifier,
BreathingClassification, HeartbeatClassification, VitalSignsClassifierConfig,
UncertaintyEstimate, ClassifierOutput,
}; };
pub use tracking::{ pub use tracking::{
SurvivorTracker, TrackerConfig, TrackId, TrackedSurvivor, AssociationResult, CsiFingerprint, DetectionObservation, KalmanState, SurvivorTracker, TrackId,
DetectionObservation, AssociationResult, TrackLifecycle, TrackState, TrackedSurvivor, TrackerConfig,
KalmanState, CsiFingerprint,
TrackState, TrackLifecycle,
}; };
/// Library version /// Library version
@@ -399,18 +405,18 @@ impl DisasterResponse {
location: geo::Point<f64>, location: geo::Point<f64>,
description: &str, description: &str,
) -> Result<&DisasterEvent> { ) -> Result<&DisasterEvent> {
let event = DisasterEvent::new( let event = DisasterEvent::new(self.config.disaster_type.clone(), location, description);
self.config.disaster_type.clone(),
location,
description,
);
self.event = Some(event); self.event = Some(event);
self.event.as_ref().ok_or_else(|| MatError::Domain("Failed to create event".into())) self.event
.as_ref()
.ok_or_else(|| MatError::Domain("Failed to create event".into()))
} }
/// Add a scan zone to the current event /// Add a scan zone to the current event
pub fn add_zone(&mut self, zone: ScanZone) -> Result<()> { pub fn add_zone(&mut self, zone: ScanZone) -> Result<()> {
let event = self.event.as_mut() let event = self
.event
.as_mut()
.ok_or_else(|| MatError::Domain("No active disaster event".into()))?; .ok_or_else(|| MatError::Domain("No active disaster event".into()))?;
event.add_zone(zone); event.add_zone(zone);
Ok(()) Ok(())
@@ -429,9 +435,10 @@ impl DisasterResponse {
break; break;
} }
tokio::time::sleep( tokio::time::sleep(std::time::Duration::from_millis(
std::time::Duration::from_millis(self.config.scan_interval_ms) self.config.scan_interval_ms,
).await; ))
.await;
} }
Ok(()) Ok(())
@@ -455,7 +462,9 @@ impl DisasterResponse {
let mut detections = Vec::new(); let mut detections = Vec::new();
{ {
let event = self.event.as_ref() let event = self
.event
.as_ref()
.ok_or_else(|| MatError::Domain("No active disaster event".into()))?; .ok_or_else(|| MatError::Domain("No active disaster event".into()))?;
for zone in event.zones() { for zone in event.zones() {
@@ -473,10 +482,17 @@ impl DisasterResponse {
// Only proceed if ensemble confidence meets threshold // Only proceed if ensemble confidence meets threshold
if ensemble_result.confidence >= self.config.confidence_threshold { if ensemble_result.confidence >= self.config.confidence_threshold {
// Attempt localization // Attempt localization
let location = self.localization_service let location = self
.localization_service
.estimate_position(&vital_signs, zone); .estimate_position(&vital_signs, zone);
detections.push((zone.id().clone(), zone.name().to_string(), vital_signs, location, ensemble_result)); detections.push((
zone.id().clone(),
zone.name().to_string(),
vital_signs,
location,
ensemble_result,
));
} }
} }
@@ -494,22 +510,25 @@ impl DisasterResponse {
} }
// Now process detections with mutable access // Now process detections with mutable access
let event = self.event.as_mut() let event = self
.event
.as_mut()
.ok_or_else(|| MatError::Domain("No active disaster event".into()))?; .ok_or_else(|| MatError::Domain("No active disaster event".into()))?;
for (zone_id, _zone_name, vital_signs, location, _ensemble) in detections { for (zone_id, _zone_name, vital_signs, location, _ensemble) in detections {
let survivor = event.record_detection(zone_id.clone(), vital_signs.clone(), location.clone())?; let survivor =
event.record_detection(zone_id.clone(), vital_signs.clone(), location.clone())?;
// Emit SurvivorDetected domain event // Emit SurvivorDetected domain event
let _ = self.event_store.append(DomainEvent::Detection( let _ =
DetectionEvent::SurvivorDetected { self.event_store
survivor_id: survivor.id().clone(), .append(DomainEvent::Detection(DetectionEvent::SurvivorDetected {
zone_id, survivor_id: survivor.id().clone(),
vital_signs, zone_id,
location, vital_signs,
timestamp: chrono::Utc::now(), location,
}, timestamp: chrono::Utc::now(),
)); }));
// Generate and dispatch alert if needed // Generate and dispatch alert if needed
if survivor.should_alert() { if survivor.should_alert() {
@@ -519,14 +538,14 @@ impl DisasterResponse {
let survivor_id = alert.survivor_id().clone(); let survivor_id = alert.survivor_id().clone();
// Emit AlertGenerated domain event // Emit AlertGenerated domain event
let _ = self.event_store.append(DomainEvent::Alert( let _ = self
AlertEvent::AlertGenerated { .event_store
.append(DomainEvent::Alert(AlertEvent::AlertGenerated {
alert_id, alert_id,
survivor_id, survivor_id,
priority, priority,
timestamp: chrono::Utc::now(), timestamp: chrono::Utc::now(),
}, }));
));
self.alert_dispatcher.dispatch(alert).await?; self.alert_dispatcher.dispatch(alert).await?;
} }
@@ -542,7 +561,8 @@ impl DisasterResponse {
/// Get all detected survivors /// Get all detected survivors
pub fn survivors(&self) -> Vec<&Survivor> { pub fn survivors(&self) -> Vec<&Survivor> {
self.event.as_ref() self.event
.as_ref()
.map(|e| e.survivors()) .map(|e| e.survivors())
.unwrap_or_default() .unwrap_or_default()
} }
@@ -559,29 +579,57 @@ impl DisasterResponse {
/// Prelude module for convenient imports /// Prelude module for convenient imports
pub mod prelude { pub mod prelude {
pub use crate::{ pub use crate::{
DisasterConfig, DisasterConfigBuilder, DisasterResponse, Alert,
MatError, Result,
// Domain types
Survivor, SurvivorId, DisasterEvent, DisasterType,
ScanZone, ZoneBounds, TriageStatus,
VitalSignsReading, BreathingPattern, HeartbeatSignature,
Coordinates3D, Alert, Priority,
// Event sourcing
DomainEvent, EventStore, InMemoryEventStore,
DetectionEvent, AlertEvent, TrackingEvent,
// Detection
DetectionPipeline, VitalSignsDetector,
EnsembleClassifier, EnsembleConfig, EnsembleResult,
// Localization
LocalizationService,
// Alerting // Alerting
AlertDispatcher, AlertDispatcher,
AlertEvent,
AssociationResult,
BreathingPattern,
Coordinates3D,
DebrisClassification,
DebrisModel,
DetectionEvent,
DetectionObservation,
// Detection
DetectionPipeline,
DisasterConfig,
DisasterConfigBuilder,
DisasterEvent,
DisasterResponse,
DisasterType,
// Event sourcing
DomainEvent,
EnsembleClassifier,
EnsembleConfig,
EnsembleResult,
EventStore,
HeartbeatSignature,
InMemoryEventStore,
// Localization
LocalizationService,
MatError,
MaterialType,
// ML types // ML types
MlDetectionConfig, MlDetectionPipeline, MlDetectionResult, MlDetectionConfig,
DebrisModel, MaterialType, DebrisClassification, MlDetectionPipeline,
VitalSignsClassifier, UncertaintyEstimate, MlDetectionResult,
Priority,
Result,
ScanZone,
// Domain types
Survivor,
SurvivorId,
// Tracking // Tracking
SurvivorTracker, TrackerConfig, TrackId, DetectionObservation, AssociationResult, SurvivorTracker,
TrackId,
TrackerConfig,
TrackingEvent,
TriageStatus,
UncertaintyEstimate,
VitalSignsClassifier,
VitalSignsDetector,
VitalSignsReading,
ZoneBounds,
}; };
} }
@@ -606,21 +654,17 @@ mod tests {
#[test] #[test]
fn test_sensitivity_clamping() { fn test_sensitivity_clamping() {
let config = DisasterConfig::builder() let config = DisasterConfig::builder().sensitivity(1.5).build();
.sensitivity(1.5)
.build();
assert!((config.sensitivity - 1.0).abs() < f64::EPSILON); assert!((config.sensitivity - 1.0).abs() < f64::EPSILON);
let config = DisasterConfig::builder() let config = DisasterConfig::builder().sensitivity(-0.5).build();
.sensitivity(-0.5)
.build();
assert!(config.sensitivity.abs() < f64::EPSILON); assert!(config.sensitivity.abs() < f64::EPSILON);
} }
#[test] #[test]
fn test_version() { fn test_version() {
assert!(!VERSION.is_empty()); assert!(VERSION.contains('.'), "VERSION should be a semver string");
} }
} }
@@ -1,6 +1,6 @@
//! Depth estimation through debris layers. //! Depth estimation through debris layers.
use crate::domain::{DebrisProfile, DepthEstimate, DebrisMaterial, MoistureLevel}; use crate::domain::{DebrisMaterial, DebrisProfile, DepthEstimate, MoistureLevel};
/// Configuration for depth estimation /// Configuration for depth estimation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -20,7 +20,7 @@ impl Default for DepthEstimatorConfig {
Self { Self {
max_depth: 10.0, max_depth: 10.0,
min_attenuation: 3.0, min_attenuation: 3.0,
frequency_ghz: 5.8, // 5.8 GHz WiFi frequency_ghz: 5.8, // 5.8 GHz WiFi
free_space_loss_1m: 47.0, // FSPL at 1m for 5.8 GHz free_space_loss_1m: 47.0, // FSPL at 1m for 5.8 GHz
} }
} }
@@ -45,8 +45,8 @@ impl DepthEstimator {
/// Estimate depth from signal attenuation /// Estimate depth from signal attenuation
pub fn estimate_depth( pub fn estimate_depth(
&self, &self,
signal_attenuation: f64, // Total attenuation in dB signal_attenuation: f64, // Total attenuation in dB
distance_2d: f64, // Horizontal distance in meters distance_2d: f64, // Horizontal distance in meters
debris_profile: &DebrisProfile, debris_profile: &DebrisProfile,
) -> Option<DepthEstimate> { ) -> Option<DepthEstimate> {
if signal_attenuation < self.config.min_attenuation { if signal_attenuation < self.config.min_attenuation {
@@ -178,7 +178,7 @@ impl DepthEstimator {
pub fn estimate_from_multipath( pub fn estimate_from_multipath(
&self, &self,
direct_path_attenuation: f64, direct_path_attenuation: f64,
reflected_paths: &[(f64, f64)], // (attenuation, delay) reflected_paths: &[(f64, f64)], // (attenuation, delay)
debris_profile: &DebrisProfile, debris_profile: &DebrisProfile,
) -> Option<DepthEstimate> { ) -> Option<DepthEstimate> {
// Use path differences to estimate depth // Use path differences to estimate depth
@@ -191,7 +191,8 @@ impl DepthEstimator {
let avg_extra_path: f64 = reflected_paths let avg_extra_path: f64 = reflected_paths
.iter() .iter()
.map(|(_, delay)| delay * SPEED_OF_LIGHT / 2.0) // Round trip .map(|(_, delay)| delay * SPEED_OF_LIGHT / 2.0) // Round trip
.sum::<f64>() / reflected_paths.len() as f64; .sum::<f64>()
/ reflected_paths.len() as f64;
// Extra path length is approximately related to depth // Extra path length is approximately related to depth
// (reflections bounce off debris layers) // (reflections bounce off debris layers)
@@ -279,7 +280,10 @@ mod tests {
// High multipath = concrete // High multipath = concrete
let profile2 = estimator.estimate_debris_profile(0.2, 0.8, 0.3); let profile2 = estimator.estimate_debris_profile(0.2, 0.8, 0.3);
assert!(matches!(profile2.primary_material, DebrisMaterial::HeavyConcrete)); assert!(matches!(
profile2.primary_material,
DebrisMaterial::HeavyConcrete
));
} }
#[test] #[test]

Some files were not shown because too many files have changed in this diff Show More