//! TLS error-path integration tests. #[path = "../support/mod.rs"] mod support; mod helpers; use helpers::*; use niom_turn::alloc::AllocationManager; use std::sync::Arc; use support::{init_tracing, test_auth_manager, default_test_credentials}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, UdpSocket}; use tokio::time::{timeout, Duration}; use tokio_rustls::TlsAcceptor; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn malformed_tls_frame_is_ignored() { init_tracing(); let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth = test_auth_manager(username, password); let allocs = AllocationManager::new(); let (cert, key) = support::tls::generate_self_signed_cert(); let mut cfg = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert.clone()], key) .expect("server config"); cfg.alpn_protocols.push(b"turn".to_vec()); let acceptor = TlsAcceptor::from(Arc::new(cfg)); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let udp_clone = udp_arc.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { loop { let (stream, peer) = match tcp_listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let acceptor = acceptor.clone(); let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { Ok(mut tls_stream) => { let _ = niom_turn::tls::handle_tls_connection( &mut tls_stream, peer, udp_clone, auth_clone, alloc_clone, ) .await; } Err(e) => tracing::error!("tls accept failed: {:?}", e), } }); } }); let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); root_store.add(&cert).expect("add root"); let client_cfg = tokio_rustls::rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(client_cfg)); let tcp_stream = tokio::net::TcpStream::connect(tcp_addr) .await .expect("tcp connect"); let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); let mut tls_stream = connector .connect(domain, tcp_stream) .await .expect("tls connect"); tls_stream .write_all(&malformed_stun_frame()) .await .expect("write malformed"); let mut buf = vec![0u8; 512]; let result = timeout(Duration::from_millis(200), tls_stream.read(&mut buf)).await; assert!(result.is_err(), "server should not respond to malformed frame"); }