"""Reconnect + heartbeat timing against a real local websockets server.""" from __future__ import annotations import asyncio import json import pytest import websockets from ria_toolkit_oss.agent.ws_client import WsClient async def _recv_json(ws) -> dict: raw = await ws.recv() return json.loads(raw) async def _open_server(handler): # websockets 13 ignores extra positional args; bind to localhost:0 for an # ephemeral port and return both the server and the port. server = await websockets.serve(handler, "127.0.0.1", 0) port = server.sockets[0].getsockname()[1] return server, port def test_heartbeat_sent_on_connect(): async def scenario(): received: list[dict] = [] connected = asyncio.Event() async def handler(ws): connected.set() msg = await _recv_json(ws) received.append(msg) server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=0.05, reconnect_pause=0.05, ) task = asyncio.create_task( client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1}) ) await asyncio.wait_for(connected.wait(), timeout=2.0) for _ in range(50): if received: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return received received = asyncio.run(scenario()) assert received and received[0]["type"] == "heartbeat" def test_reconnects_after_server_drop(): async def scenario(): connections = 0 first_dropped = asyncio.Event() async def handler(ws): nonlocal connections connections += 1 if connections == 1: await ws.close() first_dropped.set() else: try: await ws.recv() except Exception: pass server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) task = asyncio.create_task( client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"}) ) await asyncio.wait_for(first_dropped.wait(), timeout=2.0) for _ in range(100): if connections >= 2: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return connections n = asyncio.run(scenario()) assert n >= 2 def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = [] done = asyncio.Event() async def handler(ws): await ws.send("not json") await ws.send(json.dumps({"type": "ping"})) done.set() try: await ws.wait_closed() except Exception: pass server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) async def on_msg(m): handled.append(m) task = asyncio.create_task( client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) ) for _ in range(50): if handled: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return handled handled = asyncio.run(scenario()) assert handled and handled[0] == {"type": "ping"}