485 lines
18 KiB
Rust
485 lines
18 KiB
Rust
//! Allocation manager: provisions relay sockets and forwards packets for TURN allocations.
|
|
//! Backlog: permission tables, channel bindings, allocation refresh timers, and rate limits.
|
|
use std::collections::HashMap;
|
|
use std::net::IpAddr;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::{Duration, Instant};
|
|
use tokio::net::UdpSocket;
|
|
use tokio::sync::Notify;
|
|
use tokio::sync::mpsc;
|
|
use tracing::info;
|
|
|
|
use crate::stun::{build_channel_data, build_data_indication};
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
pub enum AllocationError {
|
|
#[error("allocation quota exceeded")]
|
|
AllocationQuotaExceeded,
|
|
#[error("permission quota exceeded")]
|
|
PermissionQuotaExceeded,
|
|
#[error("channel binding quota exceeded")]
|
|
ChannelQuotaExceeded,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub enum ClientSink {
|
|
Udp { sock: Arc<UdpSocket>, addr: SocketAddr },
|
|
Stream { tx: mpsc::Sender<Vec<u8>> },
|
|
}
|
|
|
|
impl ClientSink {
|
|
pub async fn send(&self, data: Vec<u8>) -> anyhow::Result<()> {
|
|
match self {
|
|
ClientSink::Udp { sock, addr } => {
|
|
sock.send_to(&data, addr).await?;
|
|
Ok(())
|
|
}
|
|
ClientSink::Stream { tx } => {
|
|
tx.send(data)
|
|
.await
|
|
.map_err(|_| anyhow::anyhow!("client stream closed"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Allocation {
|
|
pub client: SocketAddr,
|
|
pub relay_addr: SocketAddr,
|
|
// keep the socket so it stays bound
|
|
_socket: Arc<UdpSocket>,
|
|
stop: Arc<Notify>,
|
|
permissions: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
|
|
channel_bindings: Arc<Mutex<HashMap<u16, (SocketAddr, Instant)>>>,
|
|
expiry: Arc<Mutex<Instant>>,
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
pub struct AllocationManager {
|
|
inner: Arc<Mutex<HashMap<SocketAddr, Allocation>>>,
|
|
opts: AllocationOptions,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AllocationOptions {
|
|
pub relay_bind_ip: IpAddr,
|
|
pub relay_port_min: Option<u16>,
|
|
pub relay_port_max: Option<u16>,
|
|
pub advertised_ip: Option<IpAddr>,
|
|
|
|
pub max_allocations_per_ip: Option<u32>,
|
|
pub max_permissions_per_allocation: Option<u32>,
|
|
pub max_channel_bindings_per_allocation: Option<u32>,
|
|
}
|
|
|
|
impl Default for AllocationOptions {
|
|
fn default() -> Self {
|
|
Self {
|
|
relay_bind_ip: IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
|
|
relay_port_min: None,
|
|
relay_port_max: None,
|
|
advertised_ip: None,
|
|
max_allocations_per_ip: None,
|
|
max_permissions_per_allocation: None,
|
|
max_channel_bindings_per_allocation: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AllocationManager {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
inner: Arc::new(Mutex::new(HashMap::new())),
|
|
opts: AllocationOptions::default(),
|
|
}
|
|
}
|
|
|
|
pub fn new_with_options(opts: AllocationOptions) -> Self {
|
|
Self {
|
|
inner: Arc::new(Mutex::new(HashMap::new())),
|
|
opts,
|
|
}
|
|
}
|
|
|
|
/// Translate a locally-bound relay socket address into the address that should be
|
|
/// advertised to clients (e.g. replace 0.0.0.0 with a public IP).
|
|
pub fn relay_addr_for_response(&self, relay_local: SocketAddr) -> SocketAddr {
|
|
match self.opts.advertised_ip {
|
|
Some(ip) => SocketAddr::new(ip, relay_local.port()),
|
|
None => relay_local,
|
|
}
|
|
}
|
|
|
|
/// Create a relay UDP socket for the given client and spawn a relay loop that forwards
|
|
/// any packets received on the relay socket back to the client via the provided client sink.
|
|
pub async fn allocate_for(
|
|
&self,
|
|
client: SocketAddr,
|
|
client_sink: ClientSink,
|
|
) -> anyhow::Result<SocketAddr> {
|
|
// If an allocation already exists for this exact 5-tuple, reuse it.
|
|
{
|
|
let mut guard = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut guard);
|
|
if let Some(existing) = guard.get(&client) {
|
|
return Ok(existing.relay_addr);
|
|
}
|
|
|
|
if let Some(max) = self.opts.max_allocations_per_ip {
|
|
let count_for_ip = guard
|
|
.values()
|
|
.filter(|a| a.client.ip() == client.ip())
|
|
.count() as u32;
|
|
if count_for_ip >= max {
|
|
return Err(anyhow::anyhow!(AllocationError::AllocationQuotaExceeded));
|
|
}
|
|
}
|
|
}
|
|
|
|
// bind relay socket (optionally within configured port range)
|
|
let relay = match (self.opts.relay_port_min, self.opts.relay_port_max) {
|
|
(Some(min), Some(max)) if min > 0 && max >= min => {
|
|
let mut bound: Option<UdpSocket> = None;
|
|
let mut last_err: Option<anyhow::Error> = None;
|
|
for port in min..=max {
|
|
let addr = SocketAddr::new(self.opts.relay_bind_ip, port);
|
|
match UdpSocket::bind(addr).await {
|
|
Ok(sock) => {
|
|
bound = Some(sock);
|
|
last_err = None;
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
last_err = Some(anyhow::anyhow!(e));
|
|
}
|
|
}
|
|
}
|
|
|
|
match bound {
|
|
Some(sock) => sock,
|
|
None => {
|
|
let detail = last_err
|
|
.map(|e| format!("{e:?}"))
|
|
.unwrap_or_else(|| "no ports in range available".to_string());
|
|
return Err(anyhow::anyhow!(
|
|
"failed to bind relay socket in range {}-{} on {}: {}",
|
|
min,
|
|
max,
|
|
self.opts.relay_bind_ip,
|
|
detail
|
|
));
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
let addr = SocketAddr::new(self.opts.relay_bind_ip, 0);
|
|
UdpSocket::bind(addr).await?
|
|
}
|
|
};
|
|
let relay_local = relay.local_addr()?;
|
|
let relay_arc = Arc::new(relay);
|
|
|
|
// Insert allocation before spawning relay loop to avoid races.
|
|
let stop = Arc::new(Notify::new());
|
|
let alloc = Allocation {
|
|
client,
|
|
relay_addr: relay_local,
|
|
_socket: relay_arc.clone(),
|
|
stop: stop.clone(),
|
|
permissions: Arc::new(Mutex::new(HashMap::new())),
|
|
channel_bindings: Arc::new(Mutex::new(HashMap::new())),
|
|
expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)),
|
|
};
|
|
{
|
|
let mut m = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut m);
|
|
m.insert(client, alloc);
|
|
}
|
|
|
|
// spawn relay loop
|
|
let relay_clone = relay_arc.clone();
|
|
let sink_clone = client_sink.clone();
|
|
let client_clone = client;
|
|
let manager_clone = self.clone();
|
|
let stop_clone = stop.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 2048];
|
|
loop {
|
|
tokio::select! {
|
|
_ = stop_clone.notified() => {
|
|
break;
|
|
}
|
|
res = relay_clone.recv_from(&mut buf) => match res {
|
|
Ok((len, src)) => {
|
|
info!(
|
|
"relay got {} bytes from {} for client {}",
|
|
len, src, client_clone
|
|
);
|
|
if let Some(allocation) = manager_clone.get_allocation(&client_clone) {
|
|
if !allocation.is_peer_allowed(&src) {
|
|
tracing::debug!(
|
|
"dropping peer packet {} -> {} (permission expired)",
|
|
src,
|
|
client_clone
|
|
);
|
|
continue;
|
|
}
|
|
|
|
if let Some(channel) = allocation.channel_for_peer(&src) {
|
|
let frame = build_channel_data(channel, &buf[..len]);
|
|
if let Err(e) = sink_clone.send(frame).await {
|
|
tracing::error!(
|
|
"failed to send channel data {} -> {}: {:?}",
|
|
src,
|
|
client_clone,
|
|
e
|
|
);
|
|
if matches!(sink_clone, ClientSink::Stream { .. }) {
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
let indication = build_data_indication(&src, &buf[..len]);
|
|
if let Err(e) = sink_clone.send(indication).await {
|
|
tracing::error!(
|
|
"failed to send data indication {} -> {}: {:?}",
|
|
src,
|
|
client_clone,
|
|
e
|
|
);
|
|
if matches!(sink_clone, ClientSink::Stream { .. }) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
tracing::debug!(
|
|
"allocation missing while forwarding from peer {} -> {}",
|
|
src,
|
|
client_clone
|
|
);
|
|
// Allocation removed/expired: stop the relay task.
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("relay socket error: {:?}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
tracing::info!("created allocation for {} -> {}", client, relay_local);
|
|
Ok(relay_local)
|
|
}
|
|
|
|
pub fn get_allocation(&self, client: &SocketAddr) -> Option<Allocation> {
|
|
let mut m = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut m);
|
|
m.get(client).cloned()
|
|
}
|
|
|
|
/// Register a permission for the given client allocation so the relay can forward packets
|
|
/// to the specified peer address. Permissions currently expire after a static timeout.
|
|
pub fn add_permission(&self, client: SocketAddr, peer: SocketAddr) -> anyhow::Result<()> {
|
|
let mut guard = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut guard);
|
|
let alloc = guard
|
|
.get_mut(&client)
|
|
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
|
|
let mut perms = alloc.permissions.lock().unwrap();
|
|
prune_permissions(&mut perms);
|
|
|
|
if let Some(max) = self.opts.max_permissions_per_allocation {
|
|
let max = max as usize;
|
|
if !perms.contains_key(&peer) && perms.len() >= max {
|
|
return Err(anyhow::anyhow!(AllocationError::PermissionQuotaExceeded));
|
|
}
|
|
}
|
|
perms.insert(peer, Instant::now() + PERMISSION_LIFETIME);
|
|
Ok(())
|
|
}
|
|
|
|
/// Associate a TURN channel number with a specific peer socket for the allocation.
|
|
pub fn add_channel_binding(
|
|
&self,
|
|
client: SocketAddr,
|
|
channel: u16,
|
|
peer: SocketAddr,
|
|
) -> anyhow::Result<()> {
|
|
let mut guard = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut guard);
|
|
let alloc = guard
|
|
.get_mut(&client)
|
|
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
|
|
let mut bindings = alloc.channel_bindings.lock().unwrap();
|
|
prune_channel_bindings(&mut bindings);
|
|
|
|
if let Some(max) = self.opts.max_channel_bindings_per_allocation {
|
|
let max = max as usize;
|
|
if !bindings.contains_key(&channel) && bindings.len() >= max {
|
|
return Err(anyhow::anyhow!(AllocationError::ChannelQuotaExceeded));
|
|
}
|
|
}
|
|
bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME));
|
|
Ok(())
|
|
}
|
|
|
|
/// Refresh allocation lifetime or delete it when zero requested. Returns the applied lifetime.
|
|
pub fn refresh_allocation(
|
|
&self,
|
|
client: SocketAddr,
|
|
requested: Option<Duration>,
|
|
) -> anyhow::Result<Duration> {
|
|
let mut guard = self.inner.lock().unwrap();
|
|
prune_expired_locked(&mut guard);
|
|
let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME);
|
|
if let Some(d) = requested {
|
|
if d.is_zero() {
|
|
if let Some(alloc) = guard.remove(&client) {
|
|
alloc.stop.notify_waiters();
|
|
}
|
|
return Ok(Duration::from_secs(0));
|
|
}
|
|
}
|
|
|
|
let alloc = guard
|
|
.get(&client)
|
|
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
|
|
let mut expiry = alloc.expiry.lock().unwrap();
|
|
let now = Instant::now();
|
|
let clamped = clamp_lifetime(req);
|
|
*expiry = now + clamped;
|
|
Ok(clamped)
|
|
}
|
|
|
|
/// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed.
|
|
pub fn remove_allocation(&self, client: &SocketAddr) -> bool {
|
|
let mut guard = self.inner.lock().unwrap();
|
|
if let Some(alloc) = guard.remove(client) {
|
|
alloc.stop.notify_waiters();
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Spawn a background housekeeping task that periodically prunes expired allocations.
|
|
/// This avoids keeping relay tasks/sockets alive indefinitely when allocations expire.
|
|
pub fn spawn_housekeeping(&self, interval: Duration) {
|
|
let mgr = self.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::time::sleep(interval).await;
|
|
mgr.prune_expired();
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Remove expired allocations and notify their relay tasks to stop.
|
|
pub fn prune_expired(&self) {
|
|
let mut guard = self.inner.lock().unwrap();
|
|
let now = Instant::now();
|
|
let expired: Vec<SocketAddr> = guard
|
|
.iter()
|
|
.filter_map(|(k, alloc)| {
|
|
let expiry = alloc.expiry.lock().unwrap();
|
|
if *expiry <= now { Some(*k) } else { None }
|
|
})
|
|
.collect();
|
|
for client in expired {
|
|
if let Some(alloc) = guard.remove(&client) {
|
|
alloc.stop.notify_waiters();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Allocation {
|
|
/// Check whether a peer address is currently permitted for this allocation.
|
|
pub fn is_peer_allowed(&self, peer: &SocketAddr) -> bool {
|
|
if self.is_expired() {
|
|
return false;
|
|
}
|
|
let mut perms = self.permissions.lock().unwrap();
|
|
prune_permissions(&mut perms);
|
|
perms.contains_key(peer)
|
|
}
|
|
|
|
/// Resolve an active channel binding to its peer socket, if still valid.
|
|
pub fn channel_peer(&self, channel: u16) -> Option<SocketAddr> {
|
|
if self.is_expired() {
|
|
return None;
|
|
}
|
|
let mut bindings = self.channel_bindings.lock().unwrap();
|
|
prune_channel_bindings(&mut bindings);
|
|
bindings.get(&channel).map(|(peer, _)| *peer)
|
|
}
|
|
|
|
/// Return the bound channel number for a peer if available.
|
|
pub fn channel_for_peer(&self, peer: &SocketAddr) -> Option<u16> {
|
|
if self.is_expired() {
|
|
return None;
|
|
}
|
|
let mut bindings = self.channel_bindings.lock().unwrap();
|
|
prune_channel_bindings(&mut bindings);
|
|
bindings
|
|
.iter()
|
|
.find_map(|(channel, (addr, _))| if addr == peer { Some(*channel) } else { None })
|
|
}
|
|
|
|
/// Forward payload to a TURN peer via the relay socket.
|
|
pub async fn send_to_peer(&self, peer: SocketAddr, data: &[u8]) -> anyhow::Result<usize> {
|
|
let sent = self._socket.send_to(data, peer).await?;
|
|
Ok(sent)
|
|
}
|
|
|
|
/// Remaining lifetime for this allocation (saturates at zero).
|
|
pub fn remaining_lifetime(&self) -> Duration {
|
|
let expiry = self.expiry.lock().unwrap();
|
|
let now = Instant::now();
|
|
expiry.saturating_duration_since(now)
|
|
}
|
|
|
|
pub fn is_expired(&self) -> bool {
|
|
self.remaining_lifetime().is_zero()
|
|
}
|
|
}
|
|
|
|
const PERMISSION_LIFETIME: Duration = Duration::from_secs(300);
|
|
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
|
|
const MIN_ALLOCATION_LIFETIME: Duration = Duration::from_secs(60);
|
|
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
|
|
|
|
fn prune_permissions(perms: &mut HashMap<SocketAddr, Instant>) {
|
|
let now = Instant::now();
|
|
perms.retain(|_, expiry| *expiry > now);
|
|
}
|
|
|
|
fn prune_channel_bindings(bindings: &mut HashMap<u16, (SocketAddr, Instant)>) {
|
|
let now = Instant::now();
|
|
bindings.retain(|_, (_, expiry)| *expiry > now);
|
|
}
|
|
|
|
fn prune_expired_locked(map: &mut HashMap<SocketAddr, Allocation>) {
|
|
let now = Instant::now();
|
|
map.retain(|_, alloc| {
|
|
let expiry = alloc.expiry.lock().unwrap();
|
|
*expiry > now
|
|
});
|
|
}
|
|
|
|
fn clamp_lifetime(requested: Duration) -> Duration {
|
|
if requested < MIN_ALLOCATION_LIFETIME {
|
|
MIN_ALLOCATION_LIFETIME
|
|
} else if requested > MAX_ALLOCATION_LIFETIME {
|
|
MAX_ALLOCATION_LIFETIME
|
|
} else {
|
|
requested
|
|
}
|
|
}
|