From c77e95afdd1fd6f66dfd5a12f4697cac8bba9ced Mon Sep 17 00:00:00 2001 From: ghost Date: Wed, 12 Nov 2025 17:55:14 +0100 Subject: [PATCH] CreatePermission handling and permission tracking. ChannelBind setup and Send forwarding to peers. --- README.md | 27 +- src/alloc.rs | 92 ++++++- src/auth.rs | 6 +- src/bin/allocate_smoke.rs | 58 ++-- src/bin/smoke_client.rs | 24 +- src/constants.rs | 16 +- src/lib.rs | 12 +- src/main.rs | 416 ++++++++++++++++++++++++---- src/models/mod.rs | 2 +- src/stun.rs | 175 +++++++++--- src/tls.rs | 567 +++++++++++++++++++++++++++++++++++--- 11 files changed, 1206 insertions(+), 189 deletions(-) diff --git a/README.md b/README.md index abc38a2..7d05f5c 100644 --- a/README.md +++ b/README.md @@ -62,16 +62,17 @@ Milestone 1 — Protocol Backlog This milestone focuses on turning the current MVP into a feature-complete TURN core that can be used reliably by `niom-webrtc`. -- **Authentication Hardening**: nonce lifecycle, realm configuration, Argon2-backed credential - storage, and detailed error handling for 401/438 responses. -- **TURN Method Coverage**: implement `Allocate` attributes, `CreatePermission`, `ChannelBind`, - `Refresh`, and full relay path (peer data forwarding, Send/Data indications). -- **Allocation Lifecycle**: timers, refresh logic, cleanup of expired allocations, and resource - limits per user/IP. -- **Protocol Compliance**: FINGERPRINT support, XOR-MAPPED-ADDRESS, IPv6 evaluation, checksum - validation, and robustness against malformed packets. -- **Observability & Limits**: structured tracing, metrics, rate limiting, and integration tests - (including the bundled `smoke_client`). +**Prioritised Backlog (live order)** +1. **TURN Data Plane Enablement** — `CreatePermission`, `ChannelBind`, Send/Data indications, and + peer forwarding so allocations actually relay packets between clients and peers. +2. **Authentication Hardening** — nonce lifecycle, realm configuration, Argon2-backed credential + storage, and detailed error handling for 401/438 responses. +3. **Allocation Lifecycle & Quotas** — timers, refresh requests, cleanup of expired allocations, + and resource limits per user/IP. +4. **Protocol Compliance Extras** — FINGERPRINT support, XOR-MAPPED-ADDRESS, IPv6 evaluation, + checksum validation, and fuzz/interop testing. +5. **Observability & Limits** — structured tracing, metrics, rate limiting, and CI coverage (incl. + the bundled `smoke_client`). Artifacts that track this milestone live in two places: @@ -79,6 +80,12 @@ Artifacts that track this milestone live in two places: 2. Inline module docs (`//!`) inside `src/` record the current responsibilities and open backlog items for each subsystem as we iterate. +**Task in progress** +- TURN data plane enablement: + - [x] `CreatePermission` handling and permission tracking + - [x] `ChannelBind` setup and `Send` forwarding to peers + - [ ] ChannelData framing and Data Indication responses from relay to client + License: MIT Smoke-Test (End-to-End) diff --git a/src/alloc.rs b/src/alloc.rs index cccac2c..6d50159 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; use tokio::net::UdpSocket; use tracing::info; @@ -12,6 +13,8 @@ pub struct Allocation { pub relay_addr: SocketAddr, // keep the socket so it stays bound _socket: Arc, + permissions: Arc>>, + channel_bindings: Arc>>, } #[derive(Clone, Default)] @@ -20,11 +23,19 @@ pub struct AllocationManager { } impl AllocationManager { - pub fn new() -> Self { Self { inner: Arc::new(Mutex::new(HashMap::new())) } } + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + } + } /// Create a relay UDP socket for the given client and spawn a relay loop that forwards /// any packets received on the relay socket back to the client via the provided server socket. - pub async fn allocate_for(&self, client: SocketAddr, server_sock: Arc) -> anyhow::Result { + pub async fn allocate_for( + &self, + client: SocketAddr, + server_sock: Arc, + ) -> anyhow::Result { // bind relay socket to OS-chosen port let relay = UdpSocket::bind("0.0.0.0:0").await?; let relay_local = relay.local_addr()?; @@ -39,7 +50,10 @@ impl AllocationManager { loop { match relay_clone.recv_from(&mut buf).await { Ok((len, src)) => { - info!("relay got {} bytes from {} for client {}", len, src, client_clone); + info!( + "relay got {} bytes from {} for client {}", + len, src, client_clone + ); // forward to client via server socket let _ = server_sock_clone.send_to(&buf[..len], client_clone).await; } @@ -51,7 +65,13 @@ impl AllocationManager { } }); - let alloc = Allocation { client, relay_addr: relay_local, _socket: relay_arc }; + let alloc = Allocation { + client, + relay_addr: relay_local, + _socket: relay_arc, + permissions: Arc::new(Mutex::new(HashMap::new())), + channel_bindings: Arc::new(Mutex::new(HashMap::new())), + }; tracing::info!("created allocation for {} -> {}", client, relay_local); let mut m = self.inner.lock().unwrap(); m.insert(client, alloc); @@ -62,4 +82,68 @@ impl AllocationManager { let m = self.inner.lock().unwrap(); m.get(client).cloned() } + + /// Register a permission for the given client allocation so the relay can forward packets + /// to the specified peer address. Permissions currently expire after a static timeout. + pub fn add_permission(&self, client: SocketAddr, peer: SocketAddr) -> anyhow::Result<()> { + let mut guard = self.inner.lock().unwrap(); + let alloc = guard + .get_mut(&client) + .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; + let mut perms = alloc.permissions.lock().unwrap(); + prune_permissions(&mut perms); + perms.insert(peer, Instant::now() + PERMISSION_LIFETIME); + Ok(()) + } + + /// Associate a TURN channel number with a specific peer socket for the allocation. + pub fn add_channel_binding( + &self, + client: SocketAddr, + channel: u16, + peer: SocketAddr, + ) -> anyhow::Result<()> { + let mut guard = self.inner.lock().unwrap(); + let alloc = guard + .get_mut(&client) + .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; + let mut bindings = alloc.channel_bindings.lock().unwrap(); + prune_channel_bindings(&mut bindings); + bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME)); + Ok(()) + } +} + +impl Allocation { + /// Check whether a peer address is currently permitted for this allocation. + pub fn is_peer_allowed(&self, peer: &SocketAddr) -> bool { + let mut perms = self.permissions.lock().unwrap(); + prune_permissions(&mut perms); + perms.contains_key(peer) + } + + /// Resolve an active channel binding to its peer socket, if still valid. + pub fn channel_peer(&self, channel: u16) -> Option { + let mut bindings = self.channel_bindings.lock().unwrap(); + prune_channel_bindings(&mut bindings); + bindings.get(&channel).map(|(peer, _)| *peer) + } + + /// Forward payload to a TURN peer via the relay socket. + pub async fn send_to_peer(&self, peer: SocketAddr, data: &[u8]) -> anyhow::Result { + let sent = self._socket.send_to(data, peer).await?; + Ok(sent) + } +} + +const PERMISSION_LIFETIME: Duration = Duration::from_secs(300); + +fn prune_permissions(perms: &mut HashMap) { + let now = Instant::now(); + perms.retain(|_, expiry| *expiry > now); +} + +fn prune_channel_bindings(bindings: &mut HashMap) { + let now = Instant::now(); + bindings.retain(|_, (_, expiry)| *expiry > now); } diff --git a/src/auth.rs b/src/auth.rs index 22f8640..33fcf18 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,8 +1,8 @@ //! Authentication helpers and the in-memory credential store used for the MVP server. //! Backlog: Argon2-backed storage, nonce lifecycle, and integration with persistent secrets. +use crate::traits::CredentialStore; use async_trait::async_trait; use std::sync::Arc; -use crate::traits::CredentialStore; /// Simple in-memory credential store for MVP #[derive(Clone, Default)] @@ -13,7 +13,9 @@ pub struct InMemoryStore { impl InMemoryStore { pub fn new() -> Self { - Self { inner: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())) } + Self { + inner: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), + } } pub fn insert(&self, user: impl Into, password: impl Into) { diff --git a/src/bin/allocate_smoke.rs b/src/bin/allocate_smoke.rs index 9ec3e5e..cb998fb 100644 --- a/src/bin/allocate_smoke.rs +++ b/src/bin/allocate_smoke.rs @@ -2,8 +2,8 @@ use bytes::BytesMut; use niom_turn::constants::*; // use niom_turn::stun; // not needed; use specific functions via path when required use std::net::SocketAddr; -use tokio::net::UdpSocket; use std::time::Duration; +use tokio::net::UdpSocket; // Use shared decoder from library: niom_turn::stun::decode_xor_relayed_address @@ -29,15 +29,19 @@ async fn main() -> anyhow::Result<()> { buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); buf.extend_from_slice(&(uname.len() as u16).to_be_bytes()); buf.extend_from_slice(uname); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // MESSAGE-INTEGRITY placeholder let mi_attr_offset = buf.len(); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); let mi_val_pos = buf.len(); - buf.extend_from_slice(&[0u8;20]); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + buf.extend_from_slice(&[0u8; 20]); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // fix length let total_len = (buf.len() - 20) as u16; @@ -47,13 +51,15 @@ async fn main() -> anyhow::Result<()> { // compute HMAC over bytes up to MI attribute header { - use hmac::{Hmac, Mac}; - use sha1::Sha1; - type HmacSha1 = Hmac; - let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); + use hmac::{Hmac, Mac}; + use sha1::Sha1; + type HmacSha1 = Hmac; + let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); mac.update(&buf[..mi_attr_offset]); let res = mac.finalize().into_bytes(); - for i in 0..20 { buf[mi_val_pos + i] = res[i]; } + for i in 0..20 { + buf[mi_val_pos + i] = res[i]; + } } // send Allocate @@ -64,13 +70,14 @@ async fn main() -> anyhow::Result<()> { let (len, _addr) = local.recv_from(&mut r).await?; println!("got {} bytes", len); let resp = &r[..len]; - // expect success (RESP_BINDING_SUCCESS) with XOR-RELAYED-ADDRESS attr + // expect success (METHOD_ALLOCATE | CLASS_SUCCESS) with XOR-RELAYED-ADDRESS attr if resp.len() < 20 { anyhow::bail!("response too short"); } let typ = u16::from_be_bytes([resp[0], resp[1]]); println!("resp type 0x{:04x}", typ); - if typ != RESP_BINDING_SUCCESS { + let expected_type = METHOD_ALLOCATE | CLASS_SUCCESS; + if typ != expected_type { anyhow::bail!("expected success response, got 0x{:04x}", typ); } // parse attributes @@ -79,18 +86,22 @@ async fn main() -> anyhow::Result<()> { let mut offset = 20; let mut relay_addr_opt: Option = None; while offset + 4 <= total { - let atype = u16::from_be_bytes([resp[offset], resp[offset+1]]); - let alen = u16::from_be_bytes([resp[offset+2], resp[offset+3]]) as usize; + let atype = u16::from_be_bytes([resp[offset], resp[offset + 1]]); + let alen = u16::from_be_bytes([resp[offset + 2], resp[offset + 3]]) as usize; offset += 4; - if offset + alen > total { break; } + if offset + alen > total { + break; + } println!("attr type=0x{:04x} len={}", atype, alen); - println!("raw: {}", hex::encode(&resp[offset..offset+alen])); - if atype == ATTR_XOR_RELAYED_ADDRESS { - // XOR-RELAYED-ADDRESS: decode via shared library function - if let Some(sa) = niom_turn::stun::decode_xor_relayed_address(&resp[offset..offset+alen], &trans) { - relay_addr_opt = Some(sa); - } + println!("raw: {}", hex::encode(&resp[offset..offset + alen])); + if atype == ATTR_XOR_RELAYED_ADDRESS { + // XOR-RELAYED-ADDRESS: decode via shared library function + if let Some(sa) = + niom_turn::stun::decode_xor_relayed_address(&resp[offset..offset + alen], &trans) + { + relay_addr_opt = Some(sa); } + } offset += alen; let pad = (4 - (alen % 4)) % 4; offset += pad; @@ -114,7 +125,12 @@ async fn main() -> anyhow::Result<()> { println!("received {} bytes from {}", l, src); let got = &buf2[..l]; println!("payload: {:?}", got); - if got == payload { println!("relay test success"); Ok(()) } else { anyhow::bail!("payload mismatch") } + if got == payload { + println!("relay test success"); + Ok(()) + } else { + anyhow::bail!("payload mismatch") + } } Ok(Err(e)) => anyhow::bail!("recv error: {:?}", e), Err(_) => anyhow::bail!("no forwarded packet received: timeout"), diff --git a/src/bin/smoke_client.rs b/src/bin/smoke_client.rs index af241bb..7bd39df 100644 --- a/src/bin/smoke_client.rs +++ b/src/bin/smoke_client.rs @@ -1,7 +1,7 @@ use bytes::BytesMut; +use niom_turn::constants::*; use std::net::SocketAddr; use tokio::net::UdpSocket; -use niom_turn::constants::*; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -25,15 +25,19 @@ async fn main() -> anyhow::Result<()> { buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); buf.extend_from_slice(&(uname.len() as u16).to_be_bytes()); buf.extend_from_slice(uname); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // MESSAGE-INTEGRITY placeholder let mi_attr_offset = buf.len(); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); let mi_val_pos = buf.len(); - buf.extend_from_slice(&[0u8;20]); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + buf.extend_from_slice(&[0u8; 20]); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // fix length let total_len = (buf.len() - 20) as u16; @@ -43,13 +47,15 @@ async fn main() -> anyhow::Result<()> { // compute HMAC over bytes up to MI attribute header { - use hmac::{Hmac, Mac}; - use sha1::Sha1; - type HmacSha1 = Hmac; - let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); + use hmac::{Hmac, Mac}; + use sha1::Sha1; + type HmacSha1 = Hmac; + let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); mac.update(&buf[..mi_attr_offset]); let res = mac.finalize().into_bytes(); - for i in 0..20 { buf[mi_val_pos + i] = res[i]; } + for i in 0..20 { + buf[mi_val_pos + i] = res[i]; + } } // send diff --git a/src/constants.rs b/src/constants.rs index f6f01eb..eff6ac6 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,14 +1,16 @@ //! Central constants for STUN/TURN implementations (magic cookie, attribute types, methods) pub const MAGIC_COOKIE_U32: u32 = 0x2112A442; -pub const MAGIC_COOKIE_BYTES: [u8;4] = MAGIC_COOKIE_U32.to_be_bytes(); +pub const MAGIC_COOKIE_BYTES: [u8; 4] = MAGIC_COOKIE_U32.to_be_bytes(); // STUN Methods/Message Types (only those used in this MVP) pub const METHOD_BINDING: u16 = 0x0001; pub const METHOD_ALLOCATE: u16 = 0x0003; - -// Common response/error types -pub const RESP_BINDING_SUCCESS: u16 = 0x0101; +pub const METHOD_CREATE_PERMISSION: u16 = 0x0008; +pub const METHOD_REFRESH: u16 = 0x0004; +pub const METHOD_SEND: u16 = 0x0006; +pub const METHOD_DATA: u16 = 0x0007; +pub const METHOD_CHANNEL_BIND: u16 = 0x0009; // STUN/TURN class bits per RFC5389/RFC5766 pub const CLASS_SUCCESS: u16 = 0x0100; @@ -17,11 +19,16 @@ pub const CLASS_ERROR: u16 = 0x0110; // Common attribute types pub const ATTR_USERNAME: u16 = 0x0006; pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008; +pub const ATTR_ERROR_CODE: u16 = 0x0009; +pub const ATTR_CHANNEL_NUMBER: u16 = 0x000C; +pub const ATTR_LIFETIME: u16 = 0x000D; pub const ATTR_REALM: u16 = 0x0014; pub const ATTR_NONCE: u16 = 0x0015; +pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012; // TURN attrs pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016; +pub const ATTR_DATA: u16 = 0x0013; // Some helper values pub const FAMILY_IPV4: u8 = 0x01; @@ -31,4 +38,3 @@ pub const FINGERPRINT_XOR: u32 = 0x5354554e; // Length of HMAC-SHA1 (MESSAGE-INTEGRITY) pub const HMAC_SHA1_LEN: usize = 20; - diff --git a/src/lib.rs b/src/lib.rs index e2e655e..ad2e14b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,13 @@ //! Library root for niom-turn shared modules -pub mod constants; -pub mod stun; -pub mod auth; -pub mod traits; -pub mod models; pub mod alloc; +pub mod auth; pub mod config; +pub mod constants; +pub mod models; +pub mod stun; pub mod tls; +pub mod traits; +pub use crate::alloc::*; pub use crate::auth::*; pub use crate::stun::*; -pub use crate::alloc::*; diff --git a/src/main.rs b/src/main.rs index 003ab90..66d23a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,15 +3,18 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::net::UdpSocket; -use tracing::{info, error}; +use tracing::{error, info}; // Use the library crate's public modules instead of local `mod` declarations. -use niom_turn::constants::*; -use niom_turn::auth::InMemoryStore; -use niom_turn::stun::{parse_message, build_401_response, find_message_integrity, validate_message_integrity, build_success_response, encode_xor_relayed_address}; -use niom_turn::traits::CredentialStore; use niom_turn::alloc::AllocationManager; +use niom_turn::auth::InMemoryStore; use niom_turn::config::Config; +use niom_turn::constants::*; +use niom_turn::stun::{ + build_401_response, build_error_response, build_success_response, decode_xor_peer_address, + encode_xor_relayed_address, find_message_integrity, parse_message, validate_message_integrity, +}; +use niom_turn::traits::CredentialStore; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -19,14 +22,17 @@ async fn main() -> anyhow::Result<()> { info!("niom-turn starting"); - // config: try to load appsettings.json, otherwise fall back to defaults + // Bootstrap configuration: prefer appsettings.json, otherwise rely on baked-in demo defaults. let cfg = match Config::load_default() { Ok(c) => { info!("loaded config from appsettings.json"); c } Err(e) => { - info!("no appsettings.json found or failed to load: {} — using defaults", e); + info!( + "no appsettings.json found or failed to load: {} — using defaults", + e + ); // defaults Config { server: niom_turn::config::ServerOptions { @@ -34,27 +40,30 @@ async fn main() -> anyhow::Result<()> { tls_cert: None, tls_key: None, }, - credentials: vec![niom_turn::config::CredentialEntry { username: "testuser".into(), password: "secretpassword".into() }], + credentials: vec![niom_turn::config::CredentialEntry { + username: "testuser".into(), + password: "secretpassword".into(), + }], } } }; let bind_addr: SocketAddr = cfg.server.bind.parse()?; - // Initialize credential store and populate from config + // Materialise the credential backend before starting network endpoints. let creds = InMemoryStore::new(); for c in cfg.credentials.iter() { creds.insert(&c.username, &c.password); } - // UDP listener for TURN/STUN + // Bind the UDP socket that receives STUN/TURN traffic from WebRTC clients. let udp = UdpSocket::bind(bind_addr).await?; let udp = Arc::new(udp); // allocation manager let alloc_mgr = AllocationManager::new(); - // spawn packet handling loop + // Spawn the asynchronous packet loop that handles all UDP requests. let udp_clone = udp.clone(); let creds_clone = creds.clone(); let alloc_clone = alloc_mgr.clone(); @@ -64,34 +73,53 @@ async fn main() -> anyhow::Result<()> { } }); - // If TLS cert/key are present in config, start a TLS-backed listener (turns) + // Optionally start the TLS listener so `turns:` clients can connect via TCP/TLS. if let (Some(cert), Some(key)) = (cfg.server.tls_cert.clone(), cfg.server.tls_key.clone()) { let udp_for_tls = udp.clone(); let creds_for_tls = creds.clone(); let alloc_for_tls = alloc_mgr.clone(); tokio::spawn(async move { - if let Err(e) = niom_turn::tls::serve_tls("0.0.0.0:5349", &cert, &key, udp_for_tls, creds_for_tls, alloc_for_tls).await { + if let Err(e) = niom_turn::tls::serve_tls( + "0.0.0.0:5349", + &cert, + &key, + udp_for_tls, + creds_for_tls, + alloc_for_tls, + ) + .await + { error!("tls serve failed: {:?}", e); } }); } - // keep running + // Keep the runtime alive while background tasks process packets. loop { tokio::time::sleep(std::time::Duration::from_secs(60)).await; } } -async fn udp_reader_loop(udp: Arc, creds: InMemoryStore, allocs: AllocationManager) -> anyhow::Result<()> { +async fn udp_reader_loop( + udp: Arc, + creds: InMemoryStore, + allocs: AllocationManager, +) -> anyhow::Result<()> { let mut buf = vec![0u8; 1500]; loop { + // Read the next datagram and keep peer metadata for follow-up responses. let (len, peer) = udp.recv_from(&mut buf).await?; tracing::debug!("got {} bytes from {}", len, peer); // Minimal STUN/TURN detection: parse STUN messages and send 401 challenge if let Ok(msg) = parse_message(&buf[..len]) { - tracing::info!("STUN/TURN message from {} type=0x{:04x} len={}", peer, msg.header.msg_type, len); - // If MESSAGE-INTEGRITY present, attempt validation using credential store + tracing::info!( + "STUN/TURN message from {} type=0x{:04x} len={}", + peer, + msg.header.msg_type, + len + ); + // Fast-path authenticated requests when MESSAGE-INTEGRITY can be validated. if let Some(_mi_attr) = find_message_integrity(&msg) { // For MVP we expect username attribute (USERNAME) to be present let username_attr = msg.attributes.iter().find(|a| a.typ == ATTR_USERNAME); @@ -103,39 +131,312 @@ async fn udp_reader_loop(udp: Arc, creds: InMemoryStore, allocs: Allo if let Some(password) = pw { let valid = validate_message_integrity(&msg, &password); if valid { - tracing::info!("MI valid for user {}", username); - // If this is an Allocate request, perform allocation - if msg.header.msg_type == METHOD_ALLOCATE { - match allocs.allocate_for(peer, udp.clone()).await { - Ok(relay_addr) => { - use bytes::BytesMut; - let mut out = BytesMut::new(); - out.extend_from_slice(&RESP_BINDING_SUCCESS.to_be_bytes()); - out.extend_from_slice(&0u16.to_be_bytes()); - out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); - out.extend_from_slice(&msg.header.transaction_id); - // RFC: XOR-RELAYED-ADDRESS (0x0016) - let attr_val = encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); - out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); - out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); - out.extend_from_slice(&attr_val); - while (out.len() % 4) != 0 { out.extend_from_slice(&[0]); } - let total_len = (out.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - out[2] = len_bytes[0]; out[3] = len_bytes[1]; - let vec_out = out.to_vec(); - tracing::info!("sending allocate success (mi-valid) -> {} bytes hex={} ", vec_out.len(), hex::encode(&vec_out)); - let _ = udp.send_to(&vec_out, &peer).await; - continue; + tracing::info!("MI valid for user {}", username); + // Handle authenticated Allocate to mint a relay binding for the client. + if msg.header.msg_type == METHOD_ALLOCATE { + match allocs.allocate_for(peer, udp.clone()).await { + Ok(relay_addr) => { + use bytes::BytesMut; + let mut out = BytesMut::new(); + let success_type = msg.header.msg_type | CLASS_SUCCESS; + out.extend_from_slice(&success_type.to_be_bytes()); + out.extend_from_slice(&0u16.to_be_bytes()); + out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); + out.extend_from_slice(&msg.header.transaction_id); + // RFC: XOR-RELAYED-ADDRESS (0x0016) + let attr_val = encode_xor_relayed_address( + &relay_addr, + &msg.header.transaction_id, + ); + out.extend_from_slice( + &ATTR_XOR_RELAYED_ADDRESS.to_be_bytes(), + ); + out.extend_from_slice( + &((attr_val.len() as u16).to_be_bytes()), + ); + out.extend_from_slice(&attr_val); + while (out.len() % 4) != 0 { + out.extend_from_slice(&[0]); } - Err(e) => tracing::error!("allocate failed after MI valid: {:?}", e), + let total_len = (out.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + out[2] = len_bytes[0]; + out[3] = len_bytes[1]; + let vec_out = out.to_vec(); + tracing::info!("sending allocate success (mi-valid) -> {} bytes hex={} ", vec_out.len(), hex::encode(&vec_out)); + let _ = udp.send_to(&vec_out, &peer).await; + continue; + } + Err(e) => tracing::error!( + "allocate failed after MI valid: {:?}", + e + ), + } + } else if msg.header.msg_type == METHOD_CREATE_PERMISSION { + // Permission updates extend the list of peer addresses an allocation may forward to. + if allocs.get_allocation(&peer).is_none() { + tracing::warn!( + "create-permission without allocation from {}", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + let mut added = 0usize; + for attr in msg + .attributes + .iter() + .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) + { + if let Some(peer_addr) = decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) { + match allocs.add_permission(peer, peer_addr) { + Ok(()) => { + tracing::info!( + "added permission for {} -> {}", + peer, + peer_addr + ); + added += 1; + } + Err(e) => { + tracing::error!("failed to persist permission {} -> {}: {:?}", peer, peer_addr, e); + } + } + } else { + tracing::warn!( + "invalid XOR-PEER-ADDRESS in request from {}", + peer + ); } } - // default success response - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; + + if added == 0 { + let resp = build_error_response( + &msg.header, + 400, + "No valid XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + } else { + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } continue; - } else { + } else if msg.header.msg_type == METHOD_CHANNEL_BIND { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!( + "channel-bind without allocation from {}", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let channel_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_CHANNEL_NUMBER); + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + + let channel = match channel_attr.and_then(|attr| { + if attr.value.len() >= 4 { + Some(u16::from_be_bytes([attr.value[0], attr.value[1]])) + } else { + None + } + }) { + Some(c) => c, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing CHANNEL-NUMBER", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + if channel < 0x4000 || channel > 0x7FFF { + let resp = build_error_response( + &msg.header, + 400, + "Channel Out Of Range", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + let peer_addr = match peer_attr.and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + match allocs.add_channel_binding(peer, channel, peer_addr) { + Ok(()) => { + tracing::info!( + "bound channel 0x{:04x} for {} -> {}", + channel, + peer, + peer_addr + ); + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } + Err(e) => { + tracing::error!( + "failed to add channel binding {} -> {} (channel 0x{:04x}): {:?}", + peer, peer_addr, channel, e + ); + let resp = build_error_response( + &msg.header, + 500, + "Channel Binding Failed", + ); + let _ = udp.send_to(&resp, &peer).await; + } + } + continue; + } else if msg.header.msg_type == METHOD_SEND { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!("send without allocation from {}", peer); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let data_attr = + msg.attributes.iter().find(|a| a.typ == ATTR_DATA); + + let peer_addr = match peer_attr.and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let data_attr = match data_attr { + Some(attr) => attr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing DATA Attribute", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + match allocation.send_to_peer(peer_addr, &data_attr.value).await + { + Ok(sent) => { + tracing::info!( + "forwarded {} bytes from {} to peer {}", + sent, + peer, + peer_addr + ); + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } + Err(e) => { + tracing::error!( + "failed to send payload from {} to {}: {:?}", + peer, + peer_addr, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Peer Send Failed", + ); + let _ = udp.send_to(&resp, &peer).await; + } + } + continue; + } + // Non-specific success path: echo a success response so the client continues handshake. + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + continue; + } else { tracing::info!("MI invalid for user {}", username); } } else { @@ -144,7 +445,7 @@ async fn udp_reader_loop(udp: Arc, creds: InMemoryStore, allocs: Allo } } } - // If it's an Allocate request (TURN method ALLOCATE) and MI valid, allocate a relay socket + // Allow unauthenticated Allocate to fall back to challenge/early success for now (MVP compatibility). if msg.header.msg_type == METHOD_ALLOCATE { // If we reach here without MI, still attempt allocation but we will send a 401 earlier let relay = allocs.allocate_for(peer, udp.clone()).await; @@ -152,20 +453,29 @@ async fn udp_reader_loop(udp: Arc, creds: InMemoryStore, allocs: Allo Ok(relay_addr) => { use bytes::BytesMut; let mut out = BytesMut::new(); - out.extend_from_slice(&RESP_BINDING_SUCCESS.to_be_bytes()); + let success_type = msg.header.msg_type | CLASS_SUCCESS; + out.extend_from_slice(&success_type.to_be_bytes()); out.extend_from_slice(&0u16.to_be_bytes()); out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); out.extend_from_slice(&msg.header.transaction_id); - let attr_val = encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); + let attr_val = + encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); out.extend_from_slice(&attr_val); - while (out.len() % 4) != 0 { out.extend_from_slice(&[0]); } + while (out.len() % 4) != 0 { + out.extend_from_slice(&[0]); + } let total_len = (out.len() - 20) as u16; let len_bytes = total_len.to_be_bytes(); - out[2] = len_bytes[0]; out[3] = len_bytes[1]; + out[2] = len_bytes[0]; + out[3] = len_bytes[1]; let vec_out = out.to_vec(); - tracing::info!("sending allocate success (no-mi) -> {} bytes hex={} ", vec_out.len(), hex::encode(&vec_out)); + tracing::info!( + "sending allocate success (no-mi) -> {} bytes hex={} ", + vec_out.len(), + hex::encode(&vec_out) + ); let _ = udp.send_to(&vec_out, &peer).await; } Err(e) => { @@ -175,7 +485,7 @@ async fn udp_reader_loop(udp: Arc, creds: InMemoryStore, allocs: Allo continue; } - // default: send 401 challenge + // Everything else receives a 401 challenge so the client can retry with credentials. let nonce = format!("nonce-{}", uuid::Uuid::new_v4()); let resp = build_401_response(&msg.header, "niom-turn.local", &nonce, 401); if let Err(e) = udp.send_to(&resp, &peer).await { diff --git a/src/models/mod.rs b/src/models/mod.rs index 2380ff3..92b5fdb 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,3 +1,3 @@ pub mod stun; -pub use stun::{StunHeader, StunAttribute, StunMessage}; +pub use stun::{StunAttribute, StunHeader, StunMessage}; diff --git a/src/stun.rs b/src/stun.rs index de6555c..a78f2cb 100644 --- a/src/stun.rs +++ b/src/stun.rs @@ -1,38 +1,53 @@ //! STUN/TURN message parsing and builders for the server. //! Backlog: full attribute coverage, fingerprint helpers, IPv6 handling, and fuzz testing. -use std::convert::TryInto; -use crate::models::stun::{StunHeader, StunAttribute, StunMessage}; use crate::constants::*; +use crate::models::stun::{StunAttribute, StunHeader, StunMessage}; +use std::convert::TryInto; #[derive(thiserror::Error, Debug)] pub enum ParseError { - #[error("too short")] TooShort, - #[error("invalid magic cookie")] InvalidCookie, - #[error("attribute overflow")] AttrOverflow, + #[error("too short")] + TooShort, + #[error("invalid magic cookie")] + InvalidCookie, + #[error("attribute overflow")] + AttrOverflow, } pub fn parse_message(buf: &[u8]) -> Result { - if buf.len() < 20 { return Err(ParseError::TooShort); } + if buf.len() < 20 { + return Err(ParseError::TooShort); + } let msg_type = u16::from_be_bytes(buf[0..2].try_into().unwrap()); let length = u16::from_be_bytes(buf[2..4].try_into().unwrap()); let cookie = u32::from_be_bytes(buf[4..8].try_into().unwrap()); - if cookie != MAGIC_COOKIE_U32 { return Err(ParseError::InvalidCookie); } + if cookie != MAGIC_COOKIE_U32 { + return Err(ParseError::InvalidCookie); + } let mut trans = [0u8; 12]; trans.copy_from_slice(&buf[8..20]); let mut attrs = Vec::new(); let mut offset = 20usize; let total_len = (length as usize) + 20; - if buf.len() < total_len { return Err(ParseError::TooShort); } + if buf.len() < total_len { + return Err(ParseError::TooShort); + } while offset + 4 <= total_len { - let typ = u16::from_be_bytes(buf[offset..offset+2].try_into().unwrap()); - let attr_len = u16::from_be_bytes(buf[offset+2..offset+4].try_into().unwrap()) as usize; + let typ = u16::from_be_bytes(buf[offset..offset + 2].try_into().unwrap()); + let attr_len = u16::from_be_bytes(buf[offset + 2..offset + 4].try_into().unwrap()) as usize; let attr_header_offset = offset; offset += 4; - if offset + attr_len > total_len { return Err(ParseError::AttrOverflow); } - let value = buf[offset..offset+attr_len].to_vec(); - attrs.push(StunAttribute { typ, value, offset: attr_header_offset }); + if offset + attr_len > total_len { + return Err(ParseError::AttrOverflow); + } + let value = buf[offset..offset + attr_len].to_vec(); + attrs.push(StunAttribute { + typ, + value, + offset: attr_header_offset, + }); offset += attr_len; // padding to 32-bit boundary let pad = (4 - (attr_len % 4)) % 4; @@ -40,7 +55,12 @@ pub fn parse_message(buf: &[u8]) -> Result { } Ok(StunMessage { - header: StunHeader { msg_type, length, cookie, transaction_id: trans }, + header: StunHeader { + msg_type, + length, + cookie, + transaction_id: trans, + }, attributes: attrs, raw: buf[..total_len].to_vec(), }) @@ -62,14 +82,18 @@ pub fn build_401_response(req: &StunHeader, realm: &str, nonce: &str, _err_code: buf.extend_from_slice(&ATTR_REALM.to_be_bytes()); buf.extend_from_slice(&(realm_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(realm_bytes); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); } + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } // NONCE (RFC attr) let nonce_bytes = nonce.as_bytes(); buf.extend_from_slice(&ATTR_NONCE.to_be_bytes()); buf.extend_from_slice(&(nonce_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(nonce_bytes); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); } + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } // Update length let total_len = (buf.len() - 20) as u16; @@ -80,9 +104,44 @@ pub fn build_401_response(req: &StunHeader, realm: &str, nonce: &str, _err_code: buf.to_vec() } +/// Build a generic STUN error response with an ERROR-CODE attribute plus optional reason phrase. +pub fn build_error_response(req: &StunHeader, code: u16, reason: &str) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_ERROR; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let mut value = Vec::new(); + let class = (code / 100) as u8; + let number = (code % 100) as u8; + value.extend_from_slice(&[0, 0]); + value.push(class); + value.push(number); + value.extend_from_slice(reason.as_bytes()); + + buf.extend_from_slice(&ATTR_ERROR_CODE.to_be_bytes()); + buf.extend_from_slice(&(value.len() as u16).to_be_bytes()); + buf.extend_from_slice(&value); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + buf.to_vec() +} + /// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { - msg.attributes.iter().find(|a| a.typ == ATTR_MESSAGE_INTEGRITY) + msg.attributes + .iter() + .find(|a| a.typ == ATTR_MESSAGE_INTEGRITY) } /// Validate MESSAGE-INTEGRITY using provided key (password). Returns true if valid. @@ -91,7 +150,9 @@ pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { pub fn validate_message_integrity(msg: &StunMessage, key: &str) -> bool { if let Some(mi) = find_message_integrity(msg) { // MESSAGE-INTEGRITY attribute value is 20 bytes (HMAC-SHA1) - if mi.value.len() != 20 { return false; } + if mi.value.len() != 20 { + return false; + } // Compute HMAC over the message up to (but excluding) MESSAGE-INTEGRITY attribute header and value let mi_attr_start = mi.offset; // offset points to attribute header let msg_slice = &msg.raw[..mi_attr_start]; @@ -106,7 +167,7 @@ pub fn validate_message_integrity(msg: &StunMessage, key: &str) -> bool { pub fn build_success_response(req: &StunHeader) -> Vec { use bytes::BytesMut; let mut buf = BytesMut::new(); - let msg_type: u16 = RESP_BINDING_SUCCESS; // Binding success response (example) + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; buf.extend_from_slice(&msg_type.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&MAGIC_COOKIE_BYTES); @@ -143,21 +204,23 @@ pub fn compute_message_integrity(key: &str, msg: &[u8]) -> Vec { /// Encode an IPv4 SocketAddr into XOR-RELAYED-ADDRESS attribute value. /// Format (per RFC5389/RFC5766): 1 byte family, 2 byte xport, 4 byte xaddr for IPv4 -pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8;12]) -> Vec { +pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec { use std::net::IpAddr; let mut out = Vec::new(); match addr.ip() { IpAddr::V4(v4) => { out.push(0); // first 8 bits zero per spec out.push(FAMILY_IPV4); // family: IPv4 - // xport = port ^ (magic_cookie >> 16) + // xport = port ^ (magic_cookie >> 16) let port = addr.port(); let xport = (port ^ ((MAGIC_COOKIE_U32 >> 16) as u16)) as u16; out.extend_from_slice(&xport.to_be_bytes()); // xaddr = ipv4 ^ magic_cookie let octets = v4.octets(); let cookie_bytes = MAGIC_COOKIE_BYTES; - for i in 0..4 { out.push(octets[i] ^ cookie_bytes[i]); } + for i in 0..4 { + out.push(octets[i] ^ cookie_bytes[i]); + } } IpAddr::V6(_v6) => { // For now, we don't support IPv6 in this MVP implementation @@ -168,18 +231,31 @@ pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8;1 } /// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only) -pub fn decode_xor_relayed_address(value: &[u8], _trans_id: &[u8;12]) -> Option { - if value.len() < 8 { return None; } - if value[1] != FAMILY_IPV4 { return None; } // not IPv4 +pub fn decode_xor_relayed_address( + value: &[u8], + _trans_id: &[u8; 12], +) -> Option { + if value.len() < 8 { + return None; + } + if value[1] != FAMILY_IPV4 { + return None; + } // not IPv4 let xport = u16::from_be_bytes([value[2], value[3]]); - let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16); - let cookie_bytes = MAGIC_COOKIE_BYTES; - let mut ipb = [0u8;4]; - for i in 0..4 { ipb[i] = value[4 + i] ^ cookie_bytes[i]; } + let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16); + let cookie_bytes = MAGIC_COOKIE_BYTES; + let mut ipb = [0u8; 4]; + for i in 0..4 { + ipb[i] = value[4 + i] ^ cookie_bytes[i]; + } let ip = std::net::Ipv4Addr::from(ipb); Some(std::net::SocketAddr::new(std::net::IpAddr::V4(ip), port)) } +/// Decode XOR-PEER-ADDRESS / XOR-MAPPED-ADDRESS style attributes (IPv4 only). +pub fn decode_xor_peer_address(value: &[u8], _trans_id: &[u8; 12]) -> Option { + decode_xor_relayed_address(value, _trans_id) +} #[cfg(test)] mod tests { @@ -189,20 +265,25 @@ mod tests { fn parse_minimal_binding() { // Build a minimal STUN Binding request with empty attributes let mut b = Vec::new(); - b.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request + b.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request b.extend_from_slice(&0u16.to_be_bytes()); // length - b.extend_from_slice(&MAGIC_COOKIE_BYTES); + b.extend_from_slice(&MAGIC_COOKIE_BYTES); let trans = [1u8; 12]; b.extend_from_slice(&trans); let msg = parse_message(&b).expect("parse"); - assert_eq!(msg.header.msg_type, METHOD_BINDING); + assert_eq!(msg.header.msg_type, METHOD_BINDING); assert_eq!(msg.header.transaction_id, trans); assert!(msg.attributes.is_empty()); } #[test] fn build_401_roundtrip() { - let req = StunHeader { msg_type: METHOD_BINDING, length: 0, cookie: MAGIC_COOKIE_U32, transaction_id: [2u8;12] }; + let req = StunHeader { + msg_type: METHOD_BINDING, + length: 0, + cookie: MAGIC_COOKIE_U32, + transaction_id: [2u8; 12], + }; let out = build_401_response(&req, "realm", "nonce", 401); // parse back should succeed let parsed = parse_message(&out).expect("parse resp"); @@ -218,26 +299,30 @@ mod tests { // Build message: Binding Request + USERNAME attribute + MESSAGE-INTEGRITY placeholder let mut buf = BytesMut::new(); - buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request + buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder - buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); let trans = [9u8; 12]; buf.extend_from_slice(&trans); - // USERNAME (ATTR_USERNAME) - let uname_bytes = username.as_bytes(); - buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); + // USERNAME (ATTR_USERNAME) + let uname_bytes = username.as_bytes(); + buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); buf.extend_from_slice(&(uname_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(uname_bytes); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // MESSAGE-INTEGRITY placeholder (0x0008) length 20 - let mi_attr_offset = buf.len(); - buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); + let mi_attr_offset = buf.len(); + buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); let mi_val_pos = buf.len(); - buf.extend_from_slice(&[0u8;20]); - while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } + buf.extend_from_slice(&[0u8; 20]); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } // Fix length let total_len = (buf.len() - 20) as u16; @@ -248,7 +333,9 @@ mod tests { // Compute HMAC over message up to MI attribute header (mi_attr_offset) let hmac = compute_message_integrity(password, &buf[..mi_attr_offset]); // place first 20 bytes into mi value - for i in 0..20 { buf[mi_val_pos + i] = hmac[i]; } + for i in 0..20 { + buf[mi_val_pos + i] = hmac[i]; + } // Parse and validate let parsed = parse_message(&buf).expect("parsed"); diff --git a/src/tls.rs b/src/tls.rs index d42fdf3..b610040 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,19 +1,22 @@ //! TLS listener that wraps the STUN/TURN logic for `turns:` clients. //! Backlog: ALPN negotiation, TCP relay support, and shared flow-control with the UDP path. -use std::sync::Arc; use anyhow::Context; -use tokio::net::TcpListener; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio_rustls::TlsAcceptor; -use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; use std::fs::File; use std::io::BufReader; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; +use tokio_rustls::TlsAcceptor; -use crate::auth::InMemoryStore; -use crate::traits::CredentialStore; use crate::alloc::AllocationManager; -use crate::stun::{parse_message, build_401_response, find_message_integrity, validate_message_integrity, build_success_response, encode_xor_relayed_address}; +use crate::auth::InMemoryStore; use crate::constants::*; +use crate::stun::{ + build_401_response, build_error_response, build_success_response, decode_xor_peer_address, + encode_xor_relayed_address, find_message_integrity, parse_message, validate_message_integrity, +}; +use crate::traits::CredentialStore; fn load_certs(path: &str) -> anyhow::Result> { let f = File::open(path).context("opening cert file")?; @@ -41,7 +44,14 @@ fn load_private_key(path: &str) -> anyhow::Result { /// Start a TLS-backed listener (turns) on the given bind address. /// This reuses the existing STUN/TURN message handling logic, but sends replies /// back over the TLS stream rather than UDP. -pub async fn serve_tls(bind: &str, cert_path: &str, key_path: &str, udp_sock: std::sync::Arc, creds: InMemoryStore, allocs: AllocationManager) -> anyhow::Result<()> { +pub async fn serve_tls( + bind: &str, + cert_path: &str, + key_path: &str, + udp_sock: std::sync::Arc, + creds: InMemoryStore, + allocs: AllocationManager, +) -> anyhow::Result<()> { let certs = load_certs(cert_path)?; let key = load_private_key(key_path)?; @@ -81,24 +91,39 @@ pub async fn serve_tls(bind: &str, cert_path: &str, key_path: &str, udp_sock: st while buffer.len() >= 20 { let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; let total = len + 20; - if buffer.len() < total { break; } + if buffer.len() < total { + break; + } let chunk = buffer.drain(..total).collect::>(); if let Ok(msg) = parse_message(&chunk) { // process message similarly to UDP path if let Some(_mi_attr) = find_message_integrity(&msg) { - let username_attr = msg.attributes.iter().find(|a| a.typ == ATTR_USERNAME); + let username_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_USERNAME); if let Some(u) = username_attr { - if let Ok(username) = std::str::from_utf8(&u.value) { - let pw = creds_clone.get_password(username).await; + if let Ok(username) = std::str::from_utf8(&u.value) + { + let pw = + creds_clone.get_password(username).await; if let Some(password) = pw { - let valid = validate_message_integrity(&msg, &password); + let valid = validate_message_integrity( + &msg, &password, + ); if valid { - tracing::info!("MI valid for user {} on TLS", username); - if msg.header.msg_type == METHOD_ALLOCATE { + tracing::info!( + "MI valid for user {} on TLS", + username + ); + if msg.header.msg_type + == METHOD_ALLOCATE + { match alloc_clone.allocate_for(peer, udp_clone.clone()).await { Ok(relay_addr) => { let mut out = Vec::new(); - out.extend_from_slice(&RESP_BINDING_SUCCESS.to_be_bytes()); + let success_type = msg.header.msg_type | CLASS_SUCCESS; + out.extend_from_slice(&success_type.to_be_bytes()); out.extend_from_slice(&0u16.to_be_bytes()); out.extend_from_slice(&MAGIC_COOKIE_BYTES); out.extend_from_slice(&msg.header.transaction_id); @@ -117,54 +142,528 @@ pub async fn serve_tls(bind: &str, cert_path: &str, key_path: &str, udp_sock: st } Err(e) => tracing::error!("allocate failed after MI valid (tls): {:?}", e), } + } else if msg.header.msg_type + == METHOD_CREATE_PERMISSION + { + if alloc_clone + .get_allocation(&peer) + .is_none() + { + tracing::warn!("create-permission without allocation from {} (tls)", peer); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!("failed to write tls error: {:?}", e); + } + continue; + } + + let mut added = 0usize; + for attr in msg + .attributes + .iter() + .filter(|a| { + a.typ + == ATTR_XOR_PEER_ADDRESS + }) + { + if let Some(peer_addr) = + decode_xor_peer_address( + &attr.value, + &msg.header + .transaction_id, + ) + { + match alloc_clone.add_permission(peer, peer_addr) { + Ok(()) => { + tracing::info!("added TLS permission for {} -> {}", peer, peer_addr); + added += 1; + } + Err(e) => tracing::error!("failed to persist TLS permission {} -> {}: {:?}", peer, peer_addr, e), + } + } else { + tracing::warn!("invalid XOR-PEER-ADDRESS via TLS from {}", peer); + } + } + + let resp = if added == 0 { + build_error_response( + &msg.header, + 400, + "No valid XOR-PEER-ADDRESS", + ) + } else { + build_success_response( + &msg.header, + ) + }; + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!("failed to write tls response: {:?}", e); + } + continue; + } else if msg.header.msg_type + == METHOD_CHANNEL_BIND + { + let allocation = match alloc_clone + .get_allocation(&peer) + { + Some(a) => a, + None => { + tracing::warn!( + "channel-bind without allocation from {} (tls)", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let channel_attr = msg + .attributes + .iter() + .find(|a| { + a.typ == ATTR_CHANNEL_NUMBER + }); + let peer_attr = msg + .attributes + .iter() + .find(|a| { + a.typ + == ATTR_XOR_PEER_ADDRESS + }); + + let channel = match channel_attr + .and_then(|attr| { + if attr.value.len() >= 4 { + Some( + u16::from_be_bytes( + [ + attr.value + [0], + attr.value + [1], + ], + ), + ) + } else { + None + } + }) { + Some(c) => c, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing CHANNEL-NUMBER", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if channel < 0x4000 + || channel > 0x7FFF + { + let resp = build_error_response( + &msg.header, + 400, + "Channel Out Of Range", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + + let peer_addr = match peer_attr + .and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header + .transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if !allocation + .is_peer_allowed(&peer_addr) + { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + + match alloc_clone + .add_channel_binding( + peer, channel, peer_addr, + ) { + Ok(()) => { + tracing::info!( + "bound channel 0x{:04x} for {} -> {} over TLS", + channel, + peer, + peer_addr + ); + let resp = + build_success_response( + &msg.header, + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls response: {:?}", + e + ); + } + } + Err(e) => { + tracing::error!( + "failed TLS channel binding {} -> {} (0x{:04x}): {:?}", + peer, + peer_addr, + channel, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Channel Binding Failed", + ); + if let Err(e2) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e2 + ); + } + } + } + continue; + } else if msg.header.msg_type + == METHOD_SEND + { + let allocation = match alloc_clone + .get_allocation(&peer) + { + Some(a) => a, + None => { + tracing::warn!( + "send without allocation from {} (tls)", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let peer_attr = msg + .attributes + .iter() + .find(|a| { + a.typ + == ATTR_XOR_PEER_ADDRESS + }); + let data_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_DATA); + + let peer_addr = match peer_attr + .and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header + .transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let data_attr = match data_attr { + Some(attr) => attr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing DATA Attribute", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if !allocation + .is_peer_allowed(&peer_addr) + { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + + match allocation + .send_to_peer( + peer_addr, + &data_attr.value, + ) + .await + { + Ok(sent) => { + tracing::info!( + "forwarded {} bytes from {} to {} via TLS session", + sent, + peer, + peer_addr + ); + let resp = + build_success_response( + &msg.header, + ); + if let Err(e) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls response: {:?}", + e + ); + } + } + Err(e) => { + tracing::error!( + "failed to send payload from {} to {} via TLS: {:?}", + peer, + peer_addr, + e + ); + let resp = + build_error_response( + &msg.header, + 500, + "Peer Send Failed", + ); + if let Err(e2) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls error: {:?}", + e2 + ); + } + } + } + continue; } - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { + let resp = + build_success_response(&msg.header); + if let Err(e) = + tls_stream.write_all(&resp).await + { tracing::error!("failed to write tls response: {:?}", e); } continue; } else { - tracing::info!("MI invalid for user {} (tls)", username); + tracing::info!( + "MI invalid for user {} (tls)", + username + ); } } else { - tracing::info!("unknown user {} (tls)", username); + tracing::info!( + "unknown user {} (tls)", + username + ); } } } } if msg.header.msg_type == METHOD_ALLOCATE { - match alloc_clone.allocate_for(peer, udp_clone.clone()).await { + match alloc_clone + .allocate_for(peer, udp_clone.clone()) + .await + { Ok(relay_addr) => { let mut out = Vec::new(); - out.extend_from_slice(&RESP_BINDING_SUCCESS.to_be_bytes()); + let success_type = + msg.header.msg_type | CLASS_SUCCESS; + out.extend_from_slice( + &success_type.to_be_bytes(), + ); out.extend_from_slice(&0u16.to_be_bytes()); out.extend_from_slice(&MAGIC_COOKIE_BYTES); - out.extend_from_slice(&msg.header.transaction_id); - let attr_val = encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); - out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); - out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); + out.extend_from_slice( + &msg.header.transaction_id, + ); + let attr_val = encode_xor_relayed_address( + &relay_addr, + &msg.header.transaction_id, + ); + out.extend_from_slice( + &ATTR_XOR_RELAYED_ADDRESS.to_be_bytes(), + ); + out.extend_from_slice( + &((attr_val.len() as u16).to_be_bytes()), + ); out.extend_from_slice(&attr_val); - while (out.len() % 4) != 0 { out.extend_from_slice(&[0]); } + while (out.len() % 4) != 0 { + out.extend_from_slice(&[0]); + } let total_len = (out.len() - 20) as u16; let len_bytes = total_len.to_be_bytes(); - out[2] = len_bytes[0]; out[3] = len_bytes[1]; - if let Err(e) = tls_stream.write_all(&out).await { - tracing::error!("failed to write tls response: {:?}", e); + out[2] = len_bytes[0]; + out[3] = len_bytes[1]; + if let Err(e) = tls_stream.write_all(&out).await + { + tracing::error!( + "failed to write tls response: {:?}", + e + ); } } - Err(e) => tracing::error!("allocate failed (tls): {:?}", e), + Err(e) => tracing::error!( + "allocate failed (tls): {:?}", + e + ), } continue; } // default: send 401 challenge let nonce = format!("nonce-{}", uuid::Uuid::new_v4()); - let resp = build_401_response(&msg.header, "niom-turn.local", &nonce, 401); + let resp = build_401_response( + &msg.header, + "niom-turn.local", + &nonce, + 401, + ); if let Err(e) = tls_stream.write_all(&resp).await { tracing::error!("failed to write tls 401: {:?}", e); } } else { - tracing::debug!("failed to parse stun message on tls from {}", peer); + tracing::debug!( + "failed to parse stun message on tls from {}", + peer + ); } } }