diff --git a/README.md b/README.md index 10f351b..7ab7f8c 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Artifacts that track this milestone live in two places: - TURN data plane enablement: - [x] `CreatePermission` handling and permission tracking - [x] `ChannelBind` setup and `Send` forwarding to peers - - [ ] ChannelData framing and Data Indication responses from relay to client + - [x] ChannelData framing and Data Indication responses from relay to client License: MIT diff --git a/src/alloc.rs b/src/alloc.rs index 6d50159..6dea58f 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -7,6 +7,8 @@ use std::time::{Duration, Instant}; use tokio::net::UdpSocket; use tracing::info; +use crate::stun::{build_channel_data, build_data_indication}; + #[derive(Clone)] pub struct Allocation { pub client: SocketAddr, @@ -45,6 +47,7 @@ impl AllocationManager { let relay_clone = relay_arc.clone(); let server_sock_clone = server_sock.clone(); let client_clone = client; + let manager_clone = self.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 2048]; loop { @@ -54,8 +57,48 @@ impl AllocationManager { "relay got {} bytes from {} for client {}", len, src, client_clone ); - // forward to client via server socket - let _ = server_sock_clone.send_to(&buf[..len], client_clone).await; + if let Some(allocation) = manager_clone.get_allocation(&client_clone) { + if !allocation.is_peer_allowed(&src) { + tracing::debug!( + "dropping peer packet {} -> {} (permission expired)", + src, + client_clone + ); + continue; + } + + if let Some(channel) = allocation.channel_for_peer(&src) { + let frame = build_channel_data(channel, &buf[..len]); + if let Err(e) = + server_sock_clone.send_to(&frame, client_clone).await + { + tracing::error!( + "failed to send channel data {} -> {}: {:?}", + src, + client_clone, + e + ); + } + } else { + let indication = build_data_indication(&src, &buf[..len]); + if let Err(e) = + server_sock_clone.send_to(&indication, client_clone).await + { + tracing::error!( + "failed to send data indication {} -> {}: {:?}", + src, + client_clone, + e + ); + } + } + } else { + tracing::debug!( + "allocation missing while forwarding from peer {} -> {}", + src, + client_clone + ); + } } Err(e) => { tracing::error!("relay socket error: {:?}", e); @@ -129,6 +172,15 @@ impl Allocation { bindings.get(&channel).map(|(peer, _)| *peer) } + /// Return the bound channel number for a peer if available. + pub fn channel_for_peer(&self, peer: &SocketAddr) -> Option { + let mut bindings = self.channel_bindings.lock().unwrap(); + prune_channel_bindings(&mut bindings); + bindings + .iter() + .find_map(|(channel, (addr, _))| if addr == peer { Some(*channel) } else { None }) + } + /// Forward payload to a TURN peer via the relay socket. pub async fn send_to_peer(&self, peer: SocketAddr, data: &[u8]) -> anyhow::Result { let sent = self._socket.send_to(data, peer).await?; diff --git a/src/constants.rs b/src/constants.rs index eff6ac6..18d9953 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -15,6 +15,7 @@ pub const METHOD_CHANNEL_BIND: u16 = 0x0009; // STUN/TURN class bits per RFC5389/RFC5766 pub const CLASS_SUCCESS: u16 = 0x0100; pub const CLASS_ERROR: u16 = 0x0110; +pub const CLASS_INDICATION: u16 = 0x0010; // Common attribute types pub const ATTR_USERNAME: u16 = 0x0006; diff --git a/src/main.rs b/src/main.rs index b863215..c3a62a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use niom_turn::config::{AuthOptions, Config}; use niom_turn::constants::*; use niom_turn::stun::{ build_401_response, build_error_response, build_success_response, decode_xor_peer_address, - encode_xor_relayed_address, parse_message, + encode_xor_relayed_address, parse_channel_data, parse_message, }; #[tokio::main] @@ -113,6 +113,55 @@ async fn udp_reader_loop( let (len, peer) = udp.recv_from(&mut buf).await?; tracing::debug!("got {} bytes from {}", len, peer); + if let Some((channel, payload)) = parse_channel_data(&buf[..len]) { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!("channel data without allocation from {}", peer); + continue; + } + }; + + let target = match allocation.channel_peer(channel) { + Some(addr) => addr, + None => { + tracing::warn!( + "channel data with unknown channel 0x{:04x} from {}", + channel, + peer + ); + continue; + } + }; + + if !allocation.is_peer_allowed(&target) { + tracing::warn!( + "channel data target {} no longer permitted for {}", + target, + peer + ); + continue; + } + + match allocation.send_to_peer(target, payload).await { + Ok(sent) => tracing::debug!( + "forwarded {} bytes via channel 0x{:04x} from {} to {}", + sent, + channel, + peer, + target + ), + Err(e) => tracing::error!( + "failed to forward channel data 0x{:04x} from {} to {}: {:?}", + channel, + peer, + target, + e + ), + } + continue; + } + // Minimal STUN/TURN detection: parse STUN messages and send 401 challenge if let Ok(msg) = parse_message(&buf[..len]) { tracing::info!( diff --git a/src/stun.rs b/src/stun.rs index fc31d0f..a1349ed 100644 --- a/src/stun.rs +++ b/src/stun.rs @@ -3,6 +3,7 @@ use crate::constants::*; use crate::models::stun::{StunAttribute, StunHeader, StunMessage}; use std::convert::TryInto; +use uuid::Uuid; #[derive(thiserror::Error, Debug)] pub enum ParseError { @@ -224,7 +225,68 @@ pub fn compute_message_integrity(key: &[u8], msg: &[u8]) -> Vec { /// Encode an IPv4 SocketAddr into XOR-RELAYED-ADDRESS attribute value. /// Format (per RFC5389/RFC5766): 1 byte family, 2 byte xport, 4 byte xaddr for IPv4 -pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec { +pub fn build_channel_data(channel: u16, data: &[u8]) -> Vec { + let mut out = Vec::with_capacity(4 + data.len()); + out.extend_from_slice(&channel.to_be_bytes()); + out.extend_from_slice(&(data.len() as u16).to_be_bytes()); + out.extend_from_slice(data); + while (out.len() % 4) != 0 { + out.push(0); + } + out +} + +pub fn build_data_indication(peer: &std::net::SocketAddr, data: &[u8]) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = METHOD_DATA | CLASS_INDICATION; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + let uuid = Uuid::new_v4(); + let mut trans_id = [0u8; 12]; + trans_id.copy_from_slice(&uuid.as_bytes()[..12]); + buf.extend_from_slice(&trans_id); + + let addr_val = encode_xor_peer_address(peer, &trans_id); + buf.extend_from_slice(&ATTR_XOR_PEER_ADDRESS.to_be_bytes()); + buf.extend_from_slice(&(addr_val.len() as u16).to_be_bytes()); + buf.extend_from_slice(&addr_val); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + buf.extend_from_slice(&ATTR_DATA.to_be_bytes()); + buf.extend_from_slice(&(data.len() as u16).to_be_bytes()); + buf.extend_from_slice(data); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + buf.to_vec() +} + +pub fn parse_channel_data(buf: &[u8]) -> Option<(u16, &[u8])> { + if buf.len() < 4 { + return None; + } + let channel = u16::from_be_bytes([buf[0], buf[1]]); + if (channel & 0xC000) != 0x4000 { + return None; + } + let data_len = u16::from_be_bytes([buf[2], buf[3]]) as usize; + if buf.len() < 4 + data_len { + return None; + } + Some((channel, &buf[4..4 + data_len])) +} + +fn encode_xor_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec { use std::net::IpAddr; let mut out = Vec::new(); match addr.ip() { @@ -250,6 +312,14 @@ pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8; out } +pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12]) -> Vec { + encode_xor_address(addr, trans_id) +} + +pub fn encode_xor_peer_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12]) -> Vec { + encode_xor_address(addr, trans_id) +} + /// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only) pub fn decode_xor_relayed_address( value: &[u8],