Improved TLS lifetime handling.

This commit is contained in:
ghost 2025-11-17 01:54:49 +01:00
parent 5a74a0d945
commit fe0b4559d0
4 changed files with 316 additions and 71 deletions

View File

@ -17,6 +17,7 @@ pub struct Allocation {
_socket: Arc<UdpSocket>,
permissions: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
channel_bindings: Arc<Mutex<HashMap<u16, (SocketAddr, Instant)>>>,
expiry: Arc<Mutex<Instant>>,
}
#[derive(Clone, Default)]
@ -114,15 +115,18 @@ impl AllocationManager {
_socket: relay_arc,
permissions: Arc::new(Mutex::new(HashMap::new())),
channel_bindings: Arc::new(Mutex::new(HashMap::new())),
expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)),
};
tracing::info!("created allocation for {} -> {}", client, relay_local);
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.insert(client, alloc);
Ok(relay_local)
}
pub fn get_allocation(&self, client: &SocketAddr) -> Option<Allocation> {
let m = self.inner.lock().unwrap();
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.get(client).cloned()
}
@ -130,6 +134,7 @@ impl AllocationManager {
/// to the specified peer address. Permissions currently expire after a static timeout.
pub fn add_permission(&self, client: SocketAddr, peer: SocketAddr) -> anyhow::Result<()> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let alloc = guard
.get_mut(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
@ -147,6 +152,7 @@ impl AllocationManager {
peer: SocketAddr,
) -> anyhow::Result<()> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let alloc = guard
.get_mut(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
@ -155,11 +161,46 @@ impl AllocationManager {
bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME));
Ok(())
}
/// Refresh allocation lifetime or delete it when zero requested. Returns the applied lifetime.
pub fn refresh_allocation(
&self,
client: SocketAddr,
requested: Option<Duration>,
) -> anyhow::Result<Duration> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME);
if let Some(d) = requested {
if d.is_zero() {
guard.remove(&client);
return Ok(Duration::from_secs(0));
}
}
let alloc = guard
.get(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut expiry = alloc.expiry.lock().unwrap();
let now = Instant::now();
let clamped = clamp_lifetime(req);
*expiry = now + clamped;
Ok(clamped)
}
/// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed.
pub fn remove_allocation(&self, client: &SocketAddr) -> bool {
let mut guard = self.inner.lock().unwrap();
guard.remove(client).is_some()
}
}
impl Allocation {
/// Check whether a peer address is currently permitted for this allocation.
pub fn is_peer_allowed(&self, peer: &SocketAddr) -> bool {
if self.is_expired() {
return false;
}
let mut perms = self.permissions.lock().unwrap();
prune_permissions(&mut perms);
perms.contains_key(peer)
@ -167,6 +208,9 @@ impl Allocation {
/// Resolve an active channel binding to its peer socket, if still valid.
pub fn channel_peer(&self, channel: u16) -> Option<SocketAddr> {
if self.is_expired() {
return None;
}
let mut bindings = self.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings);
bindings.get(&channel).map(|(peer, _)| *peer)
@ -174,6 +218,9 @@ impl Allocation {
/// Return the bound channel number for a peer if available.
pub fn channel_for_peer(&self, peer: &SocketAddr) -> Option<u16> {
if self.is_expired() {
return None;
}
let mut bindings = self.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings);
bindings
@ -186,9 +233,23 @@ impl Allocation {
let sent = self._socket.send_to(data, peer).await?;
Ok(sent)
}
/// Remaining lifetime for this allocation (saturates at zero).
pub fn remaining_lifetime(&self) -> Duration {
let expiry = self.expiry.lock().unwrap();
let now = Instant::now();
expiry.saturating_duration_since(now)
}
pub fn is_expired(&self) -> bool {
self.remaining_lifetime().is_zero()
}
}
const PERMISSION_LIFETIME: Duration = Duration::from_secs(300);
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
const MIN_ALLOCATION_LIFETIME: Duration = Duration::from_secs(60);
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
fn prune_permissions(perms: &mut HashMap<SocketAddr, Instant>) {
let now = Instant::now();
@ -199,3 +260,21 @@ fn prune_channel_bindings(bindings: &mut HashMap<u16, (SocketAddr, Instant)>) {
let now = Instant::now();
bindings.retain(|_, (_, expiry)| *expiry > now);
}
fn prune_expired_locked(map: &mut HashMap<SocketAddr, Allocation>) {
let now = Instant::now();
map.retain(|_, alloc| {
let expiry = alloc.expiry.lock().unwrap();
*expiry > now
});
}
fn clamp_lifetime(requested: Duration) -> Duration {
if requested < MIN_ALLOCATION_LIFETIME {
MIN_ALLOCATION_LIFETIME
} else if requested > MAX_ALLOCATION_LIFETIME {
MAX_ALLOCATION_LIFETIME
} else {
requested
}
}

View File

@ -2,6 +2,7 @@
//! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tracing::{error, info};
@ -11,8 +12,9 @@ use niom_turn::auth::{AuthManager, AuthStatus, InMemoryStore};
use niom_turn::config::{AuthOptions, Config};
use niom_turn::constants::*;
use niom_turn::stun::{
build_401_response, build_error_response, build_success_response, decode_xor_peer_address,
encode_xor_relayed_address, parse_channel_data, parse_message,
build_401_response, build_allocate_success, build_error_response, build_lifetime_success,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data,
parse_message,
};
#[tokio::main]
@ -220,36 +222,41 @@ async fn udp_reader_loop(
match msg.header.msg_type {
METHOD_ALLOCATE => {
use bytes::BytesMut;
let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero());
match allocs.allocate_for(peer, udp.clone()).await {
Ok(relay_addr) => {
let mut out = BytesMut::new();
let success_type = msg.header.msg_type | CLASS_SUCCESS;
out.extend_from_slice(&success_type.to_be_bytes());
out.extend_from_slice(&0u16.to_be_bytes());
out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes());
out.extend_from_slice(&msg.header.transaction_id);
let attr_val = encode_xor_relayed_address(
&relay_addr,
&msg.header.transaction_id,
);
out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes());
out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes()));
out.extend_from_slice(&attr_val);
while (out.len() % 4) != 0 {
out.extend_from_slice(&[0]);
}
let total_len = (out.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
out[2] = len_bytes[0];
out[3] = len_bytes[1];
let vec_out = out.to_vec();
let applied =
match allocs.refresh_allocation(peer, requested_lifetime) {
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32;
let resp =
build_allocate_success(&msg.header, &relay_addr, lifetime_secs);
tracing::info!(
"sending allocate success -> {} bytes hex={} ",
vec_out.len(),
hex::encode(&vec_out)
"allocated relay {} for {} lifetime={}s",
relay_addr,
peer,
lifetime_secs
);
let _ = udp.send_to(&vec_out, &peer).await;
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!("allocate failed: {:?}", e);
@ -484,9 +491,32 @@ async fn udp_reader_loop(
continue;
}
METHOD_REFRESH => {
// Refresh support is still MVP-level; acknowledge so clients extend allocations.
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match allocs.refresh_allocation(peer, requested) {
Ok(applied) => {
if applied.is_zero() {
tracing::info!("allocation for {} released", peer);
} else {
tracing::debug!(
"allocation for {} refreshed to {}s",
peer,
applied.as_secs()
);
}
let resp = build_lifetime_success(
&msg.header,
applied.as_secs().min(u32::MAX as u64) as u32,
);
let _ = udp.send_to(&resp, &peer).await;
}
Err(_) => {
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
_ => {

View File

@ -158,6 +158,87 @@ pub fn build_error_response(req: &StunHeader, code: u16, reason: &str) -> Vec<u8
buf.to_vec()
}
/// Build an Allocate success response containing XOR-RELAYED-ADDRESS and LIFETIME attributes.
pub fn build_allocate_success(
req: &StunHeader,
relay: &std::net::SocketAddr,
lifetime_secs: u32,
) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let relay_val = encode_xor_relayed_address(relay, &req.transaction_id);
buf.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes());
buf.extend_from_slice(&((relay_val.len() as u16).to_be_bytes()));
buf.extend_from_slice(&relay_val);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
let lifetime_bytes = lifetime_secs.to_be_bytes();
buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes());
buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(&lifetime_bytes);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec()
}
/// Build a success response that advertises the remaining allocation lifetime.
pub fn build_lifetime_success(req: &StunHeader, lifetime_secs: u32) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let lifetime_bytes = lifetime_secs.to_be_bytes();
buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes());
buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(&lifetime_bytes);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec()
}
/// Extract requested LIFETIME (seconds) from STUN/TURN message attributes if present.
pub fn extract_lifetime_seconds(msg: &StunMessage) -> Option<u32> {
msg.attributes
.iter()
.find(|a| a.typ == ATTR_LIFETIME)
.and_then(|attr| {
if attr.value.len() >= 4 {
Some(u32::from_be_bytes([
attr.value[0],
attr.value[1],
attr.value[2],
attr.value[3],
]))
} else {
None
}
})
}
/// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present
pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> {
msg.attributes

View File

@ -4,6 +4,7 @@ use anyhow::Context;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
@ -13,8 +14,8 @@ use crate::alloc::AllocationManager;
use crate::auth::{AuthManager, AuthStatus, InMemoryStore};
use crate::constants::*;
use crate::stun::{
build_401_response, build_error_response, build_success_response, decode_xor_peer_address,
encode_xor_relayed_address, parse_message,
build_401_response, build_allocate_success, build_error_response, build_lifetime_success,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_message,
};
fn load_certs(path: &str) -> anyhow::Result<Vec<Certificate>> {
@ -177,45 +178,57 @@ pub async fn serve_tls(
match msg.header.msg_type {
METHOD_ALLOCATE => {
use bytes::BytesMut;
let requested_lifetime =
extract_lifetime_seconds(&msg)
.map(|secs| {
Duration::from_secs(secs as u64)
})
.filter(|d| !d.is_zero());
match alloc_clone
.allocate_for(peer, udp_clone.clone())
.await
{
Ok(relay_addr) => {
let mut out = BytesMut::new();
let success_type =
msg.header.msg_type | CLASS_SUCCESS;
out.extend_from_slice(
&success_type.to_be_bytes(),
);
out.extend_from_slice(&0u16.to_be_bytes());
out.extend_from_slice(&MAGIC_COOKIE_BYTES);
out.extend_from_slice(
&msg.header.transaction_id,
);
let attr_val = encode_xor_relayed_address(
let applied = match alloc_clone
.refresh_allocation(
peer,
requested_lifetime,
) {
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply TLS lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
if let Err(e2) = tls_stream
.write_all(&resp)
.await
{
tracing::error!(
"failed to write tls allocate error: {:?}",
e2
);
}
continue;
}
};
let lifetime_secs =
applied.as_secs().min(u32::MAX as u64)
as u32;
let resp = build_allocate_success(
&msg.header,
&relay_addr,
&msg.header.transaction_id,
lifetime_secs,
);
out.extend_from_slice(
&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes(),
);
out.extend_from_slice(
&((attr_val.len() as u16)
.to_be_bytes()),
);
out.extend_from_slice(&attr_val);
while (out.len() % 4) != 0 {
out.extend_from_slice(&[0]);
}
let total_len = (out.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
out[2] = len_bytes[0];
out[3] = len_bytes[1];
if let Err(e) =
tls_stream.write_all(&out).await
tls_stream.write_all(&resp).await
{
tracing::error!(
"failed to write tls allocate success: {:?}",
@ -636,12 +649,54 @@ pub async fn serve_tls(
continue;
}
METHOD_REFRESH => {
let resp = build_success_response(&msg.header);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls refresh response: {:?}",
e
);
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match alloc_clone
.refresh_allocation(peer, requested)
{
Ok(applied) => {
if applied.is_zero() {
tracing::info!(
"allocation for {} released (tls)",
peer
);
} else {
tracing::debug!(
"allocation for {} refreshed to {}s (tls)",
peer,
applied.as_secs()
);
}
let resp = build_lifetime_success(
&msg.header,
applied.as_secs().min(u32::MAX as u64)
as u32,
);
if let Err(e) =
tls_stream.write_all(&resp).await
{
tracing::error!(
"failed to write tls refresh response: {:?}",
e
);
}
}
Err(_) => {
let resp = build_error_response(
&msg.header,
437,
"Allocation Mismatch",
);
if let Err(e) =
tls_stream.write_all(&resp).await
{
tracing::error!(
"failed to write tls refresh error: {:?}",
e
);
}
}
}
continue;
}