gain-viz/gain_viz/app.py
2025-12-17 22:42:21 -05:00

630 lines
20 KiB
Python

from flask import Flask, render_template, send_file, request, jsonify
import zmq
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import os
import threading
import time
import serial
import json
import subprocess
import io
import base64 # ADD THIS IMPORT
from PIL import Image
from flask_socketio import SocketIO, emit
# Define PLOT_PATH at the top level
PLOT_PATH = os.path.join(os.getcwd(), "plot.png")
app = Flask(__name__)
# Initialize SocketIO with proper configuration
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading')
# ----------------- Shared Config -----------------
config = {
"usrp_tx_gain": 60,
"usrp_rx_gain": 30,
"scm_tx_gain": 30,
"scm_rx_gain": 30,
"sample_rate": 23.04e6,
"window_ms": 20,
"center_freq": 3.415e9,
"NFFT": 1024,
"tcp_port": 5556,
"streaming": False,
}
config_lock = threading.Lock()
# Global variables
usrp_tx_gain = config["usrp_tx_gain"]
usrp_rx_gain = config["usrp_rx_gain"]
scm_tx_gain = config["scm_tx_gain"]
scm_rx_gain = config["scm_rx_gain"]
# Plotting thread control
plot_thread = None
stop_event = threading.Event()
pause_event = threading.Event()
# TMUX output capture
tmux_output = []
tmux_lock = threading.Lock()
tmux_thread = None
tmux_stop_event = threading.Event()
# In-memory plot storage
plot_lock = threading.Lock()
plot_buffer = io.BytesIO()
# WebSocket sender thread
websocket_thread = None
# ----------------- Serial / SCM -----------------
def connect_serial(port, baudrate=115200, timeout=1):
"""Connect to a serial port with even parity."""
try:
ser = serial.Serial(
port=port,
baudrate=baudrate,
timeout=timeout,
bytesize=serial.EIGHTBITS,
parity=serial.PARITY_EVEN,
stopbits=serial.STOPBITS_ONE
)
return ser
except serial.SerialException as e:
print(f"Error connecting to {port}: {e}")
return None
def send_command(ser, command):
if ser and ser.is_open:
ser.write(command.encode('utf-8'))
def receive_feedback(ser):
if ser and ser.is_open:
try:
ser.flush()
raw_response = ser.readlines()
if raw_response:
rep = ""
for x in raw_response:
rep += str(x) + " ,"
rep = rep[2:].split("\\r")
return rep[-2]
except serial.SerialTimeoutException:
return ""
return ""
def scm_conf(port, baudrate, rx_cmd, tx_cmd):
ser = connect_serial(port, baudrate)
commands = [rx_cmd, tx_cmd]
if ser:
for cmd in commands:
feedback = None
attempt = 0
while feedback != "OK" and attempt < 5:
send_command(ser, cmd + "\r")
feedback = receive_feedback(ser)
attempt += 1
ser.close()
return True
return False
# ----------------- TMUX Output Capture -----------------
def capture_tmux_output():
"""Capture tmux output in a separate thread"""
while not tmux_stop_event.is_set():
try:
# First check if the tmux session exists
check_cmd = "tmux has-session -t ran 2>/dev/null"
result = subprocess.run(check_cmd, shell=True, capture_output=True)
if result.returncode == 0:
# Capture tmux output
cmd = "tmux capture-pane -t ran -p"
output = subprocess.check_output(cmd, shell=True, text=True)
with tmux_lock:
# Keep only the last 100 lines to avoid memory issues
lines = output.split('\n')
tmux_output[:] = lines[-100:] if len(lines) > 100 else lines
else:
with tmux_lock:
tmux_output[:] = ["TMUX session 'ran' not found. Please start the RAN application."]
time.sleep(1) # Update every second
except Exception as e:
print(f"Error capturing tmux output: {e}")
with tmux_lock:
tmux_output[:] = [f"Error capturing tmux output: {str(e)}"]
time.sleep(5) # Wait longer if there's an error
def start_tmux_capture():
"""Start the tmux capture thread"""
global tmux_thread
tmux_stop_event.clear()
if tmux_thread is None or not tmux_thread.is_alive():
tmux_thread = threading.Thread(target=capture_tmux_output, daemon=True)
tmux_thread.start()
print("TMUX capture thread started")
return True
def stop_tmux_capture():
"""Stop the tmux capture thread"""
global tmux_thread
tmux_stop_event.set()
if tmux_thread and tmux_thread.is_alive():
tmux_thread.join(timeout=2.0)
print("TMUX capture thread stopped")
return True
# ----------------- Gain Updates -----------------
def gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx):
global usrp_tx_gain, usrp_rx_gain, scm_tx_gain, scm_rx_gain
scm_change = False
if usrp_tx != usrp_tx_gain:
usrp_tx_gain = usrp_tx
os.system(f"tmux send-keys -t ran 'tx_gain 0 {usrp_tx_gain} ' C-m")
if usrp_rx != usrp_rx_gain:
usrp_rx_gain = usrp_rx
os.system(f"tmux send-keys -t ran 'rx_gain 0 {usrp_rx_gain} ' C-m")
if scm_tx != scm_tx_gain:
scm_tx_gain = scm_tx
scm_change = True
if scm_rx != scm_rx_gain:
scm_rx_gain = scm_rx
scm_change = True
t_cmd = f"HW:GAIN 0 TX 0 {scm_tx_gain}"
r_cmd = f"HW:GAIN 1 RX 0 {scm_rx_gain}"
if scm_change:
scm_conf("/dev/ttyUSB0", 115200, r_cmd, t_cmd)
scm_conf("/dev/ttyUSB1", 115200, r_cmd, t_cmd)
with config_lock:
config["scm_tx_gain"] = scm_tx_gain
config["scm_rx_gain"] = scm_rx_gain
with config_lock:
config["usrp_tx_gain"] = usrp_tx_gain
config["usrp_rx_gain"] = usrp_rx_gain
return True
# ----------------- Plot Generation -----------------
def generate_spectrum_plot():
socket = None
iq_sample = np.zeros(1, dtype=np.complex64)
last_port = None
while not stop_event.is_set():
# Check if we're paused
if pause_event.is_set():
time.sleep(0.1)
continue
with config_lock:
sample_rate = config["sample_rate"]
window_ms = config["window_ms"]
center_freq = config["center_freq"]
NFFT = config["NFFT"]
tcp_port = config["tcp_port"]
streaming = config["streaming"]
# Only process if streaming is active
if not streaming:
time.sleep(0.1)
continue
# Reconnect if port changed or socket is None
if socket is None or tcp_port != last_port:
if socket:
socket.close()
try:
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.setsockopt(zmq.CONFLATE, 1)
socket.setsockopt_string(zmq.SUBSCRIBE, "")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.connect(f"tcp://localhost:{tcp_port}")
last_port = tcp_port
print(f"Connected to ZMQ on port {tcp_port}")
except Exception as e:
print(f"ZMQ connection error: {e}")
socket = None
time.sleep(1)
continue
window_samples = int(sample_rate * window_ms / 1000)
if iq_sample.size != window_samples:
iq_sample = np.zeros(window_samples, dtype=np.complex64)
try:
msg = socket.recv(zmq.NOBLOCK)
float_data = np.frombuffer(msg, dtype=np.float32)
if float_data.size >= 2:
complex_data = float_data.reshape(-1, 2)
iq_all = complex_data[:, 0] + 1j * complex_data[:, 1]
if len(iq_all) >= window_samples:
iq_sample = iq_all[-window_samples:]
else:
iq_sample = np.pad(iq_all, (window_samples - len(iq_all), 0))
# Create plot with optimized settings
plt.rcParams['savefig.dpi'] = 80
plt.rcParams['figure.dpi'] = 80
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
fig.subplots_adjust(hspace=0.3)
# Time-domain plot
times_ms = np.arange(len(iq_sample)) * 1000 / sample_rate
ax1.plot(times_ms, np.real(iq_sample), label="Real", color='b', linewidth=0.8)
ax1.plot(times_ms, np.imag(iq_sample), label="Imag", color='r', linewidth=0.8)
ax1.set_xlim(0, window_ms)
ax1.set_xlabel("Time (ms)")
ax1.set_ylabel("IQ Amplitude")
ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
ax1.legend(fontsize=8)
# Spectrogram
cmap = plt.get_cmap('twilight')
ax2.specgram(
iq_sample,
Fs=sample_rate,
Fc=center_freq,
NFFT=NFFT,
noverlap=512,
cmap=cmap
)
ax2.set_xlabel("Time (ms)")
ax2.set_ylabel("Frequency (Hz)")
ax2.grid(False)
ax2.set_ylim(center_freq - sample_rate / 2,
center_freq + sample_rate / 2)
ax2.xaxis.set_major_formatter(
ticker.FuncFormatter(lambda t, pos: '{0:g}'.format(t*1e3))
)
ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())
# Save to buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=80)
buf.seek(0)
# Store in global buffer and send via WebSocket
with plot_lock:
plot_buffer.seek(0)
plot_buffer.truncate()
plot_buffer.write(buf.getvalue())
plot_buffer.seek(0)
# Send via SocketIO
img_data = base64.b64encode(buf.getvalue()).decode('utf-8')
socketio.emit('plot_update', {'image': img_data})
plt.close(fig)
buf.close()
except zmq.Again:
# No new data
fig, ax = plt.subplots(figsize=(10, 5))
ax.text(0.5, 0.5, "Waiting for data...",
ha='center', va='center', transform=ax.transAxes, fontsize=14)
ax.set_title("Spectrum Analyzer - No Data (Streaming Active)")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=80)
buf.seek(0)
with plot_lock:
plot_buffer.seek(0)
plot_buffer.truncate()
plot_buffer.write(buf.getvalue())
plt.close(fig)
buf.close()
except Exception as e:
print(f"Plot generation error: {e}")
fig, ax = plt.subplots(figsize=(10, 5))
ax.text(0.5, 0.5, f"Error: {str(e)}",
ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.set_title("Spectrum Analyzer - Error")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=80)
buf.seek(0)
with plot_lock:
plot_buffer.seek(0)
plot_buffer.truncate()
plot_buffer.write(buf.getvalue())
plt.close(fig)
buf.close()
time.sleep(0.1)
# Cleanup when stopping
if socket:
socket.close()
print("Plotting thread stopped")
def start_plotting():
"""Start the plotting thread"""
global plot_thread, stop_event, pause_event
stop_event.clear()
pause_event.clear()
with config_lock:
config["streaming"] = True
if plot_thread is None or not plot_thread.is_alive():
plot_thread = threading.Thread(target=generate_spectrum_plot, daemon=True)
plot_thread.start()
print("Plotting thread started")
return True
def stop_plotting():
"""Stop the plotting thread"""
global plot_thread, stop_event
with config_lock:
config["streaming"] = False
stop_event.set()
if plot_thread and plot_thread.is_alive():
plot_thread.join(timeout=2.0)
# Create stopped message plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.text(0.5, 0.5, "Streaming Stopped\nClick Start to begin",
ha='center', va='center', transform=ax.transAxes, fontsize=14)
ax.set_title("Spectrum Analyzer - Stopped")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=80)
buf.seek(0)
with plot_lock:
plot_buffer.seek(0)
plot_buffer.truncate()
plot_buffer.write(buf.getvalue())
plt.close(fig)
buf.close()
print("Plotting thread stopped")
return True
def pause_plotting():
"""Pause the plotting updates"""
global pause_event
if pause_event.is_set():
pause_event.clear()
print("Plotting resumed")
return "running"
else:
pause_event.set()
print("Plotting paused")
return "paused"
# ----------------- Flask Routes -----------------
@app.route('/')
def index():
# Read the template file and inject Socket.IO script
template_path = 'templates/index.html'
if os.path.exists(template_path):
return render_template('index.html')
return "Template not found", 404
@app.route('/update_gains', methods=['POST'])
def update_gains():
global usrp_tx_gain, usrp_rx_gain, scm_tx_gain, scm_rx_gain
try:
usrp_tx = request.form.get('usrp_tx_gain', type=float)
usrp_rx = request.form.get('usrp_rx_gain', type=float)
scm_tx = request.form.get('scm_tx_gain', type=float)
scm_rx = request.form.get('scm_rx_gain', type=float)
if usrp_tx is None:
usrp_tx = usrp_tx_gain
if usrp_rx is None:
usrp_rx = usrp_rx_gain
if scm_tx is None:
scm_tx = scm_tx_gain
if scm_rx is None:
scm_rx = scm_rx_gain
success = gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx)
if success:
return jsonify({"status": "success", "message": "Gains updated successfully"})
else:
return jsonify({"status": "error", "message": "Failed to update gains"}), 500
except Exception as e:
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/plot')
def plot():
try:
with plot_lock:
plot_buffer.seek(0)
img_data = plot_buffer.read()
if not img_data:
# Return placeholder if buffer is empty
return send_file(PLOT_PATH, mimetype='image/png')
# Create a new BytesIO object for this request
img_io = io.BytesIO(img_data)
img_io.seek(0)
response = send_file(
img_io,
mimetype='image/png',
cache_timeout=0
)
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response
except Exception as e:
print(f"Error serving plot: {e}")
try:
return send_file(PLOT_PATH, mimetype='image/png')
except:
return "Error serving plot", 500
@app.route('/get_gains')
def get_gains():
return jsonify({
"usrp_tx_gain": usrp_tx_gain,
"usrp_rx_gain": usrp_rx_gain,
"scm_tx_gain": scm_tx_gain,
"scm_rx_gain": scm_rx_gain
})
@app.route('/update_params', methods=['POST'])
def update_params():
try:
center_freq = request.form.get('center_freq', type=float)
sample_rate = request.form.get('sample_rate', type=float)
NFFT = request.form.get('fft_size', type=int)
window_ms = request.form.get('window_ms', type=float)
tcp_port = request.form.get('tcp_port', type=int)
if not all([center_freq, sample_rate, NFFT, window_ms, tcp_port]):
return jsonify({
'status': 'error',
'message': 'All parameters are required'
}), 400
with config_lock:
config["center_freq"] = center_freq
config["sample_rate"] = sample_rate
config["NFFT"] = NFFT
config["window_ms"] = window_ms
config["tcp_port"] = tcp_port
print(f"Updated params: center_freq={center_freq}, sample_rate={sample_rate}, NFFT={NFFT}, window_ms={window_ms}, tcp_port={tcp_port}")
save_config()
return jsonify({
'status': 'success',
'message': 'Parameters updated successfully'
})
except Exception as e:
print(f"Error updating params: {e}")
return jsonify({
'status': 'error',
'message': str(e)
}), 500
@app.route('/start_stream', methods=['POST'])
def start_stream():
try:
success = start_plotting()
if success:
start_tmux_capture()
return jsonify({"status": "success", "message": "Streaming started"})
else:
return jsonify({"status": "error", "message": "Failed to start streaming"}), 500
except Exception as e:
print(f"Error starting stream: {e}")
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/stop_stream', methods=['POST'])
def stop_stream():
try:
success = stop_plotting()
if success:
stop_tmux_capture()
return jsonify({"status": "success", "message": "Streaming stopped"})
else:
return jsonify({"status": "error", "message": "Failed to stop streaming"}), 500
except Exception as e:
print(f"Error stopping stream: {e}")
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/pause_stream', methods=['POST'])
def pause_stream():
try:
result = pause_plotting()
return jsonify({"status": "success", "message": f"Streaming {result}", "state": result})
except Exception as e:
print(f"Error pausing stream: {e}")
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/get_stream_state', methods=['GET'])
def get_stream_state():
with config_lock:
streaming = config["streaming"]
paused = pause_event.is_set()
state = "stopped"
if streaming and not paused:
state = "running"
elif streaming and paused:
state = "paused"
return jsonify({"state": state})
@app.route('/tmux_output', methods=['GET'])
def get_tmux_output():
"""Return the captured tmux output"""
with tmux_lock:
return jsonify({"output": tmux_output})
# WebSocket event handlers
@socketio.on('connect')
def handle_connect():
print('Client connected via WebSocket')
@socketio.on('disconnect')
def handle_disconnect():
print('Client disconnected from WebSocket')
def save_config():
with config_lock:
cfg = dict(config)
try:
with open(os.path.join(os.getcwd(), "gain_viz.json"), 'w') as f:
json.dump(cfg, f, indent=2)
except Exception as e:
print(f"Error saving config: {e}")
# ----------------- Main -----------------
def main():
# Ensure placeholder image exists
if not os.path.exists(PLOT_PATH):
fig, ax = plt.subplots(figsize=(10, 5))
ax.text(0.5, 0.5, "Click Start to begin streaming", ha='center', va='center', fontsize=14)
ax.set_title("Gain-Viz Spectrum Analyzer - Ready")
plt.savefig(PLOT_PATH, bbox_inches='tight', dpi=80)
plt.close(fig)
print("Gain-Viz server starting on http://0.0.0.0:5000")
print("WebSocket support enabled")
# Run the SocketIO server
socketio.run(app, host="0.0.0.0", port=5000, debug=False, use_reloader=False, allow_unsafe_werkzeug=True)
if __name__ == '__main__':
main()