ria-toolkit-oss/tests/agent/test_ws_client.py
ben 22b035dbee
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Has been cancelled
Test with tox / Test with tox (3.10) (pull_request) Has been cancelled
Test with tox / Test with tox (3.11) (pull_request) Has been cancelled
Test with tox / Test with tox (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.11) (pull_request) Has been cancelled
Build Project / Build Project (3.10) (pull_request) Has been cancelled
format fixes
2026-04-20 13:51:15 -04:00

163 lines
4.7 KiB
Python

"""Reconnect + heartbeat + malformed-control-frame behavior.
Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the
test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7.
"""
from __future__ import annotations
import asyncio
import json
import websockets
from ria_toolkit_oss.agent.ws_client import WsClient
async def _recv_json(ws) -> dict:
raw = await ws.recv()
return json.loads(raw)
async def _open_server(handler):
# websockets 13 ignores extra positional args; bind to localhost:0 for an
# ephemeral port and return both the server and the port.
server = await websockets.serve(handler, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
return server, port
def test_heartbeat_sent_on_connect():
async def scenario():
received: list[dict] = []
connected = asyncio.Event()
async def handler(ws):
connected.set()
msg = await _recv_json(ws)
received.append(msg)
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=0.05,
reconnect_pause=0.05,
)
task = asyncio.create_task(
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1})
)
await asyncio.wait_for(connected.wait(), timeout=2.0)
for _ in range(50):
if received:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return received
received = asyncio.run(scenario())
assert received and received[0]["type"] == "heartbeat"
def test_reconnects_after_server_drop():
async def scenario():
connections = 0
first_dropped = asyncio.Event()
async def handler(ws):
nonlocal connections
connections += 1
if connections == 1:
await ws.close()
first_dropped.set()
else:
try:
await ws.recv()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
task = asyncio.create_task(
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"})
)
await asyncio.wait_for(first_dropped.wait(), timeout=2.0)
for _ in range(100):
if connections >= 2:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return connections
n = asyncio.run(scenario())
assert n >= 2
def test_malformed_control_frame_does_not_crash():
async def scenario():
handled: list[dict] = []
done = asyncio.Event()
async def handler(ws):
await ws.send("not json")
await ws.send(json.dumps({"type": "ping"}))
done.set()
try:
await ws.wait_closed()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_msg(m):
handled.append(m)
task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}))
for _ in range(50):
if handled:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return handled
handled = asyncio.run(scenario())
assert handled and handled[0] == {"type": "ping"}