diff --git a/Cargo.lock b/Cargo.lock index d5f3653..874239e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.9.4" @@ -396,6 +402,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "base64 0.22.1", "bytes", "crc32fast", "hex", diff --git a/Cargo.toml b/Cargo.toml index 408de9f..29951f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ bytes = "1.4" hmac = "0.12" sha1 = "0.10" hex = "0.4" +base64 = "0.22" # config and logging serde = { version = "1.0", features = ["derive"] } diff --git a/README.md b/README.md index 7ab7f8c..d2f6513 100644 --- a/README.md +++ b/README.md @@ -147,9 +147,34 @@ Das Projekt kann eine JSON-Konfigdatei `appsettings.json` im Arbeitsverzeichnis { "server": { "bind": "0.0.0.0:3478", + "udp_bind": null, + "tcp_bind": null, + "tls_bind": "0.0.0.0:5349", + "enable_udp": true, + "enable_tcp": true, + "enable_tls": true, "tls_cert": null, "tls_key": null }, + "relay": { + "relay_port_min": null, + "relay_port_max": null, + "relay_bind_ip": "0.0.0.0", + "advertised_ip": null + }, + "logging": { + "default_directive": "warn,niom_turn=info" + }, + "limits": { + "max_allocations_per_ip": null, + "max_permissions_per_allocation": null, + "max_channel_bindings_per_allocation": null, + + "unauth_rps": null, + "unauth_burst": null, + "binding_rps": null, + "binding_burst": null + }, "credentials": [ { "username": "testuser", @@ -159,12 +184,25 @@ Das Projekt kann eine JSON-Konfigdatei `appsettings.json` im Arbeitsverzeichnis "auth": { "realm": "niom-turn.local", "nonce_secret": null, - "nonce_ttl_seconds": 300 + "nonce_ttl_seconds": 300, + "rest_secret": null, + "rest_max_ttl_seconds": 600 } } ``` -Wenn `appsettings.json` vorhanden ist, verwendet der Server die `server.bind` Adresse, befüllt den Credential-Store aus dem `credentials`-Array und übernimmt zusätzlich Realm/Nonce-Einstellungen aus `auth`. Falls die Datei fehlt, verwendet der Server die internen Defaults (Bind `0.0.0.0:3478`, Demo-Cred `testuser`, Realm `niom-turn.local`). +Wenn `appsettings.json` vorhanden ist, befüllt der Server den Credential-Store aus `credentials` und übernimmt Realm/Nonce/REST-Einstellungen aus `auth`. + +Listener-Binds: +- `server.bind` ist der Legacy-Default (wird genutzt, wenn `udp_bind`/`tcp_bind` nicht gesetzt sind). +- TCP/TLS nutzen denselben TURN-over-Stream Handler; `turns:` läuft auf `server.tls_bind`. +- Wenn `server.enable_tls=true`, aber `tls_cert`/`tls_key` fehlen, wird der TLS-Listener übersprungen. + +Relay/NAT: +- Hinter NAT sollte `relay.advertised_ip` auf die öffentliche IP gesetzt werden, damit Clients in XOR-RELAYED-ADDRESS eine erreichbare Adresse erhalten. +- Für Firewalls ist ein fixer Relay-Port-Range (`relay_port_min/max`) sinnvoll. + +Details und Runbook: siehe `docs/config/runtime.md`. Deployment & TLS / Long-term Auth roadmap ----------------------------------------- diff --git a/appsettings.example.json b/appsettings.example.json index 1db6abe..2d8c070 100644 --- a/appsettings.example.json +++ b/appsettings.example.json @@ -1,9 +1,34 @@ { "server": { "bind": "0.0.0.0:3478", + "udp_bind": null, + "tcp_bind": null, + "tls_bind": "0.0.0.0:5349", + "enable_udp": true, + "enable_tcp": true, + "enable_tls": true, "tls_cert": null, "tls_key": null }, + "relay": { + "relay_port_min": null, + "relay_port_max": null, + "relay_bind_ip": "0.0.0.0", + "advertised_ip": null + }, + "logging": { + "default_directive": "warn,niom_turn=info" + }, + "limits": { + "max_allocations_per_ip": null, + "max_permissions_per_allocation": null, + "max_channel_bindings_per_allocation": null, + + "unauth_rps": null, + "unauth_burst": null, + "binding_rps": null, + "binding_burst": null + }, "credentials": [ { "username": "testuser", @@ -13,6 +38,8 @@ "auth": { "realm": "niom-turn.local", "nonce_secret": null, - "nonce_ttl_seconds": 300 + "nonce_ttl_seconds": 300, + "rest_secret": null, + "rest_max_ttl_seconds": 600 } } diff --git a/docs/config/runtime.md b/docs/config/runtime.md index 23ede01..f2f73a9 100644 --- a/docs/config/runtime.md +++ b/docs/config/runtime.md @@ -9,9 +9,34 @@ Config { server: ServerOptions { bind: String, + udp_bind: Option, + tcp_bind: Option, + tls_bind: String, + enable_udp: bool, + enable_tcp: bool, + enable_tls: bool, tls_cert: Option, tls_key: Option, }, + relay: RelayOptions { + relay_port_min: Option, + relay_port_max: Option, + relay_bind_ip: Option, + advertised_ip: Option, + }, + logging: LoggingOptions { + default_directive: Option, + }, + limits: LimitsOptions { + max_allocations_per_ip: Option, + max_permissions_per_allocation: Option, + max_channel_bindings_per_allocation: Option, + + unauth_rps: Option, + unauth_burst: Option, + binding_rps: Option, + binding_burst: Option, + }, credentials: Vec { username: String, password: String, @@ -23,7 +48,125 @@ Config { - Bind: `0.0.0.0:3478` - Single Test Credential: `testuser` / `secretpassword` +## Runbook (Start & Betrieb) + +### 1) Konfiguration anlegen + +- Der Server lädt **immer** `appsettings.json` aus dem aktuellen Working Directory (siehe `Config::load_default()`). +- Als Basis kannst du `appsettings.example.json` nach `appsettings.json` kopieren und anpassen. + +### 2) Server starten + +```bash +# im Repo-Root +cargo run --bin niom-turn +``` + +Optional kannst du das Log-Level überschreiben: + +```bash +RUST_LOG=info cargo run --bin niom-turn +``` + +Hinweise: +- Wenn `server.enable_tls=true`, aber `server.tls_cert`/`server.tls_key` fehlen, wird der TLS-Listener **übersprungen** (Info-Log). +- Beim Start loggt der Server, welche Listener aktiv sind und welche Relay-Optionen gelten. + +### 3) Minimal-Konfig: UDP-only (einfachster Start) + +```json +{ + "server": { + "bind": "0.0.0.0:3478", + "enable_udp": true, + "enable_tcp": false, + "enable_tls": false, + "tls_cert": null, + "tls_key": null + }, + "relay": { + "relay_bind_ip": "0.0.0.0", + "advertised_ip": null, + "relay_port_min": null, + "relay_port_max": null + }, + "logging": { + "default_directive": "warn,niom_turn=info" + }, + "limits": { + "max_allocations_per_ip": null, + "max_permissions_per_allocation": null, + "max_channel_bindings_per_allocation": null, + + "unauth_rps": null, + "unauth_burst": null, + "binding_rps": null, + "binding_burst": null + }, + "credentials": [ + { "username": "testuser", "password": "secretpassword" } + ], + "auth": { + "realm": "niom-turn.local", + "nonce_secret": null, + "nonce_ttl_seconds": 300, + "rest_secret": null, + "rest_max_ttl_seconds": 600 + } +} +``` + +### 4) Minimal-Konfig: UDP + TCP + TLS (für maximale WebRTC-Kompatibilität) + +```json +{ + "server": { + "bind": "0.0.0.0:3478", + "tls_bind": "0.0.0.0:5349", + "enable_udp": true, + "enable_tcp": true, + "enable_tls": true, + "tls_cert": "/etc/niom-turn/tls/fullchain.pem", + "tls_key": "/etc/niom-turn/tls/privkey.pem" + }, + "relay": { + "relay_bind_ip": "0.0.0.0", + "advertised_ip": "203.0.113.10", + "relay_port_min": 49152, + "relay_port_max": 49252 + }, + "logging": { + "default_directive": "warn,niom_turn=info" + }, + "limits": { + "max_allocations_per_ip": 50, + "max_permissions_per_allocation": 200, + "max_channel_bindings_per_allocation": 50, + + "unauth_rps": 5, + "unauth_burst": 20, + "binding_rps": 50, + "binding_burst": 200 + }, + "credentials": [ + { "username": "testuser", "password": "secretpassword" } + ] +} +``` + +Betriebs-Checkliste (kurz): +- Firewall öffnen: UDP/3478, TCP/3478, TCP/5349 sowie den UDP-Relay-Portbereich (`relay_port_min/max`). +- Hinter NAT: `relay.advertised_ip` auf die öffentliche IP setzen. +- Für TURN REST Credentials: `auth.rest_secret` setzen und ggf. feste `credentials` leer lassen. + ## TODOs - Shared Secret / REST API zur Credential-Verwaltung. -- Konfigurierbare TLS-Bind-Adresse (`turns` Standard 5349). - Health-Port (HTTP) für Monitoring. +- Rate-Limits/Quota (pro Username) und Relay-Port-Range produktiv setzen. + +## Hinweise für produktiven Betrieb + +- Wenn der Server hinter NAT läuft, setze `relay.advertised_ip` auf die öffentliche IP, damit Clients in XOR-RELAYED-ADDRESS eine erreichbare Adresse erhalten. +- Für Firewalls ist ein fester Relay-Port-Range sinnvoll (`relay_port_min/max`), damit nur dieser UDP-Bereich geöffnet werden muss. + +- Allocations werden auch ohne Traffic periodisch auf Expiry geprüft und bereinigt (Housekeeping), damit Relay-Sockets/Tasks nicht dauerhaft hängen bleiben. diff --git a/docs/deployment.md b/docs/deployment.md new file mode 100644 index 0000000..a516927 --- /dev/null +++ b/docs/deployment.md @@ -0,0 +1,136 @@ +# Deployment Guide (niom-turn) + +This guide assumes a fresh Debian LXC (e.g., 10.0.0.22), Fritzbox port forwards are in place, and you want TURN reachable on 3478/udp+tcp and 5349/tcp with a UDP relay range (e.g., 49152-49200). + +## 1) Install dependencies + +```bash +sudo apt update +sudo apt install -y build-essential pkg-config libssl-dev curl git systemd +# Rust toolchain (stable) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +``` + +## 2) Clone and build + +```bash +cd /opt +sudo mkdir -p niom-turn && sudo chown "$USER":"$USER" niom-turn +cd /opt/niom-turn +git clone https://github.com//niom-turn.git . +cargo build --release +# Binary: target/release/niom-turn +``` + +## 3) Configuration + +Create config dir and place TLS cert/key (exported from NPM) and config: + +```bash +sudo mkdir -p /etc/niom-turn +sudo chown "$USER":"$USER" /etc/niom-turn +# place /etc/niom-turn/fullchain.pem and /etc/niom-turn/privkey.pem +``` + +Example `/etc/niom-turn/appsettings.json` (adjust realm, WAN IP, secrets): + +```json +{ + "logging": { "level": "info" }, + "auth": { + "realm": "turn.example.com", + "nonce_ttl_seconds": 600, + "rest_secret": "CHANGE_ME_REST_SECRET", + "rest_max_ttl_seconds": 86400 + }, + "listeners": { + "udp": "0.0.0.0:3478", + "tcp": "0.0.0.0:3478", + "tls": { + "addr": "0.0.0.0:5349", + "cert_file": "/etc/niom-turn/fullchain.pem", + "key_file": "/etc/niom-turn/privkey.pem" + } + }, + "relay": { + "bind_addr": "0.0.0.0", + "public_addr": "YOUR_WAN_IP", + "port_range": "49152-49200" + }, + "rate_limits": { + "enabled": true, + "max_allocations_per_ip": 10, + "max_permissions_per_allocation": 10, + "max_channels_per_allocation": 10 + } +} +``` + +- `public_addr` must be your public WAN IP (not the LXC IP). +- `rest_secret` is used for TURN REST credentials (time-based user/pass). + +## 4) Systemd service + +Install binary and user: + +```bash +sudo cp /opt/niom-turn/target/release/niom-turn /usr/local/bin/niom-turn +sudo useradd --system --no-create-home --shell /usr/sbin/nologin niomturn +sudo chown root:root /usr/local/bin/niom-turn +sudo chmod 0755 /usr/local/bin/niom-turn +sudo chown -R niomturn:niomturn /etc/niom-turn +``` + +Create `/etc/systemd/system/niom-turn.service`: + +``` +[Unit] +Description=niom-turn +After=network.target + +[Service] +User=niomturn +Group=niomturn +ExecStart=/usr/local/bin/niom-turn --config /etc/niom-turn/appsettings.json +Environment=RUST_LOG=debug,niom_turn=debug +Restart=on-failure +RestartSec=3 +# Optional: LimitNOFILE=65535 + +[Install] +WantedBy=multi-user.target +``` + +Enable/start: + +```bash +sudo systemctl daemon-reload +sudo systemctl enable --now niom-turn +``` + +## 5) Firewall (LXC) + +Allow inbound: UDP 3478, TCP 3478, TCP 5349, UDP relay range (49152-49200). Outbound allow all. + +## 6) Quick checks + +- Listener ports: `ss -tulpen | grep -E '3478|5349'` +- Logs: `journalctl -u niom-turn -f` +- External TCP reachability (from Hotspot): `nc -vz turn.example.com 3478` and `nc -vz turn.example.com 5349` +- STUN/TURN test: `stunclient turn.example.com 3478 -u user -p pass` (or REST creds) +- WebRTC: open webrtc-internals / about:webrtc; ensure relay candidates show your WAN IP + ports in 49152-49200. + +## 7) Fritzbox / Port forwards (reference) + +- UDP 3478 → 10.0.0.22:3478 +- TCP 3478 → 10.0.0.22:3478 +- TCP 5349 → 10.0.0.22:5349 +- UDP 49152-49200 → 10.0.0.22:49152-49200 +Test from external network (Hotspot), not from LAN (avoid NAT loopback assumptions). + +## 8) Tuning / next steps + +- For more logs temporarily set `RUST_LOG=trace,niom_turn=trace` in the service env. +- Consider JSON logging + metrics export if you need richer observability. +- Keep certs renewed via NPM and re-export to the LXC. diff --git a/docs/index.md b/docs/index.md index f234b3c..a6ffbc3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,12 @@ Dokumentationsübersicht für den TURN/STUN-Server. - [`architecture/data_flow.md`](architecture/data_flow.md) – UDP/TLS-Loop, Allocation Manager. - [`config/runtime.md`](config/runtime.md) – Appsettings & Credentials. +- [`turn_end_to_end_flow.md`](turn_end_to_end_flow.md) – Sequenzgrafik: Allocate → Permission → Send/ChannelData → Rückweg (UDP + TLS). +- [`tcp_tls_data_plane.md`](tcp_tls_data_plane.md) – Warum TCP/TLS Data-Plane wichtig ist und wie sie implementiert ist. +- [`mvp_gaps_and_rfc_notes.md`](mvp_gaps_and_rfc_notes.md) – MVP-Lücken, RFC-Notizen und Auswirkungen. +- [`turn_rest_credentials.md`](turn_rest_credentials.md) – TURN REST Credentials: Algorithmus, CLI-Tool und Betrieb. +- [`testing.md`](testing.md) – Testübersicht: was abgedeckt ist und wie man es ausführt. +- [`testing_todo.md`](testing_todo.md) – Vorschläge für zusätzliche Tests (Roadmap). - Bereits vorhanden: `deploy_tls_lxc.md`, RFC-Referenzen (STUN/TURN Specs). ## Zielsetzung diff --git a/docs/mvp_gaps_and_rfc_notes.md b/docs/mvp_gaps_and_rfc_notes.md new file mode 100644 index 0000000..1658015 --- /dev/null +++ b/docs/mvp_gaps_and_rfc_notes.md @@ -0,0 +1,114 @@ +# MVP-Lücken & RFC-Notizen (STUN/TURN) + +Dieses Dokument listet **bewusst vereinfachte/fehlende Teile** im aktuellen `niom-turn` MVP auf, jeweils mit kurzer Auswirkung und Code-Anker. + +> Ziel: Klarheit, was schon interoperabel ist, und wo (für Production/Interop) noch Arbeit nötig ist. + +--- + +## 1) STUN Binding ist minimal + +**Ist-Zustand** +- `METHOD_BINDING` beantwortet der Server als "Success" und enthält `XOR-MAPPED-ADDRESS` (IPv4+IPv6): [src/stun.rs](../src/stun.rs) + +**Auswirkung** +- Für STUN-Diagnose/ICE-NAT-Discovery ist das damit deutlich interoperabler. + +--- + +## 2) IPv6 wird nicht unterstützt (historisch) + +**Ist-Zustand** + - XOR-Address Encoding/Decoding unterstützt IPv4 **und** IPv6 (XOR-Key: Magic Cookie + Transaction ID): [src/stun.rs](../src/stun.rs) + +**Auswirkung** +- TURN/STUN kann IPv6-Adressen in XOR-ADDRESS Attributen korrekt verarbeiten. + +--- + +## 3) TURN Allocate ist stark vereinfacht + +**Ist-Zustand** +- `allocate_for` bindet immer ein UDP-Relay auf `0.0.0.0:0`: [src/alloc.rs](../src/alloc.rs) +- `REQUESTED-TRANSPORT` wird nicht ausgewertet (es gibt nur UDP-Relay). +- Weitere RFC-TURN Optionen (EVEN-PORT, RESERVATION-TOKEN, DONT-FRAGMENT, etc.) sind nicht implementiert. + +**Auswirkung** +- Das MVP deckt den typischen UDP-Relay-Fall ab, aber nicht die volle TURN-Feature-Matrix. + +--- + +## 4) Data Plane: UDP-relay ja, TCP-relay nein + +**Ist-Zustand** +- Relay-Datenpfad ist UDP-only (Relay-Socket ist `tokio::net::UdpSocket`): [src/alloc.rs](../src/alloc.rs) +- TLS-Listener re-used Control-Plane Logik, aber kein TCP-Relay: [src/tls.rs](../src/tls.rs) + +**Auswirkung** +- `turns:` (TLS) kann Control-Plane, aber der Rückweg der relayed Daten läuft derzeit weiterhin über UDP an die Client-UDP-Adresse. +- Für echte TURN-over-TCP/TLS Data Plane (Interoperabilität mit restriktiven Netzwerken) fehlt ein eigener TCP-Relay-Pfad. + +--- + +## 5) Allocation- und Timer-Verhalten ist MVP-artig + +**Ist-Zustand** +- Allocation-Lifetime wird über `refresh_allocation` gesetzt/geklemmt: [src/alloc.rs](../src/alloc.rs) +- `remove_allocation` entfernt den Eintrag. +- Der Relay-Task (spawn in `allocate_for`) läuft jedoch weiter und prüft nur noch, ob Allocation existiert: [src/alloc.rs](../src/alloc.rs) + +**Auswirkung** +- Bei vielen Allocations kann das zu unnötigen Hintergrund-Tasks führen (Resource-Management/Backpressure fehlt). + +--- + +## 6) Permissions & ChannelBindings sind minimal + +**Ist-Zustand** +- Permission TTL ist statisch (300s) und wird nur durch erneutes `CreatePermission` erneuert: [src/alloc.rs](../src/alloc.rs) +- ChannelBindings nutzen denselben TTL-Wert (`PERMISSION_LIFETIME`): [src/alloc.rs](../src/alloc.rs) + +**Auswirkung** +- Funktional ok für MVP, aber nicht unbedingt exakt RFC-getreu im Detail (separate Lifetime-Regeln/Refresh-Strategien sind üblich). + +--- + +## 7) MESSAGE-INTEGRITY/FINGERPRINT + +**Ist-Zustand** +- `MESSAGE-INTEGRITY` wird RFC-konform geprüft (HMAC-SHA1 über die Message bis inkl. MI-Attribut; Header-Länge wird dafür auf „Ende von MI“ gesetzt, d.h. kompatibel mit nachfolgenden Attributen wie `FINGERPRINT`): [src/stun.rs](../src/stun.rs) +- Für TURN-Responses nach erfolgreicher Authentisierung hängt der Server `MESSAGE-INTEGRITY` an und setzt `FINGERPRINT` als letztes Attribut: [src/server.rs](../src/server.rs), [src/turn_stream.rs](../src/turn_stream.rs) +- `FINGERPRINT` wird bei allen vom Server gebauten STUN-Nachrichten als letztes Attribut angehängt und bei eingehenden Nachrichten (falls vorhanden) validiert: [src/stun.rs](../src/stun.rs) + +**Auswirkung** +- Bessere Browser/ICE-Interop und leichtes Hardening (Messages mit ungültigem `FINGERPRINT` werden verworfen). + +--- + +## 8) Observability / Limits / Hardening fehlen (noch) + +**Ist-Zustand** +- Keine Quotas pro User/IP, keine Rate-Limits, keine Bandbreitenlimits pro Allocation. +- Credential Store ist In-Memory (Test/Dev): [src/auth.rs](../src/auth.rs), Trait: [src/traits/credential_store.rs](../src/traits/credential_store.rs) + +**Auswirkung** +- Für Production braucht es Limits, persistente Credentials, Monitoring/Metrics und härteres Error-Handling. + +--- + +## 9) Weitere RFC-Ecken (nicht implementiert) + +Typische Punkte, die im MVP fehlen/noch offen sind: +- Full attribute coverage (z.B. UNKNOWN-ATTRIBUTES, SOFTWARE, etc.) +- Vollständige STUN/TURN Class/Method Encoding nach RFC (hier bewusst vereinfacht über `METHOD | CLASS_*`): [src/constants.rs](../src/constants.rs) +- IPv6, Hairpinning-Sonderfälle, NAT-bezogene Interop-Edge-Cases + +--- + +## Was bereits gut abgedeckt ist + +- End-to-end UDP TURN Core: `Allocate` → `CreatePermission` → `Send`/ChannelData → Rückweg als Data Indication/ChannelData. +- Long-term Auth (Realm/Nonce + MI) mit klaren 401/438/437/403 Pfaden. +- TLS Listener (Control-Plane) mit STUN-Framing über TCP/TLS. + +Siehe auch die Integrationstests: [tests/udp_turn.rs](../tests/udp_turn.rs) und [tests/tls_turn.rs](../tests/tls_turn.rs). diff --git a/docs/tcp_tls_data_plane.md b/docs/tcp_tls_data_plane.md new file mode 100644 index 0000000..5ea71b7 --- /dev/null +++ b/docs/tcp_tls_data_plane.md @@ -0,0 +1,98 @@ +# TCP/TLS Data-Plane (TURN over TCP/TLS) in niom-turn + +Dieses Dokument erklärt **wofür** eine TCP/TLS Data-Plane bei TURN gebraucht wird und **wie** `niom-turn` sie (aktuell) implementiert. + +## Wofür wird das benötigt? + +TURN hat zwei Arten von Traffic: + +- **Control-Plane**: STUN/TURN Requests/Responses (Allocate, CreatePermission, ChannelBind, Refresh, …) +- **Data-Plane**: Nutzdaten zwischen Client und Peer, die über das Relay laufen (Send/Data-Indication oder ChannelData) + +In vielen realen Netzen ist **UDP eingeschränkt oder komplett blockiert** (Corporate WLAN, Mobilfunk-APNs, Captive Portals, Proxy-Umgebungen). WebRTC/ICE versucht deshalb typischerweise: + +1. UDP (schnell, bevorzugt) +2. TURN über TCP +3. TURN über TLS ("turns:") als letzte, aber oft funktionierende Option + +Damit TURN über TCP/TLS wirklich nutzbar ist, muss nicht nur die Control-Plane über den Stream laufen, sondern auch der **Rückweg Peer → Client** (Data-Plane) über **dieselbe TCP/TLS Verbindung** beim Client ankommen. + +## Was implementiert niom-turn konkret? + +- Client ↔ Server Transport kann **UDP**, **TCP** oder **TLS** sein. +- Das Relay zum Peer ist weiterhin **UDP** (klassisches TURN UDP-Relay). +- Bei **TCP/TLS** liefert der Server die Data-Plane zurück an den Client über den **Stream** (statt über UDP an die Client-Adresse zu senden). + +Das entspricht dem üblichen WebRTC-Fallback: "Client-Server über TCP/TLS, Peer-Transport über UDP". + +## Architektur im Code + +- Stream-Handler (TCP/TLS): [src/turn_stream.rs](../src/turn_stream.rs) +- TCP Listener: [src/tcp.rs](../src/tcp.rs) +- TLS Listener: [src/tls.rs](../src/tls.rs) +- Allocation + Relay: [src/alloc.rs](../src/alloc.rs) + +### Schlüsselidee: `ClientSink` + +Damit der Relay-Loop Peer-Pakete an unterschiedliche Client-Transporte schicken kann, gibt es in [src/alloc.rs](../src/alloc.rs) eine Abstraktion: + +- `ClientSink::Udp { sock, addr }` → sendet Peer-Daten per `udp.send_to(..., addr)` +- `ClientSink::Stream { tx }` → queued Bytes in einen Writer-Task, der auf den TCP/TLS Stream schreibt + +Wenn ein Client über TCP/TLS allocatet, wird die Allocation mit einem `ClientSink::Stream` erzeugt. + +## Framing: STUN vs. ChannelData auf einem Byte-Stream + +Auf UDP bekommt man Datagramme; auf TCP/TLS bekommt man einen **kontinuierlichen Byte-Stream**. TURN over TCP/TLS multiplexed: + +- STUN/TURN Messages (Control-Plane) +- ChannelData Frames (Data-Plane, Client → Server) + +`niom-turn` parst daher im Stream in einer Schleife "nächstes Frame" (siehe `try_pop_next_frame(...)` in [src/turn_stream.rs](../src/turn_stream.rs)): + +### STUN Message + +- Header ist 20 Bytes +- Length-Feld ist die Body-Länge +- Gesamtlänge ist: $20 + length$ + +### ChannelData Frame + +- 4 Byte Header: `CHANNEL-NUMBER` (2) + `LENGTH` (2) +- Channel-Nummern liegen im Bereich `0x4000..=0x7FFF` (Top-Bits `01`) +- Gesamtlänge ist: $4 + length$ + +Wichtig: Bei TCP/TLS darf **kein Padding** als "separate Bytes" im Stream verbleiben. Deshalb baut `niom-turn` ChannelData als exakt `4 + len` Bytes (siehe [src/stun.rs](../src/stun.rs)). + +### Hardening: Resync & Limits + +Da TCP/TLS ein Byte-Stream ist, können kaputte oder bösartige Clients den Parser sonst leicht „desynchronisieren“. +`niom-turn` implementiert daher im Stream-Parser: + +- **Magic-Cookie Check** für STUN: Ungültige Cookies führen zu einem Byte-weisen Resync (statt auf riesige Längen zu warten). +- **Frame-Size Limits** (STUN-Body und ChannelData), um Speicher-/DoS-Risiken zu begrenzen. +- **Max Buffer Limit** pro Verbindung: wenn der Eingangspuffer zu groß wird, wird die Verbindung geschlossen. + +## Datenfluss (TCP/TLS) + +1. Client verbindet sich per TCP oder TLS. +2. Der Stream-Handler liest Frames: + - STUN/TURN Requests → verarbeitet wie UDP-Pfad (Auth, Allocation, Permission, ChannelBind, Send, Refresh) + - ChannelData (Client→Peer) → wird über das UDP-Relay an den Peer geschickt +3. Peer sendet UDP an die Relay-Adresse. +4. Relay-Loop leitet die Bytes an den `ClientSink` weiter: + - bei Stream: `tx.send(bytes)` → Writer-Task schreibt Data-Indication oder ChannelData zurück auf denselben Stream + +## Grenzen / Noch nicht implementiert + +- Kein TCP-Relay zum Peer (TURN TCP allocations / CONNECT-Methoden wie in RFC6062). +- Fokus liegt auf: Client-Server Transport über TCP/TLS + UDP-Relay. +- IPv6 ist im aktuellen MVP noch nicht vollständig umgesetzt. + +## Tests + +- TCP Stream Data-Plane: `tests/tcp_turn.rs` +- TLS Stream Data-Plane: `tests/tls_data_plane.rs` +- Gemeinsames Framing (STUN + ChannelData): `tests/support/stream.rs` + +Wenn du willst, kann ich als nächsten Schritt die Doku um eine kurze Interop-Checkliste (WebRTC/ICE Verhalten, Candidate-Types, typische Fehlerbilder) ergänzen. diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..4d7445b --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,60 @@ +# Testing + +Dieses Projekt ist so aufgebaut, dass sich die wichtigsten TURN-Pfade als **Unit- und Integrationstests** ausführen lassen. + +## Schnellstart + +- Alle Tests: `cargo test` +- Mit weniger Output: `cargo test -q` +- Mit Logs (Beispiele): + - `RUST_LOG=warn,niom_turn=info cargo test -- --nocapture` + +Hinweis: Die Integrationstests initialisieren `tracing` über die Helpers in `tests/support`. + +## Was wird getestet? + +### STUN RFC-Interop (FINGERPRINT) + +- Unit-Tests prüfen, dass der Server `FINGERPRINT` an Responses anhängt und dass die CRC32/XOR Validierung fehlschlägt, wenn die Nachricht manipuliert wird. + +Siehe: Unit-Tests in [src/stun.rs](../src/stun.rs). + +### STUN RFC-Interop (MESSAGE-INTEGRITY) + +- Unit-Tests prüfen `MESSAGE-INTEGRITY` Validierung (inkl. Fall „MI + nachfolgendes FINGERPRINT“). +- UDP-Integrationstests prüfen, dass Responses nach erfolgreicher Authentisierung `MESSAGE-INTEGRITY` enthalten und validierbar sind. + +### UDP (turn:) + +- Auth-Challenge (401 + NONCE) und erfolgreicher Allocate +- Refresh mit `LIFETIME=0` entfernt Allocation +- CreatePermission + Send → Peer erhält UDP Payload über Relay + +Siehe: `tests/udp_turn.rs` sowie die thematischen Ordner in `tests/`. + +### TLS (turns:) / Stream-basierte Data-Plane + +- Allocate/Refresh über TLS-Stream +- (Neu) Data-Plane Rückweg Peer→Client über denselben TLS-Stream (Data-Indication oder ChannelData) + +Siehe: `tests/tls_turn.rs` und `tests/tls_data_plane.rs`. + +### TCP (turn:?transport=tcp) / Stream-basierte Data-Plane + +- (Neu) Allocate/CreatePermission/Send über TCP +- (Neu) Peer→Client Rückweg als STUN Data Indication über TCP +- (Neu) ChannelBind + ChannelData in beide Richtungen + +Siehe: `tests/tcp_turn.rs`. + +## Test-Hilfen + +- STUN/TURN Builder: `tests/support/stun_builders.rs` +- Stream-Framing (STUN + ChannelData über TCP/TLS): `tests/support/stream.rs` +- TLS Test-Certs: `tests/support/tls.rs` + +## Erweiterungsideen (nächste sinnvolle Abdeckung) + +- Split-Reads/Writes (Frames in mehreren TCP Reads) als Regressionstest +- IPv6 Encode/Decode Tests für XOR-ADDRESS Varianten +- Negative Tests: Peer nicht permitted, Channel ohne Bind, Allocation Timeout diff --git a/docs/testing_todo.md b/docs/testing_todo.md new file mode 100644 index 0000000..7791b7e --- /dev/null +++ b/docs/testing_todo.md @@ -0,0 +1,59 @@ +# Test-ToDo (Vorschläge) + +Dieses Dokument sammelt **konkrete** Test-Ideen, die den sicheren/stabilen Betrieb (insb. unter Last/Fehlverhalten) absichern sollen. + +## Stream (TCP/TLS) Robustheit + +- Split-Reads: STUN Header (20B) in 2 Reads, Body in mehreren Reads +- Split-Reads: ChannelData Header (4B) und Payload getrennt +- Mixed Frames: STUN → ChannelData → STUN in einem Read (und in mehreren Reads) +- Oversize Frames: + - STUN Length > Max → Verbindung wird geschlossen (oder Frame gedroppt, je nach Policy) + - ChannelData Length > Max → Verbindung wird geschlossen (oder Frame gedroppt) +- Garbage Resync: + - Zufallsbytes vor gültigem STUN (bereits abgedeckt) + - Zufallsbytes zwischen gültigen Frames + +## TURN Flows (Happy + Negative) + +- Negative pro Methode (UDP/TCP/TLS jeweils): + - ohne Allocation → 437 Allocation Mismatch + - ohne Permission → 403 Peer Not Permitted + - ChannelData ohne ChannelBind → drop + optional log counter + - Stale Nonce → 438 + - falsches MI → 401/403 je nach Policy + +## Auth + +- TURN REST: + - abgelaufener Username → reject + - Username zu weit in der Zukunft (max TTL) → reject + - falsches HMAC/base64 → reject + - „user exists in store“ vs. „REST fallback“ Priorität + +## Lifecycle + +- Allocation expiry: + - Refresh verkürzt/verlängert, Min/Max Lifetime + - Expiry entfernt Allocation und beendet Relay-Task (keine Task-Leaks) +- Permission expiry: + - Peer wird nach Ablauf verworfen +- Channel binding expiry: + - Rückweg fällt auf Data Indication zurück, wenn Binding abläuft + +## Abuse-/DoS Prevention (sobald Limits implementiert sind) + +- Rate limit: auth failures pro IP/Username +- Max allocations pro IP +- Max permissions/channels pro allocation +- Bandwidth caps (bytes/s) pro allocation +- Backpressure: Writer-Queue voll → Verhalten definieren (drop/close) + +## Interop (manuell reproduzierbar, aber dokumentiert) + +- Browser Plan: + - Trickle ICE / webrtc-internals: forced relay + - UDP-only block: erwarte TCP/TLS fallback + - `turns:` mit self-signed vs. valid cert + +Wenn du willst, kann ich als nächsten Schritt aus diesen Punkten eine priorisierte Test-Roadmap machen (P0/P1/P2) und direkt die nächsten P0-Tests implementieren. diff --git a/docs/turn_end_to_end_flow.md b/docs/turn_end_to_end_flow.md new file mode 100644 index 0000000..a196952 --- /dev/null +++ b/docs/turn_end_to_end_flow.md @@ -0,0 +1,139 @@ +# End-to-End TURN Flow (UDP + TLS) + +Dieses Dokument beschreibt den **konkret implementierten** End-to-End Ablauf in `niom-turn` anhand des aktuellen Codes (MVP): + +- UDP-Control-Plane und UDP-Data-Plane: [src/server.rs](../src/server.rs), [src/alloc.rs](../src/alloc.rs), [src/stun.rs](../src/stun.rs), [src/auth.rs](../src/auth.rs) +- TLS-Control-Plane ("turns") mit STUN-Framing: [src/tls.rs](../src/tls.rs) +- Test-Builder (wie Requests gebaut werden): [tests/support/stun_builders.rs](../tests/support/stun_builders.rs) + +## Begriffe + +- **Client**: TURN-Client (z.B. WebRTC ICE Agent) +- **Server**: `niom-turn` (dieses Projekt) +- **Peer**: Gegenstelle, zu der relayed werden soll (typischerweise anderer WebRTC-Endpunkt) +- **Allocation**: Server-seitige Sitzung, die ein **Relay-UDP-Socket** bereitstellt +- **Permission**: Erlaubnis, zu einem Peer zu senden/Peer-Pakete zu akzeptieren +- **Channel Binding**: Zuordnung `channel-number -> peer`, um ChannelData nutzen zu können + +--- + +## UDP: Sequenzgrafik (Happy Path) + +```mermaid +sequenceDiagram + autonumber + participant C as Client (TURN) + participant S as niom-turn (UDP:3478) + participant R as Relay Socket (UDP:ephemeral) + participant P as Peer + + Note over C,S: 1) Allocate ohne Auth → 401 Challenge + C->>S: STUN Allocate Request (ohne MI) + S->>C: STUN Error Response 401 + REALM + NONCE + + Note over C,S: 2) Allocate mit Long-Term Auth + C->>S: STUN Allocate Request + USERNAME/REALM/NONCE + MESSAGE-INTEGRITY + S->>R: bind("0.0.0.0:0"), spawn Relay-Loop + S->>C: Allocate Success + XOR-RELAYED-ADDRESS + LIFETIME (+ MESSAGE-INTEGRITY + FINGERPRINT) + + Note over C,S: 3) CreatePermission (Pflicht vor Send/ChannelBind) + C->>S: CreatePermission + XOR-PEER-ADDRESS (+ Auth + MI) + S->>C: Success (200) (+ MESSAGE-INTEGRITY + FINGERPRINT) + + Note over C,S: 4) Send (Client→Peer via Relay) + C->>S: Send + XOR-PEER-ADDRESS + DATA (+ Auth + MI) + S->>R: relay.send_to(DATA, Peer) + R->>P: UDP payload (source = relay_addr) + + Note over P,C: 5) Rückweg (Peer→Client) + P->>R: UDP payload (dest = relay_addr) + alt Channel Binding existiert + R->>S: recv_from(Peer) + S->>C: ChannelData(channel, payload) + else Kein Channel Binding + R->>S: recv_from(Peer) + S->>C: Data Indication (METHOD_DATA|INDICATION) + XOR-PEER-ADDRESS + DATA + end + + Note over C,S: 6) Optional: ChannelBind + ChannelData + C->>S: ChannelBind + CHANNEL-NUMBER + XOR-PEER-ADDRESS (+ Auth + MI) + S->>C: Success (200) (+ MESSAGE-INTEGRITY + FINGERPRINT) + C->>S: ChannelData(channel, payload) + S->>R: relay.send_to(payload, Peer) + R->>P: UDP payload + + Note over C,S: 7) Refresh + C->>S: Refresh + LIFETIME (+ Auth + MI) + S->>C: Success + LIFETIME(applied) (+ MESSAGE-INTEGRITY + FINGERPRINT) +``` + +### Was der Server dabei **genau** macht + +- Eingangspunkt: `udp_reader_loop` in [src/server.rs](../src/server.rs) +- Frühe Abzweigung: Wenn `parse_channel_data(...)` erfolgreich ist, wird **kein STUN** geparst, sondern ChannelData direkt weitergeleitet (nur wenn Allocation+Binding+Permission passt). +- STUN/TURN Requests werden mit `parse_message(...)` in [src/stun.rs](../src/stun.rs) geparst. + +### RFC-Interop Hinweis: FINGERPRINT + +- Alle vom Server gebauten STUN-Nachrichten enthalten `FINGERPRINT` als letztes Attribut. +- Wenn ein Client `FINGERPRINT` mitsendet, wird es validiert; bei ungültigem `FINGERPRINT` wird die Nachricht verworfen (kein Response). + +### Auth-Entscheidungen und typische Error-Codes + +Die Auth-Policy ist zentral in `AuthManager::authenticate` in [src/auth.rs](../src/auth.rs). + +- **401 Unauthorized**: Wenn `MESSAGE-INTEGRITY` fehlt → Challenge mit `REALM` + `NONCE`. +- **438 Stale Nonce**: Wenn `NONCE` abgelaufen/ungültig ist → neue Challenge. +- **437 Allocation Mismatch**: Wenn CreatePermission/Send/ChannelBind/Refresh ohne Allocation kommt. +- **403 Peer Not Permitted**: Wenn ein Peer nicht (mehr) permitted ist. +- **400 Missing/Invalid ...**: Wenn Attribute fehlen oder XOR-PEER-ADDRESS nicht dekodierbar ist. + +--- + +## TLS (turns): Sequenzgrafik (Control-Plane) + +Wichtig: Die TLS-Implementierung in [src/tls.rs](../src/tls.rs) nutzt denselben TURN-Handler wie TCP und implementiert eine echte **Stream-Data-Plane**. + +```mermaid +sequenceDiagram + autonumber + participant C as Client (turns) + participant T as niom-turn (TLS:5349) + participant R as Relay Socket (UDP:ephemeral) + participant P as Peer + + Note over C,T: STUN framing über TCP/TLS: read → chunk by (len+20) + C->>T: STUN Allocate (ohne MI) + T->>C: 401 + REALM + NONCE (über TLS) + + C->>T: STUN Allocate + Auth + MI + T->>R: allocate_for(peer, stream-sink) + T->>C: Allocate Success + XOR-RELAYED-ADDRESS + LIFETIME (über TLS) + + C->>T: CreatePermission/Send/Refresh/ChannelBind (über TLS) + T->>C: Success/Error (über TLS) + + Note over P,C: Peer-Daten kommen über den TLS-Stream zurück + P->>R: UDP payload an relay_addr + R->>T: recv_from(Peer) + T->>C: Data Indication / ChannelData (über TLS) +``` + +### Konsequenz + +- Control-Plane über TLS funktioniert (Allocate/Refresh/… werden über TLS beantwortet). +- Der **Data-Plane Rückweg** (Peer → Client) läuft ebenfalls über den TLS-Stream (Relay → `ClientSink::Stream`). + +Mehr Details dazu: [docs/tcp_tls_data_plane.md](tcp_tls_data_plane.md) + +--- + +## Mini-Checkliste: Minimaler Ablauf (praktisch) + +1. `ALLOCATE` ohne MI → `401` + `REALM` + `NONCE` +2. `ALLOCATE` mit `USERNAME/REALM/NONCE` + `MESSAGE-INTEGRITY` → `XOR-RELAYED-ADDRESS` + `LIFETIME` +3. `CREATE_PERMISSION` für Peer → `200` +4. `SEND` mit `DATA` → Server sendet via Relay an Peer +5. Peer sendet zurück an Relay → Server liefert an Client als `DATA-INDICATION` oder `CHANNEL-DATA` +6. Optional `CHANNEL_BIND` + ChannelData für effizientere Data-Plane +7. `REFRESH` zum Verlängern oder `LIFETIME=0` zum Freigeben diff --git a/docs/turn_rest_credentials.md b/docs/turn_rest_credentials.md new file mode 100644 index 0000000..a46df6a --- /dev/null +++ b/docs/turn_rest_credentials.md @@ -0,0 +1,75 @@ +# TURN REST Credentials (Ephemeral) – Nutzung & Betrieb + +Dieses Dokument erklärt die **TURN REST Credential** Strategie für `niom-turn` (MVP → production-fähiger Pfad). + +## Warum TURN REST? + +- Du willst für WebRTC *kurzlebige* TURN-Logins ausstellen (z.B. 5–10 Minuten). +- Dein TURN-Server speichert keine User-Passwörter pro Nutzer. +- Dein Backend (später) kann Tokens ausstellen; bis dahin kannst du lokal/ops-seitig Tokens generieren. + +## Verfahren (kompatibel zu gängigen WebRTC-Stacks) + +**Username**: `` oder `:` + +- Beispiel: `1735412345:alice` +- `expiry_unix_seconds` ist ein Unix-Timestamp in Sekunden. + +**Credential (Password)**: + +$$\n\text{credential} = \text{base64}(\text{HMAC-SHA1}(\text{secret}, \text{username}))\n$$ + +Dieses `credential` wird dann ganz normal als TURN-Passwort im ICE-Server-Config verwendet. + +## Server-Seite (`niom-turn`) + +Konfiguration in [appsettings.example.json](../appsettings.example.json): + +- `auth.rest_secret`: Shared Secret (muss geheim bleiben) +- `auth.rest_max_ttl_seconds`: Maximal akzeptiertes Zeitfenster in die Zukunft. Wenn ein Token zu weit in der Zukunft liegt, wird es abgewiesen (Sicherheitsmaßnahme). + +Wichtig: +- Der Server akzeptiert TURN REST Credentials als Fallback **nur wenn** der Username nicht im CredentialStore gefunden wird. +- Ablauf/TTL: + - `expiry` muss >= `now` sein + - und `expiry - now <= rest_max_ttl_seconds` + +## Lokal Tokens generieren (ohne Backend) + +Das Repo enthält ein kleines CLI: + +- Binary: [src/bin/turn_rest_cred.rs](../src/bin/turn_rest_cred.rs) + +Beispiele: + +1) Einfach (stdout als ENV-Zeilen) + +`cargo run --bin turn_rest_cred -- --secret "SUPER_SECRET" --user alice --ttl 600` + +2) JSON Ausgabe + +`cargo run --bin turn_rest_cred -- --secret "SUPER_SECRET" --user alice --ttl 600 --json` + +Du erhältst: +- `username` → in WebRTC als `iceServers[].username` +- `credential` → in WebRTC als `iceServers[].credential` + +## WebRTC Beispiel (Frontend) + +Du setzt in deiner App typischerweise: + +- URLs (mindestens UDP + TLS): + - `turn:your-domain:3478?transport=udp` + - `turn:your-domain:3478?transport=tcp` + - `turns:your-domain:5349?transport=tcp` + +und dann: +- `username`: vom Generator/Backend +- `credential`: vom Generator/Backend + +## Betriebshinweise + +- **Secret Rotation**: Plane Rotation (z.B. zwei Secrets parallel) – sonst brechen Tokens im Umlauf. +- **TTL klein halten**: 5–10 Minuten ist typisch. +- **Logs**: Niemals `secret` oder vollständige Credentials loggen. +- **Rate Limits/Quotas**: Unbedingt ergänzen (Open-Relay/Abuse vermeiden). diff --git a/src/alloc.rs b/src/alloc.rs index 484d444..feb409b 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -1,20 +1,56 @@ //! Allocation manager: provisions relay sockets and forwards packets for TURN allocations. //! Backlog: permission tables, channel bindings, allocation refresh timers, and rate limits. use std::collections::HashMap; +use std::net::IpAddr; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::net::UdpSocket; +use tokio::sync::Notify; +use tokio::sync::mpsc; use tracing::info; use crate::stun::{build_channel_data, build_data_indication}; +#[derive(thiserror::Error, Debug)] +pub enum AllocationError { + #[error("allocation quota exceeded")] + AllocationQuotaExceeded, + #[error("permission quota exceeded")] + PermissionQuotaExceeded, + #[error("channel binding quota exceeded")] + ChannelQuotaExceeded, +} + +#[derive(Clone)] +pub enum ClientSink { + Udp { sock: Arc, addr: SocketAddr }, + Stream { tx: mpsc::Sender> }, +} + +impl ClientSink { + pub async fn send(&self, data: Vec) -> anyhow::Result<()> { + match self { + ClientSink::Udp { sock, addr } => { + sock.send_to(&data, addr).await?; + Ok(()) + } + ClientSink::Stream { tx } => { + tx.send(data) + .await + .map_err(|_| anyhow::anyhow!("client stream closed")) + } + } + } +} + #[derive(Clone)] pub struct Allocation { pub client: SocketAddr, pub relay_addr: SocketAddr, // keep the socket so it stays bound _socket: Arc, + stop: Arc, permissions: Arc>>, channel_bindings: Arc>>, expiry: Arc>, @@ -23,36 +59,159 @@ pub struct Allocation { #[derive(Clone, Default)] pub struct AllocationManager { inner: Arc>>, + opts: AllocationOptions, +} + +#[derive(Clone, Debug)] +pub struct AllocationOptions { + pub relay_bind_ip: IpAddr, + pub relay_port_min: Option, + pub relay_port_max: Option, + pub advertised_ip: Option, + + pub max_allocations_per_ip: Option, + pub max_permissions_per_allocation: Option, + pub max_channel_bindings_per_allocation: Option, +} + +impl Default for AllocationOptions { + fn default() -> Self { + Self { + relay_bind_ip: IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), + relay_port_min: None, + relay_port_max: None, + advertised_ip: None, + max_allocations_per_ip: None, + max_permissions_per_allocation: None, + max_channel_bindings_per_allocation: None, + } + } } impl AllocationManager { pub fn new() -> Self { Self { inner: Arc::new(Mutex::new(HashMap::new())), + opts: AllocationOptions::default(), + } + } + + pub fn new_with_options(opts: AllocationOptions) -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + opts, + } + } + + /// Translate a locally-bound relay socket address into the address that should be + /// advertised to clients (e.g. replace 0.0.0.0 with a public IP). + pub fn relay_addr_for_response(&self, relay_local: SocketAddr) -> SocketAddr { + match self.opts.advertised_ip { + Some(ip) => SocketAddr::new(ip, relay_local.port()), + None => relay_local, } } /// Create a relay UDP socket for the given client and spawn a relay loop that forwards - /// any packets received on the relay socket back to the client via the provided server socket. + /// any packets received on the relay socket back to the client via the provided client sink. pub async fn allocate_for( &self, client: SocketAddr, - server_sock: Arc, + client_sink: ClientSink, ) -> anyhow::Result { - // bind relay socket to OS-chosen port - let relay = UdpSocket::bind("0.0.0.0:0").await?; + // If an allocation already exists for this exact 5-tuple, reuse it. + { + let mut guard = self.inner.lock().unwrap(); + prune_expired_locked(&mut guard); + if let Some(existing) = guard.get(&client) { + return Ok(existing.relay_addr); + } + + if let Some(max) = self.opts.max_allocations_per_ip { + let count_for_ip = guard + .values() + .filter(|a| a.client.ip() == client.ip()) + .count() as u32; + if count_for_ip >= max { + return Err(anyhow::anyhow!(AllocationError::AllocationQuotaExceeded)); + } + } + } + + // bind relay socket (optionally within configured port range) + let relay = match (self.opts.relay_port_min, self.opts.relay_port_max) { + (Some(min), Some(max)) if min > 0 && max >= min => { + let mut bound: Option = None; + let mut last_err: Option = None; + for port in min..=max { + let addr = SocketAddr::new(self.opts.relay_bind_ip, port); + match UdpSocket::bind(addr).await { + Ok(sock) => { + bound = Some(sock); + last_err = None; + break; + } + Err(e) => { + last_err = Some(anyhow::anyhow!(e)); + } + } + } + + match bound { + Some(sock) => sock, + None => { + let detail = last_err + .map(|e| format!("{e:?}")) + .unwrap_or_else(|| "no ports in range available".to_string()); + return Err(anyhow::anyhow!( + "failed to bind relay socket in range {}-{} on {}: {}", + min, + max, + self.opts.relay_bind_ip, + detail + )); + } + } + } + _ => { + let addr = SocketAddr::new(self.opts.relay_bind_ip, 0); + UdpSocket::bind(addr).await? + } + }; let relay_local = relay.local_addr()?; let relay_arc = Arc::new(relay); + // Insert allocation before spawning relay loop to avoid races. + let stop = Arc::new(Notify::new()); + let alloc = Allocation { + client, + relay_addr: relay_local, + _socket: relay_arc.clone(), + stop: stop.clone(), + permissions: Arc::new(Mutex::new(HashMap::new())), + channel_bindings: Arc::new(Mutex::new(HashMap::new())), + expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)), + }; + { + let mut m = self.inner.lock().unwrap(); + prune_expired_locked(&mut m); + m.insert(client, alloc); + } + // spawn relay loop let relay_clone = relay_arc.clone(); - let server_sock_clone = server_sock.clone(); + let sink_clone = client_sink.clone(); let client_clone = client; let manager_clone = self.clone(); + let stop_clone = stop.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 2048]; loop { - match relay_clone.recv_from(&mut buf).await { + tokio::select! { + _ = stop_clone.notified() => { + break; + } + res = relay_clone.recv_from(&mut buf) => match res { Ok((len, src)) => { info!( "relay got {} bytes from {} for client {}", @@ -70,27 +229,29 @@ impl AllocationManager { if let Some(channel) = allocation.channel_for_peer(&src) { let frame = build_channel_data(channel, &buf[..len]); - if let Err(e) = - server_sock_clone.send_to(&frame, client_clone).await - { + if let Err(e) = sink_clone.send(frame).await { tracing::error!( "failed to send channel data {} -> {}: {:?}", src, client_clone, e ); + if matches!(sink_clone, ClientSink::Stream { .. }) { + break; + } } } else { let indication = build_data_indication(&src, &buf[..len]); - if let Err(e) = - server_sock_clone.send_to(&indication, client_clone).await - { + if let Err(e) = sink_clone.send(indication).await { tracing::error!( "failed to send data indication {} -> {}: {:?}", src, client_clone, e ); + if matches!(sink_clone, ClientSink::Stream { .. }) { + break; + } } } } else { @@ -99,28 +260,20 @@ impl AllocationManager { src, client_clone ); + // Allocation removed/expired: stop the relay task. + break; } } Err(e) => { tracing::error!("relay socket error: {:?}", e); break; } + } } } }); - let alloc = Allocation { - client, - relay_addr: relay_local, - _socket: relay_arc, - permissions: Arc::new(Mutex::new(HashMap::new())), - channel_bindings: Arc::new(Mutex::new(HashMap::new())), - expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)), - }; tracing::info!("created allocation for {} -> {}", client, relay_local); - let mut m = self.inner.lock().unwrap(); - prune_expired_locked(&mut m); - m.insert(client, alloc); Ok(relay_local) } @@ -140,6 +293,13 @@ impl AllocationManager { .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; let mut perms = alloc.permissions.lock().unwrap(); prune_permissions(&mut perms); + + if let Some(max) = self.opts.max_permissions_per_allocation { + let max = max as usize; + if !perms.contains_key(&peer) && perms.len() >= max { + return Err(anyhow::anyhow!(AllocationError::PermissionQuotaExceeded)); + } + } perms.insert(peer, Instant::now() + PERMISSION_LIFETIME); Ok(()) } @@ -158,6 +318,13 @@ impl AllocationManager { .ok_or_else(|| anyhow::anyhow!("allocation not found"))?; let mut bindings = alloc.channel_bindings.lock().unwrap(); prune_channel_bindings(&mut bindings); + + if let Some(max) = self.opts.max_channel_bindings_per_allocation { + let max = max as usize; + if !bindings.contains_key(&channel) && bindings.len() >= max { + return Err(anyhow::anyhow!(AllocationError::ChannelQuotaExceeded)); + } + } bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME)); Ok(()) } @@ -173,7 +340,9 @@ impl AllocationManager { let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME); if let Some(d) = requested { if d.is_zero() { - guard.remove(&client); + if let Some(alloc) = guard.remove(&client) { + alloc.stop.notify_waiters(); + } return Ok(Duration::from_secs(0)); } } @@ -191,7 +360,42 @@ impl AllocationManager { /// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed. pub fn remove_allocation(&self, client: &SocketAddr) -> bool { let mut guard = self.inner.lock().unwrap(); - guard.remove(client).is_some() + if let Some(alloc) = guard.remove(client) { + alloc.stop.notify_waiters(); + true + } else { + false + } + } + + /// Spawn a background housekeeping task that periodically prunes expired allocations. + /// This avoids keeping relay tasks/sockets alive indefinitely when allocations expire. + pub fn spawn_housekeeping(&self, interval: Duration) { + let mgr = self.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + mgr.prune_expired(); + } + }); + } + + /// Remove expired allocations and notify their relay tasks to stop. + pub fn prune_expired(&self) { + let mut guard = self.inner.lock().unwrap(); + let now = Instant::now(); + let expired: Vec = guard + .iter() + .filter_map(|(k, alloc)| { + let expiry = alloc.expiry.lock().unwrap(); + if *expiry <= now { Some(*k) } else { None } + }) + .collect(); + for client in expired { + if let Some(alloc) = guard.remove(&client) { + alloc.stop.notify_waiters(); + } + } } } diff --git a/src/auth.rs b/src/auth.rs index e247d39..24d1240 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -6,6 +6,7 @@ use crate::models::stun::StunMessage; use crate::stun::{find_message_integrity, validate_message_integrity}; use crate::traits::CredentialStore; use async_trait::async_trait; +use base64::Engine; use hmac::{Hmac, Mac}; use sha1::Sha1; use std::net::SocketAddr; @@ -46,6 +47,8 @@ pub struct AuthSettings { pub realm: String, pub nonce_secret: Vec, pub nonce_ttl: Duration, + pub rest_secret: Option>, + pub rest_max_ttl: Duration, } impl AuthSettings { @@ -56,10 +59,13 @@ impl AuthSettings { .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); // Ensure TTL does not collapse to zero so challenges stay valid briefly. let ttl = Duration::from_secs(opts.nonce_ttl_seconds.max(60)); + let rest_max_ttl = Duration::from_secs(opts.rest_max_ttl_seconds.max(60)); Self { realm: opts.realm.clone(), nonce_secret: secret.into_bytes(), nonce_ttl: ttl, + rest_secret: opts.rest_secret.clone().map(|s| s.into_bytes()), + rest_max_ttl, } } } @@ -67,7 +73,7 @@ impl AuthSettings { /// Result of validating authentication attributes on an incoming STUN/TURN request. #[derive(Debug, Clone)] pub enum AuthStatus { - Granted { username: String }, + Granted { username: String, key: Vec }, Challenge { nonce: String }, StaleNonce { nonce: String }, Reject { code: u16, reason: &'static str }, @@ -159,12 +165,15 @@ impl AuthManager { let password = match self.store.get_password(&username).await { Some(p) => p, - None => { - return AuthStatus::Reject { - code: 401, - reason: "Unknown User", + None => match self.derive_turn_rest_password(&username) { + Some(p) => p, + None => { + return AuthStatus::Reject { + code: 401, + reason: "Unknown User", + }; } - } + }, }; let key = self.derive_long_term_key(&username, &password); @@ -175,7 +184,7 @@ impl AuthManager { }; } - AuthStatus::Granted { username } + AuthStatus::Granted { username, key } } fn attribute_utf8(&self, msg: &StunMessage, attr_type: u16) -> Option { @@ -190,6 +199,32 @@ impl AuthManager { compute_a1_md5(username, &self.settings.realm, password) } + /// TURN REST (ephemeral) password derivation. + /// + /// Expected username format: `` or `:`. + /// Password is `base64(HMAC-SHA1(rest_secret, username))`. + /// + /// Security: Reject if expired or if expiry is too far in the future (bounded by rest_max_ttl). + fn derive_turn_rest_password(&self, username: &str) -> Option { + let secret = self.settings.rest_secret.as_ref()?; + let expiry = parse_turn_rest_expiry(username)?; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_secs(); + + if now > expiry { + return None; + } + + let delta = expiry.saturating_sub(now); + if delta > self.settings.rest_max_ttl.as_secs() { + return None; + } + + Some(turn_rest_password_base64(secret, username)) + } + pub fn mint_nonce(&self, peer: &SocketAddr) -> String { let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -241,6 +276,19 @@ impl AuthManager { } } +fn parse_turn_rest_expiry(username: &str) -> Option { + let prefix = username.split(':').next().unwrap_or(username); + prefix.parse::().ok() +} + +fn turn_rest_password_base64(secret: &[u8], username: &str) -> String { + type HmacSha1 = Hmac; + let mut mac = HmacSha1::new_from_slice(secret).expect("rest secret to build hmac"); + mac.update(username.as_bytes()); + let bytes = mac.finalize().into_bytes(); + base64::engine::general_purpose::STANDARD.encode(bytes) +} + enum NonceValidation { Valid, Expired, diff --git a/src/bin/turn_rest_cred.rs b/src/bin/turn_rest_cred.rs new file mode 100644 index 0000000..fc0108b --- /dev/null +++ b/src/bin/turn_rest_cred.rs @@ -0,0 +1,76 @@ +use anyhow::Context; +use base64::Engine; +use hmac::{Hmac, Mac}; +use sha1::Sha1; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_secs() +} + +fn turn_rest_password_base64(secret: &[u8], username: &str) -> String { + type HmacSha1 = Hmac; + let mut mac = HmacSha1::new_from_slice(secret).expect("rest secret to build hmac"); + mac.update(username.as_bytes()); + let bytes = mac.finalize().into_bytes(); + base64::engine::general_purpose::STANDARD.encode(bytes) +} + +/// Minimal offline TURN REST credential generator. +/// +/// Usage: +/// - `cargo run --bin turn_rest_cred -- --secret "..." --user alice --ttl 600` +/// +/// Output: +/// - `TURN_USERNAME=...` +/// - `TURN_PASSWORD=...` +fn main() -> anyhow::Result<()> { + let mut secret: Option = None; + let mut user: Option = None; + let mut ttl: u64 = 600; + let mut json = false; + + let mut args = std::env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--secret" => secret = Some(args.next().context("--secret requires a value")?), + "--user" => user = Some(args.next().context("--user requires a value")?), + "--ttl" => { + let v = args.next().context("--ttl requires a value")?; + ttl = v.parse::().context("--ttl must be an integer")?; + } + "--json" => json = true, + "-h" | "--help" => { + println!( + "turn_rest_cred\n\nUSAGE:\n turn_rest_cred --secret [--user ] [--ttl ] [--json]\n\nNOTES:\n Username format is [:]. Password is base64(HMAC-SHA1(secret, username)).\n" + ); + return Ok(()); + } + other => return Err(anyhow::anyhow!("unknown arg: {}", other)), + } + } + + let secret = secret.context("missing --secret")?; + let expiry = now_unix_secs().saturating_add(ttl.max(60)); + let username = match user { + Some(u) if !u.is_empty() => format!("{}:{}", expiry, u), + _ => expiry.to_string(), + }; + let password = turn_rest_password_base64(secret.as_bytes(), &username); + + if json { + println!( + "{{\n \"username\": \"{}\",\n \"credential\": \"{}\",\n \"ttl\": {}\n}}", + username, password, ttl + ); + } else { + println!("TURN_USERNAME={}", username); + println!("TURN_PASSWORD={}", password); + println!("TURN_TTL={}", ttl); + } + + Ok(()) +} diff --git a/src/config.rs b/src/config.rs index d199edb..007636e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,6 +11,26 @@ fn default_nonce_ttl_seconds() -> u64 { 300 } +fn default_rest_max_ttl_seconds() -> u64 { + 600 +} + +fn default_tls_bind() -> String { + "0.0.0.0:5349".to_string() +} + +fn default_enable_udp() -> bool { + true +} + +fn default_enable_tcp() -> bool { + true +} + +fn default_enable_tls() -> bool { + true +} + #[derive(Debug, Deserialize, Clone)] pub struct CredentialEntry { pub username: String, @@ -28,6 +48,16 @@ pub struct AuthOptions { /// Validity period for generated nonces in seconds. #[serde(default = "default_nonce_ttl_seconds")] pub nonce_ttl_seconds: u64, + + /// Optional TURN REST shared secret. When set, the server can validate ephemeral credentials + /// derived from this secret (username contains expiry; password is base64(HMAC-SHA1(secret, username))). + #[serde(default)] + pub rest_secret: Option, + + /// Maximum acceptable TTL window (in seconds) for TURN REST usernames. + /// If the username expiry is too far in the future, credentials are rejected. + #[serde(default = "default_rest_max_ttl_seconds")] + pub rest_max_ttl_seconds: u64, } impl Default for AuthOptions { @@ -36,19 +66,99 @@ impl Default for AuthOptions { realm: default_realm(), nonce_secret: None, nonce_ttl_seconds: default_nonce_ttl_seconds(), + rest_secret: None, + rest_max_ttl_seconds: default_rest_max_ttl_seconds(), } } } #[derive(Debug, Deserialize, Clone)] pub struct ServerOptions { - /// Listen address, e.g. "0.0.0.0:3478" + /// Listen address (legacy/default), e.g. "0.0.0.0:3478". + /// If `udp_bind` / `tcp_bind` are not set, this value is used. pub bind: String, + + /// Optional per-protocol bind addresses. + #[serde(default)] + pub udp_bind: Option, + #[serde(default)] + pub tcp_bind: Option, + #[serde(default = "default_tls_bind")] + pub tls_bind: String, + + /// Enable/disable protocol listeners. + #[serde(default = "default_enable_udp")] + pub enable_udp: bool, + #[serde(default = "default_enable_tcp")] + pub enable_tcp: bool, + #[serde(default = "default_enable_tls")] + pub enable_tls: bool, + /// Optional TLS: cert/key paths (not used in MVP) pub tls_cert: Option, pub tls_key: Option, } +#[derive(Debug, Deserialize, Clone, Default)] +pub struct RelayOptions { + /// Optional UDP relay port range. If set, allocations bind within this range. + /// If omitted, OS chooses an ephemeral port. + #[serde(default)] + pub relay_port_min: Option, + #[serde(default)] + pub relay_port_max: Option, + + /// IP address used for binding relay sockets (default: 0.0.0.0). + /// Example: "0.0.0.0" or a specific local interface IP. + #[serde(default)] + pub relay_bind_ip: Option, + + /// Optional public/advertised IP that is placed into XOR-RELAYED-ADDRESS. + /// This is important when running behind NAT or when relay sockets bind to 0.0.0.0. + #[serde(default)] + pub advertised_ip: Option, +} + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct LoggingOptions { + /// Default tracing directive (overridable via RUST_LOG). + /// Example: "warn,niom_turn=info". + #[serde(default)] + pub default_directive: Option, +} + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct LimitsOptions { + /// Max concurrent allocations per source IP (across different source ports). + /// If omitted, unlimited. + #[serde(default)] + pub max_allocations_per_ip: Option, + + /// Max permissions per allocation. + /// If omitted, unlimited. + #[serde(default)] + pub max_permissions_per_allocation: Option, + + /// Max channel bindings per allocation. + /// If omitted, unlimited. + #[serde(default)] + pub max_channel_bindings_per_allocation: Option, + + /// Rate-limit unauthenticated responses (401/438) per source IP. + /// If omitted, unlimited. + #[serde(default)] + pub unauth_rps: Option, + #[serde(default)] + pub unauth_burst: Option, + + /// Rate-limit STUN Binding success responses per source IP. + /// If omitted, unlimited. + #[serde(default)] + pub binding_rps: Option, + #[serde(default)] + pub binding_burst: Option, +} + #[derive(Debug, Deserialize, Clone)] pub struct Config { /// Server options @@ -59,6 +169,18 @@ pub struct Config { /// Authentication behaviour advertised to clients. #[serde(default)] pub auth: AuthOptions, + + /// Relay socket behaviour (port range, bind ip, advertised ip). + #[serde(default)] + pub relay: RelayOptions, + + /// Logging/tracing defaults. + #[serde(default)] + pub logging: LoggingOptions, + + /// Resource limits (basic abuse-prevention). + #[serde(default)] + pub limits: LimitsOptions, } impl Config { diff --git a/src/constants.rs b/src/constants.rs index 18d9953..f56d404 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -26,13 +26,23 @@ pub const ATTR_LIFETIME: u16 = 0x000D; pub const ATTR_REALM: u16 = 0x0014; pub const ATTR_NONCE: u16 = 0x0015; pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012; +pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020; + +// RFC5389: FINGERPRINT +pub const ATTR_FINGERPRINT: u16 = 0x8028; // TURN attrs pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016; pub const ATTR_DATA: u16 = 0x0013; +pub const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019; + +// IP protocol numbers used by TURN REQUESTED-TRANSPORT +pub const IPPROTO_UDP: u8 = 17; +pub const IPPROTO_TCP: u8 = 6; // Some helper values pub const FAMILY_IPV4: u8 = 0x01; +pub const FAMILY_IPV6: u8 = 0x02; // Fingerprint XOR magic (XOR with CRC32 for FINGERPRINT attribute) pub const FINGERPRINT_XOR: u32 = 0x5354554e; diff --git a/src/lib.rs b/src/lib.rs index 07a5a15..c1d1e52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,14 @@ pub mod auth; pub mod config; pub mod constants; pub mod logging; +pub mod metrics; pub mod models; +pub mod rate_limit; pub mod server; pub mod stun; +pub mod tcp; pub mod tls; +pub mod turn_stream; pub mod traits; pub use crate::alloc::*; diff --git a/src/main.rs b/src/main.rs index 2758024..d25a439 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,28 @@ //! Binary entry point that wires configuration, UDP listener, optional TLS listener, and allocation handling. //! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners. use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; +use std::time::Duration; use tokio::net::UdpSocket; use tracing::{error, info}; // Use the library crate's public modules instead of local `mod` declarations. use niom_turn::alloc::AllocationManager; +use niom_turn::alloc::AllocationOptions; use niom_turn::auth::{AuthManager, InMemoryStore}; use niom_turn::config::{AuthOptions, Config}; -use niom_turn::server::udp_reader_loop; +use niom_turn::rate_limit::RateLimiters; #[tokio::main] async fn main() -> anyhow::Result<()> { - niom_turn::logging::init_tracing(); - - info!("niom-turn starting"); - // Bootstrap configuration: prefer appsettings.json, otherwise rely on baked-in demo defaults. let cfg = match Config::load_default() { Ok(c) => { - info!("loaded config from appsettings.json"); c } Err(e) => { - info!( + eprintln!( "no appsettings.json found or failed to load: {} — using defaults", e ); @@ -32,6 +30,12 @@ async fn main() -> anyhow::Result<()> { Config { server: niom_turn::config::ServerOptions { bind: "0.0.0.0:3478".to_string(), + udp_bind: None, + tcp_bind: None, + tls_bind: "0.0.0.0:5349".to_string(), + enable_udp: true, + enable_tcp: true, + enable_tls: true, tls_cert: None, tls_key: None, }, @@ -40,11 +44,39 @@ async fn main() -> anyhow::Result<()> { password: "secretpassword".into(), }], auth: AuthOptions::default(), + relay: niom_turn::config::RelayOptions::default(), + logging: niom_turn::config::LoggingOptions::default(), + limits: niom_turn::config::LimitsOptions::default(), } } }; - let bind_addr: SocketAddr = cfg.server.bind.parse()?; + let log_directive = cfg + .logging + .default_directive + .as_deref() + .unwrap_or("warn,niom_turn=info"); + niom_turn::logging::init_tracing_with_default(log_directive); + + // Build per-server rate limiters (defaults to disabled when unset). + let rate_limiters = Arc::new(RateLimiters::from_limits(&cfg.limits)); + + info!("niom-turn starting"); + info!("logging.default_directive={}", log_directive); + + let udp_bind = cfg + .server + .udp_bind + .clone() + .unwrap_or_else(|| cfg.server.bind.clone()); + let tcp_bind = cfg + .server + .tcp_bind + .clone() + .unwrap_or_else(|| cfg.server.bind.clone()); + let tls_bind = cfg.server.tls_bind.clone(); + + let udp_bind_addr: SocketAddr = udp_bind.parse()?; // Materialise the credential backend before starting network endpoints. let creds = InMemoryStore::new(); @@ -54,44 +86,148 @@ async fn main() -> anyhow::Result<()> { let auth = AuthManager::new(creds.clone(), &cfg.auth); - // Bind the UDP socket that receives STUN/TURN traffic from WebRTC clients. - let udp = UdpSocket::bind(bind_addr).await?; - let udp = Arc::new(udp); + let relay_bind_ip: IpAddr = cfg + .relay + .relay_bind_ip + .as_deref() + .unwrap_or("0.0.0.0") + .parse() + .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - // Allocation manager shared by UDP + TLS frontends. - let alloc_mgr = AllocationManager::new(); + let advertised_ip: Option = cfg + .relay + .advertised_ip + .as_deref() + .and_then(|s| s.parse().ok()); - // Spawn the asynchronous packet loop that handles all UDP requests. - let udp_clone = udp.clone(); - let auth_clone = auth.clone(); - let alloc_clone = alloc_mgr.clone(); + let alloc_mgr = AllocationManager::new_with_options(AllocationOptions { + relay_bind_ip, + relay_port_min: cfg.relay.relay_port_min, + relay_port_max: cfg.relay.relay_port_max, + advertised_ip, + max_allocations_per_ip: cfg.limits.max_allocations_per_ip, + max_permissions_per_allocation: cfg.limits.max_permissions_per_allocation, + max_channel_bindings_per_allocation: cfg.limits.max_channel_bindings_per_allocation, + }); + + // Periodically prune expired allocations so relay tasks can terminate even when idle. + alloc_mgr.spawn_housekeeping(Duration::from_secs(5)); + + // Periodically emit a compact metrics snapshot to logs. tokio::spawn(async move { - if let Err(e) = udp_reader_loop(udp_clone, auth_clone, alloc_clone).await { - error!("udp loop error: {:?}", e); + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + let snap = niom_turn::metrics::snapshot(); + info!( + "metrics: stun={} channeldata={} streams={} auth_challenge={} auth_stale={} auth_reject={} alloc_total={} alloc_ok={} alloc_fail={} perms_added={} channels_added={} alloc_active={} rate_limited={}", + snap.stun_messages_total, + snap.channel_data_total, + snap.stream_connections_total, + snap.auth_challenge_total, + snap.auth_stale_total, + snap.auth_reject_total, + snap.allocate_total, + snap.allocate_success_total, + snap.allocate_fail_total, + snap.permissions_added_total, + snap.channel_bindings_added_total, + snap.allocations_active, + snap.rate_limited_total + ); } }); - // Optionally start the TLS listener so `turns:` clients can connect via TCP/TLS. - if let (Some(cert), Some(key)) = (cfg.server.tls_cert.clone(), cfg.server.tls_key.clone()) { - let udp_for_tls = udp.clone(); - let auth_for_tls = auth.clone(); - let alloc_for_tls = alloc_mgr.clone(); + info!( + "listeners: udp={} tcp={} tls={} udp_bind={} tcp_bind={} tls_bind={}", + cfg.server.enable_udp, + cfg.server.enable_tcp, + cfg.server.enable_tls, + udp_bind, + tcp_bind, + tls_bind + ); + info!( + "relay: bind_ip={} port_range={:?}-{:?} advertised_ip={:?}", + relay_bind_ip, + cfg.relay.relay_port_min, + cfg.relay.relay_port_max, + advertised_ip + ); + + // Bind the UDP socket that receives STUN/TURN traffic from WebRTC clients. + let udp = if cfg.server.enable_udp { + let udp = UdpSocket::bind(udp_bind_addr).await?; + Some(Arc::new(udp)) + } else { + None + }; + + // Spawn the asynchronous packet loop that handles all UDP requests. + if let Some(udp_sock) = udp.clone() { + let udp_clone = udp_sock.clone(); + let auth_clone = auth.clone(); + let alloc_clone = alloc_mgr.clone(); + let rl = rate_limiters.clone(); tokio::spawn(async move { - if let Err(e) = niom_turn::tls::serve_tls( - "0.0.0.0:5349", - &cert, - &key, - udp_for_tls, - auth_for_tls, - alloc_for_tls, + if let Err(e) = niom_turn::server::udp_reader_loop_with_limits( + udp_clone, + auth_clone, + alloc_clone, + rl, ) .await { - error!("tls serve failed: {:?}", e); + error!("udp loop error: {:?}", e); } }); } + // Start a plain TCP listener for `turn:` clients that require TCP. + if cfg.server.enable_tcp { + let auth_for_tcp = auth.clone(); + let alloc_for_tcp = alloc_mgr.clone(); + let tcp_bind = tcp_bind.clone(); + let rl = rate_limiters.clone(); + tokio::spawn(async move { + if let Err(e) = niom_turn::tcp::serve_tcp_with_limits( + &tcp_bind, + auth_for_tcp, + alloc_for_tcp, + rl, + ) + .await + { + error!("tcp serve failed: {:?}", e); + } + }); + } + + // Optionally start the TLS listener so `turns:` clients can connect via TCP/TLS. + if cfg.server.enable_tls { + if let (Some(cert), Some(key)) = (cfg.server.tls_cert.clone(), cfg.server.tls_key.clone()) { + let auth_for_tls = auth.clone(); + let alloc_for_tls = alloc_mgr.clone(); + let tls_bind = tls_bind.clone(); + let rl = rate_limiters.clone(); + tokio::spawn(async move { + if let Err(e) = niom_turn::tls::serve_tls_with_limits( + &tls_bind, + &cert, + &key, + auth_for_tls, + alloc_for_tls, + rl, + ) + .await + { + error!("tls serve failed: {:?}", e); + } + }); + } else { + info!("TLS enabled but tls_cert/tls_key not configured; skipping TLS listener"); + } + } + // Keep the runtime alive while background tasks process packets. loop { tokio::time::sleep(std::time::Duration::from_secs(60)).await; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..fb78bf3 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,155 @@ +use std::sync::OnceLock; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[derive(Debug)] +pub struct Metrics { + pub stun_messages_total: AtomicU64, + pub channel_data_total: AtomicU64, + pub stream_connections_total: AtomicU64, + + pub auth_challenge_total: AtomicU64, + pub auth_stale_total: AtomicU64, + pub auth_reject_total: AtomicU64, + + pub allocate_total: AtomicU64, + pub allocate_success_total: AtomicU64, + pub allocate_fail_total: AtomicU64, + + pub permissions_added_total: AtomicU64, + pub channel_bindings_added_total: AtomicU64, + + pub allocations_active: AtomicU64, + + pub rate_limited_total: AtomicU64, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + stun_messages_total: AtomicU64::new(0), + channel_data_total: AtomicU64::new(0), + stream_connections_total: AtomicU64::new(0), + auth_challenge_total: AtomicU64::new(0), + auth_stale_total: AtomicU64::new(0), + auth_reject_total: AtomicU64::new(0), + allocate_total: AtomicU64::new(0), + allocate_success_total: AtomicU64::new(0), + allocate_fail_total: AtomicU64::new(0), + permissions_added_total: AtomicU64::new(0), + channel_bindings_added_total: AtomicU64::new(0), + allocations_active: AtomicU64::new(0), + rate_limited_total: AtomicU64::new(0), + } + } +} + +static METRICS: OnceLock = OnceLock::new(); + +pub fn metrics() -> &'static Metrics { + METRICS.get_or_init(Metrics::default) +} + +pub fn inc_stun_messages() { + metrics().stun_messages_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_channel_data() { + metrics().channel_data_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_stream_connections() { + metrics().stream_connections_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_auth_challenge() { + metrics().auth_challenge_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_auth_stale() { + metrics().auth_stale_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_auth_reject() { + metrics().auth_reject_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_allocate_total() { + metrics().allocate_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_allocate_success() { + metrics().allocate_success_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_allocate_fail() { + metrics().allocate_fail_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_permission_added() { + metrics().permissions_added_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_channel_binding_added() { + metrics().channel_bindings_added_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_rate_limited() { + metrics().rate_limited_total.fetch_add(1, Ordering::Relaxed); +} + +pub fn inc_allocations_active() { + metrics().allocations_active.fetch_add(1, Ordering::Relaxed); +} + +pub fn dec_allocations_active() { + // Saturating decrement to avoid underflow. + let m = metrics(); + let mut current = m.allocations_active.load(Ordering::Relaxed); + while current > 0 { + match m.allocations_active.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(v) => current = v, + } + } +} + +#[derive(Debug, Clone)] +pub struct MetricsSnapshot { + pub stun_messages_total: u64, + pub channel_data_total: u64, + pub stream_connections_total: u64, + pub auth_challenge_total: u64, + pub auth_stale_total: u64, + pub auth_reject_total: u64, + pub allocate_total: u64, + pub allocate_success_total: u64, + pub allocate_fail_total: u64, + pub permissions_added_total: u64, + pub channel_bindings_added_total: u64, + pub allocations_active: u64, + pub rate_limited_total: u64, +} + +pub fn snapshot() -> MetricsSnapshot { + let m = metrics(); + MetricsSnapshot { + stun_messages_total: m.stun_messages_total.load(Ordering::Relaxed), + channel_data_total: m.channel_data_total.load(Ordering::Relaxed), + stream_connections_total: m.stream_connections_total.load(Ordering::Relaxed), + auth_challenge_total: m.auth_challenge_total.load(Ordering::Relaxed), + auth_stale_total: m.auth_stale_total.load(Ordering::Relaxed), + auth_reject_total: m.auth_reject_total.load(Ordering::Relaxed), + allocate_total: m.allocate_total.load(Ordering::Relaxed), + allocate_success_total: m.allocate_success_total.load(Ordering::Relaxed), + allocate_fail_total: m.allocate_fail_total.load(Ordering::Relaxed), + permissions_added_total: m.permissions_added_total.load(Ordering::Relaxed), + channel_bindings_added_total: m.channel_bindings_added_total.load(Ordering::Relaxed), + allocations_active: m.allocations_active.load(Ordering::Relaxed), + rate_limited_total: m.rate_limited_total.load(Ordering::Relaxed), + } +} diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..6cfc32c --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,145 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Mutex; +use std::time::Instant; + +use crate::config::LimitsOptions; + +#[derive(Debug, Clone, Copy)] +struct TokenBucketConfig { + rate_per_sec: f64, + burst: f64, +} + +#[derive(Debug)] +struct Bucket { + tokens: f64, + last: Instant, +} + +#[derive(Debug)] +pub struct TokenBucketRateLimiter { + cfg: TokenBucketConfig, + state: Mutex>, +} + +impl TokenBucketRateLimiter { + pub fn new(rate_per_sec: u32, burst: u32) -> Self { + Self { + cfg: TokenBucketConfig { + rate_per_sec: rate_per_sec.max(1) as f64, + burst: burst.max(1) as f64, + }, + state: Mutex::new(HashMap::new()), + } + } + + fn allow_at(&self, ip: IpAddr, now: Instant) -> bool { + let mut guard = self.state.lock().unwrap(); + let entry = guard.entry(ip).or_insert(Bucket { + tokens: self.cfg.burst, + last: now, + }); + + let elapsed = now.saturating_duration_since(entry.last); + let refill = elapsed.as_secs_f64() * self.cfg.rate_per_sec; + entry.tokens = (entry.tokens + refill).min(self.cfg.burst); + entry.last = now; + + if entry.tokens >= 1.0 { + entry.tokens -= 1.0; + true + } else { + false + } + } + + pub fn allow(&self, ip: IpAddr) -> bool { + self.allow_at(ip, Instant::now()) + } + + #[cfg(test)] + fn allow_at_for_test(&self, ip: IpAddr, now: Instant) -> bool { + self.allow_at(ip, now) + } + + #[cfg(test)] + fn configured_burst(&self) -> u32 { + self.cfg.burst as u32 + } +} + +#[derive(Debug, Default)] +pub struct RateLimiters { + unauth: Option, + binding: Option, +} + +impl RateLimiters { + pub fn disabled() -> Self { + Self { + unauth: None, + binding: None, + } + } + + /// Build per-IP rate limiters from runtime configuration. + /// + /// Any pair where either value is missing/zero disables that limiter. + pub fn from_limits(limits: &LimitsOptions) -> Self { + let unauth = match (limits.unauth_rps, limits.unauth_burst) { + (Some(rps), Some(burst)) if rps > 0 && burst > 0 => { + Some(TokenBucketRateLimiter::new(rps, burst)) + } + _ => None, + }; + + let binding = match (limits.binding_rps, limits.binding_burst) { + (Some(rps), Some(burst)) if rps > 0 && burst > 0 => { + Some(TokenBucketRateLimiter::new(rps, burst)) + } + _ => None, + }; + + Self { unauth, binding } + } + + pub fn allow_unauth(&self, ip: IpAddr) -> bool { + match &self.unauth { + Some(l) => l.allow(ip), + None => true, + } + } + + pub fn allow_binding(&self, ip: IpAddr) -> bool { + match &self.binding { + Some(l) => l.allow(ip), + None => true, + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn token_bucket_enforces_burst_then_refills() { + let lim = TokenBucketRateLimiter::new(10, 3); + assert_eq!(lim.configured_burst(), 3); + + let ip: IpAddr = "127.0.0.1".parse().unwrap(); + let t0 = Instant::now(); + + assert!(lim.allow_at_for_test(ip, t0)); + assert!(lim.allow_at_for_test(ip, t0)); + assert!(lim.allow_at_for_test(ip, t0)); + assert!(!lim.allow_at_for_test(ip, t0)); + + let t1 = t0 + Duration::from_millis(200); // 10 rps => ~2 tokens + assert!(lim.allow_at_for_test(ip, t1)); + assert!(lim.allow_at_for_test(ip, t1)); + } +} diff --git a/src/server.rs b/src/server.rs index c58e984..fcba8ee 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,13 +3,17 @@ use std::sync::Arc; use tokio::net::UdpSocket; use tracing::{error, warn}; -use crate::alloc::AllocationManager; +use crate::alloc::{AllocationManager, ClientSink}; +use crate::alloc::AllocationError; use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; use crate::constants::*; +use crate::rate_limit::RateLimiters; use crate::stun::{ - build_401_response, build_allocate_success, build_error_response, build_lifetime_success, - build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data, - parse_message, + build_401_response, build_allocate_success_with_integrity, build_error_response, + build_error_response_with_integrity, build_lifetime_success_with_integrity, + build_success_response_with_integrity, decode_xor_peer_address, extract_lifetime_seconds, + parse_channel_data, extract_requested_transport_protocol, parse_message, + validate_fingerprint_if_present, }; use std::time::Duration; @@ -18,6 +22,16 @@ pub async fn udp_reader_loop( udp: Arc, auth: AuthManager, allocs: AllocationManager, +) -> anyhow::Result<()> { + udp_reader_loop_with_limits(udp, auth, allocs, Arc::new(RateLimiters::disabled())).await +} + +/// UDP reader loop with explicit (per-server) rate limiters. +pub async fn udp_reader_loop_with_limits( + udp: Arc, + auth: AuthManager, + allocs: AllocationManager, + rate_limiters: Arc, ) -> anyhow::Result<()> { let mut buf = vec![0u8; 1500]; loop { @@ -25,6 +39,7 @@ pub async fn udp_reader_loop( tracing::debug!("got {} bytes from {}", len, peer); if let Some((channel, payload)) = parse_channel_data(&buf[..len]) { + crate::metrics::inc_channel_data(); let allocation = match allocs.get_allocation(&peer) { Some(a) => a, None => { @@ -69,6 +84,11 @@ pub async fn udp_reader_loop( } if let Ok(msg) = parse_message(&buf[..len]) { + if !validate_fingerprint_if_present(&msg) { + tracing::debug!("dropping STUN/TURN message from {} due to invalid FINGERPRINT", peer); + continue; + } + crate::metrics::inc_stun_messages(); tracing::info!( "STUN/TURN message from {} type=0x{:04x} len={}", peer, @@ -85,16 +105,22 @@ pub async fn udp_reader_loop( ); if requires_auth { - match auth.authenticate(&msg, &peer).await { - AuthStatus::Granted { username } => { + let key = match auth.authenticate(&msg, &peer).await { + AuthStatus::Granted { username, key } => { tracing::debug!( "TURN auth ok for {} as {} (0x{:04x})", peer, username, msg.header.msg_type ); + key } AuthStatus::Challenge { nonce } => { + crate::metrics::inc_auth_challenge(); + if !rate_limiters.allow_unauth(peer.ip()) { + crate::metrics::inc_rate_limited(); + continue; + } let resp = build_401_response( &msg.header, auth.realm(), @@ -106,6 +132,11 @@ pub async fn udp_reader_loop( continue; } AuthStatus::StaleNonce { nonce } => { + crate::metrics::inc_auth_stale(); + if !rate_limiters.allow_unauth(peer.ip()) { + crate::metrics::inc_rate_limited(); + continue; + } let resp = build_401_response( &msg.header, auth.realm(), @@ -117,19 +148,55 @@ pub async fn udp_reader_loop( continue; } AuthStatus::Reject { code, reason } => { + crate::metrics::inc_auth_reject(); let resp = build_error_response(&msg.header, code, reason); let _ = udp.send_to(&resp, &peer).await; continue; } - } + }; match msg.header.msg_type { METHOD_ALLOCATE => { + crate::metrics::inc_allocate_total(); + + // TURN Allocate MUST include REQUESTED-TRANSPORT; WebRTC expects UDP (17). + match extract_requested_transport_protocol(&msg) { + Some(IPPROTO_UDP) => {} + Some(_) => { + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity( + &msg.header, + 442, + "Unsupported Transport", + &key, + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + None => { + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Missing REQUESTED-TRANSPORT", + &key, + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } + } + let requested_lifetime = extract_lifetime_seconds(&msg) .map(|secs| Duration::from_secs(secs as u64)) .filter(|d| !d.is_zero()); - match allocs.allocate_for(peer, udp.clone()).await { + match allocs + .allocate_for(peer, ClientSink::Udp { + sock: udp.clone(), + addr: peer, + }) + .await + { Ok(relay_addr) => { let applied = match allocs.refresh_allocation(peer, requested_lifetime) { @@ -140,10 +207,11 @@ pub async fn udp_reader_loop( peer, e ); - let resp = build_error_response( + let resp = build_error_response_with_integrity( &msg.header, 500, "Allocate Failed", + &key, ); let _ = udp.send_to(&resp, &peer).await; continue; @@ -151,20 +219,32 @@ pub async fn udp_reader_loop( }; let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32; - let resp = - build_allocate_success(&msg.header, &relay_addr, lifetime_secs); + let advertised = allocs.relay_addr_for_response(relay_addr); + let resp = build_allocate_success_with_integrity( + &msg.header, + &advertised, + lifetime_secs, + &key, + ); tracing::info!( "allocated relay {} for {} lifetime={}s", relay_addr, peer, lifetime_secs ); + crate::metrics::inc_allocate_success(); let _ = udp.send_to(&resp, &peer).await; } Err(e) => { tracing::error!("allocate failed: {:?}", e); - let resp = - build_error_response(&msg.header, 500, "Allocate Failed"); + let (code, reason) = match e.downcast_ref::() { + Some(AllocationError::AllocationQuotaExceeded) => { + (486, "Allocation Quota Reached") + } + _ => (500, "Allocate Failed"), + }; + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity(&msg.header, code, reason, &key); let _ = udp.send_to(&resp, &peer).await; } } @@ -174,7 +254,7 @@ pub async fn udp_reader_loop( if allocs.get_allocation(&peer).is_none() { warn!("create-permission without allocation from {}", peer); let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); + build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -195,6 +275,7 @@ pub async fn udp_reader_loop( peer, peer_addr ); + crate::metrics::inc_permission_added(); added += 1; } Err(e) => { @@ -204,6 +285,19 @@ pub async fn udp_reader_loop( peer_addr, e ); + if matches!( + e.downcast_ref::(), + Some(AllocationError::PermissionQuotaExceeded) + ) { + let resp = build_error_response_with_integrity( + &msg.header, + 508, + "Insufficient Capacity", + &key, + ); + let _ = udp.send_to(&resp, &peer).await; + continue; + } } } } else { @@ -213,10 +307,10 @@ pub async fn udp_reader_loop( if added == 0 { let resp = - build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS"); + build_error_response_with_integrity(&msg.header, 400, "No valid XOR-PEER-ADDRESS", &key); let _ = udp.send_to(&resp, &peer).await; } else { - let resp = build_success_response(&msg.header); + let resp = build_success_response_with_integrity(&msg.header, &key); let _ = udp.send_to(&resp, &peer).await; } continue; @@ -227,7 +321,7 @@ pub async fn udp_reader_loop( None => { warn!("channel-bind without allocation from {}", peer); let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); + build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -242,10 +336,11 @@ pub async fn udp_reader_loop( let (channel_attr, peer_attr) = match (channel_attr, peer_attr) { (Some(c), Some(p)) => (c, p), _ => { - let resp = build_error_response( + let resp = build_error_response_with_integrity( &msg.header, 400, "Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS", + &key, ); let _ = udp.send_to(&resp, &peer).await; continue; @@ -260,10 +355,11 @@ pub async fn udp_reader_loop( ) { Some(addr) => addr, None => { - let resp = build_error_response( + let resp = build_error_response_with_integrity( &msg.header, 400, "Invalid XOR-PEER-ADDRESS", + &key, ); let _ = udp.send_to(&resp, &peer).await; continue; @@ -271,7 +367,12 @@ pub async fn udp_reader_loop( }; if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); + let resp = build_error_response_with_integrity( + &msg.header, + 403, + "Peer Not Permitted", + &key, + ); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -284,13 +385,19 @@ pub async fn udp_reader_loop( channel, e ); - let resp = - build_error_response(&msg.header, 500, "Channel Bind Failed"); + let (code, reason) = match e.downcast_ref::() { + Some(AllocationError::ChannelQuotaExceeded) => { + (508, "Insufficient Capacity") + } + _ => (500, "Channel Bind Failed"), + }; + let resp = build_error_response_with_integrity(&msg.header, code, reason, &key); let _ = udp.send_to(&resp, &peer).await; continue; } - let resp = build_success_response(&msg.header); + crate::metrics::inc_channel_binding_added(); + let resp = build_success_response_with_integrity(&msg.header, &key); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -300,7 +407,7 @@ pub async fn udp_reader_loop( None => { warn!("send indication without allocation from {}", peer); let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); + build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -314,10 +421,11 @@ pub async fn udp_reader_loop( let (peer_attr, data_attr) = match (peer_attr, data_attr) { (Some(p), Some(d)) => (p, d), _ => { - let resp = build_error_response( + let resp = build_error_response_with_integrity( &msg.header, 400, "Missing DATA or XOR-PEER-ADDRESS", + &key, ); let _ = udp.send_to(&resp, &peer).await; continue; @@ -330,10 +438,11 @@ pub async fn udp_reader_loop( ) { Some(addr) => addr, None => { - let resp = build_error_response( + let resp = build_error_response_with_integrity( &msg.header, 400, "Invalid XOR-PEER-ADDRESS", + &key, ); let _ = udp.send_to(&resp, &peer).await; continue; @@ -341,7 +450,12 @@ pub async fn udp_reader_loop( }; if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); + let resp = build_error_response_with_integrity( + &msg.header, + 403, + "Peer Not Permitted", + &key, + ); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -354,7 +468,7 @@ pub async fn udp_reader_loop( peer, peer_addr ); - let resp = build_success_response(&msg.header); + let resp = build_success_response_with_integrity(&msg.header, &key); let _ = udp.send_to(&resp, &peer).await; } Err(e) => { @@ -364,8 +478,12 @@ pub async fn udp_reader_loop( peer_addr, e ); - let resp = - build_error_response(&msg.header, 500, "Peer Send Failed"); + let resp = build_error_response_with_integrity( + &msg.header, + 500, + "Peer Send Failed", + &key, + ); let _ = udp.send_to(&resp, &peer).await; } } @@ -386,22 +504,32 @@ pub async fn udp_reader_loop( applied.as_secs() ); } - let resp = build_lifetime_success( + let resp = build_lifetime_success_with_integrity( &msg.header, applied.as_secs().min(u32::MAX as u64) as u32, + &key, ); let _ = udp.send_to(&resp, &peer).await; } Err(_) => { - let resp = - build_error_response(&msg.header, 437, "Allocation Mismatch"); + let resp = build_error_response_with_integrity( + &msg.header, + 437, + "Allocation Mismatch", + &key, + ); let _ = udp.send_to(&resp, &peer).await; } } continue; } _ => { - let resp = build_error_response(&msg.header, 420, "Unknown TURN Method"); + let resp = build_error_response_with_integrity( + &msg.header, + 420, + "Unknown TURN Method", + &key, + ); let _ = udp.send_to(&resp, &peer).await; continue; } @@ -410,10 +538,18 @@ pub async fn udp_reader_loop( match msg.header.msg_type { METHOD_BINDING => { - let resp = build_success_response(&msg.header); - let _ = udp.send_to(&resp, &peer).await; + if rate_limiters.allow_binding(peer.ip()) { + let resp = crate::stun::build_binding_success(&msg.header, &peer); + let _ = udp.send_to(&resp, &peer).await; + } else { + crate::metrics::inc_rate_limited(); + } } _ => { + if !rate_limiters.allow_unauth(peer.ip()) { + crate::metrics::inc_rate_limited(); + continue; + } let nonce = auth.mint_nonce(&peer); let resp = build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized"); diff --git a/src/stun.rs b/src/stun.rs index 458b164..193561b 100644 --- a/src/stun.rs +++ b/src/stun.rs @@ -117,10 +117,7 @@ pub fn build_401_response( } // Update length - let total_len = (buf.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - buf[2] = len_bytes[0]; - buf[3] = len_bytes[1]; + append_fingerprint(&mut buf); buf.to_vec() } @@ -150,14 +147,46 @@ pub fn build_error_response(req: &StunHeader, code: u16, reason: &str) -> Vec Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_ERROR; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let mut value = Vec::new(); + let class = (code / 100) as u8; + let number = (code % 100) as u8; + value.extend_from_slice(&[0, 0]); + value.push(class); + value.push(number); + value.extend_from_slice(reason.as_bytes()); + + buf.extend_from_slice(&ATTR_ERROR_CODE.to_be_bytes()); + buf.extend_from_slice(&(value.len() as u16).to_be_bytes()); + buf.extend_from_slice(&value); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + append_message_integrity(&mut buf, key); + append_fingerprint(&mut buf); + buf.to_vec() +} + /// Build an Allocate success response containing XOR-RELAYED-ADDRESS and LIFETIME attributes. pub fn build_allocate_success( req: &StunHeader, @@ -188,10 +217,43 @@ pub fn build_allocate_success( buf.extend_from_slice(&[0]); } - let total_len = (buf.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - buf[2] = len_bytes[0]; - buf[3] = len_bytes[1]; + append_fingerprint(&mut buf); + buf.to_vec() +} + +/// Build an Allocate success response including MESSAGE-INTEGRITY and FINGERPRINT. +pub fn build_allocate_success_with_integrity( + req: &StunHeader, + relay: &std::net::SocketAddr, + lifetime_secs: u32, + key: &[u8], +) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let relay_val = encode_xor_relayed_address(relay, &req.transaction_id); + buf.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes()); + buf.extend_from_slice(&((relay_val.len() as u16).to_be_bytes())); + buf.extend_from_slice(&relay_val); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + let lifetime_bytes = lifetime_secs.to_be_bytes(); + buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes()); + buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(&lifetime_bytes); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + append_message_integrity(&mut buf, key); + append_fingerprint(&mut buf); buf.to_vec() } @@ -213,10 +275,34 @@ pub fn build_lifetime_success(req: &StunHeader, lifetime_secs: u32) -> Vec { buf.extend_from_slice(&[0]); } - let total_len = (buf.len() - 20) as u16; - let len_bytes = total_len.to_be_bytes(); - buf[2] = len_bytes[0]; - buf[3] = len_bytes[1]; + append_fingerprint(&mut buf); + buf.to_vec() +} + +/// Build a Refresh success response including MESSAGE-INTEGRITY and FINGERPRINT. +pub fn build_lifetime_success_with_integrity( + req: &StunHeader, + lifetime_secs: u32, + key: &[u8], +) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let lifetime_bytes = lifetime_secs.to_be_bytes(); + buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes()); + buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(&lifetime_bytes); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + append_message_integrity(&mut buf, key); + append_fingerprint(&mut buf); buf.to_vec() } @@ -239,6 +325,22 @@ pub fn extract_lifetime_seconds(msg: &StunMessage) -> Option { }) } +/// Extract REQUESTED-TRANSPORT IP protocol number from a TURN Allocate request. +/// +/// RFC5766: value is 4 bytes: (protocol, rsvd, rsvd, rsvd) +pub fn extract_requested_transport_protocol(msg: &StunMessage) -> Option { + msg.attributes + .iter() + .find(|a| a.typ == ATTR_REQUESTED_TRANSPORT) + .and_then(|attr| { + if attr.value.len() >= 4 { + Some(attr.value[0]) + } else { + None + } + }) +} + /// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { msg.attributes @@ -246,25 +348,67 @@ pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { .find(|a| a.typ == ATTR_MESSAGE_INTEGRITY) } -/// Validate MESSAGE-INTEGRITY using provided key (password). Returns true if valid. -/// Note: This is a simplified validator that assumes the MESSAGE-INTEGRITY attribute exists and -/// that the message bytes passed are the full STUN message (including attributes). +/// Validate MESSAGE-INTEGRITY using provided key. Returns true if valid. +/// +/// RFC5389: The HMAC is computed over the message up to (and including) the +/// MESSAGE-INTEGRITY attribute, with the header length field set to the end of +/// that attribute. Any attributes after MESSAGE-INTEGRITY (e.g. FINGERPRINT) +/// are excluded from the HMAC computation. pub fn validate_message_integrity(msg: &StunMessage, key: &[u8]) -> bool { if let Some(mi) = find_message_integrity(msg) { // MESSAGE-INTEGRITY attribute value is 20 bytes (HMAC-SHA1) if mi.value.len() != 20 { return false; } - // Compute HMAC over the message up to (but excluding) MESSAGE-INTEGRITY attribute header and value - let mi_attr_start = mi.offset; // offset points to attribute header - let msg_slice = &msg.raw[..mi_attr_start]; - let computed = crate::stun::compute_message_integrity(key, msg_slice); - // compare first 20 bytes - return &computed[..20] == mi.value.as_slice(); + + let mi_end = mi.offset + 4 + HMAC_SHA1_LEN; + if mi_end > msg.raw.len() { + return false; + } + + let mut signed = msg.raw[..mi_end].to_vec(); + + // Adjust header length to end-of-MI (exclude later attributes like FINGERPRINT). + let len = (mi_end - 20) as u16; + let len_bytes = len.to_be_bytes(); + signed[2] = len_bytes[0]; + signed[3] = len_bytes[1]; + + // Zero the MI value before computing. + let value_start = mi.offset + 4; + for b in &mut signed[value_start..value_start + HMAC_SHA1_LEN] { + *b = 0; + } + + let computed = crate::stun::compute_message_integrity(key, &signed); + return &computed[..HMAC_SHA1_LEN] == mi.value.as_slice(); } false } +fn append_message_integrity(buf: &mut bytes::BytesMut, key: &[u8]) { + // Append attribute header and placeholder; set length to end-of-MI, then compute + // HMAC over the message slice up to end-of-MI (with the MI placeholder still zero). + buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); + buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); + let mi_val_pos = buf.len(); + buf.extend_from_slice(&[0u8; HMAC_SHA1_LEN]); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0u8]); + } + + let mi_end = buf.len(); + + // Set length to end-of-MI (excluding any later attributes like FINGERPRINT) + let len = (mi_end - 20) as u16; + let len_bytes = len.to_be_bytes(); + buf[2] = len_bytes[0]; + buf[3] = len_bytes[1]; + + let hmac = compute_message_integrity(key, &buf[..mi_end]); + buf[mi_val_pos..mi_val_pos + HMAC_SHA1_LEN].copy_from_slice(&hmac[..HMAC_SHA1_LEN]); +} + /// Build a simple success (200) response echoing transaction id pub fn build_success_response(req: &StunHeader) -> Vec { use bytes::BytesMut; @@ -274,11 +418,66 @@ pub fn build_success_response(req: &StunHeader) -> Vec { buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&MAGIC_COOKIE_BYTES); buf.extend_from_slice(&req.transaction_id); + append_fingerprint(&mut buf); + buf.to_vec() +} + +/// Build a simple success (200) response including MESSAGE-INTEGRITY and FINGERPRINT. +pub fn build_success_response_with_integrity(req: &StunHeader, key: &[u8]) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + append_message_integrity(&mut buf, key); + append_fingerprint(&mut buf); + buf.to_vec() +} + +/// Build a STUN Binding success response containing XOR-MAPPED-ADDRESS. +pub fn build_binding_success(req: &StunHeader, mapped: &std::net::SocketAddr) -> Vec { + use bytes::BytesMut; + let mut buf = BytesMut::new(); + let msg_type: u16 = req.msg_type | CLASS_SUCCESS; + buf.extend_from_slice(&msg_type.to_be_bytes()); + buf.extend_from_slice(&0u16.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE_BYTES); + buf.extend_from_slice(&req.transaction_id); + + let mapped_val = encode_xor_address(mapped, &req.transaction_id); + buf.extend_from_slice(&ATTR_XOR_MAPPED_ADDRESS.to_be_bytes()); + buf.extend_from_slice(&((mapped_val.len() as u16).to_be_bytes())); + buf.extend_from_slice(&mapped_val); + while (buf.len() % 4) != 0 { + buf.extend_from_slice(&[0]); + } + + append_fingerprint(&mut buf); + buf.to_vec() +} + +fn append_fingerprint(buf: &mut bytes::BytesMut) { + // FINGERPRINT must be the last attribute. We'll append a placeholder, + // update the message length to include it, compute CRC32 over the message + // up to (but excluding) the FINGERPRINT attribute, and then write it. + let fingerprint_attr_offset = buf.len(); + + buf.extend_from_slice(&ATTR_FINGERPRINT.to_be_bytes()); + buf.extend_from_slice(&(4u16).to_be_bytes()); + let fingerprint_value_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + + // Update STUN length (bytes after the 20-byte header) let total_len = (buf.len() - 20) as u16; let len_bytes = total_len.to_be_bytes(); buf[2] = len_bytes[0]; buf[3] = len_bytes[1]; - buf.to_vec() + + let fp = compute_fingerprint(&buf[..fingerprint_attr_offset]); + let fp_bytes = fp.to_be_bytes(); + buf[fingerprint_value_pos..fingerprint_value_pos + 4].copy_from_slice(&fp_bytes); } /// Compute STUN fingerprint (XOR-32 of CRC32) @@ -290,6 +489,37 @@ pub fn compute_fingerprint(msg: &[u8]) -> u32 { crc ^ FINGERPRINT_XOR } +pub fn find_fingerprint(msg: &StunMessage) -> Option<&StunAttribute> { + msg.attributes + .iter() + .find(|a| a.typ == ATTR_FINGERPRINT) +} + +/// Validate FINGERPRINT if present. +/// +/// Returns true if the message has no FINGERPRINT attribute. If present, enforces: +/// - value length is exactly 4 bytes +/// - FINGERPRINT is the last attribute +/// - CRC32/XOR matches RFC5389 +pub fn validate_fingerprint_if_present(msg: &StunMessage) -> bool { + let Some(fp) = find_fingerprint(msg) else { + return true; + }; + + if fp.value.len() != 4 { + return false; + } + + // FINGERPRINT should be the last attribute (type+len+4) + if fp.offset + 8 != msg.raw.len() { + return false; + } + + let expected = compute_fingerprint(&msg.raw[..fp.offset]); + let actual = u32::from_be_bytes([fp.value[0], fp.value[1], fp.value[2], fp.value[3]]); + expected == actual +} + /// Compute MESSAGE-INTEGRITY (HMAC-SHA1) over the message pub fn compute_message_integrity(key: &[u8], msg: &[u8]) -> Vec { use hmac::{Hmac, Mac}; @@ -311,9 +541,6 @@ pub fn build_channel_data(channel: u16, data: &[u8]) -> Vec { out.extend_from_slice(&channel.to_be_bytes()); out.extend_from_slice(&(data.len() as u16).to_be_bytes()); out.extend_from_slice(data); - while (out.len() % 4) != 0 { - out.push(0); - } out } @@ -344,10 +571,7 @@ pub fn build_data_indication(peer: &std::net::SocketAddr, data: &[u8]) -> Vec Option<(u16, &[u8])> { Some((channel, &buf[4..4 + data_len])) } -fn encode_xor_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec { +fn encode_xor_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12]) -> Vec { use std::net::IpAddr; let mut out = Vec::new(); match addr.ip() { @@ -385,9 +609,22 @@ fn encode_xor_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec< out.push(octets[i] ^ cookie_bytes[i]); } } - IpAddr::V6(_v6) => { - // For now, we don't support IPv6 in this MVP implementation - // Return an empty vec to indicate unsupported + IpAddr::V6(v6) => { + out.push(0); + out.push(FAMILY_IPV6); + let port = addr.port(); + let xport = (port ^ ((MAGIC_COOKIE_U32 >> 16) as u16)) as u16; + out.extend_from_slice(&xport.to_be_bytes()); + + // RFC5389: IPv6 XOR uses magic cookie (4) concatenated with transaction id (12) + let mut xor_key = [0u8; 16]; + xor_key[..4].copy_from_slice(&MAGIC_COOKIE_BYTES); + xor_key[4..].copy_from_slice(trans_id); + + let octets = v6.octets(); + for i in 0..16 { + out.push(octets[i] ^ xor_key[i]); + } } } out @@ -401,36 +638,58 @@ pub fn encode_xor_peer_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12]) encode_xor_address(addr, trans_id) } -/// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only) +/// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4/IPv6) pub fn decode_xor_relayed_address( value: &[u8], - _trans_id: &[u8; 12], + trans_id: &[u8; 12], ) -> Option { - if value.len() < 8 { + if value.len() < 4 { return None; } - if value[1] != FAMILY_IPV4 { - return None; - } // not IPv4 let xport = u16::from_be_bytes([value[2], value[3]]); let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16); - let cookie_bytes = MAGIC_COOKIE_BYTES; - let mut ipb = [0u8; 4]; - for i in 0..4 { - ipb[i] = value[4 + i] ^ cookie_bytes[i]; + + match value[1] { + FAMILY_IPV4 => { + if value.len() < 8 { + return None; + } + let cookie_bytes = MAGIC_COOKIE_BYTES; + let mut ipb = [0u8; 4]; + for i in 0..4 { + ipb[i] = value[4 + i] ^ cookie_bytes[i]; + } + let ip = std::net::Ipv4Addr::from(ipb); + Some(std::net::SocketAddr::new(std::net::IpAddr::V4(ip), port)) + } + FAMILY_IPV6 => { + if value.len() < 20 { + return None; + } + let mut xor_key = [0u8; 16]; + xor_key[..4].copy_from_slice(&MAGIC_COOKIE_BYTES); + xor_key[4..].copy_from_slice(trans_id); + + let mut ipb = [0u8; 16]; + for i in 0..16 { + ipb[i] = value[4 + i] ^ xor_key[i]; + } + let ip = std::net::Ipv6Addr::from(ipb); + Some(std::net::SocketAddr::new(std::net::IpAddr::V6(ip), port)) + } + _ => None, } - let ip = std::net::Ipv4Addr::from(ipb); - Some(std::net::SocketAddr::new(std::net::IpAddr::V4(ip), port)) } -/// Decode XOR-PEER-ADDRESS / XOR-MAPPED-ADDRESS style attributes (IPv4 only). -pub fn decode_xor_peer_address(value: &[u8], _trans_id: &[u8; 12]) -> Option { - decode_xor_relayed_address(value, _trans_id) +/// Decode XOR-PEER-ADDRESS / XOR-MAPPED-ADDRESS style attributes (IPv4/IPv6). +pub fn decode_xor_peer_address(value: &[u8], trans_id: &[u8; 12]) -> Option { + decode_xor_relayed_address(value, trans_id) } #[cfg(test)] mod tests { use super::*; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[test] fn parse_minimal_binding() { @@ -445,6 +704,7 @@ mod tests { assert_eq!(msg.header.msg_type, METHOD_BINDING); assert_eq!(msg.header.transaction_id, trans); assert!(msg.attributes.is_empty()); + assert!(validate_fingerprint_if_present(&msg)); } #[test] @@ -459,6 +719,30 @@ mod tests { // parse back should succeed let parsed = parse_message(&out).expect("parse resp"); assert!(!parsed.attributes.is_empty()); + assert!(validate_fingerprint_if_present(&parsed)); + } + + #[test] + fn fingerprint_is_appended_and_valid() { + let req = StunHeader { + msg_type: METHOD_BINDING, + length: 0, + cookie: MAGIC_COOKIE_U32, + transaction_id: [4u8; 12], + }; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 9999); + let out = build_binding_success(&req, &peer); + let parsed = parse_message(&out).expect("parse"); + assert!(find_fingerprint(&parsed).is_some()); + assert!(validate_fingerprint_if_present(&parsed)); + + // Tamper any byte before the fingerprint attribute => invalid + let fp = find_fingerprint(&parsed).unwrap(); + let mut tampered = out.clone(); + tampered[fp.offset.saturating_sub(1)] ^= 0xFF; + let parsed2 = parse_message(&tampered).expect("parse tampered"); + assert!(!validate_fingerprint_if_present(&parsed2)); } #[test] @@ -486,7 +770,6 @@ mod tests { } // MESSAGE-INTEGRITY placeholder (0x0008) length 20 - let mi_attr_offset = buf.len(); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); let mi_val_pos = buf.len(); @@ -495,19 +778,23 @@ mod tests { buf.extend_from_slice(&[0u8]); } - // Fix length - let total_len = (buf.len() - 20) as u16; + // Fix length to end-of-MI + let mi_end = buf.len(); + let total_len = (mi_end - 20) as u16; let len_bytes = total_len.to_be_bytes(); buf[2] = len_bytes[0]; buf[3] = len_bytes[1]; - // Compute HMAC over message up to MI attribute header (mi_attr_offset) - let hmac = compute_message_integrity(password.as_bytes(), &buf[..mi_attr_offset]); + // Compute HMAC over message up to end-of-MI (MI value is still zero here) + let hmac = compute_message_integrity(password.as_bytes(), &buf[..mi_end]); // place first 20 bytes into mi value for i in 0..20 { buf[mi_val_pos + i] = hmac[i]; } + // Add FINGERPRINT after MESSAGE-INTEGRITY and ensure validation still succeeds. + append_fingerprint(&mut buf); + // Parse and validate let parsed = parse_message(&buf).expect("parsed"); assert!(validate_message_integrity(&parsed, password.as_bytes())); @@ -518,4 +805,44 @@ mod tests { let parsed2 = parse_message(&tampered).expect("parsed2"); assert!(!validate_message_integrity(&parsed2, password.as_bytes())); } + + #[test] + fn build_binding_success_includes_xor_mapped_ipv4() { + let req = StunHeader { + msg_type: METHOD_BINDING, + length: 0, + cookie: MAGIC_COOKIE_U32, + transaction_id: [7u8; 12], + }; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 10)), 54321); + let out = build_binding_success(&req, &peer); + let parsed = parse_message(&out).expect("parse"); + let attr = parsed + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_MAPPED_ADDRESS) + .expect("xor-mapped present"); + let decoded = decode_xor_peer_address(&attr.value, &req.transaction_id).expect("decode"); + assert_eq!(decoded, peer); + } + + #[test] + fn build_binding_success_includes_xor_mapped_ipv6() { + let req = StunHeader { + msg_type: METHOD_BINDING, + length: 0, + cookie: MAGIC_COOKIE_U32, + transaction_id: [8u8; 12], + }; + let peer = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 12345); + let out = build_binding_success(&req, &peer); + let parsed = parse_message(&out).expect("parse"); + let attr = parsed + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_MAPPED_ADDRESS) + .expect("xor-mapped present"); + let decoded = decode_xor_peer_address(&attr.value, &req.transaction_id).expect("decode"); + assert_eq!(decoded, peer); + } } diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..cf1e307 --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,38 @@ +//! Plain TCP listener for TURN over TCP (RFC5389 framing + ChannelData interleaving). +use tokio::net::TcpListener; + +use crate::alloc::AllocationManager; +use crate::auth::{AuthManager, InMemoryStore}; +use crate::rate_limit::RateLimiters; +use crate::turn_stream::handle_turn_stream_connection_with_limits; + +pub async fn serve_tcp( + bind: &str, + auth: AuthManager, + allocs: AllocationManager, +) -> anyhow::Result<()> { + serve_tcp_with_limits(bind, auth, allocs, std::sync::Arc::new(RateLimiters::disabled())).await +} + +pub async fn serve_tcp_with_limits( + bind: &str, + auth: AuthManager, + allocs: AllocationManager, + rate_limiters: std::sync::Arc, +) -> anyhow::Result<()> { + let listener = TcpListener::bind(bind).await?; + tracing::info!("TCP listener bound to {}", bind); + + loop { + let (stream, peer) = listener.accept().await?; + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + let rl = rate_limiters.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_turn_stream_connection_with_limits(stream, peer, auth_clone, alloc_clone, rl).await { + tracing::info!("TCP connection ended for {}: {:?}", peer, e); + } + }); + } +} diff --git a/src/tls.rs b/src/tls.rs index c5c0abf..f30768f 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -4,19 +4,15 @@ use anyhow::Context; use std::fs::File; use std::io::BufReader; use std::sync::Arc; -use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; use tokio_rustls::TlsAcceptor; use crate::alloc::AllocationManager; -use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; -use crate::constants::*; -use crate::stun::{ - build_401_response, build_allocate_success, build_error_response, build_lifetime_success, - build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_message, -}; +use crate::auth::{AuthManager, InMemoryStore}; +use crate::rate_limit::RateLimiters; +use crate::turn_stream::{handle_turn_stream_connection, handle_turn_stream_connection_with_limits}; fn load_certs(path: &str) -> anyhow::Result> { let f = File::open(path).context("opening cert file")?; @@ -48,9 +44,27 @@ pub async fn serve_tls( bind: &str, cert_path: &str, key_path: &str, - udp_sock: std::sync::Arc, auth: AuthManager, allocs: AllocationManager, +) -> anyhow::Result<()> { + serve_tls_with_limits( + bind, + cert_path, + key_path, + auth, + allocs, + std::sync::Arc::new(RateLimiters::disabled()), + ) + .await +} + +pub async fn serve_tls_with_limits( + bind: &str, + cert_path: &str, + key_path: &str, + auth: AuthManager, + allocs: AllocationManager, + rate_limiters: std::sync::Arc, ) -> anyhow::Result<()> { let certs = load_certs(cert_path)?; let key = load_private_key(key_path)?; @@ -68,23 +82,15 @@ pub async fn serve_tls( loop { let (stream, peer) = listener.accept().await?; let acceptor = acceptor.clone(); - let udp_clone = udp_sock.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); + let rl = rate_limiters.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { - Ok(mut tls_stream) => { - if let Err(e) = handle_tls_connection( - &mut tls_stream, - peer, - udp_clone, - auth_clone, - alloc_clone, - ) - .await - { - tracing::error!("TLS connection error: {:?}", e); + Ok(tls_stream) => { + if let Err(e) = handle_tls_connection_with_limits(tls_stream, peer, auth_clone, alloc_clone, rl).await { + tracing::info!("TLS connection ended for {}: {:?}", peer, e); } } Err(e) => tracing::error!("TLS accept error: {:?}", e), @@ -93,596 +99,27 @@ pub async fn serve_tls( } } -#[allow(clippy::too_many_arguments)] pub async fn handle_tls_connection( - tls_stream: &mut S, + tls_stream: S, peer: std::net::SocketAddr, - udp_sock: std::sync::Arc, auth: AuthManager, allocs: AllocationManager, ) -> anyhow::Result<()> where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - tracing::info!("accepted TLS connection from {}", peer); - let mut read_buf = vec![0u8; 4096]; - let mut buffer: Vec = Vec::new(); - - loop { - match tls_stream.read(&mut read_buf).await { - Ok(0) => { - tracing::info!("TLS client {} closed connection", peer); - break; - } - Ok(n) => { - buffer.extend_from_slice(&read_buf[..n]); - while buffer.len() >= 20 { - let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; - let total = len + 20; - if buffer.len() < total { - break; - } - let chunk = buffer.drain(..total).collect::>(); - match parse_message(&chunk) { - Ok(msg) => { - tracing::info!( - "STUN/TURN over TLS from {} type=0x{:04x} len={}", - peer, - msg.header.msg_type, - total - ); - - let requires_auth = matches!( - msg.header.msg_type, - METHOD_ALLOCATE - | METHOD_CREATE_PERMISSION - | METHOD_CHANNEL_BIND - | METHOD_SEND - | METHOD_REFRESH - ); - - if requires_auth { - match auth.authenticate(&msg, &peer).await { - AuthStatus::Granted { username } => { - tracing::debug!( - "TURN TLS auth ok for {} as {} (0x{:04x})", - peer, - username, - msg.header.msg_type - ); - } - AuthStatus::Challenge { nonce } => { - let resp = build_401_response( - &msg.header, - auth.realm(), - &nonce, - 401, - "Unauthorized", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls challenge: {:?}", - e - ); - } - continue; - } - AuthStatus::StaleNonce { nonce } => { - let resp = build_401_response( - &msg.header, - auth.realm(), - &nonce, - 438, - "Stale Nonce", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls stale nonce: {:?}", - e - ); - } - continue; - } - AuthStatus::Reject { code, reason } => { - let resp = build_error_response(&msg.header, code, reason); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls auth error: {:?}", - e - ); - } - continue; - } - } - } - - match msg.header.msg_type { - METHOD_ALLOCATE => { - let requested_lifetime = extract_lifetime_seconds(&msg) - .map(|secs| Duration::from_secs(secs as u64)) - .filter(|d| !d.is_zero()); - - match allocs.allocate_for(peer, udp_sock.clone()).await { - Ok(relay_addr) => { - let applied = match allocs - .refresh_allocation(peer, requested_lifetime) - { - Ok(d) => d, - Err(e) => { - tracing::error!( - "failed to apply TLS lifetime for {}: {:?}", - peer, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Allocate Failed", - ); - if let Err(e2) = - tls_stream.write_all(&resp).await - { - tracing::error!( - "failed to write tls allocate error: {:?}", - e2 - ); - } - continue; - } - }; - let lifetime_secs = - applied.as_secs().min(u32::MAX as u64) as u32; - let resp = build_allocate_success( - &msg.header, - &relay_addr, - lifetime_secs, - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls allocate success: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!("allocate failed (tls): {:?}", e); - let resp = build_error_response( - &msg.header, - 500, - "Allocate Failed", - ); - if let Err(e2) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls allocate error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_CREATE_PERMISSION => { - if allocs.get_allocation(&peer).is_none() { - tracing::warn!( - "create-permission without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!("failed to write tls error: {:?}", e); - } - continue; - } - - let mut added = 0usize; - for attr in msg - .attributes - .iter() - .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) - { - if let Some(peer_addr) = decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) { - match allocs.add_permission(peer, peer_addr) { - Ok(()) => { - tracing::info!( - "added TLS permission for {} -> {}", - peer, - peer_addr - ); - added += 1; - } - Err(e) => { - tracing::error!( - "failed to persist TLS permission {} -> {}: {:?}", - peer, - peer_addr, - e - ); - } - } - } else { - tracing::warn!( - "invalid XOR-PEER-ADDRESS via TLS from {}", - peer - ); - } - } - - let resp = if added == 0 { - build_error_response( - &msg.header, - 400, - "No valid XOR-PEER-ADDRESS", - ) - } else { - build_success_response(&msg.header) - }; - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!("failed to write tls response: {:?}", e); - } - continue; - } - METHOD_CHANNEL_BIND => { - let allocation = match allocs.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!( - "channel-bind without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let channel_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_CHANNEL_NUMBER); - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - - let channel = match channel_attr.and_then(|attr| { - if attr.value.len() >= 4 { - Some(u16::from_be_bytes([attr.value[0], attr.value[1]])) - } else { - None - } - }) { - Some(c) => c, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing CHANNEL-NUMBER", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if channel < 0x4000 || channel > 0x7FFF { - let resp = build_error_response( - &msg.header, - 400, - "Channel Out Of Range", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!("failed to write tls error: {:?}", e); - } - continue; - } - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response( - &msg.header, - 403, - "Peer Not Permitted", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!("failed to write tls error: {:?}", e); - } - continue; - } - - match allocs.add_channel_binding(peer, channel, peer_addr) { - Ok(()) => { - tracing::info!( - "bound channel 0x{:04x} for {} -> {} over TLS", - channel, - peer, - peer_addr - ); - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls response: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!( - "failed TLS channel binding {} -> {} (0x{:04x}): {:?}", - peer, - peer_addr, - channel, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Channel Binding Failed", - ); - if let Err(e2) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_SEND => { - let allocation = match allocs.get_allocation(&peer) { - Some(a) => a, - None => { - tracing::warn!( - "send without allocation from {} (tls)", - peer - ); - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let peer_attr = msg - .attributes - .iter() - .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); - let data_attr = - msg.attributes.iter().find(|a| a.typ == ATTR_DATA); - - let peer_addr = match peer_attr.and_then(|attr| { - decode_xor_peer_address( - &attr.value, - &msg.header.transaction_id, - ) - }) { - Some(addr) => addr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing XOR-PEER-ADDRESS", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - let data_attr = match data_attr { - Some(attr) => attr, - None => { - let resp = build_error_response( - &msg.header, - 400, - "Missing DATA Attribute", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e - ); - } - continue; - } - }; - - if !allocation.is_peer_allowed(&peer_addr) { - let resp = build_error_response( - &msg.header, - 403, - "Peer Not Permitted", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!("failed to write tls error: {:?}", e); - } - continue; - } - - match allocation.send_to_peer(peer_addr, &data_attr.value).await - { - Ok(sent) => { - tracing::info!( - "forwarded {} bytes from {} to {} via TLS session", - sent, - peer, - peer_addr - ); - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls response: {:?}", - e - ); - } - } - Err(e) => { - tracing::error!( - "failed to send payload from {} to {} via TLS: {:?}", - peer, - peer_addr, - e - ); - let resp = build_error_response( - &msg.header, - 500, - "Peer Send Failed", - ); - if let Err(e2) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls error: {:?}", - e2 - ); - } - } - } - continue; - } - METHOD_REFRESH => { - let requested = extract_lifetime_seconds(&msg) - .map(|secs| Duration::from_secs(secs as u64)); - - match allocs.refresh_allocation(peer, requested) { - Ok(applied) => { - if applied.is_zero() { - tracing::info!( - "allocation for {} released (tls)", - peer - ); - } else { - tracing::debug!( - "allocation for {} refreshed to {}s (tls)", - peer, - applied.as_secs() - ); - } - let resp = build_lifetime_success( - &msg.header, - applied.as_secs().min(u32::MAX as u64) as u32, - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls refresh response: {:?}", - e - ); - } - } - Err(_) => { - let resp = build_error_response( - &msg.header, - 437, - "Allocation Mismatch", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls refresh error: {:?}", - e - ); - } - } - } - continue; - } - METHOD_BINDING => { - let resp = build_success_response(&msg.header); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls binding response: {:?}", - e - ); - } - continue; - } - _ => { - let nonce = auth.mint_nonce(&peer); - let resp = build_401_response( - &msg.header, - auth.realm(), - &nonce, - 401, - "Unauthorized", - ); - if let Err(e) = tls_stream.write_all(&resp).await { - tracing::error!( - "failed to write tls fallback challenge: {:?}", - e - ); - } - continue; - } - } - } - Err(e) => { - tracing::warn!( - error = ?e, - length = chunk.len(), - "dropping unparseable STUN/TURN frame over TLS from {}", - peer - ); - } - } - } - } - Err(e) => { - tracing::error!("tls read error from {}: {:?}", peer, e); - break; - } - } - } - - Ok(()) + handle_turn_stream_connection(tls_stream, peer, auth, allocs).await +} + +pub async fn handle_tls_connection_with_limits( + tls_stream: S, + peer: std::net::SocketAddr, + auth: AuthManager, + allocs: AllocationManager, + rate_limiters: std::sync::Arc, +) -> anyhow::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + handle_turn_stream_connection_with_limits(tls_stream, peer, auth, allocs, rate_limiters).await } diff --git a/src/turn_stream.rs b/src/turn_stream.rs new file mode 100644 index 0000000..bded45b --- /dev/null +++ b/src/turn_stream.rs @@ -0,0 +1,715 @@ +//! Shared TURN-over-stream (TCP/TLS) handler. +//! +//! This implements TURN over a byte stream where STUN messages and ChannelData frames +//! may be interleaved on the same connection. +use std::net::SocketAddr; +use std::time::Duration; + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc; + +use crate::alloc::{AllocationManager, ClientSink}; +use crate::alloc::AllocationError; +use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; +use crate::constants::*; +use crate::stun::{ + build_401_response, build_allocate_success_with_integrity, build_error_response, + build_error_response_with_integrity, build_lifetime_success_with_integrity, + build_success_response_with_integrity, decode_xor_peer_address, extract_lifetime_seconds, + parse_message, validate_fingerprint_if_present, extract_requested_transport_protocol, +}; + +use crate::rate_limit::RateLimiters; + +enum StreamFrame { + ChannelData { channel: u16, payload: Vec }, + Stun(crate::models::stun::StunMessage), +} + +const MAX_STREAM_BUFFER_BYTES: usize = 256 * 1024; +const MAX_STUN_BODY_BYTES: usize = 64 * 1024; +const MAX_CHANNELDATA_BYTES: usize = 64 * 1024; + +fn try_pop_next_frame(buffer: &mut Vec) -> Option> { + if buffer.len() < 4 { + return None; + } + + // ChannelData detection: channel number is 0x4000..0x7FFF + let channel = u16::from_be_bytes([buffer[0], buffer[1]]); + if (channel & 0xC000) == 0x4000 { + let data_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; + if data_len > MAX_CHANNELDATA_BYTES { + buffer.drain(..1); + return Some(Err(anyhow::anyhow!( + "channeldata length {} exceeds max {}", + data_len, + MAX_CHANNELDATA_BYTES + ))); + } + let total = 4 + data_len; + if buffer.len() < total { + return None; + } + let frame = buffer.drain(..total).collect::>(); + let payload = frame[4..].to_vec(); + return Some(Ok(StreamFrame::ChannelData { channel, payload })); + } + + // Otherwise assume STUN over stream: 20-byte header + length. + if buffer.len() < 20 { + return None; + } + + // Quick resync: STUN messages must include the magic cookie. + if buffer[4..8] != MAGIC_COOKIE_BYTES { + buffer.drain(..1); + return Some(Err(anyhow::anyhow!("invalid STUN magic cookie"))); + } + + let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; + if len > MAX_STUN_BODY_BYTES { + buffer.drain(..1); + return Some(Err(anyhow::anyhow!( + "stun length {} exceeds max {}", + len, + MAX_STUN_BODY_BYTES + ))); + } + let total = 20 + len; + if buffer.len() < total { + return None; + } + let chunk = buffer.drain(..total).collect::>(); + Some(parse_message(&chunk).map(StreamFrame::Stun).map_err(|e| anyhow::anyhow!("parse stun: {e:?}"))) +} + +/// Handle a single TURN-over-stream connection (plain TCP or TLS). +/// +/// All server responses and relayed peer packets are written back over the same stream. +#[allow(clippy::too_many_arguments)] +pub async fn handle_turn_stream_connection( + stream: S, + peer: SocketAddr, + auth: AuthManager, + allocs: AllocationManager, +) -> anyhow::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + handle_turn_stream_connection_with_limits( + stream, + peer, + auth, + allocs, + std::sync::Arc::new(RateLimiters::disabled()), + ) + .await +} + +/// Handle a single TURN-over-stream connection (plain TCP or TLS) with explicit rate limiters. +#[allow(clippy::too_many_arguments)] +pub async fn handle_turn_stream_connection_with_limits( + stream: S, + peer: SocketAddr, + auth: AuthManager, + allocs: AllocationManager, + rate_limiters: std::sync::Arc, +) -> anyhow::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + tracing::info!("accepted TURN stream connection from {}", peer); + crate::metrics::inc_stream_connections(); + + let (mut reader, mut writer) = tokio::io::split(stream); + + // Single ordered output queue for both direct responses and relayed peer traffic. + let (tx, mut rx) = mpsc::channel::>(256); + + let writer_task = tokio::spawn(async move { + while let Some(buf) = rx.recv().await { + if let Err(e) = writer.write_all(&buf).await { + tracing::info!("stream writer closed: {:?}", e); + break; + } + } + }); + + let mut read_buf = vec![0u8; 4096]; + let mut buffer: Vec = Vec::new(); + + loop { + match reader.read(&mut read_buf).await { + Ok(0) => { + tracing::info!("TURN stream client {} closed connection", peer); + break; + } + Ok(n) => { + buffer.extend_from_slice(&read_buf[..n]); + + if buffer.len() > MAX_STREAM_BUFFER_BYTES { + tracing::warn!( + "closing stream connection {} due to oversized buffer ({} bytes)", + peer, + buffer.len() + ); + break; + } + + while let Some(frame_res) = try_pop_next_frame(&mut buffer) { + let frame = match frame_res { + Ok(f) => f, + Err(e) => { + tracing::debug!("dropping invalid frame from {}: {:?}", peer, e); + continue; + } + }; + + match frame { + StreamFrame::ChannelData { channel, payload } => { + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + tracing::warn!("channel data without allocation from {}", peer); + continue; + } + }; + + let target = match allocation.channel_peer(channel) { + Some(addr) => addr, + None => { + tracing::warn!( + "channel data with unknown channel 0x{:04x} from {}", + channel, + peer + ); + continue; + } + }; + + if !allocation.is_peer_allowed(&target) { + tracing::warn!( + "channel data target {} no longer permitted for {}", + target, + peer + ); + continue; + } + + if let Err(e) = allocation.send_to_peer(target, &payload).await { + tracing::error!( + "failed to forward channel data 0x{:04x} from {} to {}: {:?}", + channel, + peer, + target, + e + ); + } + } + StreamFrame::Stun(msg) => { + if !validate_fingerprint_if_present(&msg) { + tracing::debug!( + "dropping STUN/TURN over stream from {} due to invalid FINGERPRINT", + peer + ); + continue; + } + crate::metrics::inc_stun_messages(); + tracing::info!( + "STUN/TURN over stream from {} type=0x{:04x}", + peer, + msg.header.msg_type + ); + + let requires_auth = matches!( + msg.header.msg_type, + METHOD_ALLOCATE + | METHOD_CREATE_PERMISSION + | METHOD_CHANNEL_BIND + | METHOD_SEND + | METHOD_REFRESH + ); + + let mut auth_key: Option> = None; + if requires_auth { + match auth.authenticate(&msg, &peer).await { + AuthStatus::Granted { username, key } => { + auth_key = Some(key); + tracing::debug!( + "TURN stream auth ok for {} as {} (0x{:04x})", + peer, + username, + msg.header.msg_type + ); + } + AuthStatus::Challenge { nonce } => { + crate::metrics::inc_auth_challenge(); + if !rate_limiters.allow_unauth(peer.ip()) { + crate::metrics::inc_rate_limited(); + continue; + } + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 401, + "Unauthorized", + ); + let _ = tx.send(resp).await; + continue; + } + AuthStatus::StaleNonce { nonce } => { + crate::metrics::inc_auth_stale(); + if !rate_limiters.allow_unauth(peer.ip()) { + crate::metrics::inc_rate_limited(); + continue; + } + let resp = build_401_response( + &msg.header, + auth.realm(), + &nonce, + 438, + "Stale Nonce", + ); + let _ = tx.send(resp).await; + continue; + } + AuthStatus::Reject { code, reason } => { + crate::metrics::inc_auth_reject(); + let resp = build_error_response(&msg.header, code, reason); + let _ = tx.send(resp).await; + continue; + } + } + } + + match msg.header.msg_type { + METHOD_ALLOCATE => { + crate::metrics::inc_allocate_total(); + + let key = auth_key + .as_deref() + .expect("auth key must be set after AuthStatus::Granted"); + + // TURN Allocate MUST include REQUESTED-TRANSPORT; WebRTC expects UDP (17). + match extract_requested_transport_protocol(&msg) { + Some(IPPROTO_UDP) => {} + Some(_) => { + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity( + &msg.header, + 442, + "Unsupported Transport", + key, + ); + let _ = tx.send(resp).await; + continue; + } + None => { + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Missing REQUESTED-TRANSPORT", + key, + ); + let _ = tx.send(resp).await; + continue; + } + } + + let requested_lifetime = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)) + .filter(|d| !d.is_zero()); + + match allocs + .allocate_for( + peer, + ClientSink::Stream { + tx: tx.clone(), + }, + ) + .await + { + Ok(relay_addr) => { + let applied = match allocs + .refresh_allocation(peer, requested_lifetime) + { + Ok(d) => d, + Err(e) => { + tracing::error!( + "failed to apply lifetime for {}: {:?}", + peer, + e + ); + let resp = build_error_response_with_integrity( + &msg.header, + 500, + "Allocate Failed", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + let lifetime_secs = + applied.as_secs().min(u32::MAX as u64) as u32; + let advertised = + allocs.relay_addr_for_response(relay_addr); + let resp = build_allocate_success_with_integrity( + &msg.header, + &advertised, + lifetime_secs, + key, + ); + crate::metrics::inc_allocate_success(); + let _ = tx.send(resp).await; + } + Err(e) => { + tracing::error!("allocate failed (stream): {:?}", e); + let (code, reason) = match e.downcast_ref::() { + Some(AllocationError::AllocationQuotaExceeded) => { + (486, "Allocation Quota Reached") + } + _ => (500, "Allocate Failed"), + }; + crate::metrics::inc_allocate_fail(); + let resp = build_error_response_with_integrity( + &msg.header, + code, + reason, + key, + ); + let _ = tx.send(resp).await; + } + } + } + METHOD_CREATE_PERMISSION => { + let key = auth_key + .as_deref() + .expect("auth key must be set after AuthStatus::Granted"); + if allocs.get_allocation(&peer).is_none() { + let resp = build_error_response_with_integrity( + &msg.header, + 437, + "Allocation Mismatch", + key, + ); + let _ = tx.send(resp).await; + continue; + } + + let mut added = 0usize; + for attr in msg + .attributes + .iter() + .filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS) + { + if let Some(peer_addr) = decode_xor_peer_address( + &attr.value, + &msg.header.transaction_id, + ) { + match allocs.add_permission(peer, peer_addr) { + Ok(()) => added += 1, + Err(e) => { + if matches!( + e.downcast_ref::(), + Some(AllocationError::PermissionQuotaExceeded) + ) { + let resp = build_error_response_with_integrity( + &msg.header, + 508, + "Insufficient Capacity", + key, + ); + let _ = tx.send(resp).await; + continue; + } + } + } + } + } + + if added == 0 { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "No valid XOR-PEER-ADDRESS", + key, + ); + let _ = tx.send(resp).await; + } else { + let resp = build_success_response_with_integrity(&msg.header, key); + let _ = tx.send(resp).await; + } + } + METHOD_CHANNEL_BIND => { + let key = auth_key + .as_deref() + .expect("auth key must be set after AuthStatus::Granted"); + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + let resp = build_error_response_with_integrity( + &msg.header, + 437, + "Allocation Mismatch", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + let channel_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_CHANNEL_NUMBER); + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let (channel_attr, peer_attr) = match (channel_attr, peer_attr) + { + (Some(c), Some(p)) => (c, p), + _ => { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + if channel_attr.value.len() < 2 { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Invalid CHANNEL-NUMBER", + key, + ); + let _ = tx.send(resp).await; + continue; + } + + let channel = u16::from_be_bytes([ + channel_attr.value[0], + channel_attr.value[1], + ]); + let peer_addr = match decode_xor_peer_address( + &peer_attr.value, + &msg.header.transaction_id, + ) { + Some(addr) => addr, + None => { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Invalid XOR-PEER-ADDRESS", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response_with_integrity( + &msg.header, + 403, + "Peer Not Permitted", + key, + ); + let _ = tx.send(resp).await; + continue; + } + + if let Err(e) = + allocs.add_channel_binding(peer, channel, peer_addr) + { + tracing::error!( + "failed to persist channel binding {} -> {} (0x{:04x}): {:?}", + peer, + peer_addr, + channel, + e + ); + let (code, reason) = match e.downcast_ref::() { + Some(AllocationError::ChannelQuotaExceeded) => { + (508, "Insufficient Capacity") + } + _ => (500, "Channel Bind Failed"), + }; + let resp = build_error_response_with_integrity( + &msg.header, + code, + reason, + key, + ); + let _ = tx.send(resp).await; + continue; + } + + crate::metrics::inc_channel_binding_added(); + + let resp = build_success_response_with_integrity(&msg.header, key); + let _ = tx.send(resp).await; + } + METHOD_SEND => { + let key = auth_key + .as_deref() + .expect("auth key must be set after AuthStatus::Granted"); + let allocation = match allocs.get_allocation(&peer) { + Some(a) => a, + None => { + let resp = build_error_response_with_integrity( + &msg.header, + 437, + "Allocation Mismatch", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + let peer_attr = msg + .attributes + .iter() + .find(|a| a.typ == ATTR_XOR_PEER_ADDRESS); + let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA); + let (peer_attr, data_attr) = match (peer_attr, data_attr) { + (Some(p), Some(d)) => (p, d), + _ => { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Missing DATA or XOR-PEER-ADDRESS", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + let peer_addr = match decode_xor_peer_address( + &peer_attr.value, + &msg.header.transaction_id, + ) { + Some(addr) => addr, + None => { + let resp = build_error_response_with_integrity( + &msg.header, + 400, + "Invalid XOR-PEER-ADDRESS", + key, + ); + let _ = tx.send(resp).await; + continue; + } + }; + + if !allocation.is_peer_allowed(&peer_addr) { + let resp = build_error_response_with_integrity( + &msg.header, + 403, + "Peer Not Permitted", + key, + ); + let _ = tx.send(resp).await; + continue; + } + + match allocation.send_to_peer(peer_addr, &data_attr.value).await { + Ok(_) => { + let resp = build_success_response_with_integrity(&msg.header, key); + let _ = tx.send(resp).await; + } + Err(e) => { + tracing::error!( + "failed to send payload from {} to {}: {:?}", + peer, + peer_addr, + e + ); + let resp = build_error_response_with_integrity( + &msg.header, + 500, + "Peer Send Failed", + key, + ); + let _ = tx.send(resp).await; + } + } + } + METHOD_REFRESH => { + let key = auth_key + .as_deref() + .expect("auth key must be set after AuthStatus::Granted"); + let requested = extract_lifetime_seconds(&msg) + .map(|secs| Duration::from_secs(secs as u64)); + + match allocs.refresh_allocation(peer, requested) { + Ok(applied) => { + let resp = build_lifetime_success_with_integrity( + &msg.header, + applied.as_secs().min(u32::MAX as u64) as u32, + key, + ); + let _ = tx.send(resp).await; + } + Err(_) => { + let resp = build_error_response_with_integrity( + &msg.header, + 437, + "Allocation Mismatch", + key, + ); + let _ = tx.send(resp).await; + } + } + } + METHOD_BINDING => { + if rate_limiters.allow_binding(peer.ip()) { + let resp = crate::stun::build_binding_success(&msg.header, &peer); + let _ = tx.send(resp).await; + } else { + crate::metrics::inc_rate_limited(); + } + } + _ => { + let resp = match auth_key.as_deref() { + Some(key) => build_error_response_with_integrity( + &msg.header, + 420, + "Unknown TURN Method", + key, + ), + None => build_error_response( + &msg.header, + 420, + "Unknown TURN Method", + ), + }; + let _ = tx.send(resp).await; + } + } + } + } + } + } + Err(e) => return Err(anyhow::anyhow!("stream read error: {e:?}")), + } + } + + // Stop writer. + drop(tx); + let _ = writer_task.await; + + Ok(()) +} diff --git a/tests/alloc/unit.rs b/tests/alloc/unit.rs index 90f918d..c74869c 100644 --- a/tests/alloc/unit.rs +++ b/tests/alloc/unit.rs @@ -6,7 +6,7 @@ mod support; mod helpers; use helpers::*; -use niom_turn::alloc::AllocationManager; +use niom_turn::alloc::{AllocationManager, ClientSink}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -16,7 +16,13 @@ async fn allocate_sample(manager: &AllocationManager) -> SocketAddr { let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("udp bind")); let client = sample_client(); manager - .allocate_for(client, server) + .allocate_for( + client, + ClientSink::Udp { + sock: server, + addr: client, + }, + ) .await .expect("allocate relay"); client diff --git a/tests/auth/integration_tls.rs b/tests/auth/integration_tls.rs index 3ebf5d6..001df6b 100644 --- a/tests/auth/integration_tls.rs +++ b/tests/auth/integration_tls.rs @@ -11,14 +11,12 @@ use niom_turn::alloc::AllocationManager; use support::{default_test_credentials, init_tracing, test_auth_manager}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, UdpSocket}; +use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tls_rejects_invalid_credentials() { init_tracing(); - let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); - let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth = test_auth_manager(username, password); let allocs = AllocationManager::new(); @@ -34,7 +32,6 @@ async fn tls_rejects_invalid_credentials() { let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); - let udp_clone = udp_arc.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { @@ -44,21 +41,18 @@ async fn tls_rejects_invalid_credentials() { Err(_) => break, }; let acceptor = acceptor.clone(); - let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { - Ok(mut tls_stream) => { + Ok(tls_stream) => { if let Err(e) = niom_turn::tls::handle_tls_connection( - &mut tls_stream, + tls_stream, peer, - udp_clone, auth_clone, alloc_clone, ) - .await - { + .await { tracing::error!("tls connection error: {:?}", e); } } diff --git a/tests/channel/integration_tls.rs b/tests/channel/integration_tls.rs index 96d2ab8..21c5eab 100644 --- a/tests/channel/integration_tls.rs +++ b/tests/channel/integration_tls.rs @@ -12,14 +12,12 @@ use niom_turn::auth; use std::sync::Arc; use support::{default_test_credentials, init_tracing, test_auth_manager}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, UdpSocket}; +use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tls_channel_bind_without_allocation_returns_mismatch() { init_tracing(); - let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); - let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth_manager = test_auth_manager(username, password); let allocs = AllocationManager::new(); @@ -35,7 +33,6 @@ async fn tls_channel_bind_without_allocation_returns_mismatch() { let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); - let udp_clone = udp_arc.clone(); let auth_clone = auth_manager.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { @@ -45,21 +42,18 @@ async fn tls_channel_bind_without_allocation_returns_mismatch() { Err(_) => break, }; let acceptor = acceptor.clone(); - let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { - Ok(mut tls_stream) => { + Ok(tls_stream) => { if let Err(e) = niom_turn::tls::handle_tls_connection( - &mut tls_stream, + tls_stream, peer, - udp_clone, auth_clone, alloc_clone, ) - .await - { + .await { tracing::error!("tls connection error: {:?}", e); } } diff --git a/tests/channel/unit.rs b/tests/channel/unit.rs index 8cf6e4c..552caee 100644 --- a/tests/channel/unit.rs +++ b/tests/channel/unit.rs @@ -27,6 +27,7 @@ async fn channel_sink_mock_records_payload() { fn parse_channel_data_round_trip() { let payload = sample_payload(); let frame = build_channel_data(sample_channel_number(), &payload); + assert_eq!(frame.len(), 4 + payload.len()); let (channel, body) = parse_channel_data(&frame).expect("parse channel frame"); assert_eq!(channel, sample_channel_number()); assert_eq!(body, payload.as_slice()); diff --git a/tests/errors/integration_tls.rs b/tests/errors/integration_tls.rs index 111b233..6f7dbb9 100644 --- a/tests/errors/integration_tls.rs +++ b/tests/errors/integration_tls.rs @@ -10,15 +10,13 @@ use niom_turn::alloc::AllocationManager; use std::sync::Arc; use support::{init_tracing, test_auth_manager, default_test_credentials}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, UdpSocket}; +use tokio::net::TcpListener; use tokio::time::{timeout, Duration}; use tokio_rustls::TlsAcceptor; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn malformed_tls_frame_is_ignored() { init_tracing(); - let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); - let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth = test_auth_manager(username, password); let allocs = AllocationManager::new(); @@ -34,7 +32,6 @@ async fn malformed_tls_frame_is_ignored() { let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); - let udp_clone = udp_arc.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); tokio::spawn(async move { @@ -44,16 +41,14 @@ async fn malformed_tls_frame_is_ignored() { Err(_) => break, }; let acceptor = acceptor.clone(); - let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { - Ok(mut tls_stream) => { + Ok(tls_stream) => { let _ = niom_turn::tls::handle_tls_connection( - &mut tls_stream, + tls_stream, peer, - udp_clone, auth_clone, alloc_clone, ) diff --git a/tests/rate_limit_tcp.rs b/tests/rate_limit_tcp.rs new file mode 100644 index 0000000..019766f --- /dev/null +++ b/tests/rate_limit_tcp.rs @@ -0,0 +1,153 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::InMemoryStore; +use niom_turn::config::LimitsOptions; +use niom_turn::rate_limit::RateLimiters; +use tokio::io::AsyncWriteExt; +use tokio::net::{TcpListener, TcpStream}; +use tokio::time::{timeout, Duration}; + +use crate::support::stream::{StreamFrame, StreamFramer}; +use crate::support::stun_builders::{build_allocate_request, build_binding_request}; +use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager}; + +mod support; + +async fn start_tcp_test_server( + auth: niom_turn::auth::AuthManager, + allocs: AllocationManager, + rate_limiters: Arc, +) -> SocketAddr { + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + let rl = rate_limiters.clone(); + tokio::spawn(async move { + let _ = niom_turn::turn_stream::handle_turn_stream_connection_with_limits( + stream, peer, auth_clone, alloc_clone, rl, + ) + .await; + }); + } + }); + + tcp_addr +} + +#[tokio::test] +async fn tcp_binding_is_rate_limited_by_ip() { + init_tracing_with("warn,niom_turn=info"); + + // Configure a very small burst to make the test deterministic. + let mut limits = LimitsOptions::default(); + limits.binding_rps = Some(1); + limits.binding_burst = Some(1); + let rate_limiters = Arc::new(RateLimiters::from_limits(&limits)); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = start_tcp_test_server(auth.clone(), allocs.clone(), rate_limiters.clone()).await; + + let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + + // Fire multiple Binding requests quickly; with burst=1 we should only get 1 success response. + for _ in 0..3 { + let req = build_binding_request(); + stream.write_all(&req).await.expect("write binding"); + } + + let mut framer = StreamFramer::new(); + let mut responses = 0usize; + + // Read for a short bounded period. + let deadline = tokio::time::Instant::now() + Duration::from_millis(150); + loop { + let now = tokio::time::Instant::now(); + if now >= deadline { + break; + } + let remaining = deadline - now; + + let frame = match timeout(remaining, framer.read_frame(&mut stream)).await { + Ok(Ok(f)) => f, + Ok(Err(e)) => panic!("read_frame error: {e:?}"), + Err(_) => break, + }; + + match frame { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type & 0x0110, 0x0100); + responses += 1; + } + other => panic!("expected STUN response, got: {other:?}"), + } + } + + assert_eq!(responses, 1, "expected exactly 1 Binding response under burst=1"); + +} + +#[tokio::test] +async fn tcp_unauth_challenge_is_rate_limited_by_ip() { + init_tracing_with("warn,niom_turn=info"); + + let mut limits = LimitsOptions::default(); + limits.unauth_rps = Some(1); + limits.unauth_burst = Some(1); + let rate_limiters = Arc::new(RateLimiters::from_limits(&limits)); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = start_tcp_test_server(auth.clone(), allocs.clone(), rate_limiters.clone()).await; + + let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + + for _ in 0..3 { + let req = build_allocate_request(None, None, None, None, None); + stream.write_all(&req).await.expect("write allocate"); + } + + let mut framer = StreamFramer::new(); + let mut responses = 0usize; + + let deadline = tokio::time::Instant::now() + Duration::from_millis(150); + loop { + let now = tokio::time::Instant::now(); + if now >= deadline { + break; + } + let remaining = deadline - now; + + let frame = match timeout(remaining, framer.read_frame(&mut stream)).await { + Ok(Ok(f)) => f, + Ok(Err(e)) => panic!("read_frame error: {e:?}"), + Err(_) => break, + }; + + match frame { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR); + msg.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + responses += 1; + } + other => panic!("expected STUN response, got: {other:?}"), + } + } + + assert_eq!(responses, 1, "expected exactly 1 unauth challenge under burst=1"); +} diff --git a/tests/rate_limit_udp.rs b/tests/rate_limit_udp.rs new file mode 100644 index 0000000..68099fe --- /dev/null +++ b/tests/rate_limit_udp.rs @@ -0,0 +1,148 @@ +use std::sync::Arc; + +use niom_turn::alloc::AllocationManager; +use niom_turn::config::LimitsOptions; +use niom_turn::rate_limit::RateLimiters; +use tokio::net::UdpSocket; + +use crate::support::stun_builders::{build_allocate_request, build_binding_request, parse}; +use crate::support::{default_test_credentials, init_tracing, test_auth_manager}; + +mod support; + +#[tokio::test] +async fn udp_binding_is_rate_limited_by_ip() { + init_tracing(); + + // Configure a very small burst to make the test deterministic. + // We only limit Binding here to avoid affecting other integration tests. + let mut limits = LimitsOptions::default(); + limits.binding_rps = Some(1); + limits.binding_burst = Some(1); + let rate_limiters = Arc::new(RateLimiters::from_limits(&limits)); + + // Start a UDP server loop. + let server = UdpSocket::bind("127.0.0.1:0").await.expect("server bind"); + let server_addr = server.local_addr().expect("server addr"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let server_arc = Arc::new(server); + tokio::spawn({ + let reader = server_arc.clone(); + let auth = auth.clone(); + let allocs = allocs.clone(); + let rl = rate_limiters.clone(); + async move { + let _ = niom_turn::server::udp_reader_loop_with_limits(reader, auth, allocs, rl).await; + } + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + // Fire multiple Binding requests quickly; with burst=1 we should only get 1 success response. + for _ in 0..3 { + let req = build_binding_request(); + client + .send_to(&req, server_addr) + .await + .expect("send binding"); + } + + let mut buf = [0u8; 1500]; + let mut responses = 0usize; + + // Collect responses for a short, bounded period. + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(150); + loop { + let now = tokio::time::Instant::now(); + if now >= deadline { + break; + } + + let remaining = deadline - now; + match tokio::time::timeout(remaining, client.recv_from(&mut buf)).await { + Ok(Ok((len, _from))) => { + let resp = parse(&buf[..len]); + // Success class + assert_eq!(resp.header.msg_type & 0x0110, 0x0100); + responses += 1; + } + Ok(Err(e)) => panic!("recv error: {e}"), + Err(_) => break, // timeout + } + } + + assert_eq!(responses, 1, "expected exactly 1 Binding response under burst=1"); + +} + +#[tokio::test] +async fn udp_unauth_challenge_is_rate_limited_by_ip() { + init_tracing(); + + // Configure a very small burst so only the first unauth challenge is answered. + let mut limits = LimitsOptions::default(); + limits.unauth_rps = Some(1); + limits.unauth_burst = Some(1); + let rate_limiters = Arc::new(RateLimiters::from_limits(&limits)); + + // Start a UDP server loop. + let server = UdpSocket::bind("127.0.0.1:0").await.expect("server bind"); + let server_addr = server.local_addr().expect("server addr"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let server_arc = Arc::new(server); + tokio::spawn({ + let reader = server_arc.clone(); + let auth = auth.clone(); + let allocs = allocs.clone(); + let rl = rate_limiters.clone(); + async move { + let _ = niom_turn::server::udp_reader_loop_with_limits(reader, auth, allocs, rl).await; + } + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + for _ in 0..3 { + let req = build_allocate_request(None, None, None, None, None); + client + .send_to(&req, server_addr) + .await + .expect("send unauth allocate"); + } + + let mut buf = [0u8; 1500]; + let mut responses = 0usize; + + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(150); + loop { + let now = tokio::time::Instant::now(); + if now >= deadline { + break; + } + + let remaining = deadline - now; + match tokio::time::timeout(remaining, client.recv_from(&mut buf)).await { + Ok(Ok((len, _from))) => { + let resp = parse(&buf[..len]); + assert_eq!(resp.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR); + resp.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + responses += 1; + } + Ok(Err(e)) => panic!("recv error: {e}"), + Err(_) => break, + } + } + + assert_eq!(responses, 1, "expected exactly 1 unauth challenge under burst=1"); +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index bfe9058..05b191b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,4 +1,5 @@ pub mod mocks; +pub mod stream; pub mod stun_builders; pub mod tls; diff --git a/tests/support/stream.rs b/tests/support/stream.rs new file mode 100644 index 0000000..d605d46 --- /dev/null +++ b/tests/support/stream.rs @@ -0,0 +1,78 @@ +#![allow(dead_code)] + +use std::io; + +use niom_turn::models::stun::StunMessage; +use niom_turn::stun::parse_message; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; + +#[derive(Debug)] +pub enum StreamFrame { + Stun(StunMessage), + ChannelData { channel: u16, payload: Vec }, +} + +#[derive(Default)] +pub struct StreamFramer { + buffer: Vec, +} + +impl StreamFramer { + pub fn new() -> Self { + Self { buffer: Vec::new() } + } + + fn try_pop_next(&mut self) -> Option> { + if self.buffer.len() < 4 { + return None; + } + + // ChannelData: channel number 0x4000..=0x7FFF (top bits 01) + let channel = u16::from_be_bytes([self.buffer[0], self.buffer[1]]); + if (channel & 0xC000) == 0x4000 { + let len = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize; + let total = 4 + len; + if self.buffer.len() < total { + return None; + } + let frame = self.buffer.drain(..total).collect::>(); + return Some(Ok(StreamFrame::ChannelData { + channel, + payload: frame[4..].to_vec(), + })); + } + + // STUN over stream: 20 byte header + length. + if self.buffer.len() < 20 { + return None; + } + let len = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize; + let total = 20 + len; + if self.buffer.len() < total { + return None; + } + let chunk = self.buffer.drain(..total).collect::>(); + match parse_message(&chunk) { + Ok(msg) => Some(Ok(StreamFrame::Stun(msg))), + Err(e) => Some(Err(io::Error::new(io::ErrorKind::InvalidData, format!( + "parse stun: {e:?}" + )))), + } + } + + pub async fn read_frame(&mut self, reader: &mut R) -> io::Result { + loop { + if let Some(frame) = self.try_pop_next() { + return frame; + } + + let mut tmp = [0u8; 4096]; + let n = reader.read(&mut tmp).await?; + if n == 0 { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed")); + } + self.buffer.extend_from_slice(&tmp[..n]); + } + } +} diff --git a/tests/support/stun_builders.rs b/tests/support/stun_builders.rs index b840e8c..125a22a 100644 --- a/tests/support/stun_builders.rs +++ b/tests/support/stun_builders.rs @@ -21,7 +21,30 @@ pub fn build_allocate_request( key: Option<&[u8]>, lifetime: Option, ) -> Vec { - build_authenticated_request( + build_allocate_request_with_requested_transport( + username, + realm, + nonce, + key, + lifetime, + Some(IPPROTO_UDP), + ) +} + +/// Construct a TURN Allocate request with explicit REQUESTED-TRANSPORT. +/// +/// - `Some(IPPROTO_UDP)` (17) is the WebRTC default. +/// - `Some(IPPROTO_TCP)` (6) should be rejected by this server (UDP relay only). +/// - `None` omits the attribute (should be rejected after auth). +pub fn build_allocate_request_with_requested_transport( + username: Option<&str>, + realm: Option<&str>, + nonce: Option<&str>, + key: Option<&[u8]>, + lifetime: Option, + requested_transport: Option, +) -> Vec { + build_request_with_body( METHOD_ALLOCATE, username, realm, @@ -30,6 +53,24 @@ pub fn build_allocate_request( lifetime, None, None, + None, + requested_transport, + ) +} + +/// Build a basic STUN Binding request (no auth). +pub fn build_binding_request() -> Vec { + build_request_with_body( + METHOD_BINDING, + None, + None, + None, + None, + None, + None, + None, + None, + None, ) } @@ -52,6 +93,7 @@ pub fn build_refresh_request( None, None, Some(trans), + None, ) } @@ -73,6 +115,7 @@ pub fn build_create_permission_request( Some(peer), None, None, + None, ) } @@ -95,6 +138,7 @@ pub fn build_send_request( Some(peer), Some(payload), None, + None, ) } @@ -148,7 +192,16 @@ fn build_authenticated_request( payload: Option<&[u8]>, ) -> Vec { build_request_with_body( - method, username, realm, nonce, key, lifetime, peer, payload, None, + method, + username, + realm, + nonce, + key, + lifetime, + peer, + payload, + None, + None, ) } @@ -162,6 +215,7 @@ fn build_request_with_body( peer: Option<&std::net::SocketAddr>, payload: Option<&[u8]>, override_trans: Option<[u8; 12]>, + requested_transport: Option, ) -> Vec { let mut buf = BytesMut::new(); buf.extend_from_slice(&method.to_be_bytes()); @@ -182,6 +236,12 @@ fn build_request_with_body( if let Some(lifetime) = lifetime { push_u32_attr(&mut buf, ATTR_LIFETIME, lifetime); } + + if method == METHOD_ALLOCATE { + if let Some(proto) = requested_transport { + push_bytes_attr(&mut buf, ATTR_REQUESTED_TRANSPORT, &[proto, 0, 0, 0]); + } + } if let Some(peer) = peer { let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans); push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded); @@ -221,23 +281,21 @@ fn push_bytes_attr(buf: &mut BytesMut, typ: u16, data: &[u8]) { } fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) { - // position before adding MESSAGE-INTEGRITY attribute - let attribute_start = buf.len(); - // append attribute header and placeholder value buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&(20u16.to_be_bytes())); let value_start = buf.len(); buf.extend_from_slice(&[0u8; 20]); - // update message length to include the attribute (spec requires this before HMAC) - let total_len = (buf.len() - 20) as u16; + // update message length to end-of-MI (exclude any later attributes like FINGERPRINT) + let mi_end = buf.len(); + let total_len = (mi_end - 20) as u16; let len_bytes = total_len.to_be_bytes(); buf[2] = len_bytes[0]; buf[3] = len_bytes[1]; - // compute the HMAC over all bytes preceding the attribute (RFC 5389 §15.4) - let signed = compute_message_integrity(key, &buf[..attribute_start]); + // compute the HMAC over the message up to end-of-MI (MI value is zero here) + let signed = compute_message_integrity(key, &buf[..mi_end]); // write the computed MAC into the placeholder we appended above buf[value_start..value_start + 20].copy_from_slice(&signed[..20]); diff --git a/tests/tcp_turn.rs b/tests/tcp_turn.rs new file mode 100644 index 0000000..7d084ce --- /dev/null +++ b/tests/tcp_turn.rs @@ -0,0 +1,365 @@ +use std::net::SocketAddr; + +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{compute_a1_md5, InMemoryStore}; +use tokio::io::AsyncWriteExt; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::time::{timeout, Duration}; + +use crate::support::stream::{StreamFrame, StreamFramer}; +use crate::support::stun_builders::{ + build_allocate_request, build_channel_bind_request, build_create_permission_request, + build_send_request, +}; +use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager}; + +mod support; + +async fn start_tcp_test_server( + auth: niom_turn::auth::AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + if let Err(e) = niom_turn::turn_stream::handle_turn_stream_connection( + stream, peer, auth_clone, alloc_clone, + ) + .await + { + tracing::info!("tcp connection ended: {:?}", e); + } + }); + } + }); + + tcp_addr +} + +#[tokio::test] +async fn tcp_stream_resyncs_after_garbage_and_still_processes_allocate() { + init_tracing_with("warn,niom_turn=info"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await; + + let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + + // Send garbage bytes that look like a STUN header with a huge length but invalid cookie. + // Without resync, the server could wait for 20+65535 bytes and stall parsing. + let garbage = [0x00, 0x01, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00]; + stream.write_all(&garbage).await.expect("write garbage"); + + // Now send a valid Allocate request; server should still respond with a 401 challenge. + let allocate = build_allocate_request(None, None, None, None, None); + stream.write_all(&allocate).await.expect("write allocate"); + + let mut framer = StreamFramer::new(); + let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout response") + .expect("read response"); + + match frame { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR); + msg.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + } + other => panic!("expected STUN challenge, got: {:?}", other), + } +} + +#[tokio::test] +async fn tcp_peer_data_is_delivered_over_stream_as_data_indication() { + init_tracing_with("warn,niom_turn=info"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await; + + let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + let client_addr = stream.local_addr().expect("client addr"); + + // 1) Allocate without auth -> 401 + NONCE + let allocate = build_allocate_request(None, None, None, None, None); + stream.write_all(&allocate).await.expect("write allocate"); + + let mut framer = StreamFramer::new(); + let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout challenge") + .expect("read challenge"); + + let nonce = match challenge { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR); + let nonce_attr = msg + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8") + } + _ => panic!("expected STUN 401 challenge"), + }; + + // 2) Authenticated allocate + let key = compute_a1_md5(username, auth.realm(), password); + let allocate = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + stream.write_all(&allocate).await.expect("write auth allocate"); + + let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout alloc success") + .expect("read alloc success"); + + match alloc_success { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS); + msg.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS) + .expect("xor-relayed attr"); + } + _ => panic!("expected STUN allocate success"), + } + + let relay_addr = allocs + .get_allocation(&client_addr) + .expect("allocation exists") + .relay_addr; + + // 3) CreatePermission + let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind"); + let perm = build_create_permission_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + ); + stream.write_all(&perm).await.expect("write create permission"); + + let perm_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout perm resp") + .expect("read perm resp"); + match perm_resp { + StreamFrame::Stun(msg) => assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS), + _ => panic!("expected STUN permission success"), + } + + // 4) Send indication -> peer should receive UDP payload + let payload = b"hello-turn-tcp"; + let send = build_send_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + payload, + ); + stream.write_all(&send).await.expect("write send"); + + // This implementation responds to SEND with a success response. + let send_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout send resp") + .expect("read send resp"); + match send_resp { + StreamFrame::Stun(msg) => { + assert_eq!( + msg.header.msg_type, + niom_turn::constants::METHOD_SEND | niom_turn::constants::CLASS_SUCCESS + ); + } + _ => panic!("expected STUN send success"), + } + + let mut peer_buf = [0u8; 1500]; + let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf)) + .await + .expect("timeout peer recv") + .expect("peer recv"); + assert_eq!(&peer_buf[..n], payload); + assert!(from.ip().is_loopback()); + + // 5) Peer -> relay -> client: should come back over TCP as Data Indication + let back = b"peer-reply"; + peer_sock + .send_to(back, relay_addr) + .await + .expect("peer send back"); + + let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout data indication") + .expect("read data indication"); + + match frame { + StreamFrame::Stun(msg) => { + assert_eq!(msg.header.msg_type, niom_turn::constants::METHOD_DATA | niom_turn::constants::CLASS_INDICATION); + let data_attr = msg + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_DATA) + .expect("data attr"); + assert_eq!(data_attr.value.as_slice(), back); + } + _ => panic!("expected STUN data indication"), + } + + // sanity: allocation exists by client addr + assert!(allocs.get_allocation(&client_addr).is_some()); +} + +#[tokio::test] +async fn tcp_channel_data_round_trip_works() { + init_tracing_with("warn,niom_turn=info"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await; + + let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + + let allocate = build_allocate_request(None, None, None, None, None); + stream.write_all(&allocate).await.expect("write allocate"); + + let mut framer = StreamFramer::new(); + let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout challenge") + .expect("read challenge"); + let nonce = match challenge { + StreamFrame::Stun(msg) => { + let nonce_attr = msg + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8") + } + _ => panic!("expected STUN 401 challenge"), + }; + + let key = compute_a1_md5(username, auth.realm(), password); + let allocate = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + stream.write_all(&allocate).await.expect("write auth allocate"); + + let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout alloc success") + .expect("read alloc success"); + match alloc_success { + StreamFrame::Stun(msg) => { + msg.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS) + .expect("xor-relayed attr"); + } + _ => panic!("expected STUN allocate success"), + } + + let client_addr = stream.local_addr().expect("client addr"); + let relay_addr = allocs + .get_allocation(&client_addr) + .expect("allocation exists") + .relay_addr; + + let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind"); + + // Permission + let perm = build_create_permission_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + ); + stream.write_all(&perm).await.expect("write create permission"); + let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout perm resp") + .expect("read perm resp"); + + // ChannelBind + let channel: u16 = 0x4000; + let bind = build_channel_bind_request( + username, + auth.realm(), + &nonce, + &key, + channel, + &peer_sock.local_addr().unwrap(), + ); + stream.write_all(&bind).await.expect("write channel bind"); + + let bind_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout bind resp") + .expect("read bind resp"); + match bind_resp { + StreamFrame::Stun(msg) => assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS), + _ => panic!("expected STUN channel bind success"), + } + + // Client -> Server -> Peer via ChannelData + let payload = b"chan-hello"; + let ch = niom_turn::stun::build_channel_data(channel, payload); + stream.write_all(&ch).await.expect("write channel data"); + + let mut peer_buf = [0u8; 1500]; + let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf)) + .await + .expect("timeout peer recv") + .expect("peer recv"); + assert_eq!(&peer_buf[..n], payload); + assert!(from.ip().is_loopback()); + + // Peer -> Relay -> Client as ChannelData + let back = b"chan-back"; + peer_sock.send_to(back, relay_addr).await.expect("peer send back"); + + let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream)) + .await + .expect("timeout channel back") + .expect("read channel back"); + match frame { + StreamFrame::ChannelData { channel: chn, payload } => { + assert_eq!(chn, channel); + assert_eq!(payload.as_slice(), back); + } + _ => panic!("expected ChannelData frame"), + } +} diff --git a/tests/tls_data_plane.rs b/tests/tls_data_plane.rs new file mode 100644 index 0000000..1eb20d7 --- /dev/null +++ b/tests/tls_data_plane.rs @@ -0,0 +1,206 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use niom_turn::alloc::AllocationManager; +use niom_turn::auth::{compute_a1_md5, InMemoryStore}; +use tokio::io::AsyncWriteExt; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::time::{timeout, Duration}; +use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; + +use crate::support::stream::{StreamFrame, StreamFramer}; +use crate::support::stun_builders::{ + build_allocate_request, build_channel_bind_request, build_create_permission_request, +}; +use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager}; + +mod support; + +async fn start_tls_test_server( + acceptor: TlsAcceptor, + auth: niom_turn::auth::AuthManager, + allocs: AllocationManager, +) -> SocketAddr { + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); + let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); + + tokio::spawn(async move { + loop { + let (stream, peer) = match tcp_listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acceptor = acceptor.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + match acceptor.accept(stream).await { + Ok(tls_stream) => { + if let Err(e) = niom_turn::tls::handle_tls_connection( + tls_stream, + peer, + auth_clone, + alloc_clone, + ) + .await + { + tracing::info!("tls connection ended: {:?}", e); + } + } + Err(e) => tracing::error!("tls accept failed: {:?}", e), + } + }); + } + }); + + tcp_addr +} + +#[tokio::test] +async fn tls_channel_data_round_trip_works() { + init_tracing_with("warn,niom_turn=info"); + + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + + let (cert, key) = support::tls::generate_self_signed_cert(); + let mut cfg = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("server config"); + cfg.alpn_protocols.push(b"turn".to_vec()); + let acceptor = TlsAcceptor::from(Arc::new(cfg)); + + let server_addr = start_tls_test_server(acceptor, auth.clone(), allocs.clone()).await; + + // client config trusting generated cert + let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); + root_store.add(&cert).expect("add root"); + let client_config = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + let tcp_stream = TcpStream::connect(server_addr).await.expect("tcp connect"); + let client_addr = tcp_stream.local_addr().expect("client addr"); + let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap(); + let mut tls_stream = connector + .connect(domain, tcp_stream) + .await + .expect("tls connect"); + + // allocate (unauth -> nonce) + let allocate = build_allocate_request(None, None, None, None, None); + tls_stream.write_all(&allocate).await.expect("write allocate"); + + let mut framer = StreamFramer::new(); + let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) + .await + .expect("timeout challenge") + .expect("read challenge"); + let nonce = match challenge { + StreamFrame::Stun(msg) => { + let nonce_attr = msg + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr"); + String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8") + } + _ => panic!("expected STUN 401 challenge"), + }; + + // auth allocate + let key = compute_a1_md5(username, auth.realm(), password); + let allocate = build_allocate_request( + Some(username), + Some(auth.realm()), + Some(&nonce), + Some(&key), + Some(600), + ); + tls_stream.write_all(&allocate).await.expect("write auth allocate"); + + let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) + .await + .expect("timeout alloc success") + .expect("read alloc success"); + match alloc_success { + StreamFrame::Stun(msg) => { + msg.attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS) + .expect("xor-relayed attr"); + } + _ => panic!("expected STUN allocate success"), + } + + let relay_addr = allocs + .get_allocation(&client_addr) + .expect("allocation exists") + .relay_addr; + + // set up peer + permission + let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind"); + let perm = build_create_permission_request( + username, + auth.realm(), + &nonce, + &key, + &peer_sock.local_addr().unwrap(), + ); + tls_stream.write_all(&perm).await.expect("write create permission"); + let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) + .await + .expect("timeout perm resp") + .expect("read perm resp"); + + // channel bind + let channel: u16 = 0x4000; + let bind = build_channel_bind_request( + username, + auth.realm(), + &nonce, + &key, + channel, + &peer_sock.local_addr().unwrap(), + ); + tls_stream.write_all(&bind).await.expect("write channel bind"); + let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) + .await + .expect("timeout bind resp") + .expect("read bind resp"); + + // client -> peer via ChannelData + let payload = b"tls-chan"; + let frame = niom_turn::stun::build_channel_data(channel, payload); + tls_stream.write_all(&frame).await.expect("write channel data"); + + let mut peer_buf = [0u8; 1500]; + let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf)) + .await + .expect("timeout peer recv") + .expect("peer recv"); + assert_eq!(&peer_buf[..n], payload); + assert!(from.ip().is_loopback()); + + // peer -> client as ChannelData over TLS + let back = b"tls-back"; + peer_sock.send_to(back, relay_addr).await.expect("peer send back"); + + let received = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream)) + .await + .expect("timeout channel back") + .expect("read channel back"); + + match received { + StreamFrame::ChannelData { channel: ch, payload } => { + assert_eq!(ch, channel); + assert_eq!(payload.as_slice(), back); + } + other => panic!("expected ChannelData frame, got: {:?}", other), + } +} diff --git a/tests/tls_turn.rs b/tests/tls_turn.rs index f052605..34e0c19 100644 --- a/tests/tls_turn.rs +++ b/tests/tls_turn.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use niom_turn::alloc::AllocationManager; use niom_turn::stun::parse_message; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, UdpSocket}; +use tokio::net::TcpListener; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use crate::support::stun_builders::{build_allocate_request, build_refresh_request}; @@ -15,8 +15,6 @@ mod support; async fn tls_allocate_refresh_flow() { init_tracing_with("warn,niom_turn=info"); - let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"); - let udp_arc = Arc::new(udp); let (username, password) = default_test_credentials(); let auth = test_auth_manager(username, password); let allocs = AllocationManager::new(); @@ -32,7 +30,6 @@ async fn tls_allocate_refresh_flow() { let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); - let udp_clone = udp_arc.clone(); let auth_clone = auth.clone(); let alloc_clone = allocs.clone(); @@ -43,21 +40,18 @@ async fn tls_allocate_refresh_flow() { Err(_) => break, }; let acceptor = acceptor.clone(); - let udp_clone = udp_clone.clone(); let auth_clone = auth_clone.clone(); let alloc_clone = alloc_clone.clone(); tokio::spawn(async move { match acceptor.accept(stream).await { - Ok(mut tls_stream) => { + Ok(tls_stream) => { match niom_turn::tls::handle_tls_connection( - &mut tls_stream, + tls_stream, peer, - udp_clone, auth_clone, alloc_clone, ) - .await - { + .await { Ok(_) => {} Err(e) => { tracing::error!("tls connection error: {:?}", e); diff --git a/tests/udp_turn.rs b/tests/udp_turn.rs index 6c2bb6d..05a86aa 100644 --- a/tests/udp_turn.rs +++ b/tests/udp_turn.rs @@ -4,11 +4,14 @@ use std::sync::Arc; use niom_turn::alloc::AllocationManager; use niom_turn::auth::InMemoryStore; use niom_turn::server::udp_reader_loop; -use niom_turn::stun::parse_message; +use niom_turn::stun::{ + find_fingerprint, find_message_integrity, parse_message, validate_fingerprint_if_present, + validate_message_integrity, +}; use tokio::net::UdpSocket; use crate::support::stun_builders::{ - build_allocate_request, build_create_permission_request, build_refresh_request, + build_allocate_request, build_allocate_request_with_requested_transport, build_create_permission_request, build_refresh_request, build_send_request, new_transaction_id, parse, }; use crate::support::{default_test_credentials, init_tracing, test_auth_manager}; @@ -70,6 +73,120 @@ async fn allocate_requires_auth_then_succeeds() { let (len, _) = client.recv_from(&mut buf).await.expect("recv success"); let resp = parse(&buf[..len]); assert_eq!(resp.header.msg_type & 0x0110, 0x0100); + assert!(find_message_integrity(&resp).is_some()); + assert!(validate_message_integrity(&resp, &key)); + assert!(find_fingerprint(&resp).is_some()); + assert!(validate_fingerprint_if_present(&resp)); +} + +#[tokio::test] +async fn authenticated_allocate_rejects_missing_requested_transport() { + init_tracing(); + let (server, client_addr) = start_udp_server().await; + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_arc = Arc::new(server); + + let server_clone = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await; + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + // Get nonce via normal unauth Allocate (builder includes requested-transport=UDP) + let req = build_allocate_request(None, None, None, None, None); + client.send_to(&req, client_addr).await.expect("send unauth allocate"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge"); + let resp = parse_message(&buf[..len]).expect("parse 401"); + let nonce = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr") + .value + .clone(); + let nonce_str = String::from_utf8(nonce).expect("nonce utf8"); + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + + // Authenticated Allocate but omit REQUESTED-TRANSPORT => 400 + let req = build_allocate_request_with_requested_transport( + Some(username), + Some(auth.realm()), + Some(&nonce_str), + Some(&key), + Some(600), + None, + ); + client.send_to(&req, client_addr).await.expect("send auth allocate"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv error"); + let resp = parse(&buf[..len]); + assert_eq!(resp.header.msg_type & 0x0110, 0x0110); + assert!(find_message_integrity(&resp).is_some()); + assert!(validate_message_integrity(&resp, &key)); + assert!(find_fingerprint(&resp).is_some()); + assert!(validate_fingerprint_if_present(&resp)); + let code = crate::support::stun_builders::extract_error_code(&resp).expect("error code"); + assert_eq!(code, 400); +} + +#[tokio::test] +async fn authenticated_allocate_rejects_tcp_requested_transport() { + init_tracing(); + let (server, client_addr) = start_udp_server().await; + let (username, password) = default_test_credentials(); + let auth = test_auth_manager(username, password); + let allocs = AllocationManager::new(); + let server_arc = Arc::new(server); + + let server_clone = server_arc.clone(); + let auth_clone = auth.clone(); + let alloc_clone = allocs.clone(); + tokio::spawn(async move { + let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await; + }); + + let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind"); + + // Get nonce + let req = build_allocate_request(None, None, None, None, None); + client.send_to(&req, client_addr).await.expect("send unauth allocate"); + let mut buf = [0u8; 1500]; + let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge"); + let resp = parse_message(&buf[..len]).expect("parse 401"); + let nonce = resp + .attributes + .iter() + .find(|a| a.typ == niom_turn::constants::ATTR_NONCE) + .expect("nonce attr") + .value + .clone(); + let nonce_str = String::from_utf8(nonce).expect("nonce utf8"); + let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password); + + // Request TCP relay (unsupported): 442 + let req = build_allocate_request_with_requested_transport( + Some(username), + Some(auth.realm()), + Some(&nonce_str), + Some(&key), + Some(600), + Some(niom_turn::constants::IPPROTO_TCP), + ); + client.send_to(&req, client_addr).await.expect("send auth allocate"); + let (len, _) = client.recv_from(&mut buf).await.expect("recv error"); + let resp = parse(&buf[..len]); + assert_eq!(resp.header.msg_type & 0x0110, 0x0110); + assert!(find_message_integrity(&resp).is_some()); + assert!(validate_message_integrity(&resp, &key)); + assert!(find_fingerprint(&resp).is_some()); + assert!(validate_fingerprint_if_present(&resp)); + let code = crate::support::stun_builders::extract_error_code(&resp).expect("error code"); + assert_eq!(code, 442); } #[tokio::test] @@ -105,6 +222,10 @@ async fn refresh_zero_lifetime_releases_allocation() { let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh resp"); let resp = parse(&buf[..len]); assert_eq!(resp.header.msg_type & 0x0110, 0x0100); + assert!(find_message_integrity(&resp).is_some()); + assert!(validate_message_integrity(&resp, &key)); + assert!(find_fingerprint(&resp).is_some()); + assert!(validate_fingerprint_if_present(&resp)); let lifetime = resp .attributes .iter()