Replace STUN/TURN hex literals with RFC/project constants; refactor main to use library exports

This commit is contained in:
ghost 2025-09-26 16:15:25 +02:00
parent 5bbeb8fc55
commit 8363217c96
3 changed files with 32 additions and 28 deletions

View File

@ -10,6 +10,10 @@ pub const METHOD_ALLOCATE: u16 = 0x0003;
// Common response/error types // Common response/error types
pub const RESP_BINDING_SUCCESS: u16 = 0x0101; pub const RESP_BINDING_SUCCESS: u16 = 0x0101;
// STUN/TURN class bits per RFC5389/RFC5766
pub const CLASS_SUCCESS: u16 = 0x0100;
pub const CLASS_ERROR: u16 = 0x0110;
// Common attribute types // Common attribute types
pub const ATTR_USERNAME: u16 = 0x0006; pub const ATTR_USERNAME: u16 = 0x0006;
pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008; pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008;
@ -22,3 +26,9 @@ pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
// Some helper values // Some helper values
pub const FAMILY_IPV4: u8 = 0x01; pub const FAMILY_IPV4: u8 = 0x01;
// Fingerprint XOR magic (XOR with CRC32 for FINGERPRINT attribute)
pub const FINGERPRINT_XOR: u32 = 0x5354554e;
// Length of HMAC-SHA1 (MESSAGE-INTEGRITY)
pub const HMAC_SHA1_LEN: usize = 20;

View File

@ -3,18 +3,12 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tracing::{info, error}; use tracing::{info, error};
mod stun; // Use the library crate's public modules instead of local `mod` declarations.
mod auth; use niom_turn::constants::*;
mod traits; use niom_turn::auth::InMemoryStore;
mod models; use niom_turn::stun::{parse_message, build_401_response, find_message_integrity, validate_message_integrity, build_success_response, encode_xor_relayed_address};
mod alloc; use niom_turn::traits::CredentialStore;
mod constants; use niom_turn::alloc::AllocationManager;
use crate::constants::*;
use crate::auth::InMemoryStore;
use crate::stun::{parse_message, build_401_response};
use crate::traits::CredentialStore;
// use crate::models::stun::StunHeader; // currently unused
use crate::alloc::AllocationManager;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
@ -62,7 +56,7 @@ async fn udp_reader_loop(udp: Arc<UdpSocket>, creds: InMemoryStore, allocs: Allo
if let Ok(msg) = parse_message(&buf[..len]) { if let Ok(msg) = parse_message(&buf[..len]) {
tracing::info!("STUN/TURN message from {} type=0x{:04x} len={}", peer, msg.header.msg_type, len); tracing::info!("STUN/TURN message from {} type=0x{:04x} len={}", peer, msg.header.msg_type, len);
// If MESSAGE-INTEGRITY present, attempt validation using credential store // If MESSAGE-INTEGRITY present, attempt validation using credential store
if let Some(_mi_attr) = crate::stun::find_message_integrity(&msg) { if let Some(_mi_attr) = find_message_integrity(&msg) {
// For MVP we expect username attribute (USERNAME) to be present // For MVP we expect username attribute (USERNAME) to be present
let username_attr = msg.attributes.iter().find(|a| a.typ == ATTR_USERNAME); let username_attr = msg.attributes.iter().find(|a| a.typ == ATTR_USERNAME);
if let Some(u) = username_attr { if let Some(u) = username_attr {
@ -71,7 +65,7 @@ async fn udp_reader_loop(udp: Arc<UdpSocket>, creds: InMemoryStore, allocs: Allo
let store = creds.clone(); let store = creds.clone();
let pw = store.get_password(username).await; let pw = store.get_password(username).await;
if let Some(password) = pw { if let Some(password) = pw {
let valid = crate::stun::validate_message_integrity(&msg, &password); let valid = validate_message_integrity(&msg, &password);
if valid { if valid {
tracing::info!("MI valid for user {}", username); tracing::info!("MI valid for user {}", username);
// If this is an Allocate request, perform allocation // If this is an Allocate request, perform allocation
@ -85,7 +79,7 @@ async fn udp_reader_loop(udp: Arc<UdpSocket>, creds: InMemoryStore, allocs: Allo
out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes());
out.extend_from_slice(&msg.header.transaction_id); out.extend_from_slice(&msg.header.transaction_id);
// RFC: XOR-RELAYED-ADDRESS (0x0016) // RFC: XOR-RELAYED-ADDRESS (0x0016)
let attr_val = crate::stun::encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); let attr_val = encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id);
out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes());
out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes()));
out.extend_from_slice(&attr_val); out.extend_from_slice(&attr_val);
@ -102,7 +96,7 @@ async fn udp_reader_loop(udp: Arc<UdpSocket>, creds: InMemoryStore, allocs: Allo
} }
} }
// default success response // default success response
let resp = crate::stun::build_success_response(&msg.header); let resp = build_success_response(&msg.header);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} else { } else {
@ -126,7 +120,7 @@ async fn udp_reader_loop(udp: Arc<UdpSocket>, creds: InMemoryStore, allocs: Allo
out.extend_from_slice(&0u16.to_be_bytes()); out.extend_from_slice(&0u16.to_be_bytes());
out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes()); out.extend_from_slice(&MAGIC_COOKIE_U32.to_be_bytes());
out.extend_from_slice(&msg.header.transaction_id); out.extend_from_slice(&msg.header.transaction_id);
let attr_val = crate::stun::encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id); let attr_val = encode_xor_relayed_address(&relay_addr, &msg.header.transaction_id);
out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); out.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes());
out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes())); out.extend_from_slice(&((attr_val.len() as u16).to_be_bytes()));
out.extend_from_slice(&attr_val); out.extend_from_slice(&attr_val);

View File

@ -48,23 +48,23 @@ pub fn parse_message(buf: &[u8]) -> Result<StunMessage, ParseError> {
pub fn build_401_response(req: &StunHeader, realm: &str, nonce: &str, _err_code: u16) -> Vec<u8> { pub fn build_401_response(req: &StunHeader, realm: &str, nonce: &str, _err_code: u16) -> Vec<u8> {
use bytes::BytesMut; use bytes::BytesMut;
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
// Error response type for TURN often uses same method with error bit set; here we reuse 0x0111 placeholder // Error response type for TURN: reuse the request method with error class bits set
let msg_type: u16 = 0x0111; let msg_type: u16 = req.msg_type | CLASS_ERROR;
buf.extend_from_slice(&msg_type.to_be_bytes()); buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes()); // length buf.extend_from_slice(&0u16.to_be_bytes()); // length
buf.extend_from_slice(&MAGIC_COOKIE_BYTES); buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id); buf.extend_from_slice(&req.transaction_id);
// REALM (0x0014) // REALM (RFC attr)
let realm_bytes = realm.as_bytes(); let realm_bytes = realm.as_bytes();
buf.extend_from_slice(&0x0014u16.to_be_bytes()); buf.extend_from_slice(&ATTR_REALM.to_be_bytes());
buf.extend_from_slice(&(realm_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(&(realm_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(realm_bytes); buf.extend_from_slice(realm_bytes);
while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); } while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); }
// NONCE (0x0015) // NONCE (RFC attr)
let nonce_bytes = nonce.as_bytes(); let nonce_bytes = nonce.as_bytes();
buf.extend_from_slice(&0x0015u16.to_be_bytes()); buf.extend_from_slice(&ATTR_NONCE.to_be_bytes());
buf.extend_from_slice(&(nonce_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(&(nonce_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(nonce_bytes); buf.extend_from_slice(nonce_bytes);
while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); } while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0]); }
@ -122,7 +122,7 @@ pub fn compute_fingerprint(msg: &[u8]) -> u32 {
let mut hasher = Hasher::new(); let mut hasher = Hasher::new();
hasher.update(msg); hasher.update(msg);
let crc = hasher.finalize(); let crc = hasher.finalize();
crc ^ 0x5354554e crc ^ FINGERPRINT_XOR
} }
/// Compute MESSAGE-INTEGRITY (HMAC-SHA1) over the message /// Compute MESSAGE-INTEGRITY (HMAC-SHA1) over the message
@ -147,7 +147,7 @@ pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8;1
match addr.ip() { match addr.ip() {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {
out.push(0); // first 8 bits zero per spec out.push(0); // first 8 bits zero per spec
out.push(0x01); // family: 0x01 for IPv4 out.push(FAMILY_IPV4); // family: IPv4
// xport = port ^ (magic_cookie >> 16) // xport = port ^ (magic_cookie >> 16)
let port = addr.port(); let port = addr.port();
let xport = (port ^ ((MAGIC_COOKIE_U32 >> 16) as u16)) as u16; let xport = (port ^ ((MAGIC_COOKIE_U32 >> 16) as u16)) as u16;
@ -168,7 +168,7 @@ pub fn encode_xor_relayed_address(addr: &std::net::SocketAddr, _trans_id: &[u8;1
/// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only) /// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only)
pub fn decode_xor_relayed_address(value: &[u8], _trans_id: &[u8;12]) -> Option<std::net::SocketAddr> { pub fn decode_xor_relayed_address(value: &[u8], _trans_id: &[u8;12]) -> Option<std::net::SocketAddr> {
if value.len() < 8 { return None; } if value.len() < 8 { return None; }
if value[1] != 0x01 { return None; } // not IPv4 if value[1] != FAMILY_IPV4 { return None; } // not IPv4
let xport = u16::from_be_bytes([value[2], value[3]]); let xport = u16::from_be_bytes([value[2], value[3]]);
let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16); let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16);
let cookie_bytes = MAGIC_COOKIE_BYTES; let cookie_bytes = MAGIC_COOKIE_BYTES;
@ -218,7 +218,7 @@ mod tests {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request buf.extend_from_slice(&METHOD_BINDING.to_be_bytes()); // Binding Request
buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder buf.extend_from_slice(&0u16.to_be_bytes()); // length placeholder
buf.extend_from_slice(&0x2112A442u32.to_be_bytes()); buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
let trans = [9u8; 12]; let trans = [9u8; 12];
buf.extend_from_slice(&trans); buf.extend_from_slice(&trans);
@ -232,7 +232,7 @@ mod tests {
// MESSAGE-INTEGRITY placeholder (0x0008) length 20 // MESSAGE-INTEGRITY placeholder (0x0008) length 20
let mi_attr_offset = buf.len(); let mi_attr_offset = buf.len();
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&(20u16).to_be_bytes()); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes()));
let mi_val_pos = buf.len(); let mi_val_pos = buf.len();
buf.extend_from_slice(&[0u8;20]); buf.extend_from_slice(&[0u8;20]);
while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); } while (buf.len() % 4) != 0 { buf.extend_from_slice(&[0u8]); }