Refactor: project structure and logging. Add: Integration tests for happy flow of UDP and TLS.

This commit is contained in:
ghost 2025-11-24 16:56:54 +01:00
parent fe0b4559d0
commit 15dfec8695
14 changed files with 1814 additions and 1331 deletions

39
Cargo.lock generated
View File

@ -17,6 +17,15 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aho-corasick"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
dependencies = [
"memchr",
]
[[package]]
name = "anyhow"
version = "1.0.100"
@ -273,6 +282,15 @@ version = "0.4.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "md5"
version = "0.7.0"
@ -449,6 +467,23 @@ dependencies = [
"bitflags",
]
[[package]]
name = "regex-automata"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]]
name = "ring"
version = "0.16.20"
@ -817,10 +852,14 @@ version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]

View File

@ -17,7 +17,7 @@ hex = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt"] }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
# TLS for turns (server)
tokio-rustls = "0.23"

View File

@ -1,138 +0,0 @@
use bytes::BytesMut;
use niom_turn::constants::*;
// use niom_turn::stun; // not needed; use specific functions via path when required
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::UdpSocket;
// Use shared decoder from library: niom_turn::stun::decode_xor_relayed_address
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let server: SocketAddr = "127.0.0.1:3478".parse()?;
let local = UdpSocket::bind("0.0.0.0:0").await?;
let username = "testuser";
let password = "secretpassword";
// Build Allocate request (method METHOD_ALLOCATE)
let mut buf = BytesMut::new();
buf.extend_from_slice(&METHOD_ALLOCATE.to_be_bytes()); // Allocate Request
buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
let trans = [13u8; 12];
buf.extend_from_slice(&trans);
// USERNAME
let uname = username.as_bytes();
buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes());
buf.extend_from_slice(&(uname.len() as u16).to_be_bytes());
buf.extend_from_slice(uname);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0u8]);
}
// MESSAGE-INTEGRITY placeholder
let mi_attr_offset = buf.len();
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes()));
let mi_val_pos = buf.len();
buf.extend_from_slice(&[0u8; 20]);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0u8]);
}
// fix length
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
// compute HMAC over bytes up to MI attribute header
{
use hmac::{Hmac, Mac};
use sha1::Sha1;
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key");
mac.update(&buf[..mi_attr_offset]);
let res = mac.finalize().into_bytes();
for i in 0..20 {
buf[mi_val_pos + i] = res[i];
}
}
// send Allocate
local.send_to(&buf, server).await?;
// receive response
let mut r = vec![0u8; 1500];
let (len, _addr) = local.recv_from(&mut r).await?;
println!("got {} bytes", len);
let resp = &r[..len];
// expect success (METHOD_ALLOCATE | CLASS_SUCCESS) with XOR-RELAYED-ADDRESS attr
if resp.len() < 20 {
anyhow::bail!("response too short");
}
let typ = u16::from_be_bytes([resp[0], resp[1]]);
println!("resp type 0x{:04x}", typ);
let expected_type = METHOD_ALLOCATE | CLASS_SUCCESS;
if typ != expected_type {
anyhow::bail!("expected success response, got 0x{:04x}", typ);
}
// parse attributes
let length = u16::from_be_bytes([resp[2], resp[3]]) as usize;
let total = 20 + length;
let mut offset = 20;
let mut relay_addr_opt: Option<SocketAddr> = None;
while offset + 4 <= total {
let atype = u16::from_be_bytes([resp[offset], resp[offset + 1]]);
let alen = u16::from_be_bytes([resp[offset + 2], resp[offset + 3]]) as usize;
offset += 4;
if offset + alen > total {
break;
}
println!("attr type=0x{:04x} len={}", atype, alen);
println!("raw: {}", hex::encode(&resp[offset..offset + alen]));
if atype == ATTR_XOR_RELAYED_ADDRESS {
// XOR-RELAYED-ADDRESS: decode via shared library function
if let Some(sa) =
niom_turn::stun::decode_xor_relayed_address(&resp[offset..offset + alen], &trans)
{
relay_addr_opt = Some(sa);
}
}
offset += alen;
let pad = (4 - (alen % 4)) % 4;
offset += pad;
}
let relay_addr = match relay_addr_opt {
Some(a) => a,
None => anyhow::bail!("no relay address in response"),
};
println!("got relayed addr: {}", relay_addr);
// send test payload to relay addr
let payload = b"hello-relay";
local.send_to(payload, relay_addr).await?;
// wait for forwarded packet (should arrive via server socket) using tokio timeout
let mut buf2 = vec![0u8; 1500];
match tokio::time::timeout(Duration::from_secs(2), local.recv_from(&mut buf2)).await {
Ok(Ok((l, src))) => {
println!("received {} bytes from {}", l, src);
let got = &buf2[..l];
println!("payload: {:?}", got);
if got == payload {
println!("relay test success");
Ok(())
} else {
anyhow::bail!("payload mismatch")
}
}
Ok(Err(e)) => anyhow::bail!("recv error: {:?}", e),
Err(_) => anyhow::bail!("no forwarded packet received: timeout"),
}
}

View File

@ -1,70 +0,0 @@
use bytes::BytesMut;
use niom_turn::constants::*;
use std::net::SocketAddr;
use tokio::net::UdpSocket;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let server: SocketAddr = "127.0.0.1:3478".parse()?;
let local = UdpSocket::bind("0.0.0.0:0").await?;
// Build a minimal STUN Binding Request with USERNAME and placeholder MESSAGE-INTEGRITY
let username = "testuser";
let password = "secretpassword"; // matches server's in-memory creds
let mut buf = BytesMut::new();
buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request
buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
let trans = [7u8; 12];
buf.extend_from_slice(&trans);
// USERNAME
let uname = username.as_bytes();
buf.extend_from_slice(&ATTR_USERNAME.to_be_bytes());
buf.extend_from_slice(&(uname.len() as u16).to_be_bytes());
buf.extend_from_slice(uname);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0u8]);
}
// MESSAGE-INTEGRITY placeholder
let mi_attr_offset = buf.len();
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes()));
let mi_val_pos = buf.len();
buf.extend_from_slice(&[0u8; 20]);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0u8]);
}
// fix length
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
// compute HMAC over bytes up to MI attribute header
{
use hmac::{Hmac, Mac};
use sha1::Sha1;
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(password.as_bytes()).expect("HMAC key");
mac.update(&buf[..mi_attr_offset]);
let res = mac.finalize().into_bytes();
for i in 0..20 {
buf[mi_val_pos + i] = res[i];
}
}
// send
local.send_to(&buf, server).await?;
let mut r = vec![0u8; 1500];
let (len, addr) = local.recv_from(&mut r).await?;
println!("got {} bytes from {}", len, addr);
// dump hex
println!("{:02x?}", &r[..len]);
Ok(())
}

View File

@ -3,11 +3,15 @@ pub mod alloc;
pub mod auth;
pub mod config;
pub mod constants;
pub mod logging;
pub mod models;
pub mod server;
pub mod stun;
pub mod tls;
pub mod traits;
pub use crate::alloc::*;
pub use crate::auth::*;
pub use crate::logging::*;
pub use crate::server::*;
pub use crate::stun::*;

28
src/logging.rs Normal file
View File

@ -0,0 +1,28 @@
//! Logging helpers shared between the library, binaries, and integration tests.
//! Ensures we configure tracing once with sensible default filters while still
//! allowing `RUST_LOG` to override verbosity.
use std::sync::Once;
use tracing_subscriber::{fmt, EnvFilter};
static INIT: Once = Once::new();
/// Initialise tracing with a sane default filter (`warn` globally,
/// `niom_turn=info`) unless `RUST_LOG` is provided.
pub fn init_tracing() {
init_tracing_with_default("warn,niom_turn=info");
}
/// Initialise tracing with a custom default directive that can still be
/// overridden via `RUST_LOG` at runtime.
pub fn init_tracing_with_default(default_directive: &str) {
INIT.call_once(|| {
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(default_directive));
fmt()
.with_env_filter(env_filter)
.with_target(false)
.compact()
.init();
});
}

View File

@ -2,24 +2,18 @@
//! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tracing::{error, info};
// Use the library crate's public modules instead of local `mod` declarations.
use niom_turn::alloc::AllocationManager;
use niom_turn::auth::{AuthManager, AuthStatus, InMemoryStore};
use niom_turn::auth::{AuthManager, InMemoryStore};
use niom_turn::config::{AuthOptions, Config};
use niom_turn::constants::*;
use niom_turn::stun::{
build_401_response, build_allocate_success, build_error_response, build_lifetime_success,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data,
parse_message,
};
use niom_turn::server::udp_reader_loop;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
niom_turn::logging::init_tracing();
info!("niom-turn starting");
@ -64,7 +58,7 @@ async fn main() -> anyhow::Result<()> {
let udp = UdpSocket::bind(bind_addr).await?;
let udp = Arc::new(udp);
// allocation manager
// Allocation manager shared by UDP + TLS frontends.
let alloc_mgr = AllocationManager::new();
// Spawn the asynchronous packet loop that handles all UDP requests.
@ -103,448 +97,3 @@ async fn main() -> anyhow::Result<()> {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
}
}
async fn udp_reader_loop(
udp: Arc<UdpSocket>,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()> {
let mut buf = vec![0u8; 1500];
loop {
// Read the next datagram and keep peer metadata for follow-up responses.
let (len, peer) = udp.recv_from(&mut buf).await?;
tracing::debug!("got {} bytes from {}", len, peer);
if let Some((channel, payload)) = parse_channel_data(&buf[..len]) {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!("channel data without allocation from {}", peer);
continue;
}
};
let target = match allocation.channel_peer(channel) {
Some(addr) => addr,
None => {
tracing::warn!(
"channel data with unknown channel 0x{:04x} from {}",
channel,
peer
);
continue;
}
};
if !allocation.is_peer_allowed(&target) {
tracing::warn!(
"channel data target {} no longer permitted for {}",
target,
peer
);
continue;
}
match allocation.send_to_peer(target, payload).await {
Ok(sent) => tracing::debug!(
"forwarded {} bytes via channel 0x{:04x} from {} to {}",
sent,
channel,
peer,
target
),
Err(e) => tracing::error!(
"failed to forward channel data 0x{:04x} from {} to {}: {:?}",
channel,
peer,
target,
e
),
}
continue;
}
// Minimal STUN/TURN detection: parse STUN messages and send 401 challenge
if let Ok(msg) = parse_message(&buf[..len]) {
tracing::info!(
"STUN/TURN message from {} type=0x{:04x} len={}",
peer,
msg.header.msg_type,
len
);
let requires_auth = matches!(
msg.header.msg_type,
METHOD_ALLOCATE
| METHOD_CREATE_PERMISSION
| METHOD_CHANNEL_BIND
| METHOD_SEND
| METHOD_REFRESH
);
if requires_auth {
match auth.authenticate(&msg, &peer).await {
AuthStatus::Granted { username } => {
tracing::debug!(
"TURN auth ok for {} as {} (0x{:04x})",
peer,
username,
msg.header.msg_type
);
}
AuthStatus::Challenge { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
401,
"Unauthorized",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
AuthStatus::StaleNonce { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
438,
"Stale Nonce",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
AuthStatus::Reject { code, reason } => {
let resp = build_error_response(&msg.header, code, reason);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
}
match msg.header.msg_type {
METHOD_ALLOCATE => {
let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero());
match allocs.allocate_for(peer, udp.clone()).await {
Ok(relay_addr) => {
let applied =
match allocs.refresh_allocation(peer, requested_lifetime) {
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32;
let resp =
build_allocate_success(&msg.header, &relay_addr, lifetime_secs);
tracing::info!(
"allocated relay {} for {} lifetime={}s",
relay_addr,
peer,
lifetime_secs
);
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!("allocate failed: {:?}", e);
let resp =
build_error_response(&msg.header, 500, "Allocate Failed");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
METHOD_CREATE_PERMISSION => {
if allocs.get_allocation(&peer).is_none() {
tracing::warn!("create-permission without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
let mut added = 0usize;
for attr in msg
.attributes
.iter()
.filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS)
{
if let Some(peer_addr) =
decode_xor_peer_address(&attr.value, &msg.header.transaction_id)
{
match allocs.add_permission(peer, peer_addr) {
Ok(()) => {
tracing::info!(
"added permission for {} -> {}",
peer,
peer_addr
);
added += 1;
}
Err(e) => {
tracing::error!(
"failed to persist permission {} -> {}: {:?}",
peer,
peer_addr,
e
);
}
}
} else {
tracing::warn!("invalid XOR-PEER-ADDRESS in request from {}", peer);
}
}
if added == 0 {
let resp =
build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS");
let _ = udp.send_to(&resp, &peer).await;
} else {
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
continue;
}
METHOD_CHANNEL_BIND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!("channel-bind without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let channel_attr =
msg.attributes.iter().find(|a| a.typ == ATTR_CHANNEL_NUMBER);
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let channel = match channel_attr.and_then(|attr| {
if attr.value.len() >= 4 {
Some(u16::from_be_bytes([attr.value[0], attr.value[1]]))
} else {
None
}
}) {
Some(c) => c,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing CHANNEL-NUMBER",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
if channel < 0x4000 || channel > 0x7FFF {
let resp =
build_error_response(&msg.header, 400, "Channel Out Of Range");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
let peer_addr = match peer_attr.and_then(|attr| {
decode_xor_peer_address(&attr.value, &msg.header.transaction_id)
}) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
match allocs.add_channel_binding(peer, channel, peer_addr) {
Ok(()) => {
tracing::info!(
"bound channel 0x{:04x} for {} -> {}",
channel,
peer,
peer_addr
);
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!(
"failed to add channel binding {} -> {} (channel 0x{:04x}): {:?}",
peer,
peer_addr,
channel,
e
);
let resp = build_error_response(
&msg.header,
500,
"Channel Binding Failed",
);
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
METHOD_SEND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!("send without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA);
let peer_addr = match peer_attr.and_then(|attr| {
decode_xor_peer_address(&attr.value, &msg.header.transaction_id)
}) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let data_attr = match data_attr {
Some(attr) => attr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing DATA Attribute",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
match allocation.send_to_peer(peer_addr, &data_attr.value).await {
Ok(sent) => {
tracing::info!(
"forwarded {} bytes from {} to peer {}",
sent,
peer,
peer_addr
);
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!(
"failed to send payload from {} to {}: {:?}",
peer,
peer_addr,
e
);
let resp =
build_error_response(&msg.header, 500, "Peer Send Failed");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
METHOD_REFRESH => {
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match allocs.refresh_allocation(peer, requested) {
Ok(applied) => {
if applied.is_zero() {
tracing::info!("allocation for {} released", peer);
} else {
tracing::debug!(
"allocation for {} refreshed to {}s",
peer,
applied.as_secs()
);
}
let resp = build_lifetime_success(
&msg.header,
applied.as_secs().min(u32::MAX as u64) as u32,
);
let _ = udp.send_to(&resp, &peer).await;
}
Err(_) => {
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
_ => {
let resp = build_error_response(&msg.header, 420, "Unknown TURN Method");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
}
}
match msg.header.msg_type {
METHOD_BINDING => {
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
_ => {
let nonce = auth.mint_nonce(&peer);
let resp =
build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized");
if let Err(e) = udp.send_to(&resp, &peer).await {
error!("failed to send 401: {:?}", e);
}
}
}
} else {
tracing::debug!("Non-STUN or parse error from {} len={}", peer, len);
}
}
}
// existing helper functions moved to stun.rs

429
src/server.rs Normal file
View File

@ -0,0 +1,429 @@
//! Shared server routines for UDP TURN handling so integration tests can reuse the core loop.
use std::sync::Arc;
use tokio::net::UdpSocket;
use tracing::{error, warn};
use crate::alloc::AllocationManager;
use crate::auth::{AuthManager, AuthStatus, InMemoryStore};
use crate::constants::*;
use crate::stun::{
build_401_response, build_allocate_success, build_error_response, build_lifetime_success,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data,
parse_message,
};
use std::time::Duration;
/// Main UDP reader loop shared between binary entry point and integration tests.
pub async fn udp_reader_loop(
udp: Arc<UdpSocket>,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()> {
let mut buf = vec![0u8; 1500];
loop {
let (len, peer) = udp.recv_from(&mut buf).await?;
tracing::debug!("got {} bytes from {}", len, peer);
if let Some((channel, payload)) = parse_channel_data(&buf[..len]) {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
warn!("channel data without allocation from {}", peer);
continue;
}
};
let target = match allocation.channel_peer(channel) {
Some(addr) => addr,
None => {
warn!(
"channel data with unknown channel 0x{:04x} from {}",
channel, peer
);
continue;
}
};
if !allocation.is_peer_allowed(&target) {
warn!(
"channel data target {} no longer permitted for {}",
target, peer
);
continue;
}
match allocation.send_to_peer(target, payload).await {
Ok(sent) => tracing::debug!(
"forwarded {} bytes via channel 0x{:04x} from {} to {}",
sent,
channel,
peer,
target
),
Err(e) => error!(
"failed to forward channel data 0x{:04x} from {} to {}: {:?}",
channel, peer, target, e
),
}
continue;
}
if let Ok(msg) = parse_message(&buf[..len]) {
tracing::info!(
"STUN/TURN message from {} type=0x{:04x} len={}",
peer,
msg.header.msg_type,
len
);
let requires_auth = matches!(
msg.header.msg_type,
METHOD_ALLOCATE
| METHOD_CREATE_PERMISSION
| METHOD_CHANNEL_BIND
| METHOD_SEND
| METHOD_REFRESH
);
if requires_auth {
match auth.authenticate(&msg, &peer).await {
AuthStatus::Granted { username } => {
tracing::debug!(
"TURN auth ok for {} as {} (0x{:04x})",
peer,
username,
msg.header.msg_type
);
}
AuthStatus::Challenge { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
401,
"Unauthorized",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
AuthStatus::StaleNonce { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
438,
"Stale Nonce",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
AuthStatus::Reject { code, reason } => {
let resp = build_error_response(&msg.header, code, reason);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
}
match msg.header.msg_type {
METHOD_ALLOCATE => {
let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero());
match allocs.allocate_for(peer, udp.clone()).await {
Ok(relay_addr) => {
let applied =
match allocs.refresh_allocation(peer, requested_lifetime) {
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32;
let resp =
build_allocate_success(&msg.header, &relay_addr, lifetime_secs);
tracing::info!(
"allocated relay {} for {} lifetime={}s",
relay_addr,
peer,
lifetime_secs
);
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!("allocate failed: {:?}", e);
let resp =
build_error_response(&msg.header, 500, "Allocate Failed");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
METHOD_CREATE_PERMISSION => {
if allocs.get_allocation(&peer).is_none() {
warn!("create-permission without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
let mut added = 0usize;
for attr in msg
.attributes
.iter()
.filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS)
{
if let Some(peer_addr) =
decode_xor_peer_address(&attr.value, &msg.header.transaction_id)
{
match allocs.add_permission(peer, peer_addr) {
Ok(()) => {
tracing::info!(
"added permission for {} -> {}",
peer,
peer_addr
);
added += 1;
}
Err(e) => {
tracing::error!(
"failed to persist permission {} -> {}: {:?}",
peer,
peer_addr,
e
);
}
}
} else {
tracing::warn!("invalid XOR-PEER-ADDRESS in request from {}", peer);
}
}
if added == 0 {
let resp =
build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS");
let _ = udp.send_to(&resp, &peer).await;
} else {
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
continue;
}
METHOD_CHANNEL_BIND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
warn!("channel-bind without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let channel_attr =
msg.attributes.iter().find(|a| a.typ == ATTR_CHANNEL_NUMBER);
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let (channel_attr, peer_attr) = match (channel_attr, peer_attr) {
(Some(c), Some(p)) => (c, p),
_ => {
let resp = build_error_response(
&msg.header,
400,
"Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let channel =
u16::from_be_bytes([channel_attr.value[0], channel_attr.value[1]]);
let peer_addr = match decode_xor_peer_address(
&peer_attr.value,
&msg.header.transaction_id,
) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Invalid XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
if let Err(e) = allocs.add_channel_binding(peer, channel, peer_addr) {
tracing::error!(
"failed to persist channel binding {} -> {} (0x{:04x}): {:?}",
peer,
peer_addr,
channel,
e
);
let resp =
build_error_response(&msg.header, 500, "Channel Bind Failed");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
METHOD_SEND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
warn!("send indication without allocation from {}", peer);
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA);
let (peer_attr, data_attr) = match (peer_attr, data_attr) {
(Some(p), Some(d)) => (p, d),
_ => {
let resp = build_error_response(
&msg.header,
400,
"Missing DATA or XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
let peer_addr = match decode_xor_peer_address(
&peer_attr.value,
&msg.header.transaction_id,
) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Invalid XOR-PEER-ADDRESS",
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
match allocation.send_to_peer(peer_addr, &data_attr.value).await {
Ok(sent) => {
tracing::info!(
"forwarded {} bytes from {} to {}",
sent,
peer,
peer_addr
);
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
Err(e) => {
tracing::error!(
"failed to send payload from {} to {}: {:?}",
peer,
peer_addr,
e
);
let resp =
build_error_response(&msg.header, 500, "Peer Send Failed");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
METHOD_REFRESH => {
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match allocs.refresh_allocation(peer, requested) {
Ok(applied) => {
if applied.is_zero() {
tracing::info!("allocation for {} released", peer);
} else {
tracing::debug!(
"allocation for {} refreshed to {}s",
peer,
applied.as_secs()
);
}
let resp = build_lifetime_success(
&msg.header,
applied.as_secs().min(u32::MAX as u64) as u32,
);
let _ = udp.send_to(&resp, &peer).await;
}
Err(_) => {
let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch");
let _ = udp.send_to(&resp, &peer).await;
}
}
continue;
}
_ => {
let resp = build_error_response(&msg.header, 420, "Unknown TURN Method");
let _ = udp.send_to(&resp, &peer).await;
continue;
}
}
}
match msg.header.msg_type {
METHOD_BINDING => {
let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await;
}
_ => {
let nonce = auth.mint_nonce(&peer);
let resp =
build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized");
if let Err(e) = udp.send_to(&resp, &peer).await {
error!("failed to send 401: {:?}", e);
}
}
}
} else {
tracing::debug!("Non-STUN or parse error from {} len={}", peer, len);
}
}
}

1272
src/tls.rs

File diff suppressed because it is too large Load Diff

37
tests/support/mod.rs Normal file
View File

@ -0,0 +1,37 @@
pub mod stun_builders;
pub mod tls;
use std::net::SocketAddr;
use niom_turn::auth::{AuthManager, InMemoryStore};
use niom_turn::config::AuthOptions;
/// Ensure tracing is initialised for integration tests with the library defaults.
#[allow(dead_code)]
pub fn init_tracing() {
niom_turn::logging::init_tracing();
}
/// Initialise tracing with a custom default directive (still overridable via `RUST_LOG`).
#[allow(dead_code)]
pub fn init_tracing_with(default_directive: &str) {
niom_turn::logging::init_tracing_with_default(default_directive);
}
/// Helper to construct a basic AuthManager with a single credential for integration tests.
pub fn test_auth_manager(user: &str, password: &str) -> AuthManager<InMemoryStore> {
let store = InMemoryStore::new();
store.insert(user, password);
AuthManager::new(store, &AuthOptions::default())
}
/// Default TURN client credential used in integration scenarios.
pub fn default_test_credentials() -> (&'static str, &'static str) {
("testuser", "secretpassword")
}
/// Convenience to parse socket addresses in tests.
#[allow(dead_code)]
pub fn addr(addr: &str) -> SocketAddr {
addr.parse().expect("valid socket addr")
}

View File

@ -0,0 +1,210 @@
#![allow(dead_code)]
use bytes::BytesMut;
use niom_turn::constants::*;
use niom_turn::stun::{compute_message_integrity, parse_message};
use uuid::Uuid;
/// Build a basic STUN header for TURN requests with a freshly generated transaction id.
pub fn new_transaction_id() -> [u8; 12] {
let uuid = Uuid::new_v4();
let mut trans = [0u8; 12];
trans.copy_from_slice(&uuid.as_bytes()[..12]);
trans
}
/// Construct a TURN Allocate request optionally including lifetime and auth attributes.
pub fn build_allocate_request(
username: Option<&str>,
realm: Option<&str>,
nonce: Option<&str>,
key: Option<&[u8]>,
lifetime: Option<u32>,
) -> Vec<u8> {
build_authenticated_request(
METHOD_ALLOCATE,
username,
realm,
nonce,
key,
lifetime,
None,
None,
)
}
/// Construct a TURN Refresh request using MESSAGE-INTEGRITY.
pub fn build_refresh_request(
trans: [u8; 12],
username: &str,
realm: &str,
nonce: &str,
key: &[u8],
lifetime: u32,
) -> Vec<u8> {
build_request_with_body(
METHOD_REFRESH,
Some(username),
Some(realm),
Some(nonce),
Some(key),
Some(lifetime),
None,
None,
Some(trans),
)
}
/// Build a CreatePermission request for the specified peer address.
pub fn build_create_permission_request(
username: &str,
realm: &str,
nonce: &str,
key: &[u8],
peer: &std::net::SocketAddr,
) -> Vec<u8> {
build_request_with_body(
METHOD_CREATE_PERMISSION,
Some(username),
Some(realm),
Some(nonce),
Some(key),
None,
Some(peer),
None,
None,
)
}
/// Build a Send indication to forward payload to peer.
pub fn build_send_request(
username: &str,
realm: &str,
nonce: &str,
key: &[u8],
peer: &std::net::SocketAddr,
payload: &[u8],
) -> Vec<u8> {
build_request_with_body(
METHOD_SEND,
Some(username),
Some(realm),
Some(nonce),
Some(key),
None,
Some(peer),
Some(payload),
None,
)
}
fn build_authenticated_request(
method: u16,
username: Option<&str>,
realm: Option<&str>,
nonce: Option<&str>,
key: Option<&[u8]>,
lifetime: Option<u32>,
peer: Option<&std::net::SocketAddr>,
payload: Option<&[u8]>,
) -> Vec<u8> {
build_request_with_body(
method, username, realm, nonce, key, lifetime, peer, payload, None,
)
}
fn build_request_with_body(
method: u16,
username: Option<&str>,
realm: Option<&str>,
nonce: Option<&str>,
key: Option<&[u8]>,
lifetime: Option<u32>,
peer: Option<&std::net::SocketAddr>,
payload: Option<&[u8]>,
override_trans: Option<[u8; 12]>,
) -> Vec<u8> {
let mut buf = BytesMut::new();
buf.extend_from_slice(&method.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
let trans = override_trans.unwrap_or_else(new_transaction_id);
buf.extend_from_slice(&trans);
if let Some(username) = username {
push_string_attr(&mut buf, ATTR_USERNAME, username);
}
if let Some(realm) = realm {
push_string_attr(&mut buf, ATTR_REALM, realm);
}
if let Some(nonce) = nonce {
push_string_attr(&mut buf, ATTR_NONCE, nonce);
}
if let Some(lifetime) = lifetime {
push_u32_attr(&mut buf, ATTR_LIFETIME, lifetime);
}
if let Some(peer) = peer {
let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans);
push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded);
}
if let Some(data) = payload {
push_bytes_attr(&mut buf, ATTR_DATA, data);
}
if let Some(key) = key {
append_message_integrity(&mut buf, key);
}
// update length
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec()
}
fn push_string_attr(buf: &mut BytesMut, typ: u16, value: &str) {
push_bytes_attr(buf, typ, value.as_bytes());
}
fn push_u32_attr(buf: &mut BytesMut, typ: u16, value: u32) {
push_bytes_attr(buf, typ, &value.to_be_bytes());
}
fn push_bytes_attr(buf: &mut BytesMut, typ: u16, data: &[u8]) {
buf.extend_from_slice(&typ.to_be_bytes());
buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
buf.extend_from_slice(data);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
}
fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) {
// position before adding MESSAGE-INTEGRITY attribute
let attribute_start = buf.len();
// append attribute header and placeholder value
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&(20u16.to_be_bytes()));
let value_start = buf.len();
buf.extend_from_slice(&[0u8; 20]);
// update message length to include the attribute (spec requires this before HMAC)
let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
// compute the HMAC over all bytes preceding the attribute (RFC 5389 §15.4)
let signed = compute_message_integrity(key, &buf[..attribute_start]);
// write the computed MAC into the placeholder we appended above
buf[value_start..value_start + 20].copy_from_slice(&signed[..20]);
}
/// Ensure builders produce parseable STUN messages.
pub fn parse(buf: &[u8]) -> niom_turn::models::stun::StunMessage {
parse_message(buf).expect("valid stun message")
}

59
tests/support/tls.rs Normal file
View File

@ -0,0 +1,59 @@
#![allow(dead_code)]
use std::sync::Arc;
use std::net::IpAddr;
use rcgen::{Certificate, CertificateParams, DistinguishedName, DnType, SanType};
use tokio_rustls::rustls::{Certificate as RustlsCert, PrivateKey};
/// Generate a self-signed certificate and matching key for test TLS servers.
pub fn generate_self_signed_cert() -> (RustlsCert, PrivateKey) {
let mut params = CertificateParams::default();
params.distinguished_name = DistinguishedName::new();
params
.distinguished_name
.push(DnType::CommonName, "niom-turn-test");
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
params
.subject_alt_names
.push(SanType::DnsName("localhost".into()));
params.subject_alt_names.push(SanType::IpAddress(
"127.0.0.1"
.parse::<IpAddr>()
.expect("localhost loopback ip"),
));
let cert = Certificate::from_params(params).expect("certificate params");
let pem = cert.serialize_der().expect("cert der");
let key = cert.serialize_private_key_der();
(RustlsCert(pem), PrivateKey(key))
}
/// Build a rustls server config for tests using a generated certificate.
pub fn build_server_config() -> tokio_rustls::rustls::ServerConfig {
let (cert, key) = generate_self_signed_cert();
let mut cfg = tokio_rustls::rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.expect("valid test server config");
cfg.alpn_protocols = vec![b"turn".to_vec()];
cfg
}
/// Build a rustls client config trusting the generated test certificate.
pub fn build_client_config(cert: &RustlsCert) -> tokio_rustls::rustls::ClientConfig {
let mut root_store = tokio_rustls::rustls::RootCertStore::empty();
root_store.add(cert).expect("add root cert");
tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth()
}
/// Wrap tls config into acceptor for tests.
pub fn build_acceptor(cfg: tokio_rustls::rustls::ServerConfig) -> tokio_rustls::TlsAcceptor {
tokio_rustls::TlsAcceptor::from(Arc::new(cfg))
}

155
tests/tls_turn.rs Normal file
View File

@ -0,0 +1,155 @@
use std::sync::Arc;
use niom_turn::alloc::AllocationManager;
use niom_turn::stun::parse_message;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket};
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
use crate::support::stun_builders::{build_allocate_request, build_refresh_request};
use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager};
mod support;
#[tokio::test]
async fn tls_allocate_refresh_flow() {
init_tracing_with("warn,niom_turn=info");
let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind");
let udp_arc = Arc::new(udp);
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let (cert, key) = support::tls::generate_self_signed_cert();
let mut cfg = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key)
.expect("server config");
cfg.alpn_protocols.push(b"turn".to_vec());
let acceptor = TlsAcceptor::from(Arc::new(cfg));
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
let udp_clone = udp_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
loop {
let (stream, peer) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acceptor = acceptor.clone();
let udp_clone = udp_clone.clone();
let auth_clone = auth_clone.clone();
let alloc_clone = alloc_clone.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(mut tls_stream) => {
match niom_turn::tls::handle_tls_connection(
&mut tls_stream,
peer,
udp_clone,
auth_clone,
alloc_clone,
)
.await
{
Ok(_) => {}
Err(e) => {
tracing::error!("tls connection error: {:?}", e);
}
}
}
Err(e) => {
tracing::error!("tls accept failed: {:?}", e);
}
}
});
}
});
// Build client config trusting generated cert
let mut root_store = tokio_rustls::rustls::RootCertStore::empty();
root_store.add(&cert).expect("add root");
let client_config = tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
let tcp_stream = tokio::net::TcpStream::connect(tcp_addr)
.await
.expect("tcp connect");
let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap();
let mut tls_stream = connector
.connect(domain, tcp_stream)
.await
.expect("tls connect");
tracing::info!("client connected");
let allocate = build_allocate_request(None, None, None, None, None);
tls_stream
.write_all(&allocate)
.await
.expect("write allocate");
tracing::info!("sent unauthenticated allocate request");
let mut buf = vec![0u8; 1500];
let n = tls_stream.read(&mut buf).await.expect("read challenge");
tracing::info!(bytes = n, "received nonce challenge");
let resp = parse_message(&buf[..n]).expect("parse 401");
let nonce_attr = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
let nonce = String::from_utf8(nonce_attr.value.clone()).expect("nonce str");
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
let allocate = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce),
Some(&key),
Some(600),
);
tls_stream
.write_all(&allocate)
.await
.expect("write auth allocate");
tracing::info!("sent authenticated allocate request");
let n = tls_stream.read(&mut buf).await.expect("read success");
tracing::info!(bytes = n, "received allocate success");
let resp = parse_message(&buf[..n]).expect("parse alloc success");
assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
let refresh = build_refresh_request(
resp.header.transaction_id,
username,
auth.realm(),
&nonce,
&key,
0,
);
tls_stream.write_all(&refresh).await.expect("write refresh");
tracing::info!("sent refresh request");
let n = tls_stream.read(&mut buf).await.expect("read refresh resp");
tracing::info!(bytes = n, "received refresh response");
let resp = parse_message(&buf[..n]).expect("parse refresh resp");
let lifetime = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_LIFETIME)
.expect("lifetime attr");
let secs = u32::from_be_bytes([
lifetime.value[0],
lifetime.value[1],
lifetime.value[2],
lifetime.value[3],
]);
assert_eq!(secs, 0);
}

243
tests/udp_turn.rs Normal file
View File

@ -0,0 +1,243 @@
use std::net::SocketAddr;
use std::sync::Arc;
use niom_turn::alloc::AllocationManager;
use niom_turn::auth::InMemoryStore;
use niom_turn::server::udp_reader_loop;
use niom_turn::stun::parse_message;
use tokio::net::UdpSocket;
use crate::support::stun_builders::{
build_allocate_request, build_create_permission_request, build_refresh_request,
build_send_request, new_transaction_id, parse,
};
use crate::support::{default_test_credentials, init_tracing, test_auth_manager};
mod support;
const SERVER_ADDR: &str = "127.0.0.1:0";
#[tokio::test]
async fn allocate_requires_auth_then_succeeds() {
init_tracing();
let (server, client_addr) = start_udp_server().await;
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
let server_clone = server_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await;
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
// initial unauthenticated allocate should trigger 401 with nonce
let req = build_allocate_request(None, None, None, None, None);
client
.send_to(&req, client_addr)
.await
.expect("send unauth allocate");
let mut buf = [0u8; 1500];
let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge");
let resp = parse_message(&buf[..len]).expect("parse 401");
assert_eq!(resp.header.msg_type & 0x0110, 0x0110);
let nonce = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr")
.value
.clone();
let nonce_str = String::from_utf8(nonce).expect("nonce utf8");
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
let req = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce_str),
Some(&key),
Some(600),
);
client
.send_to(&req, client_addr)
.await
.expect("send auth allocate");
let (len, _) = client.recv_from(&mut buf).await.expect("recv success");
let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
}
#[tokio::test]
async fn refresh_zero_lifetime_releases_allocation() {
init_tracing();
let (server, client_addr) = start_udp_server().await;
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
let server_clone = server_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await;
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
let nonce =
perform_authenticated_allocate(&client, client_addr, &auth, username, password, &allocs)
.await;
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
let trans_id = new_transaction_id();
let refresh = build_refresh_request(trans_id, username, auth.realm(), &nonce, &key, 0);
client
.send_to(&refresh, client_addr)
.await
.expect("send refresh");
let mut buf = [0u8; 1500];
let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh resp");
let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
let lifetime = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_LIFETIME)
.expect("lifetime attr");
assert_eq!(
u32::from_be_bytes([
lifetime.value[0],
lifetime.value[1],
lifetime.value[2],
lifetime.value[3]
]),
0
);
assert!(allocs
.get_allocation(&client.local_addr().unwrap())
.is_none());
}
#[tokio::test]
async fn create_permission_and_send_relays_data() {
init_tracing();
let (server, client_addr) = start_udp_server().await;
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
let server_clone = server_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await;
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
let nonce =
perform_authenticated_allocate(&client, client_addr, &auth, username, password, &allocs)
.await;
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind");
let relay_addr = allocs
.get_allocation(&client.local_addr().unwrap())
.expect("allocation exists")
.relay_addr;
let perm_req = build_create_permission_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
);
client
.send_to(&perm_req, client_addr)
.await
.expect("send create permission");
let mut buf = [0u8; 1500];
client.recv_from(&mut buf).await.expect("permission resp");
let payload = b"hello-turn";
let send_req = build_send_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
payload,
);
client
.send_to(&send_req, client_addr)
.await
.expect("send indication");
let mut peer_buf = [0u8; 1500];
let (len, addr) = peer_sock.recv_from(&mut peer_buf).await.expect("peer recv");
assert_eq!(len, payload.len());
assert_eq!(&peer_buf[..len], payload);
assert_eq!(addr.port(), relay_addr.port());
assert!(addr.ip().is_loopback());
}
async fn start_udp_server() -> (UdpSocket, SocketAddr) {
let server = UdpSocket::bind(SERVER_ADDR).await.expect("server bind");
let addr = server.local_addr().expect("server addr");
(server, addr)
}
async fn perform_authenticated_allocate(
client: &UdpSocket,
server_addr: SocketAddr,
auth: &niom_turn::auth::AuthManager<InMemoryStore>,
username: &str,
password: &str,
allocs: &AllocationManager,
) -> String {
init_tracing();
// trigger nonce challenge
let req = build_allocate_request(None, None, None, None, None);
client
.send_to(&req, server_addr)
.await
.expect("send initial allocate");
let mut buf = [0u8; 1500];
let (len, _) = client.recv_from(&mut buf).await.expect("recv nonce");
let resp = parse_message(&buf[..len]).expect("parse nonce resp");
let nonce = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr")
.value
.clone();
let nonce = String::from_utf8(nonce).expect("nonce utf8");
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
let req = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce),
Some(&key),
Some(600),
);
client
.send_to(&req, server_addr)
.await
.expect("send auth allocate");
client.recv_from(&mut buf).await.expect("recv success");
assert!(allocs
.get_allocation(&client.local_addr().unwrap())
.is_some());
nonce
}