//! TLS channel bind integration tests. #[path = "../support/mod.rs"] mod support; mod helpers; use crate::support::stun_builders::{build_allocate_request, build_channel_bind_request, extract_error_code, parse}; use helpers::*; use niom_turn::alloc::AllocationManager; use niom_turn::auth; use std::sync::Arc; use support::{default_test_credentials, init_tracing, test_auth_manager}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, UdpSocket}; use tokio_rustls::TlsAcceptor; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tls_channel_bind_without_allocation_returns_mismatch() { init_tracing(); let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth_manager = test_auth_manager(username, password); let allocs = AllocationManager::new(); let (cert, key) = support::tls::generate_self_signed_cert(); let mut cfg = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert.clone()], key) .expect("server config"); cfg.alpn_protocols.push(b"turn".to_vec()); let acceptor = TlsAcceptor::from(Arc::new(cfg)); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let udp_clone = udp_arc.clone(); let auth_clone = auth_manager.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { loop { let (stream, peer) = match tcp_listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let acceptor = acceptor.clone(); let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { Ok(mut tls_stream) => { if let Err(e) = niom_turn::tls::handle_tls_connection( &mut tls_stream, peer, udp_clone, auth_clone, alloc_clone, ) .await { tracing::error!("tls connection error: {:?}", e); } } Err(e) => tracing::error!("tls accept failed: {:?}", e), } }); } }); let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); root_store.add(&cert).expect("add root"); let client_cfg = tokio_rustls::rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(client_cfg)); let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) .await .expect("tcp connect"); let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); let mut tls_stream = connector .connect(domain, tcp_stream) .await .expect("tls connect"); // Obtain nonce via challenge let allocate = build_allocate_request(None, None, None, None, None); tls_stream .write_all(&allocate) .await .expect("write challenge"); let mut buf = vec![0u8; 1500]; let n = tls_stream.read(&mut buf).await.expect("read nonce"); let resp = parse(&buf[..n]); let nonce = extract_nonce(&resp).expect("nonce attr"); let key = auth::compute_a1_md5(username, auth_manager.realm(), password); let channel_req = build_channel_bind_request( username, auth_manager.realm(), &nonce, &key, sample_channel_number(), &sample_peer(), ); tls_stream .write_all(&channel_req) .await .expect("write channel bind"); let n = tls_stream.read(&mut buf).await.expect("read response"); let resp = parse(&buf[..n]); let code = extract_error_code(&resp).expect("error attr"); assert_eq!(code, 437); }