mirror of
https://github.com/ruvnet/RuView
synced 2026-06-13 10:53:20 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0223ef6d2e | |||
| 2f5e7ffb41 | |||
| 4ce8ffc465 | |||
| 3be63a7589 | |||
| c4e640c812 |
@@ -87,7 +87,7 @@ docker run -p 3000:3000 ruvnet/wifi-densepose:latest
|
||||
</a>
|
||||
<br>
|
||||
<em>Real-time pose skeleton from WiFi CSI signals — no cameras, no wearables</em>
|
||||
<br><br>
|
||||
<br>
|
||||
<a href="https://ruvnet.github.io/RuView/"><strong>▶ Live Observatory Demo</strong></a>
|
||||
|
|
||||
<a href="https://ruvnet.github.io/RuView/pose-fusion.html"><strong>▶ Dual-Modal Pose Fusion Demo</strong></a>
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# ADR-060: Provision Channel Override and MAC Address Filtering
|
||||
|
||||
- **Status:** Accepted
|
||||
- **Date:** 2026-03-12
|
||||
- **Issues:** [#247](https://github.com/ruvnet/RuView/issues/247), [#229](https://github.com/ruvnet/RuView/issues/229)
|
||||
|
||||
## Context
|
||||
|
||||
Two related provisioning gaps were reported by users:
|
||||
|
||||
1. **Channel mismatch (Issue #247):** The CSI collector initializes on the
|
||||
Kconfig default channel (typically 6), even when the ESP32 connects to an AP
|
||||
on a different channel (e.g. 11). On managed networks where the user cannot
|
||||
change the router channel, this makes nodes undiscoverable. The
|
||||
`provision.py` script has no `--channel` argument.
|
||||
|
||||
2. **Missing MAC filter (Issue #229):** The v0.2.0 release notes documented a
|
||||
`--filter-mac` argument for `provision.py`, but it was never implemented.
|
||||
The firmware's CSI callback accepts frames from all sources, causing signal
|
||||
mixing in multi-AP environments.
|
||||
|
||||
## Decision
|
||||
|
||||
### Channel configuration
|
||||
|
||||
- Add `--channel` argument to `provision.py` that writes a `csi_channel` key
|
||||
(u8) to NVS.
|
||||
- In `nvs_config.c`, read the `csi_channel` key and override
|
||||
`channel_list[0]` when present.
|
||||
- In `csi_collector_init()`, after WiFi connects, auto-detect the AP channel
|
||||
via `esp_wifi_sta_get_ap_info()` and use it as the default CSI channel when
|
||||
no NVS override is set. This ensures the CSI collector always matches the
|
||||
connected AP's channel without requiring manual provisioning.
|
||||
|
||||
### MAC address filtering
|
||||
|
||||
- Add `--filter-mac` argument to `provision.py` that writes a `filter_mac`
|
||||
key (6-byte blob) to NVS.
|
||||
- In `nvs_config.h`, add a `filter_mac[6]` field and `filter_mac_set` flag.
|
||||
- In `nvs_config.c`, read the `filter_mac` blob from NVS.
|
||||
- In the CSI callback (`wifi_csi_callback`), if `filter_mac_set` is true,
|
||||
compare the source MAC from the received frame against the configured MAC
|
||||
and drop non-matching frames.
|
||||
|
||||
### Provisioning flow
|
||||
|
||||
```
|
||||
python provision.py --port COM7 --channel 11
|
||||
python provision.py --port COM7 --filter-mac "AA:BB:CC:DD:EE:FF"
|
||||
python provision.py --port COM7 --channel 11 --filter-mac "AA:BB:CC:DD:EE:FF"
|
||||
```
|
||||
|
||||
## Consequences
|
||||
|
||||
- Users on managed networks can force the CSI channel to match their AP
|
||||
- Multi-AP environments can filter CSI to a single source
|
||||
- Auto-channel detection eliminates the most common misconfiguration
|
||||
- Backward compatible: existing provisioned nodes without these keys behave
|
||||
as before (use Kconfig default channel, accept all MACs)
|
||||
@@ -12,7 +12,6 @@
|
||||
*/
|
||||
|
||||
#include "csi_collector.h"
|
||||
#include "nvs_config.h"
|
||||
#include "stream_sender.h"
|
||||
#include "edge_processing.h"
|
||||
|
||||
@@ -22,9 +21,6 @@
|
||||
#include "esp_timer.h"
|
||||
#include "sdkconfig.h"
|
||||
|
||||
/* ADR-060: Access the global NVS config for MAC filter and channel override. */
|
||||
extern nvs_config_t g_nvs_config;
|
||||
|
||||
/* ADR-057: Build-time guard — fail early if CSI is not enabled in sdkconfig.
|
||||
* Without this, the firmware compiles but crashes at runtime with:
|
||||
* "E (xxxx) wifi:CSI not enabled in menuconfig!"
|
||||
@@ -155,14 +151,6 @@ size_t csi_serialize_frame(const wifi_csi_info_t *info, uint8_t *buf, size_t buf
|
||||
static void wifi_csi_callback(void *ctx, wifi_csi_info_t *info)
|
||||
{
|
||||
(void)ctx;
|
||||
|
||||
/* ADR-060: MAC address filtering — drop frames from non-matching sources. */
|
||||
if (g_nvs_config.filter_mac_set) {
|
||||
if (memcmp(info->mac, g_nvs_config.filter_mac, 6) != 0) {
|
||||
return; /* Source MAC doesn't match filter — skip frame. */
|
||||
}
|
||||
}
|
||||
|
||||
s_cb_count++;
|
||||
|
||||
if (s_cb_count <= 3 || (s_cb_count % 100) == 0) {
|
||||
@@ -215,29 +203,6 @@ static void wifi_promiscuous_cb(void *buf, wifi_promiscuous_pkt_type_t type)
|
||||
|
||||
void csi_collector_init(void)
|
||||
{
|
||||
/* ADR-060: Determine the CSI channel.
|
||||
* Priority: 1) NVS override (--channel), 2) connected AP channel, 3) Kconfig default. */
|
||||
uint8_t csi_channel = (uint8_t)CONFIG_CSI_WIFI_CHANNEL;
|
||||
|
||||
if (g_nvs_config.csi_channel > 0) {
|
||||
/* Explicit NVS override via provision.py --channel */
|
||||
csi_channel = g_nvs_config.csi_channel;
|
||||
ESP_LOGI(TAG, "Using NVS channel override: %u", (unsigned)csi_channel);
|
||||
} else {
|
||||
/* Auto-detect from connected AP */
|
||||
wifi_ap_record_t ap_info;
|
||||
if (esp_wifi_sta_get_ap_info(&ap_info) == ESP_OK && ap_info.primary > 0) {
|
||||
csi_channel = ap_info.primary;
|
||||
ESP_LOGI(TAG, "Auto-detected AP channel: %u", (unsigned)csi_channel);
|
||||
} else {
|
||||
ESP_LOGW(TAG, "Could not detect AP channel, using Kconfig default: %u",
|
||||
(unsigned)csi_channel);
|
||||
}
|
||||
}
|
||||
|
||||
/* Update the hop table's first channel to match. */
|
||||
s_hop_channels[0] = csi_channel;
|
||||
|
||||
/* Enable promiscuous mode — required for reliable CSI callbacks.
|
||||
* Without this, CSI only fires on frames destined to this station,
|
||||
* which may be very infrequent on a quiet network. */
|
||||
@@ -265,15 +230,8 @@ void csi_collector_init(void)
|
||||
ESP_ERROR_CHECK(esp_wifi_set_csi_rx_cb(wifi_csi_callback, NULL));
|
||||
ESP_ERROR_CHECK(esp_wifi_set_csi(true));
|
||||
|
||||
if (g_nvs_config.filter_mac_set) {
|
||||
ESP_LOGI(TAG, "MAC filter active: %02x:%02x:%02x:%02x:%02x:%02x",
|
||||
g_nvs_config.filter_mac[0], g_nvs_config.filter_mac[1],
|
||||
g_nvs_config.filter_mac[2], g_nvs_config.filter_mac[3],
|
||||
g_nvs_config.filter_mac[4], g_nvs_config.filter_mac[5]);
|
||||
}
|
||||
|
||||
ESP_LOGI(TAG, "CSI collection initialized (node_id=%d, channel=%u)",
|
||||
CONFIG_CSI_NODE_ID, (unsigned)csi_channel);
|
||||
ESP_LOGI(TAG, "CSI collection initialized (node_id=%d, channel=%d)",
|
||||
CONFIG_CSI_NODE_ID, CONFIG_CSI_WIFI_CHANNEL);
|
||||
}
|
||||
|
||||
/* ---- ADR-029: Channel hopping ---- */
|
||||
|
||||
@@ -91,11 +91,6 @@ void nvs_config_load(nvs_config_t *cfg)
|
||||
cfg->wasm_verify = 0; /* Kconfig disabled signature verification. */
|
||||
#endif
|
||||
|
||||
/* ADR-060: Channel override and MAC filter defaults. */
|
||||
cfg->csi_channel = 0; /* 0 = auto-detect from connected AP. */
|
||||
cfg->filter_mac_set = 0;
|
||||
memset(cfg->filter_mac, 0, 6);
|
||||
|
||||
/* Try to override from NVS */
|
||||
nvs_handle_t handle;
|
||||
esp_err_t err = nvs_open("csi_cfg", NVS_READONLY, &handle);
|
||||
@@ -282,26 +277,6 @@ void nvs_config_load(nvs_config_t *cfg)
|
||||
ESP_LOGW(TAG, "wasm_verify=1 but no wasm_pubkey in NVS — uploads will be rejected");
|
||||
}
|
||||
|
||||
/* ADR-060: CSI channel override. */
|
||||
uint8_t csi_ch_val;
|
||||
if (nvs_get_u8(handle, "csi_channel", &csi_ch_val) == ESP_OK) {
|
||||
if ((csi_ch_val >= 1 && csi_ch_val <= 14) || (csi_ch_val >= 36 && csi_ch_val <= 177)) {
|
||||
cfg->csi_channel = csi_ch_val;
|
||||
ESP_LOGI(TAG, "NVS override: csi_channel=%u", (unsigned)cfg->csi_channel);
|
||||
} else {
|
||||
ESP_LOGW(TAG, "NVS csi_channel=%u invalid, ignored", (unsigned)csi_ch_val);
|
||||
}
|
||||
}
|
||||
|
||||
/* ADR-060: MAC address filter (6-byte blob). */
|
||||
size_t mac_len = 6;
|
||||
if (nvs_get_blob(handle, "filter_mac", cfg->filter_mac, &mac_len) == ESP_OK && mac_len == 6) {
|
||||
cfg->filter_mac_set = 1;
|
||||
ESP_LOGI(TAG, "NVS override: filter_mac=%02x:%02x:%02x:%02x:%02x:%02x",
|
||||
cfg->filter_mac[0], cfg->filter_mac[1], cfg->filter_mac[2],
|
||||
cfg->filter_mac[3], cfg->filter_mac[4], cfg->filter_mac[5]);
|
||||
}
|
||||
|
||||
/* Validate tdm_slot_index < tdm_node_count */
|
||||
if (cfg->tdm_slot_index >= cfg->tdm_node_count) {
|
||||
ESP_LOGW(TAG, "tdm_slot_index=%u >= tdm_node_count=%u, clamping to 0",
|
||||
|
||||
@@ -50,11 +50,6 @@ typedef struct {
|
||||
uint8_t wasm_verify; /**< Require Ed25519 signature for uploads. */
|
||||
uint8_t wasm_pubkey[32]; /**< Ed25519 public key for WASM signature. */
|
||||
uint8_t wasm_pubkey_valid; /**< 1 if pubkey was loaded from NVS. */
|
||||
|
||||
/* ADR-060: Channel override and MAC address filtering */
|
||||
uint8_t csi_channel; /**< Explicit CSI channel override (0 = auto-detect). */
|
||||
uint8_t filter_mac[6]; /**< MAC address to filter CSI frames. */
|
||||
uint8_t filter_mac_set; /**< 1 if filter_mac was loaded from NVS. */
|
||||
} nvs_config_t;
|
||||
|
||||
/**
|
||||
|
||||
@@ -64,13 +64,6 @@ def build_nvs_csv(args):
|
||||
writer.writerow(["vital_int", "data", "u16", str(args.vital_int)])
|
||||
if args.subk_count is not None:
|
||||
writer.writerow(["subk_count", "data", "u8", str(args.subk_count)])
|
||||
# ADR-060: Channel override and MAC filter
|
||||
if args.channel is not None:
|
||||
writer.writerow(["csi_channel", "data", "u8", str(args.channel)])
|
||||
if args.filter_mac is not None:
|
||||
mac_bytes = bytes(int(b, 16) for b in args.filter_mac.split(":"))
|
||||
# NVS blob: write as hex-encoded string for CSV compatibility
|
||||
writer.writerow(["filter_mac", "data", "hex2bin", mac_bytes.hex()])
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
@@ -172,10 +165,6 @@ def main():
|
||||
parser.add_argument("--vital-win", type=int, help="Phase history window in frames (default: 300)")
|
||||
parser.add_argument("--vital-int", type=int, help="Vitals packet interval in ms (default: 1000)")
|
||||
parser.add_argument("--subk-count", type=int, help="Top-K subcarrier count (default: 32)")
|
||||
# ADR-060: Channel override and MAC filter
|
||||
parser.add_argument("--channel", type=int, help="CSI channel (1-14 for 2.4GHz, 36-177 for 5GHz). "
|
||||
"Overrides auto-detection from connected AP.")
|
||||
parser.add_argument("--filter-mac", type=str, help="MAC address to filter CSI frames (AA:BB:CC:DD:EE:FF)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Generate NVS binary but don't flash")
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -187,7 +176,6 @@ def main():
|
||||
args.edge_tier is not None, args.pres_thresh is not None,
|
||||
args.fall_thresh is not None, args.vital_win is not None,
|
||||
args.vital_int is not None, args.subk_count is not None,
|
||||
args.channel is not None, args.filter_mac is not None,
|
||||
])
|
||||
if not has_value:
|
||||
parser.error("At least one config value must be specified")
|
||||
@@ -198,22 +186,6 @@ def main():
|
||||
if args.tdm_slot is not None and args.tdm_slot >= args.tdm_total:
|
||||
parser.error(f"--tdm-slot ({args.tdm_slot}) must be less than --tdm-total ({args.tdm_total})")
|
||||
|
||||
# ADR-060: Validate channel and MAC filter
|
||||
if args.channel is not None:
|
||||
if not ((1 <= args.channel <= 14) or (36 <= args.channel <= 177)):
|
||||
parser.error(f"--channel must be 1-14 (2.4GHz) or 36-177 (5GHz), got {args.channel}")
|
||||
if args.filter_mac is not None:
|
||||
parts = args.filter_mac.split(":")
|
||||
if len(parts) != 6:
|
||||
parser.error(f"--filter-mac must be in AA:BB:CC:DD:EE:FF format, got '{args.filter_mac}'")
|
||||
try:
|
||||
for p in parts:
|
||||
val = int(p, 16)
|
||||
if val < 0 or val > 255:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
parser.error(f"--filter-mac contains invalid hex bytes: '{args.filter_mac}'")
|
||||
|
||||
print("Building NVS configuration:")
|
||||
if args.ssid:
|
||||
print(f" WiFi SSID: {args.ssid}")
|
||||
@@ -240,10 +212,6 @@ def main():
|
||||
print(f" Vital Interval:{args.vital_int} ms")
|
||||
if args.subk_count is not None:
|
||||
print(f" Top-K Subcarr: {args.subk_count}")
|
||||
if args.channel is not None:
|
||||
print(f" CSI Channel: {args.channel}")
|
||||
if args.filter_mac is not None:
|
||||
print(f" Filter MAC: {args.filter_mac}")
|
||||
|
||||
csv_content = build_nvs_csv(args)
|
||||
|
||||
|
||||
+10
-51
@@ -3,15 +3,15 @@
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>RuView — Dual-Modal Pose Estimation</title>
|
||||
<link rel="stylesheet" href="pose-fusion/css/style.css?v=13">
|
||||
<title>WiFi-DensePose — Dual-Modal Pose Estimation</title>
|
||||
<link rel="stylesheet" href="pose-fusion/css/style.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<!-- Header -->
|
||||
<header class="header">
|
||||
<div class="header-left">
|
||||
<div class="logo"><span class="pi">π</span> RuView</div>
|
||||
<div class="logo"><span class="pi">π</span> DensePose</div>
|
||||
<div class="header-title">Dual-Modal Pose Estimation — Live Video + WiFi CSI Fusion</div>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
@@ -40,7 +40,6 @@
|
||||
<div class="video-overlay-label" id="mode-label">DUAL FUSION</div>
|
||||
|
||||
<div id="camera-prompt" class="camera-prompt">
|
||||
<div class="camera-prompt-label" id="prompt-mode-label">DUAL FUSION</div>
|
||||
<p>Enable your webcam for live video pose estimation.<br>
|
||||
Or switch to <strong>CSI Only</strong> mode for WiFi-based sensing.</p>
|
||||
<button id="start-camera-btn">Enable Camera</button>
|
||||
@@ -79,24 +78,7 @@
|
||||
<div class="panel">
|
||||
<div class="panel-title">◆ CSI Amplitude Heatmap</div>
|
||||
<div class="csi-canvas-wrapper">
|
||||
<canvas id="csi-canvas" width="320" height="100"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- RSSI Signal Strength -->
|
||||
<div class="panel">
|
||||
<div class="panel-title">◆ RSSI Signal Strength</div>
|
||||
<div class="rssi-row">
|
||||
<div class="rssi-gauge">
|
||||
<div class="rssi-bar-track">
|
||||
<div class="rssi-bar-fill" id="rssi-bar" style="width:0%"></div>
|
||||
</div>
|
||||
<div class="rssi-values">
|
||||
<span class="rssi-dbm" id="rssi-value">-- dBm</span>
|
||||
<span class="rssi-quality" id="rssi-quality">--</span>
|
||||
</div>
|
||||
</div>
|
||||
<canvas id="rssi-sparkline" width="160" height="32"></canvas>
|
||||
<canvas id="csi-canvas" width="320" height="120"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -104,30 +86,7 @@
|
||||
<div class="panel">
|
||||
<div class="panel-title">◆ Embedding Space (2D Projection)</div>
|
||||
<div class="embedding-canvas-wrapper">
|
||||
<canvas id="embedding-canvas" width="320" height="100"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- RuVector Attention Pipeline -->
|
||||
<div class="panel">
|
||||
<div class="panel-title">◆ RuVector WASM Attention Pipeline</div>
|
||||
<div class="rv-pipeline">
|
||||
<div class="rv-stage" id="rv-flash">Flash</div>
|
||||
<div class="rv-arrow">→</div>
|
||||
<div class="rv-stage" id="rv-mha">MHA</div>
|
||||
<div class="rv-arrow">→</div>
|
||||
<div class="rv-stage" id="rv-hyp">Hyper</div>
|
||||
<div class="rv-arrow">→</div>
|
||||
<div class="rv-stage" id="rv-lin">Linear</div>
|
||||
<div class="rv-arrow">→</div>
|
||||
<div class="rv-stage" id="rv-moe">MoE</div>
|
||||
<div class="rv-arrow">→</div>
|
||||
<div class="rv-stage" id="rv-lg">L+G</div>
|
||||
</div>
|
||||
<div class="rv-stats">
|
||||
<span>Energy: <span id="rv-energy" style="color:var(--green-glow)">--</span></span>
|
||||
<span>Refinement: <span id="rv-refine" style="color:var(--cyan)">--</span></span>
|
||||
<span>Pose Impact: <span id="rv-impact" style="color:var(--amber)">--</span></span>
|
||||
<canvas id="embedding-canvas" width="320" height="140"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -184,18 +143,18 @@
|
||||
<!-- Bottom Bar -->
|
||||
<div class="bottom-bar">
|
||||
<div>
|
||||
RuView · Dual-Modal Pose Estimation ·
|
||||
Architecture: Conv2D → RuVector 6-Stage Attention (Flash+MHA+Hyperbolic+Linear+MoE+L/G) → Fusion → 26-Keypoint Pose
|
||||
WiFi-DensePose · Dual-Modal Pose Estimation ·
|
||||
Architecture: MobileNet-V3 × 2 → Attention Fusion → 17-Keypoint COCO
|
||||
</div>
|
||||
<div>
|
||||
<a href="https://github.com/ruvnet/RuView">GitHub</a> ·
|
||||
CNN: <span id="cnn-backend">ruvector-cnn (loading…)</span> ·
|
||||
<a href="https://github.com/ruvnet/wifi-densepose">GitHub</a> ·
|
||||
CNN: ruvector-cnn (JS fallback) ·
|
||||
<a href="observatory.html">Observatory</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div><!-- /main-grid -->
|
||||
|
||||
<script type="module" src="pose-fusion/js/main.js?v=13"></script>
|
||||
<script type="module" src="pose-fusion/js/main.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/* RuView — Dual-Modal Pose Fusion Demo
|
||||
/* WiFi-DensePose — Dual-Modal Pose Fusion Demo
|
||||
Dark theme matching Observatory */
|
||||
|
||||
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&family=JetBrains+Mono:wght@400;600&display=swap');
|
||||
@@ -136,14 +136,6 @@ body {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.video-panel {
|
||||
grid-row: 1;
|
||||
}
|
||||
|
||||
.side-panels {
|
||||
grid-row: 1;
|
||||
}
|
||||
|
||||
/* === Video Panel === */
|
||||
.video-panel {
|
||||
position: relative;
|
||||
@@ -184,20 +176,14 @@ body {
|
||||
|
||||
.camera-prompt {
|
||||
position: absolute;
|
||||
top: 0; left: 0; right: 0; bottom: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
top: 50%; left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
text-align: center;
|
||||
color: var(--text-secondary);
|
||||
padding: 24px;
|
||||
z-index: 6;
|
||||
background: radial-gradient(ellipse at center, rgba(0,210,120,0.08) 0%, rgba(8,12,20,0.95) 70%);
|
||||
}
|
||||
|
||||
.camera-prompt button {
|
||||
margin-top: 16px;
|
||||
margin-top: 12px;
|
||||
padding: 10px 24px;
|
||||
background: var(--green-glow);
|
||||
color: #000;
|
||||
@@ -212,34 +198,20 @@ body {
|
||||
|
||||
.camera-prompt button:hover { background: var(--green-bright); }
|
||||
|
||||
.camera-prompt-label {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
letter-spacing: 2px;
|
||||
color: var(--green-glow);
|
||||
text-shadow: 0 0 12px rgba(0,216,120,0.4);
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
/* === Side Panels === */
|
||||
.side-panels {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
gap: 12px;
|
||||
overflow-y: auto;
|
||||
min-height: 0;
|
||||
max-height: 100%;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: var(--green-dim) transparent;
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: var(--bg-panel);
|
||||
border: 1px solid var(--bg-panel-border);
|
||||
border-radius: var(--radius);
|
||||
padding: 10px 14px;
|
||||
flex-shrink: 0;
|
||||
padding: 14px;
|
||||
}
|
||||
|
||||
.panel-title {
|
||||
@@ -324,44 +296,6 @@ body {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* === RuVector Pipeline === */
|
||||
.rv-pipeline {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 2px;
|
||||
margin-bottom: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.rv-stage {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 10px;
|
||||
padding: 3px 6px;
|
||||
border-radius: 3px;
|
||||
background: rgba(0,210,120,0.12);
|
||||
border: 1px solid rgba(0,210,120,0.3);
|
||||
color: var(--green-glow);
|
||||
transition: all 0.3s;
|
||||
}
|
||||
|
||||
.rv-stage.active {
|
||||
background: rgba(0,210,120,0.25);
|
||||
box-shadow: 0 0 6px rgba(0,210,120,0.3);
|
||||
}
|
||||
|
||||
.rv-arrow {
|
||||
font-size: 10px;
|
||||
color: var(--text-label);
|
||||
}
|
||||
|
||||
.rv-stats {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 10px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
/* === Latency Panel === */
|
||||
.latency-grid {
|
||||
display: grid;
|
||||
@@ -453,71 +387,6 @@ body {
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* === RSSI Signal Strength === */
|
||||
.rssi-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.rssi-gauge { flex: 1; }
|
||||
|
||||
.rssi-bar-track {
|
||||
height: 8px;
|
||||
background: rgba(255,255,255,0.06);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.rssi-bar-fill {
|
||||
height: 100%;
|
||||
border-radius: 4px;
|
||||
background: linear-gradient(90deg, var(--red-alert), var(--amber), var(--green-glow));
|
||||
transition: width 0.4s ease;
|
||||
position: relative;
|
||||
box-shadow: 0 0 6px rgba(0,210,120,0.3);
|
||||
}
|
||||
|
||||
.rssi-bar-fill::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0; left: 0; right: 0; bottom: 0;
|
||||
background: linear-gradient(90deg, transparent 0%, rgba(255,255,255,0.2) 50%, transparent 100%);
|
||||
animation: rssi-shimmer 2s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes rssi-shimmer {
|
||||
0% { transform: translateX(-100%); }
|
||||
100% { transform: translateX(100%); }
|
||||
}
|
||||
|
||||
.rssi-values {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.rssi-dbm {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: var(--green-glow);
|
||||
}
|
||||
|
||||
.rssi-quality {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 11px;
|
||||
color: var(--text-secondary);
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
#rssi-sparkline {
|
||||
flex-shrink: 0;
|
||||
border-radius: 4px;
|
||||
background: rgba(0,0,0,0.3);
|
||||
}
|
||||
|
||||
/* === Skeleton colors === */
|
||||
.skeleton-joint { fill: var(--green-glow); }
|
||||
.skeleton-limb { stroke: var(--green-bright); }
|
||||
|
||||
@@ -37,18 +37,12 @@ export class CanvasRenderer {
|
||||
const limbColor = color === 'amber' ? this.colors.csiLimb : this.colors.limb;
|
||||
const glowColor = color === 'amber' ? 'rgba(255,176,32,0.4)' : this.colors.jointGlow;
|
||||
|
||||
// Extended keypoint styling
|
||||
const fingerColor = '#ff6ef0'; // Magenta for finger tips
|
||||
const fingerGlow = 'rgba(255,110,240,0.4)';
|
||||
const fingerLimb = 'rgba(255,110,240,0.5)';
|
||||
const toeColor = '#6ef0ff'; // Cyan for toes
|
||||
const neckColor = '#ffffff'; // White for neck
|
||||
|
||||
ctx.clearRect(0, 0, width, height);
|
||||
|
||||
if (!keypoints || keypoints.length === 0) return;
|
||||
|
||||
// Draw limbs first (behind joints)
|
||||
ctx.lineWidth = 3;
|
||||
ctx.lineCap = 'round';
|
||||
|
||||
for (const [i, j] of SKELETON_CONNECTIONS) {
|
||||
@@ -60,22 +54,18 @@ export class CanvasRenderer {
|
||||
const bx = kpB.x * width, by = kpB.y * height;
|
||||
const avgConf = (kpA.confidence + kpB.confidence) / 2;
|
||||
|
||||
// Is this a hand/finger connection? (indices 17-22)
|
||||
const isFingerLink = i >= 17 && i <= 22 || j >= 17 && j <= 22;
|
||||
const isToeLink = i >= 23 && i <= 24 || j >= 23 && j <= 24;
|
||||
|
||||
// Glow
|
||||
ctx.strokeStyle = isFingerLink ? fingerLimb : this.colors.limbGlow;
|
||||
ctx.lineWidth = isFingerLink ? 4 : 8;
|
||||
ctx.globalAlpha = avgConf * (isFingerLink ? 0.3 : 0.4);
|
||||
ctx.strokeStyle = this.colors.limbGlow;
|
||||
ctx.lineWidth = 8;
|
||||
ctx.globalAlpha = avgConf * 0.4;
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(ax, ay);
|
||||
ctx.lineTo(bx, by);
|
||||
ctx.stroke();
|
||||
|
||||
// Main line
|
||||
ctx.strokeStyle = isFingerLink ? fingerColor : isToeLink ? toeColor : limbColor;
|
||||
ctx.lineWidth = isFingerLink || isToeLink ? 1.5 : 2.5;
|
||||
ctx.strokeStyle = limbColor;
|
||||
ctx.lineWidth = 2.5;
|
||||
ctx.globalAlpha = avgConf;
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(ax, ay);
|
||||
@@ -85,52 +75,43 @@ export class CanvasRenderer {
|
||||
|
||||
// Draw joints
|
||||
ctx.globalAlpha = 1;
|
||||
for (let idx = 0; idx < keypoints.length; idx++) {
|
||||
const kp = keypoints[idx];
|
||||
for (const kp of keypoints) {
|
||||
if (!kp || kp.confidence < minConf) continue;
|
||||
|
||||
const x = kp.x * width;
|
||||
const y = kp.y * height;
|
||||
const isFinger = idx >= 17 && idx <= 22;
|
||||
const isToe = idx >= 23 && idx <= 24;
|
||||
const isNeck = idx === 25;
|
||||
const r = isFinger ? 2 + kp.confidence * 2 : isToe ? 2 : 3 + kp.confidence * 3;
|
||||
const jColor = isFinger ? fingerColor : isToe ? toeColor : isNeck ? neckColor : jointColor;
|
||||
const gColor = isFinger ? fingerGlow : glowColor;
|
||||
const r = 3 + kp.confidence * 3;
|
||||
|
||||
// Glow
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, r + (isFinger ? 3 : 4), 0, Math.PI * 2);
|
||||
ctx.fillStyle = gColor;
|
||||
ctx.globalAlpha = kp.confidence * (isFinger ? 0.5 : 0.6);
|
||||
ctx.arc(x, y, r + 4, 0, Math.PI * 2);
|
||||
ctx.fillStyle = glowColor;
|
||||
ctx.globalAlpha = kp.confidence * 0.6;
|
||||
ctx.fill();
|
||||
|
||||
// Joint dot
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, r, 0, Math.PI * 2);
|
||||
ctx.fillStyle = jColor;
|
||||
ctx.fillStyle = jointColor;
|
||||
ctx.globalAlpha = kp.confidence;
|
||||
ctx.fill();
|
||||
|
||||
// White center (body joints only)
|
||||
if (!isFinger && !isToe) {
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, r * 0.4, 0, Math.PI * 2);
|
||||
ctx.fillStyle = '#fff';
|
||||
ctx.globalAlpha = kp.confidence * 0.8;
|
||||
ctx.fill();
|
||||
}
|
||||
// White center
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, r * 0.4, 0, Math.PI * 2);
|
||||
ctx.fillStyle = '#fff';
|
||||
ctx.globalAlpha = kp.confidence * 0.8;
|
||||
ctx.fill();
|
||||
}
|
||||
|
||||
ctx.globalAlpha = 1;
|
||||
|
||||
// Confidence label + keypoint count
|
||||
// Confidence label
|
||||
if (opts.label) {
|
||||
const visCount = keypoints.filter(kp => kp && kp.confidence >= minConf).length;
|
||||
ctx.font = '11px "JetBrains Mono", monospace';
|
||||
ctx.fillStyle = jointColor;
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.fillText(`${opts.label} · ${visCount} joints`, 8, height - 8);
|
||||
ctx.fillText(opts.label, 8, height - 8);
|
||||
ctx.globalAlpha = 1;
|
||||
}
|
||||
}
|
||||
@@ -204,63 +185,22 @@ export class CanvasRenderer {
|
||||
ctx.beginPath(); ctx.moveTo(w / 2, 0); ctx.lineTo(w / 2, h); ctx.stroke();
|
||||
ctx.beginPath(); ctx.moveTo(0, h / 2); ctx.lineTo(w, h / 2); ctx.stroke();
|
||||
|
||||
// Auto-scale: find max extent across all point sets
|
||||
let maxExtent = 0.01;
|
||||
for (const pts of [points.video, points.csi, points.fused]) {
|
||||
if (!pts) continue;
|
||||
for (const p of pts) {
|
||||
if (!p) continue;
|
||||
maxExtent = Math.max(maxExtent, Math.abs(p[0]), Math.abs(p[1]));
|
||||
}
|
||||
}
|
||||
const scale = 0.42 / maxExtent; // Fill ~84% of half-width
|
||||
|
||||
const drawPoints = (pts, color, size) => {
|
||||
if (!pts || pts.length === 0) return;
|
||||
const len = pts.length;
|
||||
|
||||
// Draw trail line connecting recent points
|
||||
if (len >= 2) {
|
||||
ctx.beginPath();
|
||||
let started = false;
|
||||
for (let i = 0; i < len; i++) {
|
||||
const p = pts[i];
|
||||
if (!p) continue;
|
||||
const px = w / 2 + p[0] * scale * w;
|
||||
const py = h / 2 + p[1] * scale * h;
|
||||
if (px < -10 || px > w + 10 || py < -10 || py > h + 10) continue;
|
||||
if (!started) { ctx.moveTo(px, py); started = true; }
|
||||
else ctx.lineTo(px, py);
|
||||
}
|
||||
ctx.strokeStyle = color;
|
||||
ctx.globalAlpha = 0.2;
|
||||
ctx.lineWidth = 1;
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// Draw dots with glow on newest
|
||||
for (let i = 0; i < len; i++) {
|
||||
const p = pts[i];
|
||||
if (!p) continue;
|
||||
const age = 1 - (i / len) * 0.7;
|
||||
const px = w / 2 + p[0] * scale * w;
|
||||
const py = h / 2 + p[1] * scale * h;
|
||||
const age = 1 - (i / len) * 0.7; // Fade older points
|
||||
const px = w / 2 + p[0] * w * 0.35;
|
||||
const py = h / 2 + p[1] * h * 0.35;
|
||||
|
||||
if (px < -10 || px > w + 10 || py < -10 || py > h + 10) continue;
|
||||
|
||||
// Glow on newest point
|
||||
if (i === len - 1) {
|
||||
ctx.beginPath();
|
||||
ctx.arc(px, py, size + 4, 0, Math.PI * 2);
|
||||
ctx.fillStyle = color;
|
||||
ctx.globalAlpha = 0.3;
|
||||
ctx.fill();
|
||||
}
|
||||
if (px < 0 || px > w || py < 0 || py > h) continue;
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.arc(px, py, i === len - 1 ? size + 1 : size, 0, Math.PI * 2);
|
||||
ctx.arc(px, py, size, 0, Math.PI * 2);
|
||||
ctx.fillStyle = color;
|
||||
ctx.globalAlpha = age * 0.8;
|
||||
ctx.globalAlpha = age * 0.7;
|
||||
ctx.fill();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
/**
|
||||
* CNN Embedder — RuVector Attention-powered feature extractor.
|
||||
* CNN Embedder — Lightweight MobileNet-V3-style feature extractor.
|
||||
*
|
||||
* Uses the real ruvector-attention-wasm WASM module for Multi-Head Attention
|
||||
* and Flash Attention on CSI/video data. Falls back to a JS Conv2D pipeline
|
||||
* when WASM is not available.
|
||||
* Architecture mirrors ruvector-cnn: Conv2D → BatchNorm → ReLU → Pool → Project → L2 Normalize
|
||||
* Uses pre-seeded random weights (deterministic). When ruvector-cnn-wasm is available,
|
||||
* transparently delegates to the WASM implementation.
|
||||
*
|
||||
* Pipeline: Conv2D → BatchNorm → ReLU → Pool → RuVector Attention → Project → L2 Normalize
|
||||
* Two instances are created: one for video frames, one for CSI pseudo-images.
|
||||
*/
|
||||
|
||||
@@ -32,14 +31,6 @@ export class CnnEmbedder {
|
||||
this.embeddingDim = opts.embeddingDim || 128;
|
||||
this.normalize = opts.normalize !== false;
|
||||
this.wasmEmbedder = null;
|
||||
this.rvAttention = null; // RuVector Multi-Head Attention (WASM)
|
||||
this.rvFlash = null; // RuVector Flash Attention (WASM)
|
||||
this.rvHyperbolic = null; // RuVector Hyperbolic Attention (hierarchical body)
|
||||
this.rvMoE = null; // RuVector Mixture-of-Experts (body-region routing)
|
||||
this.rvLinear = null; // RuVector Linear Attention (O(n) fast hand refinement)
|
||||
this.rvLocalGlobal = null; // RuVector Local-Global Attention (detail + context)
|
||||
this.rvModule = null; // RuVector WASM module reference
|
||||
this.useRuVector = false;
|
||||
|
||||
// Initialize weights with deterministic PRNG
|
||||
const rng = mulberry32(opts.seed || 42);
|
||||
@@ -57,50 +48,18 @@ export class CnnEmbedder {
|
||||
this.bnMean = new Float32Array(16).fill(0.0);
|
||||
this.bnVar = new Float32Array(16).fill(1.0);
|
||||
|
||||
// Projection: 16 → embeddingDim (used when RuVector not available)
|
||||
// Projection: 16 → embeddingDim
|
||||
this.projWeights = new Float32Array(16 * this.embeddingDim);
|
||||
for (let i = 0; i < this.projWeights.length; i++) {
|
||||
this.projWeights[i] = randRange(-0.1, 0.1);
|
||||
}
|
||||
|
||||
// Attention projection: attention_dim → embeddingDim
|
||||
this.attnProjWeights = new Float32Array(16 * this.embeddingDim);
|
||||
for (let i = 0; i < this.attnProjWeights.length; i++) {
|
||||
this.attnProjWeights[i] = randRange(-0.08, 0.08);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to load RuVector attention WASM, then fall back to ruvector-cnn-wasm
|
||||
* Try to load WASM embedder from ruvector-cnn-wasm package
|
||||
* @param {string} wasmPath - Path to the WASM package directory
|
||||
*/
|
||||
async tryLoadWasm(wasmPath) {
|
||||
// First try: RuVector Attention WASM (the real thing — browser ESM build)
|
||||
try {
|
||||
const attnBase = new URL('../pkg/ruvector-attention/ruvector_attention_browser.js', import.meta.url).href;
|
||||
const mod = await import(attnBase);
|
||||
await mod.default(); // async WASM init via fetch
|
||||
mod.init();
|
||||
|
||||
// Create all 6 attention mechanisms
|
||||
this.rvAttention = new mod.WasmMultiHeadAttention(16, 4);
|
||||
this.rvFlash = new mod.WasmFlashAttention(16, 8);
|
||||
this.rvHyperbolic = new mod.WasmHyperbolicAttention(16, -1.0);
|
||||
this.rvMoE = new mod.WasmMoEAttention(16, 3, 2);
|
||||
this.rvLinear = new mod.WasmLinearAttention(16, 16);
|
||||
this.rvLocalGlobal = new mod.WasmLocalGlobalAttention(16, 4, 2);
|
||||
this.rvModule = mod;
|
||||
this.useRuVector = true;
|
||||
|
||||
// Log available mechanisms
|
||||
const mechs = mod.available_mechanisms();
|
||||
console.log(`[CNN] RuVector WASM v${mod.version()} — all 6 attention mechanisms active`, mechs);
|
||||
return true;
|
||||
} catch (e) {
|
||||
console.log('[CNN] RuVector Attention WASM not available:', e.message);
|
||||
}
|
||||
|
||||
// Second try: ruvector-cnn-wasm (legacy path)
|
||||
try {
|
||||
const mod = await import(`${wasmPath}/ruvector_cnn_wasm.js`);
|
||||
await mod.default();
|
||||
@@ -109,10 +68,10 @@ export class CnnEmbedder {
|
||||
config.embedding_dim = this.embeddingDim;
|
||||
config.normalize = this.normalize;
|
||||
this.wasmEmbedder = new mod.WasmCnnEmbedder(config);
|
||||
console.log('[CNN] WASM CNN embedder loaded successfully');
|
||||
console.log('[CNN] WASM embedder loaded successfully');
|
||||
return true;
|
||||
} catch (e) {
|
||||
console.log('[CNN] WASM CNN not available, using JS fallback:', e.message);
|
||||
console.log('[CNN] WASM not available, using JS fallback:', e.message);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -166,17 +125,10 @@ export class CnnEmbedder {
|
||||
if (convOut[i] < 0) convOut[i] = 0;
|
||||
}
|
||||
|
||||
// 6. Global average pooling → spatial tokens (each 16-dim)
|
||||
// 6. Global average pooling → 16-dim
|
||||
const outH = sz - 2, outW = sz - 2;
|
||||
const spatial = outH * outW;
|
||||
|
||||
// 7. RuVector Attention (if loaded) — apply attention over spatial tokens
|
||||
if (this.useRuVector && this.rvAttention) {
|
||||
return this._extractWithAttention(convOut, spatial, 16);
|
||||
}
|
||||
|
||||
// Fallback: simple global average pool + linear projection
|
||||
const pooled = new Float32Array(16);
|
||||
const spatial = outH * outW;
|
||||
for (let i = 0; i < spatial; i++) {
|
||||
for (let c = 0; c < 16; c++) {
|
||||
pooled[c] += convOut[i * 16 + c];
|
||||
@@ -184,7 +136,7 @@ export class CnnEmbedder {
|
||||
}
|
||||
for (let c = 0; c < 16; c++) pooled[c] /= spatial;
|
||||
|
||||
// Linear projection → embeddingDim
|
||||
// 7. Linear projection → embeddingDim
|
||||
const emb = new Float32Array(this.embeddingDim);
|
||||
for (let o = 0; o < this.embeddingDim; o++) {
|
||||
let sum = 0;
|
||||
@@ -194,7 +146,7 @@ export class CnnEmbedder {
|
||||
emb[o] = sum;
|
||||
}
|
||||
|
||||
// L2 normalize
|
||||
// 8. L2 normalize
|
||||
if (this.normalize) {
|
||||
let norm = 0;
|
||||
for (let i = 0; i < emb.length; i++) norm += emb[i] * emb[i];
|
||||
@@ -207,149 +159,6 @@ export class CnnEmbedder {
|
||||
return emb;
|
||||
}
|
||||
|
||||
/**
|
||||
* Full 6-stage RuVector WASM attention pipeline:
|
||||
* 1. Flash Attention (efficient O(n) pre-screening of spatial tokens)
|
||||
* 2. Multi-Head Attention (global spatial reasoning)
|
||||
* 3. Hyperbolic Attention (hierarchical body-part structure, Poincaré ball)
|
||||
* 4. Linear Attention (O(n) refinement for fine detail — hands/extremities)
|
||||
* 5. MoE Attention (body-region specialized expert routing)
|
||||
* 6. Local-Global Attention (local detail + global context fusion)
|
||||
* → Weighted blend + batch_normalize + project + L2 normalize
|
||||
*/
|
||||
_extractWithAttention(convOut, numTokens, channels) {
|
||||
const mod = this.rvModule;
|
||||
|
||||
// Subsample spatial tokens for attention (max 64 for speed)
|
||||
const maxTokens = 64;
|
||||
const step = numTokens > maxTokens ? Math.floor(numTokens / maxTokens) : 1;
|
||||
const tokens = [];
|
||||
for (let i = 0; i < numTokens && tokens.length < maxTokens; i += step) {
|
||||
const token = new Float32Array(channels);
|
||||
for (let c = 0; c < channels; c++) {
|
||||
token[c] = convOut[i * channels + c];
|
||||
}
|
||||
tokens.push(token);
|
||||
}
|
||||
|
||||
const numQueries = Math.min(4, tokens.length);
|
||||
const queryStride = Math.floor(tokens.length / numQueries);
|
||||
|
||||
// === Stage 1: Flash Attention (efficient pre-screening) ===
|
||||
const flashOut = new Float32Array(channels);
|
||||
try {
|
||||
// Flash attention with block size 8 for efficient O(n) screening
|
||||
const result = this.rvFlash.compute(tokens[0], tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) flashOut[c] = result[c];
|
||||
} catch (_) {
|
||||
flashOut.set(tokens[0]);
|
||||
}
|
||||
|
||||
// === Stage 2: Multi-Head Attention (global spatial reasoning) ===
|
||||
const mhaOut = new Float32Array(channels);
|
||||
for (let q = 0; q < numQueries; q++) {
|
||||
const queryToken = tokens[q * queryStride];
|
||||
try {
|
||||
const result = this.rvAttention.compute(queryToken, tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) mhaOut[c] += result[c] / numQueries;
|
||||
} catch (_) {
|
||||
for (let c = 0; c < channels; c++) mhaOut[c] += queryToken[c] / numQueries;
|
||||
}
|
||||
}
|
||||
|
||||
// === Stage 3: Hyperbolic Attention (hierarchical body structure) ===
|
||||
const hyOut = new Float32Array(channels);
|
||||
try {
|
||||
const result = this.rvHyperbolic.compute(mhaOut, tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) hyOut[c] = result[c];
|
||||
} catch (_) {
|
||||
hyOut.set(mhaOut);
|
||||
}
|
||||
|
||||
// === Stage 4: Linear Attention (O(n) fast refinement for extremities) ===
|
||||
const linOut = new Float32Array(channels);
|
||||
try {
|
||||
const result = this.rvLinear.compute(hyOut, tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) linOut[c] = result[c];
|
||||
} catch (_) {
|
||||
linOut.set(hyOut);
|
||||
}
|
||||
|
||||
// === Stage 5: MoE Attention (body-region expert routing) ===
|
||||
const moeOut = new Float32Array(channels);
|
||||
try {
|
||||
const result = this.rvMoE.compute(linOut, tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) moeOut[c] = result[c];
|
||||
} catch (_) {
|
||||
moeOut.set(linOut);
|
||||
}
|
||||
|
||||
// === Stage 6: Local-Global Attention (detail + context) ===
|
||||
const lgOut = new Float32Array(channels);
|
||||
try {
|
||||
const result = this.rvLocalGlobal.compute(moeOut, tokens, tokens);
|
||||
for (let c = 0; c < channels; c++) lgOut[c] = result[c];
|
||||
} catch (_) {
|
||||
lgOut.set(moeOut);
|
||||
}
|
||||
|
||||
// === Blend all 6 outputs ===
|
||||
// Use WASM softmax on log-energy scores for dynamic stage weighting
|
||||
const blended = new Float32Array(channels);
|
||||
const stages = [flashOut, mhaOut, hyOut, linOut, moeOut, lgOut];
|
||||
// Use log-energy to prevent exp() overflow in softmax
|
||||
const logEnergies = new Float32Array(6);
|
||||
for (let s = 0; s < 6; s++) {
|
||||
const e = this._energy(stages[s]);
|
||||
logEnergies[s] = e > 1e-10 ? Math.log(e) : -20;
|
||||
}
|
||||
try { mod.softmax(logEnergies); } catch (_) {
|
||||
let max = -Infinity;
|
||||
for (let i = 0; i < 6; i++) max = Math.max(max, logEnergies[i]);
|
||||
let sum = 0;
|
||||
for (let i = 0; i < 6; i++) { logEnergies[i] = Math.exp(logEnergies[i] - max); sum += logEnergies[i]; }
|
||||
for (let i = 0; i < 6; i++) logEnergies[i] /= sum;
|
||||
}
|
||||
for (let c = 0; c < channels; c++) {
|
||||
for (let s = 0; s < 6; s++) {
|
||||
blended[c] += logEnergies[s] * stages[s][c];
|
||||
}
|
||||
}
|
||||
|
||||
// Batch normalize only when we have enough diversity (skip for single vectors)
|
||||
// Single-vector batch norm collapses to zeros, killing embedding space
|
||||
let normed = blended;
|
||||
|
||||
// Project to embeddingDim
|
||||
const emb = new Float32Array(this.embeddingDim);
|
||||
for (let o = 0; o < this.embeddingDim; o++) {
|
||||
let sum = 0;
|
||||
for (let i = 0; i < channels; i++) {
|
||||
sum += normed[i] * this.attnProjWeights[i * this.embeddingDim + o];
|
||||
}
|
||||
emb[o] = sum;
|
||||
}
|
||||
|
||||
// L2 normalize using RuVector WASM
|
||||
if (this.normalize) {
|
||||
try { mod.normalize(emb); } catch (_) {
|
||||
let norm = 0;
|
||||
for (let i = 0; i < emb.length; i++) norm += emb[i] * emb[i];
|
||||
norm = Math.sqrt(norm);
|
||||
if (norm > 1e-8) for (let i = 0; i < emb.length; i++) emb[i] /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
return emb;
|
||||
}
|
||||
|
||||
/** Compute vector energy (L2 norm squared) for attention weighting */
|
||||
_energy(vec) {
|
||||
let e = 0;
|
||||
for (let i = 0; i < vec.length; i++) e += vec[i] * vec[i];
|
||||
return e;
|
||||
}
|
||||
|
||||
_conv2d3x3(input, H, W, Cin, Cout) {
|
||||
const outH = H - 2, outW = W - 2;
|
||||
const output = new Float32Array(outH * outW * Cout);
|
||||
@@ -401,33 +210,7 @@ export class CnnEmbedder {
|
||||
return output;
|
||||
}
|
||||
|
||||
/** Cosine similarity using WASM when available, JS fallback */
|
||||
cosineSim(a, b) {
|
||||
if (this.rvModule) {
|
||||
try { return this.rvModule.cosine_similarity(a, b); } catch (_) { /* fallback */ }
|
||||
}
|
||||
return CnnEmbedder.cosineSimilarity(a, b);
|
||||
}
|
||||
|
||||
/** L2 norm using WASM when available */
|
||||
l2Norm(vec) {
|
||||
if (this.rvModule) {
|
||||
try { return this.rvModule.l2_norm(vec); } catch (_) { /* fallback */ }
|
||||
}
|
||||
let norm = 0;
|
||||
for (let i = 0; i < vec.length; i++) norm += vec[i] * vec[i];
|
||||
return Math.sqrt(norm);
|
||||
}
|
||||
|
||||
/** Pairwise distance matrix using WASM (for skeleton validation) */
|
||||
pairwiseDistances(vectors) {
|
||||
if (this.rvModule) {
|
||||
try { return this.rvModule.pairwise_distances(vectors); } catch (_) { /* fallback */ }
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Static JS fallback for cosine similarity */
|
||||
/** Cosine similarity between two embeddings */
|
||||
static cosineSimilarity(a, b) {
|
||||
let dot = 0, normA = 0, normB = 0;
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
*/
|
||||
|
||||
export class CsiSimulator {
|
||||
static VERSION = 'v4-drift'; // Cache-bust verification
|
||||
|
||||
constructor(opts = {}) {
|
||||
this.subcarriers = opts.subcarriers || 52; // 802.11n HT20
|
||||
this.timeWindow = opts.timeWindow || 56; // frames in sliding window
|
||||
@@ -34,10 +32,6 @@ export class CsiSimulator {
|
||||
this._basePhase[i] = (i / this.subcarriers) * Math.PI * 2;
|
||||
}
|
||||
|
||||
// RSSI tracking
|
||||
this.rssiDbm = -70; // default mid-range
|
||||
this._rssiTarget = -70;
|
||||
|
||||
// Person influence (updated from video motion)
|
||||
this.personPresence = 0;
|
||||
this.personX = 0.5;
|
||||
@@ -79,9 +73,6 @@ export class CsiSimulator {
|
||||
* (simulating through-wall sensing capability).
|
||||
*/
|
||||
updatePersonState(presence, x, y, motion) {
|
||||
// Don't override real CSI sensing with synthetic video-derived state
|
||||
if (this.mode === 'live') return;
|
||||
|
||||
if (presence > 0.1) {
|
||||
// Person detected in video — update CSI state directly
|
||||
this.personPresence = presence;
|
||||
@@ -135,13 +126,6 @@ export class CsiSimulator {
|
||||
this.phaseBuffer.shift();
|
||||
}
|
||||
|
||||
// RSSI: smooth toward target (demo mode generates synthetic RSSI)
|
||||
if (this.mode === 'demo') {
|
||||
// Simulate RSSI based on person presence and slow drift
|
||||
this._rssiTarget = -55 - 25 * (1 - this.personPresence) + Math.sin(elapsed * 0.3) * 3;
|
||||
}
|
||||
this.rssiDbm += (this._rssiTarget - this.rssiDbm) * 0.1;
|
||||
|
||||
// SNR estimate
|
||||
let signalPower = 0, noisePower = 0;
|
||||
for (let i = 0; i < this.subcarriers; i++) {
|
||||
@@ -231,11 +215,6 @@ export class CsiSimulator {
|
||||
this._noiseState[i] = 0.95 * this._noiseState[i] + 0.05 * (rng() * 2 - 1) * 0.03;
|
||||
a += this._noiseState[i];
|
||||
|
||||
// Ambient temporal drift (multipath fading even in empty room)
|
||||
a += 0.06 * Math.sin(elapsed * 0.7 + i * 0.25)
|
||||
+ 0.04 * Math.sin(elapsed * 1.3 - i * 0.18)
|
||||
+ 0.03 * Math.cos(elapsed * 2.1 + i * 0.4);
|
||||
|
||||
// Person-induced CSI perturbation
|
||||
if (presence > 0.1) {
|
||||
// Subcarrier-dependent body reflection (Fresnel zone model)
|
||||
@@ -258,23 +237,6 @@ export class CsiSimulator {
|
||||
}
|
||||
|
||||
_handleLiveFrame(data) {
|
||||
// Handle JSON text frames from the sensing server
|
||||
if (typeof data === 'string') {
|
||||
try {
|
||||
const msg = JSON.parse(data);
|
||||
this._handleJsonFrame(msg);
|
||||
} catch (_) { /* ignore malformed JSON */ }
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle Blob data (convert to ArrayBuffer and re-process)
|
||||
if (data instanceof Blob) {
|
||||
data.arrayBuffer().then(ab => this._handleLiveFrame(ab)).catch(() => {});
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle binary ArrayBuffer frames (ADR-018 format)
|
||||
if (!(data instanceof ArrayBuffer)) return;
|
||||
const view = new DataView(data);
|
||||
// Check ADR-018 magic: 0xC5110001
|
||||
if (data.byteLength < 20) return;
|
||||
@@ -294,64 +256,6 @@ export class CsiSimulator {
|
||||
}
|
||||
}
|
||||
|
||||
_handleJsonFrame(msg) {
|
||||
// Sensing server sends: { type: "sensing_update", nodes: [{ amplitude: [...], subcarrier_count }], classification, features }
|
||||
this._liveAmplitude = new Float32Array(this.subcarriers);
|
||||
this._livePhase = new Float32Array(this.subcarriers);
|
||||
|
||||
// Extract amplitude from sensing_update node data
|
||||
const node = (msg.nodes && msg.nodes[0]) || msg;
|
||||
const ampArr = node.amplitude || msg.amplitude;
|
||||
if (ampArr && Array.isArray(ampArr)) {
|
||||
const n = Math.min(ampArr.length, this.subcarriers);
|
||||
// Server sends raw amplitude (already magnitude), normalize to 0-1
|
||||
let maxAmp = 0;
|
||||
for (let i = 0; i < n; i++) maxAmp = Math.max(maxAmp, Math.abs(ampArr[i]));
|
||||
const scale = maxAmp > 0 ? 1.0 / maxAmp : 1.0;
|
||||
for (let i = 0; i < n; i++) {
|
||||
this._liveAmplitude[i] = Math.abs(ampArr[i]) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Phase from node (if available)
|
||||
const phaseArr = node.phase || msg.phase;
|
||||
if (phaseArr && Array.isArray(phaseArr)) {
|
||||
const n = Math.min(phaseArr.length, this.subcarriers);
|
||||
for (let i = 0; i < n; i++) this._livePhase[i] = phaseArr[i];
|
||||
} else if (ampArr) {
|
||||
// Synthesize phase from amplitude variation (Hilbert-like estimate)
|
||||
for (let i = 1; i < this.subcarriers; i++) {
|
||||
this._livePhase[i] = this._livePhase[i - 1] + (this._liveAmplitude[i] - this._liveAmplitude[i - 1]) * Math.PI;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle raw I/Q pairs
|
||||
const iq = node.iq || msg.iq;
|
||||
if (iq && Array.isArray(iq)) {
|
||||
const n = Math.min(iq.length / 2, this.subcarriers);
|
||||
for (let i = 0; i < n; i++) {
|
||||
const real = iq[i * 2], imag = iq[i * 2 + 1];
|
||||
this._liveAmplitude[i] = Math.sqrt(real * real + imag * imag) / 2048;
|
||||
this._livePhase[i] = Math.atan2(imag, real);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract RSSI from node data
|
||||
if (typeof node.rssi_dbm === 'number') {
|
||||
this._rssiTarget = node.rssi_dbm;
|
||||
} else if (msg.features && typeof msg.features.mean_rssi === 'number') {
|
||||
this._rssiTarget = msg.features.mean_rssi;
|
||||
}
|
||||
|
||||
// Update presence from server classification
|
||||
const cls = msg.classification;
|
||||
if (cls) {
|
||||
if (typeof cls.confidence === 'number') {
|
||||
this.personPresence = cls.presence ? cls.confidence : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_mulberry32(seed) {
|
||||
return function() {
|
||||
let t = (seed += 0x6D2B79F5);
|
||||
|
||||
@@ -8,14 +8,12 @@
|
||||
export class FusionEngine {
|
||||
/**
|
||||
* @param {number} embeddingDim
|
||||
* @param {object} opts
|
||||
* @param {object} opts.wasmModule - RuVector WASM module for cosine_similarity etc.
|
||||
*/
|
||||
constructor(embeddingDim = 128, opts = {}) {
|
||||
constructor(embeddingDim = 128) {
|
||||
this.embeddingDim = embeddingDim;
|
||||
this.wasmModule = opts.wasmModule || null;
|
||||
|
||||
// Learnable attention weights (initialized to balanced 0.5)
|
||||
// In production, these would be loaded from trained JSON
|
||||
this.attentionWeights = new Float32Array(embeddingDim).fill(0.5);
|
||||
|
||||
// Dynamic modality confidence [0, 1]
|
||||
@@ -33,9 +31,6 @@ export class FusionEngine {
|
||||
this.maxHistory = 50;
|
||||
}
|
||||
|
||||
/** Set the WASM module reference (called after WASM loads) */
|
||||
setWasmModule(mod) { this.wasmModule = mod; }
|
||||
|
||||
/**
|
||||
* Update quality-based confidence scores
|
||||
* @param {number} videoBrightness - [0,1] video brightness quality
|
||||
@@ -99,11 +94,12 @@ export class FusionEngine {
|
||||
fused[i] = alpha * videoEmb[i] + (1 - alpha) * csiEmb[i];
|
||||
}
|
||||
|
||||
// Re-normalize using WASM when available
|
||||
if (this.wasmModule) {
|
||||
try { this.wasmModule.normalize(fused); } catch (_) { this._jsNormalize(fused); }
|
||||
} else {
|
||||
this._jsNormalize(fused);
|
||||
// Re-normalize
|
||||
let norm = 0;
|
||||
for (let i = 0; i < dim; i++) norm += fused[i] * fused[i];
|
||||
norm = Math.sqrt(norm);
|
||||
if (norm > 1e-8) {
|
||||
for (let i = 0; i < dim; i++) fused[i] /= norm;
|
||||
}
|
||||
|
||||
this._recordEmbedding(videoEmb, csiEmb, fused);
|
||||
@@ -115,19 +111,18 @@ export class FusionEngine {
|
||||
* @returns {{ video: Array, csi: Array, fused: Array }}
|
||||
*/
|
||||
getEmbeddingPoints() {
|
||||
// Sparse random projection: pick a few dimensions with fixed coefficients
|
||||
// to get visible 2D spread (avoids cancellation from summing all 128 dims)
|
||||
// Simple 2D projection using first two principal components (approximated)
|
||||
const project = (emb) => {
|
||||
if (!emb || emb.length < 4) return null;
|
||||
// Use 8 sparse dimensions with predetermined signs (seeded, not random)
|
||||
const dim = emb.length;
|
||||
const x = emb[0] * 3.2 - emb[3] * 2.8 + emb[7] * 2.1 - emb[12] * 1.9
|
||||
+ (dim > 30 ? emb[29] * 1.5 - emb[31] * 1.3 : 0)
|
||||
+ (dim > 60 ? emb[55] * 1.1 - emb[60] * 0.9 : 0);
|
||||
const y = emb[1] * 3.0 - emb[5] * 2.5 + emb[9] * 2.3 - emb[15] * 1.7
|
||||
+ (dim > 40 ? emb[37] * 1.4 - emb[42] * 1.2 : 0)
|
||||
+ (dim > 80 ? emb[73] * 1.0 - emb[80] * 0.8 : 0);
|
||||
return [x, y];
|
||||
// Use pairs of dimensions as crude 2D projection
|
||||
let x = 0, y = 0;
|
||||
for (let i = 0; i < emb.length; i += 2) {
|
||||
x += emb[i] * (i % 4 < 2 ? 1 : -1);
|
||||
if (i + 1 < emb.length) {
|
||||
y += emb[i + 1] * (i % 4 < 2 ? 1 : -1);
|
||||
}
|
||||
}
|
||||
return [x * 2, y * 2]; // Scale for visibility
|
||||
};
|
||||
|
||||
return {
|
||||
@@ -146,11 +141,6 @@ export class FusionEngine {
|
||||
const c = this.recentCsiEmbeddings[this.recentCsiEmbeddings.length - 1];
|
||||
if (!v || !c) return 0;
|
||||
|
||||
// Use WASM cosine_similarity when available
|
||||
if (this.wasmModule) {
|
||||
try { return this.wasmModule.cosine_similarity(v, c); } catch (_) { /* fallback */ }
|
||||
}
|
||||
|
||||
let dot = 0, na = 0, nb = 0;
|
||||
for (let i = 0; i < v.length; i++) {
|
||||
dot += v[i] * c[i];
|
||||
@@ -161,13 +151,6 @@ export class FusionEngine {
|
||||
return (na > 1e-8 && nb > 1e-8) ? dot / (na * nb) : 0;
|
||||
}
|
||||
|
||||
_jsNormalize(vec) {
|
||||
let norm = 0;
|
||||
for (let i = 0; i < vec.length; i++) norm += vec[i] * vec[i];
|
||||
norm = Math.sqrt(norm);
|
||||
if (norm > 1e-8) for (let i = 0; i < vec.length; i++) vec[i] /= norm;
|
||||
}
|
||||
|
||||
_recordEmbedding(video, csi, fused) {
|
||||
if (video) {
|
||||
this.recentVideoEmbeddings.push(new Float32Array(video));
|
||||
|
||||
+16
-173
@@ -1,15 +1,15 @@
|
||||
/**
|
||||
* RuView — Dual-Modal Pose Estimation Demo
|
||||
* WiFi-DensePose — Dual-Modal Pose Estimation Demo
|
||||
*
|
||||
* Main orchestration: video capture → CNN embedding → CSI processing → fusion → rendering
|
||||
*/
|
||||
|
||||
import { VideoCapture } from './video-capture.js?v=13';
|
||||
import { CsiSimulator } from './csi-simulator.js?v=13';
|
||||
import { CnnEmbedder } from './cnn-embedder.js?v=13';
|
||||
import { FusionEngine } from './fusion-engine.js?v=13';
|
||||
import { PoseDecoder } from './pose-decoder.js?v=13';
|
||||
import { CanvasRenderer } from './canvas-renderer.js?v=13';
|
||||
import { VideoCapture } from './video-capture.js';
|
||||
import { CsiSimulator } from './csi-simulator.js';
|
||||
import { CnnEmbedder } from './cnn-embedder.js';
|
||||
import { FusionEngine } from './fusion-engine.js';
|
||||
import { PoseDecoder } from './pose-decoder.js';
|
||||
import { CanvasRenderer } from './canvas-renderer.js';
|
||||
|
||||
// === State ===
|
||||
let mode = 'dual'; // 'dual' | 'video' | 'csi'
|
||||
@@ -71,20 +71,9 @@ const latTotalEl = document.getElementById('lat-total');
|
||||
// Cross-modal similarity
|
||||
const crossModalEl = document.getElementById('cross-modal-sim');
|
||||
|
||||
// RSSI elements
|
||||
const rssiBarEl = document.getElementById('rssi-bar');
|
||||
const rssiValueEl = document.getElementById('rssi-value');
|
||||
const rssiQualityEl = document.getElementById('rssi-quality');
|
||||
const rssiSparkCanvas = document.getElementById('rssi-sparkline');
|
||||
const rssiSparkCtx = rssiSparkCanvas ? rssiSparkCanvas.getContext('2d') : null;
|
||||
const rssiHistory = [];
|
||||
const RSSI_HISTORY_MAX = 80;
|
||||
|
||||
// === Initialize ===
|
||||
function init() {
|
||||
console.log(`[PoseFusion] init() v4 — CsiSimulator=${CsiSimulator.VERSION || 'OLD'}, starting...`);
|
||||
resizeCanvases();
|
||||
console.log(`[PoseFusion] canvases: skeleton=${skeletonCanvas.width}x${skeletonCanvas.height}, csi=${csiCanvas.width}x${csiCanvas.height}, emb=${embeddingCanvas.width}x${embeddingCanvas.height}`);
|
||||
window.addEventListener('resize', resizeCanvases);
|
||||
|
||||
// Mode change
|
||||
@@ -121,19 +110,10 @@ function init() {
|
||||
}
|
||||
});
|
||||
|
||||
// Try to load RuVector Attention WASM embedders (non-blocking)
|
||||
const wasmBase = new URL('../pkg/ruvector-attention', import.meta.url).href;
|
||||
visualCnn.tryLoadWasm(wasmBase).then((ok) => {
|
||||
// Share the WASM module with FusionEngine for cosine_similarity, normalize, etc.
|
||||
if (visualCnn.rvModule) fusionEngine.setWasmModule(visualCnn.rvModule);
|
||||
// Update footer backend label
|
||||
const backendEl = document.getElementById('cnn-backend');
|
||||
if (backendEl) {
|
||||
backendEl.textContent = ok && visualCnn.useRuVector
|
||||
? `RuVector WASM v${visualCnn.rvModule.version()} — 6 attention mechanisms`
|
||||
: 'ruvector-cnn (JS fallback)';
|
||||
}
|
||||
});
|
||||
// Try to load WASM embedders (non-blocking)
|
||||
// Resolve relative to this JS module file (in pose-fusion/js/) → ../pkg/
|
||||
const wasmBase = new URL('../pkg/ruvector_cnn_wasm', import.meta.url).href;
|
||||
visualCnn.tryLoadWasm(wasmBase);
|
||||
csiCnn.tryLoadWasm(wasmBase);
|
||||
|
||||
// Auto-connect to local sensing server WebSocket if available
|
||||
@@ -170,6 +150,7 @@ async function startCamera() {
|
||||
|
||||
function updateModeUI() {
|
||||
const needsVideo = mode !== 'csi';
|
||||
const needsCsi = mode !== 'video';
|
||||
|
||||
// Show/hide camera prompt
|
||||
if (needsVideo && !videoCapture.isActive) {
|
||||
@@ -177,13 +158,6 @@ function updateModeUI() {
|
||||
} else {
|
||||
cameraPrompt.style.display = 'none';
|
||||
}
|
||||
|
||||
// Update mode label in both the overlay and the camera prompt
|
||||
const labelMap = { dual: 'DUAL FUSION', video: 'VIDEO ONLY', csi: 'CSI ONLY' };
|
||||
const modeLabel = document.getElementById('mode-label');
|
||||
const promptLabel = document.getElementById('prompt-mode-label');
|
||||
if (modeLabel) modeLabel.textContent = labelMap[mode] || mode;
|
||||
if (promptLabel) promptLabel.textContent = labelMap[mode] || mode;
|
||||
}
|
||||
|
||||
function resizeCanvases() {
|
||||
@@ -194,25 +168,22 @@ function resizeCanvases() {
|
||||
skeletonCanvas.height = rect.height;
|
||||
}
|
||||
|
||||
// CSI canvas (min 200px width)
|
||||
csiCanvas.width = Math.max(200, csiCanvas.parentElement.clientWidth);
|
||||
// CSI canvas
|
||||
csiCanvas.width = csiCanvas.parentElement.clientWidth;
|
||||
csiCanvas.height = 120;
|
||||
|
||||
// Embedding canvas (min 200px width)
|
||||
embeddingCanvas.width = Math.max(200, embeddingCanvas.parentElement.clientWidth);
|
||||
// Embedding canvas
|
||||
embeddingCanvas.width = embeddingCanvas.parentElement.clientWidth;
|
||||
embeddingCanvas.height = 140;
|
||||
}
|
||||
|
||||
// === Main Loop ===
|
||||
let _loopErrorShown = false;
|
||||
let _diagDone = false;
|
||||
function mainLoop(timestamp) {
|
||||
if (!isRunning) return;
|
||||
requestAnimationFrame(mainLoop);
|
||||
|
||||
if (isPaused) return;
|
||||
|
||||
try {
|
||||
const elapsed = performance.now() / 1000 - startTime;
|
||||
const totalStart = performance.now();
|
||||
|
||||
@@ -338,134 +309,6 @@ function mainLoop(timestamp) {
|
||||
// Cross-modal similarity
|
||||
const sim = fusionEngine.getCrossModalSimilarity();
|
||||
crossModalEl.textContent = sim.toFixed(3);
|
||||
|
||||
// RuVector attention pipeline stats
|
||||
const rvStats = poseDecoder.attentionStats;
|
||||
const rvEnergyEl = document.getElementById('rv-energy');
|
||||
const rvRefineEl = document.getElementById('rv-refine');
|
||||
const rvImpactEl = document.getElementById('rv-impact');
|
||||
if (rvEnergyEl) rvEnergyEl.textContent = (rvStats.energy || 0).toFixed(2);
|
||||
if (rvRefineEl) rvRefineEl.textContent = ((rvStats.refinementMag || 0) * 1000).toFixed(1) + 'px';
|
||||
if (rvImpactEl) {
|
||||
const impact = Math.min(100, (rvStats.refinementMag || 0) * 5000);
|
||||
rvImpactEl.textContent = impact.toFixed(0) + '%';
|
||||
}
|
||||
// Pulse the pipeline stages when active
|
||||
if (visualCnn.useRuVector && rvStats.energy > 0.1) {
|
||||
document.querySelectorAll('.rv-stage').forEach(el => el.classList.add('active'));
|
||||
}
|
||||
|
||||
// RSSI update
|
||||
updateRssi(csiSimulator.rssiDbm);
|
||||
|
||||
// One-time diagnostic
|
||||
if (!_diagDone) {
|
||||
_diagDone = true;
|
||||
console.log(`[PoseFusion] frame 1 OK — mode=${mode}, csi.bufLen=${csiSimulator.amplitudeBuffer.length}, embPts=${embPoints?.fused?.length ?? 0}, rssi=${(csiSimulator.rssiDbm ?? -99).toFixed(1)}`);
|
||||
}
|
||||
|
||||
} catch (err) {
|
||||
if (!_loopErrorShown) {
|
||||
_loopErrorShown = true;
|
||||
console.error('[MainLoop]', err);
|
||||
// Show error visually on page
|
||||
const errDiv = document.createElement('div');
|
||||
errDiv.style.cssText = 'position:fixed;bottom:60px;left:24px;right:24px;background:rgba(255,48,64,0.95);color:#fff;padding:12px 16px;border-radius:8px;font:12px/1.4 "JetBrains Mono",monospace;z-index:9999;max-height:120px;overflow:auto';
|
||||
errDiv.textContent = `[MainLoop Error] ${err.message}\n${err.stack?.split('\n').slice(0,3).join('\n')}`;
|
||||
document.body.appendChild(errDiv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === RSSI Visualization ===
|
||||
function updateRssi(dbm) {
|
||||
if (!rssiBarEl) return;
|
||||
|
||||
// Clamp to typical WiFi range: -100 (worst) to -30 (best)
|
||||
const clamped = Math.max(-100, Math.min(-30, dbm));
|
||||
const pct = ((clamped + 100) / 70) * 100; // 0-100%
|
||||
|
||||
rssiBarEl.style.width = `${pct}%`;
|
||||
rssiValueEl.textContent = `${Math.round(clamped)} dBm`;
|
||||
|
||||
// Quality label
|
||||
let quality;
|
||||
if (clamped > -50) quality = 'Excellent';
|
||||
else if (clamped > -60) quality = 'Good';
|
||||
else if (clamped > -70) quality = 'Fair';
|
||||
else if (clamped > -80) quality = 'Weak';
|
||||
else quality = 'Poor';
|
||||
rssiQualityEl.textContent = quality;
|
||||
|
||||
// Color the dBm value based on quality
|
||||
if (clamped > -60) rssiValueEl.style.color = 'var(--green-glow)';
|
||||
else if (clamped > -75) rssiValueEl.style.color = 'var(--amber)';
|
||||
else rssiValueEl.style.color = 'var(--red-alert)';
|
||||
|
||||
// Sparkline history
|
||||
rssiHistory.push(clamped);
|
||||
if (rssiHistory.length > RSSI_HISTORY_MAX) rssiHistory.shift();
|
||||
drawRssiSparkline();
|
||||
}
|
||||
|
||||
function drawRssiSparkline() {
|
||||
if (!rssiSparkCtx || rssiHistory.length < 2) return;
|
||||
const w = rssiSparkCanvas.width;
|
||||
const h = rssiSparkCanvas.height;
|
||||
const ctx = rssiSparkCtx;
|
||||
|
||||
ctx.clearRect(0, 0, w, h);
|
||||
|
||||
// Draw signal strength line
|
||||
const len = rssiHistory.length;
|
||||
const step = w / (RSSI_HISTORY_MAX - 1);
|
||||
|
||||
// Gradient fill under line
|
||||
const grad = ctx.createLinearGradient(0, 0, 0, h);
|
||||
grad.addColorStop(0, 'rgba(0,210,120,0.3)');
|
||||
grad.addColorStop(1, 'rgba(0,210,120,0)');
|
||||
|
||||
ctx.beginPath();
|
||||
for (let i = 0; i < len; i++) {
|
||||
const x = (RSSI_HISTORY_MAX - len + i) * step;
|
||||
const y = h - ((rssiHistory[i] + 100) / 70) * h;
|
||||
if (i === 0) ctx.moveTo(x, y);
|
||||
else ctx.lineTo(x, y);
|
||||
}
|
||||
// Fill area
|
||||
const lastX = (RSSI_HISTORY_MAX - 1) * step;
|
||||
const firstX = (RSSI_HISTORY_MAX - len) * step;
|
||||
ctx.lineTo(lastX, h);
|
||||
ctx.lineTo(firstX, h);
|
||||
ctx.closePath();
|
||||
ctx.fillStyle = grad;
|
||||
ctx.fill();
|
||||
|
||||
// Draw line on top
|
||||
ctx.beginPath();
|
||||
for (let i = 0; i < len; i++) {
|
||||
const x = (RSSI_HISTORY_MAX - len + i) * step;
|
||||
const y = h - ((rssiHistory[i] + 100) / 70) * h;
|
||||
if (i === 0) ctx.moveTo(x, y);
|
||||
else ctx.lineTo(x, y);
|
||||
}
|
||||
ctx.strokeStyle = '#00d878';
|
||||
ctx.lineWidth = 1.5;
|
||||
ctx.stroke();
|
||||
|
||||
// Pulsing dot at latest value
|
||||
const latestX = lastX;
|
||||
const latestY = h - ((rssiHistory[len - 1] + 100) / 70) * h;
|
||||
const pulse = 0.5 + 0.5 * Math.sin(performance.now() / 300);
|
||||
ctx.beginPath();
|
||||
ctx.arc(latestX, latestY, 2 + pulse, 0, Math.PI * 2);
|
||||
ctx.fillStyle = '#00d878';
|
||||
ctx.fill();
|
||||
ctx.beginPath();
|
||||
ctx.arc(latestX, latestY, 4 + pulse * 2, 0, Math.PI * 2);
|
||||
ctx.strokeStyle = `rgba(0,216,120,${0.3 + pulse * 0.3})`;
|
||||
ctx.lineWidth = 1;
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// Boot
|
||||
|
||||
+139
-319
@@ -9,35 +9,24 @@
|
||||
* When person exits frame, CSI data continues tracking (through-wall mode).
|
||||
*/
|
||||
|
||||
// Extended keypoint definitions: 17 COCO + 9 hand/fingertip approximations = 26 total
|
||||
// COCO keypoint definitions
|
||||
export const KEYPOINT_NAMES = [
|
||||
'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
|
||||
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
|
||||
'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
|
||||
'left_knee', 'right_knee', 'left_ankle', 'right_ankle',
|
||||
// Extended: hand keypoints (17-25)
|
||||
'left_thumb', 'left_index', 'left_pinky', // 17, 18, 19
|
||||
'right_thumb', 'right_index', 'right_pinky', // 20, 21, 22
|
||||
'left_foot_index', 'right_foot_index', // 23, 24 (toe tips)
|
||||
'neck', // 25 (mid-shoulder)
|
||||
'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
|
||||
];
|
||||
|
||||
// Skeleton connections (pairs of keypoint indices)
|
||||
export const SKELETON_CONNECTIONS = [
|
||||
[0, 1], [0, 2], [1, 3], [2, 4], // Head
|
||||
[0, 25], // Nose → neck
|
||||
[25, 5], [25, 6], // Neck → shoulders
|
||||
[5, 6], // Shoulders
|
||||
[5, 7], [7, 9], // Left arm
|
||||
[6, 8], [8, 10], // Right arm
|
||||
[5, 11], [6, 12], // Torso
|
||||
[11, 12], // Hips
|
||||
[11, 13], [13, 15], // Left leg
|
||||
[12, 14], [14, 16], // Right leg
|
||||
// Hand connections
|
||||
[9, 17], [9, 18], [9, 19], // Left wrist → fingers
|
||||
[10, 20], [10, 21], [10, 22], // Right wrist → fingers
|
||||
// Foot connections
|
||||
[15, 23], [16, 24], // Ankles → toes
|
||||
];
|
||||
|
||||
// Standard body proportions (relative to body height)
|
||||
@@ -52,19 +41,13 @@ const PROPORTIONS = {
|
||||
kneeToAnkle: 0.24,
|
||||
eyeSpacing: 0.04,
|
||||
earSpacing: 0.07,
|
||||
// Hand proportions
|
||||
wristToFinger: 0.09,
|
||||
fingerSpread: 0.04,
|
||||
thumbAngle: 0.6, // radians from wrist-elbow axis
|
||||
// Foot proportions
|
||||
ankleToToe: 0.06,
|
||||
};
|
||||
|
||||
export class PoseDecoder {
|
||||
constructor(embeddingDim = 128) {
|
||||
this.embeddingDim = embeddingDim;
|
||||
this.smoothedKeypoints = null;
|
||||
this.smoothingFactor = 0.25; // Low = responsive to real movement
|
||||
this.smoothingFactor = 0.45; // Lower = more responsive to movement
|
||||
this._time = 0;
|
||||
|
||||
// Through-wall tracking state
|
||||
@@ -73,53 +56,12 @@ export class PoseDecoder {
|
||||
this._ghostConfidence = 0;
|
||||
this._ghostVelocity = { x: 0, y: 0 };
|
||||
|
||||
// Zone centroid tracking (normalized 0-1 positions)
|
||||
this._headCx = 0.5;
|
||||
this._headCy = 0.15;
|
||||
this._leftArmCx = 0.3;
|
||||
this._leftArmCy = 0.35;
|
||||
this._rightArmCx = 0.7;
|
||||
this._rightArmCy = 0.35;
|
||||
this._leftLegCx = 0.4;
|
||||
this._leftLegCy = 0.8;
|
||||
this._rightLegCx = 0.6;
|
||||
this._rightLegCy = 0.8;
|
||||
this._torsoCx = 0.5;
|
||||
this._torsoCy = 0.45;
|
||||
|
||||
// RuVector embedding → joint mapping
|
||||
// Each joint gets 2 consecutive embedding dimensions (dx, dy offset)
|
||||
// and 1 dimension for confidence modulation. 26 joints × 3 = 78 dims used from 128.
|
||||
// Remaining 50 dims encode global pose features (body scale, rotation, lean).
|
||||
this._jointEmbMap = this._buildJointEmbeddingMap(embeddingDim);
|
||||
|
||||
// Attention contribution tracking (for UI overlay)
|
||||
this.attentionStats = { energy: 0, maxDim: 0, refinementMag: 0 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the mapping from embedding dimensions to joint refinement signals.
|
||||
* This maps the RuVector attention output to anatomically meaningful joint offsets.
|
||||
*/
|
||||
_buildJointEmbeddingMap(dim) {
|
||||
const map = [];
|
||||
// 26 joints × 3 dims each (dx, dy, confidence_mod) = 78 dims
|
||||
for (let j = 0; j < 26; j++) {
|
||||
const base = j * 3;
|
||||
if (base + 2 < dim) {
|
||||
map.push({ dxDim: base, dyDim: base + 1, confDim: base + 2 });
|
||||
} else {
|
||||
map.push({ dxDim: j % dim, dyDim: (j + 1) % dim, confDim: (j + 2) % dim });
|
||||
}
|
||||
}
|
||||
// Global pose features from dims 78-127
|
||||
return {
|
||||
joints: map,
|
||||
scaleDim: Math.min(78, dim - 1), // body scale factor
|
||||
rotDim: Math.min(79, dim - 1), // body rotation
|
||||
leanXDim: Math.min(80, dim - 1), // lateral lean
|
||||
leanYDim: Math.min(81, dim - 1), // forward/back lean
|
||||
};
|
||||
// Arm tracking history (smoothed positions)
|
||||
this._leftArmY = 0.5;
|
||||
this._rightArmY = 0.5;
|
||||
this._leftArmX = 0;
|
||||
this._rightArmX = 0;
|
||||
this._headOffsetX = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -183,129 +125,71 @@ export class PoseDecoder {
|
||||
|
||||
/**
|
||||
* Track body parts from the motion grid.
|
||||
* Finds the centroid of motion in each body zone and positions joints there.
|
||||
* The grid tells us WHERE motion is happening → we map that to joint positions.
|
||||
*/
|
||||
_trackFromMotionGrid(region, embedding, elapsed) {
|
||||
const grid = region.motionGrid;
|
||||
const cols = region.gridCols || 10;
|
||||
const rows = region.gridRows || 8;
|
||||
|
||||
// Body bounding box (in normalized 0-1 coords)
|
||||
const bx = region.x, by = region.y, bw = region.w, bh = region.h;
|
||||
const cx = bx + bw / 2;
|
||||
const cy = by + bh / 2;
|
||||
const bodyH = Math.max(bh, 0.3);
|
||||
const bodyW = Math.max(bw, 0.15);
|
||||
// Body bounding box
|
||||
const cx = region.x + region.w / 2;
|
||||
const cy = region.y + region.h / 2;
|
||||
const bodyH = Math.max(region.h, 0.3);
|
||||
const bodyW = Math.max(region.w, 0.15);
|
||||
|
||||
// Find motion centroids per body zone from the grid
|
||||
// Analyze the motion grid to find arm positions
|
||||
// Divide body into zones: head (top 20%), arms (top 60% sides), torso (center), legs (bottom 40%)
|
||||
if (grid) {
|
||||
const zones = this._findZoneCentroids(grid, cols, rows, bx, by, bw, bh);
|
||||
// Smooth with low alpha for responsiveness
|
||||
const a = 0.3; // 30% old, 70% new → responsive
|
||||
this._headCx = a * this._headCx + (1 - a) * zones.head.x;
|
||||
this._headCy = a * this._headCy + (1 - a) * zones.head.y;
|
||||
this._leftArmCx = a * this._leftArmCx + (1 - a) * zones.leftArm.x;
|
||||
this._leftArmCy = a * this._leftArmCy + (1 - a) * zones.leftArm.y;
|
||||
this._rightArmCx= a * this._rightArmCx+ (1 - a) * zones.rightArm.x;
|
||||
this._rightArmCy= a * this._rightArmCy+ (1 - a) * zones.rightArm.y;
|
||||
this._leftLegCx = a * this._leftLegCx + (1 - a) * zones.leftLeg.x;
|
||||
this._leftLegCy = a * this._leftLegCy + (1 - a) * zones.leftLeg.y;
|
||||
this._rightLegCx= a * this._rightLegCx+ (1 - a) * zones.rightLeg.x;
|
||||
this._rightLegCy= a * this._rightLegCy+ (1 - a) * zones.rightLeg.y;
|
||||
this._torsoCx = a * this._torsoCx + (1 - a) * zones.torso.x;
|
||||
this._torsoCy = a * this._torsoCy + (1 - a) * zones.torso.y;
|
||||
const armAnalysis = this._analyzeArmMotion(grid, cols, rows, region);
|
||||
// Smooth arm tracking
|
||||
this._leftArmY = 0.6 * this._leftArmY + 0.4 * armAnalysis.leftArmHeight;
|
||||
this._rightArmY = 0.6 * this._rightArmY + 0.4 * armAnalysis.rightArmHeight;
|
||||
this._leftArmX = 0.6 * this._leftArmX + 0.4 * armAnalysis.leftArmSpread;
|
||||
this._rightArmX = 0.6 * this._rightArmX + 0.4 * armAnalysis.rightArmSpread;
|
||||
this._headOffsetX = 0.7 * this._headOffsetX + 0.3 * armAnalysis.headOffsetX;
|
||||
}
|
||||
|
||||
const P = PROPORTIONS;
|
||||
const halfW = P.shoulderWidth * bodyH / 2;
|
||||
const hipHalfW = P.hipWidth * bodyH / 2;
|
||||
|
||||
// Breathing (subtle)
|
||||
const breathe = Math.sin(elapsed * 1.5) * 0.002;
|
||||
|
||||
// === Position joints using tracked centroids ===
|
||||
// Core body positions from detection center
|
||||
const hipY = cy + bodyH * 0.15;
|
||||
const shoulderY = hipY - P.shoulderToHip * bodyH + breathe;
|
||||
const headY = shoulderY - P.headToShoulder * bodyH;
|
||||
const kneeY = hipY + P.hipToKnee * bodyH;
|
||||
const ankleY = kneeY + P.kneeToAnkle * bodyH;
|
||||
|
||||
// HEAD: tracked centroid (top zone)
|
||||
const headX = this._headCx;
|
||||
const headY = this._headCy;
|
||||
// HEAD follows motion centroid
|
||||
const headX = cx + this._headOffsetX * bodyW * 0.3;
|
||||
|
||||
// TORSO center drives shoulder/hip
|
||||
const torsoX = this._torsoCx;
|
||||
const shoulderY = this._torsoCy - bodyH * 0.08 + breathe;
|
||||
const halfW = P.shoulderWidth * bodyH / 2;
|
||||
const hipHalfW = P.hipWidth * bodyH / 2;
|
||||
const hipY = shoulderY + P.shoulderToHip * bodyH;
|
||||
// ARM POSITIONS driven by motion grid analysis
|
||||
// leftArmY: 0 = arm down at side, 1 = arm fully raised
|
||||
// leftArmSpread: how far out the arm extends
|
||||
const leftArmRaise = this._leftArmY; // 0-1
|
||||
const rightArmRaise = this._rightArmY;
|
||||
const leftSpread = 0.02 + this._leftArmX * 0.12;
|
||||
const rightSpread = 0.02 + this._rightArmX * 0.12;
|
||||
|
||||
// ARMS: elbow + wrist driven toward arm zone centroids
|
||||
// Left arm: shoulder is fixed, elbow/wrist pulled toward left arm centroid
|
||||
const lShX = torsoX - halfW;
|
||||
const lShY = shoulderY;
|
||||
// Vector from shoulder toward arm centroid
|
||||
const lArmDx = this._leftArmCx - lShX;
|
||||
const lArmDy = this._leftArmCy - lShY;
|
||||
const lArmDist = Math.sqrt(lArmDx * lArmDx + lArmDy * lArmDy) || 0.01;
|
||||
const lArmNx = lArmDx / lArmDist;
|
||||
const lArmNy = lArmDy / lArmDist;
|
||||
// Elbow at shoulderToElbow distance along that direction
|
||||
const elbowLen = P.shoulderToElbow * bodyH;
|
||||
const lElbowX = lShX + lArmNx * elbowLen;
|
||||
const lElbowY = lShY + lArmNy * elbowLen;
|
||||
// Wrist continues further
|
||||
const wristLen = P.elbowToWrist * bodyH;
|
||||
const lWristX = lElbowX + lArmNx * wristLen;
|
||||
const lWristY = lElbowY + lArmNy * wristLen;
|
||||
// Elbow: interpolate between "at side" and "raised"
|
||||
const lElbowY = shoulderY + P.shoulderToElbow * bodyH * (1 - leftArmRaise * 0.9);
|
||||
const rElbowY = shoulderY + P.shoulderToElbow * bodyH * (1 - rightArmRaise * 0.9);
|
||||
const lElbowX = cx - halfW - leftSpread;
|
||||
const rElbowX = cx + halfW + rightSpread;
|
||||
|
||||
// Right arm: same approach
|
||||
const rShX = torsoX + halfW;
|
||||
const rShY = shoulderY;
|
||||
const rArmDx = this._rightArmCx - rShX;
|
||||
const rArmDy = this._rightArmCy - rShY;
|
||||
const rArmDist = Math.sqrt(rArmDx * rArmDx + rArmDy * rArmDy) || 0.01;
|
||||
const rArmNx = rArmDx / rArmDist;
|
||||
const rArmNy = rArmDy / rArmDist;
|
||||
const rElbowX = rShX + rArmNx * elbowLen;
|
||||
const rElbowY = rShY + rArmNy * elbowLen;
|
||||
const rWristX = rElbowX + rArmNx * wristLen;
|
||||
const rWristY = rElbowY + rArmNy * wristLen;
|
||||
// Wrist: extends further when raised
|
||||
const lWristY = lElbowY + P.elbowToWrist * bodyH * (1 - leftArmRaise * 1.1);
|
||||
const rWristY = rElbowY + P.elbowToWrist * bodyH * (1 - rightArmRaise * 1.1);
|
||||
const lWristX = lElbowX - leftSpread * 0.6;
|
||||
const rWristX = rElbowX + rightSpread * 0.6;
|
||||
|
||||
// LEGS: knees/ankles pulled toward leg zone centroids
|
||||
const lHipX = torsoX - hipHalfW;
|
||||
const rHipX = torsoX + hipHalfW;
|
||||
const lLegDx = this._leftLegCx - lHipX;
|
||||
const lLegDy = Math.max(0.05, this._leftLegCy - hipY); // always downward
|
||||
const lLegDist = Math.sqrt(lLegDx * lLegDx + lLegDy * lLegDy) || 0.01;
|
||||
const lLegNx = lLegDx / lLegDist;
|
||||
const lLegNy = lLegDy / lLegDist;
|
||||
const kneeLen = P.hipToKnee * bodyH;
|
||||
const ankleLen = P.kneeToAnkle * bodyH;
|
||||
const lKneeX = lHipX + lLegNx * kneeLen;
|
||||
const lKneeY = hipY + lLegNy * kneeLen;
|
||||
const lAnkleX = lKneeX + lLegNx * ankleLen;
|
||||
const lAnkleY = lKneeY + lLegNy * ankleLen;
|
||||
|
||||
const rLegDx = this._rightLegCx - rHipX;
|
||||
const rLegDy = Math.max(0.05, this._rightLegCy - hipY);
|
||||
const rLegDist = Math.sqrt(rLegDx * rLegDx + rLegDy * rLegDy) || 0.01;
|
||||
const rLegNx = rLegDx / rLegDist;
|
||||
const rLegNy = rLegDy / rLegDist;
|
||||
const rKneeX = rHipX + rLegNx * kneeLen;
|
||||
const rKneeY = hipY + rLegNy * kneeLen;
|
||||
const rAnkleX = rKneeX + rLegNx * ankleLen;
|
||||
const rAnkleY = rKneeY + rLegNy * ankleLen;
|
||||
|
||||
// Arm raise amount (for hand openness)
|
||||
const leftArmRaise = Math.max(0, Math.min(1, (shoulderY - this._leftArmCy) / (bodyH * 0.3)));
|
||||
const rightArmRaise = Math.max(0, Math.min(1, (shoulderY - this._rightArmCy) / (bodyH * 0.3)));
|
||||
|
||||
// Compute hand finger positions from wrist-elbow axis
|
||||
const lHandAngle = Math.atan2(lWristY - lElbowY, lWristX - lElbowX);
|
||||
const rHandAngle = Math.atan2(rWristY - rElbowY, rWristX - rElbowX);
|
||||
const fingerLen = P.wristToFinger * bodyH;
|
||||
const fingerSpr = P.fingerSpread * bodyH;
|
||||
|
||||
// Hand openness driven by arm raise + arm lateral spread
|
||||
const lArmSpread = Math.abs(this._leftArmCx - (bx + bw * 0.3)) / (bw * 0.3);
|
||||
const rArmSpread = Math.abs(this._rightArmCx - (bx + bw * 0.7)) / (bw * 0.3);
|
||||
const lHandOpen = Math.min(1, leftArmRaise * 0.5 + lArmSpread * 0.5);
|
||||
const rHandOpen = Math.min(1, rightArmRaise * 0.5 + rArmSpread * 0.5);
|
||||
// Leg motion from lower grid cells
|
||||
const legMotion = grid ? this._analyzeLegMotion(grid, cols, rows) : { left: 0, right: 0 };
|
||||
const legSwing = 0.015;
|
||||
|
||||
const keypoints = [
|
||||
// 0: nose
|
||||
@@ -319,9 +203,9 @@ export class PoseDecoder {
|
||||
// 4: right_ear
|
||||
{ x: headX + P.earSpacing * bodyH, y: headY + 0.005, confidence: 0.72 },
|
||||
// 5: left_shoulder
|
||||
{ x: lShX, y: lShY, confidence: 0.94 },
|
||||
{ x: cx - halfW, y: shoulderY, confidence: 0.94 },
|
||||
// 6: right_shoulder
|
||||
{ x: rShX, y: rShY, confidence: 0.94 },
|
||||
{ x: cx + halfW, y: shoulderY, confidence: 0.94 },
|
||||
// 7: left_elbow
|
||||
{ x: lElbowX, y: lElbowY, confidence: 0.87 },
|
||||
// 8: right_elbow
|
||||
@@ -331,179 +215,115 @@ export class PoseDecoder {
|
||||
// 10: right_wrist
|
||||
{ x: rWristX, y: rWristY, confidence: 0.82 },
|
||||
// 11: left_hip
|
||||
{ x: lHipX, y: hipY, confidence: 0.91 },
|
||||
{ x: cx - hipHalfW, y: hipY, confidence: 0.91 },
|
||||
// 12: right_hip
|
||||
{ x: rHipX, y: hipY, confidence: 0.91 },
|
||||
{ x: cx + hipHalfW, y: hipY, confidence: 0.91 },
|
||||
// 13: left_knee
|
||||
{ x: lKneeX, y: lKneeY, confidence: 0.88 },
|
||||
{ x: cx - hipHalfW + legMotion.left * legSwing, y: kneeY, confidence: 0.88 },
|
||||
// 14: right_knee
|
||||
{ x: rKneeX, y: rKneeY, confidence: 0.88 },
|
||||
{ x: cx + hipHalfW + legMotion.right * legSwing, y: kneeY, confidence: 0.88 },
|
||||
// 15: left_ankle
|
||||
{ x: lAnkleX, y: lAnkleY, confidence: 0.83 },
|
||||
{ x: cx - hipHalfW + legMotion.left * legSwing * 1.3, y: ankleY, confidence: 0.83 },
|
||||
// 16: right_ankle
|
||||
{ x: rAnkleX, y: rAnkleY, confidence: 0.83 },
|
||||
|
||||
// === Extended keypoints (17-25) ===
|
||||
|
||||
// 17: left_thumb — offset at thumb angle from wrist-elbow axis
|
||||
{ x: lWristX + fingerLen * Math.cos(lHandAngle + P.thumbAngle) * (0.6 + lHandOpen * 0.4),
|
||||
y: lWristY + fingerLen * Math.sin(lHandAngle + P.thumbAngle) * (0.6 + lHandOpen * 0.4),
|
||||
confidence: 0.68 * (0.5 + lHandOpen * 0.5) },
|
||||
// 18: left_index — extends along wrist-elbow axis
|
||||
{ x: lWristX + fingerLen * Math.cos(lHandAngle) + fingerSpr * lHandOpen * Math.cos(lHandAngle + 0.3),
|
||||
y: lWristY + fingerLen * Math.sin(lHandAngle) + fingerSpr * lHandOpen * Math.sin(lHandAngle + 0.3),
|
||||
confidence: 0.72 * (0.5 + lHandOpen * 0.5) },
|
||||
// 19: left_pinky — offset opposite thumb
|
||||
{ x: lWristX + fingerLen * 0.85 * Math.cos(lHandAngle - P.thumbAngle * 0.7),
|
||||
y: lWristY + fingerLen * 0.85 * Math.sin(lHandAngle - P.thumbAngle * 0.7),
|
||||
confidence: 0.60 * (0.5 + lHandOpen * 0.5) },
|
||||
|
||||
// 20: right_thumb
|
||||
{ x: rWristX + fingerLen * Math.cos(rHandAngle - P.thumbAngle) * (0.6 + rHandOpen * 0.4),
|
||||
y: rWristY + fingerLen * Math.sin(rHandAngle - P.thumbAngle) * (0.6 + rHandOpen * 0.4),
|
||||
confidence: 0.68 * (0.5 + rHandOpen * 0.5) },
|
||||
// 21: right_index
|
||||
{ x: rWristX + fingerLen * Math.cos(rHandAngle) + fingerSpr * rHandOpen * Math.cos(rHandAngle - 0.3),
|
||||
y: rWristY + fingerLen * Math.sin(rHandAngle) + fingerSpr * rHandOpen * Math.sin(rHandAngle - 0.3),
|
||||
confidence: 0.72 * (0.5 + rHandOpen * 0.5) },
|
||||
// 22: right_pinky
|
||||
{ x: rWristX + fingerLen * 0.85 * Math.cos(rHandAngle + P.thumbAngle * 0.7),
|
||||
y: rWristY + fingerLen * 0.85 * Math.sin(rHandAngle + P.thumbAngle * 0.7),
|
||||
confidence: 0.60 * (0.5 + rHandOpen * 0.5) },
|
||||
|
||||
// 23: left_foot_index (toe tip) — extends forward from ankle
|
||||
{ x: lAnkleX + P.ankleToToe * bodyH * 0.5,
|
||||
y: lAnkleY + P.ankleToToe * bodyH * 0.3,
|
||||
confidence: 0.65 },
|
||||
// 24: right_foot_index
|
||||
{ x: rAnkleX + P.ankleToToe * bodyH * 0.5,
|
||||
y: rAnkleY + P.ankleToToe * bodyH * 0.3,
|
||||
confidence: 0.65 },
|
||||
|
||||
// 25: neck (midpoint between shoulders, slightly above)
|
||||
{ x: (lShX + rShX) / 2, y: shoulderY - P.headToShoulder * bodyH * 0.35, confidence: 0.93 },
|
||||
{ x: cx + hipHalfW + legMotion.right * legSwing * 1.3, y: ankleY, confidence: 0.83 },
|
||||
];
|
||||
|
||||
for (let i = 0; i < keypoints.length; i++) {
|
||||
keypoints[i].name = KEYPOINT_NAMES[i];
|
||||
}
|
||||
|
||||
// === RuVector Attention Embedding Refinement ===
|
||||
// Compute attention stats for the UI pipeline display, but only apply
|
||||
// positional refinement when a trained model is loaded (random-weight
|
||||
// embeddings carry no meaningful spatial signal and distort the skeleton).
|
||||
if (embedding && embedding.length >= 26 * 3) {
|
||||
this._computeEmbeddingStats(keypoints, embedding, bodyH);
|
||||
}
|
||||
|
||||
return keypoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply RuVector attention embedding to refine joint positions and confidence.
|
||||
*
|
||||
* The 128-dim fused embedding is decoded as:
|
||||
* - Dims 0-77: Per-joint (dx, dy, confidence_mod) × 26 joints
|
||||
* - Dims 78-81: Global pose parameters (scale, rotation, lean)
|
||||
* - Dims 82-127: Reserved for cross-modal fusion features
|
||||
*
|
||||
* The attention mechanism determines HOW MUCH each spatial region contributes
|
||||
* to each joint's refinement. Multi-Head captures global relationships,
|
||||
* Hyperbolic captures hierarchical (torso→limb→hand) dependencies,
|
||||
* MoE routes different body regions to specialized experts,
|
||||
* Linear provides fast extremity refinement, Local-Global balances detail/context.
|
||||
* Analyze the motion grid to determine arm positions.
|
||||
* Left side of grid = left side of body, etc.
|
||||
*/
|
||||
/**
|
||||
* Compute embedding statistics for UI display without modifying joint positions.
|
||||
* The 6-stage attention pipeline stats are shown in the RuVector panel.
|
||||
* Position refinement is disabled until a trained model replaces random weights.
|
||||
*/
|
||||
_computeEmbeddingStats(keypoints, emb, bodyH) {
|
||||
const map = this._jointEmbMap;
|
||||
const tc = (v) => Math.tanh(Number(v) || 0);
|
||||
_analyzeArmMotion(grid, cols, rows, region) {
|
||||
// Body center column
|
||||
const centerCol = Math.floor(cols / 2);
|
||||
|
||||
// Embedding energy (L2 norm of the used dims)
|
||||
let energy = 0;
|
||||
for (let i = 0; i < Math.min(emb.length, 82); i++) {
|
||||
energy += emb[i] * emb[i];
|
||||
}
|
||||
energy = Math.sqrt(energy);
|
||||
// Upper body rows (top 60% of detected region)
|
||||
const upperEnd = Math.floor(rows * 0.6);
|
||||
|
||||
// Simulated per-joint refinement magnitude (what WOULD be applied)
|
||||
const scale = bodyH * 0.015;
|
||||
let totalRefinement = 0;
|
||||
let maxDimVal = 0;
|
||||
// Compute motion intensity for left vs right, at different heights
|
||||
let leftUpperMotion = 0, leftMidMotion = 0;
|
||||
let rightUpperMotion = 0, rightMidMotion = 0;
|
||||
let leftCount = 0, rightCount = 0;
|
||||
let headMotionX = 0, headMotionWeight = 0;
|
||||
|
||||
for (let j = 0; j < Math.min(keypoints.length, 26); j++) {
|
||||
const jmap = map.joints[j];
|
||||
if (!jmap) continue;
|
||||
const dx = tc(emb[jmap.dxDim]) * scale;
|
||||
const dy = tc(emb[jmap.dyDim]) * scale;
|
||||
totalRefinement += Math.sqrt(dx * dx + dy * dy);
|
||||
maxDimVal = Math.max(maxDimVal, Math.abs(tc(emb[jmap.dxDim])), Math.abs(tc(emb[jmap.dyDim])));
|
||||
}
|
||||
for (let r = 0; r < upperEnd; r++) {
|
||||
const heightWeight = 1.0 - (r / upperEnd) * 0.3; // Upper rows weighted more
|
||||
|
||||
this.attentionStats.energy = energy;
|
||||
this.attentionStats.maxDim = maxDimVal;
|
||||
this.attentionStats.refinementMag = totalRefinement / 26;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find weighted motion centroids for each body zone.
|
||||
* Divides the bounding box into 6 zones: head, left arm, right arm, torso, left leg, right leg.
|
||||
* Returns the (x,y) centroid of motion intensity for each zone.
|
||||
*/
|
||||
_findZoneCentroids(grid, cols, rows, bx, by, bw, bh) {
|
||||
// Zone definitions (in grid-relative fractions)
|
||||
const zones = {
|
||||
head: { rMin: 0, rMax: 0.2, cMin: 0.25, cMax: 0.75, wx: 0, wy: 0, wt: 0 },
|
||||
leftArm: { rMin: 0.1, rMax: 0.6, cMin: 0, cMax: 0.35, wx: 0, wy: 0, wt: 0 },
|
||||
rightArm: { rMin: 0.1, rMax: 0.6, cMin: 0.65, cMax: 1.0, wx: 0, wy: 0, wt: 0 },
|
||||
torso: { rMin: 0.15, rMax: 0.55, cMin: 0.3, cMax: 0.7, wx: 0, wy: 0, wt: 0 },
|
||||
leftLeg: { rMin: 0.5, rMax: 1.0, cMin: 0.1, cMax: 0.5, wx: 0, wy: 0, wt: 0 },
|
||||
rightLeg: { rMin: 0.5, rMax: 1.0, cMin: 0.5, cMax: 0.9, wx: 0, wy: 0, wt: 0 },
|
||||
};
|
||||
|
||||
// Accumulate weighted centroids per zone
|
||||
for (let r = 0; r < rows; r++) {
|
||||
const ry = r / rows; // 0-1 within grid
|
||||
for (let c = 0; c < cols; c++) {
|
||||
const cx_g = c / cols; // 0-1 within grid
|
||||
const val = grid[r][c];
|
||||
if (val < 0.005) continue; // skip near-zero motion
|
||||
|
||||
// Map grid position to body-space coordinates (0-1)
|
||||
const worldX = bx + cx_g * bw;
|
||||
const worldY = by + ry * bh;
|
||||
|
||||
// Assign to matching zones (a cell can contribute to multiple overlapping zones)
|
||||
for (const z of Object.values(zones)) {
|
||||
if (ry >= z.rMin && ry < z.rMax && cx_g >= z.cMin && cx_g < z.cMax) {
|
||||
z.wx += worldX * val;
|
||||
z.wy += worldY * val;
|
||||
z.wt += val;
|
||||
}
|
||||
// Head zone: top 25%, center 40% of width
|
||||
if (r < Math.floor(rows * 0.25)) {
|
||||
const headLeft = Math.floor(cols * 0.3);
|
||||
const headRight = Math.floor(cols * 0.7);
|
||||
for (let c = headLeft; c <= headRight; c++) {
|
||||
const val = grid[r][c];
|
||||
headMotionX += (c / cols - 0.5) * val;
|
||||
headMotionWeight += val;
|
||||
}
|
||||
}
|
||||
|
||||
// Left arm zone: left 40% of grid
|
||||
for (let c = 0; c < Math.floor(cols * 0.4); c++) {
|
||||
const val = grid[r][c];
|
||||
if (r < rows * 0.3) leftUpperMotion += val * heightWeight;
|
||||
else leftMidMotion += val * heightWeight;
|
||||
leftCount++;
|
||||
}
|
||||
|
||||
// Right arm zone: right 40% of grid
|
||||
for (let c = Math.floor(cols * 0.6); c < cols; c++) {
|
||||
const val = grid[r][c];
|
||||
if (r < rows * 0.3) rightUpperMotion += val * heightWeight;
|
||||
else rightMidMotion += val * heightWeight;
|
||||
rightCount++;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute centroids with fallback defaults
|
||||
const centroid = (z, defX, defY) => ({
|
||||
x: z.wt > 0.01 ? z.wx / z.wt : defX,
|
||||
y: z.wt > 0.01 ? z.wy / z.wt : defY,
|
||||
weight: z.wt
|
||||
});
|
||||
// Normalize
|
||||
const leftTotal = leftUpperMotion + leftMidMotion;
|
||||
const rightTotal = rightUpperMotion + rightMidMotion;
|
||||
const maxMotion = 0.15; // Calibration threshold
|
||||
|
||||
const midX = bx + bw / 2;
|
||||
const midY = by + bh / 2;
|
||||
// Arm height: 0 = at side, 1 = raised
|
||||
// High motion in upper-left → left arm is raised
|
||||
const leftArmHeight = Math.min(1, (leftUpperMotion / maxMotion) * 2);
|
||||
const rightArmHeight = Math.min(1, (rightUpperMotion / maxMotion) * 2);
|
||||
|
||||
// Arm spread: how far out from body
|
||||
const leftArmSpread = Math.min(1, leftTotal / maxMotion);
|
||||
const rightArmSpread = Math.min(1, rightTotal / maxMotion);
|
||||
|
||||
// Head offset
|
||||
const headOffsetX = headMotionWeight > 0.01 ? headMotionX / headMotionWeight : 0;
|
||||
|
||||
return { leftArmHeight, rightArmHeight, leftArmSpread, rightArmSpread, headOffsetX };
|
||||
}
|
||||
|
||||
/**
|
||||
* Analyze lower grid for leg motion.
|
||||
*/
|
||||
_analyzeLegMotion(grid, cols, rows) {
|
||||
const lowerStart = Math.floor(rows * 0.6);
|
||||
let leftMotion = 0, rightMotion = 0;
|
||||
|
||||
for (let r = lowerStart; r < rows; r++) {
|
||||
for (let c = 0; c < Math.floor(cols / 2); c++) {
|
||||
leftMotion += grid[r][c];
|
||||
}
|
||||
for (let c = Math.floor(cols / 2); c < cols; c++) {
|
||||
rightMotion += grid[r][c];
|
||||
}
|
||||
}
|
||||
|
||||
// Return as -1 to 1 range (asymmetry indicates which leg is moving)
|
||||
const total = leftMotion + rightMotion + 0.001;
|
||||
return {
|
||||
head: centroid(zones.head, midX, by + bh * 0.1),
|
||||
leftArm: centroid(zones.leftArm, bx + bw * 0.2, midY - bh * 0.05),
|
||||
rightArm: centroid(zones.rightArm, bx + bw * 0.8, midY - bh * 0.05),
|
||||
torso: centroid(zones.torso, midX, midY),
|
||||
leftLeg: centroid(zones.leftLeg, bx + bw * 0.35,by + bh * 0.75),
|
||||
rightLeg: centroid(zones.rightLeg, bx + bw * 0.65,by + bh * 0.75),
|
||||
left: (leftMotion - rightMotion) / total,
|
||||
right: (rightMotion - leftMotion) / total
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 rUv
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,220 +0,0 @@
|
||||
# ruvector-attention-wasm
|
||||
|
||||
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Attention Mechanisms**:
|
||||
- Scaled Dot-Product Attention
|
||||
- Multi-Head Attention
|
||||
- Hyperbolic Attention (for hierarchical data)
|
||||
- Linear Attention (Performer-style)
|
||||
- Flash Attention (memory-efficient)
|
||||
- Local-Global Attention
|
||||
- Mixture of Experts (MoE) Attention
|
||||
- **CGT Sheaf Attention** (coherence-gated via Prime-Radiant)
|
||||
|
||||
- **Training Utilities**:
|
||||
- InfoNCE contrastive loss
|
||||
- Adam optimizer
|
||||
- AdamW optimizer (with decoupled weight decay)
|
||||
- Learning rate scheduler (warmup + cosine decay)
|
||||
|
||||
- **TypeScript Support**: Full type definitions and modern API
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-attention-wasm
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
```typescript
|
||||
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
|
||||
|
||||
// Initialize WASM module
|
||||
await initialize();
|
||||
|
||||
// Create multi-head attention
|
||||
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
|
||||
|
||||
// Prepare inputs
|
||||
const query = new Float32Array(64);
|
||||
const keys = [new Float32Array(64), new Float32Array(64)];
|
||||
const values = [new Float32Array(64), new Float32Array(64)];
|
||||
|
||||
// Compute attention
|
||||
const output = attention.compute(query, keys, values);
|
||||
|
||||
// Use utilities
|
||||
const similarity = utils.cosineSimilarity(query, keys[0]);
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
|
||||
#### Hyperbolic Attention
|
||||
|
||||
```typescript
|
||||
import { HyperbolicAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const hyperbolic = new HyperbolicAttention({
|
||||
dim: 128,
|
||||
curvature: 1.0
|
||||
});
|
||||
|
||||
const output = hyperbolic.compute(query, keys, values);
|
||||
```
|
||||
|
||||
#### MoE Attention with Expert Stats
|
||||
|
||||
```typescript
|
||||
import { MoEAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const moe = new MoEAttention({
|
||||
dim: 64,
|
||||
numExperts: 4,
|
||||
topK: 2
|
||||
});
|
||||
|
||||
const output = moe.compute(query, keys, values);
|
||||
|
||||
// Get expert utilization
|
||||
const stats = moe.getExpertStats();
|
||||
console.log('Load balance:', stats.loadBalance);
|
||||
```
|
||||
|
||||
#### Training with InfoNCE Loss
|
||||
|
||||
```typescript
|
||||
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
|
||||
|
||||
const loss = new InfoNCELoss(0.07);
|
||||
const optimizer = new Adam(paramCount, {
|
||||
learningRate: 0.001,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
const lossValue = loss.compute(anchor, positive, negatives);
|
||||
optimizer.step(params, gradients);
|
||||
```
|
||||
|
||||
#### Learning Rate Scheduling
|
||||
|
||||
```typescript
|
||||
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
|
||||
|
||||
const scheduler = new LRScheduler({
|
||||
initialLR: 0.001,
|
||||
warmupSteps: 1000,
|
||||
totalSteps: 10000,
|
||||
});
|
||||
|
||||
const optimizer = new AdamW(paramCount, {
|
||||
learningRate: scheduler.getLR(),
|
||||
weightDecay: 0.01,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
for (let step = 0; step < 10000; step++) {
|
||||
optimizer.learningRate = scheduler.getLR();
|
||||
optimizer.step(params, gradients);
|
||||
scheduler.step();
|
||||
}
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust 1.70+
|
||||
- wasm-pack
|
||||
|
||||
### Build Commands
|
||||
|
||||
```bash
|
||||
# Build for web (ES modules)
|
||||
wasm-pack build --target web --out-dir pkg
|
||||
|
||||
# Build for Node.js
|
||||
wasm-pack build --target nodejs --out-dir pkg-node
|
||||
|
||||
# Build for bundlers (webpack, vite, etc.)
|
||||
wasm-pack build --target bundler --out-dir pkg-bundler
|
||||
|
||||
# Run tests
|
||||
wasm-pack test --headless --firefox
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Attention Mechanisms
|
||||
|
||||
- `MultiHeadAttention` - Standard multi-head attention
|
||||
- `HyperbolicAttention` - Attention in hyperbolic space
|
||||
- `LinearAttention` - Linear complexity attention (Performer)
|
||||
- `FlashAttention` - Memory-efficient attention
|
||||
- `LocalGlobalAttention` - Combined local and global attention
|
||||
- `MoEAttention` - Mixture of Experts attention
|
||||
- `CGTSheafAttention` - Coherence-gated via Prime-Radiant energy
|
||||
- `scaledDotAttention()` - Functional API for basic attention
|
||||
|
||||
### CGT Sheaf Attention (Prime-Radiant Integration)
|
||||
|
||||
The CGT (Coherence-Gated Transformer) Sheaf Attention mechanism uses Prime-Radiant's sheaf Laplacian energy to gate attention based on mathematical consistency:
|
||||
|
||||
```typescript
|
||||
import { CGTSheafAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const cgtAttention = new CGTSheafAttention({
|
||||
dim: 128,
|
||||
numHeads: 8,
|
||||
coherenceThreshold: 0.3, // Block if energy > threshold
|
||||
});
|
||||
|
||||
// Attention is gated by coherence energy
|
||||
const result = cgtAttention.compute(query, keys, values);
|
||||
console.log('Coherence energy:', result.energy);
|
||||
console.log('Is coherent:', result.isCoherent);
|
||||
```
|
||||
|
||||
**Key features:**
|
||||
- Energy-weighted attention: Lower coherence energy → higher attention
|
||||
- Automatic hallucination detection via residual analysis
|
||||
- GPU-accelerated with wgpu WGSL shaders (vec4 optimized)
|
||||
- SIMD fallback (AVX-512/AVX2/NEON)
|
||||
|
||||
### Training
|
||||
|
||||
- `InfoNCELoss` - Contrastive loss function
|
||||
- `Adam` - Adam optimizer
|
||||
- `AdamW` - AdamW optimizer with weight decay
|
||||
- `LRScheduler` - Learning rate scheduler
|
||||
|
||||
### Utilities
|
||||
|
||||
- `utils.cosineSimilarity()` - Cosine similarity between vectors
|
||||
- `utils.l2Norm()` - L2 norm of a vector
|
||||
- `utils.normalize()` - Normalize vector to unit length
|
||||
- `utils.softmax()` - Apply softmax transformation
|
||||
- `utils.attentionWeights()` - Compute attention weights from scores
|
||||
- `utils.batchNormalize()` - Batch normalization
|
||||
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
|
||||
- `utils.pairwiseDistances()` - Compute pairwise distances
|
||||
|
||||
## Performance
|
||||
|
||||
The WASM bindings provide near-native performance for attention computations:
|
||||
|
||||
- Optimized with `opt-level = "s"` and LTO
|
||||
- SIMD acceleration where available
|
||||
- Efficient memory management
|
||||
- Zero-copy data transfer where possible
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
@@ -1,28 +0,0 @@
|
||||
{
|
||||
"name": "ruvector-attention-wasm",
|
||||
"collaborators": [
|
||||
"Ruvector Team"
|
||||
],
|
||||
"description": "High-performance WebAssembly attention mechanisms: Multi-Head, Flash, Hyperbolic, MoE, CGT Sheaf Attention with GPU acceleration for transformers and LLMs",
|
||||
"version": "2.0.5",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector"
|
||||
},
|
||||
"files": [
|
||||
"ruvector_attention_wasm_bg.wasm",
|
||||
"ruvector_attention_wasm.js",
|
||||
"ruvector_attention_wasm.d.ts"
|
||||
],
|
||||
"main": "ruvector_attention_wasm.js",
|
||||
"homepage": "https://ruv.io/ruvector",
|
||||
"types": "ruvector_attention_wasm.d.ts",
|
||||
"keywords": [
|
||||
"wasm",
|
||||
"attention",
|
||||
"transformer",
|
||||
"flash-attention",
|
||||
"llm"
|
||||
]
|
||||
}
|
||||
@@ -1,642 +0,0 @@
|
||||
/**
|
||||
* Browser ESM wrapper for ruvector-attention-wasm v2.0.5
|
||||
*
|
||||
* The upstream pkg/ was built with wasm-pack --target nodejs (CJS + fs.readFileSync).
|
||||
* This wrapper loads the same WASM binary via fetch() for browser use.
|
||||
*
|
||||
* Usage:
|
||||
* import initWasm, { WasmMultiHeadAttention, ... } from './ruvector_attention_browser.js';
|
||||
* await initWasm();
|
||||
* const attn = new WasmMultiHeadAttention(dim, heads);
|
||||
*/
|
||||
|
||||
let _wasm;
|
||||
let _initialized = false;
|
||||
|
||||
// The entire CJS module runs inside this IIFE to avoid polluting global scope.
|
||||
// We capture all exports in _mod.
|
||||
const _mod = {};
|
||||
|
||||
(function(exports, wasm_getter) {
|
||||
|
||||
// ── wasm-bindgen heap management ──────────────────────────────────
|
||||
const heap = new Array(128).fill(undefined);
|
||||
heap.push(undefined, null, true, false);
|
||||
let heap_next = heap.length;
|
||||
|
||||
function addHeapObject(obj) {
|
||||
if (heap_next === heap.length) heap.push(heap.length + 1);
|
||||
const idx = heap_next;
|
||||
heap_next = heap[idx];
|
||||
heap[idx] = obj;
|
||||
return idx;
|
||||
}
|
||||
function getObject(idx) { return heap[idx]; }
|
||||
function dropObject(idx) {
|
||||
if (idx < 132) return;
|
||||
heap[idx] = heap_next;
|
||||
heap_next = idx;
|
||||
}
|
||||
function takeObject(idx) {
|
||||
const ret = getObject(idx);
|
||||
dropObject(idx);
|
||||
return ret;
|
||||
}
|
||||
function isLikeNone(x) { return x === undefined || x === null; }
|
||||
|
||||
// ── Memory views ──────────────────────────────────────────────────
|
||||
let cachedDataViewMemory0 = null;
|
||||
let cachedUint8ArrayMemory0 = null;
|
||||
let cachedFloat32ArrayMemory0 = null;
|
||||
|
||||
function wasm() { return wasm_getter(); }
|
||||
|
||||
function getDataViewMemory0() {
|
||||
if (cachedDataViewMemory0 === null || cachedDataViewMemory0.buffer !== wasm().memory.buffer)
|
||||
cachedDataViewMemory0 = new DataView(wasm().memory.buffer);
|
||||
return cachedDataViewMemory0;
|
||||
}
|
||||
function getUint8ArrayMemory0() {
|
||||
if (cachedUint8ArrayMemory0 === null || cachedUint8ArrayMemory0.buffer !== wasm().memory.buffer)
|
||||
cachedUint8ArrayMemory0 = new Uint8Array(wasm().memory.buffer);
|
||||
return cachedUint8ArrayMemory0;
|
||||
}
|
||||
function getFloat32ArrayMemory0() {
|
||||
if (cachedFloat32ArrayMemory0 === null || cachedFloat32ArrayMemory0.buffer !== wasm().memory.buffer)
|
||||
cachedFloat32ArrayMemory0 = new Float32Array(wasm().memory.buffer);
|
||||
return cachedFloat32ArrayMemory0;
|
||||
}
|
||||
function getArrayF32FromWasm0(ptr, len) {
|
||||
ptr = ptr >>> 0;
|
||||
return getFloat32ArrayMemory0().subarray(ptr / 4, ptr / 4 + len);
|
||||
}
|
||||
function getArrayU8FromWasm0(ptr, len) {
|
||||
ptr = ptr >>> 0;
|
||||
return getUint8ArrayMemory0().subarray(ptr, ptr + len);
|
||||
}
|
||||
|
||||
let WASM_VECTOR_LEN = 0;
|
||||
|
||||
function passArrayF32ToWasm0(arg, malloc) {
|
||||
const ptr = malloc(arg.length * 4, 4) >>> 0;
|
||||
getFloat32ArrayMemory0().set(arg, ptr / 4);
|
||||
WASM_VECTOR_LEN = arg.length;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
const cachedTextEncoder = new TextEncoder();
|
||||
const cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true });
|
||||
cachedTextDecoder.decode();
|
||||
|
||||
function getStringFromWasm0(ptr, len) {
|
||||
ptr = ptr >>> 0;
|
||||
return cachedTextDecoder.decode(getUint8ArrayMemory0().subarray(ptr, ptr + len));
|
||||
}
|
||||
|
||||
function passStringToWasm0(arg, malloc, realloc) {
|
||||
const buf = cachedTextEncoder.encode(arg);
|
||||
const ptr = malloc(buf.length, 1) >>> 0;
|
||||
getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf);
|
||||
WASM_VECTOR_LEN = buf.length;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
function debugString(val) {
|
||||
const type = typeof val;
|
||||
if (type == 'number' || type == 'boolean' || val == null) return `${val}`;
|
||||
if (type == 'string') return `"${val}"`;
|
||||
if (type == 'symbol') return val.description ? `Symbol(${val.description})` : 'Symbol';
|
||||
if (type == 'function') return 'Function';
|
||||
if (Array.isArray(val)) return `[${val.map(debugString).join(', ')}]`;
|
||||
try {
|
||||
const keys = Object.keys(val);
|
||||
return `{${keys.map(k => `${k}: ${debugString(val[k])}`).join(', ')}}`;
|
||||
} catch (_) { return Object.prototype.toString.call(val); }
|
||||
}
|
||||
|
||||
function handleError(f, args) {
|
||||
try { return f.apply(this, args); }
|
||||
catch (e) { wasm().__wbindgen_export3(addHeapObject(e)); }
|
||||
}
|
||||
|
||||
// ── FinalizationRegistry ──────────────────────────────────────────
|
||||
const FR = typeof FinalizationRegistry !== 'undefined'
|
||||
? FinalizationRegistry
|
||||
: class { register() {} unregister() {} };
|
||||
|
||||
const WasmMultiHeadAttentionFinalization = new FR(ptr => wasm().__wbg_wasmmultiheadattention_free(ptr >>> 0, 1));
|
||||
const WasmFlashAttentionFinalization = new FR(ptr => wasm().__wbg_wasmflashattention_free(ptr >>> 0, 1));
|
||||
const WasmHyperbolicAttentionFinalization = new FR(ptr => wasm().__wbg_wasmhyperbolicattention_free(ptr >>> 0, 1));
|
||||
const WasmMoEAttentionFinalization = new FR(ptr => wasm().__wbg_wasmmoeattention_free(ptr >>> 0, 1));
|
||||
const WasmLinearAttentionFinalization = new FR(ptr => wasm().__wbg_wasmlinearattention_free(ptr >>> 0, 1));
|
||||
const WasmLocalGlobalAttentionFinalization = new FR(ptr => wasm().__wbg_wasmlocalglobalattention_free(ptr >>> 0, 1));
|
||||
|
||||
// ── Classes ───────────────────────────────────────────────────────
|
||||
|
||||
class WasmMultiHeadAttention {
|
||||
constructor(dim, num_heads) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
wasm().wasmmultiheadattention_new(retptr, dim, num_heads);
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
if (r2) throw takeObject(r1);
|
||||
this.__wbg_ptr = r0 >>> 0;
|
||||
WasmMultiHeadAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmMultiHeadAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmmultiheadattention_free(ptr, 0);
|
||||
}
|
||||
get dim() { return wasm().wasmmultiheadattention_dim(this.__wbg_ptr); }
|
||||
get num_heads() { return wasm().wasmmultiheadattention_num_heads(this.__wbg_ptr); }
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmmultiheadattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WasmFlashAttention {
|
||||
constructor(dim, block_size) {
|
||||
const ret = wasm().wasmflashattention_new(dim, block_size);
|
||||
this.__wbg_ptr = ret >>> 0;
|
||||
WasmFlashAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmFlashAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmflashattention_free(ptr, 0);
|
||||
}
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmflashattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WasmHyperbolicAttention {
|
||||
constructor(dim, curvature) {
|
||||
const ret = wasm().wasmhyperbolicattention_new(dim, curvature);
|
||||
this.__wbg_ptr = ret >>> 0;
|
||||
WasmHyperbolicAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmHyperbolicAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmhyperbolicattention_free(ptr, 0);
|
||||
}
|
||||
get curvature() { return wasm().wasmhyperbolicattention_curvature(this.__wbg_ptr); }
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmhyperbolicattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WasmMoEAttention {
|
||||
constructor(dim, num_experts, top_k) {
|
||||
const ret = wasm().wasmmoeattention_new(dim, num_experts, top_k);
|
||||
this.__wbg_ptr = ret >>> 0;
|
||||
WasmMoEAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmMoEAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmmoeattention_free(ptr, 0);
|
||||
}
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmmoeattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WasmLinearAttention {
|
||||
constructor(dim, num_features) {
|
||||
const ret = wasm().wasmlinearattention_new(dim, num_features || dim);
|
||||
this.__wbg_ptr = ret >>> 0;
|
||||
WasmLinearAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmLinearAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmlinearattention_free(ptr, 0);
|
||||
}
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmlinearattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WasmLocalGlobalAttention {
|
||||
constructor(dim, local_window, global_tokens) {
|
||||
const ret = wasm().wasmlocalglobalattention_new(dim, local_window || 4, global_tokens || 2);
|
||||
this.__wbg_ptr = ret >>> 0;
|
||||
WasmLocalGlobalAttentionFinalization.register(this, this.__wbg_ptr, this);
|
||||
}
|
||||
free() {
|
||||
const ptr = this.__wbg_ptr; this.__wbg_ptr = 0;
|
||||
WasmLocalGlobalAttentionFinalization.unregister(this);
|
||||
wasm().__wbg_wasmlocalglobalattention_free(ptr, 0);
|
||||
}
|
||||
compute(query, keys, values) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().wasmlocalglobalattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Standalone functions ──────────────────────────────────────────
|
||||
|
||||
function cosine_similarity(a, b) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(a, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
const ptr1 = passArrayF32ToWasm0(b, wasm().__wbindgen_export);
|
||||
const len1 = WASM_VECTOR_LEN;
|
||||
wasm().cosine_similarity(retptr, ptr0, len0, ptr1, len1);
|
||||
var r0 = getDataViewMemory0().getFloat64(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r2) throw takeObject(r1);
|
||||
return r0;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(vec) {
|
||||
const ptr0 = passArrayF32ToWasm0(vec, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().normalize(ptr0, len0, addHeapObject(vec));
|
||||
}
|
||||
|
||||
function l2_norm(vec) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(vec, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().l2_norm(retptr, ptr0, len0);
|
||||
var r0 = getDataViewMemory0().getFloat64(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r2) throw takeObject(r1);
|
||||
return r0;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function softmax(vec) {
|
||||
const ptr0 = passArrayF32ToWasm0(vec, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().softmax(ptr0, len0, addHeapObject(vec));
|
||||
}
|
||||
|
||||
function batch_normalize(vectors, epsilon) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
wasm().batch_normalize(retptr, addHeapObject(vectors), isLikeNone(epsilon) ? 0x100000001 : Math.fround(epsilon));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function pairwise_distances(vectors) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
wasm().pairwise_distances(retptr, addHeapObject(vectors));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function scaled_dot_attention(query, keys, values, scale) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
const ptr0 = passArrayF32ToWasm0(query, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().scaled_dot_attention(retptr, ptr0, len0, addHeapObject(keys), addHeapObject(values), isLikeNone(scale) ? 0x100000001 : Math.fround(scale));
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var r2 = getDataViewMemory0().getInt32(retptr + 8, true);
|
||||
var r3 = getDataViewMemory0().getInt32(retptr + 12, true);
|
||||
if (r3) throw takeObject(r2);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function attention_weights(scores, temperature) {
|
||||
const ptr0 = passArrayF32ToWasm0(scores, wasm().__wbindgen_export);
|
||||
const len0 = WASM_VECTOR_LEN;
|
||||
wasm().attention_weights(ptr0, len0, addHeapObject(scores), isLikeNone(temperature) ? 0x100000001 : Math.fround(temperature));
|
||||
}
|
||||
|
||||
function available_mechanisms() {
|
||||
const ret = wasm().available_mechanisms();
|
||||
return takeObject(ret);
|
||||
}
|
||||
|
||||
function random_orthogonal_matrix(dim) {
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
wasm().random_orthogonal_matrix(retptr, dim);
|
||||
var r0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
var r1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
var v1 = getArrayF32FromWasm0(r0, r1).slice();
|
||||
wasm().__wbindgen_export4(r0, r1 * 4, 4);
|
||||
return v1;
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
}
|
||||
}
|
||||
|
||||
function rv_init() { wasm().init(); }
|
||||
|
||||
function rv_version() {
|
||||
let d0, d1;
|
||||
const retptr = wasm().__wbindgen_add_to_stack_pointer(-16);
|
||||
try {
|
||||
wasm().version(retptr);
|
||||
d0 = getDataViewMemory0().getInt32(retptr + 0, true);
|
||||
d1 = getDataViewMemory0().getInt32(retptr + 4, true);
|
||||
return getStringFromWasm0(d0, d1);
|
||||
} finally {
|
||||
wasm().__wbindgen_add_to_stack_pointer(16);
|
||||
if (d0 !== undefined) wasm().__wbindgen_export4(d0, d1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Collect exports ───────────────────────────────────────────────
|
||||
exports.WasmMultiHeadAttention = WasmMultiHeadAttention;
|
||||
exports.WasmFlashAttention = WasmFlashAttention;
|
||||
exports.WasmHyperbolicAttention = WasmHyperbolicAttention;
|
||||
exports.WasmMoEAttention = WasmMoEAttention;
|
||||
exports.WasmLinearAttention = WasmLinearAttention;
|
||||
exports.WasmLocalGlobalAttention = WasmLocalGlobalAttention;
|
||||
exports.cosine_similarity = cosine_similarity;
|
||||
exports.normalize = normalize;
|
||||
exports.l2_norm = l2_norm;
|
||||
exports.softmax = softmax;
|
||||
exports.batch_normalize = batch_normalize;
|
||||
exports.pairwise_distances = pairwise_distances;
|
||||
exports.scaled_dot_attention = scaled_dot_attention;
|
||||
exports.attention_weights = attention_weights;
|
||||
exports.available_mechanisms = available_mechanisms;
|
||||
exports.random_orthogonal_matrix = random_orthogonal_matrix;
|
||||
exports.init = rv_init;
|
||||
exports.version = rv_version;
|
||||
|
||||
// ── Build WASM import object ──────────────────────────────────────
|
||||
exports.__wbg_get_imports = function() {
|
||||
const import0 = {
|
||||
__proto__: null,
|
||||
__wbg_Error_4577686b3a6d9b3a: (arg0, arg1) => addHeapObject(Error(getStringFromWasm0(arg0, arg1))),
|
||||
__wbg_String_8564e559799eccda: (arg0, arg1) => {
|
||||
const ret = String(getObject(arg1));
|
||||
const ptr1 = passStringToWasm0(ret, wasm().__wbindgen_export, wasm().__wbindgen_export2);
|
||||
const len1 = WASM_VECTOR_LEN;
|
||||
getDataViewMemory0().setInt32(arg0 + 4, len1, true);
|
||||
getDataViewMemory0().setInt32(arg0, ptr1, true);
|
||||
},
|
||||
__wbg___wbindgen_boolean_get_18c4ed9422296fff: (arg0) => {
|
||||
const v = getObject(arg0);
|
||||
const ret = typeof v === 'boolean' ? v : undefined;
|
||||
return isLikeNone(ret) ? 0xFFFFFF : ret ? 1 : 0;
|
||||
},
|
||||
__wbg___wbindgen_copy_to_typed_array_5294f8e46aecc086: (arg0, arg1, arg2) => {
|
||||
new Uint8Array(getObject(arg2).buffer, getObject(arg2).byteOffset, getObject(arg2).byteLength).set(getArrayU8FromWasm0(arg0, arg1));
|
||||
},
|
||||
__wbg___wbindgen_debug_string_ddde1867f49c2442: (arg0, arg1) => {
|
||||
const ret = debugString(getObject(arg1));
|
||||
const ptr1 = passStringToWasm0(ret, wasm().__wbindgen_export, wasm().__wbindgen_export2);
|
||||
const len1 = WASM_VECTOR_LEN;
|
||||
getDataViewMemory0().setInt32(arg0 + 4, len1, true);
|
||||
getDataViewMemory0().setInt32(arg0, ptr1, true);
|
||||
},
|
||||
__wbg___wbindgen_is_function_d633e708baf0d146: (arg0) => typeof getObject(arg0) === 'function',
|
||||
__wbg___wbindgen_is_object_4b3de556756ee8a8: (arg0) => {
|
||||
const val = getObject(arg0);
|
||||
return typeof val === 'object' && val !== null;
|
||||
},
|
||||
__wbg___wbindgen_jsval_loose_eq_1562ceb9af84e990: (arg0, arg1) => getObject(arg0) == getObject(arg1),
|
||||
__wbg___wbindgen_number_get_5854912275df1894: (arg0, arg1) => {
|
||||
const obj = getObject(arg1);
|
||||
const ret = typeof obj === 'number' ? obj : undefined;
|
||||
getDataViewMemory0().setFloat64(arg0 + 8, isLikeNone(ret) ? 0 : ret, true);
|
||||
getDataViewMemory0().setInt32(arg0, !isLikeNone(ret), true);
|
||||
},
|
||||
__wbg___wbindgen_string_get_3e5751597f39a112: (arg0, arg1) => {
|
||||
const obj = getObject(arg1);
|
||||
const ret = typeof obj === 'string' ? obj : undefined;
|
||||
var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm().__wbindgen_export, wasm().__wbindgen_export2);
|
||||
var len1 = WASM_VECTOR_LEN;
|
||||
getDataViewMemory0().setInt32(arg0 + 4, len1, true);
|
||||
getDataViewMemory0().setInt32(arg0, ptr1, true);
|
||||
},
|
||||
__wbg___wbindgen_throw_39bc967c0e5a9b58: (arg0, arg1) => { throw new Error(getStringFromWasm0(arg0, arg1)); },
|
||||
__wbg_call_73af281463ec8b58: function() { return handleError(function(arg0, arg1) {
|
||||
return addHeapObject(getObject(arg0).call(getObject(arg1)));
|
||||
}, arguments); },
|
||||
__wbg_done_5aad55ec6b1954b1: (arg0) => getObject(arg0).done,
|
||||
__wbg_error_a6fa202b58aa1cd3: (arg0, arg1) => {
|
||||
try { console.error(getStringFromWasm0(arg0, arg1)); }
|
||||
finally { wasm().__wbindgen_export4(arg0, arg1, 1); }
|
||||
},
|
||||
__wbg_error_ad28debb48b5c6bb: (arg0) => console.error(getObject(arg0)),
|
||||
__wbg_get_4920fefd3451364b: function() { return handleError(function(arg0, arg1) {
|
||||
return addHeapObject(Reflect.get(getObject(arg0), getObject(arg1)));
|
||||
}, arguments); },
|
||||
__wbg_get_unchecked_3d0f4b91c8eca4f0: (arg0, arg1) => addHeapObject(getObject(arg0)[arg1 >>> 0]),
|
||||
__wbg_instanceof_ArrayBuffer_15859862b80b732d: (arg0) => {
|
||||
try { return getObject(arg0) instanceof ArrayBuffer; } catch (_) { return false; }
|
||||
},
|
||||
__wbg_instanceof_Uint8Array_2240b7046ac16f05: (arg0) => {
|
||||
try { return getObject(arg0) instanceof Uint8Array; } catch (_) { return false; }
|
||||
},
|
||||
__wbg_isArray_fad08a0d12828686: (arg0) => Array.isArray(getObject(arg0)),
|
||||
__wbg_iterator_fc7ad8d33bab9e26: () => addHeapObject(Symbol.iterator),
|
||||
__wbg_length_5855c1f289dfffc1: (arg0) => getObject(arg0).length,
|
||||
__wbg_length_a31e05262e09b7f8: (arg0) => getObject(arg0).length,
|
||||
__wbg_log_3c5e4b64af29e724: (arg0) => console.log(getObject(arg0)),
|
||||
__wbg_new_09959f7b4c92c246: (arg0) => addHeapObject(new Uint8Array(getObject(arg0))),
|
||||
__wbg_new_227d7c05414eb861: () => addHeapObject(new Error()),
|
||||
__wbg_new_cbee8c0d5c479eac: () => addHeapObject(new Array()),
|
||||
__wbg_next_a5fe6f328f7affc2: (arg0) => addHeapObject(getObject(arg0).next),
|
||||
__wbg_next_e592122bb4ed4c67: function() { return handleError(function(arg0) {
|
||||
return addHeapObject(getObject(arg0).next());
|
||||
}, arguments); },
|
||||
__wbg_prototypesetcall_f034d444741426c3: (arg0, arg1, arg2) => {
|
||||
Uint8Array.prototype.set.call(getArrayU8FromWasm0(arg0, arg1), getObject(arg2));
|
||||
},
|
||||
__wbg_random_2b7bed8995d680fb: () => Math.random(),
|
||||
__wbg_set_4c81cfb5dc3a333c: (arg0, arg1, arg2) => { getObject(arg0)[arg1 >>> 0] = takeObject(arg2); },
|
||||
__wbg_stack_3b0d974bbf31e44f: (arg0, arg1) => {
|
||||
const ret = getObject(arg1).stack;
|
||||
const ptr1 = passStringToWasm0(ret, wasm().__wbindgen_export, wasm().__wbindgen_export2);
|
||||
const len1 = WASM_VECTOR_LEN;
|
||||
getDataViewMemory0().setInt32(arg0 + 4, len1, true);
|
||||
getDataViewMemory0().setInt32(arg0, ptr1, true);
|
||||
},
|
||||
__wbg_value_667dcb90597486a6: (arg0) => addHeapObject(getObject(arg0).value),
|
||||
__wbindgen_cast_0000000000000001: (arg0, arg1) => addHeapObject(getStringFromWasm0(arg0, arg1)),
|
||||
__wbindgen_object_drop_ref: (arg0) => takeObject(arg0),
|
||||
};
|
||||
return { __proto__: null, "./ruvector_attention_wasm_bg.js": import0 };
|
||||
};
|
||||
|
||||
})(_mod, () => _wasm);
|
||||
|
||||
|
||||
// ── Async WASM init (fetch-based for browsers) ───────────────────
|
||||
|
||||
export default async function initWasm() {
|
||||
if (_initialized) return;
|
||||
const wasmUrl = new URL('ruvector_attention_wasm_bg.wasm', import.meta.url);
|
||||
const imports = _mod.__wbg_get_imports();
|
||||
let result;
|
||||
if (typeof WebAssembly.instantiateStreaming === 'function') {
|
||||
try {
|
||||
result = await WebAssembly.instantiateStreaming(fetch(wasmUrl), imports);
|
||||
} catch (e) {
|
||||
// Fallback if streaming fails (e.g. wrong MIME type)
|
||||
const bytes = await (await fetch(wasmUrl)).arrayBuffer();
|
||||
result = await WebAssembly.instantiate(bytes, imports);
|
||||
}
|
||||
} else {
|
||||
const bytes = await (await fetch(wasmUrl)).arrayBuffer();
|
||||
result = await WebAssembly.instantiate(bytes, imports);
|
||||
}
|
||||
_wasm = result.instance.exports;
|
||||
_wasm.__wbindgen_start();
|
||||
_initialized = true;
|
||||
}
|
||||
|
||||
// ── ESM re-exports ────────────────────────────────────────────────
|
||||
// Attention mechanism classes
|
||||
export const WasmMultiHeadAttention = _mod.WasmMultiHeadAttention;
|
||||
export const WasmFlashAttention = _mod.WasmFlashAttention;
|
||||
export const WasmHyperbolicAttention = _mod.WasmHyperbolicAttention;
|
||||
export const WasmMoEAttention = _mod.WasmMoEAttention;
|
||||
export const WasmLinearAttention = _mod.WasmLinearAttention;
|
||||
export const WasmLocalGlobalAttention = _mod.WasmLocalGlobalAttention;
|
||||
// Utility functions
|
||||
export const cosine_similarity = _mod.cosine_similarity;
|
||||
export const normalize = _mod.normalize;
|
||||
export const l2_norm = _mod.l2_norm;
|
||||
export const softmax = _mod.softmax;
|
||||
export const batch_normalize = _mod.batch_normalize;
|
||||
export const pairwise_distances = _mod.pairwise_distances;
|
||||
export const scaled_dot_attention = _mod.scaled_dot_attention;
|
||||
export const attention_weights = _mod.attention_weights;
|
||||
export const random_orthogonal_matrix = _mod.random_orthogonal_matrix;
|
||||
export const available_mechanisms = _mod.available_mechanisms;
|
||||
// Lifecycle
|
||||
export const init = _mod.init;
|
||||
export const version = _mod.version;
|
||||
@@ -1,359 +0,0 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Adam optimizer
|
||||
*/
|
||||
export class WasmAdam {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new Adam optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step
|
||||
*
|
||||
* # Arguments
|
||||
* * `params` - Current parameter values (will be updated in-place)
|
||||
* * `gradients` - Gradient values
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* AdamW optimizer (Adam with decoupled weight decay)
|
||||
*/
|
||||
export class WasmAdamW {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new AdamW optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
* * `weight_decay` - Weight decay coefficient
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number, weight_decay: number);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step with weight decay
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
/**
|
||||
* Get weight decay
|
||||
*/
|
||||
readonly weight_decay: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Flash attention mechanism
|
||||
*/
|
||||
export class WasmFlashAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute flash attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new flash attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `block_size` - Block size for tiling
|
||||
*/
|
||||
constructor(dim: number, block_size: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hyperbolic attention mechanism
|
||||
*/
|
||||
export class WasmHyperbolicAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute hyperbolic attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new hyperbolic attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `curvature` - Hyperbolic curvature parameter
|
||||
*/
|
||||
constructor(dim: number, curvature: number);
|
||||
/**
|
||||
* Get the curvature
|
||||
*/
|
||||
readonly curvature: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* InfoNCE contrastive loss for training
|
||||
*/
|
||||
export class WasmInfoNCELoss {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute InfoNCE loss
|
||||
*
|
||||
* # Arguments
|
||||
* * `anchor` - Anchor embedding
|
||||
* * `positive` - Positive example embedding
|
||||
* * `negatives` - Array of negative example embeddings
|
||||
*/
|
||||
compute(anchor: Float32Array, positive: Float32Array, negatives: any): number;
|
||||
/**
|
||||
* Create a new InfoNCE loss instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `temperature` - Temperature parameter for softmax
|
||||
*/
|
||||
constructor(temperature: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Learning rate scheduler
|
||||
*/
|
||||
export class WasmLRScheduler {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Get learning rate for current step
|
||||
*/
|
||||
get_lr(): number;
|
||||
/**
|
||||
* Create a new learning rate scheduler with warmup and cosine decay
|
||||
*
|
||||
* # Arguments
|
||||
* * `initial_lr` - Initial learning rate
|
||||
* * `warmup_steps` - Number of warmup steps
|
||||
* * `total_steps` - Total training steps
|
||||
*/
|
||||
constructor(initial_lr: number, warmup_steps: number, total_steps: number);
|
||||
/**
|
||||
* Reset scheduler
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Advance to next step
|
||||
*/
|
||||
step(): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear attention (Performer-style)
|
||||
*/
|
||||
export class WasmLinearAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute linear attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new linear attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_features` - Number of random features
|
||||
*/
|
||||
constructor(dim: number, num_features: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Local-global attention mechanism
|
||||
*/
|
||||
export class WasmLocalGlobalAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute local-global attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new local-global attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `local_window` - Size of local attention window
|
||||
* * `global_tokens` - Number of global attention tokens
|
||||
*/
|
||||
constructor(dim: number, local_window: number, global_tokens: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Mixture of Experts (MoE) attention
|
||||
*/
|
||||
export class WasmMoEAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute MoE attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new MoE attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_experts` - Number of expert attention mechanisms
|
||||
* * `top_k` - Number of experts to use per query
|
||||
*/
|
||||
constructor(dim: number, num_experts: number, top_k: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-head attention mechanism
|
||||
*/
|
||||
export class WasmMultiHeadAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute multi-head attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new multi-head attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_heads` - Number of attention heads
|
||||
*/
|
||||
constructor(dim: number, num_heads: number);
|
||||
/**
|
||||
* Get the dimension
|
||||
*/
|
||||
readonly dim: number;
|
||||
/**
|
||||
* Get the number of heads
|
||||
*/
|
||||
readonly num_heads: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* SGD optimizer with momentum
|
||||
*/
|
||||
export class WasmSGD {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new SGD optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
* * `momentum` - Momentum coefficient (default: 0)
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number, momentum?: number | null);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute attention weights from scores
|
||||
*/
|
||||
export function attention_weights(scores: Float32Array, temperature?: number | null): void;
|
||||
|
||||
/**
|
||||
* Get information about available attention mechanisms
|
||||
*/
|
||||
export function available_mechanisms(): any;
|
||||
|
||||
/**
|
||||
* Batch normalize vectors
|
||||
*/
|
||||
export function batch_normalize(vectors: any, epsilon?: number | null): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute cosine similarity between two vectors
|
||||
*/
|
||||
export function cosine_similarity(a: Float32Array, b: Float32Array): number;
|
||||
|
||||
/**
|
||||
* Initialize the WASM module with panic hook
|
||||
*/
|
||||
export function init(): void;
|
||||
|
||||
/**
|
||||
* Compute L2 norm of a vector
|
||||
*/
|
||||
export function l2_norm(vec: Float32Array): number;
|
||||
|
||||
/**
|
||||
* Log a message to the browser console
|
||||
*/
|
||||
export function log(message: string): void;
|
||||
|
||||
/**
|
||||
* Log an error to the browser console
|
||||
*/
|
||||
export function log_error(message: string): void;
|
||||
|
||||
/**
|
||||
* Normalize a vector to unit length
|
||||
*/
|
||||
export function normalize(vec: Float32Array): void;
|
||||
|
||||
/**
|
||||
* Compute pairwise distances between vectors
|
||||
*/
|
||||
export function pairwise_distances(vectors: any): Float32Array;
|
||||
|
||||
/**
|
||||
* Generate random orthogonal matrix (for initialization)
|
||||
*/
|
||||
export function random_orthogonal_matrix(dim: number): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute scaled dot-product attention
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - Query vector as Float32Array
|
||||
* * `keys` - Array of key vectors
|
||||
* * `values` - Array of value vectors
|
||||
* * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
*/
|
||||
export function scaled_dot_attention(query: Float32Array, keys: any, values: any, scale?: number | null): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute softmax of a vector
|
||||
*/
|
||||
export function softmax(vec: Float32Array): void;
|
||||
|
||||
/**
|
||||
* Get the version of the ruvector-attention-wasm crate
|
||||
*/
|
||||
export function version(): string;
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -1,71 +0,0 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const memory: WebAssembly.Memory;
|
||||
export const __wbg_wasmadam_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmadamw_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmflashattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasminfonceloss_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlinearattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmoeattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmsgd_free: (a: number, b: number) => void;
|
||||
export const attention_weights: (a: number, b: number, c: number, d: number) => void;
|
||||
export const available_mechanisms: () => number;
|
||||
export const batch_normalize: (a: number, b: number, c: number) => void;
|
||||
export const cosine_similarity: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const l2_norm: (a: number, b: number) => number;
|
||||
export const log: (a: number, b: number) => void;
|
||||
export const log_error: (a: number, b: number) => void;
|
||||
export const normalize: (a: number, b: number, c: number, d: number) => void;
|
||||
export const pairwise_distances: (a: number, b: number) => void;
|
||||
export const random_orthogonal_matrix: (a: number, b: number) => void;
|
||||
export const scaled_dot_attention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const softmax: (a: number, b: number, c: number) => void;
|
||||
export const version: (a: number) => void;
|
||||
export const wasmadam_learning_rate: (a: number) => number;
|
||||
export const wasmadam_new: (a: number, b: number) => number;
|
||||
export const wasmadam_reset: (a: number) => void;
|
||||
export const wasmadam_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmadam_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmadamw_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmadamw_reset: (a: number) => void;
|
||||
export const wasmadamw_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmadamw_weight_decay: (a: number) => number;
|
||||
export const wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmflashattention_new: (a: number, b: number) => number;
|
||||
export const wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmhyperbolicattention_curvature: (a: number) => number;
|
||||
export const wasmhyperbolicattention_new: (a: number, b: number) => number;
|
||||
export const wasminfonceloss_compute: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
|
||||
export const wasminfonceloss_new: (a: number) => number;
|
||||
export const wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlinearattention_new: (a: number, b: number) => number;
|
||||
export const wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmlrscheduler_get_lr: (a: number) => number;
|
||||
export const wasmlrscheduler_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmlrscheduler_reset: (a: number) => void;
|
||||
export const wasmlrscheduler_step: (a: number) => void;
|
||||
export const wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmoeattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmultiheadattention_dim: (a: number) => number;
|
||||
export const wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
|
||||
export const wasmmultiheadattention_num_heads: (a: number) => number;
|
||||
export const wasmsgd_learning_rate: (a: number) => number;
|
||||
export const wasmsgd_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmsgd_reset: (a: number) => void;
|
||||
export const wasmsgd_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmsgd_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const init: () => void;
|
||||
export const wasmadamw_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmadamw_learning_rate: (a: number) => number;
|
||||
export const __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlrscheduler_free: (a: number, b: number) => void;
|
||||
export const __wbindgen_export: (a: number, b: number) => number;
|
||||
export const __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
|
||||
export const __wbindgen_export3: (a: number) => void;
|
||||
export const __wbindgen_export4: (a: number, b: number, c: number) => void;
|
||||
export const __wbindgen_add_to_stack_pointer: (a: number) => number;
|
||||
export const __wbindgen_start: () => void;
|
||||
Reference in New Issue
Block a user