Compare commits
10 Commits
ria-agent-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e7ccf506d | |||
| 0b75013653 | |||
| 5a67b39d22 | |||
| 0b736612ec | |||
| 3dbbe1721d | |||
| 0fd9550977 | |||
| 6a106fe2c4 | |||
| 4a71dcf6c2 | |||
| 1b6d79f65c | |||
| 07fc871463 |
12
CHANGELOG.md
12
CHANGELOG.md
|
|
@ -1,17 +1,5 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## [0.1.8] - 2026-06-01
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- **`ria-agent register --hub` now defaults to `https://riahub.ai`** — most users can run `ria-agent register --api-key ria_reg_...` without the `--hub` flag. Dev and self-hosted users keep the existing override (`--hub http://my-hub:3005`). The default lives in `ria_toolkit_oss.agent.cli.DEFAULT_HUB_URL`.
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- **`websockets` is now a runtime dependency** — previously declared only in the optional `agent` poetry group, so a vanilla `pip install ria-toolkit-oss` left `ria-agent stream` failing with `ModuleNotFoundError: No module named 'websockets'`. Added to `[project].dependencies` with the same constraint (`>=12.0,<14.0`).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [0.1.7] - 2026-05-26
|
## [0.1.7] - 2026-05-26
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||||
project = 'ria-toolkit-oss'
|
project = 'ria-toolkit-oss'
|
||||||
copyright = '2026, Qoherent Inc'
|
copyright = '2026, Qoherent Inc'
|
||||||
author = 'Qoherent Inc.'
|
author = 'Qoherent Inc.'
|
||||||
release = '0.1.8'
|
release = '0.1.7'
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|
|
||||||
1228
poetry.lock
generated
1228
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "ria-toolkit-oss"
|
name = "ria-toolkit-oss"
|
||||||
version = "0.1.8"
|
version = "0.1.7"
|
||||||
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
||||||
license = { text = "AGPL-3.0-only" }
|
license = { text = "AGPL-3.0-only" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
@ -50,8 +50,7 @@ dependencies = [
|
||||||
"pyyaml (>=6.0.3,<7.0.0)",
|
"pyyaml (>=6.0.3,<7.0.0)",
|
||||||
"click (>=8.1.0,<9.0.0)",
|
"click (>=8.1.0,<9.0.0)",
|
||||||
"matplotlib (>=3.8.0,<4.0.0)",
|
"matplotlib (>=3.8.0,<4.0.0)",
|
||||||
"paramiko (>=3.5.1)",
|
"paramiko (>=3.5.1)"
|
||||||
"websockets (>=12.0,<14.0)"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||||
|
|
@ -78,7 +77,6 @@ packages = [
|
||||||
]
|
]
|
||||||
include = [
|
include = [
|
||||||
"**/*.so", # Required for Nuitkaification
|
"**/*.so", # Required for Nuitkaification
|
||||||
"src/ria_toolkit_oss/agent/udev/*.rules", # Shipped SDR udev rules (ria-agent install-udev)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,11 @@ Subcommands:
|
||||||
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
||||||
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
||||||
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
||||||
- ``ria-agent register --api-key KEY`` — register with the production hub
|
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub
|
||||||
(``https://riahub.ai`` by default; override with ``--hub URL`` for dev
|
using a personal registration key (minted from **Settings → RIA Agents**
|
||||||
or self-hosted) using a personal registration key (minted from
|
on the hub, shown once at mint time) and save credentials (and optional
|
||||||
**Settings → RIA Agents** on the hub, shown once at mint time) and save
|
TX interlocks) to ``~/.ria/agent.json``. The hub also accepts the legacy
|
||||||
credentials (and optional TX interlocks) to ``~/.ria/agent.json``. The
|
shared ``[wac] API_KEY`` for back-compat, but that path is deprecated.
|
||||||
hub also accepts the legacy shared ``[wac] API_KEY`` for back-compat,
|
|
||||||
but that path is deprecated.
|
|
||||||
|
|
||||||
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
||||||
long-poll behavior for back-compatibility with existing deployments.
|
long-poll behavior for back-compatibility with existing deployments.
|
||||||
|
|
@ -30,6 +28,8 @@ from .hardware import available_devices
|
||||||
from .legacy_executor import main as _legacy_main
|
from .legacy_executor import main as _legacy_main
|
||||||
from .namegen import generate_agent_name
|
from .namegen import generate_agent_name
|
||||||
|
|
||||||
|
DEFAULT_HUB_URL = "https://riahub.ai"
|
||||||
|
|
||||||
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,24 +54,11 @@ def _user_agent() -> str:
|
||||||
# small DB lookup + insert; anything past this is a stuck hub, not a slow one.
|
# small DB lookup + insert; anything past this is a stuck hub, not a slow one.
|
||||||
_REGISTER_TIMEOUT_S = 15
|
_REGISTER_TIMEOUT_S = 15
|
||||||
|
|
||||||
# Production hub URL — used as the default for `ria-agent register` so most
|
|
||||||
# users don't need to pass --hub. Dev / self-hosted users override explicitly.
|
|
||||||
DEFAULT_HUB_URL = "https://riahub.ai"
|
|
||||||
|
|
||||||
|
|
||||||
REGISTRATION_REASON_MESSAGES = {
|
REGISTRATION_REASON_MESSAGES = {
|
||||||
"invalid_key": (
|
"invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."),
|
||||||
"Registration key not recognized. Generate a fresh key from "
|
"expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||||
"Settings → RIA Agents on the hub."
|
"revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||||
),
|
|
||||||
"expired": (
|
|
||||||
"This registration key has expired. Generate a new one from "
|
|
||||||
"Settings → RIA Agents on the hub."
|
|
||||||
),
|
|
||||||
"revoked": (
|
|
||||||
"This registration key was revoked. Generate a new one from "
|
|
||||||
"Settings → RIA Agents on the hub."
|
|
||||||
),
|
|
||||||
"already_consumed": (
|
"already_consumed": (
|
||||||
"This single-use registration key has already been used. "
|
"This single-use registration key has already been used. "
|
||||||
"Generate a new one, or mint a reusable key instead."
|
"Generate a new one, or mint a reusable key instead."
|
||||||
|
|
@ -205,71 +192,6 @@ def _cmd_stream(args: argparse.Namespace) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
_UDEV_RULES_NAME = "90-ria-sdr.rules"
|
|
||||||
|
|
||||||
|
|
||||||
def _cmd_install_udev(args: argparse.Namespace) -> int:
|
|
||||||
"""Install the bundled SDR udev rules so USB radios open without sudo.
|
|
||||||
|
|
||||||
This is the one OS-level step needed for USB SDRs (B2x0 / RTL-SDR / HackRF /
|
|
||||||
bladeRF). It ships inside ria-toolkit-oss — no separate tool to install — but
|
|
||||||
writing to ``/etc/udev/rules.d`` and reloading udev requires root, so run it
|
|
||||||
once with sudo. Network radios (Pluto/USRP over IP) need nothing here.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
from importlib.resources import files
|
|
||||||
|
|
||||||
try:
|
|
||||||
src = files("ria_toolkit_oss.agent").joinpath("udev", _UDEV_RULES_NAME)
|
|
||||||
rules_text = src.read_text()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"error: bundled udev rules not found: {e}", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
dest_dir = args.dest
|
|
||||||
dest = os.path.join(dest_dir, _UDEV_RULES_NAME)
|
|
||||||
|
|
||||||
if os.geteuid() != 0:
|
|
||||||
print(
|
|
||||||
"error: installing udev rules requires root.\n"
|
|
||||||
f" run once: sudo {os.path.basename(sys.argv[0])} install-udev",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
|
||||||
with open(dest, "w") as f:
|
|
||||||
f.write(rules_text)
|
|
||||||
print(f"Installed udev rules -> {dest}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"error: failed to write {dest}: {e}", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if not args.no_reload and shutil.which("udevadm"):
|
|
||||||
for cmd in (["udevadm", "control", "--reload-rules"], ["udevadm", "trigger"]):
|
|
||||||
try:
|
|
||||||
subprocess.run(cmd, check=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"warning: '{' '.join(cmd)}' failed: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Add the invoking (pre-sudo) user to the access group so group-based rules
|
|
||||||
# apply even without a local logind session.
|
|
||||||
target_user = os.environ.get("SUDO_USER") or ""
|
|
||||||
if target_user and shutil.which("usermod"):
|
|
||||||
try:
|
|
||||||
subprocess.run(["usermod", "-aG", args.group, target_user], check=True)
|
|
||||||
print(f"Added user '{target_user}' to group '{args.group}'.")
|
|
||||||
print(f"Log out and back in (or run 'newgrp {args.group}') for the group to take effect.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"warning: could not add '{target_user}' to '{args.group}': {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
print("Done. Unplug and replug your USB SDR, then run `ria-agent stream`.")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_ws_url(hub_url: str, agent_id: str) -> str:
|
def _derive_ws_url(hub_url: str, agent_id: str) -> str:
|
||||||
if not hub_url:
|
if not hub_url:
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -296,35 +218,8 @@ def main() -> None:
|
||||||
sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)")
|
sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)")
|
||||||
sub.add_parser("detect", help="List available SDR drivers")
|
sub.add_parser("detect", help="List available SDR drivers")
|
||||||
|
|
||||||
p_udev = sub.add_parser(
|
|
||||||
"install-udev",
|
|
||||||
help="Install SDR udev rules so USB radios open without sudo (run once, with sudo)",
|
|
||||||
)
|
|
||||||
p_udev.add_argument(
|
|
||||||
"--dest",
|
|
||||||
default="/etc/udev/rules.d",
|
|
||||||
help="Directory to install the rules file into (default: /etc/udev/rules.d)",
|
|
||||||
)
|
|
||||||
p_udev.add_argument(
|
|
||||||
"--group",
|
|
||||||
default="plugdev",
|
|
||||||
help="Group granted device access; the invoking user is added to it (default: plugdev)",
|
|
||||||
)
|
|
||||||
p_udev.add_argument(
|
|
||||||
"--no-reload",
|
|
||||||
action="store_true",
|
|
||||||
help="Skip 'udevadm control --reload-rules' / 'udevadm trigger'",
|
|
||||||
)
|
|
||||||
|
|
||||||
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||||
p_reg.add_argument(
|
p_reg.add_argument("--hub", default=DEFAULT_HUB_URL, help=f"RIA Hub URL (default: {DEFAULT_HUB_URL})")
|
||||||
"--hub",
|
|
||||||
default=DEFAULT_HUB_URL,
|
|
||||||
help=(
|
|
||||||
f"RIA Hub URL (default: {DEFAULT_HUB_URL}). "
|
|
||||||
"Override for dev or self-hosted hubs, e.g. http://whitehorse:3005."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p_reg.add_argument(
|
p_reg.add_argument(
|
||||||
"--api-key",
|
"--api-key",
|
||||||
dest="api_key",
|
dest="api_key",
|
||||||
|
|
@ -393,8 +288,6 @@ def main() -> None:
|
||||||
return
|
return
|
||||||
if args.command == "detect":
|
if args.command == "detect":
|
||||||
sys.exit(_cmd_detect(args))
|
sys.exit(_cmd_detect(args))
|
||||||
if args.command == "install-udev":
|
|
||||||
sys.exit(_cmd_install_udev(args))
|
|
||||||
if args.command == "register":
|
if args.command == "register":
|
||||||
sys.exit(_cmd_register(args))
|
sys.exit(_cmd_register(args))
|
||||||
if args.command == "stream":
|
if args.command == "stream":
|
||||||
|
|
|
||||||
|
|
@ -1,181 +1,17 @@
|
||||||
"""Hardware detection and heartbeat payload construction for the streamer.
|
"""Hardware detection and heartbeat payload construction for the streamer."""
|
||||||
|
|
||||||
The heartbeat advertises a ``hardware`` list the hub uses to populate the
|
|
||||||
radio-device picker. Each entry is a dict::
|
|
||||||
|
|
||||||
{"device": "usrp", "identifier": "name=MyB210",
|
|
||||||
"label": "Ettus USRP B210 (35D7CAD)", "connected": True}
|
|
||||||
|
|
||||||
- ``device`` — driver/device-type name (``"usrp"``, ``"pluto"``, …).
|
|
||||||
- ``identifier`` — the exact addressing string this driver wants, or ``None``
|
|
||||||
to let the driver auto-select the sole device of its type.
|
|
||||||
The hub forwards this verbatim in ``radio_config`` so the
|
|
||||||
identifier is always agent-owned — never derived from the
|
|
||||||
composer graph (which is what used to leak a Pluto IP into a
|
|
||||||
USRP open). It must round-trip through
|
|
||||||
``ria_toolkit_oss_cli.common.parse_ident``: a bare value is
|
|
||||||
read as an IP address, so non-network devices use ``None`` or
|
|
||||||
``name=<value>``.
|
|
||||||
- ``label`` — human-friendly text for the hub dropdown.
|
|
||||||
- ``connected`` — ``True`` when the device was physically enumerated,
|
|
||||||
``False`` when only the driver is importable (no hardware
|
|
||||||
probed/found), ``None`` when presence is unknown.
|
|
||||||
|
|
||||||
The hub tolerates plain string entries from older agents (see
|
|
||||||
``_agent_device_names`` / ``hwName``), so this richer shape is backward
|
|
||||||
compatible.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
from ria_toolkit_oss.sdr import detect_available
|
from ria_toolkit_oss.sdr import detect_available
|
||||||
|
|
||||||
from .config import AgentConfig
|
from .config import AgentConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Human-friendly names for the hub dropdown, keyed by device-type name.
|
|
||||||
DEVICE_LABELS: dict[str, str] = {
|
|
||||||
"usrp": "Ettus USRP (UHD)",
|
|
||||||
"pluto": "ADALM-Pluto",
|
|
||||||
"rtlsdr": "RTL-SDR",
|
|
||||||
"hackrf": "HackRF One",
|
|
||||||
"blade": "BladeRF",
|
|
||||||
"thinkrf": "ThinkRF (RTSA)",
|
|
||||||
"mock": "Mock SDR (synthetic)",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Enumeration can shell out (e.g. ``uhd_find_devices``), so cache results for a
|
|
||||||
# short window rather than re-probing on every ~30s heartbeat. Hot-plug shows up
|
|
||||||
# within one TTL.
|
|
||||||
_PROBE_TTL_S = 60.0
|
|
||||||
_probe_cache: tuple[float, list[dict]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def available_devices() -> list[str]:
|
def available_devices() -> list[str]:
|
||||||
"""Return a sorted list of device names whose driver modules import cleanly."""
|
"""Return a sorted list of device names whose driver modules import cleanly."""
|
||||||
return sorted(detect_available().keys())
|
return sorted(detect_available().keys())
|
||||||
|
|
||||||
|
|
||||||
def _label_for(device: str, suffix: str = "") -> str:
|
|
||||||
base = DEVICE_LABELS.get(device, device)
|
|
||||||
return f"{base} ({suffix})" if suffix else base
|
|
||||||
|
|
||||||
|
|
||||||
def _enumerate_usrp() -> list[dict] | None:
|
|
||||||
"""Probe for connected USRPs via ``uhd_find_devices``.
|
|
||||||
|
|
||||||
Returns a list of concrete device entries (``connected=True``), an empty
|
|
||||||
list when UHD ran but found nothing, or ``None`` when probing is not
|
|
||||||
possible (UHD/driver unavailable) so the caller can fall back to a
|
|
||||||
driver-only entry.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from ria_toolkit_oss.sdr.usrp import _parse_uhd_find_devices
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
found = _parse_uhd_find_devices() or []
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("USRP enumeration failed: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Addressing reality for the CLI get_sdr_device path: its USRP
|
|
||||||
# _create_device_dict matches the identifier against *raw* device values,
|
|
||||||
# but common.py prepends "addr="/"name=" before handing it over — so no
|
|
||||||
# prefixed identifier ever matches. The only reliable open for a USB USRP
|
|
||||||
# (B2x0) is auto-select (identifier=None → first device found). Networked
|
|
||||||
# USRPs addressed by IP would need a separate fix and aren't enumerated
|
|
||||||
# distinctly here. So advertise one auto-select entry, labelled with the
|
|
||||||
# serial(s) we saw so the operator still knows what's attached.
|
|
||||||
labels = [dev.get("serial") or dev.get("name") or dev.get("product") or "USRP" for dev in found]
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"device": "usrp",
|
|
||||||
"identifier": None,
|
|
||||||
"label": _label_for("usrp", ", ".join(labels)),
|
|
||||||
"connected": True,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Device types we can cheaply enumerate into concrete instances. Anything not
|
|
||||||
# listed is advertised as a single driver-only entry (presence unknown).
|
|
||||||
_PROBERS = {
|
|
||||||
"usrp": _enumerate_usrp,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_devices_uncached() -> list[dict]:
|
|
||||||
out: list[dict] = []
|
|
||||||
for device in available_devices():
|
|
||||||
prober = _PROBERS.get(device)
|
|
||||||
if prober is not None:
|
|
||||||
probed = prober()
|
|
||||||
if probed: # one or more concrete instances found
|
|
||||||
out.extend(probed)
|
|
||||||
continue
|
|
||||||
if probed == []: # prober ran but found no hardware
|
|
||||||
out.append(
|
|
||||||
{
|
|
||||||
"device": device,
|
|
||||||
"identifier": None,
|
|
||||||
"label": _label_for(device),
|
|
||||||
"connected": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
# probed is None — couldn't probe; fall through to unknown entry.
|
|
||||||
out.append(
|
|
||||||
{
|
|
||||||
"device": device,
|
|
||||||
"identifier": None,
|
|
||||||
"label": _label_for(device),
|
|
||||||
"connected": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _driver_only_devices() -> list[dict]:
|
|
||||||
"""Hardware list from importable drivers alone — no device probing."""
|
|
||||||
return [{"device": d, "identifier": None, "label": _label_for(d), "connected": None} for d in available_devices()]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_devices(*, use_cache: bool = True, probe: bool = True) -> list[dict]:
|
|
||||||
"""Return enriched ``hardware`` entries for the heartbeat.
|
|
||||||
|
|
||||||
Results are cached for ``_PROBE_TTL_S`` seconds because enumeration may shell
|
|
||||||
out to hardware tools (e.g. ``uhd_find_devices``). Pass ``use_cache=False``
|
|
||||||
to force a fresh probe.
|
|
||||||
|
|
||||||
``probe=False`` MUST be used while a capture/transmit session is active:
|
|
||||||
probing a USB SDR (running ``uhd_find_devices``) while it is streaming
|
|
||||||
disrupts the live stream and makes the device briefly disappear. In that
|
|
||||||
case we return the last good enumeration if we have one, else a driver-only
|
|
||||||
list — never touching the hardware.
|
|
||||||
"""
|
|
||||||
global _probe_cache
|
|
||||||
now = time.monotonic()
|
|
||||||
if use_cache and _probe_cache is not None:
|
|
||||||
ts, cached = _probe_cache
|
|
||||||
if not probe or (now - ts < _PROBE_TTL_S):
|
|
||||||
return cached
|
|
||||||
if not probe:
|
|
||||||
# No cache yet and we must not touch the hardware mid-stream.
|
|
||||||
return _driver_only_devices()
|
|
||||||
devices = _detect_devices_uncached()
|
|
||||||
_probe_cache = (now, devices)
|
|
||||||
return devices
|
|
||||||
|
|
||||||
|
|
||||||
def heartbeat_payload(
|
def heartbeat_payload(
|
||||||
status: str = "idle",
|
status: str = "idle",
|
||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
|
|
@ -194,11 +30,9 @@ def heartbeat_payload(
|
||||||
if c.tx_enabled:
|
if c.tx_enabled:
|
||||||
capabilities.append("tx")
|
capabilities.append("tx")
|
||||||
|
|
||||||
# Never probe the hardware while a session is active: running
|
|
||||||
# uhd_find_devices against a streaming USB SDR disrupts the live capture.
|
|
||||||
payload: dict = {
|
payload: dict = {
|
||||||
"type": "heartbeat",
|
"type": "heartbeat",
|
||||||
"hardware": detect_devices(probe=not bool(sessions)),
|
"hardware": available_devices(),
|
||||||
"status": status,
|
"status": status,
|
||||||
"capabilities": capabilities,
|
"capabilities": capabilities,
|
||||||
"tx_enabled": bool(c.tx_enabled),
|
"tx_enabled": bool(c.tx_enabled),
|
||||||
|
|
|
||||||
|
|
@ -249,19 +249,12 @@ class Streamer:
|
||||||
await self._send_error(app_id, "start missing radio_config.device")
|
await self._send_error(app_id, "start missing radio_config.device")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Open the SDR in a thread, never inline. The open is blocking and can be
|
|
||||||
# slow — a USRP shells out to uhd_find_devices and loads its FPGA, which
|
|
||||||
# takes seconds — and doing it on the event loop freezes the WebSocket
|
|
||||||
# keepalive long enough that the hub drops the agent and stops the app.
|
|
||||||
# (A Pluto opens fast enough to slip under the timeout, which is why it
|
|
||||||
# worked where a USRP hung.)
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
try:
|
try:
|
||||||
sdr, device_key = await loop.run_in_executor(None, self._registry.acquire, device, identifier)
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
await loop.run_in_executor(None, _apply_sdr_config, sdr, radio_config)
|
_apply_sdr_config(sdr, radio_config)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Failed to open SDR %r", device)
|
logger.exception("Failed to open SDR %r", device)
|
||||||
await self._send_error(app_id, f"SDR init failed: {_friendly_sdr_error(device, exc)}")
|
await self._send_error(app_id, f"SDR init failed: {exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Inherit any pending config that was queued before start.
|
# Inherit any pending config that was queued before start.
|
||||||
|
|
@ -392,51 +385,42 @@ class Streamer:
|
||||||
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
|
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Open + init the SDR in a thread, never inline — the open is blocking and
|
device_key: tuple[str, str | None] | None = None
|
||||||
# slow on a USRP (uhd_find_devices + FPGA load), and freezing the event
|
sdr: Any = None
|
||||||
# loop stalls the WebSocket keepalive until the hub drops us. Cleanup on
|
try:
|
||||||
# failure (release/close) stays inside the thread so a partial open never
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
# leaks a device handle.
|
_apply_sdr_config(sdr, radio_config)
|
||||||
def _open_and_init_tx() -> tuple[Any, tuple[str, str | None]]:
|
# init_tx is mandatory for any driver that exposes it: drivers
|
||||||
sdr_local, key_local = self._registry.acquire(device, identifier)
|
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
|
||||||
try:
|
# …) crash with a confusing "TX was not initialized" error 2 s
|
||||||
_apply_sdr_config(sdr_local, radio_config)
|
# later in the executor thread if we skip it. Treat the three
|
||||||
# init_tx is mandatory for any driver that exposes it: drivers
|
# required keys as a hard contract — a missing one is a hub-side
|
||||||
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
|
# manifest bug and we want it surfaced immediately, not papered
|
||||||
# …) crash with a confusing "TX was not initialized" error 2 s
|
# over with stale radio state.
|
||||||
# later in the executor thread if we skip it. Treat the three
|
if hasattr(sdr, "init_tx"):
|
||||||
# required keys as a hard contract — a missing one is a hub-side
|
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")}
|
||||||
# manifest bug and we want it surfaced immediately, not papered
|
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
|
||||||
# over with stale radio state.
|
if missing:
|
||||||
if hasattr(sdr_local, "init_tx"):
|
raise ValueError(f"tx_start missing required radio_config keys: {missing}")
|
||||||
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")}
|
sdr.init_tx(
|
||||||
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
|
sample_rate=init_args["sample_rate"],
|
||||||
if missing:
|
center_frequency=init_args["center_frequency"],
|
||||||
raise ValueError(f"tx_start missing required radio_config keys: {missing}")
|
gain=init_args["gain"],
|
||||||
sdr_local.init_tx(
|
channel=radio_config.get("tx_channel", 0),
|
||||||
sample_rate=init_args["sample_rate"],
|
gain_mode=radio_config.get("tx_gain_mode", "manual"),
|
||||||
center_frequency=init_args["center_frequency"],
|
)
|
||||||
gain=init_args["gain"],
|
except Exception as exc:
|
||||||
channel=radio_config.get("tx_channel", 0),
|
if device_key is not None:
|
||||||
gain_mode=radio_config.get("tx_gain_mode", "manual"),
|
if self._registry.release(device_key):
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
if self._registry.release(key_local):
|
|
||||||
try:
|
try:
|
||||||
sdr_local.close()
|
sdr.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise
|
|
||||||
return sdr_local, key_local
|
|
||||||
|
|
||||||
self._loop = asyncio.get_running_loop()
|
|
||||||
try:
|
|
||||||
sdr, device_key = await self._loop.run_in_executor(None, _open_and_init_tx)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to init TX on %r", device)
|
logger.exception("Failed to init TX on %r", device)
|
||||||
await self._send_tx_status(app_id, "error", f"tx init failed: {_friendly_sdr_error(device, exc)}")
|
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
session = TxSession(
|
session = TxSession(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
sdr=sdr,
|
sdr=sdr,
|
||||||
|
|
@ -748,25 +732,6 @@ def _default_sdr_factory(device: str, identifier: str | None):
|
||||||
return get_sdr_device(device, ident=identifier)
|
return get_sdr_device(device, ident=identifier)
|
||||||
|
|
||||||
|
|
||||||
def _friendly_sdr_error(device: str, exc: Exception) -> str:
|
|
||||||
"""Add an actionable hint when an SDR open fails on USB permissions.
|
|
||||||
|
|
||||||
UHD/libusb surface this as 'insufficient permissions' / EACCES, which is
|
|
||||||
cryptic to operators. Point them at the one-time fix that ships with the
|
|
||||||
toolkit instead of leaving them to discover udev rules on their own.
|
|
||||||
"""
|
|
||||||
text = str(exc).lower()
|
|
||||||
permission_markers = ("insufficient permissions", "permission denied", "eacces", "access denied")
|
|
||||||
is_perm = isinstance(exc, PermissionError) or any(m in text for m in permission_markers)
|
|
||||||
if is_perm:
|
|
||||||
return (
|
|
||||||
f"{exc}\n"
|
|
||||||
f"USB permission denied opening '{device}'. Run this once, then replug the device:\n"
|
|
||||||
f" sudo ria-agent install-udev"
|
|
||||||
)
|
|
||||||
return str(exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Top-level entry
|
# Top-level entry
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
# RIA Toolkit SDR udev rules
|
|
||||||
#
|
|
||||||
# Grants non-root access to the USB SDRs ria-agent can drive, so `ria-agent
|
|
||||||
# stream` can open them without sudo. Installed by `ria-agent install-udev`.
|
|
||||||
#
|
|
||||||
# Access is granted two ways for portability:
|
|
||||||
# - GROUP="plugdev", MODE="0660" — classic group-based access.
|
|
||||||
# - TAG+="uaccess" — systemd-logind grants the active local
|
|
||||||
# session user access dynamically.
|
|
||||||
# A user in `plugdev` (or logged in locally) can open the device after replug.
|
|
||||||
|
|
||||||
# ADALM-Pluto (Analog Devices)
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="0456", ATTRS{idProduct}=="b673", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
|
|
||||||
# RTL-SDR (Realtek RTL2832U)
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="0bda", ATTRS{idProduct}=="2832", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="0bda", ATTRS{idProduct}=="2838", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
|
|
||||||
# HackRF (Great Scott Gadgets)
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="1d50", ATTRS{idProduct}=="6089", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="1d50", ATTRS{idProduct}=="604b", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="1d50", ATTRS{idProduct}=="cc15", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
|
|
||||||
# Ettus USRP B2x0 (UHD)
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="2500", ATTRS{idProduct}=="0020", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="2500", ATTRS{idProduct}=="0021", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="2500", ATTRS{idProduct}=="0022", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
# USRP B2x0 in bootloader / uninitialized (Cypress FX3 / legacy Ettus VID)
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="fffe", ATTRS{idProduct}=="0002", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="04b4", ATTRS{idProduct}=="00f3", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
|
|
||||||
# Nuand bladeRF
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="2cf0", ATTRS{idProduct}=="5246", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
SUBSYSTEM=="usb", ATTRS{idVendor}=="2cf0", ATTRS{idProduct}=="5250", MODE="0660", GROUP="plugdev", TAG+="uaccess"
|
|
||||||
|
|
@ -119,15 +119,9 @@ class WsClient:
|
||||||
await asyncio.sleep(self.reconnect_pause)
|
await asyncio.sleep(self.reconnect_pause)
|
||||||
|
|
||||||
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
|
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Build off the event loop: a heartbeat can probe SDR hardware
|
await self.send_json(heartbeat())
|
||||||
# (e.g. uhd_find_devices on a USRP), which blocks for seconds and
|
|
||||||
# would otherwise freeze the WebSocket keepalive long enough for
|
|
||||||
# the hub to drop the agent.
|
|
||||||
payload = await loop.run_in_executor(None, heartbeat)
|
|
||||||
await self.send_json(payload)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Heartbeat send failed: %s", exc)
|
logger.debug("Heartbeat send failed: %s", exc)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,14 @@ def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
|
||||||
outer_sample_stop = outer.sample_start + outer.sample_count
|
outer_sample_stop = outer.sample_start + outer.sample_count
|
||||||
|
|
||||||
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
||||||
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
|
if (
|
||||||
|
inner.freq_lower_edge is not None
|
||||||
|
and inner.freq_upper_edge is not None
|
||||||
|
and outer.freq_lower_edge is not None
|
||||||
|
and outer.freq_upper_edge is not None
|
||||||
|
and inner.freq_lower_edge > outer.freq_lower_edge
|
||||||
|
and inner.freq_upper_edge < outer.freq_upper_edge
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,10 @@ class Annotation:
|
||||||
:type sample_start: int
|
:type sample_start: int
|
||||||
:param sample_count: The index of the ending sample of the annotation, inclusive.
|
:param sample_count: The index of the ending sample of the annotation, inclusive.
|
||||||
:type sample_count: int
|
:type sample_count: int
|
||||||
:param freq_lower_edge: The lower frequency of the annotation.
|
:param freq_lower_edge: The lower frequency of the annotation. Optional; None if not specified in source.
|
||||||
:type freq_lower_edge: float
|
:type freq_lower_edge: float, optional
|
||||||
:param freq_upper_edge: The upper frequency of the annotation.
|
:param freq_upper_edge: The upper frequency of the annotation. Optional; None if not specified in source.
|
||||||
:type freq_upper_edge: float
|
:type freq_upper_edge: float, optional
|
||||||
:param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine.
|
:param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine.
|
||||||
Defaults to an emtpy string.
|
Defaults to an emtpy string.
|
||||||
:type label: str, optional
|
:type label: str, optional
|
||||||
|
|
@ -34,8 +34,8 @@ class Annotation:
|
||||||
self,
|
self,
|
||||||
sample_start: int,
|
sample_start: int,
|
||||||
sample_count: int,
|
sample_count: int,
|
||||||
freq_lower_edge: float,
|
freq_lower_edge: Optional[float] = None,
|
||||||
freq_upper_edge: float,
|
freq_upper_edge: Optional[float] = None,
|
||||||
label: Optional[str] = "",
|
label: Optional[str] = "",
|
||||||
comment: Optional[str] = "",
|
comment: Optional[str] = "",
|
||||||
detail: Optional[dict] = None,
|
detail: Optional[dict] = None,
|
||||||
|
|
@ -43,8 +43,8 @@ class Annotation:
|
||||||
"""Initialize a new Annotation instance."""
|
"""Initialize a new Annotation instance."""
|
||||||
self.sample_start = int(sample_start)
|
self.sample_start = int(sample_start)
|
||||||
self.sample_count = int(sample_count)
|
self.sample_count = int(sample_count)
|
||||||
self.freq_lower_edge = float(freq_lower_edge)
|
self.freq_lower_edge = float(freq_lower_edge) if freq_lower_edge is not None else None
|
||||||
self.freq_upper_edge = float(freq_upper_edge)
|
self.freq_upper_edge = float(freq_upper_edge) if freq_upper_edge is not None else None
|
||||||
self.label = str(label)
|
self.label = str(label)
|
||||||
self.comment = str(comment)
|
self.comment = str(comment)
|
||||||
|
|
||||||
|
|
@ -62,6 +62,8 @@ class Annotation:
|
||||||
:returns: True if valid, False if not.
|
:returns: True if valid, False if not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
||||||
|
return self.sample_count > 0
|
||||||
return self.sample_count > 0 and self.freq_lower_edge < self.freq_upper_edge
|
return self.sample_count > 0 and self.freq_lower_edge < self.freq_upper_edge
|
||||||
|
|
||||||
def overlap(self, other):
|
def overlap(self, other):
|
||||||
|
|
@ -73,6 +75,14 @@ class Annotation:
|
||||||
|
|
||||||
:returns: The area of the overlap in samples*frequency, or 0 if they do not overlap."""
|
:returns: The area of the overlap in samples*frequency, or 0 if they do not overlap."""
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.freq_lower_edge is None
|
||||||
|
or self.freq_upper_edge is None
|
||||||
|
or other.freq_lower_edge is None
|
||||||
|
or other.freq_upper_edge is None
|
||||||
|
):
|
||||||
|
return 0
|
||||||
|
|
||||||
sample_overlap_start = max(self.sample_start, other.sample_start)
|
sample_overlap_start = max(self.sample_start, other.sample_start)
|
||||||
sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count)
|
sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count)
|
||||||
|
|
||||||
|
|
@ -91,6 +101,8 @@ class Annotation:
|
||||||
|
|
||||||
:returns: sample length multiplied by bandwidth."""
|
:returns: sample length multiplied by bandwidth."""
|
||||||
|
|
||||||
|
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
||||||
|
return 0
|
||||||
return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge)
|
return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge)
|
||||||
|
|
||||||
def __eq__(self, other: Annotation) -> bool:
|
def __eq__(self, other: Annotation) -> bool:
|
||||||
|
|
@ -103,13 +115,16 @@ class Annotation:
|
||||||
|
|
||||||
annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count}
|
annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count}
|
||||||
|
|
||||||
annotation_dict["metadata"] = {
|
metadata = {
|
||||||
SigMFFile.LABEL_KEY: self.label,
|
SigMFFile.LABEL_KEY: self.label,
|
||||||
SigMFFile.COMMENT_KEY: self.comment,
|
SigMFFile.COMMENT_KEY: self.comment,
|
||||||
SigMFFile.FHI_KEY: self.freq_upper_edge,
|
|
||||||
SigMFFile.FLO_KEY: self.freq_lower_edge,
|
|
||||||
"ria:detail": self.detail,
|
"ria:detail": self.detail,
|
||||||
}
|
}
|
||||||
|
if self.freq_upper_edge is not None:
|
||||||
|
metadata[SigMFFile.FHI_KEY] = self.freq_upper_edge
|
||||||
|
if self.freq_lower_edge is not None:
|
||||||
|
metadata[SigMFFile.FLO_KEY] = self.freq_lower_edge
|
||||||
|
annotation_dict["metadata"] = metadata
|
||||||
|
|
||||||
if _is_jsonable(annotation_dict):
|
if _is_jsonable(annotation_dict):
|
||||||
return annotation_dict
|
return annotation_dict
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
import uhd
|
import uhd
|
||||||
|
|
||||||
from ria_toolkit_oss.data.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.sdr import SDR, SdrDisconnectedError, SDRError, SDRParameterError
|
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
||||||
|
|
||||||
|
|
||||||
class USRP(SDR):
|
class USRP(SDR):
|
||||||
|
|
@ -32,13 +32,6 @@ class USRP(SDR):
|
||||||
|
|
||||||
self._rx_initialized = False
|
self._rx_initialized = False
|
||||||
self._tx_initialized = False
|
self._tx_initialized = False
|
||||||
# True once a continuous RX stream has been started (see rx()). Kept
|
|
||||||
# running across rx() calls so the agent streamer gets gapless capture
|
|
||||||
# instead of a start/stop per buffer.
|
|
||||||
self._rx_streaming = False
|
|
||||||
# Samples received past the end of one rx() request, carried into the
|
|
||||||
# next call so nothing is dropped between buffers.
|
|
||||||
self._rx_residual = np.empty(0, dtype=np.complex64)
|
|
||||||
|
|
||||||
def init_rx(
|
def init_rx(
|
||||||
self,
|
self,
|
||||||
|
|
@ -72,7 +65,7 @@ class USRP(SDR):
|
||||||
|
|
||||||
# build USRP object
|
# build USRP object
|
||||||
usrp_args = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict)
|
usrp_args = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict)
|
||||||
self.usrp = _open_multi_usrp(usrp_args)
|
self.usrp = uhd.usrp.MultiUSRP(usrp_args)
|
||||||
|
|
||||||
# check if channel arg is valid
|
# check if channel arg is valid
|
||||||
max_num_channels = self.usrp.get_rx_num_channels()
|
max_num_channels = self.usrp.get_rx_num_channels()
|
||||||
|
|
@ -103,8 +96,6 @@ class USRP(SDR):
|
||||||
# flag to prevent user from calling certain functions before this one.
|
# flag to prevent user from calling certain functions before this one.
|
||||||
self._rx_initialized = True
|
self._rx_initialized = True
|
||||||
self._tx_initialized = False
|
self._tx_initialized = False
|
||||||
self._rx_streaming = False # (re)started lazily on the first rx() call
|
|
||||||
self._rx_residual = np.empty(0, dtype=np.complex64)
|
|
||||||
|
|
||||||
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
||||||
|
|
||||||
|
|
@ -274,97 +265,6 @@ class USRP(SDR):
|
||||||
|
|
||||||
return Recording(data=store_array[:, :num_samples], metadata=metadata)
|
return Recording(data=store_array[:, :num_samples], metadata=metadata)
|
||||||
|
|
||||||
def rx(self, num_samples: int) -> "np.ndarray":
|
|
||||||
"""Return *num_samples* complex64 IQ samples from a continuous RX stream.
|
|
||||||
|
|
||||||
This is the interface the agent streamer's capture loop calls every
|
|
||||||
buffer. Unlike ``record()`` (a one-shot that issues ``start_cont`` /
|
|
||||||
``stop_cont`` and sleeps each call), this keeps a single continuous
|
|
||||||
stream running across calls, so capture is gapless — no per-buffer
|
|
||||||
start/stop churn, transients, or zero-filled gaps that show up as black
|
|
||||||
bands in the spectrogram.
|
|
||||||
|
|
||||||
On the first call it auto-initializes RX (from ``sample_rate`` /
|
|
||||||
``center_freq`` / ``gain`` set by the caller) and issues ``start_cont``
|
|
||||||
once. ``close()`` (or ``stop()``) stops the stream.
|
|
||||||
"""
|
|
||||||
if not self._rx_initialized:
|
|
||||||
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
|
|
||||||
self.init_rx(
|
|
||||||
sample_rate=self.sample_rate,
|
|
||||||
center_frequency=self.center_freq,
|
|
||||||
gain=gain,
|
|
||||||
channel=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._rx_streaming:
|
|
||||||
stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont)
|
|
||||||
stream_command.stream_now = True
|
|
||||||
self.rx_stream.issue_stream_cmd(stream_command)
|
|
||||||
self._enable_rx = True
|
|
||||||
self._rx_streaming = True
|
|
||||||
print("USRP Starting RX (continuous)...")
|
|
||||||
|
|
||||||
out = np.empty(num_samples, dtype=np.complex64)
|
|
||||||
filled = 0
|
|
||||||
|
|
||||||
# Drain any samples carried over from the previous call first.
|
|
||||||
if self._rx_residual.size:
|
|
||||||
take = min(self._rx_residual.size, num_samples)
|
|
||||||
out[:take] = self._rx_residual[:take]
|
|
||||||
self._rx_residual = self._rx_residual[take:]
|
|
||||||
filled = take
|
|
||||||
|
|
||||||
recv_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64)
|
|
||||||
consecutive_timeouts = 0
|
|
||||||
error_codes = uhd.types.RXMetadataErrorCode
|
|
||||||
|
|
||||||
while filled < num_samples:
|
|
||||||
n = self.rx_stream.recv(recv_buffer, self.metadata, self.timeout)
|
|
||||||
err = self.metadata.error_code
|
|
||||||
|
|
||||||
if err == error_codes.timeout:
|
|
||||||
consecutive_timeouts += 1
|
|
||||||
# A stalled stream is a disconnect, not a transient hiccup.
|
|
||||||
if consecutive_timeouts >= 5:
|
|
||||||
self._rx_streaming = False
|
|
||||||
raise SdrDisconnectedError("USRP RX timed out repeatedly — device may be disconnected")
|
|
||||||
continue
|
|
||||||
consecutive_timeouts = 0
|
|
||||||
|
|
||||||
# Overflow ("O") means the host fell behind and UHD dropped samples
|
|
||||||
# upstream; the samples we did get are still valid, so keep going.
|
|
||||||
if err not in (error_codes.none, error_codes.overflow):
|
|
||||||
self._rx_streaming = False
|
|
||||||
raise SDRError(f"USRP RX error: {err}")
|
|
||||||
|
|
||||||
if n <= 0:
|
|
||||||
continue
|
|
||||||
take = min(n, num_samples - filled)
|
|
||||||
out[filled : filled + take] = recv_buffer[0, :take]
|
|
||||||
filled += take
|
|
||||||
# Keep anything received past this request for the next call so the
|
|
||||||
# stream stays gapless across rx() boundaries.
|
|
||||||
if take < n:
|
|
||||||
self._rx_residual = recv_buffer[0, take:n].copy()
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _stop_rx_stream(self) -> None:
|
|
||||||
"""Issue stop_cont for the continuous RX stream, if running."""
|
|
||||||
if not self._rx_streaming:
|
|
||||||
return
|
|
||||||
self._enable_rx = False
|
|
||||||
try:
|
|
||||||
stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont)
|
|
||||||
stop_cmd.stream_now = True
|
|
||||||
self.rx_stream.issue_stream_cmd(stop_cmd)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._rx_streaming = False
|
|
||||||
self._rx_residual = np.empty(0, dtype=np.complex64)
|
|
||||||
print("USRP RX stopped.")
|
|
||||||
|
|
||||||
def init_tx(
|
def init_tx(
|
||||||
self,
|
self,
|
||||||
sample_rate: int | float,
|
sample_rate: int | float,
|
||||||
|
|
@ -394,7 +294,7 @@ class USRP(SDR):
|
||||||
print(f"USRP TX Gain Mode = '{gain_mode}'")
|
print(f"USRP TX Gain Mode = '{gain_mode}'")
|
||||||
|
|
||||||
config_str = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict)
|
config_str = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict)
|
||||||
self.usrp = _open_multi_usrp(config_str)
|
self.usrp = uhd.usrp.MultiUSRP(config_str)
|
||||||
|
|
||||||
# check if channel arg is valid
|
# check if channel arg is valid
|
||||||
max_num_channels = self.usrp.get_rx_num_channels()
|
max_num_channels = self.usrp.get_rx_num_channels()
|
||||||
|
|
@ -471,7 +371,6 @@ class USRP(SDR):
|
||||||
print(f"USRP TX Gain = {self.tx_gain}")
|
print(f"USRP TX Gain = {self.tx_gain}")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._stop_rx_stream()
|
|
||||||
self._tx_initialized = False
|
self._tx_initialized = False
|
||||||
self._rx_initialized = False
|
self._rx_initialized = False
|
||||||
if hasattr(self, "rx_stream"):
|
if hasattr(self, "rx_stream"):
|
||||||
|
|
@ -563,32 +462,6 @@ class USRP(SDR):
|
||||||
return {"center_frequency": True, "sample_rate": True, "gain": True}
|
return {"center_frequency": True, "sample_rate": True, "gain": True}
|
||||||
|
|
||||||
|
|
||||||
def _open_multi_usrp(usrp_args, *, attempts=4, settle_s=2.0):
|
|
||||||
"""Construct a ``uhd.usrp.MultiUSRP``, retrying transient B200 USB states.
|
|
||||||
|
|
||||||
On USB USRPs (B200/B210) the ``uhd_find_devices`` enumeration that resolves
|
|
||||||
the device (see ``_create_device_dict``) runs immediately before the open and
|
|
||||||
can leave the FX3 USB controller mid-reset, so the first open fails with e.g.
|
|
||||||
``RuntimeError: fx3 is in state 5``. The device settles once that
|
|
||||||
enumeration's USB handle is fully released — and the failed open itself nudges
|
|
||||||
the FX3 to reload firmware/FPGA — so we retry with a short backoff before
|
|
||||||
giving up. A non-transient error (bad args, genuinely absent device) is
|
|
||||||
re-raised immediately.
|
|
||||||
"""
|
|
||||||
for attempt in range(1, attempts + 1):
|
|
||||||
try:
|
|
||||||
return uhd.usrp.MultiUSRP(usrp_args)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
msg = str(exc).lower()
|
|
||||||
transient = ("fx3" in msg) or ("usb" in msg) or ("no devices found" in msg)
|
|
||||||
if not transient or attempt == attempts:
|
|
||||||
raise
|
|
||||||
print(
|
|
||||||
f"\033[93mUSRP open attempt {attempt}/{attempts} failed " f"({exc}); retrying in {settle_s}s…\033[0m"
|
|
||||||
)
|
|
||||||
time.sleep(settle_s)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_device_dict(identifier_value=None):
|
def _create_device_dict(identifier_value=None):
|
||||||
"""
|
"""
|
||||||
Get the dictionary of information corresponding to any unique identifier,
|
Get the dictionary of information corresponding to any unique identifier,
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,8 @@ def view_annotations(
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
||||||
|
if annotation.freq_lower_edge is None or annotation.freq_upper_edge is None:
|
||||||
|
continue
|
||||||
t_start = annotation.sample_start / sample_rate
|
t_start = annotation.sample_start / sample_rate
|
||||||
t_width = annotation.sample_count / sample_rate
|
t_width = annotation.sample_count / sample_rate
|
||||||
f_start = annotation.freq_lower_edge
|
f_start = annotation.freq_lower_edge
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,57 @@
|
||||||
This module contains the main group for the ria toolkit oss CLI.
|
This module contains the main group for the ria toolkit oss CLI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import click
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ria_toolkit_oss_cli.ria_toolkit_oss import commands
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Unable to import Axes3D",
|
||||||
|
category=UserWarning,
|
||||||
|
module="matplotlib",
|
||||||
|
)
|
||||||
|
|
||||||
|
import click # noqa: E402
|
||||||
|
|
||||||
|
from ria_toolkit_oss_cli.ria_toolkit_oss import commands # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
def _git_lfs_installed() -> bool:
|
||||||
|
"""Return True if git-lfs is available on PATH."""
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "lfs", "version"],
|
||||||
|
capture_output=True,
|
||||||
|
).returncode
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(invoke_without_command=True)
|
||||||
@click.option("-v", "--verbose", is_flag=True, type=bool, help="Increase verbosity, especially useful for debugging.")
|
@click.option("-v", "--verbose", is_flag=True, type=bool, help="Increase verbosity, especially useful for debugging.")
|
||||||
def cli(verbose):
|
@click.pass_context
|
||||||
pass
|
def cli(ctx, verbose):
|
||||||
|
lfs_missing = not _git_lfs_installed()
|
||||||
|
if lfs_missing:
|
||||||
|
click.echo(
|
||||||
|
"Warning: git-lfs is not installed. RIA Hub projects require git-lfs to\n"
|
||||||
|
"track large binary files (models, recordings, datasets).\n"
|
||||||
|
"\n"
|
||||||
|
" Linux: sudo apt-get install git-lfs\n"
|
||||||
|
" macOS: brew install git-lfs\n"
|
||||||
|
" Other platforms: https://git-lfs.com\n"
|
||||||
|
"\n"
|
||||||
|
"After installing, run: git lfs install",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
|
if lfs_missing and sys.stdin.isatty():
|
||||||
|
click.pause(info="\nPress Enter to continue...", err=True)
|
||||||
|
click.echo(ctx.get_help())
|
||||||
|
|
||||||
|
|
||||||
# Loop through project commands, binding them all to the CLI.
|
# Loop through project commands, binding them all to the CLI.
|
||||||
|
|
|
||||||
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Shared authentication and security helpers for RIA Hub API calls."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import subprocess
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
DEFAULT_HUB = "https://riahub.ai"
|
||||||
|
|
||||||
|
|
||||||
|
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
||||||
|
"""Block redirects on authenticated requests to prevent credential exfiltration.
|
||||||
|
|
||||||
|
urllib re-sends the Authorization header on same-host redirects by default.
|
||||||
|
A malicious server could redirect a POST to a different host to harvest
|
||||||
|
credentials. We refuse all redirects — API clients should not encounter them
|
||||||
|
in normal operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||||
|
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
|
||||||
|
|
||||||
|
|
||||||
|
def hub_opener() -> urllib.request.OpenerDirector:
|
||||||
|
"""Return a urllib opener that blocks redirects."""
|
||||||
|
return urllib.request.build_opener(_NoRedirectHandler)
|
||||||
|
|
||||||
|
|
||||||
|
def warn_if_insecure(hub: str) -> None:
|
||||||
|
"""Warn when credentials would be sent over plain HTTP to a non-localhost host."""
|
||||||
|
parsed = urllib.parse.urlparse(hub)
|
||||||
|
if parsed.scheme == "http":
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
if host not in ("localhost", "127.0.0.1", "::1"):
|
||||||
|
click.echo(
|
||||||
|
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def basic_auth(username: str, password: str) -> str:
|
||||||
|
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
||||||
|
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
|
||||||
|
parsed = urllib.parse.urlparse(hub_url)
|
||||||
|
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "credential", "fill"],
|
||||||
|
input=payload,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
creds = {}
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
# Partition on the FIRST '=' only so passwords containing '=' are preserved.
|
||||||
|
k, sep, v = line.partition("=")
|
||||||
|
if sep:
|
||||||
|
creds[k.strip()] = v # keep value verbatim
|
||||||
|
return creds.get("username"), creds.get("password")
|
||||||
|
except Exception:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def store_credentials(hub_url: str, username: str, password: str) -> None:
|
||||||
|
"""Cache credentials via git credential approve (uses the system keychain/store)."""
|
||||||
|
parsed = urllib.parse.urlparse(hub_url)
|
||||||
|
payload = (
|
||||||
|
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["git", "credential", "approve"],
|
||||||
|
input=payload,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # non-fatal — next push just prompts again
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_credentials(hub: str) -> tuple[str, str]:
|
||||||
|
"""Return (username, password), prompting interactively if not cached."""
|
||||||
|
username, password = get_stored_credentials(hub)
|
||||||
|
if username and password:
|
||||||
|
return username, password
|
||||||
|
click.echo(f"No stored credentials found for {hub}.")
|
||||||
|
username = click.prompt("RIA Hub username")
|
||||||
|
password = click.prompt("Password / personal access token", hide_input=True)
|
||||||
|
return username, password
|
||||||
|
|
@ -86,9 +86,7 @@ def save_recording_auto(recording, output_path, input_path, quiet=False, overwri
|
||||||
input_path = Path(input_path)
|
input_path = Path(input_path)
|
||||||
fmt = detect_input_format(input_path)
|
fmt = detect_input_format(input_path)
|
||||||
|
|
||||||
output_path = determine_output_path(
|
output_path = determine_output_path(input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite)
|
||||||
input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite
|
|
||||||
)
|
|
||||||
|
|
||||||
if not quiet:
|
if not quiet:
|
||||||
if fmt == "sigmf":
|
if fmt == "sigmf":
|
||||||
|
|
@ -258,7 +256,11 @@ def list(input, verbose):
|
||||||
user_comment = ann.comment or ""
|
user_comment = ann.comment or ""
|
||||||
|
|
||||||
# Basic info
|
# Basic info
|
||||||
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
freq_range = (
|
||||||
|
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||||
|
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
||||||
|
else "N/A"
|
||||||
|
)
|
||||||
click.echo(
|
click.echo(
|
||||||
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
|
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
|
||||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
|
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
|
||||||
|
|
@ -502,8 +504,7 @@ def clear(input, output, overwrite, force, quiet):
|
||||||
help="Annotation type",
|
help="Annotation type",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None,
|
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -617,8 +618,7 @@ def energy(
|
||||||
help="Annotation type",
|
help="Annotation type",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None,
|
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -707,8 +707,7 @@ def cusum(input, label, min_duration, window_size, tolerance, annotation_type, s
|
||||||
)
|
)
|
||||||
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
|
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None,
|
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -787,8 +786,7 @@ def threshold(input, threshold, label, window_size, annotation_type, channel, sa
|
||||||
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
|
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
|
||||||
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
|
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None,
|
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -809,7 +807,8 @@ def _log_separate_start(quiet, recording, indices_list, nfft, noise_threshold_db
|
||||||
|
|
||||||
|
|
||||||
def separate(
|
def separate(
|
||||||
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose):
|
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Auto-detect parallel frequency-offset signals and split into sub-bands.
|
Auto-detect parallel frequency-offset signals and split into sub-bands.
|
||||||
|
|
||||||
|
|
@ -883,7 +882,11 @@ def separate(
|
||||||
click.echo("\n Details:")
|
click.echo("\n Details:")
|
||||||
for i in range(initial_count, final_count):
|
for i in range(initial_count, final_count):
|
||||||
ann = recording.annotations[i]
|
ann = recording.annotations[i]
|
||||||
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
freq_range = (
|
||||||
|
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||||
|
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
||||||
|
else "N/A"
|
||||||
|
)
|
||||||
click.echo(
|
click.echo(
|
||||||
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
|
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
|
||||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
|
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,11 @@ from .generate import generate
|
||||||
# from .generate import generate
|
# from .generate import generate
|
||||||
from .init import init
|
from .init import init
|
||||||
from .serve import serve
|
from .serve import serve
|
||||||
|
from .setup_repo import setup_repo
|
||||||
from .split import split
|
from .split import split
|
||||||
from .transform import transform
|
from .transform import transform
|
||||||
from .transmit import transmit
|
from .transmit import transmit
|
||||||
|
from .upload import upload
|
||||||
from .view import view
|
from .view import view
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
|
|
||||||
401
src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py
Normal file
401
src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py
Normal file
|
|
@ -0,0 +1,401 @@
|
||||||
|
"""ria setup_repo — create and configure a RIA Hub Project repo."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from ._hub_auth import (
|
||||||
|
DEFAULT_HUB,
|
||||||
|
_NoRedirectHandler,
|
||||||
|
basic_auth,
|
||||||
|
resolve_credentials,
|
||||||
|
store_credentials,
|
||||||
|
warn_if_insecure,
|
||||||
|
)
|
||||||
|
|
||||||
|
RIA_LFS_RULES = [
|
||||||
|
("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.onnx", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.sigmf", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.sigmf-data", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.sigmf-meta", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.npy", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.npz", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.h5", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.hdf5", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.bin", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Repo names must be safe directory names and valid git remote path components.
|
||||||
|
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# API helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _api_request(
|
||||||
|
hub: str,
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
body: dict | None = None,
|
||||||
|
) -> tuple[dict, int]:
|
||||||
|
"""
|
||||||
|
Make an authenticated request to the RIA Hub API.
|
||||||
|
Returns (parsed_response_body, http_status_code).
|
||||||
|
Status 0 means a network/connection error.
|
||||||
|
Credentials are sent as HTTP Basic auth — safe over HTTPS and localhost HTTP.
|
||||||
|
Redirects are blocked to prevent credential exfiltration.
|
||||||
|
"""
|
||||||
|
url = f"{hub.rstrip('/')}/api/v1{path}"
|
||||||
|
data = json.dumps(body).encode() if body is not None else None
|
||||||
|
req = urllib.request.Request(url, data=data, method=method)
|
||||||
|
req.add_header("Content-Type", "application/json")
|
||||||
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
|
||||||
|
opener = urllib.request.build_opener(_NoRedirectHandler)
|
||||||
|
try:
|
||||||
|
with opener.open(req, timeout=15) as resp:
|
||||||
|
return json.loads(resp.read() or b"{}"), resp.status
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
try:
|
||||||
|
resp_body = json.loads(e.read() or b"{}")
|
||||||
|
except Exception:
|
||||||
|
resp_body = {}
|
||||||
|
return resp_body, e.code
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
return {"message": str(e.reason)}, 0
|
||||||
|
|
||||||
|
|
||||||
|
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
|
||||||
|
"""Return the login name of the authenticated user from GET /api/v1/user.
|
||||||
|
|
||||||
|
This is the canonical username for URL construction — it may differ from
|
||||||
|
git config user.name which is a display name, not a login.
|
||||||
|
"""
|
||||||
|
body, status = _api_request(hub, "/user", "GET", username, password)
|
||||||
|
if status == 200:
|
||||||
|
return body.get("login")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _repo_exists(hub: str, owner: str, name: str, username: str, password: str) -> bool:
|
||||||
|
body, status = _api_request(
|
||||||
|
hub,
|
||||||
|
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(name, safe='')}",
|
||||||
|
"GET",
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
)
|
||||||
|
return status == 200
|
||||||
|
|
||||||
|
|
||||||
|
def _create_repo_on_hub(hub: str, name: str, username: str, password: str, private: bool) -> bool:
|
||||||
|
"""Create an RIA Hub Project repo via API.
|
||||||
|
|
||||||
|
Returns True if the repo was freshly created (server seeded README.md and
|
||||||
|
.gitattributes via auto_init + is_ria), False if the hub was unreachable
|
||||||
|
(local fallback needed). Exits on fatal errors (auth, quota, name taken).
|
||||||
|
"""
|
||||||
|
body, status = _api_request(
|
||||||
|
hub,
|
||||||
|
"/user/repos",
|
||||||
|
"POST",
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"auto_init": True,
|
||||||
|
"is_ria": True,
|
||||||
|
"private": private,
|
||||||
|
"default_branch": "main",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if status == 201:
|
||||||
|
click.echo(f"Repository '{name}' created on RIA Hub.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if status == 0:
|
||||||
|
click.echo(
|
||||||
|
f"Warning: could not reach RIA Hub at {hub}: {body.get('message', 'connection failed')}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
click.echo("Continuing with local setup only — create the repo manually on RIA Hub.", err=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
msg = body.get("message", "")
|
||||||
|
|
||||||
|
if status == 401:
|
||||||
|
click.echo("Error: authentication failed — check your username/password.", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if status in (403, 413) or "quota" in msg.lower() or "limit" in msg.lower():
|
||||||
|
click.echo("Error: cannot create repository — storage quota or account limit reached.", err=True)
|
||||||
|
if msg:
|
||||||
|
click.echo(f" Server message: {msg}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if status == 422 or "already exist" in msg.lower():
|
||||||
|
click.echo(f"Repository '{name}' already exists on RIA Hub.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
click.echo(f"Error creating repository (HTTP {status}): {msg}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Local git helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _tracked_patterns(ga_path: str) -> set:
|
||||||
|
if not os.path.exists(ga_path):
|
||||||
|
return set()
|
||||||
|
patterns = set()
|
||||||
|
with open(ga_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
m = re.match(r"^(\S+)\s+", line)
|
||||||
|
if m:
|
||||||
|
patterns.add(m.group(1))
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
|
||||||
|
def _write_local_ria_files(repo_path: str, repo_name: str) -> None:
|
||||||
|
"""Seed README.md and .gitattributes locally (used when hub is unreachable or --no-remote)."""
|
||||||
|
# README
|
||||||
|
for candidate in ("README.md", "README.rst", "README.txt", "README"):
|
||||||
|
if os.path.exists(os.path.join(repo_path, candidate)):
|
||||||
|
click.echo(f"README: {candidate} already exists, skipping")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
with open(os.path.join(repo_path, "README.md"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
f"# {repo_name}\n"
|
||||||
|
"\n"
|
||||||
|
"A RIA Hub project.\n"
|
||||||
|
"\n"
|
||||||
|
"## Description\n"
|
||||||
|
"\n"
|
||||||
|
"<!-- Add your project description here -->\n"
|
||||||
|
"\n"
|
||||||
|
"## Contents\n"
|
||||||
|
"\n"
|
||||||
|
"<!-- Describe the signals, models, or datasets in this repository -->\n"
|
||||||
|
)
|
||||||
|
click.echo("README.md: created")
|
||||||
|
|
||||||
|
# .gitattributes
|
||||||
|
ga_path = os.path.join(repo_path, ".gitattributes")
|
||||||
|
existing = _tracked_patterns(ga_path)
|
||||||
|
new_rules = [(p, a) for p, a in RIA_LFS_RULES if p not in existing]
|
||||||
|
|
||||||
|
if new_rules:
|
||||||
|
existing_content = ""
|
||||||
|
if os.path.exists(ga_path):
|
||||||
|
with open(ga_path, encoding="utf-8") as f:
|
||||||
|
existing_content = f.read()
|
||||||
|
|
||||||
|
separator = "" if (not existing_content or existing_content.endswith("\n")) else "\n"
|
||||||
|
addition = separator + "".join(f"{pattern} {attrs}\n" for pattern, attrs in new_rules)
|
||||||
|
|
||||||
|
with open(ga_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(addition)
|
||||||
|
click.echo(f".gitattributes: {len(new_rules)} rule(s) added")
|
||||||
|
else:
|
||||||
|
click.echo(".gitattributes: all RIA Hub rules are already present")
|
||||||
|
|
||||||
|
|
||||||
|
def _git(repo_path: str, *args: str, check: bool = True) -> subprocess.CompletedProcess:
|
||||||
|
return subprocess.run(
|
||||||
|
["git", "-C", repo_path, *args],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=check,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path_and_name(name: str | None, local_path: str | None) -> tuple[str, str]:
|
||||||
|
if local_path:
|
||||||
|
repo_path = os.path.abspath(local_path)
|
||||||
|
repo_name = name or os.path.basename(repo_path)
|
||||||
|
elif name:
|
||||||
|
repo_path = os.path.abspath(name)
|
||||||
|
repo_name = name
|
||||||
|
else:
|
||||||
|
repo_path = os.path.abspath(".")
|
||||||
|
repo_name = os.path.basename(repo_path)
|
||||||
|
return repo_path, repo_name
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_owner(hub: str, username: str | None, password: str | None, owner: str | None) -> str:
|
||||||
|
if not owner and username and password:
|
||||||
|
api_login = _get_authenticated_username(hub, username, password)
|
||||||
|
owner = api_login or username
|
||||||
|
return owner or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _git_init(repo_path: str) -> None:
|
||||||
|
if os.path.isdir(os.path.join(repo_path, ".git")):
|
||||||
|
return
|
||||||
|
result = _git(repo_path, "init", "-b", "main", check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
# Older git (< 2.28) doesn't support -b; fall back and rename.
|
||||||
|
_git(repo_path, "init")
|
||||||
|
_git(repo_path, "symbolic-ref", "HEAD", "refs/heads/main")
|
||||||
|
click.echo("git init: done (branch: main)")
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_remote(
|
||||||
|
repo_path: str, hub: str, resolved_owner: str, repo_name: str, username: str | None, no_remote: bool
|
||||||
|
) -> None:
|
||||||
|
if no_remote or not username:
|
||||||
|
click.echo(
|
||||||
|
f"Skipped remote setup. Add it manually:\n"
|
||||||
|
f" git -C {repo_path} remote add origin "
|
||||||
|
f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||||
|
existing = _git(repo_path, "remote", "get-url", "origin", check=False)
|
||||||
|
if existing.returncode == 0:
|
||||||
|
existing_url = existing.stdout.strip()
|
||||||
|
if existing_url == remote_url:
|
||||||
|
click.echo(f"remote origin: {remote_url} (already set)")
|
||||||
|
else:
|
||||||
|
click.echo(
|
||||||
|
f"remote 'origin' already points to {existing_url}.\n"
|
||||||
|
f" To update: git remote set-url origin {remote_url}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_git(repo_path, "remote", "add", "origin", remote_url)
|
||||||
|
click.echo(f"remote origin: {remote_url}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("setup_repo")
|
||||||
|
@click.argument("name", required=False)
|
||||||
|
@click.option(
|
||||||
|
"--path", "local_path", default=None, help="Local directory (default: current dir, or created from NAME)."
|
||||||
|
)
|
||||||
|
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
||||||
|
@click.option(
|
||||||
|
"--owner",
|
||||||
|
default=None,
|
||||||
|
metavar="USER",
|
||||||
|
help="RIA Hub login username (default: looked up from the API using your credentials).",
|
||||||
|
)
|
||||||
|
@click.option("--private", is_flag=True, default=False, help="Create the repository as private.")
|
||||||
|
@click.option(
|
||||||
|
"--no-remote", is_flag=True, default=False, help="Skip creating the repository on RIA Hub (local setup only)."
|
||||||
|
)
|
||||||
|
def setup_repo(
|
||||||
|
name: str | None,
|
||||||
|
local_path: str | None,
|
||||||
|
hub: str,
|
||||||
|
owner: str | None,
|
||||||
|
private: bool,
|
||||||
|
no_remote: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Create and configure a RIA Hub Project repo.
|
||||||
|
|
||||||
|
NAME is the repository name. If the local directory does not exist or is
|
||||||
|
not a git repo, it will be initialised automatically. Credentials are
|
||||||
|
retrieved from git's credential store — no token setup required if you
|
||||||
|
have used RIA Hub with git before.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
ria setup_repo my-dataset
|
||||||
|
ria setup_repo my-dataset --hub https://riahub.example.com
|
||||||
|
ria setup_repo --path ./existing-dir
|
||||||
|
ria setup_repo my-dataset --private
|
||||||
|
"""
|
||||||
|
repo_path, repo_name = _resolve_path_and_name(name, local_path)
|
||||||
|
|
||||||
|
if not _SAFE_NAME_RE.match(repo_name):
|
||||||
|
click.echo(
|
||||||
|
f"Error: '{repo_name}' is not a valid repository name.\n"
|
||||||
|
"Use only letters, numbers, hyphens, underscores, and dots (max 100 chars).",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not no_remote:
|
||||||
|
warn_if_insecure(hub)
|
||||||
|
|
||||||
|
username, password = (None, None) if no_remote else resolve_credentials(hub)
|
||||||
|
resolved_owner = _resolve_owner(hub, username, password, owner)
|
||||||
|
|
||||||
|
# newly_created=True means the server ran auto_init+is_ria and seeded
|
||||||
|
# README.md + .gitattributes in the initial commit; local setup pulls
|
||||||
|
# those files via fetch rather than writing them from scratch.
|
||||||
|
newly_created = False
|
||||||
|
if not no_remote and username and password:
|
||||||
|
if _repo_exists(hub, resolved_owner, repo_name, username, password):
|
||||||
|
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
|
||||||
|
else:
|
||||||
|
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
|
||||||
|
store_credentials(hub, username, password)
|
||||||
|
|
||||||
|
if not os.path.exists(repo_path):
|
||||||
|
os.makedirs(repo_path)
|
||||||
|
click.echo(f"Created directory: {repo_path}")
|
||||||
|
|
||||||
|
_git_init(repo_path)
|
||||||
|
|
||||||
|
if subprocess.run(["git", "lfs", "version"], capture_output=True).returncode != 0:
|
||||||
|
click.echo(
|
||||||
|
"Error: git-lfs is not installed.\n"
|
||||||
|
" Linux: sudo apt-get install git-lfs\n"
|
||||||
|
" macOS: brew install git-lfs\n"
|
||||||
|
" Other platforms: https://git-lfs.com",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
_git(repo_path, "lfs", "install", "--local")
|
||||||
|
click.echo("git lfs install --local: done")
|
||||||
|
|
||||||
|
_configure_remote(repo_path, hub, resolved_owner, repo_name, username, no_remote)
|
||||||
|
|
||||||
|
if newly_created:
|
||||||
|
fetch = _git(repo_path, "fetch", "origin", check=False)
|
||||||
|
if fetch.returncode == 0:
|
||||||
|
_git(repo_path, "reset", "--hard", "origin/main")
|
||||||
|
click.echo("Pulled initial commit from RIA Hub (README.md + .gitattributes)")
|
||||||
|
else:
|
||||||
|
click.echo("Warning: fetch failed — falling back to local file setup.", err=True)
|
||||||
|
_write_local_ria_files(repo_path, repo_name)
|
||||||
|
else:
|
||||||
|
_write_local_ria_files(repo_path, repo_name)
|
||||||
|
|
||||||
|
if newly_created:
|
||||||
|
click.echo(f"\nRepo is ready. Push your work:\n cd {repo_path}\n git push -u origin main")
|
||||||
|
else:
|
||||||
|
click.echo(
|
||||||
|
f"\nRepo is ready. Commit and push:\n"
|
||||||
|
f" cd {repo_path}\n"
|
||||||
|
f" git add README.md .gitattributes\n"
|
||||||
|
f" git commit -m 'chore: initialise RIA Hub project'\n"
|
||||||
|
f" git push -u origin main"
|
||||||
|
)
|
||||||
392
src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py
Normal file
392
src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py
Normal file
|
|
@ -0,0 +1,392 @@
|
||||||
|
"""ria upload — stream large files to a RIA Hub Project via the LFS API.
|
||||||
|
|
||||||
|
How it works
|
||||||
|
------------
|
||||||
|
1. The file is hashed locally (SHA-256 + size) — this is the LFS object ID.
|
||||||
|
2. A single POST to the repo's LFS batch endpoint returns an upload URL
|
||||||
|
(and headers) for any object the server does not already have.
|
||||||
|
3. The file is streamed to that URL in fixed-size chunks — nothing is ever
|
||||||
|
fully loaded into memory, so files of any size work.
|
||||||
|
4. A commit is created via the Gitea contents API that records the LFS
|
||||||
|
pointer (a small text file) so the file appears in the repo tree.
|
||||||
|
|
||||||
|
No server-side changes are required — this uses the same authenticated LFS
|
||||||
|
protocol that `git lfs push` uses internally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import http.client
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from ._hub_auth import (
|
||||||
|
DEFAULT_HUB,
|
||||||
|
basic_auth,
|
||||||
|
hub_opener,
|
||||||
|
resolve_credentials,
|
||||||
|
warn_if_insecure,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read buffer for hashing and streaming — 8 MB keeps memory use flat
|
||||||
|
# for arbitrarily large files.
|
||||||
|
_CHUNK = 8 * 1024 * 1024
|
||||||
|
|
||||||
|
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_file(path: str) -> tuple[str, int]:
|
||||||
|
"""Return (sha256_hex, byte_size) by streaming the file."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
size = 0
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = f.read(_CHUNK)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
h.update(chunk)
|
||||||
|
size += len(chunk)
|
||||||
|
return h.hexdigest(), size
|
||||||
|
|
||||||
|
|
||||||
|
def _lfs_pointer_text(oid: str, size: int) -> str:
|
||||||
|
return f"version https://git-lfs.github.com/spec/v1\noid sha256:{oid}\nsize {size}\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LFS batch API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _lfs_batch(
|
||||||
|
hub: str,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
objects: list[dict],
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
POST to /{owner}/{repo}.git/info/lfs/objects/batch.
|
||||||
|
Returns the parsed JSON response.
|
||||||
|
Raises on HTTP error or JSON decode failure.
|
||||||
|
"""
|
||||||
|
url = (
|
||||||
|
f"{hub.rstrip('/')}"
|
||||||
|
f"/{urllib.parse.quote(owner, safe='')}"
|
||||||
|
f"/{urllib.parse.quote(repo, safe='')}"
|
||||||
|
f".git/info/lfs/objects/batch"
|
||||||
|
)
|
||||||
|
body = json.dumps(
|
||||||
|
{
|
||||||
|
"operation": "upload",
|
||||||
|
"transfers": ["basic"],
|
||||||
|
"objects": objects,
|
||||||
|
}
|
||||||
|
).encode()
|
||||||
|
|
||||||
|
req = urllib.request.Request(url, data=body, method="POST")
|
||||||
|
req.add_header("Content-Type", LFS_MEDIA_TYPE)
|
||||||
|
req.add_header("Accept", LFS_MEDIA_TYPE)
|
||||||
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
|
||||||
|
opener = hub_opener()
|
||||||
|
try:
|
||||||
|
with opener.open(req, timeout=30) as resp:
|
||||||
|
return json.loads(resp.read())
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
body_text = e.read().decode(errors="replace")
|
||||||
|
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Gitea contents API — create / update a file to record the LFS pointer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_sha(
|
||||||
|
hub: str,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
path: str,
|
||||||
|
branch: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Return the blob SHA of an existing file, or None if it doesn't exist."""
|
||||||
|
url = (
|
||||||
|
f"{hub.rstrip('/')}/api/v1"
|
||||||
|
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
||||||
|
f"/contents/{urllib.parse.quote(path)}"
|
||||||
|
f"?ref={urllib.parse.quote(branch)}"
|
||||||
|
)
|
||||||
|
req = urllib.request.Request(url)
|
||||||
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
try:
|
||||||
|
with hub_opener().open(req, timeout=15) as resp:
|
||||||
|
return json.loads(resp.read()).get("sha")
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
if e.code == 404:
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _commit_lfs_pointer(
|
||||||
|
hub: str,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
remote_path: str,
|
||||||
|
pointer_text: str,
|
||||||
|
branch: str,
|
||||||
|
message: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
) -> None:
|
||||||
|
"""Create or update a file in the repo containing the LFS pointer."""
|
||||||
|
url = (
|
||||||
|
f"{hub.rstrip('/')}/api/v1"
|
||||||
|
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
||||||
|
f"/contents/{urllib.parse.quote(remote_path)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_sha = _get_file_sha(hub, owner, repo, remote_path, branch, username, password)
|
||||||
|
|
||||||
|
body: dict = {
|
||||||
|
"message": message,
|
||||||
|
"content": base64.b64encode(pointer_text.encode()).decode(),
|
||||||
|
"branch": branch,
|
||||||
|
}
|
||||||
|
if existing_sha:
|
||||||
|
body["sha"] = existing_sha
|
||||||
|
|
||||||
|
method = "PUT" if existing_sha else "POST"
|
||||||
|
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
|
||||||
|
req.add_header("Content-Type", "application/json")
|
||||||
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with hub_opener().open(req, timeout=30) as resp:
|
||||||
|
resp.read()
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
body_text = e.read().decode(errors="replace")
|
||||||
|
raise RuntimeError(f"Failed to commit LFS pointer for '{remote_path}' (HTTP {e.code}): {body_text}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-file upload logic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _upload_single_file(
|
||||||
|
hub: str,
|
||||||
|
owner: str,
|
||||||
|
repo_name: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
file_path: str,
|
||||||
|
remote_dir: str,
|
||||||
|
message: str | None,
|
||||||
|
branch: str,
|
||||||
|
) -> None:
|
||||||
|
"""Hash, upload (if needed), and commit the LFS pointer for one file."""
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
size_mb = file_size / (1024 * 1024)
|
||||||
|
|
||||||
|
click.echo(f"\n {filename} ({size_mb:.1f} MB)")
|
||||||
|
|
||||||
|
click.echo(" Hashing...", nl=False)
|
||||||
|
oid, size = _hash_file(file_path)
|
||||||
|
click.echo(f" sha256:{oid[:12]}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch = _lfs_batch(hub, owner, repo_name, [{"oid": oid, "size": size}], username, password)
|
||||||
|
except RuntimeError as e:
|
||||||
|
click.echo(f"\n Error: {e}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
objects = batch.get("objects", [])
|
||||||
|
if not objects:
|
||||||
|
click.echo(" Already in LFS — skipping upload.")
|
||||||
|
else:
|
||||||
|
obj = objects[0]
|
||||||
|
if "error" in obj:
|
||||||
|
err_msg = obj["error"].get("message", "unknown error")
|
||||||
|
err_code = obj["error"].get("code", 0)
|
||||||
|
if err_code == 413 or "quota" in err_msg.lower() or "limit" in err_msg.lower():
|
||||||
|
click.echo(
|
||||||
|
f"\n Error: storage quota exceeded for this repo.\n Server: {err_msg}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
click.echo(f"\n Error from server: {err_msg}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
upload_action = obj.get("actions", {}).get("upload")
|
||||||
|
if not upload_action:
|
||||||
|
click.echo(" Already in LFS — skipping upload.")
|
||||||
|
else:
|
||||||
|
href = upload_action["href"]
|
||||||
|
up_headers = upload_action.get("header", {})
|
||||||
|
chunks = math.ceil(size / _CHUNK)
|
||||||
|
click.echo(f" Uploading ({size_mb:.1f} MB, {chunks} chunk{'s' if chunks != 1 else ''})...")
|
||||||
|
try:
|
||||||
|
_stream_upload_progress(href, up_headers, file_path, size)
|
||||||
|
except RuntimeError as e:
|
||||||
|
click.echo(f"\n Upload failed: {e}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
click.echo(" Upload complete.")
|
||||||
|
|
||||||
|
verify_action = obj.get("actions", {}).get("verify")
|
||||||
|
if verify_action:
|
||||||
|
try:
|
||||||
|
vreq = urllib.request.Request(
|
||||||
|
verify_action["href"],
|
||||||
|
data=json.dumps({"oid": oid, "size": size}).encode(),
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
vreq.add_header("Content-Type", LFS_MEDIA_TYPE)
|
||||||
|
vreq.add_header("Accept", LFS_MEDIA_TYPE)
|
||||||
|
for k, v in verify_action.get("header", {}).items():
|
||||||
|
vreq.add_header(k, v)
|
||||||
|
with urllib.request.urlopen(vreq, timeout=15):
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass # verify is optional; non-fatal on failure
|
||||||
|
|
||||||
|
pointer = _lfs_pointer_text(oid, size)
|
||||||
|
remote_path = (f"{remote_dir.rstrip('/')}/{filename}").lstrip("/") if remote_dir else filename
|
||||||
|
commit_msg = message or f"chore: upload {filename} via ria"
|
||||||
|
|
||||||
|
click.echo(f" Committing pointer → {remote_path}...", nl=False)
|
||||||
|
try:
|
||||||
|
_commit_lfs_pointer(hub, owner, repo_name, remote_path, pointer, branch, commit_msg, username, password)
|
||||||
|
click.echo(" done.")
|
||||||
|
except RuntimeError as e:
|
||||||
|
click.echo(f"\n Error: {e}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int) -> None:
|
||||||
|
"""Stream file_path to href with a click progress bar."""
|
||||||
|
parsed = urllib.parse.urlparse(href)
|
||||||
|
host = parsed.netloc
|
||||||
|
path_q = parsed.path + (f"?{parsed.query}" if parsed.query else "")
|
||||||
|
|
||||||
|
if parsed.scheme == "https":
|
||||||
|
conn = http.client.HTTPSConnection(host, timeout=300)
|
||||||
|
else:
|
||||||
|
conn = http.client.HTTPConnection(host, timeout=300)
|
||||||
|
|
||||||
|
all_headers = dict(headers)
|
||||||
|
all_headers.setdefault("Content-Type", "application/octet-stream")
|
||||||
|
all_headers["Content-Length"] = str(size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn.connect()
|
||||||
|
conn.putrequest("PUT", path_q)
|
||||||
|
for k, v in all_headers.items():
|
||||||
|
conn.putheader(k, v)
|
||||||
|
conn.endheaders()
|
||||||
|
|
||||||
|
with click.progressbar(
|
||||||
|
length=size,
|
||||||
|
label=" ",
|
||||||
|
width=40,
|
||||||
|
show_eta=True,
|
||||||
|
show_percent=True,
|
||||||
|
fill_char="█",
|
||||||
|
empty_char="░",
|
||||||
|
) as bar:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = f.read(_CHUNK)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
conn.send(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
|
||||||
|
resp = conn.getresponse()
|
||||||
|
resp.read()
|
||||||
|
if resp.status not in (200, 201):
|
||||||
|
raise RuntimeError(f"HTTP {resp.status}")
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("upload")
|
||||||
|
@click.argument("files", nargs=-1, required=True)
|
||||||
|
@click.option(
|
||||||
|
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
|
||||||
|
)
|
||||||
|
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
||||||
|
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
|
||||||
|
@click.option(
|
||||||
|
"--path",
|
||||||
|
"remote_dir",
|
||||||
|
default="",
|
||||||
|
metavar="DIR",
|
||||||
|
help="Remote directory path inside the repo (default: repo root).",
|
||||||
|
)
|
||||||
|
@click.option("--message", "-m", default=None, help="Commit message (default: 'chore: upload <filename> via ria').")
|
||||||
|
def upload(
|
||||||
|
files: tuple[str],
|
||||||
|
repo: str,
|
||||||
|
hub: str,
|
||||||
|
branch: str,
|
||||||
|
remote_dir: str,
|
||||||
|
message: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Upload large files to a RIA Hub Project via Git LFS.
|
||||||
|
|
||||||
|
Files are streamed directly to the repo's LFS object store — nothing is
|
||||||
|
buffered into memory, so files of any size work. Each file creates one
|
||||||
|
commit recording the LFS pointer.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
ria upload recording.sigmf-data --repo benchinnery/my-recordings
|
||||||
|
ria upload *.npy --repo benchinnery/my-recordings --branch main
|
||||||
|
ria upload big.pt --repo benchinnery/models --path weights/
|
||||||
|
"""
|
||||||
|
# Validate repo argument
|
||||||
|
if "/" not in repo:
|
||||||
|
click.echo("Error: --repo must be in the form OWNER/NAME.", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
owner, repo_name = repo.split("/", 1)
|
||||||
|
|
||||||
|
# Expand and validate files
|
||||||
|
resolved = []
|
||||||
|
for pattern in files:
|
||||||
|
if not os.path.isfile(pattern):
|
||||||
|
click.echo(f"Error: '{pattern}' is not a file or does not exist.", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
resolved.append(os.path.abspath(pattern))
|
||||||
|
|
||||||
|
hub = hub.rstrip("/")
|
||||||
|
warn_if_insecure(hub)
|
||||||
|
username, password = resolve_credentials(hub)
|
||||||
|
|
||||||
|
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")
|
||||||
|
|
||||||
|
for file_path in resolved:
|
||||||
|
_upload_single_file(hub, owner, repo_name, username, password, file_path, remote_dir, message, branch)
|
||||||
|
|
||||||
|
click.echo(f"\nAll done. {len(resolved)} file(s) uploaded to {owner}/{repo_name}.")
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
"""Tests for `ria-agent install-udev` and the bundled udev rules."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from importlib.resources import files
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from ria_toolkit_oss.agent import cli as agent_cli
|
|
||||||
|
|
||||||
|
|
||||||
def test_bundled_udev_rules_present_and_cover_usb_sdrs():
|
|
||||||
text = files("ria_toolkit_oss.agent").joinpath("udev", "90-ria-sdr.rules").read_text()
|
|
||||||
# ADALM-Pluto, RTL-SDR, HackRF, and USRP B2x0 VIDs must be covered.
|
|
||||||
for vid in ("0456", "0bda", "1d50", "2500"):
|
|
||||||
assert vid in text
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_udev_requires_root(capsys):
|
|
||||||
args = agent_cli.argparse.Namespace(dest="/etc/udev/rules.d", group="plugdev", no_reload=True)
|
|
||||||
with patch("os.geteuid", return_value=1000):
|
|
||||||
rc = agent_cli._cmd_install_udev(args)
|
|
||||||
assert rc == 1
|
|
||||||
err = capsys.readouterr().err
|
|
||||||
assert "requires root" in err
|
|
||||||
assert "install-udev" in err
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_udev_writes_rules_when_root(tmp_path, monkeypatch, capsys):
|
|
||||||
args = agent_cli.argparse.Namespace(dest=str(tmp_path), group="plugdev", no_reload=True)
|
|
||||||
# No SUDO_USER and --no-reload → no subprocess calls; just the file write.
|
|
||||||
monkeypatch.delenv("SUDO_USER", raising=False)
|
|
||||||
with patch("os.geteuid", return_value=0):
|
|
||||||
rc = agent_cli._cmd_install_udev(args)
|
|
||||||
assert rc == 0
|
|
||||||
written = (tmp_path / "90-ria-sdr.rules").read_text()
|
|
||||||
assert "SUBSYSTEM" in written
|
|
||||||
|
|
@ -79,9 +79,9 @@ def test_user_agent_is_set_and_not_python_default():
|
||||||
"""
|
"""
|
||||||
ua = agent_cli._user_agent()
|
ua = agent_cli._user_agent()
|
||||||
assert ua, "User-Agent must not be empty"
|
assert ua, "User-Agent must not be empty"
|
||||||
assert not ua.lower().startswith("python-urllib"), (
|
assert not ua.lower().startswith(
|
||||||
f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
"python-urllib"
|
||||||
)
|
), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
||||||
assert ua.startswith("ria-agent/")
|
assert ua.startswith("ria-agent/")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,7 +96,10 @@ def test_register_request_carries_explicit_user_agent(tmp_path):
|
||||||
captured["api_key"] = req.get_header("X-api-key")
|
captured["api_key"] = req.get_header("X-api-key")
|
||||||
captured["timeout"] = kwargs.get("timeout")
|
captured["timeout"] = kwargs.get("timeout")
|
||||||
raise urllib.error.HTTPError(
|
raise urllib.error.HTTPError(
|
||||||
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type]
|
url=req.full_url,
|
||||||
|
code=403,
|
||||||
|
msg="",
|
||||||
|
hdrs=None, # type: ignore[arg-type]
|
||||||
fp=BytesIO(_structured("invalid_key")),
|
fp=BytesIO(_structured("invalid_key")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -140,59 +143,3 @@ def test_register_surfaces_reason_on_http_error(tmp_path, capsys):
|
||||||
assert "Settings → RIA Agents" in captured.err
|
assert "Settings → RIA Agents" in captured.err
|
||||||
# Config must NOT be written on failure.
|
# Config must NOT be written on failure.
|
||||||
assert not cfg_path.exists()
|
assert not cfg_path.exists()
|
||||||
|
|
||||||
|
|
||||||
def test_default_hub_url_is_production():
|
|
||||||
"""Lock in the constant so a future typo doesn't silently redirect users."""
|
|
||||||
assert agent_cli.DEFAULT_HUB_URL == "https://riahub.ai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_defaults_hub_to_production(tmp_path):
|
|
||||||
"""Omitting --hub uses the production hub URL constant."""
|
|
||||||
cfg_path = tmp_path / "agent.json"
|
|
||||||
captured: dict = {}
|
|
||||||
|
|
||||||
def _fake_urlopen(req, *args, **kwargs):
|
|
||||||
captured["url"] = req.full_url
|
|
||||||
raise urllib.error.HTTPError(
|
|
||||||
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type]
|
|
||||||
fp=BytesIO(_structured("invalid_key")),
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
|
|
||||||
patch("urllib.request.urlopen", side_effect=_fake_urlopen),
|
|
||||||
patch.object(sys, "argv", ["ria-agent", "register", "--api-key", "ria_reg_x"]),
|
|
||||||
):
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
agent_cli.main()
|
|
||||||
|
|
||||||
assert captured["url"] == f"{agent_cli.DEFAULT_HUB_URL}/screens/agents/register"
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_hub_override_wins_over_default(tmp_path):
|
|
||||||
"""Explicit --hub still wins; default is only a fallback."""
|
|
||||||
cfg_path = tmp_path / "agent.json"
|
|
||||||
captured: dict = {}
|
|
||||||
|
|
||||||
def _fake_urlopen(req, *args, **kwargs):
|
|
||||||
captured["url"] = req.full_url
|
|
||||||
raise urllib.error.HTTPError(
|
|
||||||
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type]
|
|
||||||
fp=BytesIO(_structured("invalid_key")),
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
|
|
||||||
patch("urllib.request.urlopen", side_effect=_fake_urlopen),
|
|
||||||
patch.object(
|
|
||||||
sys,
|
|
||||||
"argv",
|
|
||||||
["ria-agent", "register", "--hub", "http://whitehorse:3005", "--api-key", "ria_reg_x"],
|
|
||||||
),
|
|
||||||
):
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
agent_cli.main()
|
|
||||||
|
|
||||||
assert captured["url"] == "http://whitehorse:3005/screens/agents/register"
|
|
||||||
assert agent_cli.DEFAULT_HUB_URL not in captured["url"]
|
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,11 @@ def test_available_devices_sorted_list():
|
||||||
assert "mock" in devices
|
assert "mock" in devices
|
||||||
|
|
||||||
|
|
||||||
def _device_names(hardware_list):
|
|
||||||
return {e["device"] for e in hardware_list}
|
|
||||||
|
|
||||||
|
|
||||||
def test_heartbeat_payload_shape():
|
def test_heartbeat_payload_shape():
|
||||||
p = hardware.heartbeat_payload()
|
p = hardware.heartbeat_payload()
|
||||||
assert p["type"] == "heartbeat"
|
assert p["type"] == "heartbeat"
|
||||||
assert p["status"] == "idle"
|
assert p["status"] == "idle"
|
||||||
# hardware is now a list of rich dict entries.
|
assert "mock" in p["hardware"]
|
||||||
assert "mock" in _device_names(p["hardware"])
|
|
||||||
assert "app_id" not in p
|
assert "app_id" not in p
|
||||||
# New fields, default shape
|
# New fields, default shape
|
||||||
assert p["capabilities"] == ["rx"]
|
assert p["capabilities"] == ["rx"]
|
||||||
|
|
@ -37,53 +32,6 @@ def test_heartbeat_payload_shape():
|
||||||
assert p2["app_id"] == "abc"
|
assert p2["app_id"] == "abc"
|
||||||
|
|
||||||
|
|
||||||
def test_detect_devices_entry_shape():
|
|
||||||
devices = hardware.detect_devices(use_cache=False)
|
|
||||||
assert isinstance(devices, list)
|
|
||||||
for entry in devices:
|
|
||||||
assert set(entry) >= {"device", "identifier", "label", "connected"}
|
|
||||||
assert isinstance(entry["device"], str)
|
|
||||||
# identifier round-trips through parse_ident: None or a string.
|
|
||||||
assert entry["identifier"] is None or isinstance(entry["identifier"], str)
|
|
||||||
mock = next(e for e in devices if e["device"] == "mock")
|
|
||||||
assert mock["label"] # has a human label
|
|
||||||
|
|
||||||
|
|
||||||
def test_detect_devices_cache():
|
|
||||||
a = hardware.detect_devices(use_cache=False)
|
|
||||||
b = hardware.detect_devices(use_cache=True)
|
|
||||||
assert _device_names(a) == _device_names(b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_detect_devices_probe_false_never_touches_hardware(monkeypatch):
|
|
||||||
# probe=False must not run the hardware enumerators (uhd_find_devices etc.),
|
|
||||||
# which would disrupt an active USB capture.
|
|
||||||
def boom():
|
|
||||||
raise AssertionError("hardware must not be probed when probe=False")
|
|
||||||
|
|
||||||
monkeypatch.setattr(hardware, "_detect_devices_uncached", boom)
|
|
||||||
monkeypatch.setattr(hardware, "_probe_cache", None)
|
|
||||||
devices = hardware.detect_devices(probe=False, use_cache=False)
|
|
||||||
assert all(e.get("connected") is None for e in devices) # driver-only
|
|
||||||
|
|
||||||
|
|
||||||
def test_heartbeat_disables_probe_during_active_session(monkeypatch):
|
|
||||||
seen = {}
|
|
||||||
|
|
||||||
def fake_detect(**kw):
|
|
||||||
seen.clear()
|
|
||||||
seen.update(kw)
|
|
||||||
return []
|
|
||||||
|
|
||||||
monkeypatch.setattr(hardware, "detect_devices", fake_detect)
|
|
||||||
|
|
||||||
hardware.heartbeat_payload(sessions={"rx": {"app_id": "a", "state": "streaming"}})
|
|
||||||
assert seen.get("probe") is False # streaming → no hardware probe
|
|
||||||
|
|
||||||
hardware.heartbeat_payload(sessions=None)
|
|
||||||
assert seen.get("probe") is True # idle → probe allowed
|
|
||||||
|
|
||||||
|
|
||||||
def test_heartbeat_payload_tx_capability_from_cfg():
|
def test_heartbeat_payload_tx_capability_from_cfg():
|
||||||
from ria_toolkit_oss.agent.config import AgentConfig
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,24 +9,11 @@ import numpy as np
|
||||||
from ria_toolkit_oss.agent.streamer import (
|
from ria_toolkit_oss.agent.streamer import (
|
||||||
Streamer,
|
Streamer,
|
||||||
_apply_sdr_config,
|
_apply_sdr_config,
|
||||||
_friendly_sdr_error,
|
|
||||||
_samples_to_interleaved_float32,
|
_samples_to_interleaved_float32,
|
||||||
)
|
)
|
||||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
def test_friendly_sdr_error_adds_udev_hint_on_permission():
|
|
||||||
msg = _friendly_sdr_error("usrp", RuntimeError("USB open failed: insufficient permissions."))
|
|
||||||
assert "install-udev" in msg
|
|
||||||
assert "insufficient permissions" in msg
|
|
||||||
|
|
||||||
|
|
||||||
def test_friendly_sdr_error_passes_through_other_errors():
|
|
||||||
msg = _friendly_sdr_error("usrp", RuntimeError("No USRP device found for identifier 'name=x'"))
|
|
||||||
assert "install-udev" not in msg
|
|
||||||
assert "No USRP device found" in msg
|
|
||||||
|
|
||||||
|
|
||||||
class FakeWs:
|
class FakeWs:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.json_sent: list[dict] = []
|
self.json_sent: list[dict] = []
|
||||||
|
|
|
||||||
|
|
@ -1,91 +0,0 @@
|
||||||
"""Hardware-free tests for _open_multi_usrp's transient-FX3 retry.
|
|
||||||
|
|
||||||
On B200/B210 the `uhd_find_devices` enumeration that runs right before opening
|
|
||||||
can leave the FX3 USB controller mid-reset, so the first MultiUSRP open fails
|
|
||||||
with "fx3 is in state 5". _open_multi_usrp retries transient USB states with a
|
|
||||||
short settle; a non-transient error is re-raised immediately.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def usrp_mod(monkeypatch):
|
|
||||||
"""Import the usrp module against a stub `uhd`, with time.sleep neutered."""
|
|
||||||
saved_uhd = sys.modules.get("uhd")
|
|
||||||
saved_usrp = sys.modules.get("ria_toolkit_oss.sdr.usrp")
|
|
||||||
|
|
||||||
uhd = types.ModuleType("uhd")
|
|
||||||
uhd.usrp = types.SimpleNamespace(MultiUSRP=None) # set per-test
|
|
||||||
sys.modules["uhd"] = uhd
|
|
||||||
sys.modules.pop("ria_toolkit_oss.sdr.usrp", None)
|
|
||||||
import ria_toolkit_oss.sdr.usrp as mod
|
|
||||||
|
|
||||||
monkeypatch.setattr(mod.time, "sleep", lambda *_a, **_k: None)
|
|
||||||
|
|
||||||
yield mod
|
|
||||||
|
|
||||||
for name, m in (("uhd", saved_uhd), ("ria_toolkit_oss.sdr.usrp", saved_usrp)):
|
|
||||||
if m is None:
|
|
||||||
sys.modules.pop(name, None)
|
|
||||||
else:
|
|
||||||
sys.modules[name] = m
|
|
||||||
|
|
||||||
|
|
||||||
def _flaky_factory(fail_times, exc):
|
|
||||||
"""A MultiUSRP stand-in that raises `exc` the first `fail_times` calls."""
|
|
||||||
calls = {"n": 0}
|
|
||||||
|
|
||||||
def make(args):
|
|
||||||
calls["n"] += 1
|
|
||||||
if calls["n"] <= fail_times:
|
|
||||||
raise exc
|
|
||||||
return f"usrp<{args}>"
|
|
||||||
|
|
||||||
make.calls = calls
|
|
||||||
return make
|
|
||||||
|
|
||||||
|
|
||||||
def test_retries_transient_fx3_state_then_succeeds(usrp_mod):
|
|
||||||
factory = _flaky_factory(2, RuntimeError("RuntimeError: fx3 is in state 5"))
|
|
||||||
usrp_mod.uhd.usrp.MultiUSRP = factory
|
|
||||||
|
|
||||||
out = usrp_mod._open_multi_usrp("name=B210,", attempts=4, settle_s=0)
|
|
||||||
|
|
||||||
assert out == "usrp<name=B210,>"
|
|
||||||
assert factory.calls["n"] == 3 # failed twice, third succeeded
|
|
||||||
|
|
||||||
|
|
||||||
def test_gives_up_after_attempts_and_raises_last(usrp_mod):
|
|
||||||
factory = _flaky_factory(99, RuntimeError("fx3 is in state 5"))
|
|
||||||
usrp_mod.uhd.usrp.MultiUSRP = factory
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="fx3 is in state 5"):
|
|
||||||
usrp_mod._open_multi_usrp("name=B210,", attempts=3, settle_s=0)
|
|
||||||
|
|
||||||
assert factory.calls["n"] == 3 # exactly `attempts` tries, no infinite loop
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_transient_error_is_raised_immediately(usrp_mod):
|
|
||||||
factory = _flaky_factory(99, RuntimeError("EnvironmentError: no UHD images"))
|
|
||||||
usrp_mod.uhd.usrp.MultiUSRP = factory
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="no UHD images"):
|
|
||||||
usrp_mod._open_multi_usrp("name=B210,", attempts=4, settle_s=0)
|
|
||||||
|
|
||||||
assert factory.calls["n"] == 1 # not retried — fails fast
|
|
||||||
|
|
||||||
|
|
||||||
def test_success_on_first_try_does_not_retry(usrp_mod):
|
|
||||||
factory = _flaky_factory(0, RuntimeError("fx3 is in state 5"))
|
|
||||||
usrp_mod.uhd.usrp.MultiUSRP = factory
|
|
||||||
|
|
||||||
out = usrp_mod._open_multi_usrp("addr=192.168.10.2,", attempts=4, settle_s=0)
|
|
||||||
|
|
||||||
assert out == "usrp<addr=192.168.10.2,>"
|
|
||||||
assert factory.calls["n"] == 1
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
"""Hardware-free tests for the USRP continuous-streaming rx().
|
|
||||||
|
|
||||||
`uhd` isn't importable without the UHD install, so we stub the bits USRP.rx()
|
|
||||||
touches and drive it with a scripted fake rx_stream. The point is to prove the
|
|
||||||
capture is gapless across rx() calls — the property that fixes the choppy /
|
|
||||||
black-banded spectrogram caused by the old start/stop-per-buffer record().
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def _install_fake_uhd():
|
|
||||||
uhd = types.ModuleType("uhd")
|
|
||||||
|
|
||||||
class StreamCMD:
|
|
||||||
def __init__(self, mode):
|
|
||||||
self.mode = mode
|
|
||||||
self.stream_now = False
|
|
||||||
self.time_spec = None
|
|
||||||
|
|
||||||
uhd.types = types.SimpleNamespace(
|
|
||||||
StreamCMD=StreamCMD,
|
|
||||||
StreamMode=types.SimpleNamespace(start_cont="start_cont", stop_cont="stop_cont"),
|
|
||||||
RXMetadataErrorCode=types.SimpleNamespace(none="none", overflow="overflow", timeout="timeout"),
|
|
||||||
)
|
|
||||||
uhd.usrp = types.SimpleNamespace()
|
|
||||||
sys.modules["uhd"] = uhd
|
|
||||||
return uhd
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def USRP():
|
|
||||||
# Snapshot so the fake uhd / freshly-imported usrp don't leak into other
|
|
||||||
# tests (e.g. detect_available() would otherwise think usrp is importable).
|
|
||||||
saved_uhd = sys.modules.get("uhd")
|
|
||||||
saved_usrp = sys.modules.get("ria_toolkit_oss.sdr.usrp")
|
|
||||||
|
|
||||||
_install_fake_uhd()
|
|
||||||
sys.modules.pop("ria_toolkit_oss.sdr.usrp", None)
|
|
||||||
from ria_toolkit_oss.sdr.usrp import USRP as _USRP
|
|
||||||
|
|
||||||
yield _USRP
|
|
||||||
|
|
||||||
for name, mod in (("uhd", saved_uhd), ("ria_toolkit_oss.sdr.usrp", saved_usrp)):
|
|
||||||
if mod is None:
|
|
||||||
sys.modules.pop(name, None)
|
|
||||||
else:
|
|
||||||
sys.modules[name] = mod
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeStream:
|
|
||||||
"""Delivers a contiguous ramp of samples; ``real`` part is the sample index.
|
|
||||||
|
|
||||||
``script`` is a list of (count, error_code) the recv loop walks through.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, script, metadata):
|
|
||||||
self._script = list(script)
|
|
||||||
self._metadata = metadata
|
|
||||||
self._counter = 0
|
|
||||||
self.issued = []
|
|
||||||
|
|
||||||
def issue_stream_cmd(self, cmd):
|
|
||||||
self.issued.append(cmd.mode)
|
|
||||||
|
|
||||||
def recv(self, buffer, metadata, timeout):
|
|
||||||
count, err = self._script.pop(0)
|
|
||||||
metadata.error_code = err
|
|
||||||
if count > 0:
|
|
||||||
idx = np.arange(self._counter, self._counter + count, dtype=np.float32)
|
|
||||||
buffer[0, :count] = idx.astype(np.complex64)
|
|
||||||
self._counter += count
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def _make_usrp(USRP, script, rx_buffer_size=4):
|
|
||||||
u = USRP.__new__(USRP)
|
|
||||||
u._rx_initialized = True
|
|
||||||
u._rx_streaming = False
|
|
||||||
u._rx_residual = np.empty(0, dtype=np.complex64)
|
|
||||||
u.rx_buffer_size = rx_buffer_size
|
|
||||||
u.timeout = 0.1
|
|
||||||
u._enable_rx = False
|
|
||||||
u.metadata = types.SimpleNamespace(error_code="none")
|
|
||||||
u.rx_stream = _FakeStream(script, u.metadata)
|
|
||||||
return u
|
|
||||||
|
|
||||||
|
|
||||||
def test_rx_is_gapless_across_calls(USRP):
|
|
||||||
# rx_buffer_size=4; each recv yields 4 fresh samples. Two rx(6) calls must
|
|
||||||
# return a contiguous 0..11 ramp — the over-read remainder is carried over.
|
|
||||||
script = [(4, "none")] * 4
|
|
||||||
u = _make_usrp(USRP, script)
|
|
||||||
|
|
||||||
first = u.rx(6)
|
|
||||||
second = u.rx(6)
|
|
||||||
|
|
||||||
assert first.dtype == np.complex64 and len(first) == 6
|
|
||||||
combined = np.concatenate([first, second]).real
|
|
||||||
assert np.array_equal(combined, np.arange(12, dtype=np.float32)) # no drops, no zeros
|
|
||||||
assert "start_cont" in u.rx_stream.issued # stream started exactly via start_cont
|
|
||||||
assert u.rx_stream.issued.count("start_cont") == 1 # ...and only once
|
|
||||||
|
|
||||||
|
|
||||||
def test_rx_starts_stream_only_once(USRP):
|
|
||||||
u = _make_usrp(USRP, [(4, "none")] * 6)
|
|
||||||
u.rx(4)
|
|
||||||
u.rx(4)
|
|
||||||
assert u.rx_stream.issued.count("start_cont") == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_rx_keeps_going_on_overflow(USRP):
|
|
||||||
# Overflow samples are still valid — they must be used, not dropped.
|
|
||||||
script = [(2, "none"), (2, "overflow"), (2, "none")]
|
|
||||||
u = _make_usrp(USRP, script)
|
|
||||||
out = u.rx(6).real
|
|
||||||
assert np.array_equal(out, np.arange(6, dtype=np.float32))
|
|
||||||
|
|
||||||
|
|
||||||
def test_rx_raises_on_persistent_timeout(USRP):
|
|
||||||
from ria_toolkit_oss.sdr.sdr import SdrDisconnectedError
|
|
||||||
|
|
||||||
u = _make_usrp(USRP, [(0, "timeout")] * 10)
|
|
||||||
with pytest.raises(SdrDisconnectedError):
|
|
||||||
u.rx(4)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stop_rx_stream_resets_state(USRP):
|
|
||||||
u = _make_usrp(USRP, [(4, "none")] * 4)
|
|
||||||
u.rx(6) # leaves a 2-sample residual, stream running
|
|
||||||
assert u._rx_streaming is True
|
|
||||||
assert u._rx_residual.size == 2
|
|
||||||
u._stop_rx_stream()
|
|
||||||
assert u._rx_streaming is False
|
|
||||||
assert u._rx_residual.size == 0
|
|
||||||
assert "stop_cont" in u.rx_stream.issued
|
|
||||||
|
|
@ -199,3 +199,44 @@ def test_annotation_to_sigmf_format_values():
|
||||||
values = list(result.values())
|
values = list(result.values())
|
||||||
assert 50 in values or ann.sample_start in values
|
assert 50 in values or ann.sample_start in values
|
||||||
assert 100 in values or ann.sample_count in values
|
assert 100 in values or ann.sample_count in values
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# None freq-edge regression tests (SigMF optional fields)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_no_freq_edges():
|
||||||
|
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||||
|
assert ann.freq_lower_edge is None
|
||||||
|
assert ann.freq_upper_edge is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_is_valid_no_freq_edges():
|
||||||
|
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||||
|
assert ann.is_valid() is True
|
||||||
|
|
||||||
|
ann_zero = Annotation(sample_start=0, sample_count=0, label="burst")
|
||||||
|
assert ann_zero.is_valid() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_overlap_none_edges_returns_zero():
|
||||||
|
ann1 = Annotation(sample_start=0, sample_count=10)
|
||||||
|
ann2 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||||
|
assert ann1.overlap(ann2) == 0
|
||||||
|
assert ann2.overlap(ann1) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_area_none_edges_returns_zero():
|
||||||
|
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||||
|
assert ann.area() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_to_sigmf_omits_freq_keys_when_none():
|
||||||
|
from sigmf import SigMFFile
|
||||||
|
|
||||||
|
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||||
|
result = ann.to_sigmf_format()
|
||||||
|
metadata = result["metadata"]
|
||||||
|
assert SigMFFile.FLO_KEY not in metadata
|
||||||
|
assert SigMFFile.FHI_KEY not in metadata
|
||||||
|
|
|
||||||
|
|
@ -189,3 +189,21 @@ def test_sigmf_3(tmp_path):
|
||||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
assert str(e) == "File already exists"
|
assert str(e) == "File already exists"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sigmf_annotation_without_freq_edges(tmp_path):
|
||||||
|
# Regression: annotations that omit the optional SigMF freq edge fields must
|
||||||
|
# load without error; edges should be None and the annotation still valid.
|
||||||
|
ann = Annotation(sample_start=0, sample_count=5, label="burst")
|
||||||
|
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=[ann])
|
||||||
|
|
||||||
|
filename = tmp_path / "test"
|
||||||
|
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name, overwrite=True)
|
||||||
|
recording2 = from_sigmf(filename)
|
||||||
|
|
||||||
|
assert len(recording2.annotations) == 1
|
||||||
|
loaded = recording2.annotations[0]
|
||||||
|
assert loaded.freq_lower_edge is None
|
||||||
|
assert loaded.freq_upper_edge is None
|
||||||
|
assert loaded.is_valid() is True
|
||||||
|
assert loaded.label == "burst"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user