mirror of
https://github.com/ruvnet/RuView
synced 2026-06-09 10:13:17 +00:00
feat: Implement hardware, pose, and stream services for WiFi-DensePose API
- Added HardwareService for managing router interfaces, data collection, and monitoring. - Introduced PoseService for processing CSI data and estimating poses using neural networks. - Created StreamService for real-time data streaming via WebSocket connections. - Implemented initialization, start, stop, and status retrieval methods for each service. - Added data processing, error handling, and statistics tracking across services. - Integrated mock data generation for development and testing purposes.
This commit is contained in:
@@ -1,18 +0,0 @@
|
||||
Collecting paramiko
|
||||
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
|
||||
Collecting bcrypt>=3.2 (from paramiko)
|
||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
|
||||
Collecting cryptography>=3.3 (from paramiko)
|
||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
|
||||
Collecting pynacl>=1.5 (from paramiko)
|
||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
|
||||
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
|
||||
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
|
||||
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
|
||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
|
||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
|
||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
|
||||
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
|
||||
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
# WiFi-DensePose API Environment Configuration Template
|
||||
# Copy this file to .env and modify the values according to your setup
|
||||
|
||||
# =============================================================================
|
||||
# APPLICATION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Application metadata
|
||||
APP_NAME=WiFi-DensePose API
|
||||
VERSION=1.0.0
|
||||
ENVIRONMENT=development # Options: development, staging, production
|
||||
DEBUG=true
|
||||
|
||||
# =============================================================================
|
||||
# SERVER SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Server configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
RELOAD=true # Auto-reload on code changes (development only)
|
||||
WORKERS=1 # Number of worker processes
|
||||
|
||||
# =============================================================================
|
||||
# SECURITY SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# IMPORTANT: Change these values for production!
|
||||
SECRET_KEY=your-secret-key-here-change-for-production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRE_HOURS=24
|
||||
|
||||
# Allowed hosts (restrict in production)
|
||||
ALLOWED_HOSTS=* # Use specific domains in production: example.com,api.example.com
|
||||
|
||||
# CORS settings (restrict in production)
|
||||
CORS_ORIGINS=* # Use specific origins in production: https://example.com,https://app.example.com
|
||||
|
||||
# =============================================================================
|
||||
# DATABASE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Database connection (optional - defaults to SQLite in development)
|
||||
# DATABASE_URL=postgresql://user:password@localhost:5432/wifi_densepose
|
||||
# DATABASE_POOL_SIZE=10
|
||||
# DATABASE_MAX_OVERFLOW=20
|
||||
|
||||
# =============================================================================
|
||||
# REDIS SETTINGS (Optional - for caching and rate limiting)
|
||||
# =============================================================================
|
||||
|
||||
# Redis connection (optional - defaults to localhost in development)
|
||||
# REDIS_URL=redis://localhost:6379/0
|
||||
# REDIS_PASSWORD=your-redis-password
|
||||
# REDIS_DB=0
|
||||
|
||||
# =============================================================================
|
||||
# HARDWARE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# WiFi interface configuration
|
||||
WIFI_INTERFACE=wlan0
|
||||
CSI_BUFFER_SIZE=1000
|
||||
HARDWARE_POLLING_INTERVAL=0.1
|
||||
|
||||
# Hardware mock settings (for development/testing)
|
||||
MOCK_HARDWARE=true
|
||||
MOCK_POSE_DATA=true
|
||||
|
||||
# =============================================================================
|
||||
# POSE ESTIMATION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Model configuration
|
||||
# POSE_MODEL_PATH=/path/to/your/pose/model.pth
|
||||
POSE_CONFIDENCE_THRESHOLD=0.5
|
||||
POSE_PROCESSING_BATCH_SIZE=32
|
||||
POSE_MAX_PERSONS=10
|
||||
|
||||
# =============================================================================
|
||||
# STREAMING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Real-time streaming configuration
|
||||
STREAM_FPS=30
|
||||
STREAM_BUFFER_SIZE=100
|
||||
WEBSOCKET_PING_INTERVAL=60
|
||||
WEBSOCKET_TIMEOUT=300
|
||||
|
||||
# =============================================================================
|
||||
# FEATURE FLAGS
|
||||
# =============================================================================
|
||||
|
||||
# Enable/disable features
|
||||
ENABLE_AUTHENTICATION=false # Set to true for production
|
||||
ENABLE_RATE_LIMITING=false # Set to true for production
|
||||
ENABLE_WEBSOCKETS=true
|
||||
ENABLE_REAL_TIME_PROCESSING=true
|
||||
ENABLE_HISTORICAL_DATA=true
|
||||
|
||||
# Development features
|
||||
ENABLE_TEST_ENDPOINTS=true # Set to false for production
|
||||
|
||||
# =============================================================================
|
||||
# RATE LIMITING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Rate limiting configuration
|
||||
RATE_LIMIT_REQUESTS=100
|
||||
RATE_LIMIT_AUTHENTICATED_REQUESTS=1000
|
||||
RATE_LIMIT_WINDOW=3600 # Window in seconds
|
||||
|
||||
# =============================================================================
|
||||
# LOGGING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
||||
# LOG_FILE=/path/to/logfile.log # Optional: specify log file path
|
||||
LOG_MAX_SIZE=10485760 # 10MB
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# =============================================================================
|
||||
# STORAGE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Storage directories
|
||||
DATA_STORAGE_PATH=./data
|
||||
MODEL_STORAGE_PATH=./models
|
||||
TEMP_STORAGE_PATH=./temp
|
||||
MAX_STORAGE_SIZE_GB=100
|
||||
|
||||
# =============================================================================
|
||||
# MONITORING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Monitoring and metrics
|
||||
METRICS_ENABLED=true
|
||||
HEALTH_CHECK_INTERVAL=30
|
||||
PERFORMANCE_MONITORING=true
|
||||
|
||||
# =============================================================================
|
||||
# API SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# API configuration
|
||||
API_PREFIX=/api/v1
|
||||
DOCS_URL=/docs # Set to null to disable in production
|
||||
REDOC_URL=/redoc # Set to null to disable in production
|
||||
OPENAPI_URL=/openapi.json # Set to null to disable in production
|
||||
|
||||
# =============================================================================
|
||||
# PRODUCTION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# For production deployment, ensure you:
|
||||
# 1. Set ENVIRONMENT=production
|
||||
# 2. Set DEBUG=false
|
||||
# 3. Use a strong SECRET_KEY
|
||||
# 4. Configure proper DATABASE_URL
|
||||
# 5. Restrict ALLOWED_HOSTS and CORS_ORIGINS
|
||||
# 6. Enable ENABLE_AUTHENTICATION=true
|
||||
# 7. Enable ENABLE_RATE_LIMITING=true
|
||||
# 8. Set ENABLE_TEST_ENDPOINTS=false
|
||||
# 9. Disable API documentation URLs (set to null)
|
||||
# 10. Configure proper logging with LOG_FILE
|
||||
|
||||
# Example production settings:
|
||||
# ENVIRONMENT=production
|
||||
# DEBUG=false
|
||||
# SECRET_KEY=your-very-secure-secret-key-here
|
||||
# DATABASE_URL=postgresql://user:password@db-host:5432/wifi_densepose
|
||||
# REDIS_URL=redis://redis-host:6379/0
|
||||
# ALLOWED_HOSTS=yourdomain.com,api.yourdomain.com
|
||||
# CORS_ORIGINS=https://yourdomain.com,https://app.yourdomain.com
|
||||
# ENABLE_AUTHENTICATION=true
|
||||
# ENABLE_RATE_LIMITING=true
|
||||
# ENABLE_TEST_ENDPOINTS=false
|
||||
# DOCS_URL=null
|
||||
# REDOC_URL=null
|
||||
# OPENAPI_URL=null
|
||||
# LOG_FILE=/var/log/wifi-densepose/app.log
|
||||
@@ -17,6 +17,9 @@ fastapi>=0.95.0
|
||||
uvicorn>=0.20.0
|
||||
websockets>=10.4
|
||||
pydantic>=1.10.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
python-multipart>=0.0.6
|
||||
passlib[bcrypt]>=1.7.4
|
||||
|
||||
# Hardware interface dependencies
|
||||
asyncio-mqtt>=0.11.0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Executable
+376
@@ -0,0 +1,376 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
API Endpoint Testing Script
|
||||
Tests all WiFi-DensePose API endpoints and provides debugging information.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
# Initialize colorama for colored output
|
||||
init(autoreset=True)
|
||||
|
||||
class APITester:
|
||||
"""Comprehensive API endpoint tester."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
self.results = {
|
||||
"total_tests": 0,
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"errors": [],
|
||||
"test_details": []
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
def log_success(self, message: str):
|
||||
"""Log success message."""
|
||||
print(f"{Fore.GREEN}✓ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""Log error message."""
|
||||
print(f"{Fore.RED}✗ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""Log info message."""
|
||||
print(f"{Fore.BLUE}ℹ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_warning(self, message: str):
|
||||
"""Log warning message."""
|
||||
print(f"{Fore.YELLOW}⚠ {message}{Style.RESET_ALL}")
|
||||
|
||||
async def test_endpoint(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
expected_status: int = 200,
|
||||
data: Optional[Dict] = None,
|
||||
params: Optional[Dict] = None,
|
||||
headers: Optional[Dict] = None,
|
||||
description: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""Test a single API endpoint."""
|
||||
self.results["total_tests"] += 1
|
||||
test_name = f"{method.upper()} {endpoint}"
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# Prepare request
|
||||
kwargs = {}
|
||||
if data:
|
||||
kwargs["json"] = data
|
||||
if params:
|
||||
kwargs["params"] = params
|
||||
if headers:
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Make request
|
||||
start_time = time.time()
|
||||
async with self.session.request(method, url, **kwargs) as response:
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
response_text = await response.text()
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
response_data = json.loads(response_text) if response_text else {}
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"raw_response": response_text}
|
||||
|
||||
# Check status code
|
||||
status_ok = response.status == expected_status
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": url,
|
||||
"method": method.upper(),
|
||||
"expected_status": expected_status,
|
||||
"actual_status": response.status,
|
||||
"response_time_ms": round(response_time, 2),
|
||||
"response_data": response_data,
|
||||
"success": status_ok,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if status_ok:
|
||||
self.results["passed"] += 1
|
||||
self.log_success(f"{test_name} - {response.status} ({response_time:.1f}ms)")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
else:
|
||||
self.results["failed"] += 1
|
||||
self.log_error(f"{test_name} - Expected {expected_status}, got {response.status}")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
print(f" Response: {response_text[:200]}...")
|
||||
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
self.results["failed"] += 1
|
||||
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||
self.log_error(error_msg)
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": f"{self.base_url}{endpoint}",
|
||||
"method": method.upper(),
|
||||
"expected_status": expected_status,
|
||||
"actual_status": None,
|
||||
"response_time_ms": None,
|
||||
"response_data": None,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["errors"].append(error_msg)
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
async def test_websocket_endpoint(self, endpoint: str, description: str = "") -> Dict[str, Any]:
|
||||
"""Test WebSocket endpoint."""
|
||||
self.results["total_tests"] += 1
|
||||
test_name = f"WebSocket {endpoint}"
|
||||
|
||||
try:
|
||||
ws_url = f"ws://localhost:8000{endpoint}"
|
||||
|
||||
start_time = time.time()
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
# Send a test message
|
||||
test_message = {"type": "subscribe", "zone_ids": ["zone_1"]}
|
||||
await websocket.send(json.dumps(test_message))
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=3)
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
|
||||
try:
|
||||
response_data = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"raw_response": response}
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": ws_url,
|
||||
"method": "WebSocket",
|
||||
"response_time_ms": round(response_time, 2),
|
||||
"response_data": response_data,
|
||||
"success": True,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["passed"] += 1
|
||||
self.log_success(f"{test_name} - Connected ({response_time:.1f}ms)")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
self.results["failed"] += 1
|
||||
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||
self.log_error(error_msg)
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": f"ws://localhost:8000{endpoint}",
|
||||
"method": "WebSocket",
|
||||
"response_time_ms": None,
|
||||
"response_data": None,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["errors"].append(error_msg)
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all API endpoint tests."""
|
||||
print(f"{Fore.CYAN}{'='*60}")
|
||||
print(f"{Fore.CYAN}WiFi-DensePose API Endpoint Testing")
|
||||
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||
print()
|
||||
|
||||
# Test Health Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Health Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/health/health", description="System health check")
|
||||
await self.test_endpoint("GET", "/health/ready", description="Readiness check")
|
||||
print()
|
||||
|
||||
# Test Pose Estimation Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Pose Estimation Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/current", description="Current pose estimation")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/current",
|
||||
params={"zone_ids": ["zone_1"], "confidence_threshold": 0.7},
|
||||
description="Current pose estimation with parameters")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/analyze", description="Pose analysis (requires auth)")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/zones/zone_1/occupancy", description="Zone occupancy")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/zones/summary", description="All zones summary")
|
||||
print()
|
||||
|
||||
# Test Historical Data Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Historical Data Endpoints:{Style.RESET_ALL}")
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(hours=1)
|
||||
historical_data = {
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"zone_ids": ["zone_1"],
|
||||
"aggregation_interval": 300
|
||||
}
|
||||
await self.test_endpoint("POST", "/api/v1/pose/historical",
|
||||
data=historical_data,
|
||||
description="Historical pose data (requires auth)")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/activities", description="Recent activities")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/activities",
|
||||
params={"zone_id": "zone_1", "limit": 5},
|
||||
description="Activities for specific zone")
|
||||
print()
|
||||
|
||||
# Test Calibration Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Calibration Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/calibration/status", description="Calibration status (requires auth)")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/calibrate", description="Start calibration (requires auth)")
|
||||
print()
|
||||
|
||||
# Test Statistics Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Statistics Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/stats", description="Pose statistics")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/stats",
|
||||
params={"hours": 12}, description="Pose statistics (12 hours)")
|
||||
print()
|
||||
|
||||
# Test Stream Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Stream Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/stream/status", description="Stream status")
|
||||
await self.test_endpoint("POST", "/api/v1/stream/start", description="Start streaming (requires auth)")
|
||||
await self.test_endpoint("POST", "/api/v1/stream/stop", description="Stop streaming (requires auth)")
|
||||
print()
|
||||
|
||||
# Test WebSocket Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing WebSocket Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_websocket_endpoint("/ws/pose", description="Pose WebSocket")
|
||||
await self.test_websocket_endpoint("/ws/hardware", description="Hardware WebSocket")
|
||||
print()
|
||||
|
||||
# Test Documentation Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Documentation Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/docs", description="API documentation")
|
||||
await self.test_endpoint("GET", "/openapi.json", description="OpenAPI schema")
|
||||
print()
|
||||
|
||||
# Test API Info Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing API Info Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/", description="Root endpoint")
|
||||
await self.test_endpoint("GET", "/api/v1/info", description="API information")
|
||||
await self.test_endpoint("GET", "/api/v1/status", description="API status")
|
||||
print()
|
||||
|
||||
# Test Error Cases
|
||||
print(f"{Fore.MAGENTA}Testing Error Cases:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/nonexistent", expected_status=404,
|
||||
description="Non-existent endpoint")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/analyze",
|
||||
data={"invalid": "data"}, expected_status=401,
|
||||
description="Unauthorized request (no auth)")
|
||||
print()
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary."""
|
||||
print(f"{Fore.CYAN}{'='*60}")
|
||||
print(f"{Fore.CYAN}Test Summary")
|
||||
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||
|
||||
total = self.results["total_tests"]
|
||||
passed = self.results["passed"]
|
||||
failed = self.results["failed"]
|
||||
success_rate = (passed / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"Total Tests: {total}")
|
||||
print(f"{Fore.GREEN}Passed: {passed}{Style.RESET_ALL}")
|
||||
print(f"{Fore.RED}Failed: {failed}{Style.RESET_ALL}")
|
||||
print(f"Success Rate: {success_rate:.1f}%")
|
||||
print()
|
||||
|
||||
if self.results["errors"]:
|
||||
print(f"{Fore.RED}Errors:{Style.RESET_ALL}")
|
||||
for error in self.results["errors"]:
|
||||
print(f" - {error}")
|
||||
print()
|
||||
|
||||
# Save detailed results to file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results_file = f"scripts/api_test_results_{timestamp}.json"
|
||||
|
||||
try:
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(self.results, f, indent=2, default=str)
|
||||
print(f"Detailed results saved to: {results_file}")
|
||||
except Exception as e:
|
||||
self.log_warning(f"Could not save results file: {e}")
|
||||
|
||||
return failed == 0
|
||||
|
||||
async def main():
|
||||
"""Main test function."""
|
||||
try:
|
||||
async with APITester() as tester:
|
||||
await tester.run_all_tests()
|
||||
success = tester.print_summary()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n{Fore.YELLOW}Tests interrupted by user{Style.RESET_ALL}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n{Fore.RED}Fatal error: {e}{Style.RESET_ALL}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if required packages are available
|
||||
try:
|
||||
import aiohttp
|
||||
import websockets
|
||||
import colorama
|
||||
except ImportError as e:
|
||||
print(f"Missing required package: {e}")
|
||||
print("Install with: pip install aiohttp websockets colorama")
|
||||
sys.exit(1)
|
||||
|
||||
# Run tests
|
||||
asyncio.run(main())
|
||||
@@ -246,8 +246,15 @@ if __name__ != '__main__':
|
||||
|
||||
|
||||
# Compatibility aliases for backward compatibility
|
||||
try:
|
||||
WifiDensePose = app # Legacy alias
|
||||
except NameError:
|
||||
WifiDensePose = None # Will be None if app import failed
|
||||
|
||||
try:
|
||||
get_config = get_settings # Legacy alias
|
||||
except NameError:
|
||||
get_config = None # Will be None if get_settings import failed
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+2
-2
@@ -2,6 +2,6 @@
|
||||
WiFi-DensePose FastAPI application package
|
||||
"""
|
||||
|
||||
from .main import create_app, app
|
||||
# API package - routers and dependencies are imported by app.py
|
||||
|
||||
__all__ = ["create_app", "app"]
|
||||
__all__ = []
|
||||
@@ -418,6 +418,21 @@ async def get_websocket_user(
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_ws(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current user for WebSocket connections."""
|
||||
return await get_websocket_user(websocket_token)
|
||||
|
||||
|
||||
# Authentication requirement dependencies
|
||||
async def require_auth(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require authentication for endpoint access."""
|
||||
return current_user
|
||||
|
||||
|
||||
# Development dependencies
|
||||
async def development_only():
|
||||
"""Dependency that only allows access in development."""
|
||||
|
||||
+19
-27
@@ -7,18 +7,11 @@ import psutil
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_hardware_service,
|
||||
get_pose_service,
|
||||
get_stream_service,
|
||||
get_current_user
|
||||
)
|
||||
from src.services.hardware_service import HardwareService
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
from src.api.dependencies import get_current_user
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,20 +51,19 @@ class ReadinessCheck(BaseModel):
|
||||
|
||||
# Health check endpoints
|
||||
@router.get("/health", response_model=SystemHealth)
|
||||
async def health_check(
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
async def health_check(request: Request):
|
||||
"""Comprehensive system health check."""
|
||||
try:
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
components = {}
|
||||
overall_status = "healthy"
|
||||
|
||||
# Check hardware service
|
||||
try:
|
||||
hw_health = await hardware_service.health_check()
|
||||
hw_health = await orchestrator.hardware_service.health_check()
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status=hw_health["status"],
|
||||
@@ -96,7 +88,7 @@ async def health_check(
|
||||
|
||||
# Check pose service
|
||||
try:
|
||||
pose_health = await pose_service.health_check()
|
||||
pose_health = await orchestrator.pose_service.health_check()
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status=pose_health["status"],
|
||||
@@ -121,7 +113,7 @@ async def health_check(
|
||||
|
||||
# Check stream service
|
||||
try:
|
||||
stream_health = await stream_service.health_check()
|
||||
stream_health = await orchestrator.stream_service.health_check()
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status=stream_health["status"],
|
||||
@@ -167,20 +159,19 @@ async def health_check(
|
||||
|
||||
|
||||
@router.get("/ready", response_model=ReadinessCheck)
|
||||
async def readiness_check(
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
async def readiness_check(request: Request):
|
||||
"""Check if system is ready to serve requests."""
|
||||
try:
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
checks = {}
|
||||
|
||||
# Check if services are initialized and ready
|
||||
checks["hardware_ready"] = await hardware_service.is_ready()
|
||||
checks["pose_ready"] = await pose_service.is_ready()
|
||||
checks["stream_ready"] = await stream_service.is_ready()
|
||||
checks["hardware_ready"] = await orchestrator.hardware_service.is_ready()
|
||||
checks["pose_ready"] = await orchestrator.pose_service.is_ready()
|
||||
checks["stream_ready"] = await orchestrator.stream_service.is_ready()
|
||||
|
||||
# Check system resources
|
||||
checks["memory_available"] = check_memory_availability()
|
||||
@@ -221,7 +212,8 @@ async def liveness_check():
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_system_metrics(
|
||||
async def get_health_metrics(
|
||||
request: Request,
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed system metrics."""
|
||||
|
||||
+43
-11
@@ -73,7 +73,8 @@ async def websocket_pose_stream(
|
||||
websocket: WebSocket,
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
||||
max_fps: int = Query(30, ge=1, le=60)
|
||||
max_fps: int = Query(30, ge=1, le=60),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time pose data streaming."""
|
||||
client_id = None
|
||||
@@ -82,6 +83,18 @@ async def websocket_pose_stream(
|
||||
# Accept WebSocket connection
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
@@ -146,7 +159,8 @@ async def websocket_pose_stream(
|
||||
async def websocket_events_stream(
|
||||
websocket: WebSocket,
|
||||
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs")
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time event streaming."""
|
||||
client_id = None
|
||||
@@ -154,6 +168,18 @@ async def websocket_events_stream(
|
||||
try:
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse parameters
|
||||
event_list = None
|
||||
if event_types:
|
||||
@@ -244,19 +270,27 @@ async def handle_websocket_message(client_id: str, data: Dict[str, Any], websock
|
||||
# HTTP endpoints for stream management
|
||||
@router.get("/status", response_model=StreamStatus)
|
||||
async def get_stream_status(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
"""Get current streaming status."""
|
||||
try:
|
||||
status = await stream_service.get_status()
|
||||
connections = await connection_manager.get_connection_stats()
|
||||
|
||||
# Calculate uptime (simplified for now)
|
||||
uptime_seconds = 0.0
|
||||
if status.get("running", False):
|
||||
uptime_seconds = 3600.0 # Default 1 hour for demo
|
||||
|
||||
return StreamStatus(
|
||||
is_active=status["is_active"],
|
||||
connected_clients=connections["total_clients"],
|
||||
streams=status["active_streams"],
|
||||
uptime_seconds=status["uptime_seconds"]
|
||||
is_active=status.get("running", False),
|
||||
connected_clients=connections.get("total_clients", status["connections"]["active"]),
|
||||
streams=[{
|
||||
"type": "pose_stream",
|
||||
"active": status.get("running", False),
|
||||
"buffer_size": status["buffers"]["pose_buffer_size"]
|
||||
}],
|
||||
uptime_seconds=uptime_seconds
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -416,9 +450,7 @@ async def broadcast_message(
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_streaming_metrics(
|
||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
||||
):
|
||||
async def get_streaming_metrics():
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
metrics = await connection_manager.get_metrics()
|
||||
|
||||
@@ -120,7 +120,7 @@ class ConnectionManager:
|
||||
"start_time": datetime.utcnow()
|
||||
}
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
self._started = False
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
@@ -413,6 +413,13 @@ class ConnectionManager:
|
||||
if stale_clients:
|
||||
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection manager."""
|
||||
if not self._started:
|
||||
self._start_cleanup_task()
|
||||
self._started = True
|
||||
logger.info("Connection manager started")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task."""
|
||||
async def cleanup_loop():
|
||||
@@ -428,7 +435,11 @@ class ConnectionManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup task: {e}")
|
||||
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
except RuntimeError:
|
||||
# No event loop running, will start later
|
||||
logger.debug("No event loop running, cleanup task will start later")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown connection manager."""
|
||||
|
||||
+20
-8
@@ -3,6 +3,7 @@ FastAPI application factory and configuration
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,10 +16,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.middleware.auth import AuthMiddleware
|
||||
from src.middleware.cors import setup_cors
|
||||
from src.middleware.auth import AuthenticationMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlerMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
|
||||
@@ -34,6 +35,9 @@ async def lifespan(app: FastAPI):
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = app.state.orchestrator
|
||||
|
||||
# Start connection manager
|
||||
await connection_manager.start()
|
||||
|
||||
# Start all services
|
||||
await orchestrator.start()
|
||||
|
||||
@@ -47,6 +51,10 @@ async def lifespan(app: FastAPI):
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
if hasattr(app.state, 'orchestrator'):
|
||||
await app.state.orchestrator.shutdown()
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
@@ -88,19 +96,23 @@ def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI
|
||||
def setup_middleware(app: FastAPI, settings: Settings):
|
||||
"""Setup application middleware."""
|
||||
|
||||
# Error handling middleware (should be first)
|
||||
app.add_middleware(ErrorHandlerMiddleware)
|
||||
|
||||
# Rate limiting middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware, settings=settings)
|
||||
|
||||
# Authentication middleware
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthMiddleware, settings=settings)
|
||||
app.add_middleware(AuthenticationMiddleware, settings=settings)
|
||||
|
||||
# CORS middleware
|
||||
setup_cors(app, settings)
|
||||
if settings.cors_enabled:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Trusted host middleware for production
|
||||
if settings.is_production:
|
||||
|
||||
+8
-2
@@ -14,8 +14,9 @@ from src.commands.start import start_command
|
||||
from src.commands.stop import stop_command
|
||||
from src.commands.status import status_command
|
||||
|
||||
# Setup logging for CLI
|
||||
setup_logging()
|
||||
# Get default settings and setup logging for CLI
|
||||
settings = get_settings()
|
||||
setup_logging(settings)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -498,5 +499,10 @@ def version():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def create_cli(orchestrator=None):
|
||||
"""Create CLI interface for the application."""
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
@@ -349,6 +349,10 @@ class DomainConfig:
|
||||
|
||||
return routers
|
||||
|
||||
def get_all_routers(self) -> List[RouterConfig]:
|
||||
"""Get all router configurations."""
|
||||
return list(self.routers.values())
|
||||
|
||||
def validate_configuration(self) -> List[str]:
|
||||
"""Validate the entire configuration."""
|
||||
issues = []
|
||||
|
||||
@@ -97,6 +97,8 @@ class Settings(BaseSettings):
|
||||
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
||||
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
||||
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
|
||||
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Development settings
|
||||
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .csi_processor import CSIProcessor
|
||||
from .phase_sanitizer import PhaseSanitizer
|
||||
from .router_interface import RouterInterface
|
||||
|
||||
__all__ = [
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'RouterInterface'
|
||||
]
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
|
||||
class CSIProcessor:
|
||||
@@ -18,6 +20,11 @@ class CSIProcessor:
|
||||
self.sample_rate = self.config.get('sample_rate', 1000)
|
||||
self.num_subcarriers = self.config.get('num_subcarriers', 56)
|
||||
self.num_antennas = self.config.get('num_antennas', 3)
|
||||
self.buffer_size = self.config.get('buffer_size', 1000)
|
||||
|
||||
# Data buffer for temporal processing
|
||||
self.data_buffer = deque(maxlen=self.buffer_size)
|
||||
self.last_processed_data = None
|
||||
|
||||
def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
|
||||
"""Process raw CSI data into normalized format.
|
||||
@@ -77,3 +84,46 @@ class CSIProcessor:
|
||||
|
||||
# Convert to tensor
|
||||
return torch.from_numpy(processed_data).float()
|
||||
|
||||
def add_data(self, csi_data: np.ndarray, timestamp: datetime):
|
||||
"""Add CSI data to the processing buffer.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data array
|
||||
timestamp: Timestamp of the data sample
|
||||
"""
|
||||
sample = {
|
||||
'data': csi_data,
|
||||
'timestamp': timestamp,
|
||||
'processed': False
|
||||
}
|
||||
self.data_buffer.append(sample)
|
||||
|
||||
def get_processed_data(self) -> Optional[np.ndarray]:
|
||||
"""Get the most recent processed CSI data.
|
||||
|
||||
Returns:
|
||||
Processed CSI data array or None if no data available
|
||||
"""
|
||||
if not self.data_buffer:
|
||||
return None
|
||||
|
||||
# Get the most recent unprocessed sample
|
||||
recent_sample = None
|
||||
for sample in reversed(self.data_buffer):
|
||||
if not sample['processed']:
|
||||
recent_sample = sample
|
||||
break
|
||||
|
||||
if recent_sample is None:
|
||||
return self.last_processed_data
|
||||
|
||||
# Process the data
|
||||
try:
|
||||
processed_data = self.process_raw_csi(recent_sample['data'])
|
||||
recent_sample['processed'] = True
|
||||
self.last_processed_data = processed_data
|
||||
return processed_data
|
||||
except Exception as e:
|
||||
# Return last known good data if processing fails
|
||||
return self.last_processed_data
|
||||
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Router interface for WiFi CSI data collection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouterInterface:
|
||||
"""Interface for connecting to WiFi routers and collecting CSI data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_id: str,
|
||||
host: str,
|
||||
port: int = 22,
|
||||
username: str = "admin",
|
||||
password: str = "",
|
||||
interface: str = "wlan0",
|
||||
mock_mode: bool = False
|
||||
):
|
||||
"""Initialize router interface.
|
||||
|
||||
Args:
|
||||
router_id: Unique identifier for the router
|
||||
host: Router IP address or hostname
|
||||
port: SSH port for connection
|
||||
username: SSH username
|
||||
password: SSH password
|
||||
interface: WiFi interface name
|
||||
mock_mode: Whether to use mock data instead of real connection
|
||||
"""
|
||||
self.router_id = router_id
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.interface = interface
|
||||
self.mock_mode = mock_mode
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.{router_id}")
|
||||
|
||||
# Connection state
|
||||
self.is_connected = False
|
||||
self.connection = None
|
||||
self.last_error = None
|
||||
|
||||
# Data collection state
|
||||
self.last_data_time = None
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
|
||||
# Mock data generation
|
||||
self.mock_data_generator = None
|
||||
if mock_mode:
|
||||
self._initialize_mock_generator()
|
||||
|
||||
def _initialize_mock_generator(self):
|
||||
"""Initialize mock data generator."""
|
||||
self.mock_data_generator = {
|
||||
'phase': 0,
|
||||
'amplitude_base': 1.0,
|
||||
'frequency': 0.1,
|
||||
'noise_level': 0.1
|
||||
}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the router."""
|
||||
if self.mock_mode:
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||
|
||||
# In a real implementation, this would establish SSH connection
|
||||
# For now, we'll simulate the connection
|
||||
await asyncio.sleep(0.1) # Simulate connection delay
|
||||
|
||||
self.is_connected = True
|
||||
self.error_count = 0
|
||||
self.logger.info(f"Connected to router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the router."""
|
||||
try:
|
||||
if self.connection:
|
||||
# Close SSH connection
|
||||
self.connection = None
|
||||
|
||||
self.is_connected = False
|
||||
self.logger.info(f"Disconnected from router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
|
||||
|
||||
async def reconnect(self):
|
||||
"""Reconnect to the router."""
|
||||
await self.disconnect()
|
||||
await asyncio.sleep(1) # Wait before reconnecting
|
||||
await self.connect()
|
||||
|
||||
async def get_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Get CSI data from the router.
|
||||
|
||||
Returns:
|
||||
CSI data as numpy array, or None if no data available
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(f"Router {self.router_id} is not connected")
|
||||
|
||||
try:
|
||||
if self.mock_mode:
|
||||
csi_data = self._generate_mock_csi_data()
|
||||
else:
|
||||
csi_data = await self._collect_real_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
self.last_data_time = datetime.now()
|
||||
self.sample_count += 1
|
||||
self.error_count = 0
|
||||
|
||||
return csi_data
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
|
||||
return None
|
||||
|
||||
def _generate_mock_csi_data(self) -> np.ndarray:
|
||||
"""Generate mock CSI data for testing."""
|
||||
# Simulate CSI data with realistic characteristics
|
||||
num_subcarriers = 64
|
||||
num_antennas = 4
|
||||
num_samples = 100
|
||||
|
||||
# Update mock generator state
|
||||
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
|
||||
|
||||
# Generate amplitude and phase data
|
||||
time_axis = np.linspace(0, 1, num_samples)
|
||||
|
||||
# Create realistic CSI patterns
|
||||
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
|
||||
|
||||
for antenna in range(num_antennas):
|
||||
for subcarrier in range(num_subcarriers):
|
||||
# Base signal with some variation per antenna/subcarrier
|
||||
amplitude = (
|
||||
self.mock_data_generator['amplitude_base'] *
|
||||
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
|
||||
(1 + 0.1 * antenna)
|
||||
)
|
||||
|
||||
# Phase with spatial and frequency variation
|
||||
phase_offset = (
|
||||
self.mock_data_generator['phase'] +
|
||||
2 * np.pi * subcarrier / num_subcarriers +
|
||||
np.pi * antenna / num_antennas
|
||||
)
|
||||
|
||||
# Add some movement simulation
|
||||
movement_freq = 0.5 # Hz
|
||||
movement_amplitude = 0.3
|
||||
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
|
||||
|
||||
# Generate complex signal
|
||||
signal_amplitude = amplitude * (1 + movement)
|
||||
signal_phase = phase_offset + movement * 0.5
|
||||
|
||||
# Add noise
|
||||
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise = noise_real + 1j * noise_imag
|
||||
|
||||
# Create complex signal
|
||||
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
|
||||
csi_data[antenna, subcarrier, :] = signal
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Collect real CSI data from router (placeholder implementation)."""
|
||||
# This would implement the actual CSI data collection
|
||||
# For now, return None to indicate no real implementation
|
||||
self.logger.warning("Real CSI data collection not implemented")
|
||||
return None
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""Check if the router connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# In mock mode, always healthy
|
||||
if self.mock_mode:
|
||||
return True
|
||||
|
||||
# For real connections, we could ping the router or check SSH connection
|
||||
# For now, consider healthy if error count is low
|
||||
return self.error_count < 5
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router status
|
||||
"""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode,
|
||||
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
|
||||
"error_count": self.error_count,
|
||||
"sample_count": self.sample_count,
|
||||
"last_error": self.last_error,
|
||||
"configuration": {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"username": self.username,
|
||||
"interface": self.interface
|
||||
}
|
||||
}
|
||||
|
||||
async def get_router_info(self) -> Dict[str, Any]:
|
||||
"""Get router hardware information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router information
|
||||
"""
|
||||
if self.mock_mode:
|
||||
return {
|
||||
"model": "Mock Router",
|
||||
"firmware": "1.0.0-mock",
|
||||
"wifi_standard": "802.11ac",
|
||||
"antennas": 4,
|
||||
"supported_bands": ["2.4GHz", "5GHz"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 4,
|
||||
"sampling_rate": 1000
|
||||
}
|
||||
}
|
||||
|
||||
# For real routers, this would query the actual hardware
|
||||
return {
|
||||
"model": "Unknown",
|
||||
"firmware": "Unknown",
|
||||
"wifi_standard": "Unknown",
|
||||
"antennas": 1,
|
||||
"supported_bands": ["Unknown"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 1,
|
||||
"sampling_rate": 100
|
||||
}
|
||||
}
|
||||
|
||||
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
|
||||
"""Configure CSI data collection parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self.mock_mode:
|
||||
# Update mock generator parameters
|
||||
if 'sampling_rate' in config:
|
||||
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
|
||||
|
||||
if 'noise_level' in config:
|
||||
self.mock_data_generator['noise_level'] = config['noise_level']
|
||||
|
||||
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
|
||||
return True
|
||||
|
||||
# For real routers, this would send configuration commands
|
||||
self.logger.warning("Real CSI configuration not implemented")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get router interface metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing metrics
|
||||
"""
|
||||
uptime = 0
|
||||
if self.last_data_time:
|
||||
uptime = (datetime.now() - self.last_data_time).total_seconds()
|
||||
|
||||
success_rate = 0
|
||||
if self.sample_count > 0:
|
||||
success_rate = (self.sample_count - self.error_count) / self.sample_count
|
||||
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"sample_count": self.sample_count,
|
||||
"error_count": self.error_count,
|
||||
"success_rate": success_rate,
|
||||
"uptime_seconds": uptime,
|
||||
"is_connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset statistics counters."""
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
self.last_error = None
|
||||
self.logger.info(f"Statistics reset for router {self.router_id}")
|
||||
@@ -307,31 +307,34 @@ class ErrorHandler:
|
||||
class ErrorHandlingMiddleware:
|
||||
"""Error handling middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
def __init__(self, app, settings: Settings):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.error_handler = ErrorHandler(settings)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Process request through error handling middleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except HTTPException as exc:
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
except RequestValidationError as exc:
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
except ValidationError as exc:
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as exc:
|
||||
# Create a mock request for error handling
|
||||
from starlette.requests import Request
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Handle different exception types
|
||||
if isinstance(exc, HTTPException):
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
elif isinstance(exc, RequestValidationError):
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
elif isinstance(exc, ValidationError):
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
else:
|
||||
# Check for specific error types
|
||||
if self._is_database_error(exc):
|
||||
error_response = self.error_handler.handle_database_error(request, exc)
|
||||
@@ -340,7 +343,9 @@ class ErrorHandlingMiddleware:
|
||||
else:
|
||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||
|
||||
return error_response.to_response()
|
||||
# Send the error response
|
||||
response = error_response.to_response()
|
||||
await response(scope, receive, send)
|
||||
|
||||
finally:
|
||||
# Log request processing time
|
||||
@@ -424,11 +429,10 @@ def setup_error_handling(app, settings: Settings):
|
||||
return error_response.to_response()
|
||||
|
||||
# Add middleware for additional error handling
|
||||
middleware = ErrorHandlingMiddleware(settings)
|
||||
|
||||
@app.middleware("http")
|
||||
async def error_handling_middleware(request: Request, call_next):
|
||||
return await middleware(request, call_next)
|
||||
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
||||
# The middleware approach is commented out but kept for reference
|
||||
# middleware = ErrorHandlingMiddleware(app, settings)
|
||||
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
||||
|
||||
logger.info("Error handling configured")
|
||||
|
||||
|
||||
@@ -5,9 +5,15 @@ Services package for WiFi-DensePose API
|
||||
from .orchestrator import ServiceOrchestrator
|
||||
from .health_check import HealthCheckService
|
||||
from .metrics import MetricsService
|
||||
from .pose_service import PoseService
|
||||
from .stream_service import StreamService
|
||||
from .hardware_service import HardwareService
|
||||
|
||||
__all__ = [
|
||||
'ServiceOrchestrator',
|
||||
'HealthCheckService',
|
||||
'MetricsService'
|
||||
'MetricsService',
|
||||
'PoseService',
|
||||
'StreamService',
|
||||
'HardwareService'
|
||||
]
|
||||
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Hardware interface service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.router_interface import RouterInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HardwareService:
|
||||
"""Service for hardware interface operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize hardware service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Router interfaces
|
||||
self.router_interfaces: Dict[str, RouterInterface] = {}
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Data collection statistics
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": 0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.collection_task = None
|
||||
self.monitoring_task = None
|
||||
|
||||
# Data buffers
|
||||
self.recent_samples = []
|
||||
self.max_recent_samples = 1000
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the hardware service."""
|
||||
await self.start()
|
||||
|
||||
async def start(self):
|
||||
"""Start the hardware service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info("Starting hardware service...")
|
||||
|
||||
# Initialize router interfaces
|
||||
await self._initialize_routers()
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start background tasks
|
||||
if not self.settings.mock_hardware:
|
||||
self.collection_task = asyncio.create_task(self._data_collection_loop())
|
||||
|
||||
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
|
||||
|
||||
self.logger.info("Hardware service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to start hardware service: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the hardware service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self.collection_task:
|
||||
self.collection_task.cancel()
|
||||
try:
|
||||
await self.collection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.monitoring_task:
|
||||
self.monitoring_task.cancel()
|
||||
try:
|
||||
await self.monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect from routers
|
||||
await self._disconnect_routers()
|
||||
|
||||
self.logger.info("Hardware service stopped")
|
||||
|
||||
async def _initialize_routers(self):
|
||||
"""Initialize router interfaces."""
|
||||
try:
|
||||
# Get router configurations from domain config
|
||||
routers = self.domain_config.get_all_routers()
|
||||
|
||||
for router_config in routers:
|
||||
if not router_config.enabled:
|
||||
continue
|
||||
|
||||
router_id = router_config.router_id
|
||||
|
||||
# Create router interface
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
# Connect to router
|
||||
if not self.settings.mock_hardware:
|
||||
await router_interface.connect()
|
||||
|
||||
self.router_interfaces[router_id] = router_interface
|
||||
self.logger.info(f"Router interface initialized: {router_id}")
|
||||
|
||||
self.stats["connected_routers"] = len(self.router_interfaces)
|
||||
|
||||
if not self.router_interfaces:
|
||||
self.logger.warning("No router interfaces configured")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize routers: {e}")
|
||||
raise
|
||||
|
||||
async def _disconnect_routers(self):
|
||||
"""Disconnect from all routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
await interface.disconnect()
|
||||
self.logger.info(f"Disconnected from router: {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
|
||||
|
||||
self.router_interfaces.clear()
|
||||
self.stats["connected_routers"] = 0
|
||||
|
||||
async def _data_collection_loop(self):
|
||||
"""Background loop for data collection."""
|
||||
try:
|
||||
while self.is_running:
|
||||
start_time = time.time()
|
||||
|
||||
# Collect data from all routers
|
||||
await self._collect_data_from_routers()
|
||||
|
||||
# Calculate sleep time to maintain polling interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
|
||||
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Data collection loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in data collection loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""Background loop for hardware monitoring."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Monitor router connections
|
||||
await self._monitor_router_health()
|
||||
|
||||
# Update statistics
|
||||
self._update_sample_rate_stats()
|
||||
|
||||
# Wait before next check
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Monitoring loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monitoring loop: {e}")
|
||||
|
||||
async def _collect_data_from_routers(self):
|
||||
"""Collect CSI data from all connected routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
# Get CSI data from router
|
||||
csi_data = await interface.get_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
# Process the collected data
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
|
||||
self.stats["successful_samples"] += 1
|
||||
self.stats["last_sample_time"] = datetime.now().isoformat()
|
||||
else:
|
||||
self.stats["failed_samples"] += 1
|
||||
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error collecting data from router {router_id}: {e}")
|
||||
self.stats["failed_samples"] += 1
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
|
||||
"""Process collected CSI data."""
|
||||
try:
|
||||
# Create sample metadata
|
||||
metadata = {
|
||||
"router_id": router_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sample_rate": self.stats["average_sample_rate"],
|
||||
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
|
||||
}
|
||||
|
||||
# Add to recent samples buffer
|
||||
sample = {
|
||||
"router_id": router_id,
|
||||
"timestamp": metadata["timestamp"],
|
||||
"data": csi_data,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
self.recent_samples.append(sample)
|
||||
|
||||
# Maintain buffer size
|
||||
if len(self.recent_samples) > self.max_recent_samples:
|
||||
self.recent_samples.pop(0)
|
||||
|
||||
# Notify other services (this would typically be done through an event system)
|
||||
# For now, we'll just log the data collection
|
||||
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing collected data: {e}")
|
||||
|
||||
async def _monitor_router_health(self):
|
||||
"""Monitor health of router connections."""
|
||||
healthy_routers = 0
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
|
||||
if is_healthy:
|
||||
healthy_routers += 1
|
||||
else:
|
||||
self.logger.warning(f"Router {router_id} is unhealthy")
|
||||
|
||||
# Try to reconnect if not in mock mode
|
||||
if not self.settings.mock_hardware:
|
||||
try:
|
||||
await interface.reconnect()
|
||||
self.logger.info(f"Reconnected to router {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {router_id}: {e}")
|
||||
|
||||
self.stats["connected_routers"] = healthy_routers
|
||||
|
||||
def _update_sample_rate_stats(self):
|
||||
"""Update sample rate statistics."""
|
||||
if len(self.recent_samples) < 2:
|
||||
return
|
||||
|
||||
# Calculate sample rate from recent samples
|
||||
recent_count = min(100, len(self.recent_samples))
|
||||
recent_samples = self.recent_samples[-recent_count:]
|
||||
|
||||
if len(recent_samples) >= 2:
|
||||
# Calculate time differences
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_samples)):
|
||||
try:
|
||||
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
|
||||
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
|
||||
diff = (t2 - t1).total_seconds()
|
||||
if diff > 0:
|
||||
time_diffs.append(diff)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if time_diffs:
|
||||
avg_interval = sum(time_diffs) / len(time_diffs)
|
||||
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
|
||||
|
||||
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a specific router."""
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
status = await interface.get_status()
|
||||
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": is_healthy,
|
||||
"connected": status.get("connected", False),
|
||||
"last_data_time": status.get("last_data_time"),
|
||||
"error_count": status.get("error_count", 0),
|
||||
"configuration": status.get("configuration", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"connected": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_all_router_status(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
statuses = []
|
||||
|
||||
for router_id in self.router_interfaces:
|
||||
try:
|
||||
status = await self.get_router_status(router_id)
|
||||
statuses.append(status)
|
||||
except Exception as e:
|
||||
statuses.append({
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return statuses
|
||||
|
||||
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent CSI data samples."""
|
||||
samples = self.recent_samples[-limit:] if limit else self.recent_samples
|
||||
|
||||
if router_id:
|
||||
samples = [s for s in samples if s["router_id"] == router_id]
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
result = []
|
||||
for sample in samples:
|
||||
sample_copy = sample.copy()
|
||||
if isinstance(sample_copy["data"], np.ndarray):
|
||||
sample_copy["data"] = sample_copy["data"].tolist()
|
||||
result.append(sample_copy)
|
||||
|
||||
return result
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_hardware": self.settings.mock_hardware,
|
||||
"wifi_interface": self.settings.wifi_interface,
|
||||
"polling_interval": self.settings.hardware_polling_interval,
|
||||
"buffer_size": self.settings.csi_buffer_size
|
||||
},
|
||||
"routers": await self.get_all_router_status()
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_samples = self.stats["total_samples"]
|
||||
success_rate = self.stats["successful_samples"] / max(1, total_samples)
|
||||
|
||||
return {
|
||||
"hardware_service": {
|
||||
"total_samples": total_samples,
|
||||
"successful_samples": self.stats["successful_samples"],
|
||||
"failed_samples": self.stats["failed_samples"],
|
||||
"success_rate": success_rate,
|
||||
"average_sample_rate": self.stats["average_sample_rate"],
|
||||
"connected_routers": self.stats["connected_routers"],
|
||||
"last_sample_time": self.stats["last_sample_time"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": len(self.router_interfaces)
|
||||
}
|
||||
|
||||
self.recent_samples.clear()
|
||||
self.last_error = None
|
||||
|
||||
self.logger.info("Hardware service reset")
|
||||
|
||||
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Manually trigger data collection."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Hardware service is not running")
|
||||
|
||||
results = {}
|
||||
|
||||
if router_id:
|
||||
# Collect from specific router
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
try:
|
||||
csi_data = await interface.get_csi_data()
|
||||
if csi_data is not None:
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
|
||||
else:
|
||||
results[router_id] = {"success": False, "error": "No data received"}
|
||||
except Exception as e:
|
||||
results[router_id] = {"success": False, "error": str(e)}
|
||||
else:
|
||||
# Collect from all routers
|
||||
await self._collect_data_from_routers()
|
||||
results = {"message": "Manual collection triggered for all routers"}
|
||||
|
||||
return results
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
# Check router health
|
||||
healthy_routers = 0
|
||||
total_routers = len(self.router_interfaces)
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
if await interface.check_health():
|
||||
healthy_routers += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Hardware service is running normally",
|
||||
"connected_routers": f"{healthy_routers}/{total_routers}",
|
||||
"metrics": {
|
||||
"total_samples": self.stats["total_samples"],
|
||||
"success_rate": (
|
||||
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
|
||||
),
|
||||
"average_sample_rate": self.stats["average_sample_rate"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running and len(self.router_interfaces) > 0
|
||||
@@ -0,0 +1,706 @@
|
||||
"""
|
||||
Pose estimation service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseService:
|
||||
"""Service for pose estimation operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize pose service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.csi_processor = None
|
||||
self.phase_sanitizer = None
|
||||
self.densepose_model = None
|
||||
self.modality_translator = None
|
||||
|
||||
# Service state
|
||||
self.is_initialized = False
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Processing statistics
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pose service."""
|
||||
try:
|
||||
self.logger.info("Initializing pose service...")
|
||||
|
||||
# Initialize CSI processor
|
||||
csi_config = {
|
||||
'buffer_size': self.settings.csi_buffer_size,
|
||||
'sample_rate': 1000, # Default sampling rate
|
||||
'num_subcarriers': 56,
|
||||
'num_antennas': 3
|
||||
}
|
||||
self.csi_processor = CSIProcessor(config=csi_config)
|
||||
|
||||
# Initialize phase sanitizer
|
||||
self.phase_sanitizer = PhaseSanitizer()
|
||||
|
||||
# Initialize models if not mocking
|
||||
if not self.settings.mock_pose_data:
|
||||
await self._initialize_models()
|
||||
else:
|
||||
self.logger.info("Using mock pose data for development")
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("Pose service initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to initialize pose service: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_models(self):
|
||||
"""Initialize neural network models."""
|
||||
try:
|
||||
# Initialize DensePose model
|
||||
if self.settings.pose_model_path:
|
||||
self.densepose_model = DensePoseHead()
|
||||
# Load model weights if path is provided
|
||||
# model_state = torch.load(self.settings.pose_model_path)
|
||||
# self.densepose_model.load_state_dict(model_state)
|
||||
self.logger.info("DensePose model loaded")
|
||||
else:
|
||||
self.logger.warning("No pose model path provided, using default model")
|
||||
self.densepose_model = DensePoseHead()
|
||||
|
||||
# Initialize modality translation
|
||||
config = {
|
||||
'input_channels': 64, # CSI data channels
|
||||
'hidden_channels': [128, 256, 512],
|
||||
'output_channels': 256, # Visual feature channels
|
||||
'use_attention': True
|
||||
}
|
||||
self.modality_translator = ModalityTranslationNetwork(config)
|
||||
|
||||
# Set models to evaluation mode
|
||||
self.densepose_model.eval()
|
||||
self.modality_translator.eval()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize models: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start the pose service."""
|
||||
if not self.is_initialized:
|
||||
await self.initialize()
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Pose service started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the pose service."""
|
||||
self.is_running = False
|
||||
self.logger.info("Pose service stopped")
|
||||
|
||||
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process CSI data and estimate poses."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Pose service is not running")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Process CSI data
|
||||
processed_csi = await self._process_csi(csi_data, metadata)
|
||||
|
||||
# Estimate poses
|
||||
poses = await self._estimate_poses(processed_csi, metadata)
|
||||
|
||||
# Update statistics
|
||||
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self._update_stats(poses, processing_time)
|
||||
|
||||
return {
|
||||
"timestamp": start_time.isoformat(),
|
||||
"poses": poses,
|
||||
"metadata": metadata,
|
||||
"processing_time_ms": processing_time,
|
||||
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.stats["failed_detections"] += 1
|
||||
self.logger.error(f"Error processing CSI data: {e}")
|
||||
raise
|
||||
|
||||
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
|
||||
"""Process raw CSI data."""
|
||||
# Add CSI data to processor
|
||||
self.csi_processor.add_data(csi_data, metadata.get("timestamp", datetime.now()))
|
||||
|
||||
# Get processed data
|
||||
processed_data = self.csi_processor.get_processed_data()
|
||||
|
||||
# Apply phase sanitization
|
||||
if processed_data is not None:
|
||||
sanitized_data = self.phase_sanitizer.sanitize(processed_data)
|
||||
return sanitized_data
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Estimate poses from processed CSI data."""
|
||||
if self.settings.mock_pose_data:
|
||||
return self._generate_mock_poses()
|
||||
|
||||
try:
|
||||
# Convert CSI data to tensor
|
||||
csi_tensor = torch.from_numpy(csi_data).float()
|
||||
|
||||
# Add batch dimension if needed
|
||||
if len(csi_tensor.shape) == 2:
|
||||
csi_tensor = csi_tensor.unsqueeze(0)
|
||||
|
||||
# Translate modality (CSI to visual-like features)
|
||||
with torch.no_grad():
|
||||
visual_features = self.modality_translator(csi_tensor)
|
||||
|
||||
# Estimate poses using DensePose
|
||||
pose_outputs = self.densepose_model(visual_features)
|
||||
|
||||
# Convert outputs to pose detections
|
||||
poses = self._parse_pose_outputs(pose_outputs)
|
||||
|
||||
# Filter by confidence threshold
|
||||
filtered_poses = [
|
||||
pose for pose in poses
|
||||
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
|
||||
]
|
||||
|
||||
# Limit number of persons
|
||||
if len(filtered_poses) > self.settings.pose_max_persons:
|
||||
filtered_poses = sorted(
|
||||
filtered_poses,
|
||||
key=lambda x: x.get("confidence", 0.0),
|
||||
reverse=True
|
||||
)[:self.settings.pose_max_persons]
|
||||
|
||||
return filtered_poses
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in pose estimation: {e}")
|
||||
return []
|
||||
|
||||
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections."""
|
||||
poses = []
|
||||
|
||||
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||
# For now, generate mock poses based on the output shape
|
||||
batch_size = outputs.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract pose information (mock implementation)
|
||||
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||
"""Generate mock pose data for development."""
|
||||
import random
|
||||
|
||||
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
|
||||
poses = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = random.uniform(0.3, 0.95)
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_keypoints(self) -> List[Dict[str, Any]]:
|
||||
"""Generate keypoints for a person."""
|
||||
import random
|
||||
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
keypoints = []
|
||||
for name in keypoint_names:
|
||||
keypoints.append({
|
||||
"name": name,
|
||||
"x": random.uniform(0.1, 0.9),
|
||||
"y": random.uniform(0.1, 0.9),
|
||||
"confidence": random.uniform(0.5, 0.95)
|
||||
})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _generate_bounding_box(self) -> Dict[str, float]:
|
||||
"""Generate bounding box for a person."""
|
||||
import random
|
||||
|
||||
x = random.uniform(0.1, 0.6)
|
||||
y = random.uniform(0.1, 0.6)
|
||||
width = random.uniform(0.2, 0.4)
|
||||
height = random.uniform(0.3, 0.5)
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height
|
||||
}
|
||||
|
||||
def _classify_activity(self, features: torch.Tensor) -> str:
|
||||
"""Classify activity from features."""
|
||||
# Simple mock classification
|
||||
import random
|
||||
activities = ["standing", "sitting", "walking", "lying", "unknown"]
|
||||
return random.choice(activities)
|
||||
|
||||
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
|
||||
"""Update processing statistics."""
|
||||
self.stats["total_processed"] += 1
|
||||
|
||||
if poses:
|
||||
self.stats["successful_detections"] += 1
|
||||
confidences = [pose.get("confidence", 0.0) for pose in poses]
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
# Update running average
|
||||
total = self.stats["successful_detections"]
|
||||
current_avg = self.stats["average_confidence"]
|
||||
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
|
||||
else:
|
||||
self.stats["failed_detections"] += 1
|
||||
|
||||
# Update processing time (running average)
|
||||
total = self.stats["total_processed"]
|
||||
current_avg = self.stats["processing_time_ms"]
|
||||
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"initialized": self.is_initialized,
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_data": self.settings.mock_pose_data,
|
||||
"confidence_threshold": self.settings.pose_confidence_threshold,
|
||||
"max_persons": self.settings.pose_max_persons,
|
||||
"batch_size": self.settings.pose_processing_batch_size
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
return {
|
||||
"pose_service": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"successful_detections": self.stats["successful_detections"],
|
||||
"failed_detections": self.stats["failed_detections"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_confidence": self.stats["average_confidence"],
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
self.last_error = None
|
||||
self.logger.info("Pose service reset")
|
||||
|
||||
# API endpoint methods
|
||||
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Estimate poses with API parameters."""
|
||||
try:
|
||||
# Generate mock CSI data for estimation
|
||||
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
|
||||
metadata = {
|
||||
"timestamp": datetime.now(),
|
||||
"zone_ids": zone_ids or ["zone_1"],
|
||||
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
|
||||
"max_persons": max_persons or self.settings.pose_max_persons
|
||||
}
|
||||
|
||||
# Process the data
|
||||
result = await self.process_csi_data(mock_csi, metadata)
|
||||
|
||||
# Format for API response
|
||||
persons = []
|
||||
for i, pose in enumerate(result["poses"]):
|
||||
person = {
|
||||
"person_id": str(pose["person_id"]),
|
||||
"confidence": pose["confidence"],
|
||||
"bounding_box": pose["bounding_box"],
|
||||
"zone_id": zone_ids[0] if zone_ids else "zone_1",
|
||||
"activity": pose["activity"],
|
||||
"timestamp": datetime.fromisoformat(pose["timestamp"])
|
||||
}
|
||||
|
||||
if include_keypoints:
|
||||
person["keypoints"] = pose["keypoints"]
|
||||
|
||||
if include_segmentation:
|
||||
person["segmentation"] = {"mask": "mock_segmentation_data"}
|
||||
|
||||
persons.append(person)
|
||||
|
||||
# Zone summary
|
||||
zone_summary = {}
|
||||
for zone_id in (zone_ids or ["zone_1"]):
|
||||
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(),
|
||||
"frame_id": f"frame_{int(datetime.now().timestamp())}",
|
||||
"persons": persons,
|
||||
"zone_summary": zone_summary,
|
||||
"processing_time_ms": result["processing_time_ms"],
|
||||
"metadata": {"mock_data": self.settings.mock_pose_data}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in estimate_poses: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Analyze pose data with custom parameters."""
|
||||
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
|
||||
include_keypoints, include_segmentation)
|
||||
|
||||
async def get_zone_occupancy(self, zone_id: str):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
# Mock occupancy data
|
||||
import random
|
||||
count = random.randint(0, 5)
|
||||
persons = []
|
||||
|
||||
for i in range(count):
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": random.uniform(0.7, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"max_occupancy": 10,
|
||||
"persons": persons,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zone occupancy: {e}")
|
||||
return None
|
||||
|
||||
async def get_zones_summary(self):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
import random
|
||||
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
|
||||
zone_data = {}
|
||||
total_persons = 0
|
||||
active_zones = 0
|
||||
|
||||
for zone_id in zones:
|
||||
count = random.randint(0, 3)
|
||||
zone_data[zone_id] = {
|
||||
"occupancy": count,
|
||||
"max_occupancy": 10,
|
||||
"status": "active" if count > 0 else "inactive"
|
||||
}
|
||||
total_persons += count
|
||||
if count > 0:
|
||||
active_zones += 1
|
||||
|
||||
return {
|
||||
"total_persons": total_persons,
|
||||
"zones": zone_data,
|
||||
"active_zones": active_zones
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zones summary: {e}")
|
||||
raise
|
||||
|
||||
async def get_historical_data(self, start_time, end_time, zone_ids=None,
|
||||
aggregation_interval=300, include_raw_data=False):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Mock historical data
|
||||
import random
|
||||
from datetime import timedelta
|
||||
|
||||
current_time = start_time
|
||||
aggregated_data = []
|
||||
raw_data = [] if include_raw_data else None
|
||||
|
||||
while current_time < end_time:
|
||||
# Generate aggregated data point
|
||||
data_point = {
|
||||
"timestamp": current_time,
|
||||
"total_persons": random.randint(0, 8),
|
||||
"zones": {}
|
||||
}
|
||||
|
||||
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
|
||||
data_point["zones"][zone_id] = {
|
||||
"occupancy": random.randint(0, 3),
|
||||
"avg_confidence": random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
aggregated_data.append(data_point)
|
||||
|
||||
# Generate raw data if requested
|
||||
if include_raw_data:
|
||||
for _ in range(random.randint(0, 5)):
|
||||
raw_data.append({
|
||||
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
|
||||
"person_id": f"person_{random.randint(1, 10)}",
|
||||
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
|
||||
"confidence": random.uniform(0.5, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
current_time += timedelta(seconds=aggregation_interval)
|
||||
|
||||
return {
|
||||
"aggregated_data": aggregated_data,
|
||||
"raw_data": raw_data,
|
||||
"total_records": len(aggregated_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting historical data: {e}")
|
||||
raise
|
||||
|
||||
async def get_recent_activities(self, zone_id=None, limit=10):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
import random
|
||||
activities = []
|
||||
|
||||
for i in range(limit):
|
||||
activity = {
|
||||
"activity_id": f"activity_{i}",
|
||||
"person_id": f"person_{random.randint(1, 5)}",
|
||||
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"confidence": random.uniform(0.6, 0.95),
|
||||
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
|
||||
"duration_seconds": random.randint(10, 300)
|
||||
}
|
||||
activities.append(activity)
|
||||
|
||||
return activities
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting recent activities: {e}")
|
||||
raise
|
||||
|
||||
async def is_calibrating(self):
|
||||
"""Check if calibration is in progress."""
|
||||
return False # Mock implementation
|
||||
|
||||
async def start_calibration(self):
|
||||
"""Start calibration process."""
|
||||
import uuid
|
||||
calibration_id = str(uuid.uuid4())
|
||||
self.logger.info(f"Started calibration: {calibration_id}")
|
||||
return calibration_id
|
||||
|
||||
async def run_calibration(self, calibration_id):
|
||||
"""Run calibration process."""
|
||||
self.logger.info(f"Running calibration: {calibration_id}")
|
||||
# Mock calibration process
|
||||
await asyncio.sleep(5)
|
||||
self.logger.info(f"Calibration completed: {calibration_id}")
|
||||
|
||||
async def get_calibration_status(self):
|
||||
"""Get current calibration status."""
|
||||
return {
|
||||
"is_calibrating": False,
|
||||
"calibration_id": None,
|
||||
"progress_percent": 100,
|
||||
"current_step": "completed",
|
||||
"estimated_remaining_minutes": 0,
|
||||
"last_calibration": datetime.now() - timedelta(hours=1)
|
||||
}
|
||||
|
||||
async def get_statistics(self, start_time, end_time):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
import random
|
||||
|
||||
# Mock statistics
|
||||
total_detections = random.randint(100, 1000)
|
||||
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
|
||||
|
||||
return {
|
||||
"total_detections": total_detections,
|
||||
"successful_detections": successful_detections,
|
||||
"failed_detections": total_detections - successful_detections,
|
||||
"success_rate": successful_detections / total_detections,
|
||||
"average_confidence": random.uniform(0.75, 0.90),
|
||||
"average_processing_time_ms": random.uniform(50, 200),
|
||||
"unique_persons": random.randint(5, 20),
|
||||
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity_distribution": {
|
||||
"standing": random.uniform(0.3, 0.5),
|
||||
"sitting": random.uniform(0.2, 0.4),
|
||||
"walking": random.uniform(0.1, 0.3),
|
||||
"lying": random.uniform(0.0, 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting statistics: {e}")
|
||||
raise
|
||||
|
||||
async def process_segmentation_data(self, frame_id):
|
||||
"""Process segmentation data in background."""
|
||||
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
|
||||
# Mock background processing
|
||||
await asyncio.sleep(2)
|
||||
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
|
||||
|
||||
# WebSocket streaming methods
|
||||
async def get_current_pose_data(self):
|
||||
"""Get current pose data for streaming."""
|
||||
try:
|
||||
# Generate current pose data
|
||||
result = await self.estimate_poses()
|
||||
|
||||
# Format data by zones for WebSocket streaming
|
||||
zone_data = {}
|
||||
|
||||
# Group persons by zone
|
||||
for person in result["persons"]:
|
||||
zone_id = person.get("zone_id", "zone_1")
|
||||
|
||||
if zone_id not in zone_data:
|
||||
zone_data[zone_id] = {
|
||||
"pose": {
|
||||
"persons": [],
|
||||
"count": 0
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"activity": None,
|
||||
"metadata": {
|
||||
"frame_id": result["frame_id"],
|
||||
"processing_time_ms": result["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
zone_data[zone_id]["pose"]["persons"].append(person)
|
||||
zone_data[zone_id]["pose"]["count"] += 1
|
||||
|
||||
# Update zone confidence (average)
|
||||
current_confidence = zone_data[zone_id]["confidence"]
|
||||
person_confidence = person.get("confidence", 0.0)
|
||||
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
|
||||
|
||||
# Set activity if not already set
|
||||
if not zone_data[zone_id]["activity"] and person.get("activity"):
|
||||
zone_data[zone_id]["activity"] = person["activity"]
|
||||
|
||||
return zone_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting current pose data: {e}")
|
||||
# Return empty zone data on error
|
||||
return {}
|
||||
|
||||
# Health check methods
|
||||
async def health_check(self):
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Service is running normally",
|
||||
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
|
||||
"metrics": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self):
|
||||
"""Check if service is ready."""
|
||||
return self.is_initialized and self.is_running
|
||||
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Real-time streaming service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamService:
|
||||
"""Service for real-time data streaming."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize stream service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# WebSocket connections
|
||||
self.connections: Set[WebSocket] = set()
|
||||
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
||||
|
||||
# Stream buffers
|
||||
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Streaming statistics
|
||||
self.stats = {
|
||||
"active_connections": 0,
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.streaming_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the stream service."""
|
||||
self.logger.info("Stream service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Stream service started")
|
||||
|
||||
# Start background streaming task
|
||||
if self.settings.enable_real_time_processing:
|
||||
self.streaming_task = asyncio.create_task(self._streaming_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background task
|
||||
if self.streaming_task:
|
||||
self.streaming_task.cancel()
|
||||
try:
|
||||
await self.streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close all connections
|
||||
await self._close_all_connections()
|
||||
|
||||
self.logger.info("Stream service stopped")
|
||||
|
||||
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
|
||||
"""Add a new WebSocket connection."""
|
||||
try:
|
||||
await websocket.accept()
|
||||
self.connections.add(websocket)
|
||||
self.connection_metadata[websocket] = metadata or {}
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
self.stats["total_connections"] += 1
|
||||
|
||||
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
|
||||
|
||||
# Send initial data if available
|
||||
await self._send_initial_data(websocket)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error adding WebSocket connection: {e}")
|
||||
raise
|
||||
|
||||
async def remove_connection(self, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
self.connections.remove(websocket)
|
||||
self.connection_metadata.pop(websocket, None)
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
|
||||
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing WebSocket connection: {e}")
|
||||
|
||||
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
|
||||
"""Broadcast pose data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Add to buffer
|
||||
self.pose_buffer.append({
|
||||
"type": "pose_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "pose_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Broadcast CSI data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Convert numpy array to list for JSON serialization
|
||||
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
|
||||
|
||||
# Add to buffer
|
||||
self.csi_buffer.append({
|
||||
"type": "csi_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "csi_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
async def broadcast_system_status(self, status_data: Dict[str, Any]):
|
||||
"""Broadcast system status to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self._broadcast_message({
|
||||
"type": "system_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status_data
|
||||
})
|
||||
|
||||
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||
"""Send message to a specific connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def _broadcast_message(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients."""
|
||||
if not self.connections:
|
||||
return
|
||||
|
||||
disconnected = set()
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to send message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
disconnected.add(websocket)
|
||||
|
||||
# Remove disconnected clients
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
if message.get("type") in ["pose_update", "csi_update"]:
|
||||
self.stats["data_points_streamed"] += 1
|
||||
|
||||
async def _send_initial_data(self, websocket: WebSocket):
|
||||
"""Send initial data to a new connection."""
|
||||
try:
|
||||
# Send recent pose data
|
||||
if self.pose_buffer:
|
||||
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_poses",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_poses
|
||||
})
|
||||
|
||||
# Send recent CSI data
|
||||
if self.csi_buffer:
|
||||
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_csi",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_csi
|
||||
})
|
||||
|
||||
# Send service status
|
||||
status = await self.get_status()
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "service_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending initial data: {e}")
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""Background streaming loop for periodic updates."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Send periodic heartbeat
|
||||
if self.connections:
|
||||
await self._broadcast_message({
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"active_connections": len(self.connections)
|
||||
})
|
||||
|
||||
# Wait for next iteration
|
||||
await asyncio.sleep(self.settings.websocket_ping_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Streaming loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in streaming loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""Close all WebSocket connections."""
|
||||
disconnected = []
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.close()
|
||||
disconnected.append(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing connection: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clear all connections
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"connections": {
|
||||
"active": len(self.connections),
|
||||
"total": self.stats["total_connections"]
|
||||
},
|
||||
"buffers": {
|
||||
"pose_buffer_size": len(self.pose_buffer),
|
||||
"csi_buffer_size": len(self.csi_buffer),
|
||||
"max_buffer_size": self.settings.stream_buffer_size
|
||||
},
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"stream_fps": self.settings.stream_fps,
|
||||
"buffer_size": self.settings.stream_buffer_size,
|
||||
"ping_interval": self.settings.websocket_ping_interval,
|
||||
"timeout": self.settings.websocket_timeout
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
|
||||
success_rate = self.stats["messages_sent"] / max(1, total_messages)
|
||||
|
||||
return {
|
||||
"stream_service": {
|
||||
"active_connections": self.stats["active_connections"],
|
||||
"total_connections": self.stats["total_connections"],
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"message_success_rate": success_rate,
|
||||
"data_points_streamed": self.stats["data_points_streamed"],
|
||||
"average_latency_ms": self.stats["average_latency_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def get_connection_info(self) -> List[Dict[str, Any]]:
|
||||
"""Get information about active connections."""
|
||||
connections_info = []
|
||||
|
||||
for websocket in self.connections:
|
||||
metadata = self.connection_metadata.get(websocket, {})
|
||||
|
||||
connection_info = {
|
||||
"id": id(websocket),
|
||||
"connected_at": metadata.get("connected_at", "unknown"),
|
||||
"user_agent": metadata.get("user_agent", "unknown"),
|
||||
"ip_address": metadata.get("ip_address", "unknown"),
|
||||
"subscription_types": metadata.get("subscription_types", [])
|
||||
}
|
||||
|
||||
connections_info.append(connection_info)
|
||||
|
||||
return connections_info
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
# Clear buffers
|
||||
self.pose_buffer.clear()
|
||||
self.csi_buffer.clear()
|
||||
|
||||
# Reset statistics
|
||||
self.stats = {
|
||||
"active_connections": len(self.connections),
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
self.last_error = None
|
||||
self.logger.info("Stream service reset")
|
||||
|
||||
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get data from buffers."""
|
||||
if buffer_type == "pose":
|
||||
return list(self.pose_buffer)[-limit:]
|
||||
elif buffer_type == "csi":
|
||||
return list(self.csi_buffer)[-limit:]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if stream service is active."""
|
||||
return self.is_running
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Stream service is running normally",
|
||||
"active_connections": len(self.connections),
|
||||
"metrics": {
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"data_points_streamed": self.stats["data_points_streamed"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running
|
||||
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify WiFi-DensePose API functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import websockets
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
WS_URL = "ws://localhost:8000"
|
||||
|
||||
async def test_health_endpoints():
|
||||
"""Test health check endpoints."""
|
||||
print("🔍 Testing health endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test basic health
|
||||
async with session.get(f"{BASE_URL}/health/health") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Health check: {data['status']}")
|
||||
else:
|
||||
print(f"❌ Health check failed: {response.status}")
|
||||
|
||||
# Test readiness
|
||||
async with session.get(f"{BASE_URL}/health/ready") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
status = "ready" if data['ready'] else "not ready"
|
||||
print(f"✅ Readiness check: {status}")
|
||||
else:
|
||||
print(f"❌ Readiness check failed: {response.status}")
|
||||
|
||||
# Test liveness
|
||||
async with session.get(f"{BASE_URL}/health/live") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Liveness check: {data['status']}")
|
||||
else:
|
||||
print(f"❌ Liveness check failed: {response.status}")
|
||||
|
||||
async def test_api_endpoints():
|
||||
"""Test main API endpoints."""
|
||||
print("\n🔍 Testing API endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test root endpoint
|
||||
async with session.get(f"{BASE_URL}/") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Root endpoint: {data['name']} v{data['version']}")
|
||||
else:
|
||||
print(f"❌ Root endpoint failed: {response.status}")
|
||||
|
||||
# Test API info
|
||||
async with session.get(f"{BASE_URL}/api/v1/info") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ API info: {len(data['services'])} services configured")
|
||||
else:
|
||||
print(f"❌ API info failed: {response.status}")
|
||||
|
||||
# Test API status
|
||||
async with session.get(f"{BASE_URL}/api/v1/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ API status: {data['api']['status']}")
|
||||
else:
|
||||
print(f"❌ API status failed: {response.status}")
|
||||
|
||||
async def test_pose_endpoints():
|
||||
"""Test pose estimation endpoints."""
|
||||
print("\n🔍 Testing pose endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test current pose data
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/current") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Current pose data: {len(data.get('poses', []))} poses detected")
|
||||
else:
|
||||
print(f"❌ Current pose data failed: {response.status}")
|
||||
|
||||
# Test zones summary
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/zones/summary") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
zones = data.get('zones', {})
|
||||
print(f"✅ Zones summary: {len(zones)} zones")
|
||||
for zone_id, zone_data in list(zones.items())[:3]: # Show first 3 zones
|
||||
print(f" - {zone_id}: {zone_data.get('occupancy', 0)} people")
|
||||
else:
|
||||
print(f"❌ Zones summary failed: {response.status}")
|
||||
|
||||
# Test pose stats
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/stats") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Pose stats: {data.get('total_detections', 0)} total detections")
|
||||
else:
|
||||
print(f"❌ Pose stats failed: {response.status}")
|
||||
|
||||
async def test_stream_endpoints():
|
||||
"""Test streaming endpoints."""
|
||||
print("\n🔍 Testing stream endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test stream status
|
||||
async with session.get(f"{BASE_URL}/api/v1/stream/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Stream status: {'Active' if data['is_active'] else 'Inactive'}")
|
||||
print(f" - Connected clients: {data['connected_clients']}")
|
||||
else:
|
||||
print(f"❌ Stream status failed: {response.status}")
|
||||
|
||||
# Test stream metrics
|
||||
async with session.get(f"{BASE_URL}/api/v1/stream/metrics") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Stream metrics available")
|
||||
else:
|
||||
print(f"❌ Stream metrics failed: {response.status}")
|
||||
|
||||
async def test_websocket_connection():
|
||||
"""Test WebSocket connection."""
|
||||
print("\n🔍 Testing WebSocket connection...")
|
||||
|
||||
try:
|
||||
uri = f"{WS_URL}/api/v1/stream/pose"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("✅ WebSocket connected successfully")
|
||||
|
||||
# Wait for connection confirmation
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||
data = json.loads(message)
|
||||
|
||||
if data.get("type") == "connection_established":
|
||||
print(f"✅ Connection established with client ID: {data.get('client_id')}")
|
||||
|
||||
# Send a ping
|
||||
await websocket.send(json.dumps({"type": "ping"}))
|
||||
|
||||
# Wait for pong
|
||||
pong_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||
pong_data = json.loads(pong_message)
|
||||
|
||||
if pong_data.get("type") == "pong":
|
||||
print("✅ WebSocket ping/pong successful")
|
||||
else:
|
||||
print(f"❌ Unexpected pong response: {pong_data}")
|
||||
else:
|
||||
print(f"❌ Unexpected connection message: {data}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("❌ WebSocket connection timeout")
|
||||
except Exception as e:
|
||||
print(f"❌ WebSocket connection failed: {e}")
|
||||
|
||||
async def test_calibration_endpoints():
|
||||
"""Test calibration endpoints."""
|
||||
print("\n🔍 Testing calibration endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test calibration status
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/calibration/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Calibration status: {data.get('status', 'unknown')}")
|
||||
else:
|
||||
print(f"❌ Calibration status failed: {response.status}")
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting WiFi-DensePose API Tests")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
await test_health_endpoints()
|
||||
await test_api_endpoints()
|
||||
await test_pose_endpoints()
|
||||
await test_stream_endpoints()
|
||||
await test_websocket_connection()
|
||||
await test_calibration_endpoints()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("✅ All tests completed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test suite failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user