|
@@ -0,0 +1,467 @@
|
|
|
|
|
+use anyhow::{Context, Result};
|
|
|
|
|
+use clap::Parser;
|
|
|
|
|
+use log::{debug, error, info, warn};
|
|
|
|
|
+use multicast_relay::{
|
|
|
|
|
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
|
|
|
|
|
+ config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups},
|
|
|
|
|
+ protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo},
|
|
|
|
|
+ DEFAULT_BUFFER_SIZE,
|
|
|
|
|
+};
|
|
|
|
|
+use socket2::{Domain, Protocol, Socket, Type};
|
|
|
|
|
+use std::{
|
|
|
|
|
+ collections::HashMap,
|
|
|
|
|
+ net::{IpAddr, Ipv4Addr, SocketAddr},
|
|
|
|
|
+ path::PathBuf,
|
|
|
|
|
+ str::FromStr,
|
|
|
|
|
+ sync::Arc,
|
|
|
|
|
+ time::Duration, // Add this import for Duration
|
|
|
|
|
+};
|
|
|
|
|
+use tokio::{
|
|
|
|
|
+ io::{AsyncReadExt, AsyncWriteExt},
|
|
|
|
|
+ net::{TcpListener, TcpStream},
|
|
|
|
|
+ sync::{mpsc, Mutex},
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
|
|
+#[derive(Parser, Debug)]
|
|
|
|
|
+#[command(author, version, about, long_about = None)]
|
|
|
|
|
+struct Args {
|
|
|
|
|
+ #[arg(short, long, default_value = "server_config.toml")]
|
|
|
|
|
+ config: PathBuf,
|
|
|
|
|
+
|
|
|
|
|
+ #[arg(short, long, action)]
|
|
|
|
|
+ generate_default: bool,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>;
|
|
|
|
|
+
|
|
|
|
|
+#[tokio::main]
|
|
|
|
|
+async fn main() -> Result<()> {
|
|
|
|
|
+ env_logger::init();
|
|
|
|
|
+ let args = Args::parse();
|
|
|
|
|
+
|
|
|
|
|
+ // Generate default configs if requested
|
|
|
|
|
+ if args.generate_default {
|
|
|
|
|
+ ensure_default_configs()?;
|
|
|
|
|
+ return Ok(());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Load configuration
|
|
|
|
|
+ let config = load_server_config(&args.config)
|
|
|
|
|
+ .context(format!("Failed to load config from {:?}", args.config))?;
|
|
|
|
|
+
|
|
|
|
|
+ info!("Server configuration loaded from {:?}", args.config);
|
|
|
|
|
+
|
|
|
|
|
+ let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port);
|
|
|
|
|
+ let listener = TcpListener::bind(&listen_addr).await
|
|
|
|
|
+ .context("Failed to bind TCP listener")?;
|
|
|
|
|
+
|
|
|
|
|
+ info!("Server listening on {}", listen_addr);
|
|
|
|
|
+
|
|
|
|
|
+ // Setup multicast receivers
|
|
|
|
|
+ let clients: ClientMap = Arc::new(Mutex::new(HashMap::new()));
|
|
|
|
|
+
|
|
|
|
|
+ // Start multicast listeners for each multicast group
|
|
|
|
|
+ for (group_id, group) in &config.multicast_groups {
|
|
|
|
|
+ let ports = group.get_ports();
|
|
|
|
|
+ if ports.is_empty() {
|
|
|
|
|
+ error!("No ports defined for group {}", group_id);
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ let display_group_id = group_id.clone();
|
|
|
|
|
+ let ports_display = if ports.len() == 1 {
|
|
|
|
|
+ format!("port {}", ports[0])
|
|
|
|
|
+ } else {
|
|
|
|
|
+ format!("ports {}-{}", ports[0], ports.last().unwrap())
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Create a listener for each port in the range
|
|
|
|
|
+ for port in ports {
|
|
|
|
|
+ let clients = clients.clone();
|
|
|
|
|
+ let _secret = config.secret.clone();
|
|
|
|
|
+ let group_id_clone = group_id.clone();
|
|
|
|
|
+ let mut group_info = group.clone();
|
|
|
|
|
+
|
|
|
|
|
+ // Set the specific port for this listener
|
|
|
|
|
+ group_info.port = Some(port);
|
|
|
|
|
+ group_info.port_range = None;
|
|
|
|
|
+
|
|
|
|
|
+ tokio::spawn(async move {
|
|
|
|
|
+ if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await {
|
|
|
|
|
+ error!("Multicast listener error for group {} port {}: {}",
|
|
|
|
|
+ group_id_clone, port, e);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ info!("Listening for multicast group {} on address {} with {}",
|
|
|
|
|
+ display_group_id, group.address, ports_display);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Store config for use in client handlers
|
|
|
|
|
+ let config = Arc::new(config);
|
|
|
|
|
+
|
|
|
|
|
+ // Accept client connections
|
|
|
|
|
+ while let Ok((stream, addr)) = listener.accept().await {
|
|
|
|
|
+ info!("New client connection from: {}", addr);
|
|
|
|
|
+ let secret = config.secret.clone();
|
|
|
|
|
+ let clients = clients.clone();
|
|
|
|
|
+ let config = config.clone();
|
|
|
|
|
+
|
|
|
|
|
+ tokio::spawn(async move {
|
|
|
|
|
+ if let Err(e) = handle_client(stream, addr, &secret, clients, config).await {
|
|
|
|
|
+ error!("Client error: {}: {}", addr, e);
|
|
|
|
|
+ }
|
|
|
|
|
+ info!("Client disconnected: {}", addr);
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ Ok(())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+async fn listen_to_multicast(
|
|
|
|
|
+ group_id: &str,
|
|
|
|
|
+ group: &MulticastGroup,
|
|
|
|
|
+ clients: ClientMap
|
|
|
|
|
+) -> Result<()> {
|
|
|
|
|
+ // Get the port to use
|
|
|
|
|
+ let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?;
|
|
|
|
|
+
|
|
|
|
|
+ // Parse the multicast address
|
|
|
|
|
+ let mcast_ip = match IpAddr::from_str(&group.address)
|
|
|
|
|
+ .context("Invalid multicast address")? {
|
|
|
|
|
+ IpAddr::V4(addr) => addr,
|
|
|
|
|
+ _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported"))
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Create a UDP socket with more explicit settings
|
|
|
|
|
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
|
|
|
|
|
+ .context("Failed to create socket")?;
|
|
|
|
|
+
|
|
|
|
|
+ // Important: Set socket options
|
|
|
|
|
+ socket.set_reuse_address(true)?;
|
|
|
|
|
+
|
|
|
|
|
+ #[cfg(unix)]
|
|
|
|
|
+ socket.set_reuse_port(true)?;
|
|
|
|
|
+
|
|
|
|
|
+ socket.set_nonblocking(true)?;
|
|
|
|
|
+ socket.set_multicast_loop_v4(true)?;
|
|
|
|
|
+
|
|
|
|
|
+ // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port
|
|
|
|
|
+ // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port
|
|
|
|
|
+ let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port);
|
|
|
|
|
+ info!("Binding multicast listener to specific address: {:?}", bind_addr);
|
|
|
|
|
+ socket.bind(&bind_addr.into())?;
|
|
|
|
|
+
|
|
|
|
|
+ // Join the multicast group with a specific interface
|
|
|
|
|
+ let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface
|
|
|
|
|
+ info!("Joining multicast group {} on interface {:?}", mcast_ip, interface);
|
|
|
|
|
+ socket.join_multicast_v4(&mcast_ip, &interface)?;
|
|
|
|
|
+
|
|
|
|
|
+ // Additional multicast option: set the IP_MULTICAST_IF option
|
|
|
|
|
+ socket.set_multicast_if_v4(&interface)?;
|
|
|
|
|
+
|
|
|
|
|
+ // Convert to tokio socket
|
|
|
|
|
+ let udp_socket = tokio::net::UdpSocket::from_std(socket.into())
|
|
|
|
|
+ .context("Failed to convert socket to async")?;
|
|
|
|
|
+
|
|
|
|
|
+ info!("Multicast listener ready and bound specifically to {}:{} (group {})",
|
|
|
|
|
+ group.address, port, group_id);
|
|
|
|
|
+
|
|
|
|
|
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
|
|
|
|
|
+ let group_id = group_id.to_string();
|
|
|
|
|
+
|
|
|
|
|
+ loop {
|
|
|
|
|
+ match udp_socket.recv_from(&mut buf).await {
|
|
|
|
|
+ Ok((len, src)) => {
|
|
|
|
|
+ // Since we're bound to the exact multicast address, we can be confident
|
|
|
|
|
+ // this packet was sent to our specific multicast group
|
|
|
|
|
+ let data = buf[..len].to_vec();
|
|
|
|
|
+
|
|
|
|
|
+ info!("RECEIVED: group={} from={} size={} destination={}:{}",
|
|
|
|
|
+ group_id, src, len, mcast_ip, port);
|
|
|
|
|
+
|
|
|
|
|
+ // Create a message with the packet
|
|
|
|
|
+ let message = Message::MulticastPacket {
|
|
|
|
|
+ group_id: group_id.clone(),
|
|
|
|
|
+ source: src,
|
|
|
|
|
+ destination: group.address.clone(),
|
|
|
|
|
+ port,
|
|
|
|
|
+ data,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Send to clients
|
|
|
|
|
+ match serialize_message(&message) {
|
|
|
|
|
+ Ok(serialized) => {
|
|
|
|
|
+ let clients_lock = clients.lock().await;
|
|
|
|
|
+ for (client_addr, sender) in clients_lock.iter() {
|
|
|
|
|
+ if sender.send(serialized.clone()).await.is_err() {
|
|
|
|
|
+ debug!("Failed to send to client {}", client_addr);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ debug!("Sent multicast packet to client {}", client_addr);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ Err(e) => error!("Failed to serialize message: {}", e),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ Err(e) => {
|
|
|
|
|
+ if e.kind() != std::io::ErrorKind::WouldBlock {
|
|
|
|
|
+ error!("Error receiving from socket: {}", e);
|
|
|
|
|
+ }
|
|
|
|
|
+ // Small delay to avoid busy waiting on errors
|
|
|
|
|
+ tokio::time::sleep(Duration::from_millis(10)).await;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#[derive(PartialEq)]
|
|
|
|
|
+enum StatusMessageType {
|
|
|
|
|
+ ClientHeartbeat,
|
|
|
|
|
+ ClientPong,
|
|
|
|
|
+ Other
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+async fn handle_client(
|
|
|
|
|
+ stream: TcpStream,
|
|
|
|
|
+ addr: SocketAddr,
|
|
|
|
|
+ secret: &str,
|
|
|
|
|
+ clients: ClientMap,
|
|
|
|
|
+ config: Arc<ServerConfig>,
|
|
|
|
|
+) -> Result<()> {
|
|
|
|
|
+ // Check if external clients are allowed when client is not from localhost
|
|
|
|
|
+ if !config.allow_external_clients &&
|
|
|
|
|
+ !addr.ip().is_loopback() &&
|
|
|
|
|
+ !addr.ip().to_string().starts_with("192.168.") &&
|
|
|
|
|
+ !addr.ip().to_string().starts_with("10.") {
|
|
|
|
|
+ warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr);
|
|
|
|
|
+ return Err(anyhow::anyhow!("External clients not allowed"));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Split the TCP stream into read and write parts once
|
|
|
|
|
+ let (mut read_stream, mut write_stream) = tokio::io::split(stream);
|
|
|
|
|
+
|
|
|
|
|
+ // Authentication using the split streams
|
|
|
|
|
+ if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? {
|
|
|
|
|
+ return Err(anyhow::anyhow!("Authentication failed"));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ info!("Client authenticated: {}", addr);
|
|
|
|
|
+
|
|
|
|
|
+ // Get client info
|
|
|
|
|
+ let client_ip = addr.ip().to_string();
|
|
|
|
|
+ let client_port = addr.port();
|
|
|
|
|
+
|
|
|
|
|
+ // Check if client has specific group permissions
|
|
|
|
|
+ let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) {
|
|
|
|
|
+ Some(groups) => groups,
|
|
|
|
|
+ None => return Err(anyhow::anyhow!("Client not authorized for any groups")),
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Create channel for sending multicast packets to this client
|
|
|
|
|
+ let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
|
|
|
|
+
|
|
|
|
|
+ // Add client to map
|
|
|
|
|
+ clients.lock().await.insert(addr, tx.clone());
|
|
|
|
|
+
|
|
|
|
|
+ // Create HashMap of available groups for this client
|
|
|
|
|
+ let mut available_groups = HashMap::new();
|
|
|
|
|
+ for (id, group) in &config.multicast_groups {
|
|
|
|
|
+ // If client has empty group list (all allowed) or specific group is in list
|
|
|
|
|
+ if authorized_groups.is_empty() || authorized_groups.contains(id) {
|
|
|
|
|
+ let ports = group.get_ports();
|
|
|
|
|
+ if ports.is_empty() {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Primary port is the first one
|
|
|
|
|
+ let primary_port = ports[0];
|
|
|
|
|
+
|
|
|
|
|
+ // Get additional ports if there are any
|
|
|
|
|
+ let additional_ports = if ports.len() > 1 {
|
|
|
|
|
+ Some(ports[1..].to_vec())
|
|
|
|
|
+ } else {
|
|
|
|
|
+ None
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ available_groups.insert(id.clone(), MulticastGroupInfo {
|
|
|
|
|
+ address: group.address.clone(),
|
|
|
|
|
+ port: primary_port,
|
|
|
|
|
+ additional_ports,
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // IMPORTANT: Clone tx before moving it into the spawn
|
|
|
|
|
+ let tx_for_read = tx.clone();
|
|
|
|
|
+
|
|
|
|
|
+ // Spawn task to read client messages
|
|
|
|
|
+ let clients_clone = clients.clone();
|
|
|
|
|
+
|
|
|
|
|
+ // Use the already split read_stream
|
|
|
|
|
+ let read_task = tokio::spawn(async move {
|
|
|
|
|
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
|
|
|
|
|
+ loop {
|
|
|
|
|
+ match read_stream.read(&mut buf).await {
|
|
|
|
|
+ Ok(0) => break, // Connection closed
|
|
|
|
|
+ Ok(n) => {
|
|
|
|
|
+ if let Ok(msg) = deserialize_message(&buf[..n]) {
|
|
|
|
|
+ match msg {
|
|
|
|
|
+ Message::Subscribe { group_ids } => {
|
|
|
|
|
+ info!("Client {} subscribing to groups: {:?}", addr, group_ids);
|
|
|
|
|
+ // Group subscriptions handled by server
|
|
|
|
|
+ },
|
|
|
|
|
+ Message::MulticastGroupsRequest => {
|
|
|
|
|
+ // Send available groups to client
|
|
|
|
|
+ let response = Message::MulticastGroupsResponse {
|
|
|
|
|
+ groups: available_groups.clone()
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ if let Ok(bytes) = serialize_message(&response) {
|
|
|
|
|
+ let _ = tx_for_read.send(bytes).await;
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ Message::PingStatus { timestamp, status } => {
|
|
|
|
|
+ // Determine the type of status message
|
|
|
|
|
+ let msg_type = if status.starts_with("Client keepalive ping") ||
|
|
|
|
|
+ status.starts_with("Client periodic ping") {
|
|
|
|
|
+ StatusMessageType::ClientHeartbeat
|
|
|
|
|
+ } else if status.starts_with("Client pong response") {
|
|
|
|
|
+ StatusMessageType::ClientPong
|
|
|
|
|
+ } else {
|
|
|
|
|
+ StatusMessageType::Other
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Log the message receipt
|
|
|
|
|
+ match msg_type {
|
|
|
|
|
+ StatusMessageType::ClientHeartbeat => {
|
|
|
|
|
+ info!("Heartbeat from client {}: {}", addr, status);
|
|
|
|
|
+
|
|
|
|
|
+ // Respond only to actual heartbeat pings, not pong responses
|
|
|
|
|
+ let response = Message::PingStatus {
|
|
|
|
|
+ timestamp,
|
|
|
|
|
+ status: format!("Server connection to {} is OK", addr),
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ if let Ok(bytes) = serialize_message(&response) {
|
|
|
|
|
+ let _ = tx_for_read.send(bytes).await;
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ StatusMessageType::ClientPong => {
|
|
|
|
|
+ // Just log pongs without responding to avoid loops
|
|
|
|
|
+ debug!("Pong from client {}: {}", addr, status);
|
|
|
|
|
+ },
|
|
|
|
|
+ StatusMessageType::Other => {
|
|
|
|
|
+ info!("Status message from client {}: {}", addr, status);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ _ => {}
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ Err(e) => {
|
|
|
|
|
+ error!("Error reading from client: {}: {}", addr, e);
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ // Clean up on disconnect
|
|
|
|
|
+ clients_clone.lock().await.remove(&addr);
|
|
|
|
|
+ info!("Client reader task ended: {}", addr);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Forward multicast packets to client using the already split write_stream
|
|
|
|
|
+ let write_task = tokio::spawn(async move {
|
|
|
|
|
+ while let Some(packet) = rx.recv().await {
|
|
|
|
|
+ if let Err(e) = write_stream.write_all(&packet).await {
|
|
|
|
|
+ error!("Error writing to client {}: {}", addr, e);
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ info!("Client writer task ended: {}", addr);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Now tx is still valid here - use it for heartbeats, but with a delay
|
|
|
|
|
+ let tx_for_heartbeat = tx.clone();
|
|
|
|
|
+ let client_addr = addr.clone();
|
|
|
|
|
+ tokio::spawn(async move {
|
|
|
|
|
+ // Add initial delay before starting heartbeats to avoid interfering with initial setup messages
|
|
|
|
|
+ tokio::time::sleep(Duration::from_secs(5)).await;
|
|
|
|
|
+
|
|
|
|
|
+ let mut interval = tokio::time::interval(Duration::from_secs(30));
|
|
|
|
|
+ loop {
|
|
|
|
|
+ interval.tick().await;
|
|
|
|
|
+ let now = std::time::SystemTime::now()
|
|
|
|
|
+ .duration_since(std::time::UNIX_EPOCH)
|
|
|
|
|
+ .unwrap()
|
|
|
|
|
+ .as_secs();
|
|
|
|
|
+
|
|
|
|
|
+ // Send a heartbeat with a clear identifier
|
|
|
|
|
+ let msg = Message::PingStatus {
|
|
|
|
|
+ timestamp: now,
|
|
|
|
|
+ status: format!("Server heartbeat to {}", client_addr),
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ if let Ok(bytes) = serialize_message(&msg) {
|
|
|
|
|
+ if tx_for_heartbeat.send(bytes).await.is_err() {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Wait for either task to complete
|
|
|
|
|
+ tokio::select! {
|
|
|
|
|
+ _ = read_task => {},
|
|
|
|
|
+ _ = write_task => {},
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Clean up
|
|
|
|
|
+ clients.lock().await.remove(&addr);
|
|
|
|
|
+ Ok(())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+async fn authenticate_client(
|
|
|
|
|
+ reader: &mut (impl AsyncReadExt + Unpin),
|
|
|
|
|
+ writer: &mut (impl AsyncWriteExt + Unpin),
|
|
|
|
|
+ _addr: SocketAddr,
|
|
|
|
|
+ secret: &str
|
|
|
|
|
+) -> Result<bool> {
|
|
|
|
|
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
|
|
|
|
|
+
|
|
|
|
|
+ // Receive auth request
|
|
|
|
|
+ let n = reader.read(&mut buf).await?;
|
|
|
|
|
+ let auth_request = deserialize_message(&buf[..n])?;
|
|
|
|
|
+
|
|
|
|
|
+ if let Message::AuthRequest { client_nonce } = auth_request {
|
|
|
|
|
+ // Generate server nonce
|
|
|
|
|
+ let server_nonce = generate_nonce();
|
|
|
|
|
+
|
|
|
|
|
+ // Calculate auth token
|
|
|
|
|
+ let auth_data = format!("{}{}", client_nonce, server_nonce);
|
|
|
|
|
+ let auth_token = calculate_hmac(secret, &auth_data);
|
|
|
|
|
+
|
|
|
|
|
+ // Send response
|
|
|
|
|
+ let response = Message::AuthResponse {
|
|
|
|
|
+ server_nonce: server_nonce.clone(),
|
|
|
|
|
+ auth_token,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ let response_bytes = serialize_message(&response)?;
|
|
|
|
|
+ writer.write_all(&response_bytes).await?;
|
|
|
|
|
+
|
|
|
|
|
+ // Receive confirmation
|
|
|
|
|
+ let n = reader.read(&mut buf).await?;
|
|
|
|
|
+ let auth_confirm = deserialize_message(&buf[..n])?;
|
|
|
|
|
+
|
|
|
|
|
+ if let Message::AuthConfirm { auth_token } = auth_confirm {
|
|
|
|
|
+ // Verify token
|
|
|
|
|
+ let expected_data = format!("{}{}", server_nonce, client_nonce);
|
|
|
|
|
+ return Ok(verify_hmac(secret, &expected_data, &auth_token));
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ Ok(false)
|
|
|
|
|
+}
|