From 29c0d8c0cfbc8a8c8fe823254505c2d8e220ab89 Mon Sep 17 00:00:00 2001 From: ghost Date: Wed, 26 Nov 2025 15:06:48 +0100 Subject: [PATCH] Add: Finalize unit and integration tests. README for test usage. --- Cargo.lock | 127 +++++++++++++++++++++++++++++++ Cargo.toml | 4 + tests/README.md | 74 ++++++++++++++++++ tests/alloc/helpers.rs | 50 ++++++++++++ tests/alloc/integration_udp.rs | 70 +++++++++++++++++ tests/alloc/unit.rs | 44 +++++++++++ tests/auth/helpers.rs | 50 ++++++++++++ tests/auth/integration_tls.rs | 116 ++++++++++++++++++++++++++++ tests/auth/integration_udp.rs | 53 +++++++++++++ tests/auth/unit.rs | 80 +++++++++++++++++++ tests/channel/helpers.rs | 49 ++++++++++++ tests/channel/integration_tls.rs | 117 ++++++++++++++++++++++++++++ tests/channel/integration_udp.rs | 55 +++++++++++++ tests/channel/unit.rs | 40 ++++++++++ tests/config/helpers.rs | 24 ++++++ tests/config/integration.rs | 29 +++++++ tests/config/unit.rs | 37 +++++++++ tests/errors/helpers.rs | 36 +++++++++ tests/errors/integration_tls.rs | 92 ++++++++++++++++++++++ tests/errors/integration_udp.rs | 30 ++++++++ tests/errors/unit.rs | 28 +++++++ tests/support/mocks.rs | 94 +++++++++++++++++++++++ tests/support/mod.rs | 1 + tests/support/stun_builders.rs | 74 ++++++++++++++++++ 24 files changed, 1374 insertions(+) create mode 100644 tests/README.md create mode 100644 tests/alloc/helpers.rs create mode 100644 tests/alloc/integration_udp.rs create mode 100644 tests/alloc/unit.rs create mode 100644 tests/auth/helpers.rs create mode 100644 tests/auth/integration_tls.rs create mode 100644 tests/auth/integration_udp.rs create mode 100644 tests/auth/unit.rs create mode 100644 tests/channel/helpers.rs create mode 100644 tests/channel/integration_tls.rs create mode 100644 tests/channel/integration_udp.rs create mode 100644 tests/channel/unit.rs create mode 100644 tests/config/helpers.rs create mode 100644 tests/config/integration.rs create mode 100644 tests/config/unit.rs create mode 100644 tests/errors/helpers.rs create mode 100644 tests/errors/integration_tls.rs create mode 100644 tests/errors/integration_udp.rs create mode 100644 tests/errors/unit.rs create mode 100644 tests/support/mocks.rs diff --git a/Cargo.lock b/Cargo.lock index 3f1196b..d5f3653 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + [[package]] name = "anyhow" version = "1.0.100" @@ -167,12 +173,40 @@ dependencies = [ "subtle", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ced73b1dacfc750a6db6c0a0c3a3853c8b41997e2e2c563dc90804ae6867959" +[[package]] +name = "fragile" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" + [[package]] name = "generic-array" version = "0.14.7" @@ -266,6 +300,12 @@ version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "lock_api" version = "0.4.13" @@ -323,6 +363,33 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "niom-turn" version = "0.1.0" @@ -334,12 +401,14 @@ dependencies = [ "hex", "hmac", "md5", + "mockall", "rcgen", "rustls 0.21.12", "rustls-pemfile", "serde", "serde_json", "sha1", + "tempfile", "thiserror", "tokio", "tokio-rustls", @@ -422,6 +491,32 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -519,6 +614,19 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" version = "0.20.9" @@ -713,6 +821,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index df8c7bd..408de9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,7 @@ thiserror = "1.0" crc32fast = "1.3" md5 = "0.7" +[dev-dependencies] +mockall = "0.12" +tempfile = "3" + diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..f98f380 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,74 @@ +# Test Suite Overview + +The test tree is split by concern so it is easy to find the relevant scenarios. Each folder owns its +own helpers while all reusable mocks live in `support/mocks.rs`. + +``` +tests/ +├── README.md # This file +├── support/ # Shared tooling (stun builders, TLS utils, mocks) +│ ├── mocks.rs # mockall definitions grouped by domain +│ ├── mod.rs # re-exports used by integration crates +│ ├── stun_builders.rs +│ └── tls.rs +├── udp_turn.rs # Legacy UDP end-to-end happy-path integration +├── tls_turn.rs # Legacy TLS allocate/refresh integration +├── auth/ +│ ├── helpers.rs # Realm/nonce utilities + TURN bootstrapping +│ ├── unit.rs # Realm mismatch + unknown user unit tests +│ ├── integration_udp.rs # UDP auth failures (unknown user) +│ └── integration_tls.rs # TLS auth failure path (bad credentials) +├── alloc/ +│ ├── helpers.rs # Lifetime/nonce helpers and UDP harness +│ ├── unit.rs # Lifetime clamp + zero release +│ └── integration_udp.rs # Refresh request clamping (server path) +├── channel/ +│ ├── helpers.rs # Channel/peer fixtures + harness +│ ├── unit.rs # ChannelData parsing coverage +│ ├── integration_udp.rs # ChannelBind without allocation → 437 +│ └── integration_tls.rs # TLS ChannelBind mismatch handling +├── config/ +│ ├── helpers.rs # JSON builders + temp-file writer +│ ├── unit.rs # Minimal/malformed parse tests +│ └── integration.rs # Config::from_file → AuthManager wiring +└── errors/ + ├── helpers.rs # Malformed frame + harness + ├── unit.rs # Frame stream mock + parse errors + ├── integration_udp.rs # Malformed packet dropped silently + └── integration_tls.rs # TLS reader ignores garbage frames +``` + +## Mocks + +- `support/mocks.rs` centralises every `mockall` mock. Sections are grouped by domain (Auth, + Allocation, Channel, Config, Error) so it is obvious which tests should use which mock. +- Tests import mocks via `#[path = "../support/mod.rs"] mod support;` and reach the relevant mock at + `support::mocks::MockFoo`. +- Additional domain-specific helper traits can be added to `mocks.rs` without touching the main + crate. + +## Helper Policy + +- Each folder keeps tiny `helpers.rs` files for fixtures that are only meaningful within that domain + (e.g. auth peers, config JSON blobs). This keeps intent local and avoids a mega helper file. +- When scenarios need runtime bootstrapping they should add small helper modules in their folder + rather than extending `tests/support`. + +## Domain Coverage Highlights + +- **Auth**: Validates realm mismatch and unknown users at unit level plus UDP/TLS rejection flows. +- **Allocation**: Exercises lifetime clamping in isolation and via refresh STUN requests. +- **Channel**: Covers ChannelData parsing plus UDP/TLS ChannelBind error paths when allocations + are missing. +- **Config**: Confirms JSON defaults, malformed detection, and `Config::from_file` wiring into an + `AuthManager` via a temp file. +- **Errors**: Ensures malformed frames trigger parser errors and are ignored by UDP/TLS loops + without producing responses. +- **End-to-end baselines**: `udp_turn.rs` and `tls_turn.rs` remain as the canonical happy-path + integration tests for Allocate/Refresh/Permission flows over both transports. + +## Running Tests + +- Full suite: `cargo test` (runs unit + integration crates, TLS fixtures included). +- Per-domain focus: `cargo test --test auth_integration_udp`, `cargo test --test channel_unit`, etc. +- Include ignored tests: none remain; every scenario runs as part of the default suite. diff --git a/tests/alloc/helpers.rs b/tests/alloc/helpers.rs new file mode 100644 index 0000000..d9c3a11 --- /dev/null +++ b/tests/alloc/helpers.rs @@ -0,0 +1,50 @@ +//! Allocation test helpers. +use crate::support::{default_test_credentials, test_auth_manager}; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{AuthManager, InMemoryStore}; +use niom_turn::constants::ATTR_NONCE; +use niom_turn::models::stun::StunMessage; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::UdpSocket; + +pub fn sample_client() -> SocketAddr { + "127.0.0.1:41000".parse().expect("client addr") +} + +pub fn sample_peer() -> SocketAddr { + "127.0.0.1:42000".parse().expect("peer addr") +} + +pub fn lifetime_secs(secs: u64) -> Duration { + Duration::from_secs(secs) +} + +pub fn build_auth_manager() -> AuthManager { + let (user, password) = default_test_credentials(); + test_auth_manager(user, password) +} + +pub async fn spawn_udp_server( + auth: AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let server = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let addr = server.local_addr().expect("udp addr"); + let arc = Arc::new(server); + let reader = arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = niom_turn::server::udp_reader_loop(reader, auth_clone, alloc_clone).await; + }); + addr +} + +pub fn extract_nonce(msg: &StunMessage) -> Option { + msg.attributes + .iter() + .find(|attr| attr.typ == ATTR_NONCE) + .and_then(|attr| String::from_utf8(attr.value.clone()).ok()) +} diff --git a/tests/alloc/integration_udp.rs b/tests/alloc/integration_udp.rs new file mode 100644 index 0000000..8e8e8df --- /dev/null +++ b/tests/alloc/integration_udp.rs @@ -0,0 +1,70 @@ +//! UDP allocation lifecycle integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use crate::support::stun_builders::{ + build_allocate_request, build_refresh_request, extract_lifetime, new_transaction_id, parse, +}; +use helpers::*; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth; +use support::{default_test_credentials, init_tracing, test_auth_manager}; +use tokio::net::UdpSocket; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn refresh_request_is_clamped_to_maximum_lifetime() { + init_tracing(); + let (username, password) = default_test_credentials(); + let auth_manager = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = spawn_udp_server(auth_manager.clone(), allocs.clone()).await; + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + let mut buf = [0u8; 1500]; + + // Challenge for nonce + let challenge = build_allocate_request(None, None, None, None, None); + client + .send_to(&challenge, server_addr) + .await + .expect("send challenge"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv nonce"); + let resp = parse(&buf[..len]); + let nonce = helpers::extract_nonce(&resp).expect("nonce attr"); + + // Successful allocation + let key = auth::compute_a1_md5(username, auth_manager.realm(), password); + let allocate = build_allocate_request( + Some(username), + Some(auth_manager.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + client + .send_to(&allocate, server_addr) + .await + .expect("send auth allocate"); + client.recv_from(&mut buf).await.expect("recv alloc success"); + + // Request refresh with value exceeding MAX (7200s) and assert server clamps to 3600 + let refresh = build_refresh_request( + new_transaction_id(), + username, + auth_manager.realm(), + &nonce, + &key, + 7200, + ); + client + .send_to(&refresh, server_addr) + .await + .expect("send refresh"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh"); + let resp = parse(&buf[..len]); + let lifetime = extract_lifetime(&resp).expect("lifetime attr"); + assert_eq!(lifetime, 3600); +} diff --git a/tests/alloc/unit.rs b/tests/alloc/unit.rs new file mode 100644 index 0000000..90f918d --- /dev/null +++ b/tests/alloc/unit.rs @@ -0,0 +1,44 @@ +//! Allocation lifecycle unit tests covering clamping and removal. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::alloc::AllocationManager; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::UdpSocket; + +async fn allocate_sample(manager: &AllocationManager) -> SocketAddr { + let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("udp bind")); + let client = sample_client(); + manager + .allocate_for(client, server) + .await + .expect("allocate relay"); + client +} + +#[tokio::test(flavor = "current_thread")] +async fn refresh_clamps_to_minimum_lifetime() { + let manager = AllocationManager::new(); + let client = allocate_sample(&manager).await; + let applied = manager + .refresh_allocation(client, Some(Duration::from_secs(30))) + .expect("refresh"); + assert_eq!(applied, Duration::from_secs(60)); +} + +#[tokio::test(flavor = "current_thread")] +async fn zero_lifetime_removes_allocation() { + let manager = AllocationManager::new(); + let client = allocate_sample(&manager).await; + let applied = manager + .refresh_allocation(client, Some(Duration::from_secs(0))) + .expect("refresh zero"); + assert_eq!(applied, Duration::from_secs(0)); + assert!(manager.get_allocation(&client).is_none()); +} diff --git a/tests/auth/helpers.rs b/tests/auth/helpers.rs new file mode 100644 index 0000000..15c47fa --- /dev/null +++ b/tests/auth/helpers.rs @@ -0,0 +1,50 @@ +//! Helpers dedicated to auth-focused tests. +use crate::support::{default_test_credentials, test_auth_manager}; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{AuthManager, InMemoryStore}; +use niom_turn::config::AuthOptions; +use niom_turn::constants::ATTR_NONCE; +use niom_turn::models::stun::StunMessage; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::UdpSocket; + +pub fn default_options() -> AuthOptions { + AuthOptions::default() +} + +pub fn loopback_peer() -> SocketAddr { + "127.0.0.1:55000".parse().expect("loopback socket") +} + +pub fn build_auth_manager() -> AuthManager { + let (user, password) = default_test_credentials(); + test_auth_manager(user, password) +} + +pub fn default_credentials() -> (&'static str, &'static str) { + default_test_credentials() +} + +pub async fn spawn_udp_server( + auth: AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let server = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let addr = server.local_addr().expect("udp addr"); + let server_arc = Arc::new(server); + let reader = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = niom_turn::server::udp_reader_loop(reader, auth_clone, alloc_clone).await; + }); + addr +} + +pub fn extract_nonce(msg: &StunMessage) -> Option { + msg.attributes + .iter() + .find(|attr| attr.typ == ATTR_NONCE) + .and_then(|attr| String::from_utf8(attr.value.clone()).ok()) +} diff --git a/tests/auth/integration_tls.rs b/tests/auth/integration_tls.rs new file mode 100644 index 0000000..3ebf5d6 --- /dev/null +++ b/tests/auth/integration_tls.rs @@ -0,0 +1,116 @@ +//! TLS-focused authentication integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use crate::support::stun_builders::{build_allocate_request, extract_error_code, parse}; +use helpers::*; +use niom_turn::alloc::AllocationManager; +use support::{default_test_credentials, init_tracing, test_auth_manager}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, UdpSocket}; +use tokio_rustls::TlsAcceptor; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tls_rejects_invalid_credentials() { + init_tracing(); + let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let udp_arc = Arc::new(udp); + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let (cert, key) = support::tls::generate_self_signed_cert(); + let mut cfg = tokio_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("server config"); + cfg.alpn_protocols.push(b"turn".to_vec()); + let acceptor = TlsAcceptor::from(Arc::new(cfg)); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + let udp_clone = udp_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acceptor = acceptor.clone(); + let udp_clone = udp_clone.clone(); + let auth_clone = auth_clone.clone(); + let alloc_clone = alloc_clone.clone(); + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(mut tls_stream) => { + if let Err(e) = niom_turn::tls::handle_tls_connection( + &mut tls_stream, + peer, + udp_clone, + auth_clone, + alloc_clone, + ) + .await + { + tracing::error!("tls connection error: {:?}", e); + } + } + Err(e) => { + tracing::error!("tls accept failed: {:?}", e); + } + } + }); + } + }); + + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(&cert).expect("add root"); + let client_cfg = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_cfg)); + + let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) + .await + .expect("tcp connect"); + let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); + let mut tls_stream = connector + .connect(domain, tcp_stream) + .await + .expect("tls connect"); + + let allocate = build_allocate_request(None, None, None, None, None); + tls_stream + .write_all(&allocate) + .await + .expect("write challenge"); + let mut buf = vec![0u8; 1500]; + let n = tls_stream.read(&mut buf).await.expect("read nonce"); + let resp = parse(&buf[..n]); + let nonce = extract_nonce(&resp).expect("nonce attr"); + + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), "wrongpass"); + let request = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + tls_stream + .write_all(&request) + .await + .expect("write invalid alloc"); + let n = tls_stream.read(&mut buf).await.expect("read reject"); + let resp = parse(&buf[..n]); + let code = extract_error_code(&resp).expect("error attr"); + assert_eq!(code, 401); +} diff --git a/tests/auth/integration_udp.rs b/tests/auth/integration_udp.rs new file mode 100644 index 0000000..e70a803 --- /dev/null +++ b/tests/auth/integration_udp.rs @@ -0,0 +1,53 @@ +//! UDP-focused authentication integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use crate::support::stun_builders::{build_allocate_request, extract_error_code, parse}; +use helpers::*; +use niom_turn::alloc::AllocationManager; +use support::{default_test_credentials, init_tracing, test_auth_manager}; +use tokio::net::UdpSocket; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn udp_rejects_unknown_user_after_nonce() { + init_tracing(); + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = spawn_udp_server(auth.clone(), allocs.clone()).await; + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + // Trigger initial challenge to receive nonce + let request = build_allocate_request(None, None, None, None, None); + client + .send_to(&request, server_addr) + .await + .expect("send challenge"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv nonce"); + let resp = parse(&buf[..len]); + let nonce = extract_nonce(&resp).expect("nonce attr"); + + // Attempt to authenticate with an unknown username + let intruder = "intruder"; + let key = niom_turn::auth::compute_a1_md5(intruder, auth.realm(), "wrongpass"); + let request = build_allocate_request( + Some(intruder), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + client + .send_to(&request, server_addr) + .await + .expect("send invalid auth allocate"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv reject"); + let resp = parse(&buf[..len]); + let code = extract_error_code(&resp).expect("error code attr"); + assert_eq!(code, 401); +} diff --git a/tests/auth/unit.rs b/tests/auth/unit.rs new file mode 100644 index 0000000..6ba479c --- /dev/null +++ b/tests/auth/unit.rs @@ -0,0 +1,80 @@ +//! Auth-specific unit tests driven by mock credential stores. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::auth::{self, AuthStatus}; +use niom_turn::traits::CredentialStore; +use support::mocks; +use crate::support::stun_builders::{build_allocate_request, parse}; + +fn realm_options(realm: &str) -> niom_turn::config::AuthOptions { + let mut opts = default_options(); + opts.realm = realm.to_string(); + opts.nonce_secret = Some("static-secret".into()); + opts +} + +#[tokio::test(flavor = "current_thread")] +async fn credential_store_mock_allows_lookup() { + let mut store = mocks::MockCredentialStore::new(); + store + .expect_get_password() + .with(mocks::predicates::eq("alice")) + .returning(|_| Box::pin(async { Some("s3cret".to_string()) })); + + let password = CredentialStore::get_password(&store, "alice").await; + assert_eq!(password.as_deref(), Some("s3cret")); +} + +#[tokio::test(flavor = "current_thread")] +async fn rejects_mismatched_realm_requests() { + let peer = loopback_peer(); + let mut store = niom_turn::auth::InMemoryStore::new(); + store.insert("alice", "secret"); + let auth = niom_turn::auth::AuthManager::new(store, &realm_options("expected.realm")); + let nonce = auth.mint_nonce(&peer); + let wrong_realm = "other.realm"; + let key = auth::compute_a1_md5("alice", wrong_realm, "secret"); + let buf = build_allocate_request( + Some("alice"), + Some(wrong_realm), + Some(&nonce), + Some(&key), + Some(600), + ); + let msg = parse(&buf); + match auth.authenticate(&msg, &peer).await { + AuthStatus::Reject { code, reason } => { + assert_eq!(code, 400); + assert_eq!(reason, "Realm Mismatch"); + } + other => panic!("unexpected auth result: {:?}", other), + } +} + +#[tokio::test(flavor = "current_thread")] +async fn rejects_unknown_user() { + let peer = loopback_peer(); + let auth = build_auth_manager(); + let nonce = auth.mint_nonce(&peer); + let key = auth::compute_a1_md5("intruder", auth.realm(), "badpass"); + let buf = build_allocate_request( + Some("intruder"), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + let msg = parse(&buf); + match auth.authenticate(&msg, &peer).await { + AuthStatus::Reject { code, reason } => { + assert_eq!(code, 401); + assert_eq!(reason, "Unknown User"); + } + other => panic!("unexpected auth result: {:?}", other), + } +} diff --git a/tests/channel/helpers.rs b/tests/channel/helpers.rs new file mode 100644 index 0000000..19f2703 --- /dev/null +++ b/tests/channel/helpers.rs @@ -0,0 +1,49 @@ +//! Helpers for channel-oriented tests. +use crate::support::{default_test_credentials, test_auth_manager}; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{AuthManager, InMemoryStore}; +use niom_turn::constants::ATTR_NONCE; +use niom_turn::models::stun::StunMessage; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::UdpSocket; + +pub fn sample_channel_number() -> u16 { + 0x4001 +} + +pub fn sample_peer() -> SocketAddr { + "127.0.0.1:43000".parse().expect("peer addr") +} + +pub fn sample_payload() -> Vec { + b"some-channel-payload".to_vec() +} + +pub fn build_auth_manager() -> AuthManager { + let (user, password) = default_test_credentials(); + test_auth_manager(user, password) +} + +pub async fn spawn_udp_server( + auth: AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let server = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let addr = server.local_addr().expect("udp addr"); + let arc = Arc::new(server); + let reader = arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = niom_turn::server::udp_reader_loop(reader, auth_clone, alloc_clone).await; + }); + addr +} + +pub fn extract_nonce(msg: &StunMessage) -> Option { + msg.attributes + .iter() + .find(|attr| attr.typ == ATTR_NONCE) + .and_then(|attr| String::from_utf8(attr.value.clone()).ok()) +} diff --git a/tests/channel/integration_tls.rs b/tests/channel/integration_tls.rs new file mode 100644 index 0000000..96d2ab8 --- /dev/null +++ b/tests/channel/integration_tls.rs @@ -0,0 +1,117 @@ +//! TLS channel bind integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use crate::support::stun_builders::{build_allocate_request, build_channel_bind_request, extract_error_code, parse}; +use helpers::*; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth; +use std::sync::Arc; +use support::{default_test_credentials, init_tracing, test_auth_manager}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, UdpSocket}; +use tokio_rustls::TlsAcceptor; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tls_channel_bind_without_allocation_returns_mismatch() { + init_tracing(); + let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let udp_arc = Arc::new(udp); + let (username, password) = default_test_credentials(); + let auth_manager = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let (cert, key) = support::tls::generate_self_signed_cert(); + let mut cfg = tokio_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("server config"); + cfg.alpn_protocols.push(b"turn".to_vec()); + let acceptor = TlsAcceptor::from(Arc::new(cfg)); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + let udp_clone = udp_arc.clone(); + let auth_clone = auth_manager.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acceptor = acceptor.clone(); + let udp_clone = udp_clone.clone(); + let auth_clone = auth_clone.clone(); + let alloc_clone = alloc_clone.clone(); + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(mut tls_stream) => { + if let Err(e) = niom_turn::tls::handle_tls_connection( + &mut tls_stream, + peer, + udp_clone, + auth_clone, + alloc_clone, + ) + .await + { + tracing::error!("tls connection error: {:?}", e); + } + } + Err(e) => tracing::error!("tls accept failed: {:?}", e), + } + }); + } + }); + + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(&cert).expect("add root"); + let client_cfg = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_cfg)); + + let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) + .await + .expect("tcp connect"); + let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); + let mut tls_stream = connector + .connect(domain, tcp_stream) + .await + .expect("tls connect"); + + // Obtain nonce via challenge + let allocate = build_allocate_request(None, None, None, None, None); + tls_stream + .write_all(&allocate) + .await + .expect("write challenge"); + let mut buf = vec![0u8; 1500]; + let n = tls_stream.read(&mut buf).await.expect("read nonce"); + let resp = parse(&buf[..n]); + let nonce = extract_nonce(&resp).expect("nonce attr"); + + let key = auth::compute_a1_md5(username, auth_manager.realm(), password); + let channel_req = build_channel_bind_request( + username, + auth_manager.realm(), + &nonce, + &key, + sample_channel_number(), + &sample_peer(), + ); + tls_stream + .write_all(&channel_req) + .await + .expect("write channel bind"); + let n = tls_stream.read(&mut buf).await.expect("read response"); + let resp = parse(&buf[..n]); + let code = extract_error_code(&resp).expect("error attr"); + assert_eq!(code, 437); +} diff --git a/tests/channel/integration_udp.rs b/tests/channel/integration_udp.rs new file mode 100644 index 0000000..d9ae7a0 --- /dev/null +++ b/tests/channel/integration_udp.rs @@ -0,0 +1,55 @@ +//! UDP channel bind integration coverage. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use crate::support::stun_builders::{ + build_allocate_request, build_channel_bind_request, extract_error_code, parse, +}; +use helpers::*; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth; +use support::{default_test_credentials, init_tracing, test_auth_manager}; +use tokio::net::UdpSocket; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn channel_bind_without_allocation_returns_mismatch() { + init_tracing(); + let (username, password) = default_test_credentials(); + let auth_manager = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = spawn_udp_server(auth_manager.clone(), allocs.clone()).await; + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + let mut buf = [0u8; 1500]; + + // Challenge to obtain nonce (no allocation performed yet) + let challenge = build_allocate_request(None, None, None, None, None); + client + .send_to(&challenge, server_addr) + .await + .expect("send challenge"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv nonce"); + let resp = parse(&buf[..len]); + let nonce = extract_nonce(&resp).expect("nonce attr"); + + let key = auth::compute_a1_md5(username, auth_manager.realm(), password); + let channel_req = build_channel_bind_request( + username, + auth_manager.realm(), + &nonce, + &key, + sample_channel_number(), + &sample_peer(), + ); + client + .send_to(&channel_req, server_addr) + .await + .expect("send channel bind"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv error"); + let resp = parse(&buf[..len]); + let code = extract_error_code(&resp).expect("error attr"); + assert_eq!(code, 437); +} diff --git a/tests/channel/unit.rs b/tests/channel/unit.rs new file mode 100644 index 0000000..8cf6e4c --- /dev/null +++ b/tests/channel/unit.rs @@ -0,0 +1,40 @@ +//! Channel-centric unit scaffolding validating mock sinks. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::stun::{build_channel_data, parse_channel_data}; +use support::mocks; + +#[tokio::test(flavor = "current_thread")] +async fn channel_sink_mock_records_payload() { + let mut sink = mocks::MockChannelSink::new(); + sink + .expect_send_channel_data() + .withf(|channel, payload| *channel == sample_channel_number() && payload == &sample_payload()) + .returning(|_, _| Box::pin(async { Ok(()) })); + + sink + .send_channel_data(sample_channel_number(), sample_payload()) + .await + .expect("channel data to send"); +} + +#[test] +fn parse_channel_data_round_trip() { + let payload = sample_payload(); + let frame = build_channel_data(sample_channel_number(), &payload); + let (channel, body) = parse_channel_data(&frame).expect("parse channel frame"); + assert_eq!(channel, sample_channel_number()); + assert_eq!(body, payload.as_slice()); +} + +#[test] +fn parse_channel_data_rejects_invalid_channel_range() { + let mut frame = build_channel_data(sample_channel_number(), &sample_payload()); + frame[0] = 0x20; // invalid prefix (must be 0x40) + assert!(parse_channel_data(&frame).is_none()); +} diff --git a/tests/config/helpers.rs b/tests/config/helpers.rs new file mode 100644 index 0000000..57e431d --- /dev/null +++ b/tests/config/helpers.rs @@ -0,0 +1,24 @@ +//! Helpers for config parsing tests. +use serde_json::json; +use std::io::Write; +use tempfile::NamedTempFile; + +pub fn minimal_config_json() -> String { + json!({ + "server": { "bind": "127.0.0.1:0", "tls_cert": null, "tls_key": null }, + "credentials": [], + "auth": { "realm": "niom-turn.test", "nonce_ttl_seconds": 300 } + }) + .to_string() +} + +pub fn malformed_config_json() -> String { + "{ server: }".to_string() +} + +pub fn write_temp_config(body: &str) -> NamedTempFile { + let mut file = NamedTempFile::new().expect("temp file"); + file.write_all(body.as_bytes()).expect("write temp config"); + file.flush().expect("flush temp config"); + file +} diff --git a/tests/config/integration.rs b/tests/config/integration.rs new file mode 100644 index 0000000..54e51b9 --- /dev/null +++ b/tests/config/integration.rs @@ -0,0 +1,29 @@ +//! Config-driven integration scaffolding testing startup paths. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use serde_json::json; + +#[tokio::test(flavor = "current_thread")] +async fn config_file_round_trip_populates_auth_manager() { + let cfg_json = json!({ + "server": { "bind": "127.0.0.1:3478", "tls_cert": null, "tls_key": null }, + "credentials": [ { "username": "alice", "password": "secret" } ], + "auth": { "realm": "niom-turn.integration", "nonce_ttl_seconds": 120 } + }) + .to_string(); + let temp = helpers::write_temp_config(&cfg_json); + let cfg = niom_turn::config::Config::from_file(temp.path()).expect("config load"); + assert_eq!(cfg.server.bind, "127.0.0.1:3478"); + assert_eq!(cfg.credentials.len(), 1); + + let store = niom_turn::auth::InMemoryStore::new(); + for cred in &cfg.credentials { + store.insert(&cred.username, &cred.password); + } + let auth = niom_turn::auth::AuthManager::new(store, &cfg.auth); + assert_eq!(auth.realm(), "niom-turn.integration"); +} diff --git a/tests/config/unit.rs b/tests/config/unit.rs new file mode 100644 index 0000000..563e46b --- /dev/null +++ b/tests/config/unit.rs @@ -0,0 +1,37 @@ +//! Config parsing unit scaffolding using mock sources. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::config::Config; +use support::mocks; + +#[test] +fn config_source_mock_serves_payload() { + let mut source = mocks::MockConfigSource::new(); + source + .expect_load() + .returning(|| Ok(minimal_config_json())); + + let body = source.load().expect("config payload"); + assert!(body.contains("niom-turn.test")); +} + +#[test] +fn parsing_minimal_config_populates_defaults() { + let json = minimal_config_json(); + let cfg: Config = serde_json::from_str(&json).expect("config parse"); + assert_eq!(cfg.server.bind, "127.0.0.1:0"); + assert_eq!(cfg.auth.realm, "niom-turn.test"); + assert_eq!(cfg.auth.nonce_ttl_seconds, 300); +} + +#[test] +fn malformed_config_fails_to_parse() { + let json = malformed_config_json(); + let parsed: Result = serde_json::from_str(&json); + assert!(parsed.is_err()); +} diff --git a/tests/errors/helpers.rs b/tests/errors/helpers.rs new file mode 100644 index 0000000..957861b --- /dev/null +++ b/tests/errors/helpers.rs @@ -0,0 +1,36 @@ +//! Helpers for error-path tests. +use crate::support::{default_test_credentials, test_auth_manager}; +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{AuthManager, InMemoryStore}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::UdpSocket; + +pub fn malformed_stun_frame() -> Vec { + vec![0x00, 0x01, 0x02] // too short for STUN header +} + +pub fn oversized_payload() -> Vec { + vec![0u8; 4096] +} + +pub fn build_auth_manager() -> AuthManager { + let (user, password) = default_test_credentials(); + test_auth_manager(user, password) +} + +pub async fn spawn_udp_server( + auth: AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let server = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let addr = server.local_addr().expect("udp addr"); + let arc = Arc::new(server); + let reader = arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = niom_turn::server::udp_reader_loop(reader, auth_clone, alloc_clone).await; + }); + addr +} diff --git a/tests/errors/integration_tls.rs b/tests/errors/integration_tls.rs new file mode 100644 index 0000000..111b233 --- /dev/null +++ b/tests/errors/integration_tls.rs @@ -0,0 +1,92 @@ +//! TLS error-path integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::alloc::AllocationManager; +use std::sync::Arc; +use support::{init_tracing, test_auth_manager, default_test_credentials}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, UdpSocket}; +use tokio::time::{timeout, Duration}; +use tokio_rustls::TlsAcceptor; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn malformed_tls_frame_is_ignored() { + init_tracing(); + let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); + let udp_arc = Arc::new(udp); + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let (cert, key) = support::tls::generate_self_signed_cert(); + let mut cfg = tokio_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("server config"); + cfg.alpn_protocols.push(b"turn".to_vec()); + let acceptor = TlsAcceptor::from(Arc::new(cfg)); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + let udp_clone = udp_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acceptor = acceptor.clone(); + let udp_clone = udp_clone.clone(); + let auth_clone = auth_clone.clone(); + let alloc_clone = alloc_clone.clone(); + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(mut tls_stream) => { + let _ = niom_turn::tls::handle_tls_connection( + &mut tls_stream, + peer, + udp_clone, + auth_clone, + alloc_clone, + ) + .await; + } + Err(e) => tracing::error!("tls accept failed: {:?}", e), + } + }); + } + }); + + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(&cert).expect("add root"); + let client_cfg = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_cfg)); + + let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) + .await + .expect("tcp connect"); + let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); + let mut tls_stream = connector + .connect(domain, tcp_stream) + .await + .expect("tls connect"); + + tls_stream + .write_all(&malformed_stun_frame()) + .await + .expect("write malformed"); + let mut buf = vec![0u8; 512]; + let result = timeout(Duration::from_millis(200), tls_stream.read(&mut buf)).await; + assert!(result.is_err(), "server should not respond to malformed frame"); +} diff --git a/tests/errors/integration_udp.rs b/tests/errors/integration_udp.rs new file mode 100644 index 0000000..a408f9e --- /dev/null +++ b/tests/errors/integration_udp.rs @@ -0,0 +1,30 @@ +//! UDP error-path integration tests. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::alloc::AllocationManager; +use support::init_tracing; +use tokio::net::UdpSocket; +use tokio::time::{timeout, Duration}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn malformed_packet_is_dropped_without_response() { + init_tracing(); + let auth = build_auth_manager(); + let allocs = AllocationManager::new(); + let server_addr = spawn_udp_server(auth, allocs).await; + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + client + .send_to(&malformed_stun_frame(), server_addr) + .await + .expect("send malformed"); + + let mut buf = [0u8; 1500]; + let recv = timeout(Duration::from_millis(200), client.recv_from(&mut buf)).await; + assert!(recv.is_err(), "server should not respond to malformed frame"); +} diff --git a/tests/errors/unit.rs b/tests/errors/unit.rs new file mode 100644 index 0000000..5b9238f --- /dev/null +++ b/tests/errors/unit.rs @@ -0,0 +1,28 @@ +//! Error-path unit scaffolding leveraging mock frame streams. + +#[path = "../support/mod.rs"] +mod support; + +mod helpers; + +use helpers::*; +use niom_turn::stun::{parse_message, ParseError}; +use support::mocks; + +#[tokio::test(flavor = "current_thread")] +async fn frame_stream_mock_yields_sequence() { + let mut stream = mocks::MockFrameStream::new(); + stream + .expect_next_frame() + .times(1) + .returning(|| Box::pin(async { Some(malformed_stun_frame()) })); + + let first = stream.next_frame().await; + assert_eq!(first.as_ref().map(|f| f.len()), Some(3)); +} + +#[test] +fn parse_message_rejects_short_frame() { + let err = parse_message(&malformed_stun_frame()).unwrap_err(); + assert!(matches!(err, ParseError::TooShort)); +} diff --git a/tests/support/mocks.rs b/tests/support/mocks.rs new file mode 100644 index 0000000..73e3fb2 --- /dev/null +++ b/tests/support/mocks.rs @@ -0,0 +1,94 @@ +//! Centralised mock definitions for test crates. +//! Sections are grouped by domain to keep the mapping between mocks and +//! their intended test areas obvious. +#![allow(dead_code)] + +use async_trait::async_trait; +use mockall::mock; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +pub mod predicates { + #![allow(unused_imports)] + pub use mockall::predicate::*; +} + +// --- Auth domain ----------------------------------------------------------- +mock! { + pub CredentialStore {} + #[async_trait] + impl niom_turn::traits::CredentialStore for CredentialStore { + async fn get_password(&self, username: &str) -> Option; + } +} + +// --- Allocation domain ----------------------------------------------------- +#[async_trait] +pub trait RelayIo: Send + Sync { + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>; + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result; +} + +mock! { + pub RelayIo {} + #[async_trait] + impl RelayIo for RelayIo { + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>; + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result; + } +} + +pub trait AllocationClock: Send + Sync { + fn now(&self) -> Instant; + fn advance(&self, delta: Duration); +} + +mock! { + pub AllocationClock {} + impl AllocationClock for AllocationClock { + fn now(&self) -> Instant; + fn advance(&self, delta: Duration); + } +} + +// --- Channel domain -------------------------------------------------------- +#[async_trait] +pub trait ChannelSink: Send + Sync { + async fn send_channel_data(&self, channel: u16, payload: Vec) -> anyhow::Result<()>; + async fn send_data_indication(&self, peer: SocketAddr, payload: Vec) -> anyhow::Result<()>; +} + +mock! { + pub ChannelSink {} + #[async_trait] + impl ChannelSink for ChannelSink { + async fn send_channel_data(&self, channel: u16, payload: Vec) -> anyhow::Result<()>; + async fn send_data_indication(&self, peer: SocketAddr, payload: Vec) -> anyhow::Result<()>; + } +} + +// --- Config domain --------------------------------------------------------- +pub trait ConfigSource: Send + Sync { + fn load(&self) -> Result; +} + +mock! { + pub ConfigSource {} + impl ConfigSource for ConfigSource { + fn load(&self) -> Result; + } +} + +// --- Error/Parser domain --------------------------------------------------- +#[async_trait] +pub trait FrameStream: Send + Sync { + async fn next_frame(&self) -> Option>; +} + +mock! { + pub FrameStream {} + #[async_trait] + impl FrameStream for FrameStream { + async fn next_frame(&self) -> Option>; + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index e950d19..bfe9058 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod mocks; pub mod stun_builders; pub mod tls; diff --git a/tests/support/stun_builders.rs b/tests/support/stun_builders.rs index f3b1019..b840e8c 100644 --- a/tests/support/stun_builders.rs +++ b/tests/support/stun_builders.rs @@ -98,6 +98,45 @@ pub fn build_send_request( ) } +/// Build a ChannelBind request binding `channel` to `peer`. +pub fn build_channel_bind_request( + username: &str, + realm: &str, + nonce: &str, + key: &[u8], + channel: u16, + peer: &std::net::SocketAddr, +) -> Vec { + let mut buf = BytesMut::new(); + buf.extend_from_slice(&METHOD_CHANNEL_BIND.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + let trans = new_transaction_id(); + buf.extend_from_slice(&trans); + + push_string_attr(&mut buf, ATTR_USERNAME, username); + push_string_attr(&mut buf, ATTR_REALM, realm); + push_string_attr(&mut buf, ATTR_NONCE, nonce); + + let mut channel_value = vec![0u8; 4]; + channel_value[0] = (channel >> 8) as u8; + channel_value[1] = channel as u8; + push_bytes_attr(&mut buf, ATTR_CHANNEL_NUMBER, &channel_value); + + let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans); + push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded); + + append_message_integrity(&mut buf, key); + + // update length + let total_len = (buf.len() - 20) as u16; + let len_bytes = total_len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + buf.to_vec() +} + fn build_authenticated_request( method: u16, username: Option<&str>, @@ -208,3 +247,38 @@ fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) { pub fn parse(buf: &[u8]) -> niom_turn::models::stun::StunMessage { parse_message(buf).expect("valid stun message") } + +/// Extract ERROR-CODE attribute value (e.g. 401, 437) if present. +pub fn extract_error_code(msg: &niom_turn::models::stun::StunMessage) -> Option { + msg.attributes + .iter() + .find(|a| a.typ == ATTR_ERROR_CODE) + .and_then(|attr| { + if attr.value.len() >= 4 { + let class = attr.value[2] as u16; + let number = attr.value[3] as u16; + Some(class * 100 + number) + } else { + None + } + }) +} + +/// Extract lifetime (seconds) from responses when present. +pub fn extract_lifetime(msg: &niom_turn::models::stun::StunMessage) -> Option { + msg.attributes + .iter() + .find(|a| a.typ == ATTR_LIFETIME) + .and_then(|attr| { + if attr.value.len() >= 4 { + Some(u32::from_be_bytes([ + attr.value[0], + attr.value[1], + attr.value[2], + attr.value[3], + ])) + } else { + None + } + }) +}