"""Persistent WebSocket client for the streamer agent. Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop. The caller drives the I/O loop via ``run()`` with a message handler callback. """ from __future__ import annotations import asyncio import json import logging from typing import Awaitable, Callable logger = logging.getLogger("ria_agent.ws") MessageHandler = Callable[[dict], Awaitable[None]] HeartbeatBuilder = Callable[[], dict] BinaryHandler = Callable[[bytes], Awaitable[None]] class WsClient: """Persistent WebSocket connection with heartbeat and auto-reconnect. ``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token`` is sent as a bearer in the ``Authorization`` header on connect. """ def __init__( self, url: str, token: str, *, heartbeat_interval: float = 30.0, reconnect_pause: float = 5.0, ) -> None: self.url = url self.token = token self.heartbeat_interval = heartbeat_interval self.reconnect_pause = reconnect_pause self._ws = None self._stop = asyncio.Event() # ------------------------------------------------------------------ async def _connect(self): import websockets headers = [("Authorization", f"Bearer {self.token}")] if self.token else None # websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions. try: return await websockets.connect(self.url, additional_headers=headers) except TypeError: return await websockets.connect(self.url, extra_headers=headers) # ------------------------------------------------------------------ async def send_json(self, payload: dict) -> None: if self._ws is None: raise ConnectionError("WebSocket is not connected") await self._ws.send(json.dumps(payload)) async def send_bytes(self, data: bytes) -> None: if self._ws is None: raise ConnectionError("WebSocket is not connected") await self._ws.send(data) def stop(self) -> None: self._stop.set() # ------------------------------------------------------------------ async def run( self, on_message: MessageHandler, heartbeat: HeartbeatBuilder, on_binary: BinaryHandler | None = None, ) -> None: """Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" while not self._stop.is_set(): try: self._ws = await self._connect() logger.info("Connected to %s", self.url) hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat)) try: async for raw in self._ws: if isinstance(raw, bytes): if on_binary is None: logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) continue try: await on_binary(raw) except Exception: logger.exception("on_binary handler raised; dropping frame") continue try: msg = json.loads(raw) except json.JSONDecodeError: logger.warning("Malformed control frame: %r", raw[:200]) continue await on_message(msg) finally: hb_task.cancel() try: await hb_task except (asyncio.CancelledError, Exception): pass except asyncio.CancelledError: raise except Exception as exc: if self._stop.is_set(): break logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause) finally: try: if self._ws is not None: await self._ws.close() except Exception: pass self._ws = None if self._stop.is_set(): break await asyncio.sleep(self.reconnect_pause) async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None: while True: try: await self.send_json(heartbeat()) except Exception as exc: logger.debug("Heartbeat send failed: %s", exc) return await asyncio.sleep(self.heartbeat_interval)