use std::net::SocketAddr; use std::sync::Arc; use niom_turn::alloc::AllocationManager; use niom_turn::auth::{compute_a1_md5, InMemoryStore}; use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::time::{timeout, Duration}; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use crate::support::stream::{StreamFrame, StreamFramer}; use crate::support::stun_builders::{ build_allocate_request, build_channel_bind_request, build_create_permission_request, }; use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager}; mod support; async fn start_tls_test_server( acceptor: TlsAcceptor, auth: niom_turn::auth::AuthManager, allocs: AllocationManager, ) -> SocketAddr { let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); tokio::spawn(async move { loop { let (stream, peer) = match tcp_listener.accept().await { Ok(conn) => conn, Err(_) => break, }; let acceptor = acceptor.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { if let Err(e) = niom_turn::tls::handle_tls_connection( tls_stream, peer, auth_clone, alloc_clone, ) .await { tracing::info!("tls connection ended: {:?}", e); } } Err(e) => tracing::error!("tls accept failed: {:?}", e), } }); } }); tcp_addr } #[tokio::test] async fn tls_channel_data_round_trip_works() { init_tracing_with("warn,niom_turn=info"); let (username, password) = default_test_credentials(); let auth = test_auth_manager(username, password); let allocs = AllocationManager::new(); let (cert, key) = support::tls::generate_self_signed_cert(); let mut cfg = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert.clone()], key) .expect("server config"); cfg.alpn_protocols.push(b"turn".to_vec()); let acceptor = TlsAcceptor::from(Arc::new(cfg)); let server_addr = start_tls_test_server(acceptor, auth.clone(), allocs.clone()).await; // client config trusting generated cert let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); root_store.add(&cert).expect("add root"); let client_config = tokio_rustls::rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); let tcp_stream = TcpStream::connect(server_addr).await.expect("tcp connect"); let client_addr = tcp_stream.local_addr().expect("client addr"); let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); let mut tls_stream = connector .connect(domain, tcp_stream) .await .expect("tls connect"); // allocate (unauth -> nonce) let allocate = build_allocate_request(None, None, None, None, None); tls_stream.write_all(&allocate).await.expect("write allocate"); let mut framer = StreamFramer::new(); let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) .await .expect("timeout challenge") .expect("read challenge"); let nonce = match challenge { StreamFrame::Stun(msg) => { let nonce_attr = msg .attributes .iter() .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) .expect("nonce attr"); String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8") } _ => panic!("expected STUN 401 challenge"), }; // auth allocate let key = compute_a1_md5(username, auth.realm(), password); let allocate = build_allocate_request( Some(username), Some(auth.realm()), Some(&nonce), Some(&key), Some(600), ); tls_stream.write_all(&allocate).await.expect("write auth allocate"); let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) .await .expect("timeout alloc success") .expect("read alloc success"); match alloc_success { StreamFrame::Stun(msg) => { msg.attributes .iter() .find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS) .expect("xor-relayed attr"); } _ => panic!("expected STUN allocate success"), } let relay_addr = allocs .get_allocation(&client_addr) .expect("allocation exists") .relay_addr; // set up peer + permission let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind"); let perm = build_create_permission_request( username, auth.realm(), &nonce, &key, &peer_sock.local_addr().unwrap(), ); tls_stream.write_all(&perm).await.expect("write create permission"); let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) .await .expect("timeout perm resp") .expect("read perm resp"); // channel bind let channel: u16 = 0x4000; let bind = build_channel_bind_request( username, auth.realm(), &nonce, &key, channel, &peer_sock.local_addr().unwrap(), ); tls_stream.write_all(&bind).await.expect("write channel bind"); let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) .await .expect("timeout bind resp") .expect("read bind resp"); // client -> peer via ChannelData let payload = b"tls-chan"; let frame = niom_turn::stun::build_channel_data(channel, payload); tls_stream.write_all(&frame).await.expect("write channel data"); let mut peer_buf = [0u8; 1500]; let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf)) .await .expect("timeout peer recv") .expect("peer recv"); assert_eq!(&peer_buf[..n], payload); assert!(from.ip().is_loopback()); // peer -> client as ChannelData over TLS let back = b"tls-back"; peer_sock.send_to(back, relay_addr).await.expect("peer send back"); let received = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) .await .expect("timeout channel back") .expect("read channel back"); match received { StreamFrame::ChannelData { channel: ch, payload } => { assert_eq!(ch, channel); assert_eq!(payload.as_slice(), back); } other => panic!("expected ChannelData frame, got: {:?}", other), } }