niom-turn/src/tls.rs

126 lines
4.1 KiB
Rust

//! TLS listener that wraps the STUN/TURN logic for `turns:` clients.
//! Backlog: ALPN negotiation, TCP relay support, and shared flow-control with the UDP path.
use anyhow::Context;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor;
use crate::alloc::AllocationManager;
use crate::auth::{AuthManager, InMemoryStore};
use crate::rate_limit::RateLimiters;
use crate::turn_stream::{handle_turn_stream_connection, handle_turn_stream_connection_with_limits};
fn load_certs(path: &str) -> anyhow::Result<Vec<Certificate>> {
let f = File::open(path).context("opening cert file")?;
let mut reader = BufReader::new(f);
let certs = rustls_pemfile::certs(&mut reader).context("reading certs")?;
Ok(certs.into_iter().map(Certificate).collect())
}
fn load_private_key(path: &str) -> anyhow::Result<PrivateKey> {
let f = File::open(path).context("opening key file")?;
let mut reader = BufReader::new(f);
let keys = rustls_pemfile::pkcs8_private_keys(&mut reader).context("reading pkcs8 keys")?;
if !keys.is_empty() {
return Ok(PrivateKey(keys[0].clone()));
}
// try RSA keys
let mut reader = BufReader::new(File::open(path)?);
let rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).context("reading rsa keys")?;
if !rsa_keys.is_empty() {
return Ok(PrivateKey(rsa_keys[0].clone()));
}
Err(anyhow::anyhow!("no private key found in {}", path))
}
/// Start a TLS-backed listener (turns) on the given bind address.
/// This reuses the existing STUN/TURN message handling logic, but sends replies
/// back over the TLS stream rather than UDP.
pub async fn serve_tls(
bind: &str,
cert_path: &str,
key_path: &str,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()> {
serve_tls_with_limits(
bind,
cert_path,
key_path,
auth,
allocs,
std::sync::Arc::new(RateLimiters::disabled()),
)
.await
}
pub async fn serve_tls_with_limits(
bind: &str,
cert_path: &str,
key_path: &str,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: std::sync::Arc<RateLimiters>,
) -> anyhow::Result<()> {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let cfg = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
let acceptor = TlsAcceptor::from(Arc::new(cfg));
let listener = TcpListener::bind(bind).await?;
tracing::info!("TLS listener bound to {}", bind);
loop {
let (stream, peer) = listener.accept().await?;
let acceptor = acceptor.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_tls_connection_with_limits(tls_stream, peer, auth_clone, alloc_clone, rl).await {
tracing::info!("TLS connection ended for {}: {:?}", peer, e);
}
}
Err(e) => tracing::error!("TLS accept error: {:?}", e),
}
});
}
}
pub async fn handle_tls_connection<S>(
tls_stream: S,
peer: std::net::SocketAddr,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
handle_turn_stream_connection(tls_stream, peer, auth, allocs).await
}
pub async fn handle_tls_connection_with_limits<S>(
tls_stream: S,
peer: std::net::SocketAddr,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: std::sync::Arc<RateLimiters>,
) -> anyhow::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
handle_turn_stream_connection_with_limits(tls_stream, peer, auth, allocs, rate_limiters).await
}