mirror of
https://github.com/ruvnet/RuView
synced 2026-06-09 10:13:17 +00:00
374 lines
14 KiB
Python
374 lines
14 KiB
Python
"""
|
|
Pose streaming WebSocket handler
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime
|
|
|
|
from fastapi import WebSocket
|
|
from pydantic import BaseModel, Field
|
|
|
|
from src.api.websocket.connection_manager import ConnectionManager
|
|
from src.services.pose_service import PoseService
|
|
from src.services.stream_service import StreamService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PoseStreamData(BaseModel):
|
|
"""Pose stream data model."""
|
|
|
|
timestamp: datetime = Field(..., description="Data timestamp")
|
|
zone_id: str = Field(..., description="Zone identifier")
|
|
pose_data: Dict[str, Any] = Field(..., description="Pose estimation data")
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
|
activity: Optional[str] = Field(default=None, description="Detected activity")
|
|
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
|
|
|
|
|
class PoseStreamHandler:
|
|
"""Handles pose data streaming to WebSocket clients."""
|
|
|
|
def __init__(
|
|
self,
|
|
connection_manager: ConnectionManager,
|
|
pose_service: PoseService,
|
|
stream_service: StreamService
|
|
):
|
|
self.connection_manager = connection_manager
|
|
self.pose_service = pose_service
|
|
self.stream_service = stream_service
|
|
self.is_streaming = False
|
|
self.stream_task = None
|
|
self.subscribers = {}
|
|
self.stream_config = {
|
|
"fps": 30,
|
|
"min_confidence": 0.5,
|
|
"include_metadata": True,
|
|
"buffer_size": 100
|
|
}
|
|
|
|
async def start_streaming(self):
|
|
"""Start pose data streaming."""
|
|
if self.is_streaming:
|
|
logger.warning("Pose streaming already active")
|
|
return
|
|
|
|
self.is_streaming = True
|
|
self.stream_task = asyncio.create_task(self._stream_loop())
|
|
logger.info("Pose streaming started")
|
|
|
|
async def stop_streaming(self):
|
|
"""Stop pose data streaming."""
|
|
if not self.is_streaming:
|
|
return
|
|
|
|
self.is_streaming = False
|
|
|
|
if self.stream_task:
|
|
self.stream_task.cancel()
|
|
try:
|
|
await self.stream_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
logger.info("Pose streaming stopped")
|
|
|
|
async def _stream_loop(self):
|
|
"""Main streaming loop."""
|
|
try:
|
|
while self.is_streaming:
|
|
try:
|
|
# Get current pose data from all zones
|
|
pose_data = await self.pose_service.get_current_pose_data()
|
|
|
|
if pose_data:
|
|
await self._process_and_broadcast_pose_data(pose_data)
|
|
|
|
# Control streaming rate
|
|
await asyncio.sleep(1.0 / self.stream_config["fps"])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in pose streaming loop: {e}")
|
|
await asyncio.sleep(1.0) # Brief pause on error
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Pose streaming loop cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error in pose streaming loop: {e}")
|
|
finally:
|
|
self.is_streaming = False
|
|
|
|
async def _process_and_broadcast_pose_data(self, raw_pose_data: Dict[str, Any]):
|
|
"""Process and broadcast pose data to subscribers."""
|
|
try:
|
|
# Process data for each zone
|
|
for zone_id, zone_data in raw_pose_data.items():
|
|
if not zone_data:
|
|
continue
|
|
|
|
# Create structured pose data
|
|
pose_stream_data = PoseStreamData(
|
|
timestamp=datetime.utcnow(),
|
|
zone_id=zone_id,
|
|
pose_data=zone_data.get("pose", {}),
|
|
confidence=zone_data.get("confidence", 0.0),
|
|
activity=zone_data.get("activity"),
|
|
metadata=zone_data.get("metadata") if self.stream_config["include_metadata"] else None
|
|
)
|
|
|
|
# Filter by minimum confidence
|
|
if pose_stream_data.confidence < self.stream_config["min_confidence"]:
|
|
continue
|
|
|
|
# Broadcast to subscribers
|
|
await self._broadcast_pose_data(pose_stream_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing pose data: {e}")
|
|
|
|
async def _broadcast_pose_data(self, pose_data: PoseStreamData):
|
|
"""Broadcast pose data to matching WebSocket clients."""
|
|
try:
|
|
# Prepare broadcast data
|
|
broadcast_data = {
|
|
"type": "pose_data",
|
|
"timestamp": pose_data.timestamp.isoformat(),
|
|
"zone_id": pose_data.zone_id,
|
|
"data": {
|
|
"pose": pose_data.pose_data,
|
|
"confidence": pose_data.confidence,
|
|
"activity": pose_data.activity
|
|
}
|
|
}
|
|
|
|
# Add metadata if enabled
|
|
if pose_data.metadata and self.stream_config["include_metadata"]:
|
|
broadcast_data["metadata"] = pose_data.metadata
|
|
|
|
# Broadcast to pose stream subscribers
|
|
sent_count = await self.connection_manager.broadcast(
|
|
data=broadcast_data,
|
|
stream_type="pose",
|
|
zone_ids=[pose_data.zone_id]
|
|
)
|
|
|
|
if sent_count > 0:
|
|
logger.debug(f"Broadcasted pose data for zone {pose_data.zone_id} to {sent_count} clients")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting pose data: {e}")
|
|
|
|
async def handle_client_subscription(
|
|
self,
|
|
client_id: str,
|
|
subscription_config: Dict[str, Any]
|
|
):
|
|
"""Handle client subscription configuration."""
|
|
try:
|
|
# Store client subscription config
|
|
self.subscribers[client_id] = {
|
|
"zone_ids": subscription_config.get("zone_ids", []),
|
|
"min_confidence": subscription_config.get("min_confidence", 0.5),
|
|
"max_fps": subscription_config.get("max_fps", 30),
|
|
"include_metadata": subscription_config.get("include_metadata", True),
|
|
"stream_types": subscription_config.get("stream_types", ["pose_data"]),
|
|
"subscribed_at": datetime.utcnow()
|
|
}
|
|
|
|
logger.info(f"Updated subscription for client {client_id}")
|
|
|
|
# Send confirmation
|
|
confirmation = {
|
|
"type": "subscription_updated",
|
|
"client_id": client_id,
|
|
"config": self.subscribers[client_id],
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
await self.connection_manager.send_to_client(client_id, confirmation)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling client subscription: {e}")
|
|
|
|
async def handle_client_disconnect(self, client_id: str):
|
|
"""Handle client disconnection."""
|
|
if client_id in self.subscribers:
|
|
del self.subscribers[client_id]
|
|
logger.info(f"Removed subscription for disconnected client {client_id}")
|
|
|
|
async def send_historical_data(
|
|
self,
|
|
client_id: str,
|
|
zone_id: str,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
limit: int = 100
|
|
):
|
|
"""Send historical pose data to client."""
|
|
try:
|
|
# Get historical data from pose service
|
|
historical_data = await self.pose_service.get_historical_data(
|
|
zone_id=zone_id,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=limit
|
|
)
|
|
|
|
# Send data in chunks to avoid overwhelming the client
|
|
chunk_size = 10
|
|
for i in range(0, len(historical_data), chunk_size):
|
|
chunk = historical_data[i:i + chunk_size]
|
|
|
|
message = {
|
|
"type": "historical_data",
|
|
"zone_id": zone_id,
|
|
"chunk_index": i // chunk_size,
|
|
"total_chunks": (len(historical_data) + chunk_size - 1) // chunk_size,
|
|
"data": chunk,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
await self.connection_manager.send_to_client(client_id, message)
|
|
|
|
# Small delay between chunks
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Send completion message
|
|
completion_message = {
|
|
"type": "historical_data_complete",
|
|
"zone_id": zone_id,
|
|
"total_records": len(historical_data),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
await self.connection_manager.send_to_client(client_id, completion_message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending historical data: {e}")
|
|
|
|
# Send error message to client
|
|
error_message = {
|
|
"type": "error",
|
|
"message": f"Failed to retrieve historical data: {str(e)}",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
await self.connection_manager.send_to_client(client_id, error_message)
|
|
|
|
async def send_zone_statistics(self, client_id: str, zone_id: str):
|
|
"""Send zone statistics to client."""
|
|
try:
|
|
# Get zone statistics
|
|
stats = await self.pose_service.get_zone_statistics(zone_id)
|
|
|
|
message = {
|
|
"type": "zone_statistics",
|
|
"zone_id": zone_id,
|
|
"statistics": stats,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
await self.connection_manager.send_to_client(client_id, message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending zone statistics: {e}")
|
|
|
|
async def broadcast_system_event(self, event_type: str, event_data: Dict[str, Any]):
|
|
"""Broadcast system events to all connected clients."""
|
|
try:
|
|
message = {
|
|
"type": "system_event",
|
|
"event_type": event_type,
|
|
"data": event_data,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Broadcast to all pose stream clients
|
|
sent_count = await self.connection_manager.broadcast(
|
|
data=message,
|
|
stream_type="pose"
|
|
)
|
|
|
|
logger.info(f"Broadcasted system event '{event_type}' to {sent_count} clients")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting system event: {e}")
|
|
|
|
async def update_stream_config(self, config: Dict[str, Any]):
|
|
"""Update streaming configuration."""
|
|
try:
|
|
# Validate and update configuration
|
|
if "fps" in config:
|
|
fps = max(1, min(60, config["fps"]))
|
|
self.stream_config["fps"] = fps
|
|
|
|
if "min_confidence" in config:
|
|
confidence = max(0.0, min(1.0, config["min_confidence"]))
|
|
self.stream_config["min_confidence"] = confidence
|
|
|
|
if "include_metadata" in config:
|
|
self.stream_config["include_metadata"] = bool(config["include_metadata"])
|
|
|
|
if "buffer_size" in config:
|
|
buffer_size = max(10, min(1000, config["buffer_size"]))
|
|
self.stream_config["buffer_size"] = buffer_size
|
|
|
|
logger.info(f"Updated stream configuration: {self.stream_config}")
|
|
|
|
# Broadcast configuration update to clients
|
|
await self.broadcast_system_event("stream_config_updated", {
|
|
"new_config": self.stream_config
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating stream configuration: {e}")
|
|
|
|
def get_stream_status(self) -> Dict[str, Any]:
|
|
"""Get current streaming status."""
|
|
return {
|
|
"is_streaming": self.is_streaming,
|
|
"config": self.stream_config,
|
|
"subscriber_count": len(self.subscribers),
|
|
"subscribers": {
|
|
client_id: {
|
|
"zone_ids": sub["zone_ids"],
|
|
"min_confidence": sub["min_confidence"],
|
|
"subscribed_at": sub["subscribed_at"].isoformat()
|
|
}
|
|
for client_id, sub in self.subscribers.items()
|
|
}
|
|
}
|
|
|
|
async def get_performance_metrics(self) -> Dict[str, Any]:
|
|
"""Get streaming performance metrics."""
|
|
try:
|
|
# Get connection manager metrics
|
|
conn_metrics = await self.connection_manager.get_metrics()
|
|
|
|
# Get pose service metrics
|
|
pose_metrics = await self.pose_service.get_performance_metrics()
|
|
|
|
return {
|
|
"streaming": {
|
|
"is_active": self.is_streaming,
|
|
"fps": self.stream_config["fps"],
|
|
"subscriber_count": len(self.subscribers)
|
|
},
|
|
"connections": conn_metrics,
|
|
"pose_service": pose_metrics,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting performance metrics: {e}")
|
|
return {}
|
|
|
|
async def shutdown(self):
|
|
"""Shutdown pose stream handler."""
|
|
await self.stop_streaming()
|
|
self.subscribers.clear()
|
|
logger.info("Pose stream handler shutdown complete") |