niom-turn/src/auth.rs

618 lines
25 KiB
Rust

//! Authentication helpers and the in-memory credential store used for the MVP server.
//! Backlog: Argon2-backed storage, nonce lifecycle, and integration with persistent secrets.
use crate::config::AuthOptions;
use crate::constants::{ATTR_NONCE, ATTR_REALM, ATTR_USERNAME};
use crate::models::stun::StunMessage;
use crate::stun::{
compute_message_integrity_adjusted,
compute_message_integrity_adjusted_nozero,
compute_message_integrity_before_mi,
compute_message_integrity_full,
compute_message_integrity_full_nozero,
compute_message_integrity_full_len_to_mi_end,
compute_message_integrity_len_preserved as compute_mi_len_preserved,
compute_message_integrity_len_preserved_nozero,
compute_message_integrity_through_mi_header,
find_message_integrity,
MessageIntegrityMode,
validate_message_integrity,
validate_message_integrity_len_preserved_nozero,
validate_message_integrity_nozero,
validate_message_integrity_len_preserved,
};
use crate::traits::CredentialStore;
use async_trait::async_trait;
use base64::Engine;
use hmac::{Hmac, Mac};
use sha1::Sha1;
use tracing::warn;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
/// Simple in-memory credential store for MVP
#[derive(Clone, Default)]
pub struct InMemoryStore {
// simple map; for production replace with DB-backed store
inner: Arc<std::sync::Mutex<std::collections::HashMap<String, String>>>,
}
impl InMemoryStore {
pub fn new() -> Self {
Self {
inner: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
}
}
pub fn insert(&self, user: impl Into<String>, password: impl Into<String>) {
let mut m = self.inner.lock().unwrap();
m.insert(user.into(), password.into());
}
}
#[async_trait]
impl CredentialStore for InMemoryStore {
async fn get_password(&self, username: &str) -> Option<String> {
let m = self.inner.lock().unwrap();
m.get(username).cloned()
}
}
/// Authentication settings resolved from configuration for runtime usage.
#[derive(Clone, Debug)]
pub struct AuthSettings {
pub realm: String,
pub nonce_secret: Vec<u8>,
pub nonce_ttl: Duration,
pub rest_secret: Option<Vec<u8>>,
pub rest_max_ttl: Duration,
}
impl AuthSettings {
pub fn from_options(opts: &AuthOptions) -> Self {
let secret = opts
.nonce_secret
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
// Ensure TTL does not collapse to zero so challenges stay valid briefly.
let ttl = Duration::from_secs(opts.nonce_ttl_seconds.max(60));
let rest_max_ttl = Duration::from_secs(opts.rest_max_ttl_seconds.max(60));
Self {
realm: opts.realm.clone(),
nonce_secret: secret.into_bytes(),
nonce_ttl: ttl,
rest_secret: opts.rest_secret.clone().map(|s| s.into_bytes()),
rest_max_ttl,
}
}
}
/// Result of validating authentication attributes on an incoming STUN/TURN request.
#[derive(Debug, Clone)]
pub enum AuthStatus {
Granted {
username: String,
key: Vec<u8>,
mi_mode: MessageIntegrityMode,
},
Challenge { nonce: String },
StaleNonce { nonce: String },
Reject { code: u16, reason: &'static str },
}
/// Orchestrates STUN/TURN long-term credential validation for the server.
pub struct AuthManager<S: CredentialStore + Clone> {
store: S,
settings: AuthSettings,
}
fn build_signed_bytes_adjusted(
msg: &StunMessage,
mi_offset: usize,
mi_len: usize,
) -> Option<Vec<u8>> {
if mi_len != 20 {
return None;
}
let mi_end = mi_offset.checked_add(4 + mi_len)?;
if mi_end > msg.raw.len() {
return None;
}
let mut signed = msg.raw[..mi_end].to_vec();
let len = (mi_end - 20) as u16;
signed[2..4].copy_from_slice(&len.to_be_bytes());
let value_start = mi_offset + 4;
signed[value_start..value_start + mi_len].fill(0);
Some(signed)
}
fn build_signed_bytes_len_preserved(
msg: &StunMessage,
mi_offset: usize,
mi_len: usize,
) -> Option<Vec<u8>> {
if mi_len != 20 {
return None;
}
let mi_end = mi_offset.checked_add(4 + mi_len)?;
if mi_end > msg.raw.len() {
return None;
}
let mut signed = msg.raw[..mi_end].to_vec();
let value_start = mi_offset + 4;
signed[value_start..value_start + mi_len].fill(0);
Some(signed)
}
impl<S: CredentialStore + Clone> Clone for AuthManager<S> {
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
settings: self.settings.clone(),
}
}
}
impl<S: CredentialStore + Clone> AuthManager<S> {
pub fn new(store: S, opts: &AuthOptions) -> Self {
Self {
store,
settings: AuthSettings::from_options(opts),
}
}
pub fn realm(&self) -> &str {
&self.settings.realm
}
/// Inspect a parsed STUN/TURN message and determine whether credentials are acceptable.
pub async fn authenticate(&self, msg: &StunMessage, peer: &SocketAddr) -> AuthStatus {
if find_message_integrity(msg).is_none() {
// Client has not yet computed MESSAGE-INTEGRITY; ask it to retry with credentials.
return AuthStatus::Challenge {
nonce: self.mint_nonce(peer),
};
}
let username = match self.attribute_utf8(msg, ATTR_USERNAME) {
Some(u) => u,
None => {
return AuthStatus::Challenge {
nonce: self.mint_nonce(peer),
}
}
};
let realm = match self.attribute_utf8(msg, ATTR_REALM) {
Some(r) => r,
None => {
return AuthStatus::Challenge {
nonce: self.mint_nonce(peer),
}
}
};
if realm != self.settings.realm {
warn!("auth reject: realm mismatch client_realm={} expected={} peer={}", realm, self.settings.realm, peer);
return AuthStatus::Reject {
code: 400,
reason: "Realm Mismatch",
};
}
let nonce = match self.attribute_utf8(msg, ATTR_NONCE) {
Some(n) => n,
None => {
return AuthStatus::Challenge {
nonce: self.mint_nonce(peer),
}
}
};
match self.check_nonce(&nonce, peer) {
NonceValidation::Valid => {}
NonceValidation::Expired => {
return AuthStatus::StaleNonce {
nonce: self.mint_nonce(peer),
}
}
NonceValidation::Invalid => {
return AuthStatus::Challenge {
nonce: self.mint_nonce(peer),
}
}
}
let password = match self.store.get_password(&username).await {
Some(p) => p,
None => match self.derive_turn_rest_password(&username) {
Some(p) => p,
None => {
warn!("auth reject: unknown user username={} realm={} peer={}", username, realm, peer);
return AuthStatus::Reject {
code: 401,
reason: "Unknown User",
};
}
},
};
let key = self.derive_long_term_key(&username, &password);
// Primary: long-term (MD5(username:realm:password))
if validate_message_integrity(msg, &key) {
return AuthStatus::Granted {
username,
key,
mi_mode: MessageIntegrityMode::Rfc5389,
};
}
// Interop: some clients appear to compute MESSAGE-INTEGRITY without zeroing the MI bytes.
if validate_message_integrity_nozero(msg, &key)
|| validate_message_integrity_len_preserved_nozero(msg, &key)
{
warn!(
"auth accept via MI nozero username={} realm={} peer={} (interop)",
username, realm, peer
);
return AuthStatus::Granted {
username,
key,
mi_mode: MessageIntegrityMode::Rfc5389,
};
}
// Workaround: also accept short-term style (raw password as key) for test clients like turnutils_uclient.
let short_key = password.as_bytes();
if validate_message_integrity(msg, short_key)
|| validate_message_integrity_len_preserved(msg, short_key)
{
warn!("auth accept via short-term key username={} realm={} peer={} (workaround)", username, realm, peer);
return AuthStatus::Granted {
username,
key: short_key.to_vec(),
mi_mode: MessageIntegrityMode::Rfc5389,
};
}
if validate_message_integrity_nozero(msg, short_key)
|| validate_message_integrity_len_preserved_nozero(msg, short_key)
{
warn!(
"auth accept via short-term nozero username={} realm={} peer={} (interop)",
username, realm, peer
);
return AuthStatus::Granted {
username,
key: short_key.to_vec(),
mi_mode: MessageIntegrityMode::Rfc5389,
};
}
// Additional interop fallback: some clients miscompute length when adding FINGERPRINT;
// try validation without adjusting the header length.
if validate_message_integrity_len_preserved(msg, &key) {
warn!("auth accept via len-preserved MI username={} realm={} peer={} (interop fallback)", username, realm, peer);
return AuthStatus::Granted {
username,
key,
mi_mode: MessageIntegrityMode::Rfc5389,
};
}
// No acceptance without MI validation. Emit detailed diagnostics.
// Keep logs compact by default to avoid journald truncation.
// Set NIOM_TURN_DEBUG_AUTH_HEX=1 for a summary, and NIOM_TURN_DEBUG_AUTH_HEX_FULL=1 for full raw/signed hex.
if std::env::var_os("NIOM_TURN_DEBUG_AUTH_HEX").is_some() {
let mi = find_message_integrity(msg);
let mi_end = mi.map(|a| a.offset + 4 + a.value.len()).unwrap_or(0);
let mut attrs = Vec::new();
for a in &msg.attributes {
attrs.push(format!(
"t=0x{:04x} len={} off={} v={}",
a.typ,
a.value.len(),
a.offset,
hex::encode(&a.value)
));
}
let signed_adj_hex = mi
.and_then(|a| build_signed_bytes_adjusted(msg, a.offset, a.value.len()))
.map(hex::encode);
let signed_len_hex = mi
.and_then(|a| build_signed_bytes_len_preserved(msg, a.offset, a.value.len()))
.map(hex::encode);
if std::env::var_os("NIOM_TURN_DEBUG_AUTH_HEX_FULL").is_some() {
warn!(
"auth debug dump FULL peer={} msg_type=0x{:04x} raw_len={} raw={} mi_end={} attrs=[{}] signed_adj={:?} signed_len_preserved={:?}",
peer,
msg.header.msg_type,
msg.raw.len(),
hex::encode(&msg.raw),
mi_end,
attrs.join(" | "),
signed_adj_hex,
signed_len_hex
);
} else {
let mi_hex = mi.map(|a| hex::encode(&a.value));
warn!(
"auth debug dump peer={} msg_type=0x{:04x} raw_len={} mi_end={} mi={:?} attrs=[{}]",
peer,
msg.header.msg_type,
msg.raw.len(),
mi_end,
mi_hex,
attrs.join(" | ")
);
warn!(
"auth debug signed (truncated) peer={} signed_adj_prefix={:?} signed_len_preserved_prefix={:?}",
peer,
signed_adj_hex.as_ref().map(|s| s.chars().take(160).collect::<String>()),
signed_len_hex.as_ref().map(|s| s.chars().take(160).collect::<String>())
);
}
}
let mi_attr = find_message_integrity(msg).map(|a| hex::encode(&a.value));
let mi_long_adj = compute_message_integrity_adjusted(msg, &key).map(hex::encode);
let mi_long_len = compute_mi_len_preserved(msg, &key).map(hex::encode);
let mi_long_adj_nozero = compute_message_integrity_adjusted_nozero(msg, &key).map(hex::encode);
let mi_long_len_nozero = compute_message_integrity_len_preserved_nozero(msg, &key).map(hex::encode);
let mi_short_adj = compute_message_integrity_adjusted(msg, short_key).map(hex::encode);
let mi_short_len = compute_mi_len_preserved(msg, short_key).map(hex::encode);
let mi_short_adj_nozero = compute_message_integrity_adjusted_nozero(msg, short_key).map(hex::encode);
let mi_short_len_nozero = compute_message_integrity_len_preserved_nozero(msg, short_key).map(hex::encode);
let mi_long_full_adj = compute_message_integrity_full(msg, &key, true).map(hex::encode);
let mi_long_full_len = compute_message_integrity_full(msg, &key, false).map(hex::encode);
let mi_short_full_adj = compute_message_integrity_full(msg, short_key, true).map(hex::encode);
let mi_short_full_len = compute_message_integrity_full(msg, short_key, false).map(hex::encode);
let mi_long_full_adj_nozero = compute_message_integrity_full_nozero(msg, &key, true, false).map(hex::encode);
let mi_long_full_adj_nozero_zfp = compute_message_integrity_full_nozero(msg, &key, true, true).map(hex::encode);
let mi_short_full_adj_nozero = compute_message_integrity_full_nozero(msg, short_key, true, false).map(hex::encode);
let mi_short_full_adj_nozero_zfp = compute_message_integrity_full_nozero(msg, short_key, true, true).map(hex::encode);
let mi_long_full_len_to_mi_end =
compute_message_integrity_full_len_to_mi_end(msg, &key, true, false).map(hex::encode);
let mi_long_full_len_to_mi_end_nozero =
compute_message_integrity_full_len_to_mi_end(msg, &key, false, false).map(hex::encode);
let mi_long_full_len_to_mi_end_nozero_zfp =
compute_message_integrity_full_len_to_mi_end(msg, &key, false, true).map(hex::encode);
let mi_short_full_len_to_mi_end =
compute_message_integrity_full_len_to_mi_end(msg, short_key, true, false).map(hex::encode);
let mi_long_before_mi_len_to_mi_end =
compute_message_integrity_before_mi(msg, &key, true).map(hex::encode);
let mi_long_before_mi_len_before_mi =
compute_message_integrity_before_mi(msg, &key, false).map(hex::encode);
let mi_long_through_mi_hdr = compute_message_integrity_through_mi_header(msg, &key).map(hex::encode);
// Accept if any variant matches received MI (still requires correct key).
if let Some(mi_attr_val) = find_message_integrity(msg) {
let mi_bytes = &mi_attr_val.value;
let variants: [(&str, Option<Vec<u8>>); 28] = [
("long_adj", compute_message_integrity_adjusted(msg, &key)),
("long_len", compute_mi_len_preserved(msg, &key)),
("long_adj_nozero", compute_message_integrity_adjusted_nozero(msg, &key)),
("long_len_nozero", compute_message_integrity_len_preserved_nozero(msg, &key)),
("short_adj", compute_message_integrity_adjusted(msg, short_key)),
("short_len", compute_mi_len_preserved(msg, short_key)),
("short_adj_nozero", compute_message_integrity_adjusted_nozero(msg, short_key)),
("short_len_nozero", compute_message_integrity_len_preserved_nozero(msg, short_key)),
("long_full_adj", compute_message_integrity_full(msg, &key, true)),
("long_full_len", compute_message_integrity_full(msg, &key, false)),
("short_full_adj", compute_message_integrity_full(msg, short_key, true)),
("short_full_len", compute_message_integrity_full(msg, short_key, false)),
("long_full_adj_nozero", compute_message_integrity_full_nozero(msg, &key, true, false)),
("long_full_adj_nozero_zfp", compute_message_integrity_full_nozero(msg, &key, true, true)),
("short_full_adj_nozero", compute_message_integrity_full_nozero(msg, short_key, true, false)),
("short_full_adj_nozero_zfp", compute_message_integrity_full_nozero(msg, short_key, true, true)),
("long_full_len_to_mi_end", compute_message_integrity_full_len_to_mi_end(msg, &key, true, false)),
("long_full_len_to_mi_end_nozero", compute_message_integrity_full_len_to_mi_end(msg, &key, false, false)),
("long_full_len_to_mi_end_nozero_zfp", compute_message_integrity_full_len_to_mi_end(msg, &key, false, true)),
("short_full_len_to_mi_end", compute_message_integrity_full_len_to_mi_end(msg, short_key, true, false)),
("short_full_len_to_mi_end_nozero", compute_message_integrity_full_len_to_mi_end(msg, short_key, false, false)),
("short_full_len_to_mi_end_nozero_zfp", compute_message_integrity_full_len_to_mi_end(msg, short_key, false, true)),
("long_before_mi_len_to_mi_end", compute_message_integrity_before_mi(msg, &key, true)),
("long_before_mi_len_before_mi", compute_message_integrity_before_mi(msg, &key, false)),
("long_through_mi_header", compute_message_integrity_through_mi_header(msg, &key)),
("short_before_mi_len_to_mi_end", compute_message_integrity_before_mi(msg, short_key, true)),
("short_before_mi_len_before_mi", compute_message_integrity_before_mi(msg, short_key, false)),
("short_through_mi_header", compute_message_integrity_through_mi_header(msg, short_key)),
];
for (label, cand) in variants.iter() {
if let Some(c) = cand {
if c.len() >= 20 && &c[..20] == mi_bytes.as_slice() {
let mi_mode = if *label == "long_before_mi_len_to_mi_end"
|| *label == "short_before_mi_len_to_mi_end"
{
MessageIntegrityMode::BeforeMiLenToMiEnd
} else {
MessageIntegrityMode::Rfc5389
};
let chosen_key = if label.starts_with("short_") {
short_key.to_vec()
} else {
key.clone()
};
warn!("auth accept via MI variant={} username={} realm={} peer={} (interop)", label, username, realm, peer);
return AuthStatus::Granted {
username,
key: chosen_key,
mi_mode,
};
}
}
}
}
warn!(
"auth reject: bad credentials username={} realm={} peer={} a1_md5={} mi_attr={:?} mi_long_adj={:?} mi_long_len={:?} mi_long_adj_nozero={:?} mi_long_len_nozero={:?} mi_short_adj={:?} mi_short_len={:?} mi_short_adj_nozero={:?} mi_short_len_nozero={:?} mi_long_full_adj={:?} mi_long_full_len={:?} mi_short_full_adj={:?} mi_short_full_len={:?} mi_long_full_adj_nozero={:?} mi_long_full_adj_nozero_zfp={:?} mi_short_full_adj_nozero={:?} mi_short_full_adj_nozero_zfp={:?} mi_long_full_len_to_mi_end={:?} mi_long_full_len_to_mi_end_nozero={:?} mi_long_full_len_to_mi_end_nozero_zfp={:?} mi_short_full_len_to_mi_end={:?} mi_long_before_mi_len_to_mi_end={:?} mi_long_before_mi_len_before_mi={:?} mi_long_through_mi_hdr={:?}",
username,
realm,
peer,
hex::encode(&key),
mi_attr,
mi_long_adj,
mi_long_len,
mi_long_adj_nozero,
mi_long_len_nozero,
mi_short_adj,
mi_short_len,
mi_short_adj_nozero,
mi_short_len_nozero,
mi_long_full_adj,
mi_long_full_len,
mi_short_full_adj,
mi_short_full_len,
mi_long_full_adj_nozero,
mi_long_full_adj_nozero_zfp,
mi_short_full_adj_nozero,
mi_short_full_adj_nozero_zfp,
mi_long_full_len_to_mi_end,
mi_long_full_len_to_mi_end_nozero,
mi_long_full_len_to_mi_end_nozero_zfp,
mi_short_full_len_to_mi_end,
mi_long_before_mi_len_to_mi_end,
mi_long_before_mi_len_before_mi,
mi_long_through_mi_hdr
);
AuthStatus::Reject {
code: 401,
reason: "Bad Credentials",
}
}
fn attribute_utf8(&self, msg: &StunMessage, attr_type: u16) -> Option<String> {
msg.attributes
.iter()
.find(|a| a.typ == attr_type)
.and_then(|attr| std::str::from_utf8(&attr.value).ok())
.map(|s| s.to_string())
}
fn derive_long_term_key(&self, username: &str, password: &str) -> Vec<u8> {
compute_a1_md5(username, &self.settings.realm, password)
}
/// TURN REST (ephemeral) password derivation.
///
/// Expected username format: `<expiry_unix_seconds>` or `<expiry_unix_seconds>:<opaque>`.
/// Password is `base64(HMAC-SHA1(rest_secret, username))`.
///
/// Security: Reject if expired or if expiry is too far in the future (bounded by rest_max_ttl).
fn derive_turn_rest_password(&self, username: &str) -> Option<String> {
let secret = self.settings.rest_secret.as_ref()?;
let expiry = parse_turn_rest_expiry(username)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs();
if now > expiry {
return None;
}
let delta = expiry.saturating_sub(now);
if delta > self.settings.rest_max_ttl.as_secs() {
return None;
}
Some(turn_rest_password_base64(secret, username))
}
pub fn mint_nonce(&self, peer: &SocketAddr) -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs();
let payload = format!("{}|{}", now, peer.ip());
let sig = self.sign_payload(payload.as_bytes());
format!("{}:{}", now, sig)
}
fn check_nonce(&self, nonce: &str, peer: &SocketAddr) -> NonceValidation {
let mut parts = nonce.splitn(2, ':');
let ts_str = parts.next();
let sig_str = parts.next();
let (ts_str, sig_str) = match (ts_str, sig_str) {
(Some(ts), Some(sig)) => (ts, sig),
_ => return NonceValidation::Invalid,
};
let timestamp = match ts_str.parse::<u64>() {
Ok(t) => t,
Err(_) => return NonceValidation::Invalid,
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs();
if now.saturating_sub(timestamp) > self.settings.nonce_ttl.as_secs() {
return NonceValidation::Expired;
}
let payload = format!("{}|{}", timestamp, peer.ip());
let expected = self.sign_payload(payload.as_bytes());
if expected == sig_str {
NonceValidation::Valid
} else {
NonceValidation::Invalid
}
}
fn sign_payload(&self, payload: &[u8]) -> String {
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(&self.settings.nonce_secret)
.expect("nonce secret to build hmac");
mac.update(payload);
let bytes = mac.finalize().into_bytes();
hex::encode(bytes)
}
}
fn parse_turn_rest_expiry(username: &str) -> Option<u64> {
let prefix = username.split(':').next().unwrap_or(username);
prefix.parse::<u64>().ok()
}
fn turn_rest_password_base64(secret: &[u8], username: &str) -> String {
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(secret).expect("rest secret to build hmac");
mac.update(username.as_bytes());
let bytes = mac.finalize().into_bytes();
base64::engine::general_purpose::STANDARD.encode(bytes)
}
enum NonceValidation {
Valid,
Expired,
Invalid,
}
/// Helper: compute MESSAGE-INTEGRITY (HMAC-SHA1 as bytes)
pub fn compute_hmac_sha1_bytes(key: &str, data: &[u8]) -> Vec<u8> {
use hmac::{Hmac, Mac};
use sha1::Sha1;
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(key.as_bytes()).expect("HMAC key");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
/// Compute A1 MD5(username:realm:password) as bytes for long-term credential derivation
pub fn compute_a1_md5(username: &str, realm: &str, password: &str) -> Vec<u8> {
let s = format!("{}:{}:{}", username, realm, password);
let digest = md5::compute(s.as_bytes());
digest.0.to_vec()
}