From 15dfec869527f02ce998887ea634d352fb676e61 Mon Sep 17 00:00:00 2001 From: ghost Date: Mon, 24 Nov 2025 16:56:54 +0100 Subject: [PATCH] Refactor: project structure and logging. Add: Integration tests for happy flow of UDP and TLS. --- Cargo.lock | 39 + Cargo.toml | 2 +- src/bin/allocate_smoke.rs | 138 ---- src/bin/smoke_client.rs | 70 -- src/lib.rs | 4 + src/logging.rs | 28 + src/main.rs | 459 +----------- src/server.rs | 429 +++++++++++ src/tls.rs | 1272 +++++++++++++++----------------- tests/support/mod.rs | 37 + tests/support/stun_builders.rs | 210 ++++++ tests/support/tls.rs | 59 ++ tests/tls_turn.rs | 155 ++++ tests/udp_turn.rs | 243 ++++++ 14 files changed, 1814 insertions(+), 1331 deletions(-) delete mode 100644 src/bin/allocate_smoke.rs delete mode 100644 src/bin/smoke_client.rs create mode 100644 src/logging.rs create mode 100644 src/server.rs create mode 100644 tests/support/mod.rs create mode 100644 tests/support/stun_builders.rs create mode 100644 tests/support/tls.rs create mode 100644 tests/tls_turn.rs create mode 100644 tests/udp_turn.rs diff --git a/Cargo.lock b/Cargo.lock index 7feaa80..3f1196b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -273,6 +282,15 @@ version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "md5" version = "0.7.0" @@ -449,6 +467,23 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + [[package]] name = "ring" version = "0.16.20" @@ -817,10 +852,14 @@ version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex-automata", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] diff --git a/Cargo.toml b/Cargo.toml index 924d831..df8c7bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ hex = "0.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["fmt"] } +tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } # TLS for turns (server) tokio-rustls = "0.23" diff --git a/src/bin/allocate_smoke.rs b/src/bin/allocate_smoke.rs deleted file mode 100644 index cb998fb..0000000 --- a/src/bin/allocate_smoke.rs +++ /dev/null @@ -1,138 +0,0 @@ -use bytes::BytesMut; -use niom_turn::constants::*; -// use niom_turn::stun; // not needed; use specific functions via path when required -use std::net::SocketAddr; -use std::time::Duration; -use tokio::net::UdpSocket; - -// Use shared decoder from library: niom_turn::stun::decode_xor_relayed_address - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); - let server: SocketAddr = "127.0.0.1:3478".parse()?; - let local = UdpSocket::bind("0.0.0.0:0").await?; - - let username = "testuser"; - let password = "secretpassword"; - - // Build Allocate request (method METHOD_ALLOCATE) - let mut buf = BytesMut::new(); - buf.extend_from_slice(&METHOD_ALLOCATE.to_be_bytes()); // Allocate Request - buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder - buf.extend_from_slice(&MAGIC_COOKIE_BYTES); - let trans = [13u8; 12]; - buf.extend_from_slice(&trans); - - // USERNAME - let uname = username.as_bytes(); - buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); - buf.extend_from_slice(&(uname.len() as u16).to_be_bytes()); - buf.extend_from_slice(uname); - while (buf.len() % 4) != 0 { - buf.extend_from_slice(&[0u8]); - } - - // MESSAGE-INTEGRITY placeholder - let mi_attr_offset = buf.len(); - buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); - buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); - let mi_val_pos = buf.len(); - buf.extend_from_slice(&[0u8; 20]); - while (buf.len() % 4) != 0 { - buf.extend_from_slice(&[0u8]); - } - - // fix length - let total_len = (buf.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - buf[2] = len_bytes[0]; - buf[3] = len_bytes[1]; - - // compute HMAC over bytes up to MI attribute header - { - use hmac::{Hmac, Mac}; - use sha1::Sha1; - type HmacSha1 = Hmac; - let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); - mac.update(&buf[..mi_attr_offset]); - let res = mac.finalize().into_bytes(); - for i in 0..20 { - buf[mi_val_pos + i] = res[i]; - } - } - - // send Allocate - local.send_to(&buf, server).await?; - - // receive response - let mut r = vec![0u8; 1500]; - let (len, _addr) = local.recv_from(&mut r).await?; - println!("got {} bytes", len); - let resp = &r[..len]; - // expect success (METHOD_ALLOCATE | CLASS_SUCCESS) with XOR-RELAYED-ADDRESS attr - if resp.len() < 20 { - anyhow::bail!("response too short"); - } - let typ = u16::from_be_bytes([resp[0], resp[1]]); - println!("resp type 0x{:04x}", typ); - let expected_type = METHOD_ALLOCATE | CLASS_SUCCESS; - if typ != expected_type { - anyhow::bail!("expected success response, got 0x{:04x}", typ); - } - // parse attributes - let length = u16::from_be_bytes([resp[2], resp[3]]) as usize; - let total = 20 + length; - let mut offset = 20; - let mut relay_addr_opt: Option = None; - while offset + 4 <= total { - let atype = u16::from_be_bytes([resp[offset], resp[offset + 1]]); - let alen = u16::from_be_bytes([resp[offset + 2], resp[offset + 3]]) as usize; - offset += 4; - if offset + alen > total { - break; - } - println!("attr type=0x{:04x} len={}", atype, alen); - println!("raw: {}", hex::encode(&resp[offset..offset + alen])); - if atype == ATTR_XOR_RELAYED_ADDRESS { - // XOR-RELAYED-ADDRESS: decode via shared library function - if let Some(sa) = - niom_turn::stun::decode_xor_relayed_address(&resp[offset..offset + alen], &trans) - { - relay_addr_opt = Some(sa); - } - } - offset += alen; - let pad = (4 - (alen % 4)) % 4; - offset += pad; - } - - let relay_addr = match relay_addr_opt { - Some(a) => a, - None => anyhow::bail!("no relay address in response"), - }; - - println!("got relayed addr: {}", relay_addr); - - // send test payload to relay addr - let payload = b"hello-relay"; - local.send_to(payload, relay_addr).await?; - - // wait for forwarded packet (should arrive via server socket) using tokio timeout - let mut buf2 = vec![0u8; 1500]; - match tokio::time::timeout(Duration::from_secs(2), local.recv_from(&mut buf2)).await { - Ok(Ok((l, src))) => { - println!("received {} bytes from {}", l, src); - let got = &buf2[..l]; - println!("payload: {:?}", got); - if got == payload { - println!("relay test success"); - Ok(()) - } else { - anyhow::bail!("payload mismatch") - } - } - Ok(Err(e)) => anyhow::bail!("recv error: {:?}", e), - Err(_) => anyhow::bail!("no forwarded packet received: timeout"), - } -} diff --git a/src/bin/smoke_client.rs b/src/bin/smoke_client.rs deleted file mode 100644 index 7bd39df..0000000 --- a/src/bin/smoke_client.rs +++ /dev/null @@ -1,70 +0,0 @@ -use bytes::BytesMut; -use niom_turn::constants::*; -use std::net::SocketAddr; -use tokio::net::UdpSocket; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); - let server: SocketAddr = "127.0.0.1:3478".parse()?; - let local = UdpSocket::bind("0.0.0.0:0").await?; - - // Build a minimal STUN Binding Request with USERNAME and placeholder MESSAGE-INTEGRITY - let username = "testuser"; - let password = "secretpassword"; // matches server's in-memory creds - - let mut buf = BytesMut::new(); - buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request - buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder - buf.extend_from_slice(&MAGIC_COOKIE_BYTES); - let trans = [7u8; 12]; - buf.extend_from_slice(&trans); - - // USERNAME - let uname = username.as_bytes(); - buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes()); - buf.extend_from_slice(&(uname.len() as u16).to_be_bytes()); - buf.extend_from_slice(uname); - while (buf.len() % 4) != 0 { - buf.extend_from_slice(&[0u8]); - } - - // MESSAGE-INTEGRITY placeholder - let mi_attr_offset = buf.len(); - buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); - buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); - let mi_val_pos = buf.len(); - buf.extend_from_slice(&[0u8; 20]); - while (buf.len() % 4) != 0 { - buf.extend_from_slice(&[0u8]); - } - - // fix length - let total_len = (buf.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - buf[2] = len_bytes[0]; - buf[3] = len_bytes[1]; - - // compute HMAC over bytes up to MI attribute header - { - use hmac::{Hmac, Mac}; - use sha1::Sha1; - type HmacSha1 = Hmac; - let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key"); - mac.update(&buf[..mi_attr_offset]); - let res = mac.finalize().into_bytes(); - for i in 0..20 { - buf[mi_val_pos + i] = res[i]; - } - } - - // send - local.send_to(&buf, server).await?; - - let mut r = vec![0u8; 1500]; - let (len, addr) = local.recv_from(&mut r).await?; - println!("got {} bytes from {}", len, addr); - // dump hex - println!("{:02x?}", &r[..len]); - Ok(()) -} diff --git a/src/lib.rs b/src/lib.rs index ad2e14b..07a5a15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,11 +3,15 @@ pub mod alloc; pub mod auth; pub mod config; pub mod constants; +pub mod logging; pub mod models; +pub mod server; pub mod stun; pub mod tls; pub mod traits; pub use crate::alloc::*; pub use crate::auth::*; +pub use crate::logging::*; +pub use crate::server::*; pub use crate::stun::*; diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 0000000..c7094b3 --- /dev/null +++ b/src/logging.rs @@ -0,0 +1,28 @@ +//! Logging helpers shared between the library, binaries, and integration tests. +//! Ensures we configure tracing once with sensible default filters while still +//! allowing `RUST_LOG` to override verbosity. +use std::sync::Once; +use tracing_subscriber::{fmt, EnvFilter}; + +static INIT: Once = Once::new(); + +/// Initialise tracing with a sane default filter (`warn` globally, +/// `niom_turn=info`) unless `RUST_LOG` is provided. +pub fn init_tracing() { + init_tracing_with_default("warn,niom_turn=info"); +} + +/// Initialise tracing with a custom default directive that can still be +/// overridden via `RUST_LOG` at runtime. +pub fn init_tracing_with_default(default_directive: &str) { + INIT.call_once(|| { + let env_filter = + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(default_directive)); + + fmt() + .with_env_filter(env_filter) + .with_target(false) + .compact() + .init(); + }); +} diff --git a/src/main.rs b/src/main.rs index e3fb84a..2758024 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,24 +2,18 @@ //! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners. use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; use tokio::net::UdpSocket; use tracing::{error, info}; // Use the library crate's public modules instead of local `mod` declarations. use niom_turn::alloc::AllocationManager; -use niom_turn::auth::{AuthManager, AuthStatus, InMemoryStore}; +use niom_turn::auth::{AuthManager, InMemoryStore}; use niom_turn::config::{AuthOptions, Config}; -use niom_turn::constants::*; -use niom_turn::stun::{ - build_401_response, build_allocate_success, build_error_response, build_lifetime_success, - build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data, - parse_message, -}; +use niom_turn::server::udp_reader_loop; #[tokio::main] async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); + niom_turn::logging::init_tracing(); info!("niom-turn starting"); @@ -64,7 +58,7 @@ async fn main() -> anyhow::Result<()> { let udp = UdpSocket::bind(bind_addr).await?; let udp = Arc::new(udp); - // allocation manager + // Allocation manager shared by UDP + TLS frontends. let alloc_mgr = AllocationManager::new(); // Spawn the asynchronous packet loop that handles all UDP requests. @@ -103,448 +97,3 @@ async fn main() -> anyhow::Result<()> { tokio::time::sleep(std::time::Duration::from_secs(60)).await; } } - -async fn udp_reader_loop( - udp: Arc, - auth: AuthManager, - allocs: AllocationManager, -) -> anyhow::Result<()> { - let mut buf = vec![0u8; 1500]; - loop { - // Read the next datagram and keep peer metadata for follow-up responses. - let (len, peer) = udp.recv_from(&mut buf).await?; - tracing::debug!("got {} bytes from {}", len, peer); - - if let Some((channel, payload)) = parse_channel_data(&buf[..len]) { - let allocation = match allocs.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!("channel data without allocation from {}", peer); - continue; - } - }; - - let target = match allocation.channel_peer(channel) { - Some(addr) => addr, - None => { - tracing::warn!( - "channel data with unknown channel 0x{:04x} from {}", - channel, - peer - ); - continue; - } - }; - - if !allocation.is_peer_allowed(&target) { - tracing::warn!( - "channel data target {} no longer permitted for {}", - target, - peer - ); - continue; - } - - match allocation.send_to_peer(target, payload).await { - Ok(sent) => tracing::debug!( - "forwarded {} bytes via channel 0x{:04x} from {} to {}", - sent, - channel, - peer, - target - ), - Err(e) => tracing::error!( - "failed to forward channel data 0x{:04x} from {} to {}: {:?}", - channel, - peer, - target, - e - ), - } - continue; - } - - // Minimal STUN/TURN detection: parse STUN messages and send 401 challenge - if let Ok(msg) = parse_message(&buf[..len]) { - tracing::info!( - "STUN/TURN message from {} type=0x{:04x} len={}", - peer, - msg.header.msg_type, - len - ); - let requires_auth = matches!( - msg.header.msg_type, - METHOD_ALLOCATE - | METHOD_CREATE_PERMISSION - | METHOD_CHANNEL_BIND - | METHOD_SEND - | METHOD_REFRESH - ); - - if requires_auth { - match auth.authenticate(&msg, &peer).await { - AuthStatus::Granted { username } => { - tracing::debug!( - "TURN auth ok for {} as {} (0x{:04x})", - peer, - username, - msg.header.msg_type - ); - } - AuthStatus::Challenge { nonce } => { - let resp = build_401_response( - &msg.header, - auth.realm(), - &nonce, - 401, - "Unauthorized", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - AuthStatus::StaleNonce { nonce } => { - let resp = build_401_response( - &msg.header, - auth.realm(), - &nonce, - 438, - "Stale Nonce", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - AuthStatus::Reject { code, reason } => { - let resp = build_error_response(&msg.header, code, reason); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - } - - match msg.header.msg_type { - METHOD_ALLOCATE => { - let requested_lifetime = extract_lifetime_seconds(&msg) - .map(|secs| Duration::from_secs(secs as u64)) - .filter(|d| !d.is_zero()); - - match allocs.allocate_for(peer, udp.clone()).await { - Ok(relay_addr) => { - let applied = - match allocs.refresh_allocation(peer, requested_lifetime) { - Ok(d) => d, - Err(e) => { - tracing::error!( - "failed to apply lifetime for {}: {:?}", - peer, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Allocate Failed", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32; - let resp = - build_allocate_success(&msg.header, &relay_addr, lifetime_secs); - tracing::info!( - "allocated relay {} for {} lifetime={}s", - relay_addr, - peer, - lifetime_secs - ); - let _ = udp.send_to(&resp, &peer).await; - } - Err(e) => { - tracing::error!("allocate failed: {:?}", e); - let resp = - build_error_response(&msg.header, 500, "Allocate Failed"); - let _ = udp.send_to(&resp, &peer).await; - } - } - continue; - } - METHOD_CREATE_PERMISSION => { - if allocs.get_allocation(&peer).is_none() { - tracing::warn!("create-permission without allocation from {}", peer); - let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - - let mut added = 0usize; - for attr in msg - .attributes - .iter() - .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) - { - if let Some(peer_addr) = - decode_xor_peer_address(&attr.value, &msg.header.transaction_id) - { - match allocs.add_permission(peer, peer_addr) { - Ok(()) => { - tracing::info!( - "added permission for {} -> {}", - peer, - peer_addr - ); - added += 1; - } - Err(e) => { - tracing::error!( - "failed to persist permission {} -> {}: {:?}", - peer, - peer_addr, - e - ); - } - } - } else { - tracing::warn!("invalid XOR-PEER-ADDRESS in request from {}", peer); - } - } - - if added == 0 { - let resp = - build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS"); - let _ = udp.send_to(&resp, &peer).await; - } else { - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; - } - continue; - } - METHOD_CHANNEL_BIND => { - let allocation = match allocs.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!("channel-bind without allocation from {}", peer); - let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - let channel_attr = - msg.attributes.iter().find(|a| a.typ == ATTR_CHANNEL_NUMBER); - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - - let channel = match channel_attr.and_then(|attr| { - if attr.value.len() >= 4 { - Some(u16::from_be_bytes([attr.value[0], attr.value[1]])) - } else { - None - } - }) { - Some(c) => c, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing CHANNEL-NUMBER", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - if channel < 0x4000 || channel > 0x7FFF { - let resp = - build_error_response(&msg.header, 400, "Channel Out Of Range"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address(&attr.value, &msg.header.transaction_id) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - - match allocs.add_channel_binding(peer, channel, peer_addr) { - Ok(()) => { - tracing::info!( - "bound channel 0x{:04x} for {} -> {}", - channel, - peer, - peer_addr - ); - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; - } - Err(e) => { - tracing::error!( - "failed to add channel binding {} -> {} (channel 0x{:04x}): {:?}", - peer, - peer_addr, - channel, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Channel Binding Failed", - ); - let _ = udp.send_to(&resp, &peer).await; - } - } - continue; - } - METHOD_SEND => { - let allocation = match allocs.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!("send without allocation from {}", peer); - let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA); - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address(&attr.value, &msg.header.transaction_id) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - let data_attr = match data_attr { - Some(attr) => attr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing DATA Attribute", - ); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - - match allocation.send_to_peer(peer_addr, &data_attr.value).await { - Ok(sent) => { - tracing::info!( - "forwarded {} bytes from {} to peer {}", - sent, - peer, - peer_addr - ); - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; - } - Err(e) => { - tracing::error!( - "failed to send payload from {} to {}: {:?}", - peer, - peer_addr, - e - ); - let resp = - build_error_response(&msg.header, 500, "Peer Send Failed"); - let _ = udp.send_to(&resp, &peer).await; - } - } - continue; - } - METHOD_REFRESH => { - let requested = extract_lifetime_seconds(&msg) - .map(|secs| Duration::from_secs(secs as u64)); - - match allocs.refresh_allocation(peer, requested) { - Ok(applied) => { - if applied.is_zero() { - tracing::info!("allocation for {} released", peer); - } else { - tracing::debug!( - "allocation for {} refreshed to {}s", - peer, - applied.as_secs() - ); - } - let resp = build_lifetime_success( - &msg.header, - applied.as_secs().min(u32::MAX as u64) as u32, - ); - let _ = udp.send_to(&resp, &peer).await; - } - Err(_) => { - let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); - let _ = udp.send_to(&resp, &peer).await; - } - } - continue; - } - _ => { - let resp = build_error_response(&msg.header, 420, "Unknown TURN Method"); - let _ = udp.send_to(&resp, &peer).await; - continue; - } - } - } - - match msg.header.msg_type { - METHOD_BINDING => { - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; - } - _ => { - let nonce = auth.mint_nonce(&peer); - let resp = - build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized"); - if let Err(e) = udp.send_to(&resp, &peer).await { - error!("failed to send 401: {:?}", e); - } - } - } - } else { - tracing::debug!("Non-STUN or parse error from {} len={}", peer, len); - } - } -} - -// existing helper functions moved to stun.rs diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..c58e984 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,429 @@ +//! Shared server routines for UDP TURN handling so integration tests can reuse the core loop. +use std::sync::Arc; +use tokio::net::UdpSocket; +use tracing::{error, warn}; + +use crate::alloc::AllocationManager; +use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; +use crate::constants::*; +use crate::stun::{ + build_401_response, build_allocate_success, build_error_response, build_lifetime_success, + build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data, + parse_message, +}; +use std::time::Duration; + +/// Main UDP reader loop shared between binary entry point and integration tests. +pub async fn udp_reader_loop( + udp: Arc, + auth: AuthManager, + allocs: AllocationManager, +) -> anyhow::Result<()> { + let mut buf = vec![0u8; 1500]; + loop { + let (len, peer) = udp.recv_from(&mut buf).await?; + tracing::debug!("got {} bytes from {}", len, peer); + + if let Some((channel, payload)) = parse_channel_data(&buf[..len]) { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + warn!("channel data without allocation from {}", peer); + continue; + } + }; + + let target = match allocation.channel_peer(channel) { + Some(addr) => addr, + None => { + warn!( + "channel data with unknown channel 0x{:04x} from {}", + channel, peer + ); + continue; + } + }; + + if !allocation.is_peer_allowed(&target) { + warn!( + "channel data target {} no longer permitted for {}", + target, peer + ); + continue; + } + + match allocation.send_to_peer(target, payload).await { + Ok(sent) => tracing::debug!( + "forwarded {} bytes via channel 0x{:04x} from {} to {}", + sent, + channel, + peer, + target + ), + Err(e) => error!( + "failed to forward channel data 0x{:04x} from {} to {}: {:?}", + channel, peer, target, e + ), + } + continue; + } + + if let Ok(msg) = parse_message(&buf[..len]) { + tracing::info!( + "STUN/TURN message from {} type=0x{:04x} len={}", + peer, + msg.header.msg_type, + len + ); + let requires_auth = matches!( + msg.header.msg_type, + METHOD_ALLOCATE + | METHOD_CREATE_PERMISSION + | METHOD_CHANNEL_BIND + | METHOD_SEND + | METHOD_REFRESH + ); + + if requires_auth { + match auth.authenticate(&msg, &peer).await { + AuthStatus::Granted { username } => { + tracing::debug!( + "TURN auth ok for {} as {} (0x{:04x})", + peer, + username, + msg.header.msg_type + ); + } + AuthStatus::Challenge { nonce } => { + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 401, + "Unauthorized", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + AuthStatus::StaleNonce { nonce } => { + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 438, + "Stale Nonce", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + AuthStatus::Reject { code, reason } => { + let resp = build_error_response(&msg.header, code, reason); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + } + + match msg.header.msg_type { + METHOD_ALLOCATE => { + let requested_lifetime = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)) + .filter(|d| !d.is_zero()); + + match allocs.allocate_for(peer, udp.clone()).await { + Ok(relay_addr) => { + let applied = + match allocs.refresh_allocation(peer, requested_lifetime) { + Ok(d) => d, + Err(e) => { + tracing::error!( + "failed to apply lifetime for {}: {:?}", + peer, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Allocate Failed", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32; + let resp = + build_allocate_success(&msg.header, &relay_addr, lifetime_secs); + tracing::info!( + "allocated relay {} for {} lifetime={}s", + relay_addr, + peer, + lifetime_secs + ); + let _ = udp.send_to(&resp, &peer).await; + } + Err(e) => { + tracing::error!("allocate failed: {:?}", e); + let resp = + build_error_response(&msg.header, 500, "Allocate Failed"); + let _ = udp.send_to(&resp, &peer).await; + } + } + continue; + } + METHOD_CREATE_PERMISSION => { + if allocs.get_allocation(&peer).is_none() { + warn!("create-permission without allocation from {}", peer); + let resp = + build_error_response(&msg.header, 437, "Allocation Mismatch"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + let mut added = 0usize; + for attr in msg + .attributes + .iter() + .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) + { + if let Some(peer_addr) = + decode_xor_peer_address(&attr.value, &msg.header.transaction_id) + { + match allocs.add_permission(peer, peer_addr) { + Ok(()) => { + tracing::info!( + "added permission for {} -> {}", + peer, + peer_addr + ); + added += 1; + } + Err(e) => { + tracing::error!( + "failed to persist permission {} -> {}: {:?}", + peer, + peer_addr, + e + ); + } + } + } else { + tracing::warn!("invalid XOR-PEER-ADDRESS in request from {}", peer); + } + } + + if added == 0 { + let resp = + build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS"); + let _ = udp.send_to(&resp, &peer).await; + } else { + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } + continue; + } + METHOD_CHANNEL_BIND => { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + warn!("channel-bind without allocation from {}", peer); + let resp = + build_error_response(&msg.header, 437, "Allocation Mismatch"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let channel_attr = + msg.attributes.iter().find(|a| a.typ == ATTR_CHANNEL_NUMBER); + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let (channel_attr, peer_attr) = match (channel_attr, peer_attr) { + (Some(c), Some(p)) => (c, p), + _ => { + let resp = build_error_response( + &msg.header, + 400, + "Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let channel = + u16::from_be_bytes([channel_attr.value[0], channel_attr.value[1]]); + let peer_addr = match decode_xor_peer_address( + &peer_attr.value, + &msg.header.transaction_id, + ) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Invalid XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + if let Err(e) = allocs.add_channel_binding(peer, channel, peer_addr) { + tracing::error!( + "failed to persist channel binding {} -> {} (0x{:04x}): {:?}", + peer, + peer_addr, + channel, + e + ); + let resp = + build_error_response(&msg.header, 500, "Channel Bind Failed"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + METHOD_SEND => { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + warn!("send indication without allocation from {}", peer); + let resp = + build_error_response(&msg.header, 437, "Allocation Mismatch"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA); + let (peer_attr, data_attr) = match (peer_attr, data_attr) { + (Some(p), Some(d)) => (p, d), + _ => { + let resp = build_error_response( + &msg.header, + 400, + "Missing DATA or XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let peer_addr = match decode_xor_peer_address( + &peer_attr.value, + &msg.header.transaction_id, + ) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Invalid XOR-PEER-ADDRESS", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + + match allocation.send_to_peer(peer_addr, &data_attr.value).await { + Ok(sent) => { + tracing::info!( + "forwarded {} bytes from {} to {}", + sent, + peer, + peer_addr + ); + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } + Err(e) => { + tracing::error!( + "failed to send payload from {} to {}: {:?}", + peer, + peer_addr, + e + ); + let resp = + build_error_response(&msg.header, 500, "Peer Send Failed"); + let _ = udp.send_to(&resp, &peer).await; + } + } + continue; + } + METHOD_REFRESH => { + let requested = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)); + + match allocs.refresh_allocation(peer, requested) { + Ok(applied) => { + if applied.is_zero() { + tracing::info!("allocation for {} released", peer); + } else { + tracing::debug!( + "allocation for {} refreshed to {}s", + peer, + applied.as_secs() + ); + } + let resp = build_lifetime_success( + &msg.header, + applied.as_secs().min(u32::MAX as u64) as u32, + ); + let _ = udp.send_to(&resp, &peer).await; + } + Err(_) => { + let resp = + build_error_response(&msg.header, 437, "Allocation Mismatch"); + let _ = udp.send_to(&resp, &peer).await; + } + } + continue; + } + _ => { + let resp = build_error_response(&msg.header, 420, "Unknown TURN Method"); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + } + } + + match msg.header.msg_type { + METHOD_BINDING => { + let resp = build_success_response(&msg.header); + let _ = udp.send_to(&resp, &peer).await; + } + _ => { + let nonce = auth.mint_nonce(&peer); + let resp = + build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized"); + if let Err(e) = udp.send_to(&resp, &peer).await { + error!("failed to send 401: {:?}", e); + } + } + } + } else { + tracing::debug!("Non-STUN or parse error from {} len={}", peer, len); + } + } +} diff --git a/src/tls.rs b/src/tls.rs index 4658500..c5c0abf 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -5,7 +5,7 @@ use std::fs::File; use std::io::BufReader; use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; use tokio_rustls::TlsAcceptor; @@ -75,672 +75,16 @@ pub async fn serve_tls( tokio::spawn(async move { match acceptor.accept(stream).await { Ok(mut tls_stream) => { - tracing::info!("accepted TLS connection from {}", peer); - let mut read_buf = vec![0u8; 4096]; - let mut buffer: Vec = Vec::new(); - - loop { - match tls_stream.read(&mut read_buf).await { - Ok(0) => { - tracing::info!("TLS client {} closed connection", peer); - break; - } - Ok(n) => { - buffer.extend_from_slice(&read_buf[..n]); - while buffer.len() >= 20 { - let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; - let total = len + 20; - if buffer.len() < total { - break; - } - let chunk = buffer.drain(..total).collect::>(); - if let Ok(msg) = parse_message(&chunk) { - tracing::info!( - "STUN/TURN over TLS from {} type=0x{:04x} len={}", - peer, - msg.header.msg_type, - total - ); - - let requires_auth = matches!( - msg.header.msg_type, - METHOD_ALLOCATE - | METHOD_CREATE_PERMISSION - | METHOD_CHANNEL_BIND - | METHOD_SEND - | METHOD_REFRESH - ); - - if requires_auth { - match auth_clone.authenticate(&msg, &peer).await { - AuthStatus::Granted { username } => { - tracing::debug!( - "TURN TLS auth ok for {} as {} (0x{:04x})", - peer, - username, - msg.header.msg_type - ); - } - AuthStatus::Challenge { nonce } => { - let resp = build_401_response( - &msg.header, - auth_clone.realm(), - &nonce, - 401, - "Unauthorized", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls challenge: {:?}", - e - ); - } - continue; - } - AuthStatus::StaleNonce { nonce } => { - let resp = build_401_response( - &msg.header, - auth_clone.realm(), - &nonce, - 438, - "Stale Nonce", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls stale nonce: {:?}", - e - ); - } - continue; - } - AuthStatus::Reject { code, reason } => { - let resp = build_error_response( - &msg.header, - code, - reason, - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls auth error: {:?}", - e - ); - } - continue; - } - } - } - - match msg.header.msg_type { - METHOD_ALLOCATE => { - let requested_lifetime = - extract_lifetime_seconds(&msg) - .map(|secs| { - Duration::from_secs(secs as u64) - }) - .filter(|d| !d.is_zero()); - - match alloc_clone - .allocate_for(peer, udp_clone.clone()) - .await - { - Ok(relay_addr) => { - let applied = match alloc_clone - .refresh_allocation( - peer, - requested_lifetime, - ) { - Ok(d) => d, - Err(e) => { - tracing::error!( - "failed to apply TLS lifetime for {}: {:?}", - peer, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Allocate Failed", - ); - if let Err(e2) = tls_stream - .write_all(&resp) - .await - { - tracing::error!( - "failed to write tls allocate error: {:?}", - e2 - ); - } - continue; - } - }; - let lifetime_secs = - applied.as_secs().min(u32::MAX as u64) - as u32; - let resp = build_allocate_success( - &msg.header, - &relay_addr, - lifetime_secs, - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls allocate success: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!( - "allocate failed (tls): {:?}", - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Allocate Failed", - ); - if let Err(e2) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls allocate error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_CREATE_PERMISSION => { - if alloc_clone.get_allocation(&peer).is_none() { - tracing::warn!( - "create-permission without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - - let mut added = 0usize; - for attr in msg - .attributes - .iter() - .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) - { - if let Some(peer_addr) = decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) { - match alloc_clone - .add_permission(peer, peer_addr) - { - Ok(()) => { - tracing::info!( - "added TLS permission for {} -> {}", - peer, - peer_addr - ); - added += 1; - } - Err(e) => { - tracing::error!( - "failed to persist TLS permission {} -> {}: {:?}", - peer, - peer_addr, - e - ); - } - } - } else { - tracing::warn!( - "invalid XOR-PEER-ADDRESS via TLS from {}", - peer - ); - } - } - - let resp = if added == 0 { - build_error_response( - &msg.header, - 400, - "No valid XOR-PEER-ADDRESS", - ) - } else { - build_success_response(&msg.header) - }; - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls response: {:?}", - e - ); - } - continue; - } - METHOD_CHANNEL_BIND => { - let allocation = match alloc_clone - .get_allocation(&peer) - { - Some(a) => a, - None => { - tracing::warn!( - "channel-bind without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let channel_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_CHANNEL_NUMBER); - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - - let channel = match channel_attr.and_then(|attr| { - if attr.value.len() >= 4 { - Some(u16::from_be_bytes([ - attr.value[0], - attr.value[1], - ])) - } else { - None - } - }) { - Some(c) => c, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing CHANNEL-NUMBER", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if channel < 0x4000 || channel > 0x7FFF { - let resp = build_error_response( - &msg.header, - 400, - "Channel Out Of Range", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response( - &msg.header, - 403, - "Peer Not Permitted", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - - match alloc_clone - .add_channel_binding(peer, channel, peer_addr) - { - Ok(()) => { - tracing::info!( - "bound channel 0x{:04x} for {} -> {} over TLS", - channel, - peer, - peer_addr - ); - let resp = - build_success_response(&msg.header); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls response: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!( - "failed TLS channel binding {} -> {} (0x{:04x}): {:?}", - peer, - peer_addr, - channel, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Channel Binding Failed", - ); - if let Err(e2) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_SEND => { - let allocation = - match alloc_clone.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!( - "send without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - let data_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_DATA); - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let data_attr = match data_attr { - Some(attr) => attr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing DATA Attribute", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response( - &msg.header, - 403, - "Peer Not Permitted", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - - match allocation - .send_to_peer(peer_addr, &data_attr.value) - .await - { - Ok(sent) => { - tracing::info!( - "forwarded {} bytes from {} to {} via TLS session", - sent, - peer, - peer_addr - ); - let resp = - build_success_response(&msg.header); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls response: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!( - "failed to send payload from {} to {} via TLS: {:?}", - peer, - peer_addr, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Peer Send Failed", - ); - if let Err(e2) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_REFRESH => { - let requested = extract_lifetime_seconds(&msg) - .map(|secs| Duration::from_secs(secs as u64)); - - match alloc_clone - .refresh_allocation(peer, requested) - { - Ok(applied) => { - if applied.is_zero() { - tracing::info!( - "allocation for {} released (tls)", - peer - ); - } else { - tracing::debug!( - "allocation for {} refreshed to {}s (tls)", - peer, - applied.as_secs() - ); - } - let resp = build_lifetime_success( - &msg.header, - applied.as_secs().min(u32::MAX as u64) - as u32, - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls refresh response: {:?}", - e - ); - } - } - Err(_) => { - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls refresh error: {:?}", - e - ); - } - } - } - continue; - } - METHOD_BINDING => { - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls binding response: {:?}", - e - ); - } - continue; - } - _ => { - let nonce = auth_clone.mint_nonce(&peer); - let resp = build_401_response( - &msg.header, - auth_clone.realm(), - &nonce, - 401, - "Unauthorized", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls fallback challenge: {:?}", - e - ); - } - continue; - } - } - } else { - tracing::debug!( - "failed to parse stun message on tls from {}", - peer - ); - } - } - } - Err(e) => { - tracing::error!("tls read error from {}: {:?}", peer, e); - break; - } - } + if let Err(e) = handle_tls_connection( + &mut tls_stream, + peer, + udp_clone, + auth_clone, + alloc_clone, + ) + .await + { + tracing::error!("TLS connection error: {:?}", e); } } Err(e) => tracing::error!("TLS accept error: {:?}", e), @@ -748,3 +92,597 @@ pub async fn serve_tls( }); } } + +#[allow(clippy::too_many_arguments)] +pub async fn handle_tls_connection( + tls_stream: &mut S, + peer: std::net::SocketAddr, + udp_sock: std::sync::Arc, + auth: AuthManager, + allocs: AllocationManager, +) -> anyhow::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + tracing::info!("accepted TLS connection from {}", peer); + let mut read_buf = vec![0u8; 4096]; + let mut buffer: Vec = Vec::new(); + + loop { + match tls_stream.read(&mut read_buf).await { + Ok(0) => { + tracing::info!("TLS client {} closed connection", peer); + break; + } + Ok(n) => { + buffer.extend_from_slice(&read_buf[..n]); + while buffer.len() >= 20 { + let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; + let total = len + 20; + if buffer.len() < total { + break; + } + let chunk = buffer.drain(..total).collect::>(); + match parse_message(&chunk) { + Ok(msg) => { + tracing::info!( + "STUN/TURN over TLS from {} type=0x{:04x} len={}", + peer, + msg.header.msg_type, + total + ); + + let requires_auth = matches!( + msg.header.msg_type, + METHOD_ALLOCATE + | METHOD_CREATE_PERMISSION + | METHOD_CHANNEL_BIND + | METHOD_SEND + | METHOD_REFRESH + ); + + if requires_auth { + match auth.authenticate(&msg, &peer).await { + AuthStatus::Granted { username } => { + tracing::debug!( + "TURN TLS auth ok for {} as {} (0x{:04x})", + peer, + username, + msg.header.msg_type + ); + } + AuthStatus::Challenge { nonce } => { + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 401, + "Unauthorized", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls challenge: {:?}", + e + ); + } + continue; + } + AuthStatus::StaleNonce { nonce } => { + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 438, + "Stale Nonce", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls stale nonce: {:?}", + e + ); + } + continue; + } + AuthStatus::Reject { code, reason } => { + let resp = build_error_response(&msg.header, code, reason); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls auth error: {:?}", + e + ); + } + continue; + } + } + } + + match msg.header.msg_type { + METHOD_ALLOCATE => { + let requested_lifetime = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)) + .filter(|d| !d.is_zero()); + + match allocs.allocate_for(peer, udp_sock.clone()).await { + Ok(relay_addr) => { + let applied = match allocs + .refresh_allocation(peer, requested_lifetime) + { + Ok(d) => d, + Err(e) => { + tracing::error!( + "failed to apply TLS lifetime for {}: {:?}", + peer, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Allocate Failed", + ); + if let Err(e2) = + tls_stream.write_all(&resp).await + { + tracing::error!( + "failed to write tls allocate error: {:?}", + e2 + ); + } + continue; + } + }; + let lifetime_secs = + applied.as_secs().min(u32::MAX as u64) as u32; + let resp = build_allocate_success( + &msg.header, + &relay_addr, + lifetime_secs, + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls allocate success: {:?}", + e + ); + } + } + Err(e) => { + tracing::error!("allocate failed (tls): {:?}", e); + let resp = build_error_response( + &msg.header, + 500, + "Allocate Failed", + ); + if let Err(e2) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls allocate error: {:?}", + e2 + ); + } + } + } + continue; + } + METHOD_CREATE_PERMISSION => { + if allocs.get_allocation(&peer).is_none() { + tracing::warn!( + "create-permission without allocation from {} (tls)", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!("failed to write tls error: {:?}", e); + } + continue; + } + + let mut added = 0usize; + for attr in msg + .attributes + .iter() + .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) + { + if let Some(peer_addr) = decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) { + match allocs.add_permission(peer, peer_addr) { + Ok(()) => { + tracing::info!( + "added TLS permission for {} -> {}", + peer, + peer_addr + ); + added += 1; + } + Err(e) => { + tracing::error!( + "failed to persist TLS permission {} -> {}: {:?}", + peer, + peer_addr, + e + ); + } + } + } else { + tracing::warn!( + "invalid XOR-PEER-ADDRESS via TLS from {}", + peer + ); + } + } + + let resp = if added == 0 { + build_error_response( + &msg.header, + 400, + "No valid XOR-PEER-ADDRESS", + ) + } else { + build_success_response(&msg.header) + }; + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!("failed to write tls response: {:?}", e); + } + continue; + } + METHOD_CHANNEL_BIND => { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!( + "channel-bind without allocation from {} (tls)", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let channel_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_CHANNEL_NUMBER); + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + + let channel = match channel_attr.and_then(|attr| { + if attr.value.len() >= 4 { + Some(u16::from_be_bytes([attr.value[0], attr.value[1]])) + } else { + None + } + }) { + Some(c) => c, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing CHANNEL-NUMBER", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if channel < 0x4000 || channel > 0x7FFF { + let resp = build_error_response( + &msg.header, + 400, + "Channel Out Of Range", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!("failed to write tls error: {:?}", e); + } + continue; + } + + let peer_addr = match peer_attr.and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!("failed to write tls error: {:?}", e); + } + continue; + } + + match allocs.add_channel_binding(peer, channel, peer_addr) { + Ok(()) => { + tracing::info!( + "bound channel 0x{:04x} for {} -> {} over TLS", + channel, + peer, + peer_addr + ); + let resp = build_success_response(&msg.header); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls response: {:?}", + e + ); + } + } + Err(e) => { + tracing::error!( + "failed TLS channel binding {} -> {} (0x{:04x}): {:?}", + peer, + peer_addr, + channel, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Channel Binding Failed", + ); + if let Err(e2) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e2 + ); + } + } + } + continue; + } + METHOD_SEND => { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!( + "send without allocation from {} (tls)", + peer + ); + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let data_attr = + msg.attributes.iter().find(|a| a.typ == ATTR_DATA); + + let peer_addr = match peer_attr.and_then(|attr| { + decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) + }) { + Some(addr) => addr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing XOR-PEER-ADDRESS", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + let data_attr = match data_attr { + Some(attr) => attr, + None => { + let resp = build_error_response( + &msg.header, + 400, + "Missing DATA Attribute", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e + ); + } + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response( + &msg.header, + 403, + "Peer Not Permitted", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!("failed to write tls error: {:?}", e); + } + continue; + } + + match allocation.send_to_peer(peer_addr, &data_attr.value).await + { + Ok(sent) => { + tracing::info!( + "forwarded {} bytes from {} to {} via TLS session", + sent, + peer, + peer_addr + ); + let resp = build_success_response(&msg.header); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls response: {:?}", + e + ); + } + } + Err(e) => { + tracing::error!( + "failed to send payload from {} to {} via TLS: {:?}", + peer, + peer_addr, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Peer Send Failed", + ); + if let Err(e2) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls error: {:?}", + e2 + ); + } + } + } + continue; + } + METHOD_REFRESH => { + let requested = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)); + + match allocs.refresh_allocation(peer, requested) { + Ok(applied) => { + if applied.is_zero() { + tracing::info!( + "allocation for {} released (tls)", + peer + ); + } else { + tracing::debug!( + "allocation for {} refreshed to {}s (tls)", + peer, + applied.as_secs() + ); + } + let resp = build_lifetime_success( + &msg.header, + applied.as_secs().min(u32::MAX as u64) as u32, + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls refresh response: {:?}", + e + ); + } + } + Err(_) => { + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls refresh error: {:?}", + e + ); + } + } + } + continue; + } + METHOD_BINDING => { + let resp = build_success_response(&msg.header); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls binding response: {:?}", + e + ); + } + continue; + } + _ => { + let nonce = auth.mint_nonce(&peer); + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 401, + "Unauthorized", + ); + if let Err(e) = tls_stream.write_all(&resp).await { + tracing::error!( + "failed to write tls fallback challenge: {:?}", + e + ); + } + continue; + } + } + } + Err(e) => { + tracing::warn!( + error = ?e, + length = chunk.len(), + "dropping unparseable STUN/TURN frame over TLS from {}", + peer + ); + } + } + } + } + Err(e) => { + tracing::error!("tls read error from {}: {:?}", peer, e); + break; + } + } + } + + Ok(()) +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs new file mode 100644 index 0000000..e950d19 --- /dev/null +++ b/tests/support/mod.rs @@ -0,0 +1,37 @@ +pub mod stun_builders; +pub mod tls; + +use std::net::SocketAddr; + +use niom_turn::auth::{AuthManager, InMemoryStore}; +use niom_turn::config::AuthOptions; + +/// Ensure tracing is initialised for integration tests with the library defaults. +#[allow(dead_code)] +pub fn init_tracing() { + niom_turn::logging::init_tracing(); +} + +/// Initialise tracing with a custom default directive (still overridable via `RUST_LOG`). +#[allow(dead_code)] +pub fn init_tracing_with(default_directive: &str) { + niom_turn::logging::init_tracing_with_default(default_directive); +} + +/// Helper to construct a basic AuthManager with a single credential for integration tests. +pub fn test_auth_manager(user: &str, password: &str) -> AuthManager { + let store = InMemoryStore::new(); + store.insert(user, password); + AuthManager::new(store, &AuthOptions::default()) +} + +/// Default TURN client credential used in integration scenarios. +pub fn default_test_credentials() -> (&'static str, &'static str) { + ("testuser", "secretpassword") +} + +/// Convenience to parse socket addresses in tests. +#[allow(dead_code)] +pub fn addr(addr: &str) -> SocketAddr { + addr.parse().expect("valid socket addr") +} diff --git a/tests/support/stun_builders.rs b/tests/support/stun_builders.rs new file mode 100644 index 0000000..f3b1019 --- /dev/null +++ b/tests/support/stun_builders.rs @@ -0,0 +1,210 @@ +#![allow(dead_code)] + +use bytes::BytesMut; +use niom_turn::constants::*; +use niom_turn::stun::{compute_message_integrity, parse_message}; +use uuid::Uuid; + +/// Build a basic STUN header for TURN requests with a freshly generated transaction id. +pub fn new_transaction_id() -> [u8; 12] { + let uuid = Uuid::new_v4(); + let mut trans = [0u8; 12]; + trans.copy_from_slice(&uuid.as_bytes()[..12]); + trans +} + +/// Construct a TURN Allocate request optionally including lifetime and auth attributes. +pub fn build_allocate_request( + username: Option<&str>, + realm: Option<&str>, + nonce: Option<&str>, + key: Option<&[u8]>, + lifetime: Option, +) -> Vec { + build_authenticated_request( + METHOD_ALLOCATE, + username, + realm, + nonce, + key, + lifetime, + None, + None, + ) +} + +/// Construct a TURN Refresh request using MESSAGE-INTEGRITY. +pub fn build_refresh_request( + trans: [u8; 12], + username: &str, + realm: &str, + nonce: &str, + key: &[u8], + lifetime: u32, +) -> Vec { + build_request_with_body( + METHOD_REFRESH, + Some(username), + Some(realm), + Some(nonce), + Some(key), + Some(lifetime), + None, + None, + Some(trans), + ) +} + +/// Build a CreatePermission request for the specified peer address. +pub fn build_create_permission_request( + username: &str, + realm: &str, + nonce: &str, + key: &[u8], + peer: &std::net::SocketAddr, +) -> Vec { + build_request_with_body( + METHOD_CREATE_PERMISSION, + Some(username), + Some(realm), + Some(nonce), + Some(key), + None, + Some(peer), + None, + None, + ) +} + +/// Build a Send indication to forward payload to peer. +pub fn build_send_request( + username: &str, + realm: &str, + nonce: &str, + key: &[u8], + peer: &std::net::SocketAddr, + payload: &[u8], +) -> Vec { + build_request_with_body( + METHOD_SEND, + Some(username), + Some(realm), + Some(nonce), + Some(key), + None, + Some(peer), + Some(payload), + None, + ) +} + +fn build_authenticated_request( + method: u16, + username: Option<&str>, + realm: Option<&str>, + nonce: Option<&str>, + key: Option<&[u8]>, + lifetime: Option, + peer: Option<&std::net::SocketAddr>, + payload: Option<&[u8]>, +) -> Vec { + build_request_with_body( + method, username, realm, nonce, key, lifetime, peer, payload, None, + ) +} + +fn build_request_with_body( + method: u16, + username: Option<&str>, + realm: Option<&str>, + nonce: Option<&str>, + key: Option<&[u8]>, + lifetime: Option, + peer: Option<&std::net::SocketAddr>, + payload: Option<&[u8]>, + override_trans: Option<[u8; 12]>, +) -> Vec { + let mut buf = BytesMut::new(); + buf.extend_from_slice(&method.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + let trans = override_trans.unwrap_or_else(new_transaction_id); + buf.extend_from_slice(&trans); + + if let Some(username) = username { + push_string_attr(&mut buf, ATTR_USERNAME, username); + } + if let Some(realm) = realm { + push_string_attr(&mut buf, ATTR_REALM, realm); + } + if let Some(nonce) = nonce { + push_string_attr(&mut buf, ATTR_NONCE, nonce); + } + if let Some(lifetime) = lifetime { + push_u32_attr(&mut buf, ATTR_LIFETIME, lifetime); + } + if let Some(peer) = peer { + let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans); + push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded); + } + if let Some(data) = payload { + push_bytes_attr(&mut buf, ATTR_DATA, data); + } + + if let Some(key) = key { + append_message_integrity(&mut buf, key); + } + + // update length + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + buf.to_vec() +} + +fn push_string_attr(buf: &mut BytesMut, typ: u16, value: &str) { + push_bytes_attr(buf, typ, value.as_bytes()); +} + +fn push_u32_attr(buf: &mut BytesMut, typ: u16, value: u32) { + push_bytes_attr(buf, typ, &value.to_be_bytes()); +} + +fn push_bytes_attr(buf: &mut BytesMut, typ: u16, data: &[u8]) { + buf.extend_from_slice(&typ.to_be_bytes()); + buf.extend_from_slice(&(data.len() as u16).to_be_bytes()); + buf.extend_from_slice(data); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } +} + +fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) { + // position before adding MESSAGE-INTEGRITY attribute + let attribute_start = buf.len(); + + // append attribute header and placeholder value + buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); + buf.extend_from_slice(&(20u16.to_be_bytes())); + let value_start = buf.len(); + buf.extend_from_slice(&[0u8; 20]); + + // update message length to include the attribute (spec requires this before HMAC) + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + // compute the HMAC over all bytes preceding the attribute (RFC 5389 ยง15.4) + let signed = compute_message_integrity(key, &buf[..attribute_start]); + + // write the computed MAC into the placeholder we appended above + buf[value_start..value_start + 20].copy_from_slice(&signed[..20]); +} + +/// Ensure builders produce parseable STUN messages. +pub fn parse(buf: &[u8]) -> niom_turn::models::stun::StunMessage { + parse_message(buf).expect("valid stun message") +} diff --git a/tests/support/tls.rs b/tests/support/tls.rs new file mode 100644 index 0000000..d1edbaa --- /dev/null +++ b/tests/support/tls.rs @@ -0,0 +1,59 @@ +#![allow(dead_code)] + +use std::sync::Arc; + +use std::net::IpAddr; + +use rcgen::{Certificate, CertificateParams, DistinguishedName, DnType, SanType}; +use tokio_rustls::rustls::{Certificate as RustlsCert, PrivateKey}; + +/// Generate a self-signed certificate and matching key for test TLS servers. +pub fn generate_self_signed_cert() -> (RustlsCert, PrivateKey) { + let mut params = CertificateParams::default(); + params.distinguished_name = DistinguishedName::new(); + params + .distinguished_name + .push(DnType::CommonName, "niom-turn-test"); + params.alg = &rcgen::PKCS_ECDSA_P256_SHA256; + params + .subject_alt_names + .push(SanType::DnsName("localhost".into())); + params.subject_alt_names.push(SanType::IpAddress( + "127.0.0.1" + .parse::() + .expect("localhost loopback ip"), + )); + + let cert = Certificate::from_params(params).expect("certificate params"); + let pem = cert.serialize_der().expect("cert der"); + let key = cert.serialize_private_key_der(); + (RustlsCert(pem), PrivateKey(key)) +} + +/// Build a rustls server config for tests using a generated certificate. +pub fn build_server_config() -> tokio_rustls::rustls::ServerConfig { + let (cert, key) = generate_self_signed_cert(); + let mut cfg = tokio_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .expect("valid test server config"); + cfg.alpn_protocols = vec![b"turn".to_vec()]; + cfg +} + +/// Build a rustls client config trusting the generated test certificate. +pub fn build_client_config(cert: &RustlsCert) -> tokio_rustls::rustls::ClientConfig { + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(cert).expect("add root cert"); + + tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth() +} + +/// Wrap tls config into acceptor for tests. +pub fn build_acceptor(cfg: tokio_rustls::rustls::ServerConfig) -> tokio_rustls::TlsAcceptor { + tokio_rustls::TlsAcceptor::from(Arc::new(cfg)) +} diff --git a/tests/tls_turn.rs b/tests/tls_turn.rs new file mode 100644 index 0000000..f052605 --- /dev/null +++ b/tests/tls_turn.rs @@ -0,0 +1,155 @@ +use std::sync::Arc; + +use niom_turn::alloc::AllocationManager; +use niom_turn::stun::parse_message; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, UdpSocket}; +use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; + +use crate::support::stun_builders::{build_allocate_request, build_refresh_request}; +use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager}; + +mod support; + +#[tokio::test] +async fn tls_allocate_refresh_flow() { + init_tracing_with("warn,niom_turn=info"); + + let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let udp_arc = Arc::new(udp); + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let (cert, key) = support::tls::generate_self_signed_cert(); + let mut cfg = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("server config"); + cfg.alpn_protocols.push(b"turn".to_vec()); + let acceptor = TlsAcceptor::from(Arc::new(cfg)); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + let udp_clone = udp_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acceptor = acceptor.clone(); + let udp_clone = udp_clone.clone(); + let auth_clone = auth_clone.clone(); + let alloc_clone = alloc_clone.clone(); + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(mut tls_stream) => { + match niom_turn::tls::handle_tls_connection( + &mut tls_stream, + peer, + udp_clone, + auth_clone, + alloc_clone, + ) + .await + { + Ok(_) => {} + Err(e) => { + tracing::error!("tls connection error: {:?}", e); + } + } + } + Err(e) => { + tracing::error!("tls accept failed: {:?}", e); + } + } + }); + } + }); + + // Build client config trusting generated cert + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(&cert).expect("add root"); + let client_config = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) + .await + .expect("tcp connect"); + let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); + let mut tls_stream = connector + .connect(domain, tcp_stream) + .await + .expect("tls connect"); + tracing::info!("client connected"); + + let allocate = build_allocate_request(None, None, None, None, None); + tls_stream + .write_all(&allocate) + .await + .expect("write allocate"); + tracing::info!("sent unauthenticated allocate request"); + + let mut buf = vec![0u8; 1500]; + let n = tls_stream.read(&mut buf).await.expect("read challenge"); + tracing::info!(bytes = n, "received nonce challenge"); + let resp = parse_message(&buf[..n]).expect("parse 401"); + let nonce_attr = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + let nonce = String::from_utf8(nonce_attr.value.clone()).expect("nonce str"); + + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + let allocate = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + tls_stream + .write_all(&allocate) + .await + .expect("write auth allocate"); + tracing::info!("sent authenticated allocate request"); + let n = tls_stream.read(&mut buf).await.expect("read success"); + tracing::info!(bytes = n, "received allocate success"); + let resp = parse_message(&buf[..n]).expect("parse alloc success"); + assert_eq!(resp.header.msg_type & 0x0110, 0x0100); + + let refresh = build_refresh_request( + resp.header.transaction_id, + username, + auth.realm(), + &nonce, + &key, + 0, + ); + tls_stream.write_all(&refresh).await.expect("write refresh"); + tracing::info!("sent refresh request"); + let n = tls_stream.read(&mut buf).await.expect("read refresh resp"); + tracing::info!(bytes = n, "received refresh response"); + let resp = parse_message(&buf[..n]).expect("parse refresh resp"); + let lifetime = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_LIFETIME) + .expect("lifetime attr"); + let secs = u32::from_be_bytes([ + lifetime.value[0], + lifetime.value[1], + lifetime.value[2], + lifetime.value[3], + ]); + assert_eq!(secs, 0); +} diff --git a/tests/udp_turn.rs b/tests/udp_turn.rs new file mode 100644 index 0000000..6c2bb6d --- /dev/null +++ b/tests/udp_turn.rs @@ -0,0 +1,243 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::InMemoryStore; +use niom_turn::server::udp_reader_loop; +use niom_turn::stun::parse_message; +use tokio::net::UdpSocket; + +use crate::support::stun_builders::{ + build_allocate_request, build_create_permission_request, build_refresh_request, + build_send_request, new_transaction_id, parse, +}; +use crate::support::{default_test_credentials, init_tracing, test_auth_manager}; + +mod support; + +const SERVER_ADDR: &str = "127.0.0.1:0"; + +#[tokio::test] +async fn allocate_requires_auth_then_succeeds() { + init_tracing(); + let (server, client_addr) = start_udp_server().await; + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_arc = Arc::new(server); + + let server_clone = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await; + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + // initial unauthenticated allocate should trigger 401 with nonce + let req = build_allocate_request(None, None, None, None, None); + client + .send_to(&req, client_addr) + .await + .expect("send unauth allocate"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge"); + let resp = parse_message(&buf[..len]).expect("parse 401"); + assert_eq!(resp.header.msg_type & 0x0110, 0x0110); + + let nonce = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr") + .value + .clone(); + let nonce_str = String::from_utf8(nonce).expect("nonce utf8"); + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + + let req = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce_str), + Some(&key), + Some(600), + ); + client + .send_to(&req, client_addr) + .await + .expect("send auth allocate"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv success"); + let resp = parse(&buf[..len]); + assert_eq!(resp.header.msg_type & 0x0110, 0x0100); +} + +#[tokio::test] +async fn refresh_zero_lifetime_releases_allocation() { + init_tracing(); + let (server, client_addr) = start_udp_server().await; + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_arc = Arc::new(server); + + let server_clone = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await; + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + let nonce = + perform_authenticated_allocate(&client, client_addr, &auth, username, password, &allocs) + .await; + + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + let trans_id = new_transaction_id(); + let refresh = build_refresh_request(trans_id, username, auth.realm(), &nonce, &key, 0); + client + .send_to(&refresh, client_addr) + .await + .expect("send refresh"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh resp"); + let resp = parse(&buf[..len]); + assert_eq!(resp.header.msg_type & 0x0110, 0x0100); + let lifetime = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_LIFETIME) + .expect("lifetime attr"); + assert_eq!( + u32::from_be_bytes([ + lifetime.value[0], + lifetime.value[1], + lifetime.value[2], + lifetime.value[3] + ]), + 0 + ); + + assert!(allocs + .get_allocation(&client.local_addr().unwrap()) + .is_none()); +} + +#[tokio::test] +async fn create_permission_and_send_relays_data() { + init_tracing(); + let (server, client_addr) = start_udp_server().await; + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_arc = Arc::new(server); + + let server_clone = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await; + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + let nonce = + perform_authenticated_allocate(&client, client_addr, &auth, username, password, &allocs) + .await; + + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind"); + let relay_addr = allocs + .get_allocation(&client.local_addr().unwrap()) + .expect("allocation exists") + .relay_addr; + + let perm_req = build_create_permission_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + ); + client + .send_to(&perm_req, client_addr) + .await + .expect("send create permission"); + let mut buf = [0u8; 1500]; + client.recv_from(&mut buf).await.expect("permission resp"); + + let payload = b"hello-turn"; + let send_req = build_send_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + payload, + ); + client + .send_to(&send_req, client_addr) + .await + .expect("send indication"); + + let mut peer_buf = [0u8; 1500]; + let (len, addr) = peer_sock.recv_from(&mut peer_buf).await.expect("peer recv"); + assert_eq!(len, payload.len()); + assert_eq!(&peer_buf[..len], payload); + assert_eq!(addr.port(), relay_addr.port()); + assert!(addr.ip().is_loopback()); +} + +async fn start_udp_server() -> (UdpSocket, SocketAddr) { + let server = UdpSocket::bind(SERVER_ADDR).await.expect("server bind"); + let addr = server.local_addr().expect("server addr"); + (server, addr) +} + +async fn perform_authenticated_allocate( + client: &UdpSocket, + server_addr: SocketAddr, + auth: &niom_turn::auth::AuthManager, + username: &str, + password: &str, + allocs: &AllocationManager, +) -> String { + init_tracing(); + // trigger nonce challenge + let req = build_allocate_request(None, None, None, None, None); + client + .send_to(&req, server_addr) + .await + .expect("send initial allocate"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv nonce"); + let resp = parse_message(&buf[..len]).expect("parse nonce resp"); + let nonce = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr") + .value + .clone(); + let nonce = String::from_utf8(nonce).expect("nonce utf8"); + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + + let req = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + client + .send_to(&req, server_addr) + .await + .expect("send auth allocate"); + client.recv_from(&mut buf).await.expect("recv success"); + + assert!(allocs + .get_allocation(&client.local_addr().unwrap()) + .is_some()); + nonce +}