diff --git a/src/alloc.rs b/src/alloc.rs index 6dea58f..484d444 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -17,6 +17,7 @@ pub struct Allocation { _socket: Arc, permissions: Arc>>, channel_bindings: Arc>>, + expiry: Arc>, } #[derive(Clone, Default)] @@ -114,15 +115,18 @@ impl AllocationManager { _socket: relay_arc, permissions: Arc::new(Mutex::new(HashMap::new())), channel_bindings: Arc::new(Mutex::new(HashMap::new())), + expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)), }; tracing::info!("created allocation for {} -> {}", client, relay_local); let mut m = self.inner.lock().unwrap(); + prune_expired_locked(&mut m); m.insert(client, alloc); Ok(relay_local) } pub fn get_allocation(&self, client: &SocketAddr) -> Option { - let m = self.inner.lock().unwrap(); + let mut m = self.inner.lock().unwrap(); + prune_expired_locked(&mut m); m.get(client).cloned() } @@ -130,6 +134,7 @@ impl AllocationManager { /// to the specified peer address. Permissions currently expire after a static timeout. pub fn add_permission(&self, client: SocketAddr, peer: SocketAddr) -> anyhow::Result<()> { let mut guard = self.inner.lock().unwrap(); + prune_expired_locked(&mut guard); let alloc = guard .get_mut(&client) .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; @@ -147,6 +152,7 @@ impl AllocationManager { peer: SocketAddr, ) -> anyhow::Result<()> { let mut guard = self.inner.lock().unwrap(); + prune_expired_locked(&mut guard); let alloc = guard .get_mut(&client) .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; @@ -155,11 +161,46 @@ impl AllocationManager { bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME)); Ok(()) } + + /// Refresh allocation lifetime or delete it when zero requested. Returns the applied lifetime. + pub fn refresh_allocation( + &self, + client: SocketAddr, + requested: Option, + ) -> anyhow::Result { + let mut guard = self.inner.lock().unwrap(); + prune_expired_locked(&mut guard); + let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME); + if let Some(d) = requested { + if d.is_zero() { + guard.remove(&client); + return Ok(Duration::from_secs(0)); + } + } + + let alloc = guard + .get(&client) + .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; + let mut expiry = alloc.expiry.lock().unwrap(); + let now = Instant::now(); + let clamped = clamp_lifetime(req); + *expiry = now + clamped; + Ok(clamped) + } + + /// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed. + pub fn remove_allocation(&self, client: &SocketAddr) -> bool { + let mut guard = self.inner.lock().unwrap(); + guard.remove(client).is_some() + } } impl Allocation { /// Check whether a peer address is currently permitted for this allocation. pub fn is_peer_allowed(&self, peer: &SocketAddr) -> bool { + if self.is_expired() { + return false; + } let mut perms = self.permissions.lock().unwrap(); prune_permissions(&mut perms); perms.contains_key(peer) @@ -167,6 +208,9 @@ impl Allocation { /// Resolve an active channel binding to its peer socket, if still valid. pub fn channel_peer(&self, channel: u16) -> Option { + if self.is_expired() { + return None; + } let mut bindings = self.channel_bindings.lock().unwrap(); prune_channel_bindings(&mut bindings); bindings.get(&channel).map(|(peer, _)| *peer) @@ -174,6 +218,9 @@ impl Allocation { /// Return the bound channel number for a peer if available. pub fn channel_for_peer(&self, peer: &SocketAddr) -> Option { + if self.is_expired() { + return None; + } let mut bindings = self.channel_bindings.lock().unwrap(); prune_channel_bindings(&mut bindings); bindings @@ -186,9 +233,23 @@ impl Allocation { let sent = self._socket.send_to(data, peer).await?; Ok(sent) } + + /// Remaining lifetime for this allocation (saturates at zero). + pub fn remaining_lifetime(&self) -> Duration { + let expiry = self.expiry.lock().unwrap(); + let now = Instant::now(); + expiry.saturating_duration_since(now) + } + + pub fn is_expired(&self) -> bool { + self.remaining_lifetime().is_zero() + } } const PERMISSION_LIFETIME: Duration = Duration::from_secs(300); +const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); +const MIN_ALLOCATION_LIFETIME: Duration = Duration::from_secs(60); +const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); fn prune_permissions(perms: &mut HashMap) { let now = Instant::now(); @@ -199,3 +260,21 @@ fn prune_channel_bindings(bindings: &mut HashMap) { let now = Instant::now(); bindings.retain(|_, (_, expiry)| *expiry > now); } + +fn prune_expired_locked(map: &mut HashMap) { + let now = Instant::now(); + map.retain(|_, alloc| { + let expiry = alloc.expiry.lock().unwrap(); + *expiry > now + }); +} + +fn clamp_lifetime(requested: Duration) -> Duration { + if requested < MIN_ALLOCATION_LIFETIME { + MIN_ALLOCATION_LIFETIME + } else if requested > MAX_ALLOCATION_LIFETIME { + MAX_ALLOCATION_LIFETIME + } else { + requested + } +} diff --git a/src/main.rs b/src/main.rs index c3a62a9..e3fb84a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ //! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners. use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::net::UdpSocket; use tracing::{error, info}; @@ -11,8 +12,9 @@ use niom_turn::auth::{AuthManager, AuthStatus, InMemoryStore}; use niom_turn::config::{AuthOptions, Config}; use niom_turn::constants::*; use niom_turn::stun::{ - build_401_response, build_error_response, build_success_response, decode_xor_peer_address, - encode_xor_relayed_address, parse_channel_data, parse_message, + build_401_response, build_allocate_success, build_error_response, build_lifetime_success, + build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data, + parse_message, }; #[tokio::main] @@ -220,36 +222,41 @@ async fn udp_reader_loop( match msg.header.msg_type { METHOD_ALLOCATE => { - use bytes::BytesMut; + let requested_lifetime = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)) + .filter(|d| !d.is_zero()); + match allocs.allocate_for(peer, udp.clone()).await { Ok(relay_addr) => { - let mut out = BytesMut::new(); - let success_type = msg.header.msg_type | CLASS_SUCCESS; - out.extend_from_slice(&success_type.to_be_bytes()); - out.extend_from_slice(&0u16.to_be_bytes()); - out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); - out.extend_from_slice(&msg.header.transaction_id); - let attr_val = encode_xor_relayed_address( - &relay_addr, - &msg.header.transaction_id, - ); - out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); - out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); - out.extend_from_slice(&attr_val); - while (out.len() % 4) != 0 { - out.extend_from_slice(&[0]); - } - let total_len = (out.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - out[2] = len_bytes[0]; - out[3] = len_bytes[1]; - let vec_out = out.to_vec(); + let applied = + match allocs.refresh_allocation(peer, requested_lifetime) { + Ok(d) => d, + Err(e) => { + tracing::error!( + "failed to apply lifetime for {}: {:?}", + peer, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Allocate Failed", + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + }; + + let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32; + let resp = + build_allocate_success(&msg.header, &relay_addr, lifetime_secs); tracing::info!( - "sending allocate success -> {} bytes hex={} ", - vec_out.len(), - hex::encode(&vec_out) + "allocated relay {} for {} lifetime={}s", + relay_addr, + peer, + lifetime_secs ); - let _ = udp.send_to(&vec_out, &peer).await; + let _ = udp.send_to(&resp, &peer).await; } Err(e) => { tracing::error!("allocate failed: {:?}", e); @@ -484,9 +491,32 @@ async fn udp_reader_loop( continue; } METHOD_REFRESH => { - // Refresh support is still MVP-level; acknowledge so clients extend allocations. - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; + let requested = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)); + + match allocs.refresh_allocation(peer, requested) { + Ok(applied) => { + if applied.is_zero() { + tracing::info!("allocation for {} released", peer); + } else { + tracing::debug!( + "allocation for {} refreshed to {}s", + peer, + applied.as_secs() + ); + } + let resp = build_lifetime_success( + &msg.header, + applied.as_secs().min(u32::MAX as u64) as u32, + ); + let _ = udp.send_to(&resp, &peer).await; + } + Err(_) => { + let resp = + build_error_response(&msg.header, 437, "Allocation Mismatch"); + let _ = udp.send_to(&resp, &peer).await; + } + } continue; } _ => { diff --git a/src/stun.rs b/src/stun.rs index a1349ed..458b164 100644 --- a/src/stun.rs +++ b/src/stun.rs @@ -158,6 +158,87 @@ pub fn build_error_response(req: &StunHeader, code: u16, reason: &str) -> Vec Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let relay_val = encode_xor_relayed_address(relay, &req.transaction_id); + buf.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); + buf.extend_from_slice(&((relay_val.len() as u16).to_be_bytes())); + buf.extend_from_slice(&relay_val); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let lifetime_bytes = lifetime_secs.to_be_bytes(); + buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes()); + buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(&lifetime_bytes); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + buf.to_vec() +} + +/// Build a success response that advertises the remaining allocation lifetime. +pub fn build_lifetime_success(req: &StunHeader, lifetime_secs: u32) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let lifetime_bytes = lifetime_secs.to_be_bytes(); + buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes()); + buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(&lifetime_bytes); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + buf.to_vec() +} + +/// Extract requested LIFETIME (seconds) from STUN/TURN message attributes if present. +pub fn extract_lifetime_seconds(msg: &StunMessage) -> Option { + msg.attributes + .iter() + .find(|a| a.typ == ATTR_LIFETIME) + .and_then(|attr| { + if attr.value.len() >= 4 { + Some(u32::from_be_bytes([ + attr.value[0], + attr.value[1], + attr.value[2], + attr.value[3], + ])) + } else { + None + } + }) +} + /// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { msg.attributes diff --git a/src/tls.rs b/src/tls.rs index d068041..4658500 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -4,6 +4,7 @@ use anyhow::Context; use std::fs::File; use std::io::BufReader; use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; @@ -13,8 +14,8 @@ use crate::alloc::AllocationManager; use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; use crate::constants::*; use crate::stun::{ - build_401_response, build_error_response, build_success_response, decode_xor_peer_address, - encode_xor_relayed_address, parse_message, + build_401_response, build_allocate_success, build_error_response, build_lifetime_success, + build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_message, }; fn load_certs(path: &str) -> anyhow::Result> { @@ -177,45 +178,57 @@ pub async fn serve_tls( match msg.header.msg_type { METHOD_ALLOCATE => { - use bytes::BytesMut; + let requested_lifetime = + extract_lifetime_seconds(&msg) + .map(|secs| { + Duration::from_secs(secs as u64) + }) + .filter(|d| !d.is_zero()); + match alloc_clone .allocate_for(peer, udp_clone.clone()) .await { Ok(relay_addr) => { - let mut out = BytesMut::new(); - let success_type = - msg.header.msg_type | CLASS_SUCCESS; - out.extend_from_slice( - &success_type.to_be_bytes(), - ); - out.extend_from_slice(&0u16.to_be_bytes()); - out.extend_from_slice(&MAGIC_COOKIE_BYTES); - - out.extend_from_slice( - &msg.header.transaction_id, - ); - let attr_val = encode_xor_relayed_address( + let applied = match alloc_clone + .refresh_allocation( + peer, + requested_lifetime, + ) { + Ok(d) => d, + Err(e) => { + tracing::error!( + "failed to apply TLS lifetime for {}: {:?}", + peer, + e + ); + let resp = build_error_response( + &msg.header, + 500, + "Allocate Failed", + ); + if let Err(e2) = tls_stream + .write_all(&resp) + .await + { + tracing::error!( + "failed to write tls allocate error: {:?}", + e2 + ); + } + continue; + } + }; + let lifetime_secs = + applied.as_secs().min(u32::MAX as u64) + as u32; + let resp = build_allocate_success( + &msg.header, &relay_addr, - &msg.header.transaction_id, + lifetime_secs, ); - out.extend_from_slice( - &ATTR_XOR_RELAYED_ADDRESS.to_be_bytes(), - ); - out.extend_from_slice( - &((attr_val.len() as u16) - .to_be_bytes()), - ); - out.extend_from_slice(&attr_val); - while (out.len() % 4) != 0 { - out.extend_from_slice(&[0]); - } - let total_len = (out.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - out[2] = len_bytes[0]; - out[3] = len_bytes[1]; if let Err(e) = - tls_stream.write_all(&out).await + tls_stream.write_all(&resp).await { tracing::error!( "failed to write tls allocate success: {:?}", @@ -636,12 +649,54 @@ pub async fn serve_tls( continue; } METHOD_REFRESH => { - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls refresh response: {:?}", - e - ); + let requested = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)); + + match alloc_clone + .refresh_allocation(peer, requested) + { + Ok(applied) => { + if applied.is_zero() { + tracing::info!( + "allocation for {} released (tls)", + peer + ); + } else { + tracing::debug!( + "allocation for {} refreshed to {}s (tls)", + peer, + applied.as_secs() + ); + } + let resp = build_lifetime_success( + &msg.header, + applied.as_secs().min(u32::MAX as u64) + as u32, + ); + if let Err(e) = + tls_stream.write_all(&resp).await + { + tracing::error!( + "failed to write tls refresh response: {:?}", + e + ); + } + } + Err(_) => { + let resp = build_error_response( + &msg.header, + 437, + "Allocation Mismatch", + ); + if let Err(e) = + tls_stream.write_all(&resp).await + { + tracing::error!( + "failed to write tls refresh error: {:?}", + e + ); + } + } } continue; }