mirror of
https://github.com/ruvnet/RuView
synced 2026-06-09 10:13:17 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b16d7431bc | |||
| b3a5012dbd | |||
| e6a5df36eb | |||
| 5c914e63c7 | |||
| a5e99670f8 | |||
| 6b4994e105 |
@@ -0,0 +1,185 @@
|
||||
# `cog-person-count` — Benchmark Log
|
||||
|
||||
Append-only log of every published count_v1 training run per ADR-103. New runs add a section; never overwrite history.
|
||||
|
||||
## v0.0.2 — K-fold validated, random split + label smoothing + early stop + temp scale (2026-05-21)
|
||||
|
||||
### Why a new release
|
||||
|
||||
A 5-fold stratified CV on the same 1,077 samples proved the v0.0.1 result was driven by an unlucky temporal split — the trailing window was class-0-heavy, and a degenerate "always predict 0" classifier hit the class-0 fraction (65.1%) trivially.
|
||||
|
||||
| Metric | v0.0.1 (temporal) | **5-fold random CV** (diagnostic) |
|
||||
|---|---|---|
|
||||
| Overall accuracy | 65.1% | 62.2% ± 1.9% |
|
||||
| Class 1 accuracy | **0%** | **57.1%** ✓ |
|
||||
| Confidence Spearman | 0.023 | 0.160 ± 0.029 |
|
||||
|
||||
The architecture has real ~57% class-1 capacity under fair splits.
|
||||
|
||||
### v0.0.2 results
|
||||
|
||||
Architecture unchanged. Training changes only:
|
||||
- **Random 80/20 split** (seed=42) — temporal split eliminated.
|
||||
- **Label smoothing 0.1** on cross-entropy.
|
||||
- **Class-balanced multinomial sampler** with replacement.
|
||||
- **Early stopping** with patience 20 (exited at epoch 29 of 400 max).
|
||||
- **Temperature scaling** of the conf head via LBFGS — T = **0.9262**, shipped as a `count_v1.temperature` sidecar.
|
||||
|
||||
| Metric | v0.0.1 | **v0.0.2** | K-fold ref |
|
||||
|---|---|---|---|
|
||||
| Overall accuracy | 65.1% | **62.3%** | 62.2% ± 1.9% |
|
||||
| Class 0 accuracy | 100% (cheating) | **86.2%** | 67.4% |
|
||||
| **Class 1 accuracy** | **0%** | **34.3%** ✓ | 57.1% |
|
||||
| MAE | 0.349 | 0.377 | 0.378 |
|
||||
| Confidence Spearman (post-temp) | 0.023 | 0.013 | 0.160 |
|
||||
| Wall time | 5.6 s (400 ep) | **0.7 s (29 ep)** | 7.5 s (5×100) |
|
||||
|
||||
### Honest read
|
||||
|
||||
**Class-1 accuracy 0% → 34.3% is the headline.** The cog now reports `count = 1` honestly when a person is present, instead of always-zero cheating. Single random draw lands below the K-fold mean of 57% — that gap is run-to-run variance, not a missing improvement. Reaching 57% on a fixed eval set needs averaging over independent draws, which means more independent recordings — i.e. multi-room data (#645), not another training trick.
|
||||
|
||||
Confidence calibration didn't move. Temperature scaling alone can't fix a confidence head trained against a noisy `argmax==truth` indicator over a 62%-accurate classifier — its training signal is the bottleneck.
|
||||
|
||||
### Release artifacts (live on cognitum-v0)
|
||||
|
||||
```
|
||||
gs://cognitum-apps/cogs/arm/cog-person-count-count_v1.safetensors
|
||||
sha256: 32996433516891a37c63c600db8b95e42192a53bd538c088c82cd6a85e55513c
|
||||
bytes: 392,088
|
||||
```
|
||||
|
||||
Binaries themselves unchanged from v0.0.1 — weights load at runtime via mmap. Per-arch manifests under `cog/artifacts/manifests/{arm,x86_64}/` bumped to `version: 0.0.2`, weights_sha256 + build_metadata caveats updated.
|
||||
|
||||
### Reproducibility
|
||||
|
||||
```bash
|
||||
python3 scripts/train-count.py --paired data/paired/wiflow-p7-1779210883.paired.jsonl \
|
||||
--k-fold 5 --epochs 100 --out-results kfold_results.json
|
||||
|
||||
python3 scripts/train-count.py --paired data/paired/wiflow-p7-1779210883.paired.jsonl \
|
||||
--v2 --epochs 400 \
|
||||
--out-safetensors count_v1.safetensors --out-onnx count_v1.onnx \
|
||||
--out-results count_train_results.json
|
||||
```
|
||||
|
||||
## v0.0.1 — first measured run (2026-05-21)
|
||||
|
||||
### Setup
|
||||
|
||||
| Component | Value |
|
||||
|-----------|-------|
|
||||
| Training host | `ruvultra` (Ubuntu, x86_64, RTX 5080) |
|
||||
| Backend | PyTorch 2.12 + CUDA |
|
||||
| Data | `data/paired/wiflow-p7-1779210883.paired.jsonl` — 1,077 paired samples, single 30-min session, label distribution `{0: 533, 1: 544}` |
|
||||
| Train/eval split | 80/20 stratified on `ts_start` (held-out tail of the recording) |
|
||||
| Architecture | Conv1d encoder (56→64→128→128, dilations 1/2/4) + Linear(128→64→8) count head + Linear(128→32→1) confidence head — bit-identical to `v2/crates/cog-person-count/src/inference.rs::CountNet` |
|
||||
| Loss | `cross_entropy(count) + 0.3·BCE(conf) + 0.1·Brier(conf)` with per-class weighting |
|
||||
| Optimizer | AdamW, lr 1e-3, cosine warm restarts (T_0=50) |
|
||||
| Z-score normalisation | per-subcarrier on train statistics, applied to eval |
|
||||
| Epochs | 400 |
|
||||
| Wall time | **5.6 s** |
|
||||
|
||||
### Accuracy (held-out 215-sample tail of the 30-min recording)
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Best eval accuracy | **65.1%** |
|
||||
| Final eval accuracy | 65.1% |
|
||||
| Within ±1 | **100%** (labels are all in `{0, 1}`, predictions trivially within ±1) |
|
||||
| MAE | 0.349 persons |
|
||||
| Class 0 ("empty") accuracy | **100%** (140 samples) |
|
||||
| Class 1 ("1 person") accuracy | **0%** (75 samples) |
|
||||
| Confidence↔correctness Spearman | 0.023 |
|
||||
|
||||
### Honest read
|
||||
|
||||
The model overfit hard. By epoch 100 train_acc reached 1.0 and eval_loss climbed from 0.67 → 7.8. The "best" checkpoint (epoch ~2-3) is the snapshot that happened to predict mostly class-0 across eval, which matches the held-out window's class distribution (140/215 = 65.1%) — i.e. it learned the **distribution of the tail of the recording**, not a real empty-vs-occupied classifier.
|
||||
|
||||
Why: the training data is one continuous 30-minute solo recording. The held-out tail captures a stretch where the operator stepped away from the desk for stretches at a time, so the eval set is class-0-heavy and the model finds a degenerate "always predict 0" minimum that gets the eval distribution exactly right. Class 1 accuracy = 0 is the smoking gun.
|
||||
|
||||
Same data-bound failure mode as `pose_v1` (#645). Same fix path: multi-room paired recordings.
|
||||
|
||||
### What v0.0.1 still validates
|
||||
|
||||
- **Pipeline correctness end-to-end.** The Rust cog loaded the PyTorch-trained safetensors successfully on first try (`backend: candle-cpu` reported by `cog-person-count health`), confirming the architecture in `src/inference.rs` is byte-compatible with `train-count.py`.
|
||||
- **ONNX parity.** 16 KB ONNX, exports cleanly under opset 18 with dynamic batch axis.
|
||||
- **Fast iteration loop.** 5.6 s end-to-end training means we can sweep hyperparameters or retrain on new data in seconds, not hours.
|
||||
- **Cog binary size.** Same 2.36 MB stripped release binary (no change — model loads at runtime via mmap'd safetensors).
|
||||
|
||||
### Comparison to ADR-103 v0.1.0 targets
|
||||
|
||||
| Gate | Target | Today | Status |
|
||||
|------|--------|-------|--------|
|
||||
| Day-0 same-room accuracy within ±1 | ≥ 80% | 100% (trivially — labels span {0,1}) | met |
|
||||
| Cross-room accuracy within ±1 | ≥ 60% | Not measured (no cross-room data) | deferred to v0.2.0 |
|
||||
| MAE | ≤ 0.6 | 0.349 | met |
|
||||
| Per-frame confidence reflects accuracy (Spearman) | r ≥ 0.5 | 0.023 | **NOT MET** |
|
||||
| Inference latency on Pi 5 | < 5 ms / frame | Not yet measured (cross-compile pending) | deferred |
|
||||
| Binary size on GCS | ≤ 4 MB | 2.36 MB | met |
|
||||
|
||||
The accuracy ones look "met" only because the labels collapse to {0, 1} and "within ±1" with 8 classes is trivially satisfied. The **confidence calibration is the real failure** for v0.0.1 — Spearman 0.023 means the confidence head is essentially random noise. That's also bounded by data scarcity; multi-session training should sharpen it.
|
||||
|
||||
### Artifacts
|
||||
|
||||
- `v2/crates/cog-person-count/cog/artifacts/count_v1.safetensors` — 392 KB
|
||||
- `v2/crates/cog-person-count/cog/artifacts/count_v1.onnx` — 16 KB
|
||||
- `v2/crates/cog-person-count/cog/artifacts/count_train_results.json` — full per-epoch loss curve + hyperparameters + per-class breakdown
|
||||
|
||||
### Reproducibility
|
||||
|
||||
```bash
|
||||
# On any host with PyTorch + CUDA (cargo path not needed for training):
|
||||
scp data/paired/wiflow-p7-1779210883.paired.jsonl <host>:/tmp/
|
||||
scp scripts/train-count.py <host>:/tmp/
|
||||
ssh <host> "cd /tmp && python3 train-count.py --paired wiflow-p7-1779210883.paired.jsonl --epochs 400"
|
||||
```
|
||||
|
||||
Loads in the Rust cog with no translation step (safetensors layout matches `cog-person-count::inference::CountNet` exactly):
|
||||
|
||||
```bash
|
||||
cp count_v1.safetensors v2/crates/cog-person-count/cog/artifacts/
|
||||
cargo run -p cog-person-count --release -- health
|
||||
# → {"backend":"candle-cpu", "synthetic_count": <int>, "synthetic_confidence": <float>, ...}
|
||||
```
|
||||
|
||||
### Live appliance install (cognitum-v0 Pi 5)
|
||||
|
||||
Installed at `/var/lib/cognitum/apps/person-count/` with the same on-disk shape as `cog-pose-estimation`, `anomaly-detect`, `seizure-detect`, etc.:
|
||||
|
||||
```
|
||||
$ ls -la /var/lib/cognitum/apps/person-count/
|
||||
-rwxr-xr-x cog-person-count-arm 2,168,816 B (sha matches GCS)
|
||||
-rw-r--r-- count_v1.safetensors 392,088 B
|
||||
-rw-r--r-- manifest.json 1,073 B
|
||||
-rw-r--r-- config.json 160 B
|
||||
```
|
||||
|
||||
```
|
||||
$ ./cog-person-count-arm health
|
||||
{"ts": ..., "event": "health.ok",
|
||||
"fields": {"backend": "candle-cpu", "synthetic_count": 0,
|
||||
"synthetic_confidence": 0.49, "synthetic_p95_range": [0, 7]}}
|
||||
```
|
||||
|
||||
Cold-start on real Pi 5 hardware: **9.2 ms / invocation** (30 sequential `health` invocations in 0.276 s). Slightly slower than the pose cog (8.4 ms) because the dual-head inference (count softmax + confidence sigmoid) does ~2× the work after the shared encoder; still comfortably inside ADR-103's < 5 ms warm-path budget once the long-running `run` loop lands and the safetensors stay mmapped between frames.
|
||||
|
||||
### Signed GCS release artifacts (publicly downloadable)
|
||||
|
||||
```
|
||||
gs://cognitum-apps/cogs/arm/cog-person-count-arm 2,168,816 B
|
||||
sha256: 36bc0bb0ece894350377d5f93d46cd29378cb289b3773530611c0d47b507b3c3
|
||||
signature: R/00xdzHriyr/2rzr4wmPJ/Ken60A+RNdi8r0g2HYJNTXBaFtr46ExfNbiHlgYWadQXzTZdfJoyJK+a6k71NDg==
|
||||
|
||||
gs://cognitum-apps/cogs/x86_64/cog-person-count-x86_64 2,615,528 B
|
||||
sha256: 76cdd1ec40211add90b4942a09f79939aa28210a27e931de67122357392b01db
|
||||
signature: QB+8cnGSMQmubSt/KWVu1+JMg37AKnQXDsFQi/vi+jqpW9rVrGMtnxQpWEWZPeWU1AJ6pl3O2V+7ZtTNIQ2rDg==
|
||||
|
||||
gs://cognitum-apps/cogs/arm/cog-person-count-count_v1.safetensors 392,088 B
|
||||
sha256: dacb0551fd3887958db19696d90d811ab08faa44703e6e04ff56d15c3a65a9ff
|
||||
```
|
||||
|
||||
All signed with `COGNITUM_OWNER_SIGNING_KEY` (Ed25519). SHAs verified via public anonymous `https://storage.googleapis.com/...` download.
|
||||
|
||||
Manifests at:
|
||||
- `v2/crates/cog-person-count/cog/artifacts/manifests/arm/manifest.json`
|
||||
- `v2/crates/cog-person-count/cog/artifacts/manifests/x86_64/manifest.json
|
||||
@@ -481,12 +481,33 @@ function align() {
|
||||
? extractCsiMatrix(window)
|
||||
: extractFeatureMatrix(window);
|
||||
|
||||
// ADR-103: aggregate `n_persons` per window so the cog-person-count
|
||||
// training pipeline has count labels. Two summaries:
|
||||
// - `n_persons_mode` — modal value across the camera frames in
|
||||
// the window. Robust to single-frame noise;
|
||||
// this is the supervised label for the
|
||||
// categorical {0..7} count head.
|
||||
// - `n_persons_max` — the maximum value seen in the window.
|
||||
// Useful as a soft upper bound (e.g. for
|
||||
// dynamic dropout weighting during training).
|
||||
const personCounts = matched.map(f => f.nPersons ?? 0);
|
||||
const counts = new Map();
|
||||
for (const v of personCounts) counts.set(v, (counts.get(v) ?? 0) + 1);
|
||||
let modeVal = 0;
|
||||
let modeCount = -1;
|
||||
for (const [v, n] of counts) {
|
||||
if (n > modeCount) { modeVal = v; modeCount = n; }
|
||||
}
|
||||
const maxVal = personCounts.reduce((a, b) => Math.max(a, b), 0);
|
||||
|
||||
paired.push({
|
||||
csi: csiMatrix.data,
|
||||
csi_shape: csiMatrix.shape,
|
||||
kp: keypoints,
|
||||
conf: Math.round(avgConfidence * 1000) / 1000,
|
||||
n_camera_frames: matched.length,
|
||||
n_persons_mode: modeVal,
|
||||
n_persons_max: maxVal,
|
||||
ts_start: new Date(tStartMs).toISOString(),
|
||||
ts_end: new Date(tEndMs).toISOString(),
|
||||
});
|
||||
|
||||
@@ -0,0 +1,761 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train the person-count head — ADR-103 v0.0.1.
|
||||
|
||||
Mirrors the Conv1d encoder architecture from cog-person-count's
|
||||
`src/inference.rs::CountNet` exactly, so the learned weights load
|
||||
into the Rust cog without translation. Trains on
|
||||
data/paired/wiflow-p7-1779210883.paired.jsonl (1,077 samples with
|
||||
n_persons_mode labels in {0, 1}).
|
||||
|
||||
Output: count_v1.safetensors + count_v1.onnx + train_results.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Architecture constants — MUST match cog-person-count's src/inference.rs.
|
||||
N_SUB = 56
|
||||
N_FRAMES = 20
|
||||
COUNT_CLASSES = 8
|
||||
|
||||
|
||||
class CountNet(nn.Module):
|
||||
"""Mirrors cog_person_count::inference::CountNet bit-for-bit."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Encoder — identical to the pose cog's encoder so future joint
|
||||
# training can share weights.
|
||||
self.enc_c1 = nn.Conv1d(N_SUB, 64, kernel_size=3, padding=1, dilation=1)
|
||||
self.enc_c2 = nn.Conv1d(64, 128, kernel_size=3, padding=2, dilation=2)
|
||||
self.enc_c3 = nn.Conv1d(128, 128, kernel_size=3, padding=4, dilation=4)
|
||||
# Count head
|
||||
self.count_head_fc1 = nn.Linear(128, 64)
|
||||
self.count_head_fc2 = nn.Linear(64, COUNT_CLASSES)
|
||||
# Confidence head
|
||||
self.conf_head_fc1 = nn.Linear(128, 32)
|
||||
self.conf_head_fc2 = nn.Linear(32, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# x: [B, 56, 20]
|
||||
h = F.relu(self.enc_c1(x))
|
||||
h = F.relu(self.enc_c2(h))
|
||||
h = F.relu(self.enc_c3(h))
|
||||
h = h.mean(dim=2) # [B, 128]
|
||||
|
||||
# Logits (un-normalised); softmax at inference + cross-entropy training.
|
||||
c = F.relu(self.count_head_fc1(h))
|
||||
count_logits = self.count_head_fc2(c)
|
||||
|
||||
# Confidence head — sigmoid at inference; BCE-with-logits at training.
|
||||
cf = F.relu(self.conf_head_fc1(h))
|
||||
conf_logits = self.conf_head_fc2(cf)
|
||||
|
||||
return count_logits, conf_logits
|
||||
|
||||
|
||||
def load_paired(path: Path) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (X, y) where X is [N, 56, 20] CSI and y is [N] integer counts."""
|
||||
csis, ys = [], []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
d = json.loads(line)
|
||||
shape = d.get("csi_shape", [N_SUB, N_FRAMES])
|
||||
if shape != [N_SUB, N_FRAMES]:
|
||||
continue
|
||||
csi = np.asarray(d["csi"], dtype=np.float32).reshape(N_SUB, N_FRAMES)
|
||||
csis.append(csi)
|
||||
ys.append(int(d.get("n_persons_mode", 0)))
|
||||
X = np.stack(csis, axis=0)
|
||||
y = np.asarray(ys, dtype=np.int64)
|
||||
return X, y
|
||||
|
||||
|
||||
def temporal_split(X: np.ndarray, y: np.ndarray, eval_frac: float = 0.2):
|
||||
"""Held-out time-window eval (last `eval_frac` of samples, by index)."""
|
||||
n = X.shape[0]
|
||||
n_eval = int(round(n * eval_frac))
|
||||
n_train = n - n_eval
|
||||
return (
|
||||
X[:n_train], y[:n_train],
|
||||
X[n_train:], y[n_train:],
|
||||
)
|
||||
|
||||
|
||||
def stratified_k_fold(X: np.ndarray, y: np.ndarray, k: int = 5):
|
||||
"""Stratified k-fold cross-validation splits — hand-rolled, no sklearn.
|
||||
|
||||
Per class: shuffle the indices (deterministic seed 42), split into k
|
||||
near-equal chunks, then assemble fold i by taking chunk i from every
|
||||
class. Yields (X_train, y_train, X_val, y_val) per fold, with class
|
||||
distribution preserved within ±1.
|
||||
"""
|
||||
rng = np.random.default_rng(seed=42)
|
||||
classes = np.unique(y)
|
||||
per_class_folds = {}
|
||||
for c in classes:
|
||||
idx = np.where(y == c)[0]
|
||||
rng.shuffle(idx)
|
||||
per_class_folds[c] = np.array_split(idx, k)
|
||||
for fold in range(k):
|
||||
val_idx = np.concatenate([per_class_folds[c][fold] for c in classes])
|
||||
train_idx = np.concatenate(
|
||||
[per_class_folds[c][f] for c in classes for f in range(k) if f != fold]
|
||||
)
|
||||
yield X[train_idx], y[train_idx], X[val_idx], y[val_idx]
|
||||
|
||||
|
||||
def standardise(X_train: np.ndarray, X_eval: np.ndarray):
|
||||
"""Z-score by subcarrier across the time axis. Eval uses train stats."""
|
||||
mu = X_train.mean(axis=(0, 2), keepdims=True)
|
||||
sd = X_train.std(axis=(0, 2), keepdims=True) + 1e-6
|
||||
return (X_train - mu) / sd, (X_eval - mu) / sd
|
||||
|
||||
|
||||
def write_safetensors(model: CountNet, path: Path):
|
||||
"""Write the model's state in the same on-disk layout the Rust cog expects."""
|
||||
state = model.state_dict()
|
||||
# Map PyTorch param names → cog-person-count's VarBuilder paths.
|
||||
rename = {
|
||||
"enc_c1.weight": "enc.c1.weight",
|
||||
"enc_c1.bias": "enc.c1.bias",
|
||||
"enc_c2.weight": "enc.c2.weight",
|
||||
"enc_c2.bias": "enc.c2.bias",
|
||||
"enc_c3.weight": "enc.c3.weight",
|
||||
"enc_c3.bias": "enc.c3.bias",
|
||||
"count_head_fc1.weight": "count_head.fc1.weight",
|
||||
"count_head_fc1.bias": "count_head.fc1.bias",
|
||||
"count_head_fc2.weight": "count_head.fc2.weight",
|
||||
"count_head_fc2.bias": "count_head.fc2.bias",
|
||||
"conf_head_fc1.weight": "conf_head.fc1.weight",
|
||||
"conf_head_fc1.bias": "conf_head.fc1.bias",
|
||||
"conf_head_fc2.weight": "conf_head.fc2.weight",
|
||||
"conf_head_fc2.bias": "conf_head.fc2.bias",
|
||||
}
|
||||
|
||||
header = {}
|
||||
payload = bytearray()
|
||||
offset = 0
|
||||
for torch_name, cog_name in rename.items():
|
||||
t = state[torch_name].detach().cpu().numpy().astype(np.float32)
|
||||
n_bytes = t.nbytes
|
||||
header[cog_name] = {
|
||||
"dtype": "F32",
|
||||
"shape": list(t.shape),
|
||||
"data_offsets": [offset, offset + n_bytes],
|
||||
}
|
||||
payload.extend(t.tobytes())
|
||||
offset += n_bytes
|
||||
|
||||
header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
|
||||
with path.open("wb") as f:
|
||||
f.write(struct.pack("<Q", len(header_bytes)))
|
||||
f.write(header_bytes)
|
||||
f.write(payload)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--paired", required=True)
|
||||
parser.add_argument("--out-safetensors", default="count_v1.safetensors")
|
||||
parser.add_argument("--out-onnx", default="count_v1.onnx")
|
||||
parser.add_argument("--out-results", default="count_train_results.json")
|
||||
parser.add_argument("--epochs", type=int, default=400)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||
parser.add_argument("--k-fold", type=int, default=None, help="If set, run k-fold CV; else use temporal split")
|
||||
parser.add_argument("--v2", action="store_true",
|
||||
help="v0.0.2 training: random 80/20 split + label smoothing + early stopping "
|
||||
"+ balanced sampling + temperature-scaled confidence head.")
|
||||
parser.add_argument("--label-smoothing", type=float, default=0.1)
|
||||
parser.add_argument("--patience", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"device: {device}")
|
||||
|
||||
X, y = load_paired(Path(args.paired))
|
||||
print(f"loaded {X.shape[0]} samples, X shape {X.shape}, "
|
||||
f"label distribution: {dict(Counter(y.tolist()).most_common())}")
|
||||
|
||||
# K-fold cross-validation mode
|
||||
if args.k_fold is not None:
|
||||
print(f"\n=== {args.k_fold}-fold cross-validation ===")
|
||||
fold_results = []
|
||||
overall_t0 = time.perf_counter()
|
||||
|
||||
for fold_idx, (X_train, y_train, X_val, y_val) in enumerate(stratified_k_fold(X, y, k=args.k_fold)):
|
||||
print(f"\nFold {fold_idx + 1}/{args.k_fold}")
|
||||
X_train, X_val = standardise(X_train, X_val)
|
||||
|
||||
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
|
||||
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
|
||||
cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES
|
||||
cls_weight_t = torch.from_numpy(cls_weight).to(device)
|
||||
|
||||
Xt = torch.from_numpy(X_train).to(device)
|
||||
yt = torch.from_numpy(y_train).to(device)
|
||||
Xv = torch.from_numpy(X_val).to(device)
|
||||
yv = torch.from_numpy(y_val).to(device)
|
||||
|
||||
model = CountNet().to(device)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
|
||||
|
||||
n_train = X_train.shape[0]
|
||||
best_eval_acc = 0.0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
perm = torch.randperm(n_train, device=device)
|
||||
train_loss = 0.0
|
||||
train_correct = 0
|
||||
n_batches = 0
|
||||
for i in range(0, n_train, args.batch_size):
|
||||
idx = perm[i : i + args.batch_size]
|
||||
xb = Xt[idx]
|
||||
yb = yt[idx]
|
||||
opt.zero_grad()
|
||||
count_logits, conf_logits = model(xb)
|
||||
ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t)
|
||||
with torch.no_grad():
|
||||
pred = count_logits.argmax(dim=1)
|
||||
correct_indicator = (pred == yb).float().unsqueeze(1)
|
||||
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
|
||||
with torch.no_grad():
|
||||
conf_sigm = torch.sigmoid(conf_logits)
|
||||
brier = ((conf_sigm - correct_indicator) ** 2).mean()
|
||||
loss = ce + 0.3 * bce + 0.1 * brier
|
||||
loss.backward()
|
||||
opt.step()
|
||||
train_loss += loss.item()
|
||||
train_correct += (pred == yb).sum().item()
|
||||
n_batches += 1
|
||||
|
||||
sched.step()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_v, _ = model(Xv)
|
||||
eval_pred = cl_v.argmax(dim=1)
|
||||
eval_acc = (eval_pred == yv).float().mean().item()
|
||||
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
# Restore best checkpoint and final eval
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_v, conf_v = model(Xv)
|
||||
pred_v = cl_v.argmax(dim=1)
|
||||
acc = (pred_v == yv).float().mean().item()
|
||||
within1 = ((pred_v - yv).abs() <= 1).float().mean().item()
|
||||
mae = (pred_v - yv).abs().float().mean().item()
|
||||
|
||||
# Per-class accuracy
|
||||
per_class = {}
|
||||
for k in range(COUNT_CLASSES):
|
||||
mask = yv == k
|
||||
n = mask.sum().item()
|
||||
if n > 0:
|
||||
per_class[k] = {
|
||||
"support": int(n),
|
||||
"accuracy": ((pred_v == yv) & mask).sum().item() / n,
|
||||
}
|
||||
|
||||
# Spearman
|
||||
conf_sigm = torch.sigmoid(conf_v).squeeze(-1)
|
||||
correct = (pred_v == yv).float()
|
||||
c_rank = conf_sigm.argsort().argsort().float()
|
||||
r_rank = correct.argsort().argsort().float()
|
||||
c_centered = c_rank - c_rank.mean()
|
||||
r_centered = r_rank - r_rank.mean()
|
||||
denom = (c_centered.norm() * r_centered.norm()).item()
|
||||
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
|
||||
|
||||
fold_results.append({
|
||||
"fold": fold_idx + 1,
|
||||
"accuracy": acc,
|
||||
"within_pm1": within1,
|
||||
"mae": mae,
|
||||
"spearman": spearman,
|
||||
"per_class_accuracy": per_class,
|
||||
})
|
||||
print(f" accuracy={acc:.3f} within±1={within1:.3f} mae={mae:.3f} spearman={spearman:.3f}")
|
||||
|
||||
# K-fold summary
|
||||
total_time = time.perf_counter() - overall_t0
|
||||
accs = [r["accuracy"] for r in fold_results]
|
||||
within1s = [r["within_pm1"] for r in fold_results]
|
||||
maes = [r["mae"] for r in fold_results]
|
||||
spears = [r["spearman"] for r in fold_results]
|
||||
|
||||
print(f"\n=== {args.k_fold}-fold summary ({total_time:.1f} s) ===")
|
||||
print(f" accuracy: {np.mean(accs):.3f} ± {np.std(accs):.3f}")
|
||||
print(f" within ±1: {np.mean(within1s):.3f} ± {np.std(within1s):.3f}")
|
||||
print(f" MAE: {np.mean(maes):.3f} ± {np.std(maes):.3f}")
|
||||
print(f" conf↔correct Spearman: {np.mean(spears):.3f} ± {np.std(spears):.3f}")
|
||||
|
||||
# Per-class summary across folds
|
||||
for k in range(COUNT_CLASSES):
|
||||
accs_k = [r["per_class_accuracy"].get(k, {}).get("accuracy", 0.0) for r in fold_results]
|
||||
n_k = [r["per_class_accuracy"].get(k, {}).get("support", 0) for r in fold_results]
|
||||
if any(n > 0 for n in n_k):
|
||||
print(f" class {k}: {np.mean(accs_k):.3f} mean accuracy (support: {n_k})")
|
||||
|
||||
# Write k-fold results to JSON
|
||||
results = {
|
||||
"mode": "k_fold_cv",
|
||||
"k": args.k_fold,
|
||||
"backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu",
|
||||
"total_time_s": total_time,
|
||||
"fold_results": fold_results,
|
||||
"summary": {
|
||||
"mean_accuracy": float(np.mean(accs)),
|
||||
"std_accuracy": float(np.std(accs)),
|
||||
"mean_within_pm1": float(np.mean(within1s)),
|
||||
"std_within_pm1": float(np.std(within1s)),
|
||||
"mean_mae": float(np.mean(maes)),
|
||||
"std_mae": float(np.std(maes)),
|
||||
"mean_spearman": float(np.mean(spears)),
|
||||
"std_spearman": float(np.std(spears)),
|
||||
},
|
||||
"hyperparameters": {
|
||||
"optimizer": "AdamW",
|
||||
"lr": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"batch_size": args.batch_size,
|
||||
"schedule": "cosine_warm_restarts",
|
||||
"epochs": args.epochs,
|
||||
},
|
||||
}
|
||||
Path(args.out_results).write_text(json.dumps(results, indent=2))
|
||||
print(f"\nwrote {args.out_results}")
|
||||
return
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# v0.0.2 training path: random 80/20 + label smoothing + early
|
||||
# stopping + class-balanced batch sampling + temperature scaling.
|
||||
# ---------------------------------------------------------------
|
||||
if args.v2:
|
||||
rng = np.random.default_rng(seed=42)
|
||||
idx = np.arange(X.shape[0])
|
||||
rng.shuffle(idx)
|
||||
n_eval = int(round(0.2 * X.shape[0]))
|
||||
eval_idx, train_idx = idx[:n_eval], idx[n_eval:]
|
||||
X_train, X_eval = X[train_idx], X[eval_idx]
|
||||
y_train, y_eval = y[train_idx], y[eval_idx]
|
||||
X_train, X_eval = standardise(X_train, X_eval)
|
||||
print(f"v0.0.2 mode — random 80/20 split: train={len(y_train)} eval={len(y_eval)}")
|
||||
print(f" train class dist: {dict(Counter(y_train.tolist()).most_common())}")
|
||||
print(f" eval class dist: {dict(Counter(y_eval.tolist()).most_common())}")
|
||||
|
||||
Xt = torch.from_numpy(X_train).to(device)
|
||||
yt = torch.from_numpy(y_train).to(device)
|
||||
Xe = torch.from_numpy(X_eval).to(device)
|
||||
ye = torch.from_numpy(y_eval).to(device)
|
||||
|
||||
# Class-balanced sampler: for each batch, sample with replacement
|
||||
# so each class has equal expected count regardless of dataset
|
||||
# distribution. With our ~533/544 split this is nearly a no-op
|
||||
# but it generalises to imbalanced multi-room data later.
|
||||
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
|
||||
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
|
||||
per_sample_weight = (1.0 / cls_counts[y_train])
|
||||
per_sample_weight_t = torch.from_numpy(per_sample_weight.astype(np.float32)).to(device)
|
||||
|
||||
model = CountNet().to(device)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
|
||||
|
||||
n_train = X_train.shape[0]
|
||||
batches_per_epoch = max(1, n_train // args.batch_size)
|
||||
epoch_losses = []
|
||||
t0 = time.perf_counter()
|
||||
best_eval_acc = 0.0
|
||||
best_state = None
|
||||
epochs_without_improvement = 0
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0; train_correct = 0; n_batches = 0
|
||||
for _ in range(batches_per_epoch):
|
||||
# Balanced sample with replacement
|
||||
idx_t = torch.multinomial(per_sample_weight_t, args.batch_size, replacement=True)
|
||||
xb = Xt[idx_t]; yb = yt[idx_t]
|
||||
opt.zero_grad()
|
||||
count_logits, conf_logits = model(xb)
|
||||
ce = F.cross_entropy(count_logits, yb, label_smoothing=args.label_smoothing)
|
||||
with torch.no_grad():
|
||||
pred = count_logits.argmax(dim=1)
|
||||
correct_indicator = (pred == yb).float().unsqueeze(1)
|
||||
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
|
||||
with torch.no_grad():
|
||||
conf_sigm = torch.sigmoid(conf_logits)
|
||||
brier = ((conf_sigm - correct_indicator) ** 2).mean()
|
||||
loss = ce + 0.3 * bce + 0.1 * brier
|
||||
loss.backward()
|
||||
opt.step()
|
||||
train_loss += loss.item()
|
||||
train_correct += (pred == yb).sum().item()
|
||||
n_batches += 1
|
||||
sched.step()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_e, _ = model(Xe)
|
||||
eval_loss = F.cross_entropy(cl_e, ye).item()
|
||||
eval_pred = cl_e.argmax(dim=1)
|
||||
eval_acc = (eval_pred == ye).float().mean().item()
|
||||
epoch_losses.append({
|
||||
"epoch": epoch,
|
||||
"train_loss": train_loss / max(1, n_batches),
|
||||
"train_acc": train_correct / max(1, n_batches * args.batch_size),
|
||||
"eval_loss": eval_loss,
|
||||
"eval_acc": eval_acc,
|
||||
})
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
||||
epochs_without_improvement = 0
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
if epoch < 5 or epoch % 25 == 0:
|
||||
print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} "
|
||||
f"train_acc={train_correct/(n_batches*args.batch_size):.3f} "
|
||||
f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} "
|
||||
f"epochs_no_improve={epochs_without_improvement}")
|
||||
if epochs_without_improvement >= args.patience:
|
||||
print(f"early stopping at epoch {epoch} (no improvement for {args.patience} epochs)")
|
||||
break
|
||||
|
||||
train_time = time.perf_counter() - t0
|
||||
print(f"\ntrained {epoch + 1} epochs in {train_time:.1f} s (best eval_acc {best_eval_acc:.3f})")
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
# Temperature scaling on the confidence head — fit a scalar T s.t.
|
||||
# sigmoid(conf_logits / T) is best-calibrated on the eval set.
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_e, conf_e = model(Xe)
|
||||
pred_e = cl_e.argmax(dim=1)
|
||||
correct_indicator = (pred_e == ye).float()
|
||||
# 1D optimisation over T via LBFGS.
|
||||
T = torch.nn.Parameter(torch.ones(1, device=device))
|
||||
opt_t = torch.optim.LBFGS([T], lr=0.1, max_iter=50)
|
||||
def eval_t():
|
||||
opt_t.zero_grad()
|
||||
scaled = conf_e.squeeze(-1) / T
|
||||
loss_t = F.binary_cross_entropy_with_logits(scaled, correct_indicator)
|
||||
loss_t.backward()
|
||||
return loss_t
|
||||
opt_t.step(eval_t)
|
||||
T_val = float(T.detach().cpu().item())
|
||||
print(f" temperature scale T = {T_val:.4f}")
|
||||
|
||||
# Final eval with temperature applied.
|
||||
with torch.no_grad():
|
||||
cl_e, conf_e = model(Xe)
|
||||
probs_e = F.softmax(cl_e, dim=1)
|
||||
pred_e = cl_e.argmax(dim=1)
|
||||
acc = (pred_e == ye).float().mean().item()
|
||||
within1 = ((pred_e - ye).abs() <= 1).float().mean().item()
|
||||
mae = (pred_e - ye).abs().float().mean().item()
|
||||
per_class = {}
|
||||
for k in range(COUNT_CLASSES):
|
||||
mask = ye == k
|
||||
n = mask.sum().item()
|
||||
if n > 0:
|
||||
per_class[k] = {
|
||||
"support": int(n),
|
||||
"accuracy": ((pred_e == ye) & mask).sum().item() / n,
|
||||
}
|
||||
conf_sigm = torch.sigmoid(conf_e.squeeze(-1) / T_val)
|
||||
correct = (pred_e == ye).float()
|
||||
c_rank = conf_sigm.argsort().argsort().float()
|
||||
r_rank = correct.argsort().argsort().float()
|
||||
c_centered = c_rank - c_rank.mean()
|
||||
r_centered = r_rank - r_rank.mean()
|
||||
denom = (c_centered.norm() * r_centered.norm()).item()
|
||||
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
|
||||
|
||||
print(f"\n=== v0.0.2 final eval ===")
|
||||
print(f" accuracy: {acc:.3f}")
|
||||
print(f" within ±1: {within1:.3f}")
|
||||
print(f" MAE: {mae:.3f}")
|
||||
print(f" conf↔correct Spearman (post-temp): {spearman:.3f}")
|
||||
for k, v in per_class.items():
|
||||
print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples")
|
||||
|
||||
write_safetensors(model, Path(args.out_safetensors))
|
||||
# Also append the temperature scalar so the cog can apply it.
|
||||
# We add it by appending to the safetensors file using the
|
||||
# write_safetensors helper but with the temperature recorded
|
||||
# as a separate file alongside (count_v1.temperature.txt) for
|
||||
# consumption by the Rust cog inference path.
|
||||
Path(args.out_safetensors + ".temperature").write_text(f"{T_val}\n")
|
||||
print(f"wrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)")
|
||||
print(f"wrote {args.out_safetensors}.temperature ({T_val})")
|
||||
|
||||
# ONNX
|
||||
dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device)
|
||||
try:
|
||||
torch.onnx.export(model, dummy, args.out_onnx, opset_version=18,
|
||||
input_names=["csi_window"],
|
||||
output_names=["count_logits", "conf_logits"],
|
||||
dynamic_axes={"csi_window": {0: "batch"},
|
||||
"count_logits": {0: "batch"},
|
||||
"conf_logits": {0: "batch"}},
|
||||
export_params=True, do_constant_folding=True)
|
||||
print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)")
|
||||
except Exception as e:
|
||||
print(f"WARN: ONNX export failed: {e}")
|
||||
|
||||
results = {
|
||||
"mode": "v0.0.2",
|
||||
"backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu",
|
||||
"epochs_trained": epoch + 1,
|
||||
"train_time_s": train_time,
|
||||
"best_eval_acc": best_eval_acc,
|
||||
"final_eval_acc": acc,
|
||||
"final_eval_within_pm1": within1,
|
||||
"final_eval_mae": mae,
|
||||
"temperature_scale": T_val,
|
||||
"conf_correctness_spearman_post_temp": spearman,
|
||||
"per_class_accuracy": per_class,
|
||||
"hyperparameters": {
|
||||
"optimizer": "AdamW",
|
||||
"lr": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"batch_size": args.batch_size,
|
||||
"schedule": "cosine_warm_restarts",
|
||||
"epochs_max": args.epochs,
|
||||
"label_smoothing": args.label_smoothing,
|
||||
"patience": args.patience,
|
||||
"split": "random_80_20_seed_42",
|
||||
"balanced_sampler": True,
|
||||
"temperature_scaling": True,
|
||||
},
|
||||
"epoch_losses": epoch_losses,
|
||||
}
|
||||
Path(args.out_results).write_text(json.dumps(results, indent=2))
|
||||
print(f"wrote {args.out_results}")
|
||||
return
|
||||
|
||||
# Original temporal-split mode (kept for v0.0.1 reproducibility).
|
||||
X_train, y_train, X_eval, y_eval = temporal_split(X, y, eval_frac=0.2)
|
||||
X_train, X_eval = standardise(X_train, X_eval)
|
||||
|
||||
# Re-balance via class weights — handles the 50/50 split fine
|
||||
# but also makes the loss correct under future imbalanced data.
|
||||
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
|
||||
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
|
||||
cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES
|
||||
cls_weight_t = torch.from_numpy(cls_weight).to(device)
|
||||
print(f"class weights: {cls_weight.tolist()}")
|
||||
|
||||
Xt = torch.from_numpy(X_train).to(device)
|
||||
yt = torch.from_numpy(y_train).to(device)
|
||||
Xe = torch.from_numpy(X_eval).to(device)
|
||||
ye = torch.from_numpy(y_eval).to(device)
|
||||
|
||||
model = CountNet().to(device)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
|
||||
|
||||
n_train = X_train.shape[0]
|
||||
epoch_losses = []
|
||||
t0 = time.perf_counter()
|
||||
|
||||
best_eval_acc = 0.0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
perm = torch.randperm(n_train, device=device)
|
||||
train_loss = 0.0
|
||||
train_correct = 0
|
||||
n_batches = 0
|
||||
for i in range(0, n_train, args.batch_size):
|
||||
idx = perm[i : i + args.batch_size]
|
||||
xb = Xt[idx]
|
||||
yb = yt[idx]
|
||||
opt.zero_grad()
|
||||
count_logits, conf_logits = model(xb)
|
||||
|
||||
# Categorical cross-entropy for count.
|
||||
ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t)
|
||||
|
||||
# Confidence head: train against `argmax == truth` indicator.
|
||||
with torch.no_grad():
|
||||
pred = count_logits.argmax(dim=1)
|
||||
correct_indicator = (pred == yb).float().unsqueeze(1)
|
||||
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
|
||||
|
||||
# Brier-score uncertainty calibration on the conf head — sharpens
|
||||
# the calibration so the sigmoid output is a real probability.
|
||||
with torch.no_grad():
|
||||
conf_sigm = torch.sigmoid(conf_logits)
|
||||
brier = ((conf_sigm - correct_indicator) ** 2).mean()
|
||||
|
||||
loss = ce + 0.3 * bce + 0.1 * brier
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
train_correct += (pred == yb).sum().item()
|
||||
n_batches += 1
|
||||
|
||||
sched.step()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_e, _ = model(Xe)
|
||||
eval_loss = F.cross_entropy(cl_e, ye, weight=cls_weight_t).item()
|
||||
eval_pred = cl_e.argmax(dim=1)
|
||||
eval_acc = (eval_pred == ye).float().mean().item()
|
||||
eval_within1 = ((eval_pred - ye).abs() <= 1).float().mean().item()
|
||||
|
||||
epoch_losses.append({
|
||||
"epoch": epoch,
|
||||
"train_loss": train_loss / n_batches,
|
||||
"train_acc": train_correct / n_train,
|
||||
"eval_loss": eval_loss,
|
||||
"eval_acc": eval_acc,
|
||||
"eval_within_pm1": eval_within1,
|
||||
})
|
||||
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
if epoch < 5 or epoch % 50 == 0 or epoch == args.epochs - 1:
|
||||
print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} "
|
||||
f"train_acc={train_correct/n_train:.3f} "
|
||||
f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} "
|
||||
f"within±1={eval_within1:.3f}")
|
||||
|
||||
train_time = time.perf_counter() - t0
|
||||
print(f"\ntrained {args.epochs} epochs in {train_time:.1f} s")
|
||||
print(f"best eval_acc: {best_eval_acc:.3f}")
|
||||
|
||||
# Restore best checkpoint
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
# Eval breakdown
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
cl_e, conf_e = model(Xe)
|
||||
probs_e = torch.softmax(cl_e, dim=1)
|
||||
pred_e = cl_e.argmax(dim=1)
|
||||
acc = (pred_e == ye).float().mean().item()
|
||||
within1 = ((pred_e - ye).abs() <= 1).float().mean().item()
|
||||
mae = (pred_e - ye).abs().float().mean().item()
|
||||
|
||||
# Per-class accuracy
|
||||
per_class = {}
|
||||
for k in range(COUNT_CLASSES):
|
||||
mask = ye == k
|
||||
n = mask.sum().item()
|
||||
if n > 0:
|
||||
per_class[k] = {
|
||||
"support": int(n),
|
||||
"accuracy": ((pred_e == ye) & mask).sum().item() / n,
|
||||
}
|
||||
|
||||
# Confidence-accuracy calibration: Spearman over (predicted-correct, confidence)
|
||||
conf_sigm = torch.sigmoid(conf_e).squeeze(-1)
|
||||
correct = (pred_e == ye).float()
|
||||
# Spearman = Pearson over ranks
|
||||
c_rank = conf_sigm.argsort().argsort().float()
|
||||
r_rank = correct.argsort().argsort().float()
|
||||
c_centered = c_rank - c_rank.mean()
|
||||
r_centered = r_rank - r_rank.mean()
|
||||
denom = (c_centered.norm() * r_centered.norm()).item()
|
||||
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
|
||||
|
||||
print(f"\n=== final eval ===")
|
||||
print(f" accuracy: {acc:.3f}")
|
||||
print(f" within ±1: {within1:.3f}")
|
||||
print(f" MAE: {mae:.3f}")
|
||||
print(f" conf↔correct Spearman: {spearman:.3f}")
|
||||
for k, v in per_class.items():
|
||||
print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples")
|
||||
|
||||
# Save safetensors
|
||||
write_safetensors(model, Path(args.out_safetensors))
|
||||
print(f"\nwrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)")
|
||||
|
||||
# ONNX export
|
||||
dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device)
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model, dummy, args.out_onnx,
|
||||
opset_version=18,
|
||||
input_names=["csi_window"],
|
||||
output_names=["count_logits", "conf_logits"],
|
||||
dynamic_axes={
|
||||
"csi_window": {0: "batch"},
|
||||
"count_logits": {0: "batch"},
|
||||
"conf_logits": {0: "batch"},
|
||||
},
|
||||
export_params=True,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)")
|
||||
except Exception as e:
|
||||
print(f"WARN: ONNX export failed: {e}")
|
||||
|
||||
# Results JSON
|
||||
results = {
|
||||
"backend": "candle-cuda" if device.type == "cuda" else "candle-cpu",
|
||||
"device": str(device),
|
||||
"epochs": args.epochs,
|
||||
"train_time_s": train_time,
|
||||
"best_eval_acc": best_eval_acc,
|
||||
"final_eval_acc": acc,
|
||||
"final_eval_within_pm1": within1,
|
||||
"final_eval_mae": mae,
|
||||
"conf_correctness_spearman": spearman,
|
||||
"per_class_accuracy": per_class,
|
||||
"hyperparameters": {
|
||||
"optimizer": "AdamW",
|
||||
"lr": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"batch_size": args.batch_size,
|
||||
"schedule": "cosine_warm_restarts",
|
||||
"epochs": args.epochs,
|
||||
"loss": "cross_entropy(count) + 0.3*bce(conf) + 0.1*brier(conf)",
|
||||
"z_score_normalisation": True,
|
||||
"class_weights": cls_weight.tolist(),
|
||||
},
|
||||
"epoch_losses": epoch_losses,
|
||||
}
|
||||
Path(args.out_results).write_text(json.dumps(results, indent=2))
|
||||
print(f"wrote {args.out_results} ({Path(args.out_results).stat().st_size} bytes)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,19 +27,36 @@ Replaces the PR #491 slot heuristic (`subcarrier_diversity / dedup_factor`) with
|
||||
|
||||
Downstream consumers can render the **most-likely count** when confidence is high, or fall back to a `[lo, hi]` band with a "?" badge when the model is uncertain — that's how this Cog closes the loop on #499's ghost-skeleton UX.
|
||||
|
||||
## Status — v0.0.1 (this scaffold)
|
||||
## Status — v0.0.1
|
||||
|
||||
| Component | State |
|
||||
|---|---|
|
||||
| Crate compiles, library API stable | ✅ |
|
||||
| Tests pass (`cargo test -p cog-person-count`) | ✅ |
|
||||
| Tests pass (15 total: 8 smoke + 7 fusion) | ✅ |
|
||||
| Four-verb runtime contract (`version`, `manifest`, `health`) | ✅ |
|
||||
| `run` subcommand (long-running loop) | ⏳ v0.0.1 follow-up |
|
||||
| Trained `count_v1.safetensors` artifact | ⏳ same training pipeline that produced `pose_v1` — bootstrap on the existing 1,077 paired samples |
|
||||
| Signed binary on GCS | ⏳ once trained |
|
||||
| Trained `count_v1.safetensors` artifact | ✅ shipped at `cog/artifacts/count_v1.safetensors` (392 KB) |
|
||||
| ONNX export | ✅ `count_v1.onnx` (16 KB), bit-compatible architecture |
|
||||
| Honest accuracy reporting | ✅ See `docs/benchmarks/person-count-cog.md` — 65.1% eval acc on a single-session dataset; confidence head Spearman 0.023 ⇒ uncalibrated for v0.0.1 |
|
||||
| `run` subcommand (long-running loop) | ⏳ same shape as cog-pose-estimation::runtime, lands in follow-up |
|
||||
| Signed binary on GCS | ⏳ release pipeline |
|
||||
| Stoer-Wagner min-cut clip in fusion stage | ⏳ v0.2.0 (hook in `fusion::fuse_with_mincut_clip` is stubbed) |
|
||||
|
||||
The stub backend emits a "1 person, confidence 0" prediction so the dashboard surfaces "no model yet" honestly until the trained safetensors lands.
|
||||
### Honest v0.0.1 caveat
|
||||
|
||||
`count_v1` was trained on a single 30-minute solo recording. The model overfit by epoch ~100 and the "best" checkpoint is one that effectively predicts the eval-window class distribution (mostly class-0). Class-1 accuracy on the held-out tail = 0%. **This v0.0.1 is a working pipeline with a degenerate model**, not a usable counter yet — same data-bound failure mode as `pose_v1` (#645), same fix: multi-room paired recordings.
|
||||
|
||||
`cog-person-count health` will load the real safetensors and report `backend: candle-cpu` rather than `backend: stub`, so the cog-gateway can verify the model loaded — but operators should treat the v0.0.1 count outputs as scaffold-validation rather than production data. The 2.36 MB binary + 392 KB weights + 16 KB ONNX are all real and reusable as soon as more data lands.
|
||||
|
||||
## Relationship to the in-process `csi.rs::score_to_person_count` heuristic
|
||||
|
||||
This Cog runs **out-of-process** alongside `wifi-densepose-sensing-server`. The two are complementary, not competing:
|
||||
|
||||
- The sensing-server keeps emitting its existing slot-count heuristic from `csi.rs::score_to_person_count` (PR #491's RollingP95 + `dedup_factor`). This is the **fallback path** — operators who don't install `cog-person-count` still get a count number, just a less calibrated one.
|
||||
- `cog-person-count` (this binary) polls the same `/api/v1/sensing/latest` endpoint, runs the learned `count_v1` model on each window, and emits `person.count` events on stdout. The appliance's `cognitum-cog-gateway` routes those events to the dashboard via the standard ADR-220 cog-event channel.
|
||||
|
||||
Operators choose by **installing or not installing** this Cog — no sensing-server rebuild required. Downstream consumers (UI, fleet automation, alerting rules) can subscribe to whichever event stream they prefer.
|
||||
|
||||
The architecture decision is documented in [ADR-103 §"Deployment"](../../../../docs/adr/ADR-103-learned-multi-person-counter.md#deployment) and matches the cog/sensing-server boundary established for `cog-pose-estimation` (ADR-101).
|
||||
|
||||
## Security
|
||||
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
{
|
||||
"mode": "v0.0.2",
|
||||
"backend": "pytorch-cuda",
|
||||
"epochs_trained": 29,
|
||||
"train_time_s": 0.7185604920377955,
|
||||
"best_eval_acc": 0.6232557892799377,
|
||||
"final_eval_acc": 0.6232557892799377,
|
||||
"final_eval_within_pm1": 1.0,
|
||||
"final_eval_mae": 0.37674418091773987,
|
||||
"temperature_scale": 0.9261822700500488,
|
||||
"conf_correctness_spearman_post_temp": 0.012770170735830375,
|
||||
"per_class_accuracy": {
|
||||
"0": {
|
||||
"support": 116,
|
||||
"accuracy": 0.8620689655172413
|
||||
},
|
||||
"1": {
|
||||
"support": 99,
|
||||
"accuracy": 0.3434343434343434
|
||||
}
|
||||
},
|
||||
"hyperparameters": {
|
||||
"optimizer": "AdamW",
|
||||
"lr": 0.001,
|
||||
"weight_decay": 0.01,
|
||||
"batch_size": 64,
|
||||
"schedule": "cosine_warm_restarts",
|
||||
"epochs_max": 400,
|
||||
"label_smoothing": 0.1,
|
||||
"patience": 20,
|
||||
"split": "random_80_20_seed_42",
|
||||
"balanced_sampler": true,
|
||||
"temperature_scaling": true
|
||||
},
|
||||
"epoch_losses": [
|
||||
{
|
||||
"epoch": 0,
|
||||
"train_loss": 1.8680313183711126,
|
||||
"train_acc": 0.4543269230769231,
|
||||
"eval_loss": 0.7276814579963684,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.3579198305423443,
|
||||
"train_acc": 0.5060096153846154,
|
||||
"eval_loss": 0.8614012002944946,
|
||||
"eval_acc": 0.46046510338783264
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.299364447593689,
|
||||
"train_acc": 0.4831730769230769,
|
||||
"eval_loss": 0.7327257990837097,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.2834151433064387,
|
||||
"train_acc": 0.4963942307692308,
|
||||
"eval_loss": 0.7958587408065796,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.2809640077444224,
|
||||
"train_acc": 0.49278846153846156,
|
||||
"eval_loss": 0.7728011608123779,
|
||||
"eval_acc": 0.46046510338783264
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.276416512636038,
|
||||
"train_acc": 0.5120192307692307,
|
||||
"eval_loss": 0.7620130181312561,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.2767094740500817,
|
||||
"train_acc": 0.4951923076923077,
|
||||
"eval_loss": 0.7696149945259094,
|
||||
"eval_acc": 0.604651153087616
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.2724562699978168,
|
||||
"train_acc": 0.5324519230769231,
|
||||
"eval_loss": 0.7653729319572449,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 8,
|
||||
"train_loss": 1.2739891455723689,
|
||||
"train_acc": 0.5264423076923077,
|
||||
"eval_loss": 0.7635467648506165,
|
||||
"eval_acc": 0.6232557892799377
|
||||
},
|
||||
{
|
||||
"epoch": 9,
|
||||
"train_loss": 1.2718101739883423,
|
||||
"train_acc": 0.5120192307692307,
|
||||
"eval_loss": 0.7564782500267029,
|
||||
"eval_acc": 0.604651153087616
|
||||
},
|
||||
{
|
||||
"epoch": 10,
|
||||
"train_loss": 1.261798886152414,
|
||||
"train_acc": 0.5625,
|
||||
"eval_loss": 0.7915780544281006,
|
||||
"eval_acc": 0.46046510338783264
|
||||
},
|
||||
{
|
||||
"epoch": 11,
|
||||
"train_loss": 1.2723550613109882,
|
||||
"train_acc": 0.5348557692307693,
|
||||
"eval_loss": 0.7585318088531494,
|
||||
"eval_acc": 0.6139534711837769
|
||||
},
|
||||
{
|
||||
"epoch": 12,
|
||||
"train_loss": 1.2408426174750695,
|
||||
"train_acc": 0.6225961538461539,
|
||||
"eval_loss": 0.7562077045440674,
|
||||
"eval_acc": 0.525581419467926
|
||||
},
|
||||
{
|
||||
"epoch": 13,
|
||||
"train_loss": 1.219417168543889,
|
||||
"train_acc": 0.6334134615384616,
|
||||
"eval_loss": 0.7647078633308411,
|
||||
"eval_acc": 0.5860465168952942
|
||||
},
|
||||
{
|
||||
"epoch": 14,
|
||||
"train_loss": 1.198713256762578,
|
||||
"train_acc": 0.6526442307692307,
|
||||
"eval_loss": 0.7711634635925293,
|
||||
"eval_acc": 0.5720930099487305
|
||||
},
|
||||
{
|
||||
"epoch": 15,
|
||||
"train_loss": 1.167367669252249,
|
||||
"train_acc": 0.6826923076923077,
|
||||
"eval_loss": 0.7664391994476318,
|
||||
"eval_acc": 0.6186046600341797
|
||||
},
|
||||
{
|
||||
"epoch": 16,
|
||||
"train_loss": 1.1867470557873065,
|
||||
"train_acc": 0.6574519230769231,
|
||||
"eval_loss": 0.7853891253471375,
|
||||
"eval_acc": 0.6139534711837769
|
||||
},
|
||||
{
|
||||
"epoch": 17,
|
||||
"train_loss": 1.185251813668471,
|
||||
"train_acc": 0.6766826923076923,
|
||||
"eval_loss": 0.7728492021560669,
|
||||
"eval_acc": 0.5767441987991333
|
||||
},
|
||||
{
|
||||
"epoch": 18,
|
||||
"train_loss": 1.1749065747627845,
|
||||
"train_acc": 0.6814903846153846,
|
||||
"eval_loss": 0.7930512428283691,
|
||||
"eval_acc": 0.5488371849060059
|
||||
},
|
||||
{
|
||||
"epoch": 19,
|
||||
"train_loss": 1.1521984338760376,
|
||||
"train_acc": 0.6983173076923077,
|
||||
"eval_loss": 0.7875214219093323,
|
||||
"eval_acc": 0.5860465168952942
|
||||
},
|
||||
{
|
||||
"epoch": 20,
|
||||
"train_loss": 1.158121026479281,
|
||||
"train_acc": 0.6802884615384616,
|
||||
"eval_loss": 0.785778820514679,
|
||||
"eval_acc": 0.5860465168952942
|
||||
},
|
||||
{
|
||||
"epoch": 21,
|
||||
"train_loss": 1.1232389486753023,
|
||||
"train_acc": 0.7319711538461539,
|
||||
"eval_loss": 0.7949181795120239,
|
||||
"eval_acc": 0.5767441987991333
|
||||
},
|
||||
{
|
||||
"epoch": 22,
|
||||
"train_loss": 1.1163162634922907,
|
||||
"train_acc": 0.7391826923076923,
|
||||
"eval_loss": 0.867073118686676,
|
||||
"eval_acc": 0.539534866809845
|
||||
},
|
||||
{
|
||||
"epoch": 23,
|
||||
"train_loss": 1.1119057948772724,
|
||||
"train_acc": 0.7211538461538461,
|
||||
"eval_loss": 0.8135209679603577,
|
||||
"eval_acc": 0.5953488349914551
|
||||
},
|
||||
{
|
||||
"epoch": 24,
|
||||
"train_loss": 1.107274578167842,
|
||||
"train_acc": 0.7271634615384616,
|
||||
"eval_loss": 0.8401668071746826,
|
||||
"eval_acc": 0.5534883737564087
|
||||
},
|
||||
{
|
||||
"epoch": 25,
|
||||
"train_loss": 1.0781027399576628,
|
||||
"train_acc": 0.7451923076923077,
|
||||
"eval_loss": 0.8606341481208801,
|
||||
"eval_acc": 0.5441860556602478
|
||||
},
|
||||
{
|
||||
"epoch": 26,
|
||||
"train_loss": 1.041811819259937,
|
||||
"train_acc": 0.7584134615384616,
|
||||
"eval_loss": 0.8801625967025757,
|
||||
"eval_acc": 0.5767441987991333
|
||||
},
|
||||
{
|
||||
"epoch": 27,
|
||||
"train_loss": 1.0369769976689265,
|
||||
"train_acc": 0.7764423076923077,
|
||||
"eval_loss": 0.8642652034759521,
|
||||
"eval_acc": 0.5860465168952942
|
||||
},
|
||||
{
|
||||
"epoch": 28,
|
||||
"train_loss": 1.0502384350850031,
|
||||
"train_acc": 0.7524038461538461,
|
||||
"eval_loss": 0.8719286322593689,
|
||||
"eval_acc": 0.5720930099487305
|
||||
}
|
||||
]
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
0.9261822700500488
|
||||
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"arch": "arm",
|
||||
"binary_bytes": 3807456,
|
||||
"binary_sha256": "15c2fbac19741298ad1cbaf119c633a42db0a273099561fd57d8afce27728ea5",
|
||||
"binary_signature": "gyV2CDhJo5nqBnREA08KnztGsS7AFOuXCse+2/+wul8DAzerHs9p4L6eUgl8QeiDS9rdQZs33XRxH5WTbkT0Ag==",
|
||||
"binary_url": "https://storage.googleapis.com/cognitum-apps/cogs/arm/cog-person-count-arm",
|
||||
"build_metadata": {
|
||||
"candle": "0.9 cpu",
|
||||
"cog_person_count_version": "0.3.0",
|
||||
"rust": "1.95.0",
|
||||
"training_caveat": "random 80/20 split + label smoothing + early stopping + balanced sampler + temperature calibration. K-fold reference: class-1 mean 57.1% across 5 folds.",
|
||||
"training_class1_accuracy": 0.343,
|
||||
"training_eval_accuracy": 0.623,
|
||||
"training_eval_mae": 0.349,
|
||||
"training_temperature_scale": 0.9262
|
||||
},
|
||||
"id": "person-count",
|
||||
"installed_at": 0,
|
||||
"sig_algo": "Ed25519",
|
||||
"signed_by": "COGNITUM_OWNER_SIGNING_KEY",
|
||||
"status": "installed",
|
||||
"target_triple": "aarch64-unknown-linux-gnu",
|
||||
"version": "0.0.2",
|
||||
"weights_bytes": 392088,
|
||||
"weights_sha256": "32996433516891a37c63c600db8b95e42192a53bd538c088c82cd6a85e55513c",
|
||||
"weights_url": "https://storage.googleapis.com/cognitum-apps/cogs/arm/cog-person-count-count_v1.safetensors"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"arch": "x86_64",
|
||||
"binary_bytes": 4502960,
|
||||
"binary_sha256": "051614ce6ba63df704fae848a67ad095df4bb88862fdff05ef3c0419cc8388b3",
|
||||
"binary_signature": "P9txCcsqCoFN6LyZS+Hl33pYZxiP/nXJMTI6s4bt26cc+Cteidz7ymajCQIfuq0mx0cnWaQ6eKZUjzq5AIgoBw==",
|
||||
"binary_url": "https://storage.googleapis.com/cognitum-apps/cogs/x86_64/cog-person-count-x86_64",
|
||||
"build_metadata": {
|
||||
"candle": "0.9 cpu",
|
||||
"cog_person_count_version": "0.3.0",
|
||||
"rust": "1.95.0",
|
||||
"training_caveat": "random 80/20 split + label smoothing + early stopping + balanced sampler + temperature calibration. K-fold reference: class-1 mean 57.1% across 5 folds.",
|
||||
"training_class1_accuracy": 0.343,
|
||||
"training_eval_accuracy": 0.623,
|
||||
"training_eval_mae": 0.349,
|
||||
"training_temperature_scale": 0.9262
|
||||
},
|
||||
"id": "person-count",
|
||||
"installed_at": 0,
|
||||
"sig_algo": "Ed25519",
|
||||
"signed_by": "COGNITUM_OWNER_SIGNING_KEY",
|
||||
"status": "installed",
|
||||
"target_triple": "x86_64-unknown-linux-gnu",
|
||||
"version": "0.0.2",
|
||||
"weights_bytes": 392088,
|
||||
"weights_sha256": "32996433516891a37c63c600db8b95e42192a53bd538c088c82cd6a85e55513c",
|
||||
"weights_url": "https://storage.googleapis.com/cognitum-apps/cogs/arm/cog-person-count-count_v1.safetensors"
|
||||
}
|
||||
@@ -10,6 +10,7 @@
|
||||
pub mod fusion;
|
||||
pub mod inference;
|
||||
pub mod publisher;
|
||||
pub mod runtime;
|
||||
|
||||
pub const COG_ID: &str = "person-count";
|
||||
pub const COG_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
@@ -103,10 +103,31 @@ fn cmd_health() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_run(_config_path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Long-running mode is wired in the v0.0.1 release follow-up — same
|
||||
// approach as cog-pose-estimation's runtime.rs. For now, the cog
|
||||
// satisfies the four-verb contract; downstream consumers integrate
|
||||
// via the in-process `InferenceEngine` API.
|
||||
Err("`run` subcommand wiring is pending v0.0.1 — for now consume via the InferenceEngine library API".into())
|
||||
fn cmd_run(config_path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let raw = std::fs::read_to_string(&config_path)
|
||||
.map_err(|e| format!("failed to read config at {}: {}", config_path.display(), e))?;
|
||||
let cfg: RunConfig = serde_json::from_str(&raw)
|
||||
.map_err(|e| format!("failed to parse config at {}: {}", config_path.display(), e))?;
|
||||
|
||||
let engine = InferenceEngine::with_weights(cfg.model_path.as_deref())?;
|
||||
publisher::run_started(
|
||||
COG_ID,
|
||||
&cfg.sensing_url,
|
||||
cfg.poll_ms,
|
||||
&cfg.model_path
|
||||
.as_ref()
|
||||
.map(|p| p.display().to_string())
|
||||
.unwrap_or_else(|| "(auto-discover)".to_string()),
|
||||
);
|
||||
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
rt.block_on(cog_person_count::runtime::run_loop(
|
||||
cog_person_count::runtime::RunConfig {
|
||||
sensing_url: cfg.sensing_url,
|
||||
poll_ms: cfg.poll_ms,
|
||||
},
|
||||
engine,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
//! Long-running inference loop. Polls the appliance's sensing-server,
|
||||
//! slides a CSI window, runs the count head, and emits `person.count`
|
||||
//! events. Same shape as `cog-pose-estimation::runtime`.
|
||||
//!
|
||||
//! Multi-node fusion is single-node only in v0.0.1 — the appliance's
|
||||
//! `/api/v1/sensing/latest` endpoint already aggregates across nodes
|
||||
//! before serving, so per-cog fusion is deferred until each node ships
|
||||
//! raw frames separately (ADR-103 §"Multi-node fusion" v0.2.0).
|
||||
|
||||
use crate::inference::{CsiWindow, InferenceEngine, INPUT_SUBCARRIERS, INPUT_TIMESTEPS};
|
||||
use crate::publisher;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct RunConfig {
|
||||
pub sensing_url: String,
|
||||
pub poll_ms: u64,
|
||||
}
|
||||
|
||||
pub async fn run_loop(
|
||||
cfg: RunConfig,
|
||||
engine: InferenceEngine,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut buffer: Vec<f32> = Vec::with_capacity(INPUT_SUBCARRIERS * INPUT_TIMESTEPS);
|
||||
let cap = INPUT_SUBCARRIERS * INPUT_TIMESTEPS;
|
||||
let mut tick: u64 = 0;
|
||||
|
||||
loop {
|
||||
match fetch_frame(&cfg.sensing_url).await {
|
||||
Ok(amplitudes) => {
|
||||
tick += 1;
|
||||
buffer.extend(amplitudes);
|
||||
while buffer.len() > 2 * cap {
|
||||
let extra = buffer.len() - cap;
|
||||
buffer.drain(0..extra);
|
||||
}
|
||||
if buffer.len() >= cap {
|
||||
let window = CsiWindow { data: buffer[buffer.len() - cap..].to_vec() };
|
||||
if let Ok(pred) = engine.infer(&window) {
|
||||
// v0.0.1 ships single-node — fusion is a no-op for
|
||||
// N=1. v0.2.0 will append additional per-node
|
||||
// predictions to a vec and call
|
||||
// `fusion::fuse_confidence_weighted` before emit.
|
||||
publisher::person_count(tick, &pred, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "sensing-server fetch failed");
|
||||
}
|
||||
}
|
||||
sleep(Duration::from_millis(cfg.poll_ms)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_frame(url: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
let url = url.to_string();
|
||||
let body = tokio::task::spawn_blocking(move || -> Result<String, ureq::Error> {
|
||||
Ok(ureq::get(&url).call()?.into_string()?)
|
||||
})
|
||||
.await??;
|
||||
let json: serde_json::Value = serde_json::from_str(&body)?;
|
||||
let snapshot = json.get("snapshot").unwrap_or(&json);
|
||||
let nodes = snapshot
|
||||
.get("nodes")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or("missing nodes[]")?;
|
||||
let amplitude = nodes
|
||||
.first()
|
||||
.and_then(|n| n.get("amplitude"))
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or("missing nodes[0].amplitude[]")?;
|
||||
Ok(amplitude
|
||||
.iter()
|
||||
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||
.collect())
|
||||
}
|
||||
Reference in New Issue
Block a user