niom-turn/src/alloc.rs

485 lines
18 KiB
Rust

//! Allocation manager: provisions relay sockets and forwards packets for TURN allocations.
//! Backlog: permission tables, channel bindings, allocation refresh timers, and rate limits.
use std::collections::HashMap;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tracing::info;
use crate::stun::{build_channel_data, build_data_indication};
#[derive(thiserror::Error, Debug)]
pub enum AllocationError {
#[error("allocation quota exceeded")]
AllocationQuotaExceeded,
#[error("permission quota exceeded")]
PermissionQuotaExceeded,
#[error("channel binding quota exceeded")]
ChannelQuotaExceeded,
}
#[derive(Clone)]
pub enum ClientSink {
Udp { sock: Arc<UdpSocket>, addr: SocketAddr },
Stream { tx: mpsc::Sender<Vec<u8>> },
}
impl ClientSink {
pub async fn send(&self, data: Vec<u8>) -> anyhow::Result<()> {
match self {
ClientSink::Udp { sock, addr } => {
sock.send_to(&data, addr).await?;
Ok(())
}
ClientSink::Stream { tx } => {
tx.send(data)
.await
.map_err(|_| anyhow::anyhow!("client stream closed"))
}
}
}
}
#[derive(Clone)]
pub struct Allocation {
pub client: SocketAddr,
pub relay_addr: SocketAddr,
// keep the socket so it stays bound
_socket: Arc<UdpSocket>,
stop: Arc<Notify>,
permissions: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
channel_bindings: Arc<Mutex<HashMap<u16, (SocketAddr, Instant)>>>,
expiry: Arc<Mutex<Instant>>,
}
#[derive(Clone, Default)]
pub struct AllocationManager {
inner: Arc<Mutex<HashMap<SocketAddr, Allocation>>>,
opts: AllocationOptions,
}
#[derive(Clone, Debug)]
pub struct AllocationOptions {
pub relay_bind_ip: IpAddr,
pub relay_port_min: Option<u16>,
pub relay_port_max: Option<u16>,
pub advertised_ip: Option<IpAddr>,
pub max_allocations_per_ip: Option<u32>,
pub max_permissions_per_allocation: Option<u32>,
pub max_channel_bindings_per_allocation: Option<u32>,
}
impl Default for AllocationOptions {
fn default() -> Self {
Self {
relay_bind_ip: IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
relay_port_min: None,
relay_port_max: None,
advertised_ip: None,
max_allocations_per_ip: None,
max_permissions_per_allocation: None,
max_channel_bindings_per_allocation: None,
}
}
}
impl AllocationManager {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
opts: AllocationOptions::default(),
}
}
pub fn new_with_options(opts: AllocationOptions) -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
opts,
}
}
/// Translate a locally-bound relay socket address into the address that should be
/// advertised to clients (e.g. replace 0.0.0.0 with a public IP).
pub fn relay_addr_for_response(&self, relay_local: SocketAddr) -> SocketAddr {
match self.opts.advertised_ip {
Some(ip) => SocketAddr::new(ip, relay_local.port()),
None => relay_local,
}
}
/// Create a relay UDP socket for the given client and spawn a relay loop that forwards
/// any packets received on the relay socket back to the client via the provided client sink.
pub async fn allocate_for(
&self,
client: SocketAddr,
client_sink: ClientSink,
) -> anyhow::Result<SocketAddr> {
// If an allocation already exists for this exact 5-tuple, reuse it.
{
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
if let Some(existing) = guard.get(&client) {
return Ok(existing.relay_addr);
}
if let Some(max) = self.opts.max_allocations_per_ip {
let count_for_ip = guard
.values()
.filter(|a| a.client.ip() == client.ip())
.count() as u32;
if count_for_ip >= max {
return Err(anyhow::anyhow!(AllocationError::AllocationQuotaExceeded));
}
}
}
// bind relay socket (optionally within configured port range)
let relay = match (self.opts.relay_port_min, self.opts.relay_port_max) {
(Some(min), Some(max)) if min > 0 && max >= min => {
let mut bound: Option<UdpSocket> = None;
let mut last_err: Option<anyhow::Error> = None;
for port in min..=max {
let addr = SocketAddr::new(self.opts.relay_bind_ip, port);
match UdpSocket::bind(addr).await {
Ok(sock) => {
bound = Some(sock);
last_err = None;
break;
}
Err(e) => {
last_err = Some(anyhow::anyhow!(e));
}
}
}
match bound {
Some(sock) => sock,
None => {
let detail = last_err
.map(|e| format!("{e:?}"))
.unwrap_or_else(|| "no ports in range available".to_string());
return Err(anyhow::anyhow!(
"failed to bind relay socket in range {}-{} on {}: {}",
min,
max,
self.opts.relay_bind_ip,
detail
));
}
}
}
_ => {
let addr = SocketAddr::new(self.opts.relay_bind_ip, 0);
UdpSocket::bind(addr).await?
}
};
let relay_local = relay.local_addr()?;
let relay_arc = Arc::new(relay);
// Insert allocation before spawning relay loop to avoid races.
let stop = Arc::new(Notify::new());
let alloc = Allocation {
client,
relay_addr: relay_local,
_socket: relay_arc.clone(),
stop: stop.clone(),
permissions: Arc::new(Mutex::new(HashMap::new())),
channel_bindings: Arc::new(Mutex::new(HashMap::new())),
expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)),
};
{
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.insert(client, alloc);
}
// spawn relay loop
let relay_clone = relay_arc.clone();
let sink_clone = client_sink.clone();
let client_clone = client;
let manager_clone = self.clone();
let stop_clone = stop.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 2048];
loop {
tokio::select! {
_ = stop_clone.notified() => {
break;
}
res = relay_clone.recv_from(&mut buf) => match res {
Ok((len, src)) => {
info!(
"relay got {} bytes from {} for client {}",
len, src, client_clone
);
if let Some(allocation) = manager_clone.get_allocation(&client_clone) {
if !allocation.is_peer_allowed(&src) {
tracing::debug!(
"dropping peer packet {} -> {} (permission expired)",
src,
client_clone
);
continue;
}
if let Some(channel) = allocation.channel_for_peer(&src) {
let frame = build_channel_data(channel, &buf[..len]);
if let Err(e) = sink_clone.send(frame).await {
tracing::error!(
"failed to send channel data {} -> {}: {:?}",
src,
client_clone,
e
);
if matches!(sink_clone, ClientSink::Stream { .. }) {
break;
}
}
} else {
let indication = build_data_indication(&src, &buf[..len]);
if let Err(e) = sink_clone.send(indication).await {
tracing::error!(
"failed to send data indication {} -> {}: {:?}",
src,
client_clone,
e
);
if matches!(sink_clone, ClientSink::Stream { .. }) {
break;
}
}
}
} else {
tracing::debug!(
"allocation missing while forwarding from peer {} -> {}",
src,
client_clone
);
// Allocation removed/expired: stop the relay task.
break;
}
}
Err(e) => {
tracing::error!("relay socket error: {:?}", e);
break;
}
}
}
}
});
tracing::info!("created allocation for {} -> {}", client, relay_local);
Ok(relay_local)
}
pub fn get_allocation(&self, client: &SocketAddr) -> Option<Allocation> {
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.get(client).cloned()
}
/// Register a permission for the given client allocation so the relay can forward packets
/// to the specified peer address. Permissions currently expire after a static timeout.
pub fn add_permission(&self, client: SocketAddr, peer: SocketAddr) -> anyhow::Result<()> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let alloc = guard
.get_mut(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut perms = alloc.permissions.lock().unwrap();
prune_permissions(&mut perms);
if let Some(max) = self.opts.max_permissions_per_allocation {
let max = max as usize;
if !perms.contains_key(&peer) && perms.len() >= max {
return Err(anyhow::anyhow!(AllocationError::PermissionQuotaExceeded));
}
}
perms.insert(peer, Instant::now() + PERMISSION_LIFETIME);
Ok(())
}
/// Associate a TURN channel number with a specific peer socket for the allocation.
pub fn add_channel_binding(
&self,
client: SocketAddr,
channel: u16,
peer: SocketAddr,
) -> anyhow::Result<()> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let alloc = guard
.get_mut(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut bindings = alloc.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings);
if let Some(max) = self.opts.max_channel_bindings_per_allocation {
let max = max as usize;
if !bindings.contains_key(&channel) && bindings.len() >= max {
return Err(anyhow::anyhow!(AllocationError::ChannelQuotaExceeded));
}
}
bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME));
Ok(())
}
/// Refresh allocation lifetime or delete it when zero requested. Returns the applied lifetime.
pub fn refresh_allocation(
&self,
client: SocketAddr,
requested: Option<Duration>,
) -> anyhow::Result<Duration> {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME);
if let Some(d) = requested {
if d.is_zero() {
if let Some(alloc) = guard.remove(&client) {
alloc.stop.notify_waiters();
}
return Ok(Duration::from_secs(0));
}
}
let alloc = guard
.get(&client)
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut expiry = alloc.expiry.lock().unwrap();
let now = Instant::now();
let clamped = clamp_lifetime(req);
*expiry = now + clamped;
Ok(clamped)
}
/// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed.
pub fn remove_allocation(&self, client: &SocketAddr) -> bool {
let mut guard = self.inner.lock().unwrap();
if let Some(alloc) = guard.remove(client) {
alloc.stop.notify_waiters();
true
} else {
false
}
}
/// Spawn a background housekeeping task that periodically prunes expired allocations.
/// This avoids keeping relay tasks/sockets alive indefinitely when allocations expire.
pub fn spawn_housekeeping(&self, interval: Duration) {
let mgr = self.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
mgr.prune_expired();
}
});
}
/// Remove expired allocations and notify their relay tasks to stop.
pub fn prune_expired(&self) {
let mut guard = self.inner.lock().unwrap();
let now = Instant::now();
let expired: Vec<SocketAddr> = guard
.iter()
.filter_map(|(k, alloc)| {
let expiry = alloc.expiry.lock().unwrap();
if *expiry <= now { Some(*k) } else { None }
})
.collect();
for client in expired {
if let Some(alloc) = guard.remove(&client) {
alloc.stop.notify_waiters();
}
}
}
}
impl Allocation {
/// Check whether a peer address is currently permitted for this allocation.
pub fn is_peer_allowed(&self, peer: &SocketAddr) -> bool {
if self.is_expired() {
return false;
}
let mut perms = self.permissions.lock().unwrap();
prune_permissions(&mut perms);
perms.contains_key(peer)
}
/// Resolve an active channel binding to its peer socket, if still valid.
pub fn channel_peer(&self, channel: u16) -> Option<SocketAddr> {
if self.is_expired() {
return None;
}
let mut bindings = self.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings);
bindings.get(&channel).map(|(peer, _)| *peer)
}
/// Return the bound channel number for a peer if available.
pub fn channel_for_peer(&self, peer: &SocketAddr) -> Option<u16> {
if self.is_expired() {
return None;
}
let mut bindings = self.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings);
bindings
.iter()
.find_map(|(channel, (addr, _))| if addr == peer { Some(*channel) } else { None })
}
/// Forward payload to a TURN peer via the relay socket.
pub async fn send_to_peer(&self, peer: SocketAddr, data: &[u8]) -> anyhow::Result<usize> {
let sent = self._socket.send_to(data, peer).await?;
Ok(sent)
}
/// Remaining lifetime for this allocation (saturates at zero).
pub fn remaining_lifetime(&self) -> Duration {
let expiry = self.expiry.lock().unwrap();
let now = Instant::now();
expiry.saturating_duration_since(now)
}
pub fn is_expired(&self) -> bool {
self.remaining_lifetime().is_zero()
}
}
const PERMISSION_LIFETIME: Duration = Duration::from_secs(300);
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
const MIN_ALLOCATION_LIFETIME: Duration = Duration::from_secs(60);
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
fn prune_permissions(perms: &mut HashMap<SocketAddr, Instant>) {
let now = Instant::now();
perms.retain(|_, expiry| *expiry > now);
}
fn prune_channel_bindings(bindings: &mut HashMap<u16, (SocketAddr, Instant)>) {
let now = Instant::now();
bindings.retain(|_, (_, expiry)| *expiry > now);
}
fn prune_expired_locked(map: &mut HashMap<SocketAddr, Allocation>) {
let now = Instant::now();
map.retain(|_, alloc| {
let expiry = alloc.expiry.lock().unwrap();
*expiry > now
});
}
fn clamp_lifetime(requested: Duration) -> Duration {
if requested < MIN_ALLOCATION_LIFETIME {
MIN_ALLOCATION_LIFETIME
} else if requested > MAX_ALLOCATION_LIFETIME {
MAX_ALLOCATION_LIFETIME
} else {
requested
}
}