Files
rUv b3a5012dbd feat(cog-person-count): v0.0.2 — K-fold + label-smoothing + temperature-calibrated (#699)
* chore: stage v0.0.2 artifacts + temperature scalar for build pipeline

Stages count_v1.{safetensors,onnx,temperature,train_results.json}
ahead of the build/sign/upload step. This commit is a momentary
side-effect — the next commit will refresh the per-arch manifests
with the new binary SHAs once ruvultra finishes the cross-build.

The .temperature file holds the calibration scalar from LBFGS over the
held-out conf logits. The Rust cog will read it post-load and divide
conf_logits by it before sigmoid, exactly matching the Python eval.

* feat(cog-person-count): v0.0.2 — K-fold validated, label smoothing + early stop + temp scale

The v0.0.1 "65.1% but class-1=0%" result was an unlucky temporal split
that let a degenerate "always predict 0" classifier hit eval acc =
class-0 fraction. 5-fold stratified random CV proved the architecture
actually learns ~57.1% class-1 accuracy under fair splits — a real,
modestly useful signal.

v0.0.2 ships a retrained model that:

* **Splits randomly (seed=42) 80/20** instead of temporally — eliminates
  the trailing-window-class-imbalance cheat.
* **Class-balanced sampler** (multinomial with replacement, weighted by
  inverse class frequency) — per-batch expected counts are equal
  regardless of dataset distribution.
* **Label smoothing 0.1** on the cross-entropy — reduces confidence
  saturation that drove v0.0.1's all-or-nothing predictions.
* **Early stopping** with patience=20 — stops at epoch 29 instead of
  overfitting through 400.
* **Temperature scaling** of the conf head — LBFGS fits a scalar T on
  held-out conf logits; ships as a count_v1.temperature sidecar so the
  Rust cog can divide conf_logits by T before sigmoid.

Numbers on the same data:

  | Metric           | v0.0.1 | v0.0.2 | K-fold (5x100) |
  |------------------|--------|--------|----------------|
  | Overall acc      | 65.1%  | 62.3%  | 62.2% ± 1.9%   |
  | Class 0 acc      | 100%   | 86.2%  | 67.4%          |
  | Class 1 acc      |  0%    | 34.3%  | 57.1% ✓        |
  | MAE              | 0.349  | 0.377  | 0.378          |
  | Spearman         | 0.023  | 0.013  | 0.160          |

Class-1 accuracy 0 → 34.3% is the headline win. Net acc moves slightly
because we stopped cheating on class 0. K-fold's 57% says there's
headroom remaining; reaching it needs more independent splits (== more
data), not more training tricks.

Confidence calibration didn't move. Temperature scaling alone can't fix
a confidence head trained against a noisy argmax==truth indicator over
a 62%-accurate classifier — the head's training signal is the issue,
not its post-hoc transform. The honest fix is multi-room data (#645),
not another calibration knob.

Live on cognitum-v0 at /var/lib/cognitum/apps/person-count/ — health
reports candle-cpu backend, count = 1 (was 0 in v0.0.1) on synthetic
zero input.

Files changed:
* scripts/train-count.py — adds --k-fold (no sklearn dep, hand-rolled
  stratified splits with deterministic shuffle) and --v2 paths.
* v2/.../cog/artifacts/count_v1.safetensors (392 KB, new sha
  32996433…) + count_v1.onnx (16 KB) + count_v1.temperature (0.9262
  scalar) + count_train_results.json (full epoch trace).
* v2/.../cog/artifacts/manifests/{arm,x86_64}/manifest.json bumped to
  version 0.0.2 with the new weights_sha256 + caveats.
* docs/benchmarks/person-count-cog.md — appends a v0.0.2 section
  with the K-fold diagnostic table and honest-read paragraph.

GCS:
  gs://cognitum-apps/cogs/arm/cog-person-count-count_v1.safetensors
    refreshed (binaries unchanged — load weights via mmap at runtime).
2026-05-21 19:47:04 -04:00

762 lines
32 KiB
Python

#!/usr/bin/env python3
"""Train the person-count head — ADR-103 v0.0.1.
Mirrors the Conv1d encoder architecture from cog-person-count's
`src/inference.rs::CountNet` exactly, so the learned weights load
into the Rust cog without translation. Trains on
data/paired/wiflow-p7-1779210883.paired.jsonl (1,077 samples with
n_persons_mode labels in {0, 1}).
Output: count_v1.safetensors + count_v1.onnx + train_results.json.
"""
from __future__ import annotations
import argparse
import json
import struct
import time
from collections import Counter
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# Architecture constants — MUST match cog-person-count's src/inference.rs.
N_SUB = 56
N_FRAMES = 20
COUNT_CLASSES = 8
class CountNet(nn.Module):
"""Mirrors cog_person_count::inference::CountNet bit-for-bit."""
def __init__(self) -> None:
super().__init__()
# Encoder — identical to the pose cog's encoder so future joint
# training can share weights.
self.enc_c1 = nn.Conv1d(N_SUB, 64, kernel_size=3, padding=1, dilation=1)
self.enc_c2 = nn.Conv1d(64, 128, kernel_size=3, padding=2, dilation=2)
self.enc_c3 = nn.Conv1d(128, 128, kernel_size=3, padding=4, dilation=4)
# Count head
self.count_head_fc1 = nn.Linear(128, 64)
self.count_head_fc2 = nn.Linear(64, COUNT_CLASSES)
# Confidence head
self.conf_head_fc1 = nn.Linear(128, 32)
self.conf_head_fc2 = nn.Linear(32, 1)
def forward(self, x: torch.Tensor):
# x: [B, 56, 20]
h = F.relu(self.enc_c1(x))
h = F.relu(self.enc_c2(h))
h = F.relu(self.enc_c3(h))
h = h.mean(dim=2) # [B, 128]
# Logits (un-normalised); softmax at inference + cross-entropy training.
c = F.relu(self.count_head_fc1(h))
count_logits = self.count_head_fc2(c)
# Confidence head — sigmoid at inference; BCE-with-logits at training.
cf = F.relu(self.conf_head_fc1(h))
conf_logits = self.conf_head_fc2(cf)
return count_logits, conf_logits
def load_paired(path: Path) -> tuple[np.ndarray, np.ndarray]:
"""Return (X, y) where X is [N, 56, 20] CSI and y is [N] integer counts."""
csis, ys = [], []
with path.open(encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
d = json.loads(line)
shape = d.get("csi_shape", [N_SUB, N_FRAMES])
if shape != [N_SUB, N_FRAMES]:
continue
csi = np.asarray(d["csi"], dtype=np.float32).reshape(N_SUB, N_FRAMES)
csis.append(csi)
ys.append(int(d.get("n_persons_mode", 0)))
X = np.stack(csis, axis=0)
y = np.asarray(ys, dtype=np.int64)
return X, y
def temporal_split(X: np.ndarray, y: np.ndarray, eval_frac: float = 0.2):
"""Held-out time-window eval (last `eval_frac` of samples, by index)."""
n = X.shape[0]
n_eval = int(round(n * eval_frac))
n_train = n - n_eval
return (
X[:n_train], y[:n_train],
X[n_train:], y[n_train:],
)
def stratified_k_fold(X: np.ndarray, y: np.ndarray, k: int = 5):
"""Stratified k-fold cross-validation splits — hand-rolled, no sklearn.
Per class: shuffle the indices (deterministic seed 42), split into k
near-equal chunks, then assemble fold i by taking chunk i from every
class. Yields (X_train, y_train, X_val, y_val) per fold, with class
distribution preserved within ±1.
"""
rng = np.random.default_rng(seed=42)
classes = np.unique(y)
per_class_folds = {}
for c in classes:
idx = np.where(y == c)[0]
rng.shuffle(idx)
per_class_folds[c] = np.array_split(idx, k)
for fold in range(k):
val_idx = np.concatenate([per_class_folds[c][fold] for c in classes])
train_idx = np.concatenate(
[per_class_folds[c][f] for c in classes for f in range(k) if f != fold]
)
yield X[train_idx], y[train_idx], X[val_idx], y[val_idx]
def standardise(X_train: np.ndarray, X_eval: np.ndarray):
"""Z-score by subcarrier across the time axis. Eval uses train stats."""
mu = X_train.mean(axis=(0, 2), keepdims=True)
sd = X_train.std(axis=(0, 2), keepdims=True) + 1e-6
return (X_train - mu) / sd, (X_eval - mu) / sd
def write_safetensors(model: CountNet, path: Path):
"""Write the model's state in the same on-disk layout the Rust cog expects."""
state = model.state_dict()
# Map PyTorch param names → cog-person-count's VarBuilder paths.
rename = {
"enc_c1.weight": "enc.c1.weight",
"enc_c1.bias": "enc.c1.bias",
"enc_c2.weight": "enc.c2.weight",
"enc_c2.bias": "enc.c2.bias",
"enc_c3.weight": "enc.c3.weight",
"enc_c3.bias": "enc.c3.bias",
"count_head_fc1.weight": "count_head.fc1.weight",
"count_head_fc1.bias": "count_head.fc1.bias",
"count_head_fc2.weight": "count_head.fc2.weight",
"count_head_fc2.bias": "count_head.fc2.bias",
"conf_head_fc1.weight": "conf_head.fc1.weight",
"conf_head_fc1.bias": "conf_head.fc1.bias",
"conf_head_fc2.weight": "conf_head.fc2.weight",
"conf_head_fc2.bias": "conf_head.fc2.bias",
}
header = {}
payload = bytearray()
offset = 0
for torch_name, cog_name in rename.items():
t = state[torch_name].detach().cpu().numpy().astype(np.float32)
n_bytes = t.nbytes
header[cog_name] = {
"dtype": "F32",
"shape": list(t.shape),
"data_offsets": [offset, offset + n_bytes],
}
payload.extend(t.tobytes())
offset += n_bytes
header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
with path.open("wb") as f:
f.write(struct.pack("<Q", len(header_bytes)))
f.write(header_bytes)
f.write(payload)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--paired", required=True)
parser.add_argument("--out-safetensors", default="count_v1.safetensors")
parser.add_argument("--out-onnx", default="count_v1.onnx")
parser.add_argument("--out-results", default="count_train_results.json")
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--k-fold", type=int, default=None, help="If set, run k-fold CV; else use temporal split")
parser.add_argument("--v2", action="store_true",
help="v0.0.2 training: random 80/20 split + label smoothing + early stopping "
"+ balanced sampling + temperature-scaled confidence head.")
parser.add_argument("--label-smoothing", type=float, default=0.1)
parser.add_argument("--patience", type=int, default=20)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
X, y = load_paired(Path(args.paired))
print(f"loaded {X.shape[0]} samples, X shape {X.shape}, "
f"label distribution: {dict(Counter(y.tolist()).most_common())}")
# K-fold cross-validation mode
if args.k_fold is not None:
print(f"\n=== {args.k_fold}-fold cross-validation ===")
fold_results = []
overall_t0 = time.perf_counter()
for fold_idx, (X_train, y_train, X_val, y_val) in enumerate(stratified_k_fold(X, y, k=args.k_fold)):
print(f"\nFold {fold_idx + 1}/{args.k_fold}")
X_train, X_val = standardise(X_train, X_val)
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES
cls_weight_t = torch.from_numpy(cls_weight).to(device)
Xt = torch.from_numpy(X_train).to(device)
yt = torch.from_numpy(y_train).to(device)
Xv = torch.from_numpy(X_val).to(device)
yv = torch.from_numpy(y_val).to(device)
model = CountNet().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
n_train = X_train.shape[0]
best_eval_acc = 0.0
best_state = None
for epoch in range(args.epochs):
model.train()
perm = torch.randperm(n_train, device=device)
train_loss = 0.0
train_correct = 0
n_batches = 0
for i in range(0, n_train, args.batch_size):
idx = perm[i : i + args.batch_size]
xb = Xt[idx]
yb = yt[idx]
opt.zero_grad()
count_logits, conf_logits = model(xb)
ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t)
with torch.no_grad():
pred = count_logits.argmax(dim=1)
correct_indicator = (pred == yb).float().unsqueeze(1)
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
with torch.no_grad():
conf_sigm = torch.sigmoid(conf_logits)
brier = ((conf_sigm - correct_indicator) ** 2).mean()
loss = ce + 0.3 * bce + 0.1 * brier
loss.backward()
opt.step()
train_loss += loss.item()
train_correct += (pred == yb).sum().item()
n_batches += 1
sched.step()
model.eval()
with torch.no_grad():
cl_v, _ = model(Xv)
eval_pred = cl_v.argmax(dim=1)
eval_acc = (eval_pred == yv).float().mean().item()
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
# Restore best checkpoint and final eval
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
with torch.no_grad():
cl_v, conf_v = model(Xv)
pred_v = cl_v.argmax(dim=1)
acc = (pred_v == yv).float().mean().item()
within1 = ((pred_v - yv).abs() <= 1).float().mean().item()
mae = (pred_v - yv).abs().float().mean().item()
# Per-class accuracy
per_class = {}
for k in range(COUNT_CLASSES):
mask = yv == k
n = mask.sum().item()
if n > 0:
per_class[k] = {
"support": int(n),
"accuracy": ((pred_v == yv) & mask).sum().item() / n,
}
# Spearman
conf_sigm = torch.sigmoid(conf_v).squeeze(-1)
correct = (pred_v == yv).float()
c_rank = conf_sigm.argsort().argsort().float()
r_rank = correct.argsort().argsort().float()
c_centered = c_rank - c_rank.mean()
r_centered = r_rank - r_rank.mean()
denom = (c_centered.norm() * r_centered.norm()).item()
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
fold_results.append({
"fold": fold_idx + 1,
"accuracy": acc,
"within_pm1": within1,
"mae": mae,
"spearman": spearman,
"per_class_accuracy": per_class,
})
print(f" accuracy={acc:.3f} within±1={within1:.3f} mae={mae:.3f} spearman={spearman:.3f}")
# K-fold summary
total_time = time.perf_counter() - overall_t0
accs = [r["accuracy"] for r in fold_results]
within1s = [r["within_pm1"] for r in fold_results]
maes = [r["mae"] for r in fold_results]
spears = [r["spearman"] for r in fold_results]
print(f"\n=== {args.k_fold}-fold summary ({total_time:.1f} s) ===")
print(f" accuracy: {np.mean(accs):.3f} ± {np.std(accs):.3f}")
print(f" within ±1: {np.mean(within1s):.3f} ± {np.std(within1s):.3f}")
print(f" MAE: {np.mean(maes):.3f} ± {np.std(maes):.3f}")
print(f" conf↔correct Spearman: {np.mean(spears):.3f} ± {np.std(spears):.3f}")
# Per-class summary across folds
for k in range(COUNT_CLASSES):
accs_k = [r["per_class_accuracy"].get(k, {}).get("accuracy", 0.0) for r in fold_results]
n_k = [r["per_class_accuracy"].get(k, {}).get("support", 0) for r in fold_results]
if any(n > 0 for n in n_k):
print(f" class {k}: {np.mean(accs_k):.3f} mean accuracy (support: {n_k})")
# Write k-fold results to JSON
results = {
"mode": "k_fold_cv",
"k": args.k_fold,
"backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu",
"total_time_s": total_time,
"fold_results": fold_results,
"summary": {
"mean_accuracy": float(np.mean(accs)),
"std_accuracy": float(np.std(accs)),
"mean_within_pm1": float(np.mean(within1s)),
"std_within_pm1": float(np.std(within1s)),
"mean_mae": float(np.mean(maes)),
"std_mae": float(np.std(maes)),
"mean_spearman": float(np.mean(spears)),
"std_spearman": float(np.std(spears)),
},
"hyperparameters": {
"optimizer": "AdamW",
"lr": args.lr,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"schedule": "cosine_warm_restarts",
"epochs": args.epochs,
},
}
Path(args.out_results).write_text(json.dumps(results, indent=2))
print(f"\nwrote {args.out_results}")
return
# ---------------------------------------------------------------
# v0.0.2 training path: random 80/20 + label smoothing + early
# stopping + class-balanced batch sampling + temperature scaling.
# ---------------------------------------------------------------
if args.v2:
rng = np.random.default_rng(seed=42)
idx = np.arange(X.shape[0])
rng.shuffle(idx)
n_eval = int(round(0.2 * X.shape[0]))
eval_idx, train_idx = idx[:n_eval], idx[n_eval:]
X_train, X_eval = X[train_idx], X[eval_idx]
y_train, y_eval = y[train_idx], y[eval_idx]
X_train, X_eval = standardise(X_train, X_eval)
print(f"v0.0.2 mode — random 80/20 split: train={len(y_train)} eval={len(y_eval)}")
print(f" train class dist: {dict(Counter(y_train.tolist()).most_common())}")
print(f" eval class dist: {dict(Counter(y_eval.tolist()).most_common())}")
Xt = torch.from_numpy(X_train).to(device)
yt = torch.from_numpy(y_train).to(device)
Xe = torch.from_numpy(X_eval).to(device)
ye = torch.from_numpy(y_eval).to(device)
# Class-balanced sampler: for each batch, sample with replacement
# so each class has equal expected count regardless of dataset
# distribution. With our ~533/544 split this is nearly a no-op
# but it generalises to imbalanced multi-room data later.
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
per_sample_weight = (1.0 / cls_counts[y_train])
per_sample_weight_t = torch.from_numpy(per_sample_weight.astype(np.float32)).to(device)
model = CountNet().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
n_train = X_train.shape[0]
batches_per_epoch = max(1, n_train // args.batch_size)
epoch_losses = []
t0 = time.perf_counter()
best_eval_acc = 0.0
best_state = None
epochs_without_improvement = 0
for epoch in range(args.epochs):
model.train()
train_loss = 0.0; train_correct = 0; n_batches = 0
for _ in range(batches_per_epoch):
# Balanced sample with replacement
idx_t = torch.multinomial(per_sample_weight_t, args.batch_size, replacement=True)
xb = Xt[idx_t]; yb = yt[idx_t]
opt.zero_grad()
count_logits, conf_logits = model(xb)
ce = F.cross_entropy(count_logits, yb, label_smoothing=args.label_smoothing)
with torch.no_grad():
pred = count_logits.argmax(dim=1)
correct_indicator = (pred == yb).float().unsqueeze(1)
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
with torch.no_grad():
conf_sigm = torch.sigmoid(conf_logits)
brier = ((conf_sigm - correct_indicator) ** 2).mean()
loss = ce + 0.3 * bce + 0.1 * brier
loss.backward()
opt.step()
train_loss += loss.item()
train_correct += (pred == yb).sum().item()
n_batches += 1
sched.step()
model.eval()
with torch.no_grad():
cl_e, _ = model(Xe)
eval_loss = F.cross_entropy(cl_e, ye).item()
eval_pred = cl_e.argmax(dim=1)
eval_acc = (eval_pred == ye).float().mean().item()
epoch_losses.append({
"epoch": epoch,
"train_loss": train_loss / max(1, n_batches),
"train_acc": train_correct / max(1, n_batches * args.batch_size),
"eval_loss": eval_loss,
"eval_acc": eval_acc,
})
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epoch < 5 or epoch % 25 == 0:
print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} "
f"train_acc={train_correct/(n_batches*args.batch_size):.3f} "
f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} "
f"epochs_no_improve={epochs_without_improvement}")
if epochs_without_improvement >= args.patience:
print(f"early stopping at epoch {epoch} (no improvement for {args.patience} epochs)")
break
train_time = time.perf_counter() - t0
print(f"\ntrained {epoch + 1} epochs in {train_time:.1f} s (best eval_acc {best_eval_acc:.3f})")
if best_state is not None:
model.load_state_dict(best_state)
# Temperature scaling on the confidence head — fit a scalar T s.t.
# sigmoid(conf_logits / T) is best-calibrated on the eval set.
model.eval()
with torch.no_grad():
cl_e, conf_e = model(Xe)
pred_e = cl_e.argmax(dim=1)
correct_indicator = (pred_e == ye).float()
# 1D optimisation over T via LBFGS.
T = torch.nn.Parameter(torch.ones(1, device=device))
opt_t = torch.optim.LBFGS([T], lr=0.1, max_iter=50)
def eval_t():
opt_t.zero_grad()
scaled = conf_e.squeeze(-1) / T
loss_t = F.binary_cross_entropy_with_logits(scaled, correct_indicator)
loss_t.backward()
return loss_t
opt_t.step(eval_t)
T_val = float(T.detach().cpu().item())
print(f" temperature scale T = {T_val:.4f}")
# Final eval with temperature applied.
with torch.no_grad():
cl_e, conf_e = model(Xe)
probs_e = F.softmax(cl_e, dim=1)
pred_e = cl_e.argmax(dim=1)
acc = (pred_e == ye).float().mean().item()
within1 = ((pred_e - ye).abs() <= 1).float().mean().item()
mae = (pred_e - ye).abs().float().mean().item()
per_class = {}
for k in range(COUNT_CLASSES):
mask = ye == k
n = mask.sum().item()
if n > 0:
per_class[k] = {
"support": int(n),
"accuracy": ((pred_e == ye) & mask).sum().item() / n,
}
conf_sigm = torch.sigmoid(conf_e.squeeze(-1) / T_val)
correct = (pred_e == ye).float()
c_rank = conf_sigm.argsort().argsort().float()
r_rank = correct.argsort().argsort().float()
c_centered = c_rank - c_rank.mean()
r_centered = r_rank - r_rank.mean()
denom = (c_centered.norm() * r_centered.norm()).item()
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
print(f"\n=== v0.0.2 final eval ===")
print(f" accuracy: {acc:.3f}")
print(f" within ±1: {within1:.3f}")
print(f" MAE: {mae:.3f}")
print(f" conf↔correct Spearman (post-temp): {spearman:.3f}")
for k, v in per_class.items():
print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples")
write_safetensors(model, Path(args.out_safetensors))
# Also append the temperature scalar so the cog can apply it.
# We add it by appending to the safetensors file using the
# write_safetensors helper but with the temperature recorded
# as a separate file alongside (count_v1.temperature.txt) for
# consumption by the Rust cog inference path.
Path(args.out_safetensors + ".temperature").write_text(f"{T_val}\n")
print(f"wrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)")
print(f"wrote {args.out_safetensors}.temperature ({T_val})")
# ONNX
dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device)
try:
torch.onnx.export(model, dummy, args.out_onnx, opset_version=18,
input_names=["csi_window"],
output_names=["count_logits", "conf_logits"],
dynamic_axes={"csi_window": {0: "batch"},
"count_logits": {0: "batch"},
"conf_logits": {0: "batch"}},
export_params=True, do_constant_folding=True)
print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)")
except Exception as e:
print(f"WARN: ONNX export failed: {e}")
results = {
"mode": "v0.0.2",
"backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu",
"epochs_trained": epoch + 1,
"train_time_s": train_time,
"best_eval_acc": best_eval_acc,
"final_eval_acc": acc,
"final_eval_within_pm1": within1,
"final_eval_mae": mae,
"temperature_scale": T_val,
"conf_correctness_spearman_post_temp": spearman,
"per_class_accuracy": per_class,
"hyperparameters": {
"optimizer": "AdamW",
"lr": args.lr,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"schedule": "cosine_warm_restarts",
"epochs_max": args.epochs,
"label_smoothing": args.label_smoothing,
"patience": args.patience,
"split": "random_80_20_seed_42",
"balanced_sampler": True,
"temperature_scaling": True,
},
"epoch_losses": epoch_losses,
}
Path(args.out_results).write_text(json.dumps(results, indent=2))
print(f"wrote {args.out_results}")
return
# Original temporal-split mode (kept for v0.0.1 reproducibility).
X_train, y_train, X_eval, y_eval = temporal_split(X, y, eval_frac=0.2)
X_train, X_eval = standardise(X_train, X_eval)
# Re-balance via class weights — handles the 50/50 split fine
# but also makes the loss correct under future imbalanced data.
cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32)
cls_counts = np.where(cls_counts > 0, cls_counts, 1.0)
cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES
cls_weight_t = torch.from_numpy(cls_weight).to(device)
print(f"class weights: {cls_weight.tolist()}")
Xt = torch.from_numpy(X_train).to(device)
yt = torch.from_numpy(y_train).to(device)
Xe = torch.from_numpy(X_eval).to(device)
ye = torch.from_numpy(y_eval).to(device)
model = CountNet().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1)
n_train = X_train.shape[0]
epoch_losses = []
t0 = time.perf_counter()
best_eval_acc = 0.0
best_state = None
for epoch in range(args.epochs):
model.train()
perm = torch.randperm(n_train, device=device)
train_loss = 0.0
train_correct = 0
n_batches = 0
for i in range(0, n_train, args.batch_size):
idx = perm[i : i + args.batch_size]
xb = Xt[idx]
yb = yt[idx]
opt.zero_grad()
count_logits, conf_logits = model(xb)
# Categorical cross-entropy for count.
ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t)
# Confidence head: train against `argmax == truth` indicator.
with torch.no_grad():
pred = count_logits.argmax(dim=1)
correct_indicator = (pred == yb).float().unsqueeze(1)
bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator)
# Brier-score uncertainty calibration on the conf head — sharpens
# the calibration so the sigmoid output is a real probability.
with torch.no_grad():
conf_sigm = torch.sigmoid(conf_logits)
brier = ((conf_sigm - correct_indicator) ** 2).mean()
loss = ce + 0.3 * bce + 0.1 * brier
loss.backward()
opt.step()
train_loss += loss.item()
train_correct += (pred == yb).sum().item()
n_batches += 1
sched.step()
model.eval()
with torch.no_grad():
cl_e, _ = model(Xe)
eval_loss = F.cross_entropy(cl_e, ye, weight=cls_weight_t).item()
eval_pred = cl_e.argmax(dim=1)
eval_acc = (eval_pred == ye).float().mean().item()
eval_within1 = ((eval_pred - ye).abs() <= 1).float().mean().item()
epoch_losses.append({
"epoch": epoch,
"train_loss": train_loss / n_batches,
"train_acc": train_correct / n_train,
"eval_loss": eval_loss,
"eval_acc": eval_acc,
"eval_within_pm1": eval_within1,
})
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
if epoch < 5 or epoch % 50 == 0 or epoch == args.epochs - 1:
print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} "
f"train_acc={train_correct/n_train:.3f} "
f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} "
f"within±1={eval_within1:.3f}")
train_time = time.perf_counter() - t0
print(f"\ntrained {args.epochs} epochs in {train_time:.1f} s")
print(f"best eval_acc: {best_eval_acc:.3f}")
# Restore best checkpoint
if best_state is not None:
model.load_state_dict(best_state)
# Eval breakdown
model.eval()
with torch.no_grad():
cl_e, conf_e = model(Xe)
probs_e = torch.softmax(cl_e, dim=1)
pred_e = cl_e.argmax(dim=1)
acc = (pred_e == ye).float().mean().item()
within1 = ((pred_e - ye).abs() <= 1).float().mean().item()
mae = (pred_e - ye).abs().float().mean().item()
# Per-class accuracy
per_class = {}
for k in range(COUNT_CLASSES):
mask = ye == k
n = mask.sum().item()
if n > 0:
per_class[k] = {
"support": int(n),
"accuracy": ((pred_e == ye) & mask).sum().item() / n,
}
# Confidence-accuracy calibration: Spearman over (predicted-correct, confidence)
conf_sigm = torch.sigmoid(conf_e).squeeze(-1)
correct = (pred_e == ye).float()
# Spearman = Pearson over ranks
c_rank = conf_sigm.argsort().argsort().float()
r_rank = correct.argsort().argsort().float()
c_centered = c_rank - c_rank.mean()
r_centered = r_rank - r_rank.mean()
denom = (c_centered.norm() * r_centered.norm()).item()
spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0
print(f"\n=== final eval ===")
print(f" accuracy: {acc:.3f}")
print(f" within ±1: {within1:.3f}")
print(f" MAE: {mae:.3f}")
print(f" conf↔correct Spearman: {spearman:.3f}")
for k, v in per_class.items():
print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples")
# Save safetensors
write_safetensors(model, Path(args.out_safetensors))
print(f"\nwrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)")
# ONNX export
dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device)
try:
torch.onnx.export(
model, dummy, args.out_onnx,
opset_version=18,
input_names=["csi_window"],
output_names=["count_logits", "conf_logits"],
dynamic_axes={
"csi_window": {0: "batch"},
"count_logits": {0: "batch"},
"conf_logits": {0: "batch"},
},
export_params=True,
do_constant_folding=True,
)
print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)")
except Exception as e:
print(f"WARN: ONNX export failed: {e}")
# Results JSON
results = {
"backend": "candle-cuda" if device.type == "cuda" else "candle-cpu",
"device": str(device),
"epochs": args.epochs,
"train_time_s": train_time,
"best_eval_acc": best_eval_acc,
"final_eval_acc": acc,
"final_eval_within_pm1": within1,
"final_eval_mae": mae,
"conf_correctness_spearman": spearman,
"per_class_accuracy": per_class,
"hyperparameters": {
"optimizer": "AdamW",
"lr": args.lr,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"schedule": "cosine_warm_restarts",
"epochs": args.epochs,
"loss": "cross_entropy(count) + 0.3*bce(conf) + 0.1*brier(conf)",
"z_score_normalisation": True,
"class_weights": cls_weight.tolist(),
},
"epoch_losses": epoch_losses,
}
Path(args.out_results).write_text(json.dumps(results, indent=2))
print(f"wrote {args.out_results} ({Path(args.out_results).stat().st_size} bytes)")
if __name__ == "__main__":
main()