mirror of
https://github.com/ruvnet/RuView
synced 2026-06-17 11:33:19 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d639c747df | |||
| 42c764652d | |||
| db02956c22 | |||
| c84ea39e62 | |||
| 760d05026c |
@@ -277,3 +277,6 @@ aether-arena/staging/
|
||||
# MM-Fi benchmark dataset archives — large data, fetch separately, never commit
|
||||
assets/MM-Fi/E0*.zip
|
||||
assets/MM-Fi/*.zip
|
||||
|
||||
# through-wall demo: regenerable trained model artifact
|
||||
examples/through-wall/model/
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<title>WiFlow · live WiFi-inferred pose</title>
|
||||
<style>
|
||||
:root{--bg:#0a0c10;--panel:#11151c;--amber:#ffb840;--green:#46e08a;--red:#ff5a5a;--mute:#7d8796;--line:#1d2430}
|
||||
*{box-sizing:border-box}
|
||||
body{margin:0;background:var(--bg);color:#dfe6ee;font:14px/1.5 'JetBrains Mono',ui-monospace,Menlo,monospace}
|
||||
header{padding:14px 18px;border-bottom:1px solid var(--line);display:flex;align-items:center;gap:14px;flex-wrap:wrap}
|
||||
h1{font-size:15px;margin:0;letter-spacing:1px;text-transform:uppercase;font-weight:600}
|
||||
h1 span{color:var(--amber)}
|
||||
#banner{margin-left:auto;padding:5px 12px;border-radius:5px;font-weight:600;font-size:12px;letter-spacing:.5px}
|
||||
.live{background:rgba(70,224,138,.15);color:var(--green);border:1px solid var(--green)}
|
||||
.sim{background:rgba(255,184,64,.15);color:var(--amber);border:1px solid var(--amber)}
|
||||
.down{background:rgba(255,90,90,.15);color:var(--red);border:1px solid var(--red)}
|
||||
main{display:flex;gap:18px;padding:18px;flex-wrap:wrap}
|
||||
.card{background:var(--panel);border:1px solid var(--line);border-radius:10px;padding:14px}
|
||||
canvas{background:#070a0e;border-radius:8px;display:block}
|
||||
.label{font-size:11px;text-transform:uppercase;letter-spacing:1.5px;color:var(--mute);margin-bottom:8px}
|
||||
.stats{min-width:240px}
|
||||
.row{display:flex;justify-content:space-between;padding:3px 0;border-bottom:1px dashed var(--line)}
|
||||
.row .k{color:var(--mute)} .row .v{color:var(--amber);font-variant-numeric:tabular-nums}
|
||||
.v.green{color:var(--green)}
|
||||
.note{margin-top:12px;font-size:11px;color:var(--mute);line-height:1.6;max-width:300px}
|
||||
.note b{color:#dfe6ee}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>WiFlow · <span>live WiFi-inferred pose</span></h1>
|
||||
<div id="banner" class="down">CONNECTING…</div>
|
||||
</header>
|
||||
<main>
|
||||
<div class="card">
|
||||
<div class="label">CSI → pose (skeleton) overlaid on your laptop camera</div>
|
||||
<div id="stage" style="width:420px;height:560px;border-radius:8px;overflow:hidden;background:#070a0e">
|
||||
<video id="cam" autoplay muted playsinline style="position:absolute;width:2px;height:2px;opacity:0;pointer-events:none"></video>
|
||||
<canvas id="cv" width="420" height="560"></canvas>
|
||||
</div>
|
||||
<div style="margin-top:10px;display:flex;gap:8px;align-items:center;flex-wrap:wrap">
|
||||
<button id="camBtn" style="background:var(--amber);color:#0a0c10;border:0;border-radius:6px;padding:7px 14px;font:inherit;font-weight:600;cursor:pointer">enable laptop camera</button>
|
||||
<select id="camSel" style="display:none;background:var(--panel);color:#dfe6ee;border:1px solid var(--line);border-radius:6px;padding:6px;font:inherit;max-width:220px"></select>
|
||||
</div>
|
||||
<div id="camStatus" style="margin-top:6px;font-size:11px;color:var(--mute)">camera: off</div>
|
||||
<div class="note" style="margin-top:8px">Camera is a <b>visual reference only</b> — it is NOT fed to the model. Overlay alignment is approximate (model trained in a different camera's frame).</div>
|
||||
</div>
|
||||
<div class="card stats">
|
||||
<div class="label">live</div>
|
||||
<div class="row"><span class="k">CSI source</span><span class="v" id="src">—</span></div>
|
||||
<div class="row"><span class="k">nodes</span><span class="v" id="nodes">—</span></div>
|
||||
<div class="row"><span class="k">presence</span><span class="v" id="pres">—</span></div>
|
||||
<div class="row"><span class="k">motion</span><span class="v" id="motion">—</span></div>
|
||||
<div class="row"><span class="k">pose fps</span><span class="v" id="fps">—</span></div>
|
||||
<div class="note">
|
||||
This skeleton is inferred <b>from WiFi CSI only</b> — no camera in the loop here. A model was
|
||||
trained on paired (camera-pose, CSI) data in this room (ADR-079/180).
|
||||
<br/><br/>
|
||||
<b>Honest accuracy:</b> ~<b>59.5% PCK@0.10</b> on held-out data (vs a 50% mean-pose baseline →
|
||||
<b>+9.4 pp real signal</b>). It captures <b>coarse</b> pose; fine detail is weak (PCK@0.05 ≈ 24%).
|
||||
Same person / room / session — not validated cross-day or through-wall.
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
<script>
|
||||
const POSE_WS = (new URLSearchParams(location.search)).get('ws') || `ws://${location.hostname||'localhost'}:8770/pose`;
|
||||
const cv = document.getElementById('cv'), ctx = cv.getContext('2d');
|
||||
const $ = id => document.getElementById(id);
|
||||
let edges = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12],[11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]];
|
||||
let last = null, frames = 0, t0 = performance.now();
|
||||
|
||||
function banner(state, txt){ const b=$('banner'); b.className=state; b.textContent=txt; }
|
||||
|
||||
// per-joint smoothing (EMA) so dropped/jittery CSI frames render fluidly (ADR-180 dead-reckoning, lite)
|
||||
let sm = null;
|
||||
function smooth(kps){
|
||||
if(!sm){ sm = kps.map(p=>[p[0],p[1]]); return sm; }
|
||||
const a=0.35; for(let i=0;i<kps.length;i++){ sm[i][0]+=a*(kps[i][0]-sm[i][0]); sm[i][1]+=a*(kps[i][1]-sm[i][1]); }
|
||||
return sm;
|
||||
}
|
||||
const camEl=document.getElementById('cam');
|
||||
function draw(p){
|
||||
const W=cv.width, H=cv.height;
|
||||
// paint the live camera frame onto the canvas (robust — no z-index/overlay tricks)
|
||||
if(camEl && camEl.videoWidth>0){
|
||||
ctx.save(); ctx.globalAlpha=0.9;
|
||||
// cover-fit the camera frame into the canvas
|
||||
const vr=camEl.videoWidth/camEl.videoHeight, cr=W/H;
|
||||
let dw=W, dh=H, dx=0, dy=0;
|
||||
if(vr>cr){ dh=H; dw=H*vr; dx=(W-dw)/2; } else { dw=W; dh=W/vr; dy=(H-dh)/2; }
|
||||
ctx.drawImage(camEl, dx, dy, dw, dh); ctx.restore();
|
||||
} else {
|
||||
ctx.fillStyle='#070a0e'; ctx.fillRect(0,0,W,H);
|
||||
}
|
||||
if(!p || !p.kps){ return; }
|
||||
const s = smooth(p.kps);
|
||||
const k = s.map(([x,y])=>[x*W, y*H]);
|
||||
ctx.lineWidth=5; ctx.strokeStyle=p.presence?'rgba(70,224,138,.95)':'rgba(125,135,150,.8)'; ctx.lineCap='round';
|
||||
ctx.shadowColor='rgba(70,224,138,.6)'; ctx.shadowBlur=8;
|
||||
for(const [a,b] of edges){ ctx.beginPath(); ctx.moveTo(k[a][0],k[a][1]); ctx.lineTo(k[b][0],k[b][1]); ctx.stroke(); }
|
||||
ctx.shadowBlur=0;
|
||||
for(const [x,y] of k){ ctx.beginPath(); ctx.arc(x,y,5,0,7); ctx.fillStyle=p.presence?'#ffb840':'#667'; ctx.fill(); }
|
||||
}
|
||||
|
||||
// ---- laptop webcam (visual reference only; NOT fed to the model) ----
|
||||
let camStream=null;
|
||||
async function startCam(deviceId){
|
||||
if(camStream){ camStream.getTracks().forEach(t=>t.stop()); }
|
||||
const constraints = deviceId ? {video:{deviceId:{exact:deviceId}}} : {video:true};
|
||||
const st=document.getElementById('camStatus');
|
||||
try{
|
||||
st.textContent='camera: requesting…';
|
||||
camStream = await navigator.mediaDevices.getUserMedia(constraints);
|
||||
const v=document.getElementById('cam'); v.muted=true; v.srcObject=camStream;
|
||||
v.onloadedmetadata=()=>{ v.play().catch(err=>st.textContent='camera: play() blocked '+err.name); };
|
||||
await v.play().catch(()=>{});
|
||||
const tr=camStream.getVideoTracks()[0]; const ss=tr.getSettings();
|
||||
// live readout: shows if real frames are flowing (videoWidth>0) and which device
|
||||
const tick=()=>{ st.textContent = `camera: "${tr.label}" ${v.videoWidth}x${v.videoHeight} ${tr.readyState} ${v.paused?'PAUSED':'playing'}`; };
|
||||
tick(); setInterval(tick, 1000);
|
||||
document.getElementById('camBtn').textContent='switch camera ↻';
|
||||
// populate the picker now that we have permission (labels need permission)
|
||||
const devs = (await navigator.mediaDevices.enumerateDevices()).filter(d=>d.kind==='videoinput');
|
||||
const sel=document.getElementById('camSel'); sel.style.display = devs.length>1?'inline-block':'none';
|
||||
sel.innerHTML = devs.map((d,i)=>`<option value="${d.deviceId}">${d.label||('camera '+(i+1))}</option>`).join('');
|
||||
const cur = camStream.getVideoTracks()[0].getSettings().deviceId; if(cur) sel.value=cur;
|
||||
}catch(e){
|
||||
document.getElementById('camBtn').textContent = 'camera error: '+e.name+(e.name==='NotReadableError'?' (in use by Zoom/Teams?)':'');
|
||||
console.error('getUserMedia', e);
|
||||
}
|
||||
}
|
||||
document.getElementById('camBtn').addEventListener('click', ()=>startCam());
|
||||
document.getElementById('camSel').addEventListener('change', e=>startCam(e.target.value));
|
||||
|
||||
function connect(){
|
||||
banner('down','CONNECTING…');
|
||||
const ws = new WebSocket(POSE_WS);
|
||||
ws.onopen = ()=> banner('sim','WAITING FOR POSE…');
|
||||
ws.onmessage = ev => {
|
||||
const d = JSON.parse(ev.data);
|
||||
if(d.type==='meta'){ edges = d.edges; return; }
|
||||
if(d.type!=='pose') return;
|
||||
last=d; frames++;
|
||||
if(d.src==='esp32') banner('live','LIVE — WiFi-inferred pose (real ESP32 CSI)');
|
||||
else banner('sim','SIMULATED CSI — not real ('+d.src+')');
|
||||
$('src').textContent=d.src; $('src').className = d.src==='esp32'?'v green':'v';
|
||||
$('nodes').textContent=(d.nodes||[]).join(', ')||'—';
|
||||
$('pres').textContent=d.presence?'PRESENT':'—';
|
||||
$('motion').textContent=(d.motion!=null?Math.round(d.motion):'—');
|
||||
};
|
||||
ws.onclose = ()=>{ banner('down','NO BRIDGE — start wiflow_infer.py'); setTimeout(connect,1500); };
|
||||
ws.onerror = ()=> ws.close();
|
||||
}
|
||||
function loop(){ draw(last); const now=performance.now(); if(now-t0>1000){ $('fps').textContent=frames; frames=0; t0=now; } requestAnimationFrame(loop); }
|
||||
connect(); loop();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Rigorous A/B for WiFlow CSI->pose: is the held-out PCK real signal or split leakage?
|
||||
|
||||
For a dataset of {csi:[D], kps:17x[x,y,vis]} pairs, train the SAME small MLP under
|
||||
several train/val SPLITS and report held-out PCK@0.10 vs the mean-pose baseline:
|
||||
|
||||
- chronological_80_20 : last 20% in time (val temporally ADJACENT to train -> leaks
|
||||
via CSI/pose autocorrelation; this is what gave us +9.4)
|
||||
- random_80_20 : shuffled (val frames interleaved with train -> MAX leak)
|
||||
- blocked_gap : hold out a contiguous MIDDLE block with a time GAP buffer on
|
||||
each side so val is NOT adjacent to any train frame -> the
|
||||
honest, leakage-controlled test
|
||||
|
||||
If the model beats baseline on chronological/random but COLLAPSES to ~baseline on
|
||||
blocked_gap, the apparent signal was temporal leakage, not generalizable CSI->pose.
|
||||
|
||||
Usage (ruvultra venv): python wiflow_ab.py --data ~/wiflow-room/dataset.jsonl
|
||||
"""
|
||||
import argparse, json, sys
|
||||
import numpy as np, torch, torch.nn as nn
|
||||
|
||||
def _rec(r, X, Y, V, B):
|
||||
X.append(r["csi"]); kp=r["kps"]
|
||||
if kp and isinstance(kp[0], (list,tuple)): # 17 x [x,y(,vis)]
|
||||
Y.append([c for k in kp for c in (k[0],k[1])]); V.append([(k[2] if len(k)>2 else 1.0) for k in kp])
|
||||
else: # flat 34 (browser export, no vis)
|
||||
Y.append(list(kp)); V.append([1.0]*17)
|
||||
B.append(r.get("bucket"))
|
||||
|
||||
def load(path):
|
||||
X,Y,V,B=[],[],[],[]
|
||||
txt=open(path).read().strip()
|
||||
if txt[:1] in "[{": # JSON (browser export: dict{samples:[]} or bare array)
|
||||
d=json.loads(txt)
|
||||
rows = d if isinstance(d,list) else d.get("samples", d.get("data", []))
|
||||
for r in rows: _rec(r,X,Y,V,B)
|
||||
else: # JSONL (python capture)
|
||||
for line in txt.splitlines():
|
||||
if line.strip(): _rec(json.loads(line),X,Y,V,B)
|
||||
return np.array(X,np.float32), np.array(Y,np.float32), np.array(V,np.float32), B
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(s,din,dout):
|
||||
super().__init__()
|
||||
s.n=nn.Sequential(nn.Linear(din,384),nn.ReLU(),nn.Dropout(.35),
|
||||
nn.Linear(384,192),nn.ReLU(),nn.Dropout(.35),
|
||||
nn.Linear(192,96),nn.ReLU(),nn.Linear(96,dout),nn.Sigmoid())
|
||||
def forward(s,x): return s.n(x)
|
||||
|
||||
def pck(pred,gt,vis,thr=0.10):
|
||||
p=pred.reshape(-1,17,2); g=gt.reshape(-1,17,2)
|
||||
d=np.linalg.norm(p-g,axis=2); m=vis>0.5
|
||||
return float((d[m]<thr).mean()) if m.any() else 0.0
|
||||
|
||||
def split_idx(n, kind, B=None):
|
||||
idx=np.arange(n)
|
||||
if kind=="chronological_80_20":
|
||||
c=int(n*.8); return idx[:c], idx[c:]
|
||||
if kind=="random_80_20":
|
||||
rng=np.random.default_rng(0); p=rng.permutation(n); c=int(n*.8); return p[:c], p[c:]
|
||||
if kind=="blocked_gap":
|
||||
# val = contiguous middle 20%; a WIDE 10% time gap each side guarantees no train
|
||||
# frame is temporally adjacent to a val frame (kills frame-autocorrelation leakage).
|
||||
v0=int(n*.4); v1=int(n*.6); gap=int(n*.10)
|
||||
val=idx[v0:v1]; train=np.concatenate([idx[:max(0,v0-gap)], idx[min(n,v1+gap):]])
|
||||
return train, val
|
||||
if kind=="grouped_bucket":
|
||||
# hold out ENTIRE activity buckets -> val poses/activities never seen in train.
|
||||
# the strictest leakage-free test (only when bucket labels exist).
|
||||
b=np.array([x if x is not None else -1 for x in B])
|
||||
uniq=[u for u in sorted(set(b.tolist())) if u!=-1]
|
||||
if len(uniq)<3: raise ValueError("too few buckets")
|
||||
hold=set(uniq[::max(1,len(uniq)//3)][:max(1,len(uniq)//3)]) # ~1/3 of activities held out
|
||||
val=idx[np.isin(b,list(hold))]; train=idx[~np.isin(b,list(hold))]
|
||||
return train, val
|
||||
raise ValueError(kind)
|
||||
|
||||
def run(X,Y,V,tr,va,epochs=250,seed=0):
|
||||
torch.manual_seed(seed); np.random.seed(seed) # seed weight init + batch shuffle
|
||||
dev="cuda" if torch.cuda.is_available() else "cpu"
|
||||
mu,sd=X[tr].mean(0),X[tr].std(0)+1e-6
|
||||
Xtr=torch.tensor((X[tr]-mu)/sd).to(dev); Ytr=torch.tensor(Y[tr]).to(dev)
|
||||
Xva=torch.tensor((X[va]-mu)/sd).to(dev)
|
||||
net=Net(X.shape[1],Y.shape[1]).to(dev)
|
||||
opt=torch.optim.Adam(net.parameters(),lr=1e-3,weight_decay=1e-4); lf=nn.MSELoss()
|
||||
best=(1e9,None)
|
||||
for ep in range(epochs):
|
||||
net.train(); perm=torch.randperm(len(Xtr),device=dev)
|
||||
for i in range(0,len(Xtr),64):
|
||||
j=perm[i:i+64]; opt.zero_grad(); loss=lf(net(Xtr[j]),Ytr[j]); loss.backward(); opt.step()
|
||||
net.eval()
|
||||
with torch.no_grad(): pv=net(Xva).cpu().numpy()
|
||||
vl=float(((pv-Y[va])**2).mean())
|
||||
if vl<best[0]: best=(vl,pv)
|
||||
base=np.tile(Y[tr].mean(0),(len(va),1))
|
||||
return pck(best[1],Y[va],V[va]), pck(base,Y[va],V[va])
|
||||
|
||||
def main():
|
||||
ap=argparse.ArgumentParser(); ap.add_argument("--data",required=True)
|
||||
ap.add_argument("--epochs",type=int,default=250); ap.add_argument("--seeds",type=int,default=3)
|
||||
a=ap.parse_args()
|
||||
X,Y,V,B=load(a.data); n=len(X)
|
||||
has_buckets=any(x is not None for x in B)
|
||||
print(f"[ab] {n} samples, X={X.shape}, buckets={'yes' if has_buckets else 'no'}, "
|
||||
f"seeds={a.seeds}, epochs={a.epochs}\n")
|
||||
print(f"{'split':<22}{'model PCK@0.10':>16}{'baseline':>11}{'delta (mean±sd)':>20} verdict")
|
||||
print("-"*86)
|
||||
splits=["chronological_80_20","random_80_20","blocked_gap"]+(["grouped_bucket"] if has_buckets else [])
|
||||
for kind in splits:
|
||||
try:
|
||||
tr,va=split_idx(n,kind,B)
|
||||
ms=[]; bs=[]
|
||||
for s in range(a.seeds):
|
||||
m,b=run(X,Y,V,tr,va,a.epochs,seed=s); ms.append(m); bs.append(b)
|
||||
ms=np.array(ms)*100; bs=np.array(bs)*100; ds=ms-bs
|
||||
dm,dsd=ds.mean(),ds.std()
|
||||
# REAL only if the mean delta minus 1 sd still clears the 1.5pp threshold (robust to seed variance)
|
||||
verdict = "REAL signal" if dm-dsd>1.5 else ("weak/uncertain" if dm>1.5 else "no signal (==baseline)")
|
||||
print(f"{kind:<22}{ms.mean():>13.1f}±{ms.std():>3.1f}{bs.mean():>10.1f}%{dm:>+12.1f}±{dsd:>4.1f}pp {verdict}")
|
||||
except Exception as e:
|
||||
print(f"{kind:<22} skipped: {e}")
|
||||
print(f"\nmean±sd over {a.seeds} seeds (weight init + batch order). blocked_gap = 10% time gap each")
|
||||
print("side; grouped_bucket holds out ENTIRE activities (strictest). If only the LEAKY splits")
|
||||
print("(chronological/random) beat baseline, the apparent signal is leakage, not generalizable pose.")
|
||||
|
||||
if __name__=="__main__": main()
|
||||
@@ -112,7 +112,11 @@
|
||||
<div class="label">empty-room baseline (ADR-151) — step OUT of the space</div>
|
||||
<canvas id="calCv" width="420" height="300"></canvas>
|
||||
<div style="margin-top:10px;display:flex;gap:8px;align-items:center;flex-wrap:wrap">
|
||||
<button id="calBtn" class="btn">calibrate baseline (10 s)</button>
|
||||
<button id="detBtn" class="btn">① detect ESP32 sensors</button>
|
||||
<span id="detNodes" class="v">not detected</span>
|
||||
</div>
|
||||
<div style="margin-top:10px;display:flex;gap:8px;align-items:center;flex-wrap:wrap">
|
||||
<button id="calBtn" class="btn">② calibrate baseline (10 s)</button>
|
||||
<button id="recalBtn" class="ghost btn">recalibrate</button>
|
||||
<label class="note" style="margin:0">get-ready countdown
|
||||
<input id="calReady" type="number" value="5" min="3" max="15" style="width:64px"> s</label>
|
||||
@@ -285,9 +289,15 @@
|
||||
// wss when served over https (mobile/secure-context safe), else ws; ?ws= overrides
|
||||
const CSI_WS = (new URLSearchParams(location.search)).get('ws')
|
||||
|| `${location.protocol === 'https:' ? 'wss' : 'ws'}://${location.hostname || 'localhost'}:8765/ws/sensing`;
|
||||
const NODE_IDS = [9, 13]; // per-node features in this fixed order (matches Python pipeline)
|
||||
// Per-node feature schema — AUTO-DETECTED from the live stream (see detectSensors).
|
||||
// [9,13] is only the fallback until detection runs. ORDER is fixed (sorted ascending)
|
||||
// so the model's input layout is stable across capture / train / infer.
|
||||
let NODE_IDS = [9, 13];
|
||||
const FIELD_LEN = 400; // signal_field.values padded/truncated to 400
|
||||
const CSI_DIM = 4 + NODE_IDS.length * 3 + FIELD_LEN; // 4 + 6 + 400 = 410
|
||||
let CSI_DIM = 4 + NODE_IDS.length * 3 + FIELD_LEN; // 4 global + 3/node + 400 field
|
||||
function recomputeCsiDim(){ CSI_DIM = 4 + NODE_IDS.length * 3 + FIELD_LEN; }
|
||||
let sensorsDetected = false; // true once a detect (auto/manual/restored) has locked the node set
|
||||
let autoDetectStarted = false; // one-shot guard for the auto-detect on first live frame
|
||||
const N_KP = 17, OUT_DIM = N_KP * 2; // 17 COCO keypoints -> 34 coords
|
||||
const BASELINE_SECONDS = 10; // empty-room calibration window
|
||||
const EPS = 1e-6;
|
||||
@@ -333,9 +343,9 @@ async function selectBackend(){
|
||||
// ============================================================================
|
||||
// CSI vector construction — MUST match wiflow_capture.py csi_vector() exactly.
|
||||
// [mean_rssi, variance, motion_band_power, breathing_band_power] (4 global)
|
||||
// + for node 9 then node 13: [mean_rssi, variance, motion_band_power] (6 per-node)
|
||||
// + for each node in NODE_IDS order: [mean_rssi, variance, motion_band_power] (3 per-node)
|
||||
// + signal_field.values padded/truncated to 400 (400 field)
|
||||
// = 410-d (RAW — baseline-normalization applied separately, see baselineNorm)
|
||||
// = CSI_DIM-d (RAW — baseline-normalization applied separately, see baselineNorm)
|
||||
// ============================================================================
|
||||
function csiVector(frame){
|
||||
const f = frame.features || {};
|
||||
@@ -368,6 +378,87 @@ function baselineNorm(vecRaw){
|
||||
return out;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ESP32 sensor auto-detection
|
||||
// Sniff the live /ws/sensing stream, find which node_ids are actually present
|
||||
// and healthy, and lock that ordered set as the per-node schema (NODE_IDS/CSI_DIM).
|
||||
// The node set defines the model's input dimension, so detection must run BEFORE
|
||||
// calibration + capture; changing it invalidates a baseline/dataset built on a
|
||||
// different set (we confirm, then reset, on a manual re-detect).
|
||||
// ============================================================================
|
||||
async function detectSensors(ms = 3000){
|
||||
const tally = {}; // node_id -> { seen, fps, rssi }
|
||||
let frames = 0;
|
||||
const t0 = performance.now();
|
||||
const el = $('detNodes'); if (el){ el.textContent = 'scanning…'; el.className = 'v'; }
|
||||
while (performance.now() - t0 < ms){
|
||||
if (latestCSI.frame && latestCSI.source === 'esp32'){
|
||||
frames++;
|
||||
for (const nf of (latestCSI.frame.node_features || [])){
|
||||
const id = nf.node_id; if (id == null) continue;
|
||||
const f = nf.features || {};
|
||||
const t = (tally[id] || (tally[id] = { seen:0, fps:0, rssi:0 }));
|
||||
t.seen++; t.fps += (+nf.frame_rate_hz || 0);
|
||||
t.rssi += (+f.mean_rssi || +nf.rssi_dbm || 0);
|
||||
}
|
||||
}
|
||||
await new Promise(r => setTimeout(r, 100));
|
||||
}
|
||||
// healthy = seen in >40% of sampled frames (filters transient / duplicate ids)
|
||||
const healthy = Object.keys(tally).map(k => ({
|
||||
id:+k, seen:tally[k].seen, fps:tally[k].fps/tally[k].seen, rssi:tally[k].rssi/tally[k].seen }))
|
||||
.filter(n => n.seen >= Math.max(2, frames * 0.4))
|
||||
.sort((a,b)=> a.id - b.id);
|
||||
return { healthy, frames };
|
||||
}
|
||||
|
||||
function renderDetectedSensors(list){
|
||||
const el = $('detNodes'); if (!el) return;
|
||||
el.textContent = list.length
|
||||
? list.map(n => `#${n.id} (${Math.round(n.fps)}fps, ${Math.round(n.rssi)}dB)`).join(' · ')
|
||||
: 'none found';
|
||||
el.className = list.length ? 'v green' : 'v red';
|
||||
}
|
||||
|
||||
async function runDetect(manual){
|
||||
const { healthy, frames } = await detectSensors(manual ? 4000 : 3000);
|
||||
if (!healthy.length){
|
||||
const el = $('detNodes');
|
||||
if (el){ el.textContent = frames ? 'no healthy nodes' : 'no live CSI (start sensing-server / esp32)';
|
||||
el.className = 'v red'; }
|
||||
return;
|
||||
}
|
||||
const ids = healthy.map(n => n.id);
|
||||
const changed = ids.length !== NODE_IDS.length || ids.some((v,i)=> v !== NODE_IDS[i]);
|
||||
if (changed && (baseline || SAMPLES.length)){
|
||||
const ok = confirm(
|
||||
`Detected sensors [${ids.join(', ')}] differ from the current set [${NODE_IDS.join(', ')}].\n\n` +
|
||||
`The node set defines the model input, so switching invalidates the existing baseline` +
|
||||
(SAMPLES.length ? ` and ${SAMPLES.length} captured samples` : ``) +
|
||||
`. Reset and use the detected set?`);
|
||||
if (!ok){ renderDetectedSensors(healthy); return; }
|
||||
if (baseline){ baseline = null; stageDone.calibrate = false; idbDel('baseline');
|
||||
$('calStatus').textContent = 'NOT CALIBRATED'; $('calStatus').className = 'v'; $('calBar').style.width = '0%'; }
|
||||
if (SAMPLES.length){ SAMPLES = []; covCounts = new Array(BUCKETS.length).fill(0);
|
||||
idbPut('samples', []); $('capN').textContent = '0'; $('trN').textContent = '0'; renderCoverage(); }
|
||||
}
|
||||
NODE_IDS = ids; recomputeCsiDim(); sensorsDetected = true;
|
||||
idbPut('nodeIds', NODE_IDS);
|
||||
renderDetectedSensors(healthy);
|
||||
refreshGates();
|
||||
}
|
||||
|
||||
async function restoreNodeIds(){
|
||||
try{
|
||||
const ids = await idbGet('nodeIds');
|
||||
if (Array.isArray(ids) && ids.length){
|
||||
NODE_IDS = ids.slice(); recomputeCsiDim(); sensorsDetected = true;
|
||||
const el = $('detNodes');
|
||||
if (el){ el.textContent = 'restored: ' + NODE_IDS.map(i => '#' + i).join(' '); el.className = 'v'; }
|
||||
}
|
||||
}catch(e){ /* ignore */ }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CSI WebSocket
|
||||
// ============================================================================
|
||||
@@ -388,6 +479,11 @@ function connectCSI(){
|
||||
source: src,
|
||||
nodes: (d.nodes || []).map(n => n.node_id).filter(x => x != null).sort((a,b)=>a-b)
|
||||
};
|
||||
// auto-detect the sensor set once, on the first live frame, only when starting fresh
|
||||
// (no baseline / no samples) so we never silently change a schema work is built on.
|
||||
if (src === 'esp32' && !sensorsDetected && !autoDetectStarted && !baseline && SAMPLES.length === 0){
|
||||
autoDetectStarted = true; runDetect(false);
|
||||
}
|
||||
if (src === 'esp32') banner('live','LIVE — real ESP32 CSI');
|
||||
else banner('sim',`SIMULATED — not real (source=${src})`);
|
||||
};
|
||||
@@ -599,6 +695,7 @@ function finishCalibration(){
|
||||
refreshGates();
|
||||
}
|
||||
$('calBtn').addEventListener('click', startCalibration);
|
||||
$('detBtn').addEventListener('click', ()=> runDetect(true));
|
||||
$('recalBtn').addEventListener('click', ()=>{ baseline = null; stageDone.calibrate = false;
|
||||
$('calStatus').textContent = 'NOT CALIBRATED'; $('calStatus').className = 'v';
|
||||
$('calBar').style.width = '0%'; $('calN').textContent = '0'; idbDel('baseline'); refreshGates(); startCalibration(); });
|
||||
@@ -736,7 +833,7 @@ $('clrBtn').addEventListener('click', async ()=>{
|
||||
$('expBtn').addEventListener('click', ()=>{
|
||||
const out = {
|
||||
format: 'wiflow-browser-dataset', version: 1, exported: new Date().toISOString(),
|
||||
csi_dim: CSI_DIM, out_dim: OUT_DIM, buckets: BUCKETS,
|
||||
csi_dim: CSI_DIM, out_dim: OUT_DIM, buckets: BUCKETS, nodes: NODE_IDS.slice(),
|
||||
note: 'csi is baseline-normalized (ADR-151 deviation-from-baseline); kps are 17 COCO keypoints in [0,1] image coords',
|
||||
samples: SAMPLES.map((s,i)=>({ csi: Array.from(s.csi), kps: Array.from(s.kps), bucket: s.bucket, t: (s.t!=null?s.t:i) }))
|
||||
};
|
||||
@@ -1152,6 +1249,7 @@ function inferLoop(){
|
||||
(async function boot(){
|
||||
connectCSI();
|
||||
await selectBackend();
|
||||
await restoreNodeIds(); // restore a previously-detected sensor set (fixes CSI_DIM before baseline)
|
||||
await loadBaseline();
|
||||
await idbLoad();
|
||||
await loadModel();
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python3
|
||||
"""WiFlow-style camera-supervised capture (ADR-079 / ADR-180).
|
||||
|
||||
Runs on a box with BOTH a camera (ground truth) and reachable live CSI:
|
||||
- opens a camera, runs MediaPipe Pose -> 17 COCO keypoints (the LABEL),
|
||||
- subscribes to the sensing-server /ws/sensing (the INPUT: CSI features +
|
||||
20x20 signal-field),
|
||||
- writes timestamp-aligned (csi -> pose) pairs to a JSONL dataset.
|
||||
|
||||
This is the *collect* phase of camera-supervised CSI->pose training. The camera
|
||||
and the CSI nodes MUST see the same person in the same space at the same time,
|
||||
or the pairs are meaningless. Honest by construction: we only emit a pair when
|
||||
BOTH a confident camera pose AND a live (source=esp32) CSI frame are present in
|
||||
the same ~100 ms window.
|
||||
|
||||
Usage (on ruvultra, with the CSI tunneled to localhost:8765):
|
||||
python3 wiflow_capture.py --ws ws://localhost:8765/ws/sensing \
|
||||
--cam 0 --out ~/wiflow-room/dataset.jsonl --seconds 180
|
||||
"""
|
||||
import argparse, asyncio, json, time, threading, sys, os
|
||||
from collections import deque
|
||||
|
||||
import urllib.request
|
||||
import cv2
|
||||
import numpy as np
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks.python import BaseOptions
|
||||
from mediapipe.tasks.python.vision import PoseLandmarker, PoseLandmarkerOptions, RunningMode
|
||||
import websockets
|
||||
|
||||
_MODEL_URL = ("https://storage.googleapis.com/mediapipe-models/pose_landmarker/"
|
||||
"pose_landmarker_lite/float16/latest/pose_landmarker_lite.task")
|
||||
|
||||
def ensure_model(path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
print(f"[capture] downloading pose model -> {path}", flush=True)
|
||||
urllib.request.urlretrieve(_MODEL_URL, path)
|
||||
return path
|
||||
|
||||
# MediaPipe Pose (33 landmarks) -> 17 COCO keypoints (same mapping as
|
||||
# scripts/collect-ground-truth.py, ADR-079).
|
||||
COCO_FROM_MP = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
|
||||
COCO_NAMES = ["nose","l_eye","r_eye","l_ear","r_ear","l_sho","r_sho","l_elb",
|
||||
"r_elb","l_wri","r_wri","l_hip","r_hip","l_knee","r_knee","l_ank","r_ank"]
|
||||
|
||||
# ---- shared state between the CSI (async) thread and the camera (sync) loop ----
|
||||
_latest_csi = {"t": 0.0, "frame": None}
|
||||
_csi_lock = threading.Lock()
|
||||
_stop = threading.Event()
|
||||
|
||||
|
||||
def csi_thread(ws_url: str):
|
||||
"""Background thread: keep the most recent LIVE csi frame in _latest_csi."""
|
||||
async def run():
|
||||
while not _stop.is_set():
|
||||
try:
|
||||
async with websockets.connect(ws_url, open_timeout=8, ping_interval=20) as ws:
|
||||
while not _stop.is_set():
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=8)
|
||||
d = json.loads(msg)
|
||||
with _csi_lock:
|
||||
_latest_csi["t"] = time.time()
|
||||
_latest_csi["frame"] = d
|
||||
except Exception as e:
|
||||
print(f"[csi] reconnect ({e})", flush=True)
|
||||
await asyncio.sleep(1.0)
|
||||
asyncio.new_event_loop().run_until_complete(run())
|
||||
|
||||
|
||||
def csi_vector(frame: dict):
|
||||
"""Flatten a csi frame to a fixed-length input vector: features + field."""
|
||||
f = frame.get("features", {}) or {}
|
||||
feats = [f.get("mean_rssi", 0.0), f.get("variance", 0.0),
|
||||
f.get("motion_band_power", 0.0), f.get("breathing_band_power", 0.0)]
|
||||
# per-node mean_rssi/variance/motion for up to the 2 nodes (9, 13)
|
||||
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
|
||||
for nid in (9, 13):
|
||||
nf = pernode.get(nid, {})
|
||||
feats += [nf.get("mean_rssi", 0.0), nf.get("variance", 0.0), nf.get("motion_band_power", 0.0)]
|
||||
field = (frame.get("signal_field", {}) or {}).get("values") or []
|
||||
field = (field + [0.0] * 400)[:400]
|
||||
return feats + field # 4 + 6 + 400 = 410-d
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="WiFlow camera-supervised CSI<->pose capture (ADR-180).")
|
||||
ap.add_argument("--ws", default="ws://localhost:8765/ws/sensing")
|
||||
ap.add_argument("--cam", type=int, default=0)
|
||||
ap.add_argument("--out", default=os.path.expanduser("~/wiflow-room/dataset.jsonl"))
|
||||
ap.add_argument("--seconds", type=int, default=180)
|
||||
ap.add_argument("--min-vis", type=float, default=0.5, help="min mean landmark visibility to accept a pose label")
|
||||
ap.add_argument("--max-skew-ms", type=float, default=150, help="max csi/pose time skew to pair")
|
||||
ap.add_argument("--require-esp32", action="store_true", default=True,
|
||||
help="only pair when csi source==esp32 (real). Default on.")
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
th = threading.Thread(target=csi_thread, args=(args.ws,), daemon=True)
|
||||
th.start()
|
||||
|
||||
cap = cv2.VideoCapture(args.cam)
|
||||
if not cap.isOpened():
|
||||
print(f"ERROR: cannot open camera {args.cam}", file=sys.stderr); sys.exit(2)
|
||||
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
|
||||
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 480
|
||||
model_path = ensure_model(os.path.expanduser("~/wiflow-room/pose_landmarker_lite.task"))
|
||||
landmarker = PoseLandmarker.create_from_options(PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=model_path),
|
||||
running_mode=RunningMode.IMAGE, min_pose_detection_confidence=0.5))
|
||||
|
||||
n_pairs = 0; n_nopose = 0; n_nocsi = 0; n_skew = 0; n_sim = 0
|
||||
t0 = time.time()
|
||||
print(f"[capture] camera {args.cam} {W}x{H} -> {args.out} for {args.seconds}s")
|
||||
print("[capture] stand in view AND in the CSI field; move/walk so poses vary. Ctrl-C to stop.")
|
||||
with open(args.out, "a") as out:
|
||||
try:
|
||||
while time.time() - t0 < args.seconds:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
continue
|
||||
now = time.time()
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
res = landmarker.detect(mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb))
|
||||
if not res.pose_landmarks:
|
||||
n_nopose += 1; continue
|
||||
lm = res.pose_landmarks[0]
|
||||
kps = [[lm[i].x, lm[i].y, lm[i].visibility] for i in COCO_FROM_MP]
|
||||
vis = float(np.mean([k[2] for k in kps]))
|
||||
if vis < args.min_vis:
|
||||
n_nopose += 1; continue
|
||||
with _csi_lock:
|
||||
ct = _latest_csi["t"]; cf = _latest_csi["frame"]
|
||||
if cf is None:
|
||||
n_nocsi += 1; continue
|
||||
if (now - ct) * 1000.0 > args.max_skew_ms:
|
||||
n_skew += 1; continue
|
||||
if args.require_esp32 and cf.get("source") != "esp32":
|
||||
n_sim += 1; continue
|
||||
rec = {"t": now, "vis": round(vis, 3),
|
||||
"kps": [[round(x, 4), round(y, 4), round(v, 3)] for x, y, v in kps],
|
||||
"csi": csi_vector(cf),
|
||||
"src": cf.get("source"),
|
||||
"nodes": sorted(n.get("node_id") for n in cf.get("nodes", []) if n.get("node_id") is not None)}
|
||||
out.write(json.dumps(rec) + "\n")
|
||||
n_pairs += 1
|
||||
if n_pairs % 30 == 0:
|
||||
out.flush()
|
||||
el = int(now - t0)
|
||||
print(f"[capture] t+{el:3d}s pairs={n_pairs} (skip: nopose={n_nopose} nocsi={n_nocsi} skew={n_skew} sim={n_sim})", flush=True)
|
||||
except KeyboardInterrupt:
|
||||
print("\n[capture] stopped by user")
|
||||
_stop.set(); cap.release()
|
||||
print(f"[capture] DONE. wrote {n_pairs} paired samples to {args.out}")
|
||||
print(f"[capture] skipped: no-pose={n_nopose} no-csi={n_nocsi} skew={n_skew} simulated={n_sim}")
|
||||
if n_pairs == 0:
|
||||
print("[capture] WARNING: 0 pairs — check camera sees you AND csi source==esp32 (live).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Live CSI->pose inference bridge (ADR-180).
|
||||
|
||||
Runs on the box with the live CSI. Loads the camera-supervised model (numpy,
|
||||
no torch needed), subscribes to /ws/sensing, runs a forward pass per frame, and
|
||||
broadcasts the predicted 17-keypoint pose to HTML clients on ws://:8770/pose.
|
||||
|
||||
python wiflow_infer.py --model model/model.npz \
|
||||
--in ws://localhost:8765/ws/sensing --port 8770
|
||||
"""
|
||||
import argparse, asyncio, json, os
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
# COCO skeleton edges (for the client; sent once in 'meta')
|
||||
EDGES = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12],
|
||||
[11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]]
|
||||
|
||||
def csi_vector(frame):
|
||||
f = frame.get("features", {}) or {}
|
||||
feats = [f.get("mean_rssi",0.0), f.get("variance",0.0),
|
||||
f.get("motion_band_power",0.0), f.get("breathing_band_power",0.0)]
|
||||
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
|
||||
for nid in (9,13):
|
||||
nf = pernode.get(nid,{}); feats += [nf.get("mean_rssi",0.0), nf.get("variance",0.0), nf.get("motion_band_power",0.0)]
|
||||
field = (frame.get("signal_field",{}) or {}).get("values") or []
|
||||
field = (field + [0.0]*400)[:400]
|
||||
return np.array(feats + field, np.float32)
|
||||
|
||||
class Model:
|
||||
def __init__(self, path):
|
||||
z = np.load(path)
|
||||
self.mu, self.sd = z["mu"], z["sd"]
|
||||
self.W = [z["net_0_weight"], z["net_3_weight"], z["net_6_weight"], z["net_8_weight"]]
|
||||
self.b = [z["net_0_bias"], z["net_3_bias"], z["net_6_bias"], z["net_8_bias"]]
|
||||
def __call__(self, x):
|
||||
h = (x - self.mu) / self.sd
|
||||
for i in range(3):
|
||||
h = np.maximum(0.0, h @ self.W[i].T + self.b[i]) # Linear+ReLU
|
||||
out = 1.0/(1.0+np.exp(-(h @ self.W[3].T + self.b[3]))) # Linear+Sigmoid -> 34
|
||||
return out.reshape(17,2)
|
||||
|
||||
CLIENTS = set()
|
||||
LATEST = {"pose": None}
|
||||
|
||||
async def serve_client(ws):
|
||||
CLIENTS.add(ws)
|
||||
try:
|
||||
await ws.send(json.dumps({"type":"meta","edges":EDGES}))
|
||||
async for _ in ws: # client is read-only; just keep alive
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
CLIENTS.discard(ws)
|
||||
|
||||
async def infer_loop(model, in_url):
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(in_url, open_timeout=8, ping_interval=20) as ws:
|
||||
async for msg in ws:
|
||||
d = json.loads(msg)
|
||||
kp = model(csi_vector(d))
|
||||
cls = d.get("classification",{})
|
||||
payload = {"type":"pose","src":d.get("source"),
|
||||
"presence":bool(cls.get("presence")),
|
||||
"motion":(d.get("features",{}) or {}).get("motion_band_power"),
|
||||
"kps":[[round(float(x),4),round(float(y),4)] for x,y in kp],
|
||||
"nodes":sorted(n.get("node_id") for n in d.get("nodes",[]) if n.get("node_id") is not None)}
|
||||
LATEST["pose"]=payload
|
||||
if CLIENTS:
|
||||
dead=[]
|
||||
for c in list(CLIENTS):
|
||||
try: await c.send(json.dumps(payload))
|
||||
except Exception: dead.append(c)
|
||||
for c in dead: CLIENTS.discard(c)
|
||||
except Exception as e:
|
||||
print(f"[infer] reconnect ({e})", flush=True); await asyncio.sleep(1.0)
|
||||
|
||||
async def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__),"model","model.npz"))
|
||||
ap.add_argument("--in", dest="in_url", default="ws://localhost:8765/ws/sensing")
|
||||
ap.add_argument("--port", type=int, default=8770)
|
||||
args = ap.parse_args()
|
||||
model = Model(args.model)
|
||||
print(f"[infer] model {args.model} loaded; serving predicted poses on ws://0.0.0.0:{args.port}/pose")
|
||||
async with websockets.serve(serve_client, "0.0.0.0", args.port):
|
||||
await infer_loop(model, args.in_url)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train a CSI->pose model on the camera-supervised dataset (ADR-079/180).
|
||||
|
||||
Input : 410-d CSI vector (4 global feats + 6 per-node + 400 signal-field).
|
||||
Target : 17 COCO keypoints (x,y), normalized 0..1 from the camera (ground truth).
|
||||
Reports HONEST held-out PCK@k + MPJPE on a chronological val split (the last
|
||||
20% of the session — never trained on), so the number is not leaked.
|
||||
|
||||
Usage (ruvultra venv):
|
||||
python wiflow_train.py --data ~/wiflow-room/dataset.jsonl --out ~/wiflow-room/model.pt
|
||||
"""
|
||||
import argparse, json, math, os, sys
|
||||
import numpy as np
|
||||
import torch, torch.nn as nn
|
||||
|
||||
|
||||
def load(path):
|
||||
X, Y, V = [], [], []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
r = json.loads(line)
|
||||
X.append(r["csi"]) # 410
|
||||
kp = r["kps"] # 17 x [x,y,vis]
|
||||
Y.append([c for k in kp for c in (k[0], k[1])]) # 34
|
||||
V.append([k[2] for k in kp]) # 17 visibilities
|
||||
return np.array(X, np.float32), np.array(Y, np.float32), np.array(V, np.float32)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, din, dout):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(din, 512), nn.ReLU(), nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3),
|
||||
nn.Linear(256, 128), nn.ReLU(),
|
||||
nn.Linear(128, dout), nn.Sigmoid()) # coords in 0..1
|
||||
def forward(self, x): return self.net(x)
|
||||
|
||||
|
||||
def pck(pred, gt, vis, thr):
|
||||
# pred/gt: [N,34] -> [N,17,2]; PCK@thr in normalized image units, visible kps only
|
||||
p = pred.reshape(-1, 17, 2); g = gt.reshape(-1, 17, 2)
|
||||
d = np.linalg.norm(p - g, axis=2) # [N,17]
|
||||
m = vis > 0.5
|
||||
return float((d[m] < thr).mean()) if m.any() else 0.0, float(d[m].mean()) if m.any() else float("nan")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--data", required=True)
|
||||
ap.add_argument("--out", default=os.path.expanduser("~/wiflow-room/model.pt"))
|
||||
ap.add_argument("--epochs", type=int, default=300)
|
||||
ap.add_argument("--bs", type=int, default=64)
|
||||
args = ap.parse_args()
|
||||
|
||||
X, Y, V = load(args.data)
|
||||
n = len(X)
|
||||
print(f"[train] {n} samples, X={X.shape} Y={Y.shape}")
|
||||
if n < 200:
|
||||
print("[train] too few samples"); sys.exit(2)
|
||||
|
||||
# chronological split (NOT shuffled) so val is a held-out time segment -> honest
|
||||
cut = int(n * 0.8)
|
||||
mu, sd = X[:cut].mean(0), X[:cut].std(0) + 1e-6 # standardize on train only
|
||||
Xn = (X - mu) / sd
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
Xtr = torch.tensor(Xn[:cut]).to(dev); Ytr = torch.tensor(Y[:cut]).to(dev)
|
||||
Xva = torch.tensor(Xn[cut:]).to(dev); Yva = Y[cut:]; Vva = V[cut:]
|
||||
|
||||
# mean-pose baseline (predict the train-mean pose for everything) — the bar to beat
|
||||
mean_pose = Y[:cut].mean(0)
|
||||
base_pck, base_mpjpe = pck(np.tile(mean_pose, (len(Yva), 1)), Yva, Vva, 0.10)
|
||||
|
||||
net = Net(X.shape[1], Y.shape[1]).to(dev)
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
lossf = nn.MSELoss()
|
||||
best = (1e9, None)
|
||||
for ep in range(args.epochs):
|
||||
net.train(); perm = torch.randperm(len(Xtr), device=dev)
|
||||
for i in range(0, len(Xtr), args.bs):
|
||||
idx = perm[i:i+args.bs]
|
||||
opt.zero_grad(); out = net(Xtr[idx]); loss = lossf(out, Ytr[idx]); loss.backward(); opt.step()
|
||||
if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
|
||||
net.eval()
|
||||
with torch.no_grad(): pv = net(Xva).cpu().numpy()
|
||||
p10, mpj = pck(pv, Yva, Vva, 0.10); p05, _ = pck(pv, Yva, Vva, 0.05)
|
||||
vloss = float(((pv - Yva) ** 2).mean())
|
||||
print(f"[train] ep{ep+1:3d} val_mse={vloss:.4f} PCK@0.10={p10*100:.1f}% PCK@0.05={p05*100:.1f}% MPJPE={mpj:.4f}")
|
||||
if vloss < best[0]: best = (vloss, {"sd": net.state_dict(), "p10": p10, "p05": p05, "mpj": mpj})
|
||||
|
||||
torch.save({"model": best[1]["sd"], "mu": mu, "sd": sd, "din": X.shape[1]}, args.out)
|
||||
print("\n==================== HONEST RESULT (held-out 20%, never trained) ====================")
|
||||
print(f" MEAN-POSE BASELINE : PCK@0.10 = {base_pck*100:.1f}% MPJPE = {base_mpjpe:.4f} (the bar to beat)")
|
||||
print(f" CSI->POSE MODEL : PCK@0.10 = {best[1]['p10']*100:.1f}% PCK@0.05 = {best[1]['p05']*100:.1f}% MPJPE = {best[1]['mpj']:.4f}")
|
||||
delta = (best[1]['p10'] - base_pck) * 100
|
||||
print(f" VERDICT: model {'BEATS' if delta>1 else 'does NOT beat'} mean-pose baseline by {delta:+.1f} pp "
|
||||
f"-> {'real CSI->pose signal' if delta>1 else 'NO usable CSI->pose signal (honest negative)'}")
|
||||
print(f" saved -> {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1
-1
Submodule v2/crates/ruv-neural updated: 1ece3afa33...81be9e1e19
Reference in New Issue
Block a user