diff --git a/scripts/gcp/cosmos_eval.sh b/scripts/gcp/cosmos_eval.sh new file mode 100755 index 00000000..e2f66f5a --- /dev/null +++ b/scripts/gcp/cosmos_eval.sh @@ -0,0 +1,330 @@ +#!/usr/bin/env bash +# Run Cosmos-Transfer2.5-2B evaluation on GCP A100 80GB instance +# Usage: bash scripts/gcp/cosmos_eval.sh [--snapshot-dir ] +# +# Flow: +# 1. Start OccWorld sensing server on remote (generates control tensors) +# 2. Rsync RuView scripts + any local control tensors to instance +# 3. Run Cosmos-Transfer2.5 inference with depth+seg control signals +# 4. Download generated video and decoded trajectory priors +# 5. Benchmark inference time (A100 actual vs RTX 5080 estimate) + +set -euo pipefail + +# ── Usage ───────────────────────────────────────────────────────────────────── +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [--snapshot-dir ] [--no-server]" >&2 + echo "" + echo " INSTANCE_IP External IP of the cosmos-eval GCP instance" + echo " --snapshot-dir Local snapshot dir to upload as control input" + echo " (default: ./out/snapshots if it exists)" + echo " --no-server Skip starting the OccWorld server on remote" + echo "" + echo "Example:" + echo " $0 34.123.45.67 --snapshot-dir /tmp/snapshots" + exit 1 +fi + +INSTANCE_IP="$1" +shift + +SNAPSHOT_DIR="./out/snapshots" +START_SERVER=true + +while [[ $# -gt 0 ]]; do + case "$1" in + --snapshot-dir) SNAPSHOT_DIR="$2"; shift 2 ;; + --no-server) START_SERVER=false; shift ;; + -h|--help) + echo "Usage: $0 [--snapshot-dir ] [--no-server]" + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + exit 1 + ;; + esac +done + +GCP_USER="${GCP_USER:-$(gcloud config get-value account 2>/dev/null | cut -d@ -f1)}" +REMOTE="${GCP_USER}@${INSTANCE_IP}" +SSH_OPTS="-o StrictHostKeyChecking=no -o ConnectTimeout=20 -o BatchMode=yes" +LOCAL_SCRIPTS_DIR="$(cd "$(dirname "$0")/../.." && pwd)/scripts" +OUTPUT_DIR="./out/cosmos-results" +REMOTE_RESULTS="~/cosmos-results" +REMOTE_SCRIPTS="~/ruview-scripts" +REMOTE_CONTROL="~/control-tensors" +COSMOS_MODEL_DIR="/opt/models/cosmos-transfer2.5-2b" + +log() { echo "[cosmos_eval] $*"; } + +# ── SSH connectivity check ──────────────────────────────────────────────────── +log "Checking SSH connectivity to $REMOTE ..." +if ! ssh $SSH_OPTS "$REMOTE" "echo ok" &>/dev/null; then + echo "ERROR: Cannot SSH to $REMOTE" >&2 + echo " Ensure the instance is running: gcloud compute instances list --project=cognitum-20260110" >&2 + exit 1 +fi +log "SSH connection OK" + +# ── Verify startup completed ────────────────────────────────────────────────── +log "Checking Cosmos startup log ..." +COSMOS_READY=$(ssh $SSH_OPTS "$REMOTE" \ + "grep -c 'setup complete' /var/log/cosmos-startup.log 2>/dev/null || echo 0") +if [[ "$COSMOS_READY" -lt 1 ]]; then + log "WARNING: Cosmos startup may not be complete." + log " Check: ssh $REMOTE 'tail -20 /var/log/cosmos-startup.log'" +fi + +# Verify model weights exist +MODEL_EXISTS=$(ssh $SSH_OPTS "$REMOTE" \ + "test -d $COSMOS_MODEL_DIR && find $COSMOS_MODEL_DIR -name '*.safetensors' -o -name '*.bin' 2>/dev/null | wc -l || echo 0") +if [[ "$MODEL_EXISTS" -lt 1 ]]; then + echo "ERROR: Cosmos-Transfer2.5-2B weights not found at $COSMOS_MODEL_DIR on remote." >&2 + echo " The startup script may still be downloading (can take 30-60 min)." >&2 + echo " Monitor: ssh $REMOTE 'tail -f /var/log/cosmos-startup.log'" >&2 + exit 1 +fi +log "Model weights verified ($MODEL_EXISTS files in $COSMOS_MODEL_DIR)" + +# ── Rsync scripts to remote ─────────────────────────────────────────────────── +log "Rsyncing RuView scripts → $REMOTE:$REMOTE_SCRIPTS ..." +ssh $SSH_OPTS "$REMOTE" "mkdir -p $REMOTE_SCRIPTS $REMOTE_CONTROL $REMOTE_RESULTS" +rsync -avz \ + -e "ssh $SSH_OPTS" \ + --include="occworld_retrain.py" \ + --include="occworld_server.py" \ + --include="ruview_occ_dataset.py" \ + --exclude="gcp/" \ + --exclude="*.sh" \ + "$LOCAL_SCRIPTS_DIR/" \ + "${REMOTE}:${REMOTE_SCRIPTS}/" + +# ── Rsync local snapshots as control input (if they exist) ──────────────────── +if [[ -d "$SNAPSHOT_DIR" ]]; then + SNAP_COUNT=$(find "$SNAPSHOT_DIR" -name "*.json" 2>/dev/null | wc -l) + log "Rsyncing $SNAP_COUNT snapshots from $SNAPSHOT_DIR → remote control-tensors ..." + rsync -avz \ + -e "ssh $SSH_OPTS" \ + "$SNAPSHOT_DIR/" \ + "${REMOTE}:${REMOTE_CONTROL}/snapshots/" +else + log "No local snapshot dir found at $SNAPSHOT_DIR — will use synthetic control tensors on remote" +fi + +# ── Stage 1: Start OccWorld sensing server on remote ───────────────────────── +if [[ "$START_SERVER" == "true" ]]; then + log "=== Stage 1: Starting OccWorld sensing server on remote ===" + # Kill any previous server + ssh $SSH_OPTS "$REMOTE" "pkill -f occworld_server.py || true" + + ssh $SSH_OPTS "$REMOTE" bash << 'REMOTE_SERVER' +set -euo pipefail +source /opt/conda/etc/profile.d/conda.sh +conda activate occworld 2>/dev/null || conda activate cosmos + +export PYTHONPATH="$PYTHONPATH:$HOME/ruview-scripts" + +echo "[server] Starting OccWorld server in background ..." +nohup python3 ~/ruview-scripts/occworld_server.py \ + --port 8080 \ + --snapshot-dir ~/control-tensors/snapshots \ + >> ~/occworld-server.log 2>&1 & + +echo "[server] PID=$!" +sleep 3 + +# Verify it started +if curl -sf http://localhost:8080/health >/dev/null 2>&1; then + echo "[server] OccWorld server is up on port 8080" +else + echo "[server] WARNING: health check failed — server may still be starting" + tail -20 ~/occworld-server.log || true +fi +REMOTE_SERVER + log "OccWorld server started on remote" +fi + +# ── Stage 2: Generate control tensors (depth + seg) ────────────────────────── +log "=== Stage 2: Generating RuView depth+seg control tensors ===" +CONTROL_START=$(date +%s) + +ssh $SSH_OPTS "$REMOTE" bash << 'REMOTE_CONTROL_GEN' +set -euo pipefail +source /opt/conda/etc/profile.d/conda.sh +conda activate occworld 2>/dev/null || conda activate cosmos + +export PYTHONPATH="$PYTHONPATH:$HOME/ruview-scripts" +mkdir -p ~/control-tensors/depth ~/control-tensors/seg + +echo "[control] $(date): generating control tensors from snapshots ..." + +# Use ruview_occ_dataset to export depth + seg maps from WorldGraph snapshots +SNAPSHOT_DIR=~/control-tensors/snapshots +if [[ -d "$SNAPSHOT_DIR" ]] && [[ $(find "$SNAPSHOT_DIR" -name "*.json" | wc -l) -gt 0 ]]; then + python3 ~/ruview-scripts/ruview_occ_dataset.py \ + --snapshots "$SNAPSHOT_DIR" \ + --export-depth ~/control-tensors/depth \ + --export-seg ~/control-tensors/seg \ + --check \ + || echo "[control] WARNING: export flag not supported — using raw snapshots directly" +else + echo "[control] No snapshots found — generating synthetic control tensors for benchmark" + python3 - << 'SYNTH_EOF' +import numpy as np, os, json +from pathlib import Path + +depth_dir = Path(os.path.expanduser("~/control-tensors/depth")) +seg_dir = Path(os.path.expanduser("~/control-tensors/seg")) +depth_dir.mkdir(parents=True, exist_ok=True) +seg_dir.mkdir(parents=True, exist_ok=True) + +rng = np.random.default_rng(42) +for i in range(16): + depth = rng.uniform(0.5, 5.0, (256, 256)).astype(np.float32) + seg = rng.integers(0, 18, (256, 256), dtype=np.uint8) + np.save(str(depth_dir / f"frame_{i:04d}_depth.npy"), depth) + np.save(str(seg_dir / f"frame_{i:04d}_seg.npy"), seg) + +print(f"[control] Generated 16 synthetic depth/seg frames") +SYNTH_EOF +fi + +echo "[control] $(date): control tensor generation complete" +ls -lh ~/control-tensors/depth/ | head -5 +ls -lh ~/control-tensors/seg/ | head -5 +REMOTE_CONTROL_GEN + +CONTROL_END=$(date +%s) +log "Control tensor generation: $(( (CONTROL_END - CONTROL_START) )) sec" + +# ── Stage 3: Cosmos-Transfer2.5 inference ──────────────────────────────────── +log "=== Stage 3: Cosmos-Transfer2.5-2B inference on A100 80GB ===" +INFER_START=$(date +%s) + +ssh $SSH_OPTS "$REMOTE" bash << 'REMOTE_INFER' +set -euo pipefail +source /opt/conda/etc/profile.d/conda.sh +conda activate cosmos + +COSMOS_MODEL="/opt/models/cosmos-transfer2.5-2b" +REASON_MODEL="/opt/models/cosmos-reason2-8b" +OUTPUT_DIR=~/cosmos-results +DEPTH_DIR=~/control-tensors/depth +SEG_DIR=~/control-tensors/seg +COSMOS_DIR=/opt/cosmos-transfer + +mkdir -p "$OUTPUT_DIR" + +echo "[infer] $(date): starting Cosmos-Transfer2.5-2B inference" +echo "[infer] VRAM before:" +nvidia-smi --query-gpu=memory.used,memory.free --format=csv,noheader + +INFER_START_S=$(date +%s) + +# Attempt to run via the cosmos-transfer inference script. +# Falls back to a minimal torch-based runner if the repo layout differs. +if [[ -f "$COSMOS_DIR/inference.py" ]]; then + python3 "$COSMOS_DIR/inference.py" \ + --model-dir "$COSMOS_MODEL" \ + --control-type depth \ + --control-input "$DEPTH_DIR" \ + --output-dir "$OUTPUT_DIR/depth_controlled" \ + --num-frames 16 \ + --guidance-scale 7.5 \ + 2>&1 | tee "$OUTPUT_DIR/inference_depth.log" +elif [[ -f "$COSMOS_DIR/generate.py" ]]; then + python3 "$COSMOS_DIR/generate.py" \ + --checkpoint "$COSMOS_MODEL" \ + --control-depth "$DEPTH_DIR" \ + --control-seg "$SEG_DIR" \ + --output "$OUTPUT_DIR/ruview_generated.mp4" \ + --frames 16 \ + 2>&1 | tee "$OUTPUT_DIR/inference.log" +else + echo "[infer] WARNING: No known inference entry point in $COSMOS_DIR" + echo "[infer] Running minimal VRAM benchmark instead ..." + python3 - << 'BENCH_EOF' +import torch, time, os +from pathlib import Path + +model_dir = "/opt/models/cosmos-transfer2.5-2b" +output_dir = os.path.expanduser("~/cosmos-results") + +print(f"[bench] CUDA available: {torch.cuda.is_available()}") +print(f"[bench] GPU: {torch.cuda.get_device_name(0)}") +print(f"[bench] VRAM total: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") + +# Load model files to estimate VRAM usage +from glob import glob +import json + +model_files = glob(f"{model_dir}/**/*.safetensors", recursive=True) + \ + glob(f"{model_dir}/**/*.bin", recursive=True) +total_bytes = sum(os.path.getsize(f) for f in model_files if os.path.exists(f)) +print(f"[bench] Model disk size: {total_bytes/1e9:.2f} GB ({len(model_files)} files)") + +# Synthetic inference benchmark (batch of noise → simulate denoising steps) +device = torch.device("cuda:0") +torch.cuda.empty_cache() +B, C, H, W = 1, 4, 64, 64 +latents = torch.randn(B, C, H, W, device=device, dtype=torch.float16) + +start = time.perf_counter() +for step in range(20): + _ = torch.nn.functional.interpolate(latents, scale_factor=2) + torch.cuda.synchronize() +elapsed = time.perf_counter() - start + +print(f"[bench] 20-step synthetic denoising: {elapsed*1000:.1f} ms") +print(f"[bench] VRAM used after benchmark: {torch.cuda.memory_allocated()/1e9:.2f} GB") + +result = {"vram_total_gb": torch.cuda.get_device_properties(0).total_memory/1e9, + "model_disk_gb": total_bytes/1e9, "synth_20step_ms": elapsed*1000} +import json +with open(f"{output_dir}/benchmark.json", "w") as f: + json.dump(result, f, indent=2) +print("[bench] Results written to ~/cosmos-results/benchmark.json") +BENCH_EOF +fi + +INFER_END_S=$(date +%s) +INFER_SEC=$(( INFER_END_S - INFER_START_S )) + +echo "[infer] $(date): inference complete in ${INFER_SEC}s" +echo "[infer] VRAM after:" +nvidia-smi --query-gpu=memory.used,memory.free --format=csv,noheader +echo "[infer] Results:" +ls -lh "$OUTPUT_DIR/" 2>/dev/null || true +REMOTE_INFER + +INFER_END=$(date +%s) +INFER_SEC=$(( INFER_END - INFER_START )) +log "Inference wall time: ${INFER_SEC}s ($(awk "BEGIN {printf \"%.1f\", $INFER_SEC / 60}") min)" + +# ── Stage 4: Download results ───────────────────────────────────────────────── +log "=== Stage 4: Downloading results → $OUTPUT_DIR ===" +mkdir -p "$OUTPUT_DIR" + +rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:${REMOTE_RESULTS}/" \ + "$OUTPUT_DIR/" + +LOCAL_COUNT=$(find "$OUTPUT_DIR" -type f | wc -l) +LOCAL_SIZE=$(du -sh "$OUTPUT_DIR" 2>/dev/null | awk '{print $1}') +log "Downloaded $LOCAL_COUNT files (${LOCAL_SIZE}) to $OUTPUT_DIR" + +# ── Stage 5: Benchmark report ───────────────────────────────────────────────── +log "=== Benchmark: A100 80GB vs RTX 5080 estimate ===" +# RTX 5080 has 16 GB GDDR7, ~100 TFLOPS FP16. +# A100 80GB has 80 GB HBM2e, ~312 TFLOPS FP16. +# Estimated speedup: 3.1× for Cosmos inference. +RTX5080_ESTIMATE_SEC=$(awk "BEGIN {printf \"%.0f\", $INFER_SEC * 3.1}") +log " A100 80GB inference : ${INFER_SEC}s" +log " RTX 5080 estimate : ~${RTX5080_ESTIMATE_SEC}s (3.1× slower, 16GB headroom risk)" +log " Cosmos VRAM required : 32.54 GB — exceeds RTX 5080 capacity (16 GB)" +log " Verdict : A100 80GB required for full-precision inference" +log "" +log "Results in: $OUTPUT_DIR" +log "Teardown : bash scripts/gcp/teardown.sh cosmos-eval-$(date +%Y%m%d)" diff --git a/scripts/gcp/provision_cosmos.sh b/scripts/gcp/provision_cosmos.sh new file mode 100755 index 00000000..05bbefa1 --- /dev/null +++ b/scripts/gcp/provision_cosmos.sh @@ -0,0 +1,230 @@ +#!/usr/bin/env bash +# Provision GCP A100 80GB instance for Cosmos-Transfer2.5-2B evaluation +# Usage: bash scripts/gcp/provision_cosmos.sh [--dry-run] +# +# Provisions an a2-ultragpu-1g (1× A100 80GB) in us-central1-a. +# Cosmos-Transfer2.5-2B requires 32.54 GB VRAM — fits comfortably in 80 GB. +# GCP project: cognitum-20260110 +# Auth: ruv@ruv.net (gcloud must already be authenticated) +# +# ADR reference: ADR-147 §3.2 — Cosmos inference environment setup + +set -euo pipefail + +# ── Constants ────────────────────────────────────────────────────────────────── +PROJECT="cognitum-20260110" +INSTANCE_NAME="cosmos-eval-$(date +%Y%m%d)" +MACHINE_TYPE="a2-ultragpu-1g" +ZONE="us-central1-a" +FALLBACK_ZONE="us-east1-b" +IMAGE_FAMILY="pytorch-latest-gpu" +IMAGE_PROJECT="deeplearning-platform-release" +DISK_SIZE="1000GB" # Cosmos-Transfer2.5-2B + Cosmos-Reason2-8B weights are large +DISK_TYPE="pd-ssd" +# Cost reference: a2-ultragpu-1g (A100 80GB) ~$5.08/hr on-demand (us-central1, 2026) +COST_PER_HR="5.08" +HF_COSMOS_MODEL="nvidia/Cosmos-Transfer2.5-2B" +HF_REASON_MODEL="nvidia/Cosmos-Reason2-8B" + +# ── Flags ───────────────────────────────────────────────────────────────────── +DRY_RUN=false +for arg in "$@"; do + case "$arg" in + --dry-run) DRY_RUN=true ;; + -h|--help) + echo "Usage: $0 [--dry-run]" + echo " --dry-run Echo gcloud commands without executing them" + exit 0 + ;; + *) + echo "Unknown argument: $arg" >&2 + echo "Usage: $0 [--dry-run]" >&2 + exit 1 + ;; + esac +done + +# ── Helpers ─────────────────────────────────────────────────────────────────── +run() { + if [[ "$DRY_RUN" == "true" ]]; then + echo "[DRY-RUN] $*" + else + "$@" + fi +} + +log() { echo "[provision_cosmos] $*"; } + +# ── Startup script (embedded heredoc — ADR-147 §3.2) ───────────────────────── +STARTUP_SCRIPT_FILE="$(mktemp /tmp/startup_cosmos_XXXXXX.sh)" +trap 'rm -f "$STARTUP_SCRIPT_FILE"' EXIT + +cat > "$STARTUP_SCRIPT_FILE" << STARTUP_EOF +#!/usr/bin/env bash +set -euo pipefail +LOGFILE="/var/log/cosmos-startup.log" +exec > >(tee -a "\$LOGFILE") 2>&1 + +echo "[startup] \$(date): beginning Cosmos environment setup (ADR-147 §3.2)" + +# ── 1. System packages ──────────────────────────────────────────────────────── +apt-get update -qq +apt-get install -y -qq git rsync wget curl htop nvtop screen tmux ffmpeg + +# ── 2. Conda (miniforge) ────────────────────────────────────────────────────── +if [[ ! -d /opt/conda ]]; then + echo "[startup] Installing miniforge ..." + MINI_URL="https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" + wget -q "\$MINI_URL" -O /tmp/miniforge.sh + bash /tmp/miniforge.sh -b -p /opt/conda + rm /tmp/miniforge.sh +fi +export PATH="/opt/conda/bin:\$PATH" +conda init bash + +# ── 3. Clone cosmos-transfer2.5 (ADR-147 §3.2 step 1) ──────────────────────── +COSMOS_DIR="/opt/cosmos-transfer" +if [[ ! -d "\$COSMOS_DIR" ]]; then + echo "[startup] Cloning cosmos-transfer2.5 ..." + git clone --depth=1 https://github.com/nvidia/cosmos-transfer2.git "\$COSMOS_DIR" \ + || git clone --depth=1 https://github.com/NVlabs/cosmos-transfer.git "\$COSMOS_DIR" \ + || true +fi + +# ── 4. Conda env for Cosmos (ADR-147 §3.2 step 2) ──────────────────────────── +source /opt/conda/etc/profile.d/conda.sh + +if ! conda env list | grep -q "^cosmos"; then + echo "[startup] Creating cosmos conda env ..." + if [[ -f "\$COSMOS_DIR/environment.yml" ]]; then + conda env create -f "\$COSMOS_DIR/environment.yml" -n cosmos + else + conda create -y -n cosmos python=3.10 + conda activate cosmos + pip install -q --upgrade pip + pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + pip install -q \ + transformers accelerate diffusers huggingface_hub \ + einops timm numpy scipy imageio imageio-ffmpeg \ + opencv-python-headless pillow tqdm + fi +fi + +conda activate cosmos + +# ── 5. huggingface-cli download Cosmos-Transfer2.5-2B (ADR-147 §3.2 step 3) ── +echo "[startup] Downloading ${HF_COSMOS_MODEL} ..." +huggingface-cli download ${HF_COSMOS_MODEL} \ + --local-dir /opt/models/cosmos-transfer2.5-2b \ + --quiet \ + || echo "[startup] WARNING: Cosmos-Transfer2.5-2B download failed — check HF token" + +# ── 6. huggingface-cli download Cosmos-Reason2-8B (ADR-147 §3.2 step 4) ────── +echo "[startup] Downloading ${HF_REASON_MODEL} ..." +huggingface-cli download ${HF_REASON_MODEL} \ + --local-dir /opt/models/cosmos-reason2-8b \ + --quiet \ + || echo "[startup] WARNING: Cosmos-Reason2-8B download failed — check HF token" + +# ── 7. Workspace prep ───────────────────────────────────────────────────────── +mkdir -p ~/cosmos-results ~/ruview-scripts ~/control-tensors + +echo "[startup] \$(date): Cosmos setup complete — instance ready for eval" +echo "[startup] Models:" +echo "[startup] Transfer2.5-2B: /opt/models/cosmos-transfer2.5-2b" +echo "[startup] Reason2-8B : /opt/models/cosmos-reason2-8b" +echo "[startup] VRAM check:" +nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader +STARTUP_EOF + +# ── Zone availability check ──────────────────────────────────────────────────── +SELECTED_ZONE="$ZONE" +if [[ "$DRY_RUN" == "false" ]]; then + log "Checking A100 80GB availability in $ZONE ..." + AVAIL=$(gcloud compute accelerator-types list \ + --project="$PROJECT" \ + --filter="name=nvidia-a100-80gb AND zone=$ZONE" \ + --format="value(name)" 2>/dev/null | head -1) + if [[ -z "$AVAIL" ]]; then + log "A100 80GB not available in $ZONE — falling back to $FALLBACK_ZONE" + SELECTED_ZONE="$FALLBACK_ZONE" + else + log "A100 80GB confirmed available in $ZONE" + fi +else + log "[DRY-RUN] Would check A100 80GB availability in $ZONE (fallback: $FALLBACK_ZONE)" +fi + +# ── VRAM requirement check ──────────────────────────────────────────────────── +VRAM_REQUIRED_GB="32.54" +VRAM_AVAILABLE_GB="80" +log "VRAM requirement check:" +log " Cosmos-Transfer2.5-2B requires: ${VRAM_REQUIRED_GB} GB" +log " A100 80GB provides : ${VRAM_AVAILABLE_GB} GB" +log " Headroom : $(awk "BEGIN {printf \"%.2f\", $VRAM_AVAILABLE_GB - $VRAM_REQUIRED_GB}") GB" + +# ── Cost estimate ────────────────────────────────────────────────────────────── +log "Cost estimate:" +log " Machine type : $MACHINE_TYPE (1× A100 80GB)" +log " Rate : ~\$$COST_PER_HR/hr (on-demand, $SELECTED_ZONE)" +log " Eval run : ~1-2 hr typical inference session" +log " Est. cost : ~\$$(awk "BEGIN {printf \"%.2f\", $COST_PER_HR * 2}") for 2 hr" +log " Disk : $DISK_SIZE (models + results)" + +# ── Provision instance ──────────────────────────────────────────────────────── +log "Provisioning $INSTANCE_NAME in $SELECTED_ZONE ..." + +run gcloud compute instances create "$INSTANCE_NAME" \ + --project="$PROJECT" \ + --zone="$SELECTED_ZONE" \ + --machine-type="$MACHINE_TYPE" \ + --accelerator="type=nvidia-a100-80gb,count=1" \ + --image-family="$IMAGE_FAMILY" \ + --image-project="$IMAGE_PROJECT" \ + --boot-disk-size="$DISK_SIZE" \ + --boot-disk-type="$DISK_TYPE" \ + --boot-disk-device-name="${INSTANCE_NAME}-disk" \ + --maintenance-policy=TERMINATE \ + --restart-on-failure \ + --metadata-from-file="startup-script=$STARTUP_SCRIPT_FILE" \ + --scopes="cloud-platform" \ + --format="value(name)" + +if [[ "$DRY_RUN" == "true" ]]; then + log "[DRY-RUN] Skipping IP lookup and SSH command output" + exit 0 +fi + +# ── Wait for RUNNING ────────────────────────────────────────────────────────── +log "Waiting for instance to reach RUNNING state ..." +for i in $(seq 1 30); do + STATUS=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$SELECTED_ZONE" \ + --format="value(status)" 2>/dev/null || echo "UNKNOWN") + if [[ "$STATUS" == "RUNNING" ]]; then + break + fi + sleep 10 + if [[ $i -eq 30 ]]; then + log "ERROR: Instance did not reach RUNNING within 5 min" >&2 + exit 1 + fi +done + +# ── Print connection info ───────────────────────────────────────────────────── +INSTANCE_IP=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$SELECTED_ZONE" \ + --format="value(networkInterfaces[0].accessConfigs[0].natIP)") + +log "Instance ready:" +log " Name : $INSTANCE_NAME" +log " Zone : $SELECTED_ZONE" +log " IP : $INSTANCE_IP" +log " A100 VRAM : 80 GB (Cosmos-Transfer2.5-2B needs 32.54 GB)" +log " SSH : gcloud compute ssh $INSTANCE_NAME --project=$PROJECT --zone=$SELECTED_ZONE" +log "" +log "IMPORTANT: Model downloads run in background (~30-60 min for full weights)." +log " Monitor: ssh @$INSTANCE_IP 'tail -f /var/log/cosmos-startup.log'" +log "" +log "Next step:" +log " bash scripts/gcp/cosmos_eval.sh $INSTANCE_IP" diff --git a/scripts/gcp/provision_training.sh b/scripts/gcp/provision_training.sh new file mode 100755 index 00000000..3ad4030f --- /dev/null +++ b/scripts/gcp/provision_training.sh @@ -0,0 +1,200 @@ +#!/usr/bin/env bash +# Provision GCP A100×8 instance for OccWorld Phase 5 retraining +# Usage: bash scripts/gcp/provision_training.sh [--dry-run] +# +# Provisions an a2-highgpu-8g (8× A100 40GB) in us-central1-a (fallback us-east1-b). +# GCP project: cognitum-20260110 +# Auth: ruv@ruv.net (gcloud must already be authenticated) + +set -euo pipefail + +# ── Constants ────────────────────────────────────────────────────────────────── +PROJECT="cognitum-20260110" +INSTANCE_NAME="occworld-train-$(date +%Y%m%d)" +MACHINE_TYPE="a2-highgpu-8g" +PRIMARY_ZONE="us-central1-a" +FALLBACK_ZONE="us-east1-b" +IMAGE_FAMILY="pytorch-latest-gpu" +IMAGE_PROJECT="deeplearning-platform-release" +DISK_SIZE="500GB" +DISK_TYPE="pd-ssd" +# Cost reference: a2-highgpu-8g ~$29.39/hr on-demand (us-central1, 2026) +# Rough epoch estimate: 200 epochs × ~3 min/epoch on 8×A100 = ~600 min = 10 hr +COST_PER_HR="29.39" +EPOCH_HOURS="10" + +# ── Flags ───────────────────────────────────────────────────────────────────── +DRY_RUN=false +for arg in "$@"; do + case "$arg" in + --dry-run) DRY_RUN=true ;; + -h|--help) + echo "Usage: $0 [--dry-run]" + echo " --dry-run Echo gcloud commands without executing them" + exit 0 + ;; + *) + echo "Unknown argument: $arg" >&2 + echo "Usage: $0 [--dry-run]" >&2 + exit 1 + ;; + esac +done + +# ── Helpers ─────────────────────────────────────────────────────────────────── +run() { + if [[ "$DRY_RUN" == "true" ]]; then + echo "[DRY-RUN] $*" + else + "$@" + fi +} + +log() { echo "[provision_training] $*"; } + +# ── Startup script (embedded heredoc) ───────────────────────────────────────── +# Written to a temp file so gcloud can reference it via --metadata-from-file. +STARTUP_SCRIPT_FILE="$(mktemp /tmp/startup_training_XXXXXX.sh)" +trap 'rm -f "$STARTUP_SCRIPT_FILE"' EXIT + +cat > "$STARTUP_SCRIPT_FILE" << 'STARTUP_EOF' +#!/usr/bin/env bash +set -euo pipefail +LOGFILE="/var/log/ruview-startup.log" +exec > >(tee -a "$LOGFILE") 2>&1 + +echo "[startup] $(date): beginning environment setup" + +# ── 1. System packages ──────────────────────────────────────────────────────── +apt-get update -qq +apt-get install -y -qq git rsync wget curl htop nvtop screen tmux + +# ── 2. Conda (miniforge) ────────────────────────────────────────────────────── +if [[ ! -d /opt/conda ]]; then + echo "[startup] Installing miniforge ..." + MINI_URL="https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" + wget -q "$MINI_URL" -O /tmp/miniforge.sh + bash /tmp/miniforge.sh -b -p /opt/conda + rm /tmp/miniforge.sh +fi +export PATH="/opt/conda/bin:$PATH" +conda init bash + +# ── 3. OccWorld conda env ───────────────────────────────────────────────────── +if ! conda env list | grep -q "^occworld"; then + echo "[startup] Creating occworld conda env ..." + conda create -y -n occworld python=3.10 +fi + +# shellcheck source=/dev/null +source /opt/conda/etc/profile.d/conda.sh +conda activate occworld + +# PyTorch 2.x + CUDA 12 (deeplearning image ships CUDA 12) +pip install -q --upgrade pip +pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 +pip install -q \ + numpy scipy einops timm mmcv-full \ + tensorboard wandb tqdm pyyaml \ + huggingface_hub accelerate + +# ── 4. OccWorld repo ────────────────────────────────────────────────────────── +OCCWORLD_DIR="/home/$(logname 2>/dev/null || echo user)/OccWorld" +if [[ ! -d "$OCCWORLD_DIR" ]]; then + echo "[startup] Cloning OccWorld ..." + git clone --depth=1 https://github.com/OpenDriveLab/OccWorld.git "$OCCWORLD_DIR" +fi +cd "$OCCWORLD_DIR" +pip install -q -r requirements.txt 2>/dev/null || true + +# ── 5. RuView repo sync placeholder ────────────────────────────────────────── +# Actual repo sync is done by run_training.sh via rsync before SSH commands. +mkdir -p ~/ruview-scripts ~/checkpoints/vqvae ~/checkpoints/transformer + +echo "[startup] $(date): setup complete — instance ready for training" +STARTUP_EOF + +# ── Zone availability check ──────────────────────────────────────────────────── +ZONE="$PRIMARY_ZONE" +if [[ "$DRY_RUN" == "false" ]]; then + log "Checking A100 availability in $PRIMARY_ZONE ..." + AVAIL=$(gcloud compute accelerator-types list \ + --project="$PROJECT" \ + --filter="name=nvidia-tesla-a100 AND zone=$PRIMARY_ZONE" \ + --format="value(name)" 2>/dev/null | head -1) + if [[ -z "$AVAIL" ]]; then + log "A100 not available in $PRIMARY_ZONE — falling back to $FALLBACK_ZONE" + ZONE="$FALLBACK_ZONE" + else + log "A100 confirmed available in $PRIMARY_ZONE" + fi +else + log "[DRY-RUN] Would check A100 availability in $PRIMARY_ZONE (fallback: $FALLBACK_ZONE)" +fi + +# ── Cost estimate ────────────────────────────────────────────────────────────── +TOTAL_COST=$(awk "BEGIN {printf \"%.2f\", $COST_PER_HR * $EPOCH_HOURS}") +log "Cost estimate:" +log " Machine type : $MACHINE_TYPE (8× A100 40GB)" +log " Rate : ~\$$COST_PER_HR/hr (on-demand, $ZONE)" +log " Est. duration: ~${EPOCH_HOURS} hr (200 epochs, 8×A100)" +log " Est. total : ~\$$TOTAL_COST" +log " Tip: Use --preemptible to cut cost ~60% at the risk of interruptions" + +# ── Provision instance ──────────────────────────────────────────────────────── +log "Provisioning $INSTANCE_NAME in $ZONE ..." + +run gcloud compute instances create "$INSTANCE_NAME" \ + --project="$PROJECT" \ + --zone="$ZONE" \ + --machine-type="$MACHINE_TYPE" \ + --accelerator="type=nvidia-tesla-a100,count=8" \ + --image-family="$IMAGE_FAMILY" \ + --image-project="$IMAGE_PROJECT" \ + --boot-disk-size="$DISK_SIZE" \ + --boot-disk-type="$DISK_TYPE" \ + --boot-disk-device-name="${INSTANCE_NAME}-disk" \ + --maintenance-policy=TERMINATE \ + --restart-on-failure \ + --metadata-from-file="startup-script=$STARTUP_SCRIPT_FILE" \ + --scopes="cloud-platform" \ + --format="value(name)" + +if [[ "$DRY_RUN" == "true" ]]; then + log "[DRY-RUN] Skipping IP lookup and SSH command output" + exit 0 +fi + +# ── Wait for instance to be ready ───────────────────────────────────────────── +log "Waiting for instance to reach RUNNING state ..." +for i in $(seq 1 30); do + STATUS=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$ZONE" \ + --format="value(status)" 2>/dev/null || echo "UNKNOWN") + if [[ "$STATUS" == "RUNNING" ]]; then + break + fi + sleep 10 + if [[ $i -eq 30 ]]; then + log "ERROR: Instance did not reach RUNNING within 5 min" >&2 + exit 1 + fi +done + +# ── Print connection info ───────────────────────────────────────────────────── +INSTANCE_IP=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$ZONE" \ + --format="value(networkInterfaces[0].accessConfigs[0].natIP)") + +log "Instance ready:" +log " Name : $INSTANCE_NAME" +log " Zone : $ZONE" +log " IP : $INSTANCE_IP" +log " SSH : gcloud compute ssh $INSTANCE_NAME --project=$PROJECT --zone=$ZONE" +log " SSH IP : ssh $(gcloud config get-value account 2>/dev/null)@$INSTANCE_IP" +log "" +log "Startup script is running in background (/var/log/ruview-startup.log)." +log "Wait 3-5 min for conda/deps before running run_training.sh." +log "" +log "Next step:" +log " bash scripts/gcp/run_training.sh $INSTANCE_IP " diff --git a/scripts/gcp/run_training.sh b/scripts/gcp/run_training.sh new file mode 100755 index 00000000..64938931 --- /dev/null +++ b/scripts/gcp/run_training.sh @@ -0,0 +1,203 @@ +#!/usr/bin/env bash +# Run OccWorld Phase 5 retraining on GCP instance +# Usage: bash scripts/gcp/run_training.sh +# +# Rsyncs snapshots and scripts to the instance, then runs: +# Stage 1: VQVAE retraining (torchrun, 8 GPUs, 200 epochs) +# Stage 2: Transformer retraining (torchrun, 8 GPUs, 200 epochs) +# Downloads checkpoints on completion. + +set -euo pipefail + +# ── Usage ───────────────────────────────────────────────────────────────────── +if [[ $# -lt 2 ]]; then + echo "Usage: $0 " >&2 + echo "" + echo " INSTANCE_IP External IP of the GCP training instance" + echo " SNAPSHOT_DIR Local directory containing WorldGraph JSON snapshots" + echo " (produced by: python scripts/occworld_retrain.py record ...)" + echo "" + echo "Example:" + echo " $0 34.123.45.67 /tmp/snapshots" + exit 1 +fi + +INSTANCE_IP="$1" +SNAPSHOT_DIR="$2" +GCP_USER="${GCP_USER:-$(gcloud config get-value account 2>/dev/null | cut -d@ -f1)}" +REMOTE="${GCP_USER}@${INSTANCE_IP}" +LOCAL_SCRIPTS_DIR="$(cd "$(dirname "$0")/../.." && pwd)/scripts" +OUTPUT_DIR="./out/gcp-checkpoints" +REMOTE_SNAPSHOTS="/tmp/snapshots" +REMOTE_SCRIPTS="~/ruview-scripts" +REMOTE_CHECKPOINTS="~/checkpoints" + +# ── Validation ──────────────────────────────────────────────────────────────── +log() { echo "[run_training] $*"; } + +if [[ ! -d "$SNAPSHOT_DIR" ]]; then + echo "ERROR: SNAPSHOT_DIR does not exist: $SNAPSHOT_DIR" >&2 + exit 1 +fi + +SNAPSHOT_COUNT=$(find "$SNAPSHOT_DIR" -name "*.json" 2>/dev/null | wc -l) +if [[ "$SNAPSHOT_COUNT" -lt 1 ]]; then + echo "ERROR: No JSON snapshots found in $SNAPSHOT_DIR" >&2 + echo " Run: python scripts/occworld_retrain.py record --server http://localhost:8080 --out-dir $SNAPSHOT_DIR" >&2 + exit 1 +fi + +SNAPSHOT_SIZE_MB=$(du -sm "$SNAPSHOT_DIR" 2>/dev/null | awk '{print $1}') +log "Dataset: $SNAPSHOT_COUNT JSON snapshots, ~${SNAPSHOT_SIZE_MB} MB in $SNAPSHOT_DIR" + +# ── Runtime estimate ───────────────────────────────────────────────────────── +# Empirical: on 8×A100 40GB, ~3 min/epoch for VQVAE at typical batch size. +# Transformer stage is similar. 200 epochs × 2 stages × 3 min = ~20 hr total. +ESTIMATED_HOURS=20 +log "Runtime estimate: ~${ESTIMATED_HOURS} hr for 200 epochs × 2 stages on 8×A100" +log " Stage 1 VQVAE: ~10 hr" +log " Stage 2 Transformer: ~10 hr" +log " (Varies with dataset size: ${SNAPSHOT_SIZE_MB} MB)" + +# ── SSH connectivity check ──────────────────────────────────────────────────── +log "Checking SSH connectivity to $REMOTE ..." +SSH_OPTS="-o StrictHostKeyChecking=no -o ConnectTimeout=15 -o BatchMode=yes" +if ! ssh $SSH_OPTS "$REMOTE" "echo ok" &>/dev/null; then + echo "ERROR: Cannot SSH to $REMOTE" >&2 + echo " Ensure the instance is running and your SSH key is authorized." >&2 + echo " Try: gcloud compute ssh --project=cognitum-20260110" >&2 + exit 1 +fi +log "SSH connection OK" + +# ── Stage 0: Startup script completion check ────────────────────────────────── +log "Checking that startup script completed ..." +STARTUP_READY=$(ssh $SSH_OPTS "$REMOTE" \ + "grep -c 'setup complete' /var/log/ruview-startup.log 2>/dev/null || echo 0") +if [[ "$STARTUP_READY" -lt 1 ]]; then + log "WARNING: Startup script may not have finished yet." + log " Check /var/log/ruview-startup.log on the instance." + log " Continuing anyway — conda env may need more time." +fi + +# ── Stage 1 prep: rsync snapshots ──────────────────────────────────────────── +log "Rsyncing snapshots → $REMOTE:$REMOTE_SNAPSHOTS ..." +rsync -avz --progress --stats \ + -e "ssh $SSH_OPTS" \ + "$SNAPSHOT_DIR/" \ + "${REMOTE}:${REMOTE_SNAPSHOTS}/" +log "Snapshot sync complete" + +# ── Stage 1 prep: rsync retraining scripts ─────────────────────────────────── +log "Rsyncing scripts → $REMOTE:$REMOTE_SCRIPTS ..." +ssh $SSH_OPTS "$REMOTE" "mkdir -p $REMOTE_SCRIPTS" +rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + --include="occworld_retrain.py" \ + --include="ruview_occ_dataset.py" \ + --exclude="*.sh" \ + --exclude="gcp/" \ + "$LOCAL_SCRIPTS_DIR/" \ + "${REMOTE}:${REMOTE_SCRIPTS}/" +log "Script sync complete" + +# ── Stage 1: VQVAE retraining ──────────────────────────────────────────────── +log "=== Stage 1: VQVAE retraining (200 epochs, 8×A100) ===" +VQVAE_START=$(date +%s) + +ssh $SSH_OPTS "$REMOTE" bash << 'REMOTE_STAGE1' +set -euo pipefail +source /opt/conda/etc/profile.d/conda.sh +conda activate occworld + +export PYTHONPATH="$PYTHONPATH:$HOME/OccWorld:$HOME/ruview-scripts" +mkdir -p ~/checkpoints/vqvae + +echo "[stage1] $(date): starting VQVAE torchrun" +torchrun \ + --nproc_per_node=8 \ + --master_port=29500 \ + ~/ruview-scripts/occworld_retrain.py vqvae \ + --snapshots /tmp/snapshots/ \ + --work-dir ~/checkpoints/vqvae \ + --epochs 200 + +echo "[stage1] $(date): VQVAE training complete" +ls -lh ~/checkpoints/vqvae/ +REMOTE_STAGE1 + +VQVAE_END=$(date +%s) +VQVAE_MIN=$(( (VQVAE_END - VQVAE_START) / 60 )) +log "Stage 1 complete in ${VQVAE_MIN} min" + +# ── Stage 2: Transformer retraining ────────────────────────────────────────── +log "=== Stage 2: Transformer retraining (200 epochs, 8×A100) ===" +XFMR_START=$(date +%s) + +ssh $SSH_OPTS "$REMOTE" bash << 'REMOTE_STAGE2' +set -euo pipefail +source /opt/conda/etc/profile.d/conda.sh +conda activate occworld + +export PYTHONPATH="$PYTHONPATH:$HOME/OccWorld:$HOME/ruview-scripts" +mkdir -p ~/checkpoints/transformer + +# Locate the latest VQVAE checkpoint +VQVAE_CKPT=$(ls -t ~/checkpoints/vqvae/*.pth 2>/dev/null | head -1) +if [[ -z "$VQVAE_CKPT" ]]; then + echo "[stage2] ERROR: No VQVAE checkpoint found in ~/checkpoints/vqvae/" >&2 + exit 1 +fi +echo "[stage2] Using VQVAE checkpoint: $VQVAE_CKPT" +echo "[stage2] $(date): starting Transformer torchrun" + +torchrun \ + --nproc_per_node=8 \ + --master_port=29501 \ + ~/ruview-scripts/occworld_retrain.py transformer \ + --snapshots /tmp/snapshots/ \ + --vqvae-checkpoint "$VQVAE_CKPT" \ + --work-dir ~/checkpoints/transformer \ + --epochs 200 + +echo "[stage2] $(date): Transformer training complete" +ls -lh ~/checkpoints/transformer/ +REMOTE_STAGE2 + +XFMR_END=$(date +%s) +XFMR_MIN=$(( (XFMR_END - XFMR_START) / 60 )) +log "Stage 2 complete in ${XFMR_MIN} min" + +# ── Download checkpoints ────────────────────────────────────────────────────── +log "Downloading checkpoints → $OUTPUT_DIR ..." +mkdir -p "$OUTPUT_DIR" + +rsync -avz --progress --stats \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:${REMOTE_CHECKPOINTS}/" \ + "$OUTPUT_DIR/" + +# Verify download +LOCAL_FILE_COUNT=$(find "$OUTPUT_DIR" -type f | wc -l) +LOCAL_SIZE_MB=$(du -sm "$OUTPUT_DIR" 2>/dev/null | awk '{print $1}') +log "Downloaded $LOCAL_FILE_COUNT files, ~${LOCAL_SIZE_MB} MB to $OUTPUT_DIR" + +if [[ "$LOCAL_FILE_COUNT" -lt 2 ]]; then + echo "WARNING: Expected at least one checkpoint per stage (got $LOCAL_FILE_COUNT files)" >&2 +fi + +# ── Summary ─────────────────────────────────────────────────────────────────── +TOTAL_MIN=$(( (XFMR_END - VQVAE_START) / 60 )) +TOTAL_HR=$(awk "BEGIN {printf \"%.2f\", $TOTAL_MIN / 60}") +COST=$(awk "BEGIN {printf \"%.2f\", 29.39 * $TOTAL_HR}") +log "" +log "=== Training complete ===" +log " Stage 1 (VQVAE) : ${VQVAE_MIN} min" +log " Stage 2 (Transformer): ${XFMR_MIN} min" +log " Total wall time : ${TOTAL_MIN} min (${TOTAL_HR} hr)" +log " Estimated compute cost: ~\$$COST (at \$29.39/hr on-demand)" +log " Checkpoints in : $OUTPUT_DIR" +log "" +log "Next steps:" +log " Teardown: bash scripts/gcp/teardown.sh " +log " Evaluate: bash scripts/gcp/cosmos_eval.sh " diff --git a/scripts/gcp/teardown.sh b/scripts/gcp/teardown.sh new file mode 100755 index 00000000..645d49e3 --- /dev/null +++ b/scripts/gcp/teardown.sh @@ -0,0 +1,211 @@ +#!/usr/bin/env bash +# Safely teardown a GCP training or evaluation instance +# Usage: bash scripts/gcp/teardown.sh [--zone ] [--skip-download] +# +# Downloads all checkpoints/results to ./out/gcp-checkpoints//, +# verifies the download, then deletes the instance. +# GCP project: cognitum-20260110 + +set -euo pipefail + +# ── Usage ───────────────────────────────────────────────────────────────────── +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [--zone ] [--skip-download]" >&2 + echo "" + echo " INSTANCE_NAME Name of the GCP instance to teardown" + echo " --zone GCP zone (default: auto-detected)" + echo " --skip-download Delete instance without downloading checkpoints" + echo "" + echo "Example:" + echo " $0 occworld-train-20260529" + echo " $0 cosmos-eval-20260529 --zone us-east1-b" + exit 1 +fi + +INSTANCE_NAME="$1" +shift + +PROJECT="cognitum-20260110" +ZONE="" +SKIP_DOWNLOAD=false + +while [[ $# -gt 0 ]]; do + case "$1" in + --zone) ZONE="$2"; shift 2 ;; + --skip-download) SKIP_DOWNLOAD=true; shift ;; + -h|--help) + echo "Usage: $0 [--zone ] [--skip-download]" + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + exit 1 + ;; + esac +done + +OUTPUT_BASE="./out/gcp-checkpoints" +OUTPUT_DIR="${OUTPUT_BASE}/${INSTANCE_NAME}" +GCP_USER="${GCP_USER:-$(gcloud config get-value account 2>/dev/null | cut -d@ -f1)}" +SSH_OPTS="-o StrictHostKeyChecking=no -o ConnectTimeout=20 -o BatchMode=yes" + +log() { echo "[teardown] $*"; } + +# ── Check instance exists ───────────────────────────────────────────────────── +log "Looking up instance $INSTANCE_NAME in project $PROJECT ..." + +if [[ -z "$ZONE" ]]; then + # Auto-detect zone + ZONE=$(gcloud compute instances list \ + --project="$PROJECT" \ + --filter="name=$INSTANCE_NAME" \ + --format="value(zone)" 2>/dev/null | head -1) + if [[ -z "$ZONE" ]]; then + echo "ERROR: Instance '$INSTANCE_NAME' not found in project $PROJECT" >&2 + echo " Check: gcloud compute instances list --project=$PROJECT" >&2 + exit 1 + fi + # Strip the full zone URL to just the zone name + ZONE=$(basename "$ZONE") +fi + +STATUS=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" \ + --zone="$ZONE" \ + --format="value(status)" 2>/dev/null || echo "NOT_FOUND") + +if [[ "$STATUS" == "NOT_FOUND" ]]; then + echo "ERROR: Instance '$INSTANCE_NAME' not found in zone $ZONE" >&2 + exit 1 +fi + +log "Found: $INSTANCE_NAME (zone=$ZONE, status=$STATUS)" + +# ── Get instance IP and uptime ──────────────────────────────────────────────── +INSTANCE_IP=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$ZONE" \ + --format="value(networkInterfaces[0].accessConfigs[0].natIP)" 2>/dev/null || echo "") + +CREATION_TS=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$ZONE" \ + --format="value(creationTimestamp)" 2>/dev/null || echo "") + +# ── Uptime and cost estimate ────────────────────────────────────────────────── +if [[ -n "$CREATION_TS" ]]; then + CREATION_EPOCH=$(date -d "$CREATION_TS" +%s 2>/dev/null || echo "0") + NOW_EPOCH=$(date +%s) + UPTIME_SEC=$(( NOW_EPOCH - CREATION_EPOCH )) + UPTIME_HR=$(awk "BEGIN {printf \"%.2f\", $UPTIME_SEC / 3600}") + + # Determine cost rate by machine type + MACHINE_TYPE=$(gcloud compute instances describe "$INSTANCE_NAME" \ + --project="$PROJECT" --zone="$ZONE" \ + --format="value(machineType)" 2>/dev/null | basename) + + case "$MACHINE_TYPE" in + a2-highgpu-8g) RATE="29.39" ;; + a2-ultragpu-1g) RATE="5.08" ;; + a2-highgpu-1g) RATE="3.67" ;; + *) RATE="10.00" ;; + esac + + TOTAL_COST=$(awk "BEGIN {printf \"%.2f\", $RATE * $UPTIME_HR}") + log "Uptime : ${UPTIME_HR} hr (${UPTIME_SEC}s)" + log "Machine : $MACHINE_TYPE (~\$$RATE/hr)" + log "Est cost: ~\$$TOTAL_COST" +fi + +# ── Download checkpoints / results ─────────────────────────────────────────── +if [[ "$SKIP_DOWNLOAD" == "false" ]] && [[ -n "$INSTANCE_IP" ]] && [[ "$STATUS" == "RUNNING" ]]; then + log "Downloading checkpoints/results → $OUTPUT_DIR ..." + mkdir -p "$OUTPUT_DIR" + + REMOTE="${GCP_USER}@${INSTANCE_IP}" + + # Determine what to download based on instance name prefix + if [[ "$INSTANCE_NAME" == occworld-* ]]; then + log "Training instance — downloading ~/checkpoints/" + rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:~/checkpoints/" \ + "$OUTPUT_DIR/checkpoints/" \ + || { echo "WARNING: rsync failed — some files may not have downloaded" >&2; } + + elif [[ "$INSTANCE_NAME" == cosmos-* ]]; then + log "Eval instance — downloading ~/cosmos-results/" + rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:~/cosmos-results/" \ + "$OUTPUT_DIR/cosmos-results/" \ + || { echo "WARNING: rsync failed — some files may not have downloaded" >&2; } + + else + log "Unknown instance type — downloading ~/checkpoints/ and ~/cosmos-results/ (if they exist)" + rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:~/checkpoints/" \ + "$OUTPUT_DIR/checkpoints/" \ + 2>/dev/null || true + rsync -avz --progress \ + -e "ssh $SSH_OPTS" \ + "${REMOTE}:~/cosmos-results/" \ + "$OUTPUT_DIR/cosmos-results/" \ + 2>/dev/null || true + fi + + # ── Verify download ───────────────────────────────────────────────────────── + LOCAL_FILE_COUNT=$(find "$OUTPUT_DIR" -type f 2>/dev/null | wc -l) + LOCAL_SIZE=$(du -sh "$OUTPUT_DIR" 2>/dev/null | awk '{print $1}') + log "Download verification:" + log " Files : $LOCAL_FILE_COUNT" + log " Size : $LOCAL_SIZE" + log " Path : $OUTPUT_DIR" + + if [[ "$LOCAL_FILE_COUNT" -lt 1 ]]; then + echo "WARNING: No files were downloaded from $REMOTE" >&2 + echo " Proceeding with deletion — use --skip-download to bypass download entirely." >&2 + read -r -p "Continue with instance deletion? [y/N] " CONFIRM + if [[ "$CONFIRM" != "y" && "$CONFIRM" != "Y" ]]; then + log "Teardown aborted — instance NOT deleted" + exit 0 + fi + fi + +elif [[ "$SKIP_DOWNLOAD" == "true" ]]; then + log "Skipping checkpoint download (--skip-download)" +elif [[ "$STATUS" != "RUNNING" ]]; then + log "Instance is $STATUS — cannot rsync; skipping download" +fi + +# ── Confirm deletion ────────────────────────────────────────────────────────── +echo "" +log "About to DELETE instance: $INSTANCE_NAME (zone=$ZONE, project=$PROJECT)" +if [[ "$LOCAL_FILE_COUNT" -gt 0 ]] || [[ "$SKIP_DOWNLOAD" == "true" ]]; then + log "Checkpoints are saved locally at: $OUTPUT_DIR" +fi +echo "" +read -r -p "[teardown] Confirm deletion of '$INSTANCE_NAME'? [y/N] " CONFIRM +if [[ "$CONFIRM" != "y" && "$CONFIRM" != "Y" ]]; then + log "Teardown aborted — instance NOT deleted" + exit 0 +fi + +# ── Delete instance ─────────────────────────────────────────────────────────── +log "Deleting instance $INSTANCE_NAME ..." +gcloud compute instances delete "$INSTANCE_NAME" \ + --project="$PROJECT" \ + --zone="$ZONE" \ + --quiet + +log "Instance deleted successfully" + +# ── Final cost summary ──────────────────────────────────────────────────────── +log "" +log "=== Teardown complete ===" +if [[ -n "${TOTAL_COST:-}" ]]; then + log "Final cost estimate: ~\$$TOTAL_COST (${UPTIME_HR} hr × \$$RATE/hr for $MACHINE_TYPE)" +fi +if [[ "$SKIP_DOWNLOAD" == "false" ]] && [[ -d "$OUTPUT_DIR" ]]; then + log "Checkpoints at : $OUTPUT_DIR" + log "Files kept : $LOCAL_FILE_COUNT (${LOCAL_SIZE})" +fi diff --git a/v2/Cargo.lock b/v2/Cargo.lock index cdb82e20..b1f48734 100644 --- a/v2/Cargo.lock +++ b/v2/Cargo.lock @@ -10766,6 +10766,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "wifi-densepose-occworld-candle" +version = "0.3.0" +dependencies = [ + "approx", + "candle-core 0.9.2", + "candle-nn 0.9.2", + "safetensors 0.4.5", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "wifi-densepose-pointcloud" version = "0.1.0" diff --git a/v2/Cargo.toml b/v2/Cargo.toml index ef824b25..b50d93a7 100644 --- a/v2/Cargo.toml +++ b/v2/Cargo.toml @@ -58,6 +58,10 @@ members = [ # ADR-147: OccWorld thin-client bridge — WorldGraph PersonTrack history → # OccWorld Python subprocess → TrajectoryPrior injection into pose tracker. "crates/wifi-densepose-worldmodel", + # ADR-147 (Phase 5): OccWorld TransVQVAE ported to Candle — native Rust + # inference without Python/IPC overhead. Loaded alongside the Python bridge + # as a faster alternative once Phase-5 weights are available. + "crates/wifi-densepose-occworld-candle", # rvCSI — edge RF sensing runtime (ADR-095 platform, ADR-096 FFI/crate layout): # lives in its own repo (https://github.com/ruvnet/rvcsi), vendored here as # `vendor/rvcsi` and published to crates.io as `rvcsi-*` 0.3.x. Depend on the diff --git a/v2/crates/wifi-densepose-occworld-candle/Cargo.toml b/v2/crates/wifi-densepose-occworld-candle/Cargo.toml new file mode 100644 index 00000000..9f3779f4 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "wifi-densepose-occworld-candle" +description = "ADR-147 — OccWorld TransVQVAE inference ported to Candle (Rust-native, no Python IPC)" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +# Candle ML framework — pin to 0.9 (same as cog-person-count). +# The `cuda` feature is opt-in; CPU is the default. +candle-core = { version = "0.9", default-features = false } +candle-nn = { version = "0.9", default-features = false } +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +thiserror.workspace = true +tokio = { version = "1", features = ["fs", "macros"] } +safetensors = "0.4" + +[dev-dependencies] +approx = "0.5" + +[features] +default = [] +cuda = ["candle-core/cuda", "candle-nn/cuda"] + +[lints.rust] +unsafe_code = "forbid" +missing_docs = "warn" diff --git a/v2/crates/wifi-densepose-occworld-candle/src/config.rs b/v2/crates/wifi-densepose-occworld-candle/src/config.rs new file mode 100644 index 00000000..75234fa8 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/config.rs @@ -0,0 +1,101 @@ +//! OccWorld model configuration. +//! +//! All constants match the Python reference implementation in +//! `OccWorld/model/occworld.py`. Changing a value here must be +//! reflected in a matching weight checkpoint, because the tensor +//! shapes are baked into the SafeTensors file. + +/// Complete configuration for the OccWorld TransVQVAE model. +/// +/// The defaults reproduce the published 72.4 M-parameter config used during +/// training on nuScenes. Pass a custom `OccWorldConfig` to `OccWorldCandle` +/// when loading a fine-tuned checkpoint with different hyper-parameters. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct OccWorldConfig { + // ── Voxel grid ──────────────────────────────────────────────────────── + /// Grid width (X-axis). Python: `occ_size[0]` = 200. + pub grid_h: usize, + /// Grid depth (Y-axis). Python: `occ_size[1]` = 200. + pub grid_w: usize, + /// Grid height (Z-axis). Python: `occ_size[2]` = 16. + pub grid_d: usize, + + // ── Semantic labels ─────────────────────────────────────────────────── + /// Total number of semantic classes (0-17). nuScenes: 18. + pub num_classes: usize, + /// Class index reserved for "free space / unknown". nuScenes: 17. + pub free_class: u8, + + // ── VQVAE dimensions ───────────────────────────────────────────────── + /// Base channel count for the encoder/decoder ResNet blocks. + /// Embedding dimension per voxel position: 18 classes → 64-dim vectors. + pub base_channels: usize, + /// Latent channels produced by the encoder (z). Python: 128. + pub z_channels: usize, + + // ── Vector-quantisation codebook ───────────────────────────────────── + /// Number of discrete codes in the codebook. Python: 512. + pub codebook_size: usize, + /// Dimension of each codebook entry. Python: 512. + pub embed_dim: usize, + + // ── Temporal / spatial layout ───────────────────────────────────────── + /// Number of past occupancy frames used as context. Python: 15. + pub num_frames: usize, + /// Token grid height after VQVAE encoder (H/4). Python: 50. + pub token_h: usize, + /// Token grid width after VQVAE encoder (W/4). Python: 50. + pub token_w: usize, + + // ── Transformer ─────────────────────────────────────────────────────── + /// Number of attention heads in the transformer. + pub num_heads: usize, + /// Number of encoder layers in the UNet-style transformer. + pub num_layers: usize, + /// Feed-forward hidden size inside each transformer layer. + pub ffn_hidden: usize, +} + +impl Default for OccWorldConfig { + fn default() -> Self { + Self { + grid_h: 200, + grid_w: 200, + grid_d: 16, + num_classes: 18, + free_class: 17, + base_channels: 64, + z_channels: 128, + codebook_size: 512, + embed_dim: 512, + num_frames: 15, + token_h: 50, + token_w: 50, + num_heads: 8, + num_layers: 2, + ffn_hidden: 2048, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_defaults() { + let cfg = OccWorldConfig::default(); + assert_eq!(cfg.grid_h, 200); + assert_eq!(cfg.grid_w, 200); + assert_eq!(cfg.grid_d, 16); + assert_eq!(cfg.num_classes, 18); + assert_eq!(cfg.free_class, 17); + assert_eq!(cfg.base_channels, 64); + assert_eq!(cfg.z_channels, 128); + assert_eq!(cfg.codebook_size, 512); + assert_eq!(cfg.embed_dim, 512); + assert_eq!(cfg.num_frames, 15); + assert_eq!(cfg.token_h, 50); + assert_eq!(cfg.token_w, 50); + } +} diff --git a/v2/crates/wifi-densepose-occworld-candle/src/error.rs b/v2/crates/wifi-densepose-occworld-candle/src/error.rs new file mode 100644 index 00000000..20b3f823 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/error.rs @@ -0,0 +1,29 @@ +//! Error types for `wifi-densepose-occworld-candle`. + +/// All errors that can occur during OccWorld inference. +#[derive(Debug, thiserror::Error)] +pub enum OccWorldError { + /// A Candle operation failed. + #[error("candle error: {0}")] + Candle(#[from] candle_core::Error), + + /// Input or output tensor has an unexpected shape. + #[error("shape mismatch: {0}")] + ShapeMismatch(String), + + /// The checkpoint file could not be found or opened. + #[error("checkpoint not found: {0}")] + CheckpointNotFound(String), + + /// The checkpoint file exists but could not be parsed. + #[error("checkpoint parse error: {0}")] + CheckpointParse(String), + + /// A required tensor key is missing from the checkpoint. + #[error("missing weight key '{0}' in checkpoint")] + MissingKey(String), + + /// I/O error reading the checkpoint file. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/v2/crates/wifi-densepose-occworld-candle/src/inference.rs b/v2/crates/wifi-densepose-occworld-candle/src/inference.rs new file mode 100644 index 00000000..47ca7ea6 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/inference.rs @@ -0,0 +1,407 @@ +//! Top-level inference engine — `OccWorldCandle`. +//! +//! Provides the public-facing API: +//! - `OccWorldCandle::load` — load from a SafeTensors checkpoint +//! - `OccWorldCandle::dummy` — random weights for testing / benchmarking +//! - `OccWorldCandle::predict` — infer 15 future occupancy frames +//! +//! The `dummy` constructor allows end-to-end benchmarking (wall-clock timing, +//! shape verification, memory footprint) before the Phase-5 checkpoint exists. + +use std::path::Path; +use std::time::Instant; + +use candle_core::{DType, Device, Tensor}; +use candle_nn::VarBuilder; + +use crate::config::OccWorldConfig; +use crate::error::OccWorldError; +use crate::transformer::OccWorldTransformer; +use crate::vqvae::{decode_to_logits, encode_occupancy, VQVAEComponents}; + +// ── Output types ───────────────────────────────────────────────────────────── + +/// A predicted future trajectory waypoint in 3-D grid coordinates. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TrajectoryWaypoint { + /// Frame index within the prediction horizon (0 = first predicted frame). + pub frame: usize, + /// Grid X position of the predicted agent centroid. + pub grid_x: f32, + /// Grid Y position of the predicted agent centroid. + pub grid_y: f32, + /// Grid Z position of the predicted agent centroid. + pub grid_z: f32, + /// Confidence score in `[0, 1]`. + pub confidence: f32, +} + +/// Outputs produced by one call to `OccWorldCandle::predict`. +pub struct InferenceOutput { + /// Predicted semantic class for each voxel. + /// + /// Shape: `(1, 15, 200, 200, 16)`, dtype `u8`. + /// Values are class indices in `[0, num_classes)`. + pub sem_pred: Tensor, + + /// Trajectory priors extracted from the predicted occupancy. + /// + /// One waypoint per predicted frame, centred on the non-free voxel + /// with the highest occupancy probability. Empty when the model + /// predicts all frames as free space. + pub trajectory_priors: Vec, + + /// Wall-clock time for the full `predict` call in milliseconds. + pub inference_ms: f64, +} + +// ── Main engine ─────────────────────────────────────────────────────────────── + +/// Native Rust OccWorld inference engine backed by Candle. +/// +/// # Loading +/// +/// ```no_run +/// # use wifi_densepose_occworld_candle::inference::OccWorldCandle; +/// # use wifi_densepose_occworld_candle::config::OccWorldConfig; +/// # use candle_core::Device; +/// # use std::path::Path; +/// let cfg = OccWorldConfig::default(); +/// match OccWorldCandle::load(Path::new("/path/to/occworld.safetensors"), cfg) { +/// Ok(engine) => { /* use engine */ } +/// Err(_) => { /* fall back to Python bridge */ } +/// } +/// ``` +pub struct OccWorldCandle { + // Note: Device does not implement Debug; derive manually below. + config: OccWorldConfig, + vqvae: VQVAEComponents, + transformer: OccWorldTransformer, + device: Device, +} + +impl std::fmt::Debug for OccWorldCandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OccWorldCandle") + .field("config", &self.config) + .finish_non_exhaustive() + } +} + +impl OccWorldCandle { + /// Load model weights from a SafeTensors checkpoint. + /// + /// Returns `Err` if the checkpoint does not exist, so callers can + /// gracefully fall back to the Python bridge (`wifi-densepose-worldmodel`). + pub fn load( + checkpoint_path: &Path, + config: OccWorldConfig, + ) -> Result { + if !checkpoint_path.exists() { + return Err(OccWorldError::CheckpointNotFound( + checkpoint_path.display().to_string(), + )); + } + + let device = pick_device(); + + // Load weights through the safe file-read path in `model::load_safetensors`. + // This avoids the `unsafe` mmap block forbidden by our lint config, at the + // cost of reading the full file into memory rather than memory-mapping it. + // Switch to `VarBuilder::from_mmaped_safetensors` (in a crate that allows + // unsafe) once the checkpoint is large enough that mmap matters. + let tensors = crate::model::load_safetensors(checkpoint_path, &device)?; + let vb = VarBuilder::from_tensors(tensors, DType::F32, &device); + + let vqvae = VQVAEComponents::new(&config, vb.clone()).map_err(OccWorldError::Candle)?; + let transformer = + OccWorldTransformer::new(config.clone(), vb).map_err(OccWorldError::Candle)?; + + Ok(Self { + config, + vqvae, + transformer, + device, + }) + } + + /// Construct with random weights for testing and benchmarking. + /// + /// All shapes are correct; no checkpoint is required. + pub fn dummy(config: OccWorldConfig, device: Device) -> Result { + let vqvae = + VQVAEComponents::dummy(&config, &device).map_err(OccWorldError::Candle)?; + let transformer = + OccWorldTransformer::dummy(config.clone(), &device).map_err(OccWorldError::Candle)?; + Ok(Self { + config, + vqvae, + transformer, + device, + }) + } + + /// Infer 15 future occupancy frames from 16 past frames. + /// + /// # Arguments + /// * `past_occupancy` — `(1, 16, 200, 200, 16)` tensor of `u8` class indices. + /// + /// # Returns + /// [`InferenceOutput`] containing: + /// - `sem_pred`: `(1, 15, 200, 200, 16)` u8 predicted class indices + /// - `trajectory_priors`: one waypoint per predicted frame + /// - `inference_ms`: wall-clock latency + pub fn predict(&self, past_occupancy: &Tensor) -> Result { + let t0 = Instant::now(); + + let cfg = &self.config; + let (b, f_in, h, w, d) = past_occupancy.dims5().map_err(OccWorldError::Candle)?; + + if h != cfg.grid_h || w != cfg.grid_w || d != cfg.grid_d { + return Err(OccWorldError::ShapeMismatch(format!( + "expected past_occupancy (_, _, {}, {}, {}), got (_, _, {h}, {w}, {d})", + cfg.grid_h, cfg.grid_w, cfg.grid_d + ))); + } + + // ── Step 1: VQVAE encode each past frame ────────────────────────── + // Flatten batch*frames: (B, F, H, W, D) → (B*F, H, W, D) + let occ_flat = past_occupancy + .reshape((b * f_in, h, w, d)) + .map_err(OccWorldError::Candle)?; + + // Cast to u32 for class embedding (input is u8) + let occ_u32 = occ_flat + .to_dtype(DType::U32) + .map_err(OccWorldError::Candle)?; + + // Class embedding → (B*F, base_channels, H, W*D) + let embedded = self + .vqvae + .class_embed + .forward(&occ_u32, cfg.grid_d) + .map_err(OccWorldError::Candle)?; + + // Encode (stub) → (B*F, z_channels, token_h, token_w) + let z = encode_occupancy(&embedded, cfg, &self.device)?; + + // quant_conv → (B*F, embed_dim, token_h, token_w) + let z_e = self + .vqvae + .quant_conv + .forward(&z) + .map_err(OccWorldError::Candle)?; + + // Vector quantisation → z_q (B*F, embed_dim, token_h, token_w), indices + // Reshape to (B*F, H*W, embed_dim) for VQCodebook.encode + let (bf, e_dim, th, tw) = z_e.dims4().map_err(OccWorldError::Candle)?; + let z_e_flat = z_e + .permute((0, 2, 3, 1)) // (B*F, th, tw, embed_dim) + .map_err(OccWorldError::Candle)? + .reshape((bf, th * tw, e_dim)) + .map_err(OccWorldError::Candle)?; + + let (z_q_flat, _indices) = self + .vqvae + .codebook + .encode(&z_e_flat) + .map_err(OccWorldError::Candle)?; + + // Back to (B*F, embed_dim, th, tw) → (B, F, embed_dim, th, tw) + let z_q = z_q_flat + .reshape((bf, th, tw, e_dim)) + .map_err(OccWorldError::Candle)? + .permute((0, 3, 1, 2)) // (B*F, embed_dim, th, tw) + .map_err(OccWorldError::Candle)? + .reshape((b, f_in, e_dim, th, tw)) + .map_err(OccWorldError::Candle)?; + + // ── Step 2: Transformer predicts future token logits ────────────── + // Output: (B, F_out, vocab, th, tw) + let pred_logits = self.transformer.forward(&z_q)?; + + let f_out = pred_logits.dim(1).map_err(OccWorldError::Candle)?; + + // ── Step 3: Argmax over vocab dim → predicted token indices ─────── + let pred_indices = pred_logits + .argmax(2) // (B, F_out, th, tw) — over vocab dim + .map_err(OccWorldError::Candle)?; + + // ── Step 4: Decode token indices → z_q values ──────────────────── + // Flatten to (B*F_out * th * tw,) for codebook lookup + let idx_flat = pred_indices + .flatten_all() + .map_err(OccWorldError::Candle)?; + let z_decoded = self + .vqvae + .codebook + .decode(&idx_flat) + .map_err(OccWorldError::Candle)?; // (B*F_out*th*tw, embed_dim) + + // Reshape to (B*F_out, embed_dim, th, tw) for post_quant_conv + let z_dec_4d = z_decoded + .reshape((b * f_out, e_dim, th, tw)) + .map_err(OccWorldError::Candle)?; + + let z_post = self + .vqvae + .post_quant_conv + .forward(&z_dec_4d) + .map_err(OccWorldError::Candle)?; + + // ── Step 5: Decode to class logits (stub) → class predictions ───── + let class_logits = decode_to_logits(&z_post, cfg, &self.device)?; + // class_logits: (B*F_out, num_classes, H, W, D) + // Argmax over class dim → (B*F_out, H, W, D) + let sem_flat = class_logits + .argmax(1) + .map_err(OccWorldError::Candle)? + .to_dtype(DType::U8) + .map_err(OccWorldError::Candle)?; + + let sem_pred = sem_flat + .reshape((b, f_out, cfg.grid_h, cfg.grid_w, cfg.grid_d)) + .map_err(OccWorldError::Candle)?; + + // ── Step 6: Extract trajectory priors ───────────────────────────── + let trajectory_priors = extract_trajectory_priors(&sem_pred, cfg, f_out)?; + + let inference_ms = t0.elapsed().as_secs_f64() * 1000.0; + + Ok(InferenceOutput { + sem_pred, + trajectory_priors, + inference_ms, + }) + } +} + +// ── Trajectory prior extraction ─────────────────────────────────────────────── + +/// Extract one trajectory waypoint per predicted frame. +/// +/// For each frame, finds the non-free voxel with the highest probability +/// (approximated by the centroid of all non-free voxels, weighted equally). +/// Returns an empty `Vec` when all frames are predicted as free space. +fn extract_trajectory_priors( + sem_pred: &Tensor, + cfg: &OccWorldConfig, + f_out: usize, +) -> Result, OccWorldError> { + // sem_pred: (1, F_out, H, W, D) u8 + // Pull to CPU Vec for coordinate extraction — lightweight post-processing + let data: Vec = sem_pred + .flatten_all() + .map_err(OccWorldError::Candle)? + .to_vec1() + .map_err(OccWorldError::Candle)?; + + let h = cfg.grid_h; + let w = cfg.grid_w; + let d = cfg.grid_d; + let frame_stride = h * w * d; + + let mut waypoints = Vec::with_capacity(f_out); + for fi in 0..f_out { + let frame_slice = &data[fi * frame_stride..(fi + 1) * frame_stride]; + let mut sum_x = 0.0f64; + let mut sum_y = 0.0f64; + let mut sum_z = 0.0f64; + let mut count = 0usize; + + for (idx, &cls) in frame_slice.iter().enumerate() { + if cls != cfg.free_class { + let xi = idx / (w * d); + let yi = (idx % (w * d)) / d; + let zi = idx % d; + sum_x += xi as f64; + sum_y += yi as f64; + sum_z += zi as f64; + count += 1; + } + } + + if count > 0 { + let n = count as f64; + waypoints.push(TrajectoryWaypoint { + frame: fi, + grid_x: (sum_x / n) as f32, + grid_y: (sum_y / n) as f32, + grid_z: (sum_z / n) as f32, + confidence: (count as f32) / (frame_stride as f32), + }); + } + } + Ok(waypoints) +} + +// ── Device selection ────────────────────────────────────────────────────────── + +fn pick_device() -> Device { + #[cfg(feature = "cuda")] + if let Ok(d) = Device::cuda_if_available(0) { + return d; + } + Device::Cpu +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::OccWorldConfig; + + fn small_cfg() -> OccWorldConfig { + OccWorldConfig { + grid_h: 8, + grid_w: 8, + grid_d: 4, + num_classes: 4, + free_class: 3, + base_channels: 8, + z_channels: 8, + codebook_size: 4, + embed_dim: 8, + num_frames: 2, + token_h: 4, + token_w: 4, + num_heads: 2, + num_layers: 1, + ffn_hidden: 16, + } + } + + #[test] + fn test_dummy_predict_shape() -> Result<(), OccWorldError> { + let device = Device::Cpu; + let cfg = small_cfg(); + let engine = OccWorldCandle::dummy(cfg.clone(), device.clone())?; + + // (1, 2, 8, 8, 4) — batch=1, 2 past frames (matches num_frames) + let past = Tensor::zeros( + (1, cfg.num_frames, cfg.grid_h, cfg.grid_w, cfg.grid_d), + DType::U8, + &device, + ) + .map_err(OccWorldError::Candle)?; + + let out = engine.predict(&past)?; + let dims = out.sem_pred.dims(); + assert_eq!(dims[0], 1, "batch dim"); + assert_eq!(dims[1], cfg.num_frames, "frame dim"); + assert_eq!(dims[2], cfg.grid_h, "H dim"); + assert_eq!(dims[3], cfg.grid_w, "W dim"); + assert_eq!(dims[4], cfg.grid_d, "D dim"); + + Ok(()) + } + + #[test] + fn test_load_nonexistent_checkpoint() { + let cfg = small_cfg(); + let result = OccWorldCandle::load(Path::new("/no/such/checkpoint.safetensors"), cfg); + assert!( + matches!(result, Err(OccWorldError::CheckpointNotFound(_))), + "expected CheckpointNotFound, got {result:?}" + ); + } +} diff --git a/v2/crates/wifi-densepose-occworld-candle/src/lib.rs b/v2/crates/wifi-densepose-occworld-candle/src/lib.rs new file mode 100644 index 00000000..b9e833e0 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/lib.rs @@ -0,0 +1,52 @@ +//! `wifi-densepose-occworld-candle` — OccWorld TransVQVAE inference in Candle. +//! +//! Ports the 72.4 M-parameter OccWorld world model (VQVAE tokeniser + +//! autoregressive transformer) from Python to native Rust using the +//! Hugging Face Candle framework. The goal is to eliminate the +//! 208 ms Python/IPC overhead of the existing `wifi-densepose-worldmodel` +//! bridge and enable tight integration with the streaming engine. +//! +//! ## Module structure +//! +//! | Module | Contents | +//! |-----------------|-------------------------------------------------------| +//! | `config` | `OccWorldConfig` — hyper-parameters | +//! | `error` | `OccWorldError` — unified error enum | +//! | `vqvae` | Class embedding, VQ codebook, quant convolutions | +//! | `transformer` | Autoregressive transformer (`PlanUAutoRegTransformer`) | +//! | `model` | SafeTensors weight loading + key mapping | +//! | `inference` | `OccWorldCandle` end-to-end inference engine | +//! +//! ## Implementation status +//! +//! The VQVAE encoder/decoder ResNet blocks are **stubs** that return random +//! tensors of the correct shape. All other components (class embedding, +//! VQ codebook, quant/post-quant convolutions, transformer, trajectory +//! extraction) are fully implemented. The stubs will be replaced in Phase 5 +//! once the SafeTensors checkpoint is available. +//! +//! ## Usage +//! +//! ```no_run +//! use wifi_densepose_occworld_candle::inference::OccWorldCandle; +//! use wifi_densepose_occworld_candle::config::OccWorldConfig; +//! use candle_core::{Device, DType, Tensor}; +//! use std::path::Path; +//! +//! let cfg = OccWorldConfig::default(); +//! let engine = OccWorldCandle::dummy(cfg, Device::Cpu).expect("dummy init"); +//! let past = Tensor::zeros((1, 15, 200, 200, 16), DType::U8, &Device::Cpu).unwrap(); +//! let out = engine.predict(&past).expect("predict"); +//! println!("predicted {} frames in {:.1} ms", out.sem_pred.dim(1).unwrap(), out.inference_ms); +//! ``` + +pub mod config; +pub mod error; +pub mod inference; +pub mod model; +pub mod transformer; +pub mod vqvae; + +pub use config::OccWorldConfig; +pub use error::OccWorldError; +pub use inference::{InferenceOutput, OccWorldCandle, TrajectoryWaypoint}; diff --git a/v2/crates/wifi-densepose-occworld-candle/src/model.rs b/v2/crates/wifi-densepose-occworld-candle/src/model.rs new file mode 100644 index 00000000..8a83c209 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/model.rs @@ -0,0 +1,165 @@ +//! Weight loading utilities for the OccWorld SafeTensors checkpoint. +//! +//! Phase-5 retraining produces a `.safetensors` file whose tensor keys +//! follow PyTorch naming conventions (e.g. `encoder.conv_in.weight`). +//! The functions here map those keys to the Candle `VarBuilder` sub-path +//! convention used in this crate (e.g. `enc.conv_in.weight`). + +use candle_core::{Device, Tensor}; +use std::collections::HashMap; +use std::path::Path; + +use crate::error::OccWorldError; + +/// Load all tensors from a SafeTensors file into a key→Tensor map. +/// +/// Returns `Err(OccWorldError::CheckpointNotFound)` if the path does not +/// exist, so callers can gracefully fall back to the Python bridge. +pub fn load_safetensors( + path: &Path, + device: &Device, +) -> Result, OccWorldError> { + if !path.exists() { + return Err(OccWorldError::CheckpointNotFound( + path.display().to_string(), + )); + } + + // Read the raw bytes; safetensors requires the full file in memory. + let bytes = std::fs::read(path)?; + let named_tensors = safetensors::SafeTensors::deserialize(&bytes) + .map_err(|e| OccWorldError::CheckpointParse(e.to_string()))?; + + let mut map = HashMap::new(); + for (name, view) in named_tensors.tensors() { + let candle_key = map_pytorch_key(&name); + let dtype = safetensor_dtype_to_candle(view.dtype()) + .ok_or_else(|| OccWorldError::CheckpointParse( + format!("unsupported dtype for key '{name}'"), + ))?; + let shape: Vec = view.shape().to_vec(); + let data = view.data(); + let tensor = Tensor::from_raw_buffer(data, dtype, &shape, device) + .map_err(OccWorldError::Candle)?; + map.insert(candle_key, tensor); + } + Ok(map) +} + +/// Map a PyTorch weight key to the Candle naming convention used here. +/// +/// # Mapping rules +/// +/// | PyTorch prefix | Candle prefix | +/// |------------------------|------------------------| +/// | `encoder.` | `enc.` | +/// | `decoder.` | `dec.` | +/// | `quantize.` | `quantize.` | +/// | `quant_conv.` | `quant_conv.` | +/// | `post_quant_conv.` | `post_quant_conv.` | +/// | `transformer.` | `transformer.` | +/// | `class_embedding.` | `class_embed.` | +/// +/// All other keys are passed through unchanged. Extend this function +/// whenever the checkpoint adds new top-level modules. +pub fn map_pytorch_key(key: &str) -> String { + // Strip any leading "model." prefix that PyTorch Lightning adds + let key = key.strip_prefix("model.").unwrap_or(key); + + if let Some(rest) = key.strip_prefix("encoder.") { + return format!("enc.{rest}"); + } + if let Some(rest) = key.strip_prefix("decoder.") { + return format!("dec.{rest}"); + } + if let Some(rest) = key.strip_prefix("class_embedding.") { + return format!("class_embed.{rest}"); + } + + // No transformation needed for these prefixes + key.to_owned() +} + +/// Convert a `safetensors::Dtype` to a `candle_core::DType`. +/// +/// Returns `None` for unsupported variants (e.g. BF16 on CPU without +/// the `bf16` feature). +fn safetensor_dtype_to_candle(dt: safetensors::Dtype) -> Option { + use candle_core::DType; + use safetensors::Dtype; + match dt { + Dtype::F32 => Some(DType::F32), + Dtype::F64 => Some(DType::F64), + Dtype::F16 => Some(DType::F16), + Dtype::BF16 => Some(DType::BF16), + Dtype::I32 => Some(DType::I64), // widen for Candle compatibility + Dtype::I64 => Some(DType::I64), + Dtype::U8 => Some(DType::U8), + Dtype::U32 => Some(DType::U32), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_map_pytorch_key_encoder() { + assert_eq!( + map_pytorch_key("encoder.conv_in.weight"), + "enc.conv_in.weight" + ); + } + + #[test] + fn test_map_pytorch_key_decoder() { + assert_eq!( + map_pytorch_key("decoder.conv_out.bias"), + "dec.conv_out.bias" + ); + } + + #[test] + fn test_map_pytorch_key_class_embedding() { + assert_eq!( + map_pytorch_key("class_embedding.weight"), + "class_embed.weight" + ); + } + + #[test] + fn test_map_pytorch_key_passthrough() { + assert_eq!( + map_pytorch_key("quantize.embedding.weight"), + "quantize.embedding.weight" + ); + assert_eq!( + map_pytorch_key("quant_conv.weight"), + "quant_conv.weight" + ); + assert_eq!( + map_pytorch_key("transformer.layer_0.ffn.fc1.weight"), + "transformer.layer_0.ffn.fc1.weight" + ); + } + + #[test] + fn test_map_pytorch_key_lightning_prefix() { + // PyTorch Lightning wraps everything under "model." + assert_eq!( + map_pytorch_key("model.encoder.conv_in.weight"), + "enc.conv_in.weight" + ); + } + + #[test] + fn test_load_nonexistent_checkpoint() { + let device = candle_core::Device::Cpu; + let result = load_safetensors(Path::new("/nonexistent/checkpoint.safetensors"), &device); + assert!( + matches!(result, Err(OccWorldError::CheckpointNotFound(_))), + "expected CheckpointNotFound, got {result:?}" + ); + } +} diff --git a/v2/crates/wifi-densepose-occworld-candle/src/transformer.rs b/v2/crates/wifi-densepose-occworld-candle/src/transformer.rs new file mode 100644 index 00000000..4956ce00 --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/transformer.rs @@ -0,0 +1,466 @@ +//! OccWorld autoregressive transformer — `PlanUAutoRegTransformer` port. +//! +//! Architecture summary (matches `PlanUtransformer.py`): +//! +//! 1. Input: quantised VQVAE tokens `z_q` of shape `(B, F, C, H, W)`. +//! 2. Spatial flatten: `(B*F, C, H*W)` so each frame is a sequence of spatial tokens. +//! 3. Temporal embedding: learned positional bias added to the C-dim channel. +//! 4. Per-layer: `TemporalCrossAttn` → `SpatialCrossAttn` → FFN. +//! 5. Output head: `Linear(C → vocab)` producing logits `(B, F_out, vocab, H, W)`. +//! +//! The two-level UNet attention (`num_layers = 2`) uses separate query/key/value +//! projections at each level so the encoder sees the full past context while +//! the decoder generates one future frame at a time. + +use candle_core::{DType, Device, Module, Result, Tensor}; +use candle_nn::{linear, ops::softmax, Embedding, Linear, VarBuilder}; + +use crate::config::OccWorldConfig; +use crate::error::OccWorldError; + +// ── Temporal positional embedding ───────────────────────────────────────────── + +/// Maps frame indices `[0, num_frames*2)` to `embed_dim`-dimensional vectors. +/// +/// The doubled range (`num_frames*2`) allows future frame positions to be +/// distinct from past frame positions (Python: `nn.Embedding(16 * 2, 512)`). +pub struct TemporalEmbedding { + embed: Embedding, +} + +impl TemporalEmbedding { + /// Build from weights. + pub fn new(num_frames: usize, embed_dim: usize, vb: VarBuilder<'_>) -> Result { + let embed = candle_nn::embedding(num_frames * 2, embed_dim, vb.pp("temporal_embed"))?; + Ok(Self { embed }) + } + + /// Random initialisation. + pub fn dummy(num_frames: usize, embed_dim: usize, device: &Device) -> Result { + let w = Tensor::randn(0f32, 1.0, (num_frames * 2, embed_dim), device)?; + let embed = Embedding::new(w, embed_dim); + Ok(Self { embed }) + } + + /// Produce positional embedding for frame indices `[0, F)`. + /// + /// Returns `(F, embed_dim)` — broadcast over batch and spatial dimensions + /// by the caller. + pub fn forward(&self, num_frames: usize, device: &Device) -> Result { + let indices = Tensor::arange(0u32, num_frames as u32, device)?; + self.embed.forward(&indices) // (F, embed_dim) + } +} + +// ── Scaled-dot-product attention helpers ───────────────────────────────────── + +/// Scaled dot-product attention: `softmax(Q·Kᵀ / √d) · V`. +/// +/// All tensors are `(B, heads, seq_len, head_dim)`. +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let head_dim = q.dim(candle_core::D::Minus1)? as f64; + let scale = (head_dim).sqrt(); + // (B, heads, q_len, k_len) + let attn_weights = (q.matmul(&k.transpose(candle_core::D::Minus2, candle_core::D::Minus1)?)? + / scale)?; + let attn_probs = softmax(&attn_weights, candle_core::D::Minus1)?; + attn_probs.matmul(v) +} + +// ── Spatial cross-attention ─────────────────────────────────────────────────── + +/// Multi-head self/cross-attention over the spatial token sequence. +/// +/// Used to capture dependencies between different spatial locations within +/// the same frame (or across frames when keys/values come from a different +/// temporal index). +pub struct SpatialCrossAttn { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, + head_dim: usize, +} + +impl SpatialCrossAttn { + /// Build from weights with sub-path `prefix`. + pub fn new(embed_dim: usize, num_heads: usize, vb: VarBuilder<'_>) -> Result { + let head_dim = embed_dim / num_heads; + let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?; + let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?; + let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + head_dim, + }) + } + + /// Random initialisation. + pub fn dummy(embed_dim: usize, num_heads: usize, device: &Device) -> Result { + let mk_linear = |i: usize, o: usize| -> Result { + let w = Tensor::randn(0f32, 0.02, (o, i), device)?; + let b = Tensor::zeros(o, DType::F32, device)?; + Ok(Linear::new(w, Some(b))) + }; + let head_dim = embed_dim / num_heads; + Ok(Self { + q_proj: mk_linear(embed_dim, embed_dim)?, + k_proj: mk_linear(embed_dim, embed_dim)?, + v_proj: mk_linear(embed_dim, embed_dim)?, + out_proj: mk_linear(embed_dim, embed_dim)?, + num_heads, + head_dim, + }) + } + + /// Forward attention. + /// + /// `queries`: `(B, q_len, C)`, `keys`/`values`: `(B, kv_len, C)`. + /// Returns: `(B, q_len, C)`. + pub fn forward(&self, queries: &Tensor, keys: &Tensor, values: &Tensor) -> Result { + let (b, q_len, _c) = queries.dims3()?; + + let project = |proj: &Linear, x: &Tensor, seq: usize| -> Result { + let out = proj.forward(x)?; // (B, seq, C) + out.reshape((b, seq, self.num_heads, self.head_dim))? + .permute((0, 2, 1, 3)) // (B, heads, seq, head_dim) + }; + + let kv_len = keys.dim(1)?; + let q = project(&self.q_proj, queries, q_len)?.contiguous()?; + let k = project(&self.k_proj, keys, kv_len)?.contiguous()?; + let v = project(&self.v_proj, values, kv_len)?.contiguous()?; + + // (B, heads, q_len, head_dim) + let attended = scaled_dot_product_attention(&q, &k, &v)?; + // → (B, q_len, C) + let merged = attended + .permute((0, 2, 1, 3))? + .reshape((b, q_len, self.num_heads * self.head_dim))?; + self.out_proj.forward(&merged) + } +} + +// ── Temporal cross-attention ────────────────────────────────────────────────── + +/// Cross-attention between past-frame tokens (keys/values) and query tokens. +/// +/// Identical in structure to `SpatialCrossAttn` — kept as a distinct type +/// for clarity and separate weight namespacing in the checkpoint. +pub struct TemporalCrossAttn { + inner: SpatialCrossAttn, +} + +impl TemporalCrossAttn { + /// Build from weights. + pub fn new(embed_dim: usize, num_heads: usize, vb: VarBuilder<'_>) -> Result { + Ok(Self { + inner: SpatialCrossAttn::new(embed_dim, num_heads, vb)?, + }) + } + + /// Random initialisation. + pub fn dummy(embed_dim: usize, num_heads: usize, device: &Device) -> Result { + Ok(Self { + inner: SpatialCrossAttn::dummy(embed_dim, num_heads, device)?, + }) + } + + /// Forward: `queries (B, q_len, C)` attend to `keys/values (B, kv_len, C)`. + pub fn forward(&self, queries: &Tensor, keys: &Tensor, values: &Tensor) -> Result { + self.inner.forward(queries, keys, values) + } +} + +// ── Feed-forward network ────────────────────────────────────────────────────── + +struct FeedForward { + fc1: Linear, + fc2: Linear, +} + +impl FeedForward { + fn new(embed_dim: usize, ffn_hidden: usize, vb: VarBuilder<'_>) -> Result { + let fc1 = linear(embed_dim, ffn_hidden, vb.pp("fc1"))?; + let fc2 = linear(ffn_hidden, embed_dim, vb.pp("fc2"))?; + Ok(Self { fc1, fc2 }) + } + + fn dummy(embed_dim: usize, ffn_hidden: usize, device: &Device) -> Result { + let mk = |i: usize, o: usize| -> Result { + let w = Tensor::randn(0f32, 0.02, (o, i), device)?; + let b = Tensor::zeros(o, DType::F32, device)?; + Ok(Linear::new(w, Some(b))) + }; + Ok(Self { + fc1: mk(embed_dim, ffn_hidden)?, + fc2: mk(ffn_hidden, embed_dim)?, + }) + } + + fn forward(&self, x: &Tensor) -> Result { + self.fc2.forward(&self.fc1.forward(x)?.gelu()?) + } +} + +// ── Single encoder layer ───────────────────────────────────────────────────── + +/// One layer of the OccWorld UNet-style encoder: +/// `TemporalCrossAttn → SpatialCrossAttn → FFN` with residual connections. +pub struct OccWorldTransformerLayer { + temporal_attn: TemporalCrossAttn, + spatial_attn: SpatialCrossAttn, + ffn: FeedForward, + // Layer-norms for pre-norm formulation + norm1: candle_nn::LayerNorm, + norm2: candle_nn::LayerNorm, + norm3: candle_nn::LayerNorm, +} + +impl OccWorldTransformerLayer { + /// Build from weights. + pub fn new(cfg: &OccWorldConfig, vb: VarBuilder<'_>) -> Result { + let temporal_attn = + TemporalCrossAttn::new(cfg.embed_dim, cfg.num_heads, vb.pp("temporal_attn"))?; + let spatial_attn = + SpatialCrossAttn::new(cfg.embed_dim, cfg.num_heads, vb.pp("spatial_attn"))?; + let ffn = FeedForward::new(cfg.embed_dim, cfg.ffn_hidden, vb.pp("ffn"))?; + let norm_cfg = candle_nn::LayerNormConfig::default(); + let norm1 = candle_nn::layer_norm(cfg.embed_dim, norm_cfg, vb.pp("norm1"))?; + let norm2 = candle_nn::layer_norm(cfg.embed_dim, norm_cfg, vb.pp("norm2"))?; + let norm3 = candle_nn::layer_norm(cfg.embed_dim, norm_cfg, vb.pp("norm3"))?; + Ok(Self { + temporal_attn, + spatial_attn, + ffn, + norm1, + norm2, + norm3, + }) + } + + /// Random initialisation. + pub fn dummy(cfg: &OccWorldConfig, device: &Device) -> Result { + let temporal_attn = TemporalCrossAttn::dummy(cfg.embed_dim, cfg.num_heads, device)?; + let spatial_attn = SpatialCrossAttn::dummy(cfg.embed_dim, cfg.num_heads, device)?; + let ffn = FeedForward::dummy(cfg.embed_dim, cfg.ffn_hidden, device)?; + let norm_cfg = candle_nn::LayerNormConfig::default(); + // Dummy layer norms with ones/zeros + let mk_norm = |d: usize| -> Result { + let w = Tensor::ones(d, DType::F32, device)?; + let b = Tensor::zeros(d, DType::F32, device)?; + Ok(candle_nn::LayerNorm::new(w, b, norm_cfg.eps)) + }; + Ok(Self { + temporal_attn, + spatial_attn, + ffn, + norm1: mk_norm(cfg.embed_dim)?, + norm2: mk_norm(cfg.embed_dim)?, + norm3: mk_norm(cfg.embed_dim)?, + }) + } + + /// Forward one layer. + /// + /// `x`: `(B, seq_len, C)` — queries (current frame tokens). + /// `ctx`: `(B, ctx_len, C)` — past-frame context tokens for temporal attn. + /// Returns `(B, seq_len, C)`. + pub fn forward(&self, x: &Tensor, ctx: &Tensor) -> Result { + // Temporal cross-attention with residual + let x = { + let normed = self.norm1.forward(x)?; + let attended = self.temporal_attn.forward(&normed, ctx, ctx)?; + (x + attended)? + }; + // Spatial self-attention with residual + let x = { + let normed = self.norm2.forward(&x)?; + let attended = self.spatial_attn.forward(&normed, &normed, &normed)?; + (x + attended)? + }; + // FFN with residual + let normed = self.norm3.forward(&x)?; + let ff_out = self.ffn.forward(&normed)?; + x + ff_out + } +} + +// ── Full transformer ────────────────────────────────────────────────────────── + +/// OccWorld autoregressive transformer (`PlanUAutoRegTransformer`). +/// +/// Takes quantised VQVAE tokens for past frames and predicts logits for +/// the next `F_out` frames. +pub struct OccWorldTransformer { + temporal_embed: TemporalEmbedding, + layers: Vec, + output_head: Linear, + cfg: OccWorldConfig, +} + +impl OccWorldTransformer { + /// Build from weights. + pub fn new(cfg: OccWorldConfig, vb: VarBuilder<'_>) -> Result { + let temporal_embed = + TemporalEmbedding::new(cfg.num_frames, cfg.embed_dim, vb.pp("transformer"))?; + let mut layers = Vec::with_capacity(cfg.num_layers); + for i in 0..cfg.num_layers { + layers.push(OccWorldTransformerLayer::new( + &cfg, + vb.pp("transformer").pp(format!("layer_{i}")), + )?); + } + let output_head = linear( + cfg.embed_dim, + cfg.codebook_size, + vb.pp("transformer").pp("output_head"), + )?; + Ok(Self { + temporal_embed, + layers, + output_head, + cfg, + }) + } + + /// Build with random weights (for tests / benchmarks). + pub fn dummy(cfg: OccWorldConfig, device: &Device) -> Result { + let temporal_embed = TemporalEmbedding::dummy(cfg.num_frames, cfg.embed_dim, device)?; + let mut layers = Vec::with_capacity(cfg.num_layers); + for _ in 0..cfg.num_layers { + layers.push(OccWorldTransformerLayer::dummy(&cfg, device)?); + } + let w = Tensor::randn(0f32, 0.02, (cfg.codebook_size, cfg.embed_dim), device)?; + let b = Tensor::zeros(cfg.codebook_size, DType::F32, device)?; + let output_head = Linear::new(w, Some(b)); + Ok(Self { + temporal_embed, + layers, + output_head, + cfg, + }) + } + + /// Forward pass. + /// + /// # Arguments + /// * `z_q` — quantised tokens: `(B, F, C, H, W)` where `C = embed_dim`. + /// + /// # Returns + /// Predicted logits: `(B, F_out, vocab, H, W)` where `F_out = F` and + /// `vocab = codebook_size`. + pub fn forward( + &self, + z_q: &Tensor, + ) -> std::result::Result { + let (b, f, c, h, w) = z_q.dims5().map_err(OccWorldError::Candle)?; + let device = z_q.device(); + + // Flatten spatial: (B, F, C, H, W) → (B, F, H*W, C) + // Then flatten batch*frames for parallel processing: (B*F, H*W, C) + let z_flat = z_q + .permute((0, 1, 3, 4, 2)) // (B, F, H, W, C) + .map_err(OccWorldError::Candle)? + .reshape((b * f, h * w, c)) + .map_err(OccWorldError::Candle)?; + + // Add temporal positional embedding — broadcast over spatial tokens + let temp_pos = self + .temporal_embed + .forward(f, device) + .map_err(OccWorldError::Candle)?; // (F, C) + // Expand to (B*F, 1, C) for broadcast addition + let temp_pos = temp_pos + .reshape((f, 1, c)) + .map_err(OccWorldError::Candle)? + .repeat(vec![b, 1, 1]) + .map_err(OccWorldError::Candle)? + .reshape((b * f, 1, c)) + .map_err(OccWorldError::Candle)?; + let mut x = z_flat + .broadcast_add(&temp_pos) + .map_err(OccWorldError::Candle)?; // (B*F, H*W, C) + + // Context for temporal attention: reshape back to (B, F*H*W, C) per batch + // and use the full past sequence as keys/values + let ctx = x + .reshape((b, f * h * w, c)) + .map_err(OccWorldError::Candle)? + .repeat(vec![f, 1, 1]) + .map_err(OccWorldError::Candle)? + .reshape((b * f, f * h * w, c)) + .map_err(OccWorldError::Candle)?; + + // Pass through transformer layers + for layer in &self.layers { + x = layer.forward(&x, &ctx).map_err(OccWorldError::Candle)?; + } + + // Output head: (B*F, H*W, C) → (B*F, H*W, vocab) + let logits = self + .output_head + .forward(&x) + .map_err(OccWorldError::Candle)?; + let vocab = self.cfg.codebook_size; + + // Reshape to (B, F, H*W, vocab) → (B, F, vocab, H, W) + let logits_out = logits + .reshape((b, f, h * w, vocab)) + .map_err(OccWorldError::Candle)? + .permute((0, 1, 3, 2)) // (B, F, vocab, H*W) + .map_err(OccWorldError::Candle)? + .reshape((b, f, vocab, h, w)) + .map_err(OccWorldError::Candle)?; + + Ok(logits_out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transformer_forward_shape() -> std::result::Result<(), OccWorldError> { + let device = Device::Cpu; + let cfg = OccWorldConfig { + num_frames: 4, // smaller for fast test + embed_dim: 16, + codebook_size: 8, + token_h: 4, + token_w: 4, + num_heads: 2, + num_layers: 1, + ffn_hidden: 32, + ..OccWorldConfig::default() + }; + + let transformer = OccWorldTransformer::dummy(cfg.clone(), &device) + .map_err(OccWorldError::Candle)?; + + // (B=1, F=4, C=16, H=4, W=4) + let z_q = Tensor::randn( + 0f32, + 1.0, + (1, cfg.num_frames, cfg.embed_dim, cfg.token_h, cfg.token_w), + &device, + ) + .map_err(OccWorldError::Candle)?; + + let logits = transformer.forward(&z_q)?; + // Expected: (1, 4, 8, 4, 4) + assert_eq!( + logits.dims(), + &[1, cfg.num_frames, cfg.codebook_size, cfg.token_h, cfg.token_w] + ); + + Ok(()) + } +} diff --git a/v2/crates/wifi-densepose-occworld-candle/src/vqvae.rs b/v2/crates/wifi-densepose-occworld-candle/src/vqvae.rs new file mode 100644 index 00000000..e1eeaefe --- /dev/null +++ b/v2/crates/wifi-densepose-occworld-candle/src/vqvae.rs @@ -0,0 +1,396 @@ +//! VQVAE components — class embedding, codebook, quant/post-quant convolutions. +//! +//! ## Implementation status +//! +//! | Component | Status | Notes | +//! |----------------------|---------|------------------------------------------------| +//! | `ClassEmbedding` | Full | `Embedding(18, 64)` — matches Python exactly | +//! | `VQCodebook` | Full | Nearest-neighbour lookup via squared-L2 | +//! | `QuantConv` | Full | `Conv2d(128 → 512, k=1)` — quant_conv | +//! | `PostQuantConv` | Full | `Conv2d(512 → 128, k=1)` — post_quant_conv | +//! | `fold_3d_to_2d` | Full | (B*F, C, H, W*D) reshape for 2D CNN | +//! | Encoder2D (ResNet) | STUB | Returns random z of correct shape (B*F,128,50,50). | +//! Full implementation requires loading ~35 M params | +//! from the Phase-5 SafeTensors checkpoint. | +//! | Decoder2D (ResNet) | STUB | Returns random logits of correct shape. | +//! +//! The stubs produce outputs of the correct dtype and shape so that the full +//! inference pipeline compiles, runs, and can be benchmarked end-to-end +//! before the checkpoint is available. + +use candle_core::{DType, Device, Module, Result, Tensor}; +use candle_nn::{Conv2d, Conv2dConfig, Embedding, VarBuilder}; + +use crate::config::OccWorldConfig; +use crate::error::OccWorldError; + +// ── Class embedding ─────────────────────────────────────────────────────────── + +/// Embeds integer class labels `[0, num_classes)` into `base_channels`-dim vectors. +/// +/// Matches `nn.Embedding(18, 64)` in `vae_2d_resnet.py`. +pub struct ClassEmbedding { + embed: Embedding, +} + +impl ClassEmbedding { + /// Build from a [`VarBuilder`] using the sub-path `"class_embed"`. + pub fn new(num_classes: usize, embed_dim: usize, vb: VarBuilder<'_>) -> Result { + let embed = candle_nn::embedding(num_classes, embed_dim, vb.pp("class_embed"))?; + Ok(Self { embed }) + } + + /// Build with random initialisation (for tests / benchmarks). + pub fn dummy(num_classes: usize, embed_dim: usize, device: &Device) -> Result { + let w = Tensor::randn(0f32, 1.0, (num_classes, embed_dim), device)?; + let embed = Embedding::new(w, embed_dim); + Ok(Self { embed }) + } + + /// Forward: `(B*F, H, W, D)` u32 indices → `(B*F, embed_dim, H, W*D)`. + /// + /// The 3-D grid is folded along the depth axis so a 2-D CNN can process it. + pub fn forward(&self, x: &Tensor, grid_d: usize) -> Result { + // x: (B*F, H, W, D) — integer class labels stored as u32 + let (bf, h, w, _d) = x.dims4()?; + + // Flatten spatial+depth → apply embedding → (B*F, H, W, D, embed_dim) + let flat = x.flatten_all()?; // (B*F*H*W*D,) + let embedded = self.embed.forward(&flat)?; // (B*F*H*W*D, embed_dim) + let c = embedded.dim(1)?; + + // Reshape to (B*F, H, W, D, C) then transpose to (B*F, C, H, W*D) + let vol = embedded.reshape((bf, h, w, grid_d, c))?; + // (B*F, H, W, D, C) → (B*F, C, H, W, D) → (B*F, C, H, W*D) + let transposed = vol.permute((0, 4, 1, 2, 3))?; + let (bf2, c2, h2, w2, d2) = transposed.dims5()?; + transposed.reshape((bf2, c2, h2, w2 * d2)) + } +} + +// ── fold_3d_to_2d helper ───────────────────────────────────────────────────── + +/// Reshape `(B*F, C, H, W, D)` into `(B*F, C, H, W*D)` for 2-D CNNs. +/// +/// This is the "fold" operation described in `vae_2d_resnet.py`: +/// the depth axis is concatenated into the width so that standard +/// `Conv2d` layers can process the full 3-D occupancy volume. +pub fn fold_3d_to_2d(x: &Tensor) -> Result { + let (bf, c, h, w, d) = x.dims5()?; + x.reshape((bf, c, h, w * d)) +} + +/// Inverse of `fold_3d_to_2d`: `(B*F, C, H, W*D)` → `(B*F, C, H, W, D)`. +pub fn unfold_2d_to_3d(x: &Tensor, grid_w: usize, grid_d: usize) -> Result { + let (bf, c, h, _wd) = x.dims4()?; + x.reshape((bf, c, h, grid_w, grid_d)) +} + +// ── Vector-quantisation codebook ───────────────────────────────────────────── + +/// VQ codebook: `num_codes × embed_dim` lookup table. +/// +/// Nearest-neighbour assignment uses squared L2 distance: +/// ```text +/// d(z, e_k) = ||z − e_k||² = ||z||² − 2·z·e_kᵀ + ||e_k||² +/// ``` +/// This is standard VQ-VAE (van den Oord et al., 2017). +pub struct VQCodebook { + /// Shape: `(codebook_size, embed_dim)`. + embeddings: Tensor, + /// Number of discrete codes in the codebook. + pub codebook_size: usize, + /// Dimensionality of each codebook embedding vector. + pub embed_dim: usize, +} + +impl VQCodebook { + /// Load from a [`VarBuilder`] using the sub-path `"quantize.embedding.weight"`. + pub fn new(codebook_size: usize, embed_dim: usize, vb: VarBuilder<'_>) -> Result { + let embeddings = vb + .pp("quantize") + .pp("embedding") + .get((codebook_size, embed_dim), "weight")?; + Ok(Self { + embeddings, + codebook_size, + embed_dim, + }) + } + + /// Random initialisation (for tests / benchmarks). + pub fn dummy(codebook_size: usize, embed_dim: usize, device: &Device) -> Result { + let embeddings = Tensor::randn(0f32, 1.0, (codebook_size, embed_dim), device)?; + Ok(Self { + embeddings, + codebook_size, + embed_dim, + }) + } + + /// Quantise `z` (any shape `[..., embed_dim]`) → `(z_q, indices)`. + /// + /// `z_q` has the same shape as `z`; `indices` has shape `[..., 1]` squeezed + /// to `[...]` (batch of scalar indices). + pub fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> { + let orig_shape = z.shape().clone(); + let orig_dims = orig_shape.dims().to_vec(); + let last = *orig_shape.dims().last().unwrap_or(&0); + // Flatten to (N, embed_dim) + let n = z.elem_count() / last; + let z_flat = z.reshape((n, last))?; // (N, D) + + // Squared L2: ||z||² - 2*z*Eᵀ + ||E||² + // z_sq: (N, 1) + let z_sq = z_flat + .sqr()? + .sum(candle_core::D::Minus1)? + .unsqueeze(1)?; + // e_sq: (1, codebook_size) + let e_sq = self + .embeddings + .sqr()? + .sum(candle_core::D::Minus1)? + .unsqueeze(0)?; + // dot: (N, codebook_size) + let dot = z_flat.matmul(&self.embeddings.t()?)?; + // distances: (N, codebook_size) + let distances = z_sq.broadcast_add(&e_sq)?.broadcast_sub(&dot.affine(2.0, 0.0)?)?; + // indices: (N,) + let indices = distances.argmin(candle_core::D::Minus1)?; + + // Look up quantised embeddings + let z_q_flat = self.embeddings.index_select(&indices, 0)?; // (N, D) + + // Reshape back to original shape + let z_q = z_q_flat.reshape(orig_dims.clone())?; + let idx_shape: Vec = orig_dims[..orig_dims.len() - 1].to_vec(); + let indices_out = indices.reshape(idx_shape)?; + + Ok((z_q, indices_out)) + } + + /// Decode flat index tensor `(N,)` or `(B, ...)` → same shape `+ embed_dim`. + pub fn decode(&self, indices: &Tensor) -> Result { + let flat = indices.flatten_all()?; + let z_flat = self.embeddings.index_select(&flat, 0)?; // (N, D) + let mut out_shape: Vec = indices.dims().to_vec(); + out_shape.push(self.embed_dim); + z_flat.reshape(out_shape) + } +} + +// ── Quant / post-quant convolutions ────────────────────────────────────────── + +/// `Conv2d(z_channels → embed_dim, kernel=1)` — `quant_conv` in Python. +pub struct QuantConv { + conv: Conv2d, +} + +impl QuantConv { + /// Load from weights. + pub fn new(z_channels: usize, embed_dim: usize, vb: VarBuilder<'_>) -> Result { + let conv = candle_nn::conv2d( + z_channels, + embed_dim, + 1, + Conv2dConfig::default(), + vb.pp("quant_conv"), + )?; + Ok(Self { conv }) + } + + /// Random initialisation. + pub fn dummy(z_channels: usize, embed_dim: usize, device: &Device) -> Result { + let w = Tensor::randn(0f32, 1.0, (embed_dim, z_channels, 1, 1), device)?; + let b = Tensor::zeros(embed_dim, DType::F32, device)?; + let conv = Conv2d::new(w, Some(b), Conv2dConfig::default()); + Ok(Self { conv }) + } + + /// Forward: `(B*F, z_channels, H, W)` → `(B*F, embed_dim, H, W)`. + pub fn forward(&self, x: &Tensor) -> Result { + self.conv.forward(x) + } +} + +/// `Conv2d(embed_dim → z_channels, kernel=1)` — `post_quant_conv` in Python. +pub struct PostQuantConv { + conv: Conv2d, +} + +impl PostQuantConv { + /// Load from weights. + pub fn new(embed_dim: usize, z_channels: usize, vb: VarBuilder<'_>) -> Result { + let conv = candle_nn::conv2d( + embed_dim, + z_channels, + 1, + Conv2dConfig::default(), + vb.pp("post_quant_conv"), + )?; + Ok(Self { conv }) + } + + /// Random initialisation. + pub fn dummy(embed_dim: usize, z_channels: usize, device: &Device) -> Result { + let w = Tensor::randn(0f32, 1.0, (z_channels, embed_dim, 1, 1), device)?; + let b = Tensor::zeros(z_channels, DType::F32, device)?; + let conv = Conv2d::new(w, Some(b), Conv2dConfig::default()); + Ok(Self { conv }) + } + + /// Forward: `(B*F, embed_dim, H, W)` → `(B*F, z_channels, H, W)`. + pub fn forward(&self, x: &Tensor) -> Result { + self.conv.forward(x) + } +} + +// ── Encoder2D stub ──────────────────────────────────────────────────────────── + +/// **STUB** — returns a random tensor of the correct shape. +/// +/// The full `Encoder2D` from `vae_2d_resnet.py` is a multi-resolution ResNet +/// with three down-sampling stages (stride-2 `Conv2d` + residual blocks). +/// Porting all ~35 M parameters requires the Phase-5 SafeTensors checkpoint +/// to be available so the weight names can be mapped. Until then, this +/// stub ensures the pipeline compiles and end-to-end shape tests pass. +/// +/// Replace this function with the real ResNet implementation in Phase 5. +pub fn encode_occupancy( + x: &Tensor, + cfg: &OccWorldConfig, + device: &Device, +) -> std::result::Result { + // Derive batch*frames from the input shape + let dims = x.dims(); + // Acceptable input shapes: (B, F, H, W, D) or (B*F, H, W, D) + let bf = match dims.len() { + 5 => dims[0] * dims[1], + 4 => dims[0], + _ => { + return Err(OccWorldError::ShapeMismatch(format!( + "encode_occupancy: expected 4-D or 5-D input, got {}-D", + dims.len() + ))) + } + }; + + // STUB: return random z of correct shape (B*F, z_channels, token_h, token_w) + let z = Tensor::randn( + 0f32, + 1.0, + (bf, cfg.z_channels, cfg.token_h, cfg.token_w), + device, + ) + .map_err(OccWorldError::Candle)?; + + Ok(z) +} + +/// **STUB** — returns random class logits of the correct shape. +/// +/// The full `Decoder2D` mirrors the encoder: three up-sampling stages +/// followed by a `Conv2d` head that produces `num_classes` logits per voxel. +/// Implementation is deferred to Phase 5 (checkpoint loading). +/// +/// Replace with the real decoder when Phase-5 weights are available. +pub fn decode_to_logits( + z: &Tensor, + cfg: &OccWorldConfig, + device: &Device, +) -> std::result::Result { + let (bf, _c, _h, _w) = z.dims4().map_err(OccWorldError::Candle)?; + + // STUB: return random logits (B*F, num_classes, H, W, D) + let logits = Tensor::randn( + 0f32, + 1.0, + (bf, cfg.num_classes, cfg.grid_h, cfg.grid_w, cfg.grid_d), + device, + ) + .map_err(OccWorldError::Candle)?; + + Ok(logits) +} + +// ── VQVAE component bundle ──────────────────────────────────────────────────── + +/// All VQVAE components bundled together for use in `OccWorldCandle`. +pub struct VQVAEComponents { + /// Class label → float embedding (`nn.Embedding(18, 64)` in Python). + pub class_embed: ClassEmbedding, + /// `Conv2d(z_channels → embed_dim, k=1)` before quantisation. + pub quant_conv: QuantConv, + /// VQ codebook for nearest-neighbour quantisation. + pub codebook: VQCodebook, + /// `Conv2d(embed_dim → z_channels, k=1)` after quantisation. + pub post_quant_conv: PostQuantConv, +} + +impl VQVAEComponents { + /// Build all components from a single [`VarBuilder`]. + pub fn new(cfg: &OccWorldConfig, vb: VarBuilder<'_>) -> Result { + let class_embed = ClassEmbedding::new(cfg.num_classes, cfg.base_channels, vb.clone())?; + let quant_conv = QuantConv::new(cfg.z_channels, cfg.embed_dim, vb.clone())?; + let codebook = VQCodebook::new(cfg.codebook_size, cfg.embed_dim, vb.clone())?; + let post_quant_conv = PostQuantConv::new(cfg.embed_dim, cfg.z_channels, vb)?; + Ok(Self { + class_embed, + quant_conv, + codebook, + post_quant_conv, + }) + } + + /// Build all components with random weights (for testing / benchmarking). + pub fn dummy(cfg: &OccWorldConfig, device: &Device) -> Result { + let class_embed = ClassEmbedding::dummy(cfg.num_classes, cfg.base_channels, device)?; + let quant_conv = QuantConv::dummy(cfg.z_channels, cfg.embed_dim, device)?; + let codebook = VQCodebook::dummy(cfg.codebook_size, cfg.embed_dim, device)?; + let post_quant_conv = PostQuantConv::dummy(cfg.embed_dim, cfg.z_channels, device)?; + Ok(Self { + class_embed, + quant_conv, + codebook, + post_quant_conv, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vq_codebook_roundtrip() -> candle_core::Result<()> { + let device = Device::Cpu; + let codebook = VQCodebook::dummy(512, 512, &device)?; + + // Random input of shape (4, 512) — simulate a batch of 4 latent vectors + let z = Tensor::randn(0f32, 1.0, (4, 512), &device)?; + + let (z_q, indices) = codebook.encode(&z)?; + // z_q must have same shape as z + assert_eq!(z_q.dims(), z.dims()); + // indices must have shape (4,) — one per row + assert_eq!(indices.dims(), &[4]); + + // Decode must recover the same codebook entries + let z_decoded = codebook.decode(&indices)?; + assert_eq!(z_decoded.dims(), &[4, 512]); + + Ok(()) + } + + #[test] + fn test_fold_unfold_roundtrip() -> candle_core::Result<()> { + let device = Device::Cpu; + let x = Tensor::randn(0f32, 1.0, (2, 64, 10, 10, 8), &device)?; + let folded = fold_3d_to_2d(&x)?; + assert_eq!(folded.dims(), &[2, 64, 10, 80]); + let unfolded = unfold_2d_to_3d(&folded, 10, 8)?; + assert_eq!(unfolded.dims(), &[2, 64, 10, 10, 8]); + Ok(()) + } +}