Add: MI-signed TURN responses, RFC MI validation, and deployment guide.

This commit is contained in:
ghost 2025-12-28 15:57:06 +01:00
parent 29c0d8c0cf
commit a42af38cfe
41 changed files with 4389 additions and 819 deletions

7
Cargo.lock generated
View File

@ -82,6 +82,12 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.9.4" version = "2.9.4"
@ -396,6 +402,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"base64 0.22.1",
"bytes", "bytes",
"crc32fast", "crc32fast",
"hex", "hex",

View File

@ -12,6 +12,7 @@ bytes = "1.4"
hmac = "0.12" hmac = "0.12"
sha1 = "0.10" sha1 = "0.10"
hex = "0.4" hex = "0.4"
base64 = "0.22"
# config and logging # config and logging
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }

View File

@ -147,9 +147,34 @@ Das Projekt kann eine JSON-Konfigdatei `appsettings.json` im Arbeitsverzeichnis
{ {
"server": { "server": {
"bind": "0.0.0.0:3478", "bind": "0.0.0.0:3478",
"udp_bind": null,
"tcp_bind": null,
"tls_bind": "0.0.0.0:5349",
"enable_udp": true,
"enable_tcp": true,
"enable_tls": true,
"tls_cert": null, "tls_cert": null,
"tls_key": null "tls_key": null
}, },
"relay": {
"relay_port_min": null,
"relay_port_max": null,
"relay_bind_ip": "0.0.0.0",
"advertised_ip": null
},
"logging": {
"default_directive": "warn,niom_turn=info"
},
"limits": {
"max_allocations_per_ip": null,
"max_permissions_per_allocation": null,
"max_channel_bindings_per_allocation": null,
"unauth_rps": null,
"unauth_burst": null,
"binding_rps": null,
"binding_burst": null
},
"credentials": [ "credentials": [
{ {
"username": "testuser", "username": "testuser",
@ -159,12 +184,25 @@ Das Projekt kann eine JSON-Konfigdatei `appsettings.json` im Arbeitsverzeichnis
"auth": { "auth": {
"realm": "niom-turn.local", "realm": "niom-turn.local",
"nonce_secret": null, "nonce_secret": null,
"nonce_ttl_seconds": 300 "nonce_ttl_seconds": 300,
"rest_secret": null,
"rest_max_ttl_seconds": 600
} }
} }
``` ```
Wenn `appsettings.json` vorhanden ist, verwendet der Server die `server.bind` Adresse, befüllt den Credential-Store aus dem `credentials`-Array und übernimmt zusätzlich Realm/Nonce-Einstellungen aus `auth`. Falls die Datei fehlt, verwendet der Server die internen Defaults (Bind `0.0.0.0:3478`, Demo-Cred `testuser`, Realm `niom-turn.local`). Wenn `appsettings.json` vorhanden ist, befüllt der Server den Credential-Store aus `credentials` und übernimmt Realm/Nonce/REST-Einstellungen aus `auth`.
Listener-Binds:
- `server.bind` ist der Legacy-Default (wird genutzt, wenn `udp_bind`/`tcp_bind` nicht gesetzt sind).
- TCP/TLS nutzen denselben TURN-over-Stream Handler; `turns:` läuft auf `server.tls_bind`.
- Wenn `server.enable_tls=true`, aber `tls_cert`/`tls_key` fehlen, wird der TLS-Listener übersprungen.
Relay/NAT:
- Hinter NAT sollte `relay.advertised_ip` auf die öffentliche IP gesetzt werden, damit Clients in XOR-RELAYED-ADDRESS eine erreichbare Adresse erhalten.
- Für Firewalls ist ein fixer Relay-Port-Range (`relay_port_min/max`) sinnvoll.
Details und Runbook: siehe `docs/config/runtime.md`.
Deployment & TLS / Long-term Auth roadmap Deployment & TLS / Long-term Auth roadmap
----------------------------------------- -----------------------------------------

View File

@ -1,9 +1,34 @@
{ {
"server": { "server": {
"bind": "0.0.0.0:3478", "bind": "0.0.0.0:3478",
"udp_bind": null,
"tcp_bind": null,
"tls_bind": "0.0.0.0:5349",
"enable_udp": true,
"enable_tcp": true,
"enable_tls": true,
"tls_cert": null, "tls_cert": null,
"tls_key": null "tls_key": null
}, },
"relay": {
"relay_port_min": null,
"relay_port_max": null,
"relay_bind_ip": "0.0.0.0",
"advertised_ip": null
},
"logging": {
"default_directive": "warn,niom_turn=info"
},
"limits": {
"max_allocations_per_ip": null,
"max_permissions_per_allocation": null,
"max_channel_bindings_per_allocation": null,
"unauth_rps": null,
"unauth_burst": null,
"binding_rps": null,
"binding_burst": null
},
"credentials": [ "credentials": [
{ {
"username": "testuser", "username": "testuser",
@ -13,6 +38,8 @@
"auth": { "auth": {
"realm": "niom-turn.local", "realm": "niom-turn.local",
"nonce_secret": null, "nonce_secret": null,
"nonce_ttl_seconds": 300 "nonce_ttl_seconds": 300,
"rest_secret": null,
"rest_max_ttl_seconds": 600
} }
} }

View File

@ -9,9 +9,34 @@
Config { Config {
server: ServerOptions { server: ServerOptions {
bind: String, bind: String,
udp_bind: Option<String>,
tcp_bind: Option<String>,
tls_bind: String,
enable_udp: bool,
enable_tcp: bool,
enable_tls: bool,
tls_cert: Option<String>, tls_cert: Option<String>,
tls_key: Option<String>, tls_key: Option<String>,
}, },
relay: RelayOptions {
relay_port_min: Option<u16>,
relay_port_max: Option<u16>,
relay_bind_ip: Option<String>,
advertised_ip: Option<String>,
},
logging: LoggingOptions {
default_directive: Option<String>,
},
limits: LimitsOptions {
max_allocations_per_ip: Option<u32>,
max_permissions_per_allocation: Option<u32>,
max_channel_bindings_per_allocation: Option<u32>,
unauth_rps: Option<u32>,
unauth_burst: Option<u32>,
binding_rps: Option<u32>,
binding_burst: Option<u32>,
},
credentials: Vec<CredentialEntry> { credentials: Vec<CredentialEntry> {
username: String, username: String,
password: String, password: String,
@ -23,7 +48,125 @@ Config {
- Bind: `0.0.0.0:3478` - Bind: `0.0.0.0:3478`
- Single Test Credential: `testuser` / `secretpassword` - Single Test Credential: `testuser` / `secretpassword`
## Runbook (Start & Betrieb)
### 1) Konfiguration anlegen
- Der Server lädt **immer** `appsettings.json` aus dem aktuellen Working Directory (siehe `Config::load_default()`).
- Als Basis kannst du `appsettings.example.json` nach `appsettings.json` kopieren und anpassen.
### 2) Server starten
```bash
# im Repo-Root
cargo run --bin niom-turn
```
Optional kannst du das Log-Level überschreiben:
```bash
RUST_LOG=info cargo run --bin niom-turn
```
Hinweise:
- Wenn `server.enable_tls=true`, aber `server.tls_cert`/`server.tls_key` fehlen, wird der TLS-Listener **übersprungen** (Info-Log).
- Beim Start loggt der Server, welche Listener aktiv sind und welche Relay-Optionen gelten.
### 3) Minimal-Konfig: UDP-only (einfachster Start)
```json
{
"server": {
"bind": "0.0.0.0:3478",
"enable_udp": true,
"enable_tcp": false,
"enable_tls": false,
"tls_cert": null,
"tls_key": null
},
"relay": {
"relay_bind_ip": "0.0.0.0",
"advertised_ip": null,
"relay_port_min": null,
"relay_port_max": null
},
"logging": {
"default_directive": "warn,niom_turn=info"
},
"limits": {
"max_allocations_per_ip": null,
"max_permissions_per_allocation": null,
"max_channel_bindings_per_allocation": null,
"unauth_rps": null,
"unauth_burst": null,
"binding_rps": null,
"binding_burst": null
},
"credentials": [
{ "username": "testuser", "password": "secretpassword" }
],
"auth": {
"realm": "niom-turn.local",
"nonce_secret": null,
"nonce_ttl_seconds": 300,
"rest_secret": null,
"rest_max_ttl_seconds": 600
}
}
```
### 4) Minimal-Konfig: UDP + TCP + TLS (für maximale WebRTC-Kompatibilität)
```json
{
"server": {
"bind": "0.0.0.0:3478",
"tls_bind": "0.0.0.0:5349",
"enable_udp": true,
"enable_tcp": true,
"enable_tls": true,
"tls_cert": "/etc/niom-turn/tls/fullchain.pem",
"tls_key": "/etc/niom-turn/tls/privkey.pem"
},
"relay": {
"relay_bind_ip": "0.0.0.0",
"advertised_ip": "203.0.113.10",
"relay_port_min": 49152,
"relay_port_max": 49252
},
"logging": {
"default_directive": "warn,niom_turn=info"
},
"limits": {
"max_allocations_per_ip": 50,
"max_permissions_per_allocation": 200,
"max_channel_bindings_per_allocation": 50,
"unauth_rps": 5,
"unauth_burst": 20,
"binding_rps": 50,
"binding_burst": 200
},
"credentials": [
{ "username": "testuser", "password": "secretpassword" }
]
}
```
Betriebs-Checkliste (kurz):
- Firewall öffnen: UDP/3478, TCP/3478, TCP/5349 sowie den UDP-Relay-Portbereich (`relay_port_min/max`).
- Hinter NAT: `relay.advertised_ip` auf die öffentliche IP setzen.
- Für TURN REST Credentials: `auth.rest_secret` setzen und ggf. feste `credentials` leer lassen.
## TODOs ## TODOs
- Shared Secret / REST API zur Credential-Verwaltung. - Shared Secret / REST API zur Credential-Verwaltung.
- Konfigurierbare TLS-Bind-Adresse (`turns` Standard 5349).
- Health-Port (HTTP) für Monitoring. - Health-Port (HTTP) für Monitoring.
- Rate-Limits/Quota (pro Username) und Relay-Port-Range produktiv setzen.
## Hinweise für produktiven Betrieb
- Wenn der Server hinter NAT läuft, setze `relay.advertised_ip` auf die öffentliche IP, damit Clients in XOR-RELAYED-ADDRESS eine erreichbare Adresse erhalten.
- Für Firewalls ist ein fester Relay-Port-Range sinnvoll (`relay_port_min/max`), damit nur dieser UDP-Bereich geöffnet werden muss.
- Allocations werden auch ohne Traffic periodisch auf Expiry geprüft und bereinigt (Housekeeping), damit Relay-Sockets/Tasks nicht dauerhaft hängen bleiben.

136
docs/deployment.md Normal file
View File

@ -0,0 +1,136 @@
# Deployment Guide (niom-turn)
This guide assumes a fresh Debian LXC (e.g., 10.0.0.22), Fritzbox port forwards are in place, and you want TURN reachable on 3478/udp+tcp and 5349/tcp with a UDP relay range (e.g., 49152-49200).
## 1) Install dependencies
```bash
sudo apt update
sudo apt install -y build-essential pkg-config libssl-dev curl git systemd
# Rust toolchain (stable)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
```
## 2) Clone and build
```bash
cd /opt
sudo mkdir -p niom-turn && sudo chown "$USER":"$USER" niom-turn
cd /opt/niom-turn
git clone https://github.com/<your-repo>/niom-turn.git .
cargo build --release
# Binary: target/release/niom-turn
```
## 3) Configuration
Create config dir and place TLS cert/key (exported from NPM) and config:
```bash
sudo mkdir -p /etc/niom-turn
sudo chown "$USER":"$USER" /etc/niom-turn
# place /etc/niom-turn/fullchain.pem and /etc/niom-turn/privkey.pem
```
Example `/etc/niom-turn/appsettings.json` (adjust realm, WAN IP, secrets):
```json
{
"logging": { "level": "info" },
"auth": {
"realm": "turn.example.com",
"nonce_ttl_seconds": 600,
"rest_secret": "CHANGE_ME_REST_SECRET",
"rest_max_ttl_seconds": 86400
},
"listeners": {
"udp": "0.0.0.0:3478",
"tcp": "0.0.0.0:3478",
"tls": {
"addr": "0.0.0.0:5349",
"cert_file": "/etc/niom-turn/fullchain.pem",
"key_file": "/etc/niom-turn/privkey.pem"
}
},
"relay": {
"bind_addr": "0.0.0.0",
"public_addr": "YOUR_WAN_IP",
"port_range": "49152-49200"
},
"rate_limits": {
"enabled": true,
"max_allocations_per_ip": 10,
"max_permissions_per_allocation": 10,
"max_channels_per_allocation": 10
}
}
```
- `public_addr` must be your public WAN IP (not the LXC IP).
- `rest_secret` is used for TURN REST credentials (time-based user/pass).
## 4) Systemd service
Install binary and user:
```bash
sudo cp /opt/niom-turn/target/release/niom-turn /usr/local/bin/niom-turn
sudo useradd --system --no-create-home --shell /usr/sbin/nologin niomturn
sudo chown root:root /usr/local/bin/niom-turn
sudo chmod 0755 /usr/local/bin/niom-turn
sudo chown -R niomturn:niomturn /etc/niom-turn
```
Create `/etc/systemd/system/niom-turn.service`:
```
[Unit]
Description=niom-turn
After=network.target
[Service]
User=niomturn
Group=niomturn
ExecStart=/usr/local/bin/niom-turn --config /etc/niom-turn/appsettings.json
Environment=RUST_LOG=debug,niom_turn=debug
Restart=on-failure
RestartSec=3
# Optional: LimitNOFILE=65535
[Install]
WantedBy=multi-user.target
```
Enable/start:
```bash
sudo systemctl daemon-reload
sudo systemctl enable --now niom-turn
```
## 5) Firewall (LXC)
Allow inbound: UDP 3478, TCP 3478, TCP 5349, UDP relay range (49152-49200). Outbound allow all.
## 6) Quick checks
- Listener ports: `ss -tulpen | grep -E '3478|5349'`
- Logs: `journalctl -u niom-turn -f`
- External TCP reachability (from Hotspot): `nc -vz turn.example.com 3478` and `nc -vz turn.example.com 5349`
- STUN/TURN test: `stunclient turn.example.com 3478 -u user -p pass` (or REST creds)
- WebRTC: open webrtc-internals / about:webrtc; ensure relay candidates show your WAN IP + ports in 49152-49200.
## 7) Fritzbox / Port forwards (reference)
- UDP 3478 → 10.0.0.22:3478
- TCP 3478 → 10.0.0.22:3478
- TCP 5349 → 10.0.0.22:5349
- UDP 49152-49200 → 10.0.0.22:49152-49200
Test from external network (Hotspot), not from LAN (avoid NAT loopback assumptions).
## 8) Tuning / next steps
- For more logs temporarily set `RUST_LOG=trace,niom_turn=trace` in the service env.
- Consider JSON logging + metrics export if you need richer observability.
- Keep certs renewed via NPM and re-export to the LXC.

View File

@ -4,6 +4,12 @@ Dokumentationsübersicht für den TURN/STUN-Server.
- [`architecture/data_flow.md`](architecture/data_flow.md) UDP/TLS-Loop, Allocation Manager. - [`architecture/data_flow.md`](architecture/data_flow.md) UDP/TLS-Loop, Allocation Manager.
- [`config/runtime.md`](config/runtime.md) Appsettings & Credentials. - [`config/runtime.md`](config/runtime.md) Appsettings & Credentials.
- [`turn_end_to_end_flow.md`](turn_end_to_end_flow.md) Sequenzgrafik: Allocate → Permission → Send/ChannelData → Rückweg (UDP + TLS).
- [`tcp_tls_data_plane.md`](tcp_tls_data_plane.md) Warum TCP/TLS Data-Plane wichtig ist und wie sie implementiert ist.
- [`mvp_gaps_and_rfc_notes.md`](mvp_gaps_and_rfc_notes.md) MVP-Lücken, RFC-Notizen und Auswirkungen.
- [`turn_rest_credentials.md`](turn_rest_credentials.md) TURN REST Credentials: Algorithmus, CLI-Tool und Betrieb.
- [`testing.md`](testing.md) Testübersicht: was abgedeckt ist und wie man es ausführt.
- [`testing_todo.md`](testing_todo.md) Vorschläge für zusätzliche Tests (Roadmap).
- Bereits vorhanden: `deploy_tls_lxc.md`, RFC-Referenzen (STUN/TURN Specs). - Bereits vorhanden: `deploy_tls_lxc.md`, RFC-Referenzen (STUN/TURN Specs).
## Zielsetzung ## Zielsetzung

View File

@ -0,0 +1,114 @@
# MVP-Lücken & RFC-Notizen (STUN/TURN)
Dieses Dokument listet **bewusst vereinfachte/fehlende Teile** im aktuellen `niom-turn` MVP auf, jeweils mit kurzer Auswirkung und Code-Anker.
> Ziel: Klarheit, was schon interoperabel ist, und wo (für Production/Interop) noch Arbeit nötig ist.
---
## 1) STUN Binding ist minimal
**Ist-Zustand**
- `METHOD_BINDING` beantwortet der Server als "Success" und enthält `XOR-MAPPED-ADDRESS` (IPv4+IPv6): [src/stun.rs](../src/stun.rs)
**Auswirkung**
- Für STUN-Diagnose/ICE-NAT-Discovery ist das damit deutlich interoperabler.
---
## 2) IPv6 wird nicht unterstützt (historisch)
**Ist-Zustand**
- XOR-Address Encoding/Decoding unterstützt IPv4 **und** IPv6 (XOR-Key: Magic Cookie + Transaction ID): [src/stun.rs](../src/stun.rs)
**Auswirkung**
- TURN/STUN kann IPv6-Adressen in XOR-ADDRESS Attributen korrekt verarbeiten.
---
## 3) TURN Allocate ist stark vereinfacht
**Ist-Zustand**
- `allocate_for` bindet immer ein UDP-Relay auf `0.0.0.0:0`: [src/alloc.rs](../src/alloc.rs)
- `REQUESTED-TRANSPORT` wird nicht ausgewertet (es gibt nur UDP-Relay).
- Weitere RFC-TURN Optionen (EVEN-PORT, RESERVATION-TOKEN, DONT-FRAGMENT, etc.) sind nicht implementiert.
**Auswirkung**
- Das MVP deckt den typischen UDP-Relay-Fall ab, aber nicht die volle TURN-Feature-Matrix.
---
## 4) Data Plane: UDP-relay ja, TCP-relay nein
**Ist-Zustand**
- Relay-Datenpfad ist UDP-only (Relay-Socket ist `tokio::net::UdpSocket`): [src/alloc.rs](../src/alloc.rs)
- TLS-Listener re-used Control-Plane Logik, aber kein TCP-Relay: [src/tls.rs](../src/tls.rs)
**Auswirkung**
- `turns:` (TLS) kann Control-Plane, aber der Rückweg der relayed Daten läuft derzeit weiterhin über UDP an die Client-UDP-Adresse.
- Für echte TURN-over-TCP/TLS Data Plane (Interoperabilität mit restriktiven Netzwerken) fehlt ein eigener TCP-Relay-Pfad.
---
## 5) Allocation- und Timer-Verhalten ist MVP-artig
**Ist-Zustand**
- Allocation-Lifetime wird über `refresh_allocation` gesetzt/geklemmt: [src/alloc.rs](../src/alloc.rs)
- `remove_allocation` entfernt den Eintrag.
- Der Relay-Task (spawn in `allocate_for`) läuft jedoch weiter und prüft nur noch, ob Allocation existiert: [src/alloc.rs](../src/alloc.rs)
**Auswirkung**
- Bei vielen Allocations kann das zu unnötigen Hintergrund-Tasks führen (Resource-Management/Backpressure fehlt).
---
## 6) Permissions & ChannelBindings sind minimal
**Ist-Zustand**
- Permission TTL ist statisch (300s) und wird nur durch erneutes `CreatePermission` erneuert: [src/alloc.rs](../src/alloc.rs)
- ChannelBindings nutzen denselben TTL-Wert (`PERMISSION_LIFETIME`): [src/alloc.rs](../src/alloc.rs)
**Auswirkung**
- Funktional ok für MVP, aber nicht unbedingt exakt RFC-getreu im Detail (separate Lifetime-Regeln/Refresh-Strategien sind üblich).
---
## 7) MESSAGE-INTEGRITY/FINGERPRINT
**Ist-Zustand**
- `MESSAGE-INTEGRITY` wird RFC-konform geprüft (HMAC-SHA1 über die Message bis inkl. MI-Attribut; Header-Länge wird dafür auf „Ende von MI“ gesetzt, d.h. kompatibel mit nachfolgenden Attributen wie `FINGERPRINT`): [src/stun.rs](../src/stun.rs)
- Für TURN-Responses nach erfolgreicher Authentisierung hängt der Server `MESSAGE-INTEGRITY` an und setzt `FINGERPRINT` als letztes Attribut: [src/server.rs](../src/server.rs), [src/turn_stream.rs](../src/turn_stream.rs)
- `FINGERPRINT` wird bei allen vom Server gebauten STUN-Nachrichten als letztes Attribut angehängt und bei eingehenden Nachrichten (falls vorhanden) validiert: [src/stun.rs](../src/stun.rs)
**Auswirkung**
- Bessere Browser/ICE-Interop und leichtes Hardening (Messages mit ungültigem `FINGERPRINT` werden verworfen).
---
## 8) Observability / Limits / Hardening fehlen (noch)
**Ist-Zustand**
- Keine Quotas pro User/IP, keine Rate-Limits, keine Bandbreitenlimits pro Allocation.
- Credential Store ist In-Memory (Test/Dev): [src/auth.rs](../src/auth.rs), Trait: [src/traits/credential_store.rs](../src/traits/credential_store.rs)
**Auswirkung**
- Für Production braucht es Limits, persistente Credentials, Monitoring/Metrics und härteres Error-Handling.
---
## 9) Weitere RFC-Ecken (nicht implementiert)
Typische Punkte, die im MVP fehlen/noch offen sind:
- Full attribute coverage (z.B. UNKNOWN-ATTRIBUTES, SOFTWARE, etc.)
- Vollständige STUN/TURN Class/Method Encoding nach RFC (hier bewusst vereinfacht über `METHOD | CLASS_*`): [src/constants.rs](../src/constants.rs)
- IPv6, Hairpinning-Sonderfälle, NAT-bezogene Interop-Edge-Cases
---
## Was bereits gut abgedeckt ist
- End-to-end UDP TURN Core: `Allocate``CreatePermission``Send`/ChannelData → Rückweg als Data Indication/ChannelData.
- Long-term Auth (Realm/Nonce + MI) mit klaren 401/438/437/403 Pfaden.
- TLS Listener (Control-Plane) mit STUN-Framing über TCP/TLS.
Siehe auch die Integrationstests: [tests/udp_turn.rs](../tests/udp_turn.rs) und [tests/tls_turn.rs](../tests/tls_turn.rs).

View File

@ -0,0 +1,98 @@
# TCP/TLS Data-Plane (TURN over TCP/TLS) in niom-turn
Dieses Dokument erklärt **wofür** eine TCP/TLS Data-Plane bei TURN gebraucht wird und **wie** `niom-turn` sie (aktuell) implementiert.
## Wofür wird das benötigt?
TURN hat zwei Arten von Traffic:
- **Control-Plane**: STUN/TURN Requests/Responses (Allocate, CreatePermission, ChannelBind, Refresh, …)
- **Data-Plane**: Nutzdaten zwischen Client und Peer, die über das Relay laufen (Send/Data-Indication oder ChannelData)
In vielen realen Netzen ist **UDP eingeschränkt oder komplett blockiert** (Corporate WLAN, Mobilfunk-APNs, Captive Portals, Proxy-Umgebungen). WebRTC/ICE versucht deshalb typischerweise:
1. UDP (schnell, bevorzugt)
2. TURN über TCP
3. TURN über TLS ("turns:") als letzte, aber oft funktionierende Option
Damit TURN über TCP/TLS wirklich nutzbar ist, muss nicht nur die Control-Plane über den Stream laufen, sondern auch der **Rückweg Peer → Client** (Data-Plane) über **dieselbe TCP/TLS Verbindung** beim Client ankommen.
## Was implementiert niom-turn konkret?
- Client ↔ Server Transport kann **UDP**, **TCP** oder **TLS** sein.
- Das Relay zum Peer ist weiterhin **UDP** (klassisches TURN UDP-Relay).
- Bei **TCP/TLS** liefert der Server die Data-Plane zurück an den Client über den **Stream** (statt über UDP an die Client-Adresse zu senden).
Das entspricht dem üblichen WebRTC-Fallback: "Client-Server über TCP/TLS, Peer-Transport über UDP".
## Architektur im Code
- Stream-Handler (TCP/TLS): [src/turn_stream.rs](../src/turn_stream.rs)
- TCP Listener: [src/tcp.rs](../src/tcp.rs)
- TLS Listener: [src/tls.rs](../src/tls.rs)
- Allocation + Relay: [src/alloc.rs](../src/alloc.rs)
### Schlüsselidee: `ClientSink`
Damit der Relay-Loop Peer-Pakete an unterschiedliche Client-Transporte schicken kann, gibt es in [src/alloc.rs](../src/alloc.rs) eine Abstraktion:
- `ClientSink::Udp { sock, addr }` → sendet Peer-Daten per `udp.send_to(..., addr)`
- `ClientSink::Stream { tx }` → queued Bytes in einen Writer-Task, der auf den TCP/TLS Stream schreibt
Wenn ein Client über TCP/TLS allocatet, wird die Allocation mit einem `ClientSink::Stream` erzeugt.
## Framing: STUN vs. ChannelData auf einem Byte-Stream
Auf UDP bekommt man Datagramme; auf TCP/TLS bekommt man einen **kontinuierlichen Byte-Stream**. TURN over TCP/TLS multiplexed:
- STUN/TURN Messages (Control-Plane)
- ChannelData Frames (Data-Plane, Client → Server)
`niom-turn` parst daher im Stream in einer Schleife "nächstes Frame" (siehe `try_pop_next_frame(...)` in [src/turn_stream.rs](../src/turn_stream.rs)):
### STUN Message
- Header ist 20 Bytes
- Length-Feld ist die Body-Länge
- Gesamtlänge ist: $20 + length$
### ChannelData Frame
- 4 Byte Header: `CHANNEL-NUMBER` (2) + `LENGTH` (2)
- Channel-Nummern liegen im Bereich `0x4000..=0x7FFF` (Top-Bits `01`)
- Gesamtlänge ist: $4 + length$
Wichtig: Bei TCP/TLS darf **kein Padding** als "separate Bytes" im Stream verbleiben. Deshalb baut `niom-turn` ChannelData als exakt `4 + len` Bytes (siehe [src/stun.rs](../src/stun.rs)).
### Hardening: Resync & Limits
Da TCP/TLS ein Byte-Stream ist, können kaputte oder bösartige Clients den Parser sonst leicht „desynchronisieren“.
`niom-turn` implementiert daher im Stream-Parser:
- **Magic-Cookie Check** für STUN: Ungültige Cookies führen zu einem Byte-weisen Resync (statt auf riesige Längen zu warten).
- **Frame-Size Limits** (STUN-Body und ChannelData), um Speicher-/DoS-Risiken zu begrenzen.
- **Max Buffer Limit** pro Verbindung: wenn der Eingangspuffer zu groß wird, wird die Verbindung geschlossen.
## Datenfluss (TCP/TLS)
1. Client verbindet sich per TCP oder TLS.
2. Der Stream-Handler liest Frames:
- STUN/TURN Requests → verarbeitet wie UDP-Pfad (Auth, Allocation, Permission, ChannelBind, Send, Refresh)
- ChannelData (Client→Peer) → wird über das UDP-Relay an den Peer geschickt
3. Peer sendet UDP an die Relay-Adresse.
4. Relay-Loop leitet die Bytes an den `ClientSink` weiter:
- bei Stream: `tx.send(bytes)` → Writer-Task schreibt Data-Indication oder ChannelData zurück auf denselben Stream
## Grenzen / Noch nicht implementiert
- Kein TCP-Relay zum Peer (TURN TCP allocations / CONNECT-Methoden wie in RFC6062).
- Fokus liegt auf: Client-Server Transport über TCP/TLS + UDP-Relay.
- IPv6 ist im aktuellen MVP noch nicht vollständig umgesetzt.
## Tests
- TCP Stream Data-Plane: `tests/tcp_turn.rs`
- TLS Stream Data-Plane: `tests/tls_data_plane.rs`
- Gemeinsames Framing (STUN + ChannelData): `tests/support/stream.rs`
Wenn du willst, kann ich als nächsten Schritt die Doku um eine kurze Interop-Checkliste (WebRTC/ICE Verhalten, Candidate-Types, typische Fehlerbilder) ergänzen.

60
docs/testing.md Normal file
View File

@ -0,0 +1,60 @@
# Testing
Dieses Projekt ist so aufgebaut, dass sich die wichtigsten TURN-Pfade als **Unit- und Integrationstests** ausführen lassen.
## Schnellstart
- Alle Tests: `cargo test`
- Mit weniger Output: `cargo test -q`
- Mit Logs (Beispiele):
- `RUST_LOG=warn,niom_turn=info cargo test -- --nocapture`
Hinweis: Die Integrationstests initialisieren `tracing` über die Helpers in `tests/support`.
## Was wird getestet?
### STUN RFC-Interop (FINGERPRINT)
- Unit-Tests prüfen, dass der Server `FINGERPRINT` an Responses anhängt und dass die CRC32/XOR Validierung fehlschlägt, wenn die Nachricht manipuliert wird.
Siehe: Unit-Tests in [src/stun.rs](../src/stun.rs).
### STUN RFC-Interop (MESSAGE-INTEGRITY)
- Unit-Tests prüfen `MESSAGE-INTEGRITY` Validierung (inkl. Fall „MI + nachfolgendes FINGERPRINT“).
- UDP-Integrationstests prüfen, dass Responses nach erfolgreicher Authentisierung `MESSAGE-INTEGRITY` enthalten und validierbar sind.
### UDP (turn:)
- Auth-Challenge (401 + NONCE) und erfolgreicher Allocate
- Refresh mit `LIFETIME=0` entfernt Allocation
- CreatePermission + Send → Peer erhält UDP Payload über Relay
Siehe: `tests/udp_turn.rs` sowie die thematischen Ordner in `tests/`.
### TLS (turns:) / Stream-basierte Data-Plane
- Allocate/Refresh über TLS-Stream
- (Neu) Data-Plane Rückweg Peer→Client über denselben TLS-Stream (Data-Indication oder ChannelData)
Siehe: `tests/tls_turn.rs` und `tests/tls_data_plane.rs`.
### TCP (turn:?transport=tcp) / Stream-basierte Data-Plane
- (Neu) Allocate/CreatePermission/Send über TCP
- (Neu) Peer→Client Rückweg als STUN Data Indication über TCP
- (Neu) ChannelBind + ChannelData in beide Richtungen
Siehe: `tests/tcp_turn.rs`.
## Test-Hilfen
- STUN/TURN Builder: `tests/support/stun_builders.rs`
- Stream-Framing (STUN + ChannelData über TCP/TLS): `tests/support/stream.rs`
- TLS Test-Certs: `tests/support/tls.rs`
## Erweiterungsideen (nächste sinnvolle Abdeckung)
- Split-Reads/Writes (Frames in mehreren TCP Reads) als Regressionstest
- IPv6 Encode/Decode Tests für XOR-ADDRESS Varianten
- Negative Tests: Peer nicht permitted, Channel ohne Bind, Allocation Timeout

59
docs/testing_todo.md Normal file
View File

@ -0,0 +1,59 @@
# Test-ToDo (Vorschläge)
Dieses Dokument sammelt **konkrete** Test-Ideen, die den sicheren/stabilen Betrieb (insb. unter Last/Fehlverhalten) absichern sollen.
## Stream (TCP/TLS) Robustheit
- Split-Reads: STUN Header (20B) in 2 Reads, Body in mehreren Reads
- Split-Reads: ChannelData Header (4B) und Payload getrennt
- Mixed Frames: STUN → ChannelData → STUN in einem Read (und in mehreren Reads)
- Oversize Frames:
- STUN Length > Max → Verbindung wird geschlossen (oder Frame gedroppt, je nach Policy)
- ChannelData Length > Max → Verbindung wird geschlossen (oder Frame gedroppt)
- Garbage Resync:
- Zufallsbytes vor gültigem STUN (bereits abgedeckt)
- Zufallsbytes zwischen gültigen Frames
## TURN Flows (Happy + Negative)
- Negative pro Methode (UDP/TCP/TLS jeweils):
- ohne Allocation → 437 Allocation Mismatch
- ohne Permission → 403 Peer Not Permitted
- ChannelData ohne ChannelBind → drop + optional log counter
- Stale Nonce → 438
- falsches MI → 401/403 je nach Policy
## Auth
- TURN REST:
- abgelaufener Username → reject
- Username zu weit in der Zukunft (max TTL) → reject
- falsches HMAC/base64 → reject
- „user exists in store“ vs. „REST fallback“ Priorität
## Lifecycle
- Allocation expiry:
- Refresh verkürzt/verlängert, Min/Max Lifetime
- Expiry entfernt Allocation und beendet Relay-Task (keine Task-Leaks)
- Permission expiry:
- Peer wird nach Ablauf verworfen
- Channel binding expiry:
- Rückweg fällt auf Data Indication zurück, wenn Binding abläuft
## Abuse-/DoS Prevention (sobald Limits implementiert sind)
- Rate limit: auth failures pro IP/Username
- Max allocations pro IP
- Max permissions/channels pro allocation
- Bandwidth caps (bytes/s) pro allocation
- Backpressure: Writer-Queue voll → Verhalten definieren (drop/close)
## Interop (manuell reproduzierbar, aber dokumentiert)
- Browser Plan:
- Trickle ICE / webrtc-internals: forced relay
- UDP-only block: erwarte TCP/TLS fallback
- `turns:` mit self-signed vs. valid cert
Wenn du willst, kann ich als nächsten Schritt aus diesen Punkten eine priorisierte Test-Roadmap machen (P0/P1/P2) und direkt die nächsten P0-Tests implementieren.

View File

@ -0,0 +1,139 @@
# End-to-End TURN Flow (UDP + TLS)
Dieses Dokument beschreibt den **konkret implementierten** End-to-End Ablauf in `niom-turn` anhand des aktuellen Codes (MVP):
- UDP-Control-Plane und UDP-Data-Plane: [src/server.rs](../src/server.rs), [src/alloc.rs](../src/alloc.rs), [src/stun.rs](../src/stun.rs), [src/auth.rs](../src/auth.rs)
- TLS-Control-Plane ("turns") mit STUN-Framing: [src/tls.rs](../src/tls.rs)
- Test-Builder (wie Requests gebaut werden): [tests/support/stun_builders.rs](../tests/support/stun_builders.rs)
## Begriffe
- **Client**: TURN-Client (z.B. WebRTC ICE Agent)
- **Server**: `niom-turn` (dieses Projekt)
- **Peer**: Gegenstelle, zu der relayed werden soll (typischerweise anderer WebRTC-Endpunkt)
- **Allocation**: Server-seitige Sitzung, die ein **Relay-UDP-Socket** bereitstellt
- **Permission**: Erlaubnis, zu einem Peer zu senden/Peer-Pakete zu akzeptieren
- **Channel Binding**: Zuordnung `channel-number -> peer`, um ChannelData nutzen zu können
---
## UDP: Sequenzgrafik (Happy Path)
```mermaid
sequenceDiagram
autonumber
participant C as Client (TURN)
participant S as niom-turn (UDP:3478)
participant R as Relay Socket (UDP:ephemeral)
participant P as Peer
Note over C,S: 1) Allocate ohne Auth → 401 Challenge
C->>S: STUN Allocate Request (ohne MI)
S->>C: STUN Error Response 401 + REALM + NONCE
Note over C,S: 2) Allocate mit Long-Term Auth
C->>S: STUN Allocate Request + USERNAME/REALM/NONCE + MESSAGE-INTEGRITY
S->>R: bind("0.0.0.0:0"), spawn Relay-Loop
S->>C: Allocate Success + XOR-RELAYED-ADDRESS + LIFETIME (+ MESSAGE-INTEGRITY + FINGERPRINT)
Note over C,S: 3) CreatePermission (Pflicht vor Send/ChannelBind)
C->>S: CreatePermission + XOR-PEER-ADDRESS (+ Auth + MI)
S->>C: Success (200) (+ MESSAGE-INTEGRITY + FINGERPRINT)
Note over C,S: 4) Send (Client→Peer via Relay)
C->>S: Send + XOR-PEER-ADDRESS + DATA (+ Auth + MI)
S->>R: relay.send_to(DATA, Peer)
R->>P: UDP payload (source = relay_addr)
Note over P,C: 5) Rückweg (Peer→Client)
P->>R: UDP payload (dest = relay_addr)
alt Channel Binding existiert
R->>S: recv_from(Peer)
S->>C: ChannelData(channel, payload)
else Kein Channel Binding
R->>S: recv_from(Peer)
S->>C: Data Indication (METHOD_DATA|INDICATION) + XOR-PEER-ADDRESS + DATA
end
Note over C,S: 6) Optional: ChannelBind + ChannelData
C->>S: ChannelBind + CHANNEL-NUMBER + XOR-PEER-ADDRESS (+ Auth + MI)
S->>C: Success (200) (+ MESSAGE-INTEGRITY + FINGERPRINT)
C->>S: ChannelData(channel, payload)
S->>R: relay.send_to(payload, Peer)
R->>P: UDP payload
Note over C,S: 7) Refresh
C->>S: Refresh + LIFETIME (+ Auth + MI)
S->>C: Success + LIFETIME(applied) (+ MESSAGE-INTEGRITY + FINGERPRINT)
```
### Was der Server dabei **genau** macht
- Eingangspunkt: `udp_reader_loop` in [src/server.rs](../src/server.rs)
- Frühe Abzweigung: Wenn `parse_channel_data(...)` erfolgreich ist, wird **kein STUN** geparst, sondern ChannelData direkt weitergeleitet (nur wenn Allocation+Binding+Permission passt).
- STUN/TURN Requests werden mit `parse_message(...)` in [src/stun.rs](../src/stun.rs) geparst.
### RFC-Interop Hinweis: FINGERPRINT
- Alle vom Server gebauten STUN-Nachrichten enthalten `FINGERPRINT` als letztes Attribut.
- Wenn ein Client `FINGERPRINT` mitsendet, wird es validiert; bei ungültigem `FINGERPRINT` wird die Nachricht verworfen (kein Response).
### Auth-Entscheidungen und typische Error-Codes
Die Auth-Policy ist zentral in `AuthManager::authenticate` in [src/auth.rs](../src/auth.rs).
- **401 Unauthorized**: Wenn `MESSAGE-INTEGRITY` fehlt → Challenge mit `REALM` + `NONCE`.
- **438 Stale Nonce**: Wenn `NONCE` abgelaufen/ungültig ist → neue Challenge.
- **437 Allocation Mismatch**: Wenn CreatePermission/Send/ChannelBind/Refresh ohne Allocation kommt.
- **403 Peer Not Permitted**: Wenn ein Peer nicht (mehr) permitted ist.
- **400 Missing/Invalid ...**: Wenn Attribute fehlen oder XOR-PEER-ADDRESS nicht dekodierbar ist.
---
## TLS (turns): Sequenzgrafik (Control-Plane)
Wichtig: Die TLS-Implementierung in [src/tls.rs](../src/tls.rs) nutzt denselben TURN-Handler wie TCP und implementiert eine echte **Stream-Data-Plane**.
```mermaid
sequenceDiagram
autonumber
participant C as Client (turns)
participant T as niom-turn (TLS:5349)
participant R as Relay Socket (UDP:ephemeral)
participant P as Peer
Note over C,T: STUN framing über TCP/TLS: read → chunk by (len+20)
C->>T: STUN Allocate (ohne MI)
T->>C: 401 + REALM + NONCE (über TLS)
C->>T: STUN Allocate + Auth + MI
T->>R: allocate_for(peer, stream-sink)
T->>C: Allocate Success + XOR-RELAYED-ADDRESS + LIFETIME (über TLS)
C->>T: CreatePermission/Send/Refresh/ChannelBind (über TLS)
T->>C: Success/Error (über TLS)
Note over P,C: Peer-Daten kommen über den TLS-Stream zurück
P->>R: UDP payload an relay_addr
R->>T: recv_from(Peer)
T->>C: Data Indication / ChannelData (über TLS)
```
### Konsequenz
- Control-Plane über TLS funktioniert (Allocate/Refresh/… werden über TLS beantwortet).
- Der **Data-Plane Rückweg** (Peer → Client) läuft ebenfalls über den TLS-Stream (Relay → `ClientSink::Stream`).
Mehr Details dazu: [docs/tcp_tls_data_plane.md](tcp_tls_data_plane.md)
---
## Mini-Checkliste: Minimaler Ablauf (praktisch)
1. `ALLOCATE` ohne MI → `401` + `REALM` + `NONCE`
2. `ALLOCATE` mit `USERNAME/REALM/NONCE` + `MESSAGE-INTEGRITY``XOR-RELAYED-ADDRESS` + `LIFETIME`
3. `CREATE_PERMISSION` für Peer → `200`
4. `SEND` mit `DATA` → Server sendet via Relay an Peer
5. Peer sendet zurück an Relay → Server liefert an Client als `DATA-INDICATION` oder `CHANNEL-DATA`
6. Optional `CHANNEL_BIND` + ChannelData für effizientere Data-Plane
7. `REFRESH` zum Verlängern oder `LIFETIME=0` zum Freigeben

View File

@ -0,0 +1,75 @@
# TURN REST Credentials (Ephemeral) Nutzung & Betrieb
Dieses Dokument erklärt die **TURN REST Credential** Strategie für `niom-turn` (MVP → production-fähiger Pfad).
## Warum TURN REST?
- Du willst für WebRTC *kurzlebige* TURN-Logins ausstellen (z.B. 510 Minuten).
- Dein TURN-Server speichert keine User-Passwörter pro Nutzer.
- Dein Backend (später) kann Tokens ausstellen; bis dahin kannst du lokal/ops-seitig Tokens generieren.
## Verfahren (kompatibel zu gängigen WebRTC-Stacks)
**Username**: `<expiry_unix_seconds>` oder `<expiry_unix_seconds>:<opaque_user_id>`
- Beispiel: `1735412345:alice`
- `expiry_unix_seconds` ist ein Unix-Timestamp in Sekunden.
**Credential (Password)**:
$$\n\text{credential} = \text{base64}(\text{HMAC-SHA1}(\text{secret}, \text{username}))\n$$
Dieses `credential` wird dann ganz normal als TURN-Passwort im ICE-Server-Config verwendet.
## Server-Seite (`niom-turn`)
Konfiguration in [appsettings.example.json](../appsettings.example.json):
- `auth.rest_secret`: Shared Secret (muss geheim bleiben)
- `auth.rest_max_ttl_seconds`: Maximal akzeptiertes Zeitfenster in die Zukunft. Wenn ein Token zu weit in der Zukunft liegt, wird es abgewiesen (Sicherheitsmaßnahme).
Wichtig:
- Der Server akzeptiert TURN REST Credentials als Fallback **nur wenn** der Username nicht im CredentialStore gefunden wird.
- Ablauf/TTL:
- `expiry` muss >= `now` sein
- und `expiry - now <= rest_max_ttl_seconds`
## Lokal Tokens generieren (ohne Backend)
Das Repo enthält ein kleines CLI:
- Binary: [src/bin/turn_rest_cred.rs](../src/bin/turn_rest_cred.rs)
Beispiele:
1) Einfach (stdout als ENV-Zeilen)
`cargo run --bin turn_rest_cred -- --secret "SUPER_SECRET" --user alice --ttl 600`
2) JSON Ausgabe
`cargo run --bin turn_rest_cred -- --secret "SUPER_SECRET" --user alice --ttl 600 --json`
Du erhältst:
- `username` → in WebRTC als `iceServers[].username`
- `credential` → in WebRTC als `iceServers[].credential`
## WebRTC Beispiel (Frontend)
Du setzt in deiner App typischerweise:
- URLs (mindestens UDP + TLS):
- `turn:your-domain:3478?transport=udp`
- `turn:your-domain:3478?transport=tcp`
- `turns:your-domain:5349?transport=tcp`
und dann:
- `username`: vom Generator/Backend
- `credential`: vom Generator/Backend
## Betriebshinweise
- **Secret Rotation**: Plane Rotation (z.B. zwei Secrets parallel) sonst brechen Tokens im Umlauf.
- **TTL klein halten**: 510 Minuten ist typisch.
- **Logs**: Niemals `secret` oder vollständige Credentials loggen.
- **Rate Limits/Quotas**: Unbedingt ergänzen (Open-Relay/Abuse vermeiden).

View File

@ -1,20 +1,56 @@
//! Allocation manager: provisions relay sockets and forwards packets for TURN allocations. //! Allocation manager: provisions relay sockets and forwards packets for TURN allocations.
//! Backlog: permission tables, channel bindings, allocation refresh timers, and rate limits. //! Backlog: permission tables, channel bindings, allocation refresh timers, and rate limits.
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tracing::info; use tracing::info;
use crate::stun::{build_channel_data, build_data_indication}; use crate::stun::{build_channel_data, build_data_indication};
#[derive(thiserror::Error, Debug)]
pub enum AllocationError {
#[error("allocation quota exceeded")]
AllocationQuotaExceeded,
#[error("permission quota exceeded")]
PermissionQuotaExceeded,
#[error("channel binding quota exceeded")]
ChannelQuotaExceeded,
}
#[derive(Clone)]
pub enum ClientSink {
Udp { sock: Arc<UdpSocket>, addr: SocketAddr },
Stream { tx: mpsc::Sender<Vec<u8>> },
}
impl ClientSink {
pub async fn send(&self, data: Vec<u8>) -> anyhow::Result<()> {
match self {
ClientSink::Udp { sock, addr } => {
sock.send_to(&data, addr).await?;
Ok(())
}
ClientSink::Stream { tx } => {
tx.send(data)
.await
.map_err(|_| anyhow::anyhow!("client stream closed"))
}
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Allocation { pub struct Allocation {
pub client: SocketAddr, pub client: SocketAddr,
pub relay_addr: SocketAddr, pub relay_addr: SocketAddr,
// keep the socket so it stays bound // keep the socket so it stays bound
_socket: Arc<UdpSocket>, _socket: Arc<UdpSocket>,
stop: Arc<Notify>,
permissions: Arc<Mutex<HashMap<SocketAddr, Instant>>>, permissions: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
channel_bindings: Arc<Mutex<HashMap<u16, (SocketAddr, Instant)>>>, channel_bindings: Arc<Mutex<HashMap<u16, (SocketAddr, Instant)>>>,
expiry: Arc<Mutex<Instant>>, expiry: Arc<Mutex<Instant>>,
@ -23,36 +59,159 @@ pub struct Allocation {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct AllocationManager { pub struct AllocationManager {
inner: Arc<Mutex<HashMap<SocketAddr, Allocation>>>, inner: Arc<Mutex<HashMap<SocketAddr, Allocation>>>,
opts: AllocationOptions,
}
#[derive(Clone, Debug)]
pub struct AllocationOptions {
pub relay_bind_ip: IpAddr,
pub relay_port_min: Option<u16>,
pub relay_port_max: Option<u16>,
pub advertised_ip: Option<IpAddr>,
pub max_allocations_per_ip: Option<u32>,
pub max_permissions_per_allocation: Option<u32>,
pub max_channel_bindings_per_allocation: Option<u32>,
}
impl Default for AllocationOptions {
fn default() -> Self {
Self {
relay_bind_ip: IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
relay_port_min: None,
relay_port_max: None,
advertised_ip: None,
max_allocations_per_ip: None,
max_permissions_per_allocation: None,
max_channel_bindings_per_allocation: None,
}
}
} }
impl AllocationManager { impl AllocationManager {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
inner: Arc::new(Mutex::new(HashMap::new())), inner: Arc::new(Mutex::new(HashMap::new())),
opts: AllocationOptions::default(),
}
}
pub fn new_with_options(opts: AllocationOptions) -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
opts,
}
}
/// Translate a locally-bound relay socket address into the address that should be
/// advertised to clients (e.g. replace 0.0.0.0 with a public IP).
pub fn relay_addr_for_response(&self, relay_local: SocketAddr) -> SocketAddr {
match self.opts.advertised_ip {
Some(ip) => SocketAddr::new(ip, relay_local.port()),
None => relay_local,
} }
} }
/// Create a relay UDP socket for the given client and spawn a relay loop that forwards /// Create a relay UDP socket for the given client and spawn a relay loop that forwards
/// any packets received on the relay socket back to the client via the provided server socket. /// any packets received on the relay socket back to the client via the provided client sink.
pub async fn allocate_for( pub async fn allocate_for(
&self, &self,
client: SocketAddr, client: SocketAddr,
server_sock: Arc<UdpSocket>, client_sink: ClientSink,
) -> anyhow::Result<SocketAddr> { ) -> anyhow::Result<SocketAddr> {
// bind relay socket to OS-chosen port // If an allocation already exists for this exact 5-tuple, reuse it.
let relay = UdpSocket::bind("0.0.0.0:0").await?; {
let mut guard = self.inner.lock().unwrap();
prune_expired_locked(&mut guard);
if let Some(existing) = guard.get(&client) {
return Ok(existing.relay_addr);
}
if let Some(max) = self.opts.max_allocations_per_ip {
let count_for_ip = guard
.values()
.filter(|a| a.client.ip() == client.ip())
.count() as u32;
if count_for_ip >= max {
return Err(anyhow::anyhow!(AllocationError::AllocationQuotaExceeded));
}
}
}
// bind relay socket (optionally within configured port range)
let relay = match (self.opts.relay_port_min, self.opts.relay_port_max) {
(Some(min), Some(max)) if min > 0 && max >= min => {
let mut bound: Option<UdpSocket> = None;
let mut last_err: Option<anyhow::Error> = None;
for port in min..=max {
let addr = SocketAddr::new(self.opts.relay_bind_ip, port);
match UdpSocket::bind(addr).await {
Ok(sock) => {
bound = Some(sock);
last_err = None;
break;
}
Err(e) => {
last_err = Some(anyhow::anyhow!(e));
}
}
}
match bound {
Some(sock) => sock,
None => {
let detail = last_err
.map(|e| format!("{e:?}"))
.unwrap_or_else(|| "no ports in range available".to_string());
return Err(anyhow::anyhow!(
"failed to bind relay socket in range {}-{} on {}: {}",
min,
max,
self.opts.relay_bind_ip,
detail
));
}
}
}
_ => {
let addr = SocketAddr::new(self.opts.relay_bind_ip, 0);
UdpSocket::bind(addr).await?
}
};
let relay_local = relay.local_addr()?; let relay_local = relay.local_addr()?;
let relay_arc = Arc::new(relay); let relay_arc = Arc::new(relay);
// Insert allocation before spawning relay loop to avoid races.
let stop = Arc::new(Notify::new());
let alloc = Allocation {
client,
relay_addr: relay_local,
_socket: relay_arc.clone(),
stop: stop.clone(),
permissions: Arc::new(Mutex::new(HashMap::new())),
channel_bindings: Arc::new(Mutex::new(HashMap::new())),
expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)),
};
{
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.insert(client, alloc);
}
// spawn relay loop // spawn relay loop
let relay_clone = relay_arc.clone(); let relay_clone = relay_arc.clone();
let server_sock_clone = server_sock.clone(); let sink_clone = client_sink.clone();
let client_clone = client; let client_clone = client;
let manager_clone = self.clone(); let manager_clone = self.clone();
let stop_clone = stop.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut buf = vec![0u8; 2048]; let mut buf = vec![0u8; 2048];
loop { loop {
match relay_clone.recv_from(&mut buf).await { tokio::select! {
_ = stop_clone.notified() => {
break;
}
res = relay_clone.recv_from(&mut buf) => match res {
Ok((len, src)) => { Ok((len, src)) => {
info!( info!(
"relay got {} bytes from {} for client {}", "relay got {} bytes from {} for client {}",
@ -70,27 +229,29 @@ impl AllocationManager {
if let Some(channel) = allocation.channel_for_peer(&src) { if let Some(channel) = allocation.channel_for_peer(&src) {
let frame = build_channel_data(channel, &buf[..len]); let frame = build_channel_data(channel, &buf[..len]);
if let Err(e) = if let Err(e) = sink_clone.send(frame).await {
server_sock_clone.send_to(&frame, client_clone).await
{
tracing::error!( tracing::error!(
"failed to send channel data {} -> {}: {:?}", "failed to send channel data {} -> {}: {:?}",
src, src,
client_clone, client_clone,
e e
); );
if matches!(sink_clone, ClientSink::Stream { .. }) {
break;
}
} }
} else { } else {
let indication = build_data_indication(&src, &buf[..len]); let indication = build_data_indication(&src, &buf[..len]);
if let Err(e) = if let Err(e) = sink_clone.send(indication).await {
server_sock_clone.send_to(&indication, client_clone).await
{
tracing::error!( tracing::error!(
"failed to send data indication {} -> {}: {:?}", "failed to send data indication {} -> {}: {:?}",
src, src,
client_clone, client_clone,
e e
); );
if matches!(sink_clone, ClientSink::Stream { .. }) {
break;
}
} }
} }
} else { } else {
@ -99,28 +260,20 @@ impl AllocationManager {
src, src,
client_clone client_clone
); );
// Allocation removed/expired: stop the relay task.
break;
} }
} }
Err(e) => { Err(e) => {
tracing::error!("relay socket error: {:?}", e); tracing::error!("relay socket error: {:?}", e);
break; break;
} }
}
} }
} }
}); });
let alloc = Allocation {
client,
relay_addr: relay_local,
_socket: relay_arc,
permissions: Arc::new(Mutex::new(HashMap::new())),
channel_bindings: Arc::new(Mutex::new(HashMap::new())),
expiry: Arc::new(Mutex::new(Instant::now() + DEFAULT_ALLOCATION_LIFETIME)),
};
tracing::info!("created allocation for {} -> {}", client, relay_local); tracing::info!("created allocation for {} -> {}", client, relay_local);
let mut m = self.inner.lock().unwrap();
prune_expired_locked(&mut m);
m.insert(client, alloc);
Ok(relay_local) Ok(relay_local)
} }
@ -140,6 +293,13 @@ impl AllocationManager {
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?; .ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut perms = alloc.permissions.lock().unwrap(); let mut perms = alloc.permissions.lock().unwrap();
prune_permissions(&mut perms); prune_permissions(&mut perms);
if let Some(max) = self.opts.max_permissions_per_allocation {
let max = max as usize;
if !perms.contains_key(&peer) && perms.len() >= max {
return Err(anyhow::anyhow!(AllocationError::PermissionQuotaExceeded));
}
}
perms.insert(peer, Instant::now() + PERMISSION_LIFETIME); perms.insert(peer, Instant::now() + PERMISSION_LIFETIME);
Ok(()) Ok(())
} }
@ -158,6 +318,13 @@ impl AllocationManager {
.ok_or_else(|| anyhow::anyhow!("allocation not found"))?; .ok_or_else(|| anyhow::anyhow!("allocation not found"))?;
let mut bindings = alloc.channel_bindings.lock().unwrap(); let mut bindings = alloc.channel_bindings.lock().unwrap();
prune_channel_bindings(&mut bindings); prune_channel_bindings(&mut bindings);
if let Some(max) = self.opts.max_channel_bindings_per_allocation {
let max = max as usize;
if !bindings.contains_key(&channel) && bindings.len() >= max {
return Err(anyhow::anyhow!(AllocationError::ChannelQuotaExceeded));
}
}
bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME)); bindings.insert(channel, (peer, Instant::now() + PERMISSION_LIFETIME));
Ok(()) Ok(())
} }
@ -173,7 +340,9 @@ impl AllocationManager {
let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME); let req = requested.unwrap_or(DEFAULT_ALLOCATION_LIFETIME);
if let Some(d) = requested { if let Some(d) = requested {
if d.is_zero() { if d.is_zero() {
guard.remove(&client); if let Some(alloc) = guard.remove(&client) {
alloc.stop.notify_waiters();
}
return Ok(Duration::from_secs(0)); return Ok(Duration::from_secs(0));
} }
} }
@ -191,7 +360,42 @@ impl AllocationManager {
/// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed. /// Remove allocation explicitly (e.g. on zero lifetime). Returns true if removed.
pub fn remove_allocation(&self, client: &SocketAddr) -> bool { pub fn remove_allocation(&self, client: &SocketAddr) -> bool {
let mut guard = self.inner.lock().unwrap(); let mut guard = self.inner.lock().unwrap();
guard.remove(client).is_some() if let Some(alloc) = guard.remove(client) {
alloc.stop.notify_waiters();
true
} else {
false
}
}
/// Spawn a background housekeeping task that periodically prunes expired allocations.
/// This avoids keeping relay tasks/sockets alive indefinitely when allocations expire.
pub fn spawn_housekeeping(&self, interval: Duration) {
let mgr = self.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
mgr.prune_expired();
}
});
}
/// Remove expired allocations and notify their relay tasks to stop.
pub fn prune_expired(&self) {
let mut guard = self.inner.lock().unwrap();
let now = Instant::now();
let expired: Vec<SocketAddr> = guard
.iter()
.filter_map(|(k, alloc)| {
let expiry = alloc.expiry.lock().unwrap();
if *expiry <= now { Some(*k) } else { None }
})
.collect();
for client in expired {
if let Some(alloc) = guard.remove(&client) {
alloc.stop.notify_waiters();
}
}
} }
} }

View File

@ -6,6 +6,7 @@ use crate::models::stun::StunMessage;
use crate::stun::{find_message_integrity, validate_message_integrity}; use crate::stun::{find_message_integrity, validate_message_integrity};
use crate::traits::CredentialStore; use crate::traits::CredentialStore;
use async_trait::async_trait; use async_trait::async_trait;
use base64::Engine;
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha1::Sha1; use sha1::Sha1;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -46,6 +47,8 @@ pub struct AuthSettings {
pub realm: String, pub realm: String,
pub nonce_secret: Vec<u8>, pub nonce_secret: Vec<u8>,
pub nonce_ttl: Duration, pub nonce_ttl: Duration,
pub rest_secret: Option<Vec<u8>>,
pub rest_max_ttl: Duration,
} }
impl AuthSettings { impl AuthSettings {
@ -56,10 +59,13 @@ impl AuthSettings {
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
// Ensure TTL does not collapse to zero so challenges stay valid briefly. // Ensure TTL does not collapse to zero so challenges stay valid briefly.
let ttl = Duration::from_secs(opts.nonce_ttl_seconds.max(60)); let ttl = Duration::from_secs(opts.nonce_ttl_seconds.max(60));
let rest_max_ttl = Duration::from_secs(opts.rest_max_ttl_seconds.max(60));
Self { Self {
realm: opts.realm.clone(), realm: opts.realm.clone(),
nonce_secret: secret.into_bytes(), nonce_secret: secret.into_bytes(),
nonce_ttl: ttl, nonce_ttl: ttl,
rest_secret: opts.rest_secret.clone().map(|s| s.into_bytes()),
rest_max_ttl,
} }
} }
} }
@ -67,7 +73,7 @@ impl AuthSettings {
/// Result of validating authentication attributes on an incoming STUN/TURN request. /// Result of validating authentication attributes on an incoming STUN/TURN request.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum AuthStatus { pub enum AuthStatus {
Granted { username: String }, Granted { username: String, key: Vec<u8> },
Challenge { nonce: String }, Challenge { nonce: String },
StaleNonce { nonce: String }, StaleNonce { nonce: String },
Reject { code: u16, reason: &'static str }, Reject { code: u16, reason: &'static str },
@ -159,12 +165,15 @@ impl<S: CredentialStore + Clone> AuthManager<S> {
let password = match self.store.get_password(&username).await { let password = match self.store.get_password(&username).await {
Some(p) => p, Some(p) => p,
None => { None => match self.derive_turn_rest_password(&username) {
return AuthStatus::Reject { Some(p) => p,
code: 401, None => {
reason: "Unknown User", return AuthStatus::Reject {
code: 401,
reason: "Unknown User",
};
} }
} },
}; };
let key = self.derive_long_term_key(&username, &password); let key = self.derive_long_term_key(&username, &password);
@ -175,7 +184,7 @@ impl<S: CredentialStore + Clone> AuthManager<S> {
}; };
} }
AuthStatus::Granted { username } AuthStatus::Granted { username, key }
} }
fn attribute_utf8(&self, msg: &StunMessage, attr_type: u16) -> Option<String> { fn attribute_utf8(&self, msg: &StunMessage, attr_type: u16) -> Option<String> {
@ -190,6 +199,32 @@ impl<S: CredentialStore + Clone> AuthManager<S> {
compute_a1_md5(username, &self.settings.realm, password) compute_a1_md5(username, &self.settings.realm, password)
} }
/// TURN REST (ephemeral) password derivation.
///
/// Expected username format: `<expiry_unix_seconds>` or `<expiry_unix_seconds>:<opaque>`.
/// Password is `base64(HMAC-SHA1(rest_secret, username))`.
///
/// Security: Reject if expired or if expiry is too far in the future (bounded by rest_max_ttl).
fn derive_turn_rest_password(&self, username: &str) -> Option<String> {
let secret = self.settings.rest_secret.as_ref()?;
let expiry = parse_turn_rest_expiry(username)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs();
if now > expiry {
return None;
}
let delta = expiry.saturating_sub(now);
if delta > self.settings.rest_max_ttl.as_secs() {
return None;
}
Some(turn_rest_password_base64(secret, username))
}
pub fn mint_nonce(&self, peer: &SocketAddr) -> String { pub fn mint_nonce(&self, peer: &SocketAddr) -> String {
let now = SystemTime::now() let now = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
@ -241,6 +276,19 @@ impl<S: CredentialStore + Clone> AuthManager<S> {
} }
} }
fn parse_turn_rest_expiry(username: &str) -> Option<u64> {
let prefix = username.split(':').next().unwrap_or(username);
prefix.parse::<u64>().ok()
}
fn turn_rest_password_base64(secret: &[u8], username: &str) -> String {
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(secret).expect("rest secret to build hmac");
mac.update(username.as_bytes());
let bytes = mac.finalize().into_bytes();
base64::engine::general_purpose::STANDARD.encode(bytes)
}
enum NonceValidation { enum NonceValidation {
Valid, Valid,
Expired, Expired,

76
src/bin/turn_rest_cred.rs Normal file
View File

@ -0,0 +1,76 @@
use anyhow::Context;
use base64::Engine;
use hmac::{Hmac, Mac};
use sha1::Sha1;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs()
}
fn turn_rest_password_base64(secret: &[u8], username: &str) -> String {
type HmacSha1 = Hmac<Sha1>;
let mut mac = HmacSha1::new_from_slice(secret).expect("rest secret to build hmac");
mac.update(username.as_bytes());
let bytes = mac.finalize().into_bytes();
base64::engine::general_purpose::STANDARD.encode(bytes)
}
/// Minimal offline TURN REST credential generator.
///
/// Usage:
/// - `cargo run --bin turn_rest_cred -- --secret "..." --user alice --ttl 600`
///
/// Output:
/// - `TURN_USERNAME=...`
/// - `TURN_PASSWORD=...`
fn main() -> anyhow::Result<()> {
let mut secret: Option<String> = None;
let mut user: Option<String> = None;
let mut ttl: u64 = 600;
let mut json = false;
let mut args = std::env::args().skip(1);
while let Some(arg) = args.next() {
match arg.as_str() {
"--secret" => secret = Some(args.next().context("--secret requires a value")?),
"--user" => user = Some(args.next().context("--user requires a value")?),
"--ttl" => {
let v = args.next().context("--ttl requires a value")?;
ttl = v.parse::<u64>().context("--ttl must be an integer")?;
}
"--json" => json = true,
"-h" | "--help" => {
println!(
"turn_rest_cred\n\nUSAGE:\n turn_rest_cred --secret <secret> [--user <id>] [--ttl <seconds>] [--json]\n\nNOTES:\n Username format is <expiry>[:<user>]. Password is base64(HMAC-SHA1(secret, username)).\n"
);
return Ok(());
}
other => return Err(anyhow::anyhow!("unknown arg: {}", other)),
}
}
let secret = secret.context("missing --secret")?;
let expiry = now_unix_secs().saturating_add(ttl.max(60));
let username = match user {
Some(u) if !u.is_empty() => format!("{}:{}", expiry, u),
_ => expiry.to_string(),
};
let password = turn_rest_password_base64(secret.as_bytes(), &username);
if json {
println!(
"{{\n \"username\": \"{}\",\n \"credential\": \"{}\",\n \"ttl\": {}\n}}",
username, password, ttl
);
} else {
println!("TURN_USERNAME={}", username);
println!("TURN_PASSWORD={}", password);
println!("TURN_TTL={}", ttl);
}
Ok(())
}

View File

@ -11,6 +11,26 @@ fn default_nonce_ttl_seconds() -> u64 {
300 300
} }
fn default_rest_max_ttl_seconds() -> u64 {
600
}
fn default_tls_bind() -> String {
"0.0.0.0:5349".to_string()
}
fn default_enable_udp() -> bool {
true
}
fn default_enable_tcp() -> bool {
true
}
fn default_enable_tls() -> bool {
true
}
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone)]
pub struct CredentialEntry { pub struct CredentialEntry {
pub username: String, pub username: String,
@ -28,6 +48,16 @@ pub struct AuthOptions {
/// Validity period for generated nonces in seconds. /// Validity period for generated nonces in seconds.
#[serde(default = "default_nonce_ttl_seconds")] #[serde(default = "default_nonce_ttl_seconds")]
pub nonce_ttl_seconds: u64, pub nonce_ttl_seconds: u64,
/// Optional TURN REST shared secret. When set, the server can validate ephemeral credentials
/// derived from this secret (username contains expiry; password is base64(HMAC-SHA1(secret, username))).
#[serde(default)]
pub rest_secret: Option<String>,
/// Maximum acceptable TTL window (in seconds) for TURN REST usernames.
/// If the username expiry is too far in the future, credentials are rejected.
#[serde(default = "default_rest_max_ttl_seconds")]
pub rest_max_ttl_seconds: u64,
} }
impl Default for AuthOptions { impl Default for AuthOptions {
@ -36,19 +66,99 @@ impl Default for AuthOptions {
realm: default_realm(), realm: default_realm(),
nonce_secret: None, nonce_secret: None,
nonce_ttl_seconds: default_nonce_ttl_seconds(), nonce_ttl_seconds: default_nonce_ttl_seconds(),
rest_secret: None,
rest_max_ttl_seconds: default_rest_max_ttl_seconds(),
} }
} }
} }
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone)]
pub struct ServerOptions { pub struct ServerOptions {
/// Listen address, e.g. "0.0.0.0:3478" /// Listen address (legacy/default), e.g. "0.0.0.0:3478".
/// If `udp_bind` / `tcp_bind` are not set, this value is used.
pub bind: String, pub bind: String,
/// Optional per-protocol bind addresses.
#[serde(default)]
pub udp_bind: Option<String>,
#[serde(default)]
pub tcp_bind: Option<String>,
#[serde(default = "default_tls_bind")]
pub tls_bind: String,
/// Enable/disable protocol listeners.
#[serde(default = "default_enable_udp")]
pub enable_udp: bool,
#[serde(default = "default_enable_tcp")]
pub enable_tcp: bool,
#[serde(default = "default_enable_tls")]
pub enable_tls: bool,
/// Optional TLS: cert/key paths (not used in MVP) /// Optional TLS: cert/key paths (not used in MVP)
pub tls_cert: Option<String>, pub tls_cert: Option<String>,
pub tls_key: Option<String>, pub tls_key: Option<String>,
} }
#[derive(Debug, Deserialize, Clone, Default)]
pub struct RelayOptions {
/// Optional UDP relay port range. If set, allocations bind within this range.
/// If omitted, OS chooses an ephemeral port.
#[serde(default)]
pub relay_port_min: Option<u16>,
#[serde(default)]
pub relay_port_max: Option<u16>,
/// IP address used for binding relay sockets (default: 0.0.0.0).
/// Example: "0.0.0.0" or a specific local interface IP.
#[serde(default)]
pub relay_bind_ip: Option<String>,
/// Optional public/advertised IP that is placed into XOR-RELAYED-ADDRESS.
/// This is important when running behind NAT or when relay sockets bind to 0.0.0.0.
#[serde(default)]
pub advertised_ip: Option<String>,
}
#[derive(Debug, Deserialize, Clone, Default)]
pub struct LoggingOptions {
/// Default tracing directive (overridable via RUST_LOG).
/// Example: "warn,niom_turn=info".
#[serde(default)]
pub default_directive: Option<String>,
}
#[derive(Debug, Deserialize, Clone, Default)]
pub struct LimitsOptions {
/// Max concurrent allocations per source IP (across different source ports).
/// If omitted, unlimited.
#[serde(default)]
pub max_allocations_per_ip: Option<u32>,
/// Max permissions per allocation.
/// If omitted, unlimited.
#[serde(default)]
pub max_permissions_per_allocation: Option<u32>,
/// Max channel bindings per allocation.
/// If omitted, unlimited.
#[serde(default)]
pub max_channel_bindings_per_allocation: Option<u32>,
/// Rate-limit unauthenticated responses (401/438) per source IP.
/// If omitted, unlimited.
#[serde(default)]
pub unauth_rps: Option<u32>,
#[serde(default)]
pub unauth_burst: Option<u32>,
/// Rate-limit STUN Binding success responses per source IP.
/// If omitted, unlimited.
#[serde(default)]
pub binding_rps: Option<u32>,
#[serde(default)]
pub binding_burst: Option<u32>,
}
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone)]
pub struct Config { pub struct Config {
/// Server options /// Server options
@ -59,6 +169,18 @@ pub struct Config {
/// Authentication behaviour advertised to clients. /// Authentication behaviour advertised to clients.
#[serde(default)] #[serde(default)]
pub auth: AuthOptions, pub auth: AuthOptions,
/// Relay socket behaviour (port range, bind ip, advertised ip).
#[serde(default)]
pub relay: RelayOptions,
/// Logging/tracing defaults.
#[serde(default)]
pub logging: LoggingOptions,
/// Resource limits (basic abuse-prevention).
#[serde(default)]
pub limits: LimitsOptions,
} }
impl Config { impl Config {

View File

@ -26,13 +26,23 @@ pub const ATTR_LIFETIME: u16 = 0x000D;
pub const ATTR_REALM: u16 = 0x0014; pub const ATTR_REALM: u16 = 0x0014;
pub const ATTR_NONCE: u16 = 0x0015; pub const ATTR_NONCE: u16 = 0x0015;
pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012; pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012;
pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
// RFC5389: FINGERPRINT
pub const ATTR_FINGERPRINT: u16 = 0x8028;
// TURN attrs // TURN attrs
pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016; pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
pub const ATTR_DATA: u16 = 0x0013; pub const ATTR_DATA: u16 = 0x0013;
pub const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019;
// IP protocol numbers used by TURN REQUESTED-TRANSPORT
pub const IPPROTO_UDP: u8 = 17;
pub const IPPROTO_TCP: u8 = 6;
// Some helper values // Some helper values
pub const FAMILY_IPV4: u8 = 0x01; pub const FAMILY_IPV4: u8 = 0x01;
pub const FAMILY_IPV6: u8 = 0x02;
// Fingerprint XOR magic (XOR with CRC32 for FINGERPRINT attribute) // Fingerprint XOR magic (XOR with CRC32 for FINGERPRINT attribute)
pub const FINGERPRINT_XOR: u32 = 0x5354554e; pub const FINGERPRINT_XOR: u32 = 0x5354554e;

View File

@ -4,10 +4,14 @@ pub mod auth;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod logging; pub mod logging;
pub mod metrics;
pub mod models; pub mod models;
pub mod rate_limit;
pub mod server; pub mod server;
pub mod stun; pub mod stun;
pub mod tcp;
pub mod tls; pub mod tls;
pub mod turn_stream;
pub mod traits; pub mod traits;
pub use crate::alloc::*; pub use crate::alloc::*;

View File

@ -1,30 +1,28 @@
//! Binary entry point that wires configuration, UDP listener, optional TLS listener, and allocation handling. //! Binary entry point that wires configuration, UDP listener, optional TLS listener, and allocation handling.
//! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners. //! Backlog: graceful shutdown signals, structured metrics, and coordinated lifecycle management across listeners.
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tracing::{error, info}; use tracing::{error, info};
// Use the library crate's public modules instead of local `mod` declarations. // Use the library crate's public modules instead of local `mod` declarations.
use niom_turn::alloc::AllocationManager; use niom_turn::alloc::AllocationManager;
use niom_turn::alloc::AllocationOptions;
use niom_turn::auth::{AuthManager, InMemoryStore}; use niom_turn::auth::{AuthManager, InMemoryStore};
use niom_turn::config::{AuthOptions, Config}; use niom_turn::config::{AuthOptions, Config};
use niom_turn::server::udp_reader_loop; use niom_turn::rate_limit::RateLimiters;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
niom_turn::logging::init_tracing();
info!("niom-turn starting");
// Bootstrap configuration: prefer appsettings.json, otherwise rely on baked-in demo defaults. // Bootstrap configuration: prefer appsettings.json, otherwise rely on baked-in demo defaults.
let cfg = match Config::load_default() { let cfg = match Config::load_default() {
Ok(c) => { Ok(c) => {
info!("loaded config from appsettings.json");
c c
} }
Err(e) => { Err(e) => {
info!( eprintln!(
"no appsettings.json found or failed to load: {} — using defaults", "no appsettings.json found or failed to load: {} — using defaults",
e e
); );
@ -32,6 +30,12 @@ async fn main() -> anyhow::Result<()> {
Config { Config {
server: niom_turn::config::ServerOptions { server: niom_turn::config::ServerOptions {
bind: "0.0.0.0:3478".to_string(), bind: "0.0.0.0:3478".to_string(),
udp_bind: None,
tcp_bind: None,
tls_bind: "0.0.0.0:5349".to_string(),
enable_udp: true,
enable_tcp: true,
enable_tls: true,
tls_cert: None, tls_cert: None,
tls_key: None, tls_key: None,
}, },
@ -40,11 +44,39 @@ async fn main() -> anyhow::Result<()> {
password: "secretpassword".into(), password: "secretpassword".into(),
}], }],
auth: AuthOptions::default(), auth: AuthOptions::default(),
relay: niom_turn::config::RelayOptions::default(),
logging: niom_turn::config::LoggingOptions::default(),
limits: niom_turn::config::LimitsOptions::default(),
} }
} }
}; };
let bind_addr: SocketAddr = cfg.server.bind.parse()?; let log_directive = cfg
.logging
.default_directive
.as_deref()
.unwrap_or("warn,niom_turn=info");
niom_turn::logging::init_tracing_with_default(log_directive);
// Build per-server rate limiters (defaults to disabled when unset).
let rate_limiters = Arc::new(RateLimiters::from_limits(&cfg.limits));
info!("niom-turn starting");
info!("logging.default_directive={}", log_directive);
let udp_bind = cfg
.server
.udp_bind
.clone()
.unwrap_or_else(|| cfg.server.bind.clone());
let tcp_bind = cfg
.server
.tcp_bind
.clone()
.unwrap_or_else(|| cfg.server.bind.clone());
let tls_bind = cfg.server.tls_bind.clone();
let udp_bind_addr: SocketAddr = udp_bind.parse()?;
// Materialise the credential backend before starting network endpoints. // Materialise the credential backend before starting network endpoints.
let creds = InMemoryStore::new(); let creds = InMemoryStore::new();
@ -54,44 +86,148 @@ async fn main() -> anyhow::Result<()> {
let auth = AuthManager::new(creds.clone(), &cfg.auth); let auth = AuthManager::new(creds.clone(), &cfg.auth);
// Bind the UDP socket that receives STUN/TURN traffic from WebRTC clients. let relay_bind_ip: IpAddr = cfg
let udp = UdpSocket::bind(bind_addr).await?; .relay
let udp = Arc::new(udp); .relay_bind_ip
.as_deref()
.unwrap_or("0.0.0.0")
.parse()
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
// Allocation manager shared by UDP + TLS frontends. let advertised_ip: Option<IpAddr> = cfg
let alloc_mgr = AllocationManager::new(); .relay
.advertised_ip
.as_deref()
.and_then(|s| s.parse().ok());
// Spawn the asynchronous packet loop that handles all UDP requests. let alloc_mgr = AllocationManager::new_with_options(AllocationOptions {
let udp_clone = udp.clone(); relay_bind_ip,
let auth_clone = auth.clone(); relay_port_min: cfg.relay.relay_port_min,
let alloc_clone = alloc_mgr.clone(); relay_port_max: cfg.relay.relay_port_max,
advertised_ip,
max_allocations_per_ip: cfg.limits.max_allocations_per_ip,
max_permissions_per_allocation: cfg.limits.max_permissions_per_allocation,
max_channel_bindings_per_allocation: cfg.limits.max_channel_bindings_per_allocation,
});
// Periodically prune expired allocations so relay tasks can terminate even when idle.
alloc_mgr.spawn_housekeeping(Duration::from_secs(5));
// Periodically emit a compact metrics snapshot to logs.
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = udp_reader_loop(udp_clone, auth_clone, alloc_clone).await { loop {
error!("udp loop error: {:?}", e); tokio::time::sleep(Duration::from_secs(30)).await;
let snap = niom_turn::metrics::snapshot();
info!(
"metrics: stun={} channeldata={} streams={} auth_challenge={} auth_stale={} auth_reject={} alloc_total={} alloc_ok={} alloc_fail={} perms_added={} channels_added={} alloc_active={} rate_limited={}",
snap.stun_messages_total,
snap.channel_data_total,
snap.stream_connections_total,
snap.auth_challenge_total,
snap.auth_stale_total,
snap.auth_reject_total,
snap.allocate_total,
snap.allocate_success_total,
snap.allocate_fail_total,
snap.permissions_added_total,
snap.channel_bindings_added_total,
snap.allocations_active,
snap.rate_limited_total
);
} }
}); });
// Optionally start the TLS listener so `turns:` clients can connect via TCP/TLS. info!(
if let (Some(cert), Some(key)) = (cfg.server.tls_cert.clone(), cfg.server.tls_key.clone()) { "listeners: udp={} tcp={} tls={} udp_bind={} tcp_bind={} tls_bind={}",
let udp_for_tls = udp.clone(); cfg.server.enable_udp,
let auth_for_tls = auth.clone(); cfg.server.enable_tcp,
let alloc_for_tls = alloc_mgr.clone(); cfg.server.enable_tls,
udp_bind,
tcp_bind,
tls_bind
);
info!(
"relay: bind_ip={} port_range={:?}-{:?} advertised_ip={:?}",
relay_bind_ip,
cfg.relay.relay_port_min,
cfg.relay.relay_port_max,
advertised_ip
);
// Bind the UDP socket that receives STUN/TURN traffic from WebRTC clients.
let udp = if cfg.server.enable_udp {
let udp = UdpSocket::bind(udp_bind_addr).await?;
Some(Arc::new(udp))
} else {
None
};
// Spawn the asynchronous packet loop that handles all UDP requests.
if let Some(udp_sock) = udp.clone() {
let udp_clone = udp_sock.clone();
let auth_clone = auth.clone();
let alloc_clone = alloc_mgr.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = niom_turn::tls::serve_tls( if let Err(e) = niom_turn::server::udp_reader_loop_with_limits(
"0.0.0.0:5349", udp_clone,
&cert, auth_clone,
&key, alloc_clone,
udp_for_tls, rl,
auth_for_tls,
alloc_for_tls,
) )
.await .await
{ {
error!("tls serve failed: {:?}", e); error!("udp loop error: {:?}", e);
} }
}); });
} }
// Start a plain TCP listener for `turn:` clients that require TCP.
if cfg.server.enable_tcp {
let auth_for_tcp = auth.clone();
let alloc_for_tcp = alloc_mgr.clone();
let tcp_bind = tcp_bind.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move {
if let Err(e) = niom_turn::tcp::serve_tcp_with_limits(
&tcp_bind,
auth_for_tcp,
alloc_for_tcp,
rl,
)
.await
{
error!("tcp serve failed: {:?}", e);
}
});
}
// Optionally start the TLS listener so `turns:` clients can connect via TCP/TLS.
if cfg.server.enable_tls {
if let (Some(cert), Some(key)) = (cfg.server.tls_cert.clone(), cfg.server.tls_key.clone()) {
let auth_for_tls = auth.clone();
let alloc_for_tls = alloc_mgr.clone();
let tls_bind = tls_bind.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move {
if let Err(e) = niom_turn::tls::serve_tls_with_limits(
&tls_bind,
&cert,
&key,
auth_for_tls,
alloc_for_tls,
rl,
)
.await
{
error!("tls serve failed: {:?}", e);
}
});
} else {
info!("TLS enabled but tls_cert/tls_key not configured; skipping TLS listener");
}
}
// Keep the runtime alive while background tasks process packets. // Keep the runtime alive while background tasks process packets.
loop { loop {
tokio::time::sleep(std::time::Duration::from_secs(60)).await; tokio::time::sleep(std::time::Duration::from_secs(60)).await;

155
src/metrics.rs Normal file
View File

@ -0,0 +1,155 @@
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug)]
pub struct Metrics {
pub stun_messages_total: AtomicU64,
pub channel_data_total: AtomicU64,
pub stream_connections_total: AtomicU64,
pub auth_challenge_total: AtomicU64,
pub auth_stale_total: AtomicU64,
pub auth_reject_total: AtomicU64,
pub allocate_total: AtomicU64,
pub allocate_success_total: AtomicU64,
pub allocate_fail_total: AtomicU64,
pub permissions_added_total: AtomicU64,
pub channel_bindings_added_total: AtomicU64,
pub allocations_active: AtomicU64,
pub rate_limited_total: AtomicU64,
}
impl Default for Metrics {
fn default() -> Self {
Self {
stun_messages_total: AtomicU64::new(0),
channel_data_total: AtomicU64::new(0),
stream_connections_total: AtomicU64::new(0),
auth_challenge_total: AtomicU64::new(0),
auth_stale_total: AtomicU64::new(0),
auth_reject_total: AtomicU64::new(0),
allocate_total: AtomicU64::new(0),
allocate_success_total: AtomicU64::new(0),
allocate_fail_total: AtomicU64::new(0),
permissions_added_total: AtomicU64::new(0),
channel_bindings_added_total: AtomicU64::new(0),
allocations_active: AtomicU64::new(0),
rate_limited_total: AtomicU64::new(0),
}
}
}
static METRICS: OnceLock<Metrics> = OnceLock::new();
pub fn metrics() -> &'static Metrics {
METRICS.get_or_init(Metrics::default)
}
pub fn inc_stun_messages() {
metrics().stun_messages_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_channel_data() {
metrics().channel_data_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_stream_connections() {
metrics().stream_connections_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_auth_challenge() {
metrics().auth_challenge_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_auth_stale() {
metrics().auth_stale_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_auth_reject() {
metrics().auth_reject_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_allocate_total() {
metrics().allocate_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_allocate_success() {
metrics().allocate_success_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_allocate_fail() {
metrics().allocate_fail_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_permission_added() {
metrics().permissions_added_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_channel_binding_added() {
metrics().channel_bindings_added_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_rate_limited() {
metrics().rate_limited_total.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_allocations_active() {
metrics().allocations_active.fetch_add(1, Ordering::Relaxed);
}
pub fn dec_allocations_active() {
// Saturating decrement to avoid underflow.
let m = metrics();
let mut current = m.allocations_active.load(Ordering::Relaxed);
while current > 0 {
match m.allocations_active.compare_exchange_weak(
current,
current - 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(v) => current = v,
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub stun_messages_total: u64,
pub channel_data_total: u64,
pub stream_connections_total: u64,
pub auth_challenge_total: u64,
pub auth_stale_total: u64,
pub auth_reject_total: u64,
pub allocate_total: u64,
pub allocate_success_total: u64,
pub allocate_fail_total: u64,
pub permissions_added_total: u64,
pub channel_bindings_added_total: u64,
pub allocations_active: u64,
pub rate_limited_total: u64,
}
pub fn snapshot() -> MetricsSnapshot {
let m = metrics();
MetricsSnapshot {
stun_messages_total: m.stun_messages_total.load(Ordering::Relaxed),
channel_data_total: m.channel_data_total.load(Ordering::Relaxed),
stream_connections_total: m.stream_connections_total.load(Ordering::Relaxed),
auth_challenge_total: m.auth_challenge_total.load(Ordering::Relaxed),
auth_stale_total: m.auth_stale_total.load(Ordering::Relaxed),
auth_reject_total: m.auth_reject_total.load(Ordering::Relaxed),
allocate_total: m.allocate_total.load(Ordering::Relaxed),
allocate_success_total: m.allocate_success_total.load(Ordering::Relaxed),
allocate_fail_total: m.allocate_fail_total.load(Ordering::Relaxed),
permissions_added_total: m.permissions_added_total.load(Ordering::Relaxed),
channel_bindings_added_total: m.channel_bindings_added_total.load(Ordering::Relaxed),
allocations_active: m.allocations_active.load(Ordering::Relaxed),
rate_limited_total: m.rate_limited_total.load(Ordering::Relaxed),
}
}

145
src/rate_limit.rs Normal file
View File

@ -0,0 +1,145 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
use std::time::Instant;
use crate::config::LimitsOptions;
#[derive(Debug, Clone, Copy)]
struct TokenBucketConfig {
rate_per_sec: f64,
burst: f64,
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last: Instant,
}
#[derive(Debug)]
pub struct TokenBucketRateLimiter {
cfg: TokenBucketConfig,
state: Mutex<HashMap<IpAddr, Bucket>>,
}
impl TokenBucketRateLimiter {
pub fn new(rate_per_sec: u32, burst: u32) -> Self {
Self {
cfg: TokenBucketConfig {
rate_per_sec: rate_per_sec.max(1) as f64,
burst: burst.max(1) as f64,
},
state: Mutex::new(HashMap::new()),
}
}
fn allow_at(&self, ip: IpAddr, now: Instant) -> bool {
let mut guard = self.state.lock().unwrap();
let entry = guard.entry(ip).or_insert(Bucket {
tokens: self.cfg.burst,
last: now,
});
let elapsed = now.saturating_duration_since(entry.last);
let refill = elapsed.as_secs_f64() * self.cfg.rate_per_sec;
entry.tokens = (entry.tokens + refill).min(self.cfg.burst);
entry.last = now;
if entry.tokens >= 1.0 {
entry.tokens -= 1.0;
true
} else {
false
}
}
pub fn allow(&self, ip: IpAddr) -> bool {
self.allow_at(ip, Instant::now())
}
#[cfg(test)]
fn allow_at_for_test(&self, ip: IpAddr, now: Instant) -> bool {
self.allow_at(ip, now)
}
#[cfg(test)]
fn configured_burst(&self) -> u32 {
self.cfg.burst as u32
}
}
#[derive(Debug, Default)]
pub struct RateLimiters {
unauth: Option<TokenBucketRateLimiter>,
binding: Option<TokenBucketRateLimiter>,
}
impl RateLimiters {
pub fn disabled() -> Self {
Self {
unauth: None,
binding: None,
}
}
/// Build per-IP rate limiters from runtime configuration.
///
/// Any pair where either value is missing/zero disables that limiter.
pub fn from_limits(limits: &LimitsOptions) -> Self {
let unauth = match (limits.unauth_rps, limits.unauth_burst) {
(Some(rps), Some(burst)) if rps > 0 && burst > 0 => {
Some(TokenBucketRateLimiter::new(rps, burst))
}
_ => None,
};
let binding = match (limits.binding_rps, limits.binding_burst) {
(Some(rps), Some(burst)) if rps > 0 && burst > 0 => {
Some(TokenBucketRateLimiter::new(rps, burst))
}
_ => None,
};
Self { unauth, binding }
}
pub fn allow_unauth(&self, ip: IpAddr) -> bool {
match &self.unauth {
Some(l) => l.allow(ip),
None => true,
}
}
pub fn allow_binding(&self, ip: IpAddr) -> bool {
match &self.binding {
Some(l) => l.allow(ip),
None => true,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn token_bucket_enforces_burst_then_refills() {
let lim = TokenBucketRateLimiter::new(10, 3);
assert_eq!(lim.configured_burst(), 3);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
let t0 = Instant::now();
assert!(lim.allow_at_for_test(ip, t0));
assert!(lim.allow_at_for_test(ip, t0));
assert!(lim.allow_at_for_test(ip, t0));
assert!(!lim.allow_at_for_test(ip, t0));
let t1 = t0 + Duration::from_millis(200); // 10 rps => ~2 tokens
assert!(lim.allow_at_for_test(ip, t1));
assert!(lim.allow_at_for_test(ip, t1));
}
}

View File

@ -3,13 +3,17 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tracing::{error, warn}; use tracing::{error, warn};
use crate::alloc::AllocationManager; use crate::alloc::{AllocationManager, ClientSink};
use crate::alloc::AllocationError;
use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; use crate::auth::{AuthManager, AuthStatus, InMemoryStore};
use crate::constants::*; use crate::constants::*;
use crate::rate_limit::RateLimiters;
use crate::stun::{ use crate::stun::{
build_401_response, build_allocate_success, build_error_response, build_lifetime_success, build_401_response, build_allocate_success_with_integrity, build_error_response,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_channel_data, build_error_response_with_integrity, build_lifetime_success_with_integrity,
parse_message, build_success_response_with_integrity, decode_xor_peer_address, extract_lifetime_seconds,
parse_channel_data, extract_requested_transport_protocol, parse_message,
validate_fingerprint_if_present,
}; };
use std::time::Duration; use std::time::Duration;
@ -18,6 +22,16 @@ pub async fn udp_reader_loop(
udp: Arc<UdpSocket>, udp: Arc<UdpSocket>,
auth: AuthManager<InMemoryStore>, auth: AuthManager<InMemoryStore>,
allocs: AllocationManager, allocs: AllocationManager,
) -> anyhow::Result<()> {
udp_reader_loop_with_limits(udp, auth, allocs, Arc::new(RateLimiters::disabled())).await
}
/// UDP reader loop with explicit (per-server) rate limiters.
pub async fn udp_reader_loop_with_limits(
udp: Arc<UdpSocket>,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: Arc<RateLimiters>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut buf = vec![0u8; 1500]; let mut buf = vec![0u8; 1500];
loop { loop {
@ -25,6 +39,7 @@ pub async fn udp_reader_loop(
tracing::debug!("got {} bytes from {}", len, peer); tracing::debug!("got {} bytes from {}", len, peer);
if let Some((channel, payload)) = parse_channel_data(&buf[..len]) { if let Some((channel, payload)) = parse_channel_data(&buf[..len]) {
crate::metrics::inc_channel_data();
let allocation = match allocs.get_allocation(&peer) { let allocation = match allocs.get_allocation(&peer) {
Some(a) => a, Some(a) => a,
None => { None => {
@ -69,6 +84,11 @@ pub async fn udp_reader_loop(
} }
if let Ok(msg) = parse_message(&buf[..len]) { if let Ok(msg) = parse_message(&buf[..len]) {
if !validate_fingerprint_if_present(&msg) {
tracing::debug!("dropping STUN/TURN message from {} due to invalid FINGERPRINT", peer);
continue;
}
crate::metrics::inc_stun_messages();
tracing::info!( tracing::info!(
"STUN/TURN message from {} type=0x{:04x} len={}", "STUN/TURN message from {} type=0x{:04x} len={}",
peer, peer,
@ -85,16 +105,22 @@ pub async fn udp_reader_loop(
); );
if requires_auth { if requires_auth {
match auth.authenticate(&msg, &peer).await { let key = match auth.authenticate(&msg, &peer).await {
AuthStatus::Granted { username } => { AuthStatus::Granted { username, key } => {
tracing::debug!( tracing::debug!(
"TURN auth ok for {} as {} (0x{:04x})", "TURN auth ok for {} as {} (0x{:04x})",
peer, peer,
username, username,
msg.header.msg_type msg.header.msg_type
); );
key
} }
AuthStatus::Challenge { nonce } => { AuthStatus::Challenge { nonce } => {
crate::metrics::inc_auth_challenge();
if !rate_limiters.allow_unauth(peer.ip()) {
crate::metrics::inc_rate_limited();
continue;
}
let resp = build_401_response( let resp = build_401_response(
&msg.header, &msg.header,
auth.realm(), auth.realm(),
@ -106,6 +132,11 @@ pub async fn udp_reader_loop(
continue; continue;
} }
AuthStatus::StaleNonce { nonce } => { AuthStatus::StaleNonce { nonce } => {
crate::metrics::inc_auth_stale();
if !rate_limiters.allow_unauth(peer.ip()) {
crate::metrics::inc_rate_limited();
continue;
}
let resp = build_401_response( let resp = build_401_response(
&msg.header, &msg.header,
auth.realm(), auth.realm(),
@ -117,19 +148,55 @@ pub async fn udp_reader_loop(
continue; continue;
} }
AuthStatus::Reject { code, reason } => { AuthStatus::Reject { code, reason } => {
crate::metrics::inc_auth_reject();
let resp = build_error_response(&msg.header, code, reason); let resp = build_error_response(&msg.header, code, reason);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
} };
match msg.header.msg_type { match msg.header.msg_type {
METHOD_ALLOCATE => { METHOD_ALLOCATE => {
crate::metrics::inc_allocate_total();
// TURN Allocate MUST include REQUESTED-TRANSPORT; WebRTC expects UDP (17).
match extract_requested_transport_protocol(&msg) {
Some(IPPROTO_UDP) => {}
Some(_) => {
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(
&msg.header,
442,
"Unsupported Transport",
&key,
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
None => {
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Missing REQUESTED-TRANSPORT",
&key,
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
}
let requested_lifetime = extract_lifetime_seconds(&msg) let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64)) .map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero()); .filter(|d| !d.is_zero());
match allocs.allocate_for(peer, udp.clone()).await { match allocs
.allocate_for(peer, ClientSink::Udp {
sock: udp.clone(),
addr: peer,
})
.await
{
Ok(relay_addr) => { Ok(relay_addr) => {
let applied = let applied =
match allocs.refresh_allocation(peer, requested_lifetime) { match allocs.refresh_allocation(peer, requested_lifetime) {
@ -140,10 +207,11 @@ pub async fn udp_reader_loop(
peer, peer,
e e
); );
let resp = build_error_response( let resp = build_error_response_with_integrity(
&msg.header, &msg.header,
500, 500,
"Allocate Failed", "Allocate Failed",
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
@ -151,20 +219,32 @@ pub async fn udp_reader_loop(
}; };
let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32; let lifetime_secs = applied.as_secs().min(u32::MAX as u64) as u32;
let resp = let advertised = allocs.relay_addr_for_response(relay_addr);
build_allocate_success(&msg.header, &relay_addr, lifetime_secs); let resp = build_allocate_success_with_integrity(
&msg.header,
&advertised,
lifetime_secs,
&key,
);
tracing::info!( tracing::info!(
"allocated relay {} for {} lifetime={}s", "allocated relay {} for {} lifetime={}s",
relay_addr, relay_addr,
peer, peer,
lifetime_secs lifetime_secs
); );
crate::metrics::inc_allocate_success();
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
Err(e) => { Err(e) => {
tracing::error!("allocate failed: {:?}", e); tracing::error!("allocate failed: {:?}", e);
let resp = let (code, reason) = match e.downcast_ref::<AllocationError>() {
build_error_response(&msg.header, 500, "Allocate Failed"); Some(AllocationError::AllocationQuotaExceeded) => {
(486, "Allocation Quota Reached")
}
_ => (500, "Allocate Failed"),
};
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(&msg.header, code, reason, &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
} }
@ -174,7 +254,7 @@ pub async fn udp_reader_loop(
if allocs.get_allocation(&peer).is_none() { if allocs.get_allocation(&peer).is_none() {
warn!("create-permission without allocation from {}", peer); warn!("create-permission without allocation from {}", peer);
let resp = let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch"); build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -195,6 +275,7 @@ pub async fn udp_reader_loop(
peer, peer,
peer_addr peer_addr
); );
crate::metrics::inc_permission_added();
added += 1; added += 1;
} }
Err(e) => { Err(e) => {
@ -204,6 +285,19 @@ pub async fn udp_reader_loop(
peer_addr, peer_addr,
e e
); );
if matches!(
e.downcast_ref::<AllocationError>(),
Some(AllocationError::PermissionQuotaExceeded)
) {
let resp = build_error_response_with_integrity(
&msg.header,
508,
"Insufficient Capacity",
&key,
);
let _ = udp.send_to(&resp, &peer).await;
continue;
}
} }
} }
} else { } else {
@ -213,10 +307,10 @@ pub async fn udp_reader_loop(
if added == 0 { if added == 0 {
let resp = let resp =
build_error_response(&msg.header, 400, "No valid XOR-PEER-ADDRESS"); build_error_response_with_integrity(&msg.header, 400, "No valid XOR-PEER-ADDRESS", &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} else { } else {
let resp = build_success_response(&msg.header); let resp = build_success_response_with_integrity(&msg.header, &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
continue; continue;
@ -227,7 +321,7 @@ pub async fn udp_reader_loop(
None => { None => {
warn!("channel-bind without allocation from {}", peer); warn!("channel-bind without allocation from {}", peer);
let resp = let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch"); build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -242,10 +336,11 @@ pub async fn udp_reader_loop(
let (channel_attr, peer_attr) = match (channel_attr, peer_attr) { let (channel_attr, peer_attr) = match (channel_attr, peer_attr) {
(Some(c), Some(p)) => (c, p), (Some(c), Some(p)) => (c, p),
_ => { _ => {
let resp = build_error_response( let resp = build_error_response_with_integrity(
&msg.header, &msg.header,
400, 400,
"Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS", "Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS",
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
@ -260,10 +355,11 @@ pub async fn udp_reader_loop(
) { ) {
Some(addr) => addr, Some(addr) => addr,
None => { None => {
let resp = build_error_response( let resp = build_error_response_with_integrity(
&msg.header, &msg.header,
400, 400,
"Invalid XOR-PEER-ADDRESS", "Invalid XOR-PEER-ADDRESS",
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
@ -271,7 +367,12 @@ pub async fn udp_reader_loop(
}; };
if !allocation.is_peer_allowed(&peer_addr) { if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); let resp = build_error_response_with_integrity(
&msg.header,
403,
"Peer Not Permitted",
&key,
);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -284,13 +385,19 @@ pub async fn udp_reader_loop(
channel, channel,
e e
); );
let resp = let (code, reason) = match e.downcast_ref::<AllocationError>() {
build_error_response(&msg.header, 500, "Channel Bind Failed"); Some(AllocationError::ChannelQuotaExceeded) => {
(508, "Insufficient Capacity")
}
_ => (500, "Channel Bind Failed"),
};
let resp = build_error_response_with_integrity(&msg.header, code, reason, &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
let resp = build_success_response(&msg.header); crate::metrics::inc_channel_binding_added();
let resp = build_success_response_with_integrity(&msg.header, &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -300,7 +407,7 @@ pub async fn udp_reader_loop(
None => { None => {
warn!("send indication without allocation from {}", peer); warn!("send indication without allocation from {}", peer);
let resp = let resp =
build_error_response(&msg.header, 437, "Allocation Mismatch"); build_error_response_with_integrity(&msg.header, 437, "Allocation Mismatch", &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -314,10 +421,11 @@ pub async fn udp_reader_loop(
let (peer_attr, data_attr) = match (peer_attr, data_attr) { let (peer_attr, data_attr) = match (peer_attr, data_attr) {
(Some(p), Some(d)) => (p, d), (Some(p), Some(d)) => (p, d),
_ => { _ => {
let resp = build_error_response( let resp = build_error_response_with_integrity(
&msg.header, &msg.header,
400, 400,
"Missing DATA or XOR-PEER-ADDRESS", "Missing DATA or XOR-PEER-ADDRESS",
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
@ -330,10 +438,11 @@ pub async fn udp_reader_loop(
) { ) {
Some(addr) => addr, Some(addr) => addr,
None => { None => {
let resp = build_error_response( let resp = build_error_response_with_integrity(
&msg.header, &msg.header,
400, 400,
"Invalid XOR-PEER-ADDRESS", "Invalid XOR-PEER-ADDRESS",
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
@ -341,7 +450,12 @@ pub async fn udp_reader_loop(
}; };
if !allocation.is_peer_allowed(&peer_addr) { if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(&msg.header, 403, "Peer Not Permitted"); let resp = build_error_response_with_integrity(
&msg.header,
403,
"Peer Not Permitted",
&key,
);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -354,7 +468,7 @@ pub async fn udp_reader_loop(
peer, peer,
peer_addr peer_addr
); );
let resp = build_success_response(&msg.header); let resp = build_success_response_with_integrity(&msg.header, &key);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
Err(e) => { Err(e) => {
@ -364,8 +478,12 @@ pub async fn udp_reader_loop(
peer_addr, peer_addr,
e e
); );
let resp = let resp = build_error_response_with_integrity(
build_error_response(&msg.header, 500, "Peer Send Failed"); &msg.header,
500,
"Peer Send Failed",
&key,
);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
} }
@ -386,22 +504,32 @@ pub async fn udp_reader_loop(
applied.as_secs() applied.as_secs()
); );
} }
let resp = build_lifetime_success( let resp = build_lifetime_success_with_integrity(
&msg.header, &msg.header,
applied.as_secs().min(u32::MAX as u64) as u32, applied.as_secs().min(u32::MAX as u64) as u32,
&key,
); );
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
Err(_) => { Err(_) => {
let resp = let resp = build_error_response_with_integrity(
build_error_response(&msg.header, 437, "Allocation Mismatch"); &msg.header,
437,
"Allocation Mismatch",
&key,
);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
} }
} }
continue; continue;
} }
_ => { _ => {
let resp = build_error_response(&msg.header, 420, "Unknown TURN Method"); let resp = build_error_response_with_integrity(
&msg.header,
420,
"Unknown TURN Method",
&key,
);
let _ = udp.send_to(&resp, &peer).await; let _ = udp.send_to(&resp, &peer).await;
continue; continue;
} }
@ -410,10 +538,18 @@ pub async fn udp_reader_loop(
match msg.header.msg_type { match msg.header.msg_type {
METHOD_BINDING => { METHOD_BINDING => {
let resp = build_success_response(&msg.header); if rate_limiters.allow_binding(peer.ip()) {
let _ = udp.send_to(&resp, &peer).await; let resp = crate::stun::build_binding_success(&msg.header, &peer);
let _ = udp.send_to(&resp, &peer).await;
} else {
crate::metrics::inc_rate_limited();
}
} }
_ => { _ => {
if !rate_limiters.allow_unauth(peer.ip()) {
crate::metrics::inc_rate_limited();
continue;
}
let nonce = auth.mint_nonce(&peer); let nonce = auth.mint_nonce(&peer);
let resp = let resp =
build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized"); build_401_response(&msg.header, auth.realm(), &nonce, 401, "Unauthorized");

View File

@ -117,10 +117,7 @@ pub fn build_401_response(
} }
// Update length // Update length
let total_len = (buf.len() - 20) as u16; append_fingerprint(&mut buf);
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec() buf.to_vec()
} }
@ -150,14 +147,46 @@ pub fn build_error_response(req: &StunHeader, code: u16, reason: &str) -> Vec<u8
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]);
} }
let total_len = (buf.len() - 20) as u16; append_fingerprint(&mut buf);
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec() buf.to_vec()
} }
/// Build a generic STUN/TURN error response including MESSAGE-INTEGRITY and FINGERPRINT.
pub fn build_error_response_with_integrity(
req: &StunHeader,
code: u16,
reason: &str,
key: &[u8],
) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_ERROR;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let mut value = Vec::new();
let class = (code / 100) as u8;
let number = (code % 100) as u8;
value.extend_from_slice(&[0, 0]);
value.push(class);
value.push(number);
value.extend_from_slice(reason.as_bytes());
buf.extend_from_slice(&ATTR_ERROR_CODE.to_be_bytes());
buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
buf.extend_from_slice(&value);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
append_message_integrity(&mut buf, key);
append_fingerprint(&mut buf);
buf.to_vec()
}
/// Build an Allocate success response containing XOR-RELAYED-ADDRESS and LIFETIME attributes. /// Build an Allocate success response containing XOR-RELAYED-ADDRESS and LIFETIME attributes.
pub fn build_allocate_success( pub fn build_allocate_success(
req: &StunHeader, req: &StunHeader,
@ -188,10 +217,43 @@ pub fn build_allocate_success(
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]);
} }
let total_len = (buf.len() - 20) as u16; append_fingerprint(&mut buf);
let len_bytes = total_len.to_be_bytes(); buf.to_vec()
buf[2] = len_bytes[0]; }
buf[3] = len_bytes[1];
/// Build an Allocate success response including MESSAGE-INTEGRITY and FINGERPRINT.
pub fn build_allocate_success_with_integrity(
req: &StunHeader,
relay: &std::net::SocketAddr,
lifetime_secs: u32,
key: &[u8],
) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let relay_val = encode_xor_relayed_address(relay, &req.transaction_id);
buf.extend_from_slice(&ATTR_XOR_RELAYED_ADDRESS.to_be_bytes());
buf.extend_from_slice(&((relay_val.len() as u16).to_be_bytes()));
buf.extend_from_slice(&relay_val);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
let lifetime_bytes = lifetime_secs.to_be_bytes();
buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes());
buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(&lifetime_bytes);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
append_message_integrity(&mut buf, key);
append_fingerprint(&mut buf);
buf.to_vec() buf.to_vec()
} }
@ -213,10 +275,34 @@ pub fn build_lifetime_success(req: &StunHeader, lifetime_secs: u32) -> Vec<u8> {
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]);
} }
let total_len = (buf.len() - 20) as u16; append_fingerprint(&mut buf);
let len_bytes = total_len.to_be_bytes(); buf.to_vec()
buf[2] = len_bytes[0]; }
buf[3] = len_bytes[1];
/// Build a Refresh success response including MESSAGE-INTEGRITY and FINGERPRINT.
pub fn build_lifetime_success_with_integrity(
req: &StunHeader,
lifetime_secs: u32,
key: &[u8],
) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let lifetime_bytes = lifetime_secs.to_be_bytes();
buf.extend_from_slice(&ATTR_LIFETIME.to_be_bytes());
buf.extend_from_slice(&(lifetime_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(&lifetime_bytes);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
append_message_integrity(&mut buf, key);
append_fingerprint(&mut buf);
buf.to_vec() buf.to_vec()
} }
@ -239,6 +325,22 @@ pub fn extract_lifetime_seconds(msg: &StunMessage) -> Option<u32> {
}) })
} }
/// Extract REQUESTED-TRANSPORT IP protocol number from a TURN Allocate request.
///
/// RFC5766: value is 4 bytes: (protocol, rsvd, rsvd, rsvd)
pub fn extract_requested_transport_protocol(msg: &StunMessage) -> Option<u8> {
msg.attributes
.iter()
.find(|a| a.typ == ATTR_REQUESTED_TRANSPORT)
.and_then(|attr| {
if attr.value.len() >= 4 {
Some(attr.value[0])
} else {
None
}
})
}
/// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present /// Find MESSAGE-INTEGRITY attribute (ATTR_MESSAGE_INTEGRITY) if present
pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> { pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> {
msg.attributes msg.attributes
@ -246,25 +348,67 @@ pub fn find_message_integrity(msg: &StunMessage) -> Option<&StunAttribute> {
.find(|a| a.typ == ATTR_MESSAGE_INTEGRITY) .find(|a| a.typ == ATTR_MESSAGE_INTEGRITY)
} }
/// Validate MESSAGE-INTEGRITY using provided key (password). Returns true if valid. /// Validate MESSAGE-INTEGRITY using provided key. Returns true if valid.
/// Note: This is a simplified validator that assumes the MESSAGE-INTEGRITY attribute exists and ///
/// that the message bytes passed are the full STUN message (including attributes). /// RFC5389: The HMAC is computed over the message up to (and including) the
/// MESSAGE-INTEGRITY attribute, with the header length field set to the end of
/// that attribute. Any attributes after MESSAGE-INTEGRITY (e.g. FINGERPRINT)
/// are excluded from the HMAC computation.
pub fn validate_message_integrity(msg: &StunMessage, key: &[u8]) -> bool { pub fn validate_message_integrity(msg: &StunMessage, key: &[u8]) -> bool {
if let Some(mi) = find_message_integrity(msg) { if let Some(mi) = find_message_integrity(msg) {
// MESSAGE-INTEGRITY attribute value is 20 bytes (HMAC-SHA1) // MESSAGE-INTEGRITY attribute value is 20 bytes (HMAC-SHA1)
if mi.value.len() != 20 { if mi.value.len() != 20 {
return false; return false;
} }
// Compute HMAC over the message up to (but excluding) MESSAGE-INTEGRITY attribute header and value
let mi_attr_start = mi.offset; // offset points to attribute header let mi_end = mi.offset + 4 + HMAC_SHA1_LEN;
let msg_slice = &msg.raw[..mi_attr_start]; if mi_end > msg.raw.len() {
let computed = crate::stun::compute_message_integrity(key, msg_slice); return false;
// compare first 20 bytes }
return &computed[..20] == mi.value.as_slice();
let mut signed = msg.raw[..mi_end].to_vec();
// Adjust header length to end-of-MI (exclude later attributes like FINGERPRINT).
let len = (mi_end - 20) as u16;
let len_bytes = len.to_be_bytes();
signed[2] = len_bytes[0];
signed[3] = len_bytes[1];
// Zero the MI value before computing.
let value_start = mi.offset + 4;
for b in &mut signed[value_start..value_start + HMAC_SHA1_LEN] {
*b = 0;
}
let computed = crate::stun::compute_message_integrity(key, &signed);
return &computed[..HMAC_SHA1_LEN] == mi.value.as_slice();
} }
false false
} }
fn append_message_integrity(buf: &mut bytes::BytesMut, key: &[u8]) {
// Append attribute header and placeholder; set length to end-of-MI, then compute
// HMAC over the message slice up to end-of-MI (with the MI placeholder still zero).
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes()));
let mi_val_pos = buf.len();
buf.extend_from_slice(&[0u8; HMAC_SHA1_LEN]);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0u8]);
}
let mi_end = buf.len();
// Set length to end-of-MI (excluding any later attributes like FINGERPRINT)
let len = (mi_end - 20) as u16;
let len_bytes = len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
let hmac = compute_message_integrity(key, &buf[..mi_end]);
buf[mi_val_pos..mi_val_pos + HMAC_SHA1_LEN].copy_from_slice(&hmac[..HMAC_SHA1_LEN]);
}
/// Build a simple success (200) response echoing transaction id /// Build a simple success (200) response echoing transaction id
pub fn build_success_response(req: &StunHeader) -> Vec<u8> { pub fn build_success_response(req: &StunHeader) -> Vec<u8> {
use bytes::BytesMut; use bytes::BytesMut;
@ -274,11 +418,66 @@ pub fn build_success_response(req: &StunHeader) -> Vec<u8> {
buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES); buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id); buf.extend_from_slice(&req.transaction_id);
append_fingerprint(&mut buf);
buf.to_vec()
}
/// Build a simple success (200) response including MESSAGE-INTEGRITY and FINGERPRINT.
pub fn build_success_response_with_integrity(req: &StunHeader, key: &[u8]) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
append_message_integrity(&mut buf, key);
append_fingerprint(&mut buf);
buf.to_vec()
}
/// Build a STUN Binding success response containing XOR-MAPPED-ADDRESS.
pub fn build_binding_success(req: &StunHeader, mapped: &std::net::SocketAddr) -> Vec<u8> {
use bytes::BytesMut;
let mut buf = BytesMut::new();
let msg_type: u16 = req.msg_type | CLASS_SUCCESS;
buf.extend_from_slice(&msg_type.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE_BYTES);
buf.extend_from_slice(&req.transaction_id);
let mapped_val = encode_xor_address(mapped, &req.transaction_id);
buf.extend_from_slice(&ATTR_XOR_MAPPED_ADDRESS.to_be_bytes());
buf.extend_from_slice(&((mapped_val.len() as u16).to_be_bytes()));
buf.extend_from_slice(&mapped_val);
while (buf.len() % 4) != 0 {
buf.extend_from_slice(&[0]);
}
append_fingerprint(&mut buf);
buf.to_vec()
}
fn append_fingerprint(buf: &mut bytes::BytesMut) {
// FINGERPRINT must be the last attribute. We'll append a placeholder,
// update the message length to include it, compute CRC32 over the message
// up to (but excluding) the FINGERPRINT attribute, and then write it.
let fingerprint_attr_offset = buf.len();
buf.extend_from_slice(&ATTR_FINGERPRINT.to_be_bytes());
buf.extend_from_slice(&(4u16).to_be_bytes());
let fingerprint_value_pos = buf.len();
buf.extend_from_slice(&[0u8; 4]);
// Update STUN length (bytes after the 20-byte header)
let total_len = (buf.len() - 20) as u16; let total_len = (buf.len() - 20) as u16;
let len_bytes = total_len.to_be_bytes(); let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0]; buf[2] = len_bytes[0];
buf[3] = len_bytes[1]; buf[3] = len_bytes[1];
buf.to_vec()
let fp = compute_fingerprint(&buf[..fingerprint_attr_offset]);
let fp_bytes = fp.to_be_bytes();
buf[fingerprint_value_pos..fingerprint_value_pos + 4].copy_from_slice(&fp_bytes);
} }
/// Compute STUN fingerprint (XOR-32 of CRC32) /// Compute STUN fingerprint (XOR-32 of CRC32)
@ -290,6 +489,37 @@ pub fn compute_fingerprint(msg: &[u8]) -> u32 {
crc ^ FINGERPRINT_XOR crc ^ FINGERPRINT_XOR
} }
pub fn find_fingerprint(msg: &StunMessage) -> Option<&StunAttribute> {
msg.attributes
.iter()
.find(|a| a.typ == ATTR_FINGERPRINT)
}
/// Validate FINGERPRINT if present.
///
/// Returns true if the message has no FINGERPRINT attribute. If present, enforces:
/// - value length is exactly 4 bytes
/// - FINGERPRINT is the last attribute
/// - CRC32/XOR matches RFC5389
pub fn validate_fingerprint_if_present(msg: &StunMessage) -> bool {
let Some(fp) = find_fingerprint(msg) else {
return true;
};
if fp.value.len() != 4 {
return false;
}
// FINGERPRINT should be the last attribute (type+len+4)
if fp.offset + 8 != msg.raw.len() {
return false;
}
let expected = compute_fingerprint(&msg.raw[..fp.offset]);
let actual = u32::from_be_bytes([fp.value[0], fp.value[1], fp.value[2], fp.value[3]]);
expected == actual
}
/// Compute MESSAGE-INTEGRITY (HMAC-SHA1) over the message /// Compute MESSAGE-INTEGRITY (HMAC-SHA1) over the message
pub fn compute_message_integrity(key: &[u8], msg: &[u8]) -> Vec<u8> { pub fn compute_message_integrity(key: &[u8], msg: &[u8]) -> Vec<u8> {
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
@ -311,9 +541,6 @@ pub fn build_channel_data(channel: u16, data: &[u8]) -> Vec<u8> {
out.extend_from_slice(&channel.to_be_bytes()); out.extend_from_slice(&channel.to_be_bytes());
out.extend_from_slice(&(data.len() as u16).to_be_bytes()); out.extend_from_slice(&(data.len() as u16).to_be_bytes());
out.extend_from_slice(data); out.extend_from_slice(data);
while (out.len() % 4) != 0 {
out.push(0);
}
out out
} }
@ -344,10 +571,7 @@ pub fn build_data_indication(peer: &std::net::SocketAddr, data: &[u8]) -> Vec<u8
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]);
} }
let total_len = (buf.len() - 20) as u16; append_fingerprint(&mut buf);
let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0];
buf[3] = len_bytes[1];
buf.to_vec() buf.to_vec()
} }
@ -367,7 +591,7 @@ pub fn parse_channel_data(buf: &[u8]) -> Option<(u16, &[u8])> {
Some((channel, &buf[4..4 + data_len])) Some((channel, &buf[4..4 + data_len]))
} }
fn encode_xor_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec<u8> { fn encode_xor_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12]) -> Vec<u8> {
use std::net::IpAddr; use std::net::IpAddr;
let mut out = Vec::new(); let mut out = Vec::new();
match addr.ip() { match addr.ip() {
@ -385,9 +609,22 @@ fn encode_xor_address(addr: &std::net::SocketAddr, _trans_id: &[u8; 12]) -> Vec<
out.push(octets[i] ^ cookie_bytes[i]); out.push(octets[i] ^ cookie_bytes[i]);
} }
} }
IpAddr::V6(_v6) => { IpAddr::V6(v6) => {
// For now, we don't support IPv6 in this MVP implementation out.push(0);
// Return an empty vec to indicate unsupported out.push(FAMILY_IPV6);
let port = addr.port();
let xport = (port ^ ((MAGIC_COOKIE_U32 >> 16) as u16)) as u16;
out.extend_from_slice(&xport.to_be_bytes());
// RFC5389: IPv6 XOR uses magic cookie (4) concatenated with transaction id (12)
let mut xor_key = [0u8; 16];
xor_key[..4].copy_from_slice(&MAGIC_COOKIE_BYTES);
xor_key[4..].copy_from_slice(trans_id);
let octets = v6.octets();
for i in 0..16 {
out.push(octets[i] ^ xor_key[i]);
}
} }
} }
out out
@ -401,36 +638,58 @@ pub fn encode_xor_peer_address(addr: &std::net::SocketAddr, trans_id: &[u8; 12])
encode_xor_address(addr, trans_id) encode_xor_address(addr, trans_id)
} }
/// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4 only) /// Decode XOR-RELAYED-ADDRESS attribute value into SocketAddr (IPv4/IPv6)
pub fn decode_xor_relayed_address( pub fn decode_xor_relayed_address(
value: &[u8], value: &[u8],
_trans_id: &[u8; 12], trans_id: &[u8; 12],
) -> Option<std::net::SocketAddr> { ) -> Option<std::net::SocketAddr> {
if value.len() < 8 { if value.len() < 4 {
return None; return None;
} }
if value[1] != FAMILY_IPV4 {
return None;
} // not IPv4
let xport = u16::from_be_bytes([value[2], value[3]]); let xport = u16::from_be_bytes([value[2], value[3]]);
let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16); let port = xport ^ ((MAGIC_COOKIE_U32 >> 16) as u16);
let cookie_bytes = MAGIC_COOKIE_BYTES;
let mut ipb = [0u8; 4]; match value[1] {
for i in 0..4 { FAMILY_IPV4 => {
ipb[i] = value[4 + i] ^ cookie_bytes[i]; if value.len() < 8 {
return None;
}
let cookie_bytes = MAGIC_COOKIE_BYTES;
let mut ipb = [0u8; 4];
for i in 0..4 {
ipb[i] = value[4 + i] ^ cookie_bytes[i];
}
let ip = std::net::Ipv4Addr::from(ipb);
Some(std::net::SocketAddr::new(std::net::IpAddr::V4(ip), port))
}
FAMILY_IPV6 => {
if value.len() < 20 {
return None;
}
let mut xor_key = [0u8; 16];
xor_key[..4].copy_from_slice(&MAGIC_COOKIE_BYTES);
xor_key[4..].copy_from_slice(trans_id);
let mut ipb = [0u8; 16];
for i in 0..16 {
ipb[i] = value[4 + i] ^ xor_key[i];
}
let ip = std::net::Ipv6Addr::from(ipb);
Some(std::net::SocketAddr::new(std::net::IpAddr::V6(ip), port))
}
_ => None,
} }
let ip = std::net::Ipv4Addr::from(ipb);
Some(std::net::SocketAddr::new(std::net::IpAddr::V4(ip), port))
} }
/// Decode XOR-PEER-ADDRESS / XOR-MAPPED-ADDRESS style attributes (IPv4 only). /// Decode XOR-PEER-ADDRESS / XOR-MAPPED-ADDRESS style attributes (IPv4/IPv6).
pub fn decode_xor_peer_address(value: &[u8], _trans_id: &[u8; 12]) -> Option<std::net::SocketAddr> { pub fn decode_xor_peer_address(value: &[u8], trans_id: &[u8; 12]) -> Option<std::net::SocketAddr> {
decode_xor_relayed_address(value, _trans_id) decode_xor_relayed_address(value, trans_id)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[test] #[test]
fn parse_minimal_binding() { fn parse_minimal_binding() {
@ -445,6 +704,7 @@ mod tests {
assert_eq!(msg.header.msg_type, METHOD_BINDING); assert_eq!(msg.header.msg_type, METHOD_BINDING);
assert_eq!(msg.header.transaction_id, trans); assert_eq!(msg.header.transaction_id, trans);
assert!(msg.attributes.is_empty()); assert!(msg.attributes.is_empty());
assert!(validate_fingerprint_if_present(&msg));
} }
#[test] #[test]
@ -459,6 +719,30 @@ mod tests {
// parse back should succeed // parse back should succeed
let parsed = parse_message(&out).expect("parse resp"); let parsed = parse_message(&out).expect("parse resp");
assert!(!parsed.attributes.is_empty()); assert!(!parsed.attributes.is_empty());
assert!(validate_fingerprint_if_present(&parsed));
}
#[test]
fn fingerprint_is_appended_and_valid() {
let req = StunHeader {
msg_type: METHOD_BINDING,
length: 0,
cookie: MAGIC_COOKIE_U32,
transaction_id: [4u8; 12],
};
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 9999);
let out = build_binding_success(&req, &peer);
let parsed = parse_message(&out).expect("parse");
assert!(find_fingerprint(&parsed).is_some());
assert!(validate_fingerprint_if_present(&parsed));
// Tamper any byte before the fingerprint attribute => invalid
let fp = find_fingerprint(&parsed).unwrap();
let mut tampered = out.clone();
tampered[fp.offset.saturating_sub(1)] ^= 0xFF;
let parsed2 = parse_message(&tampered).expect("parse tampered");
assert!(!validate_fingerprint_if_present(&parsed2));
} }
#[test] #[test]
@ -486,7 +770,6 @@ mod tests {
} }
// MESSAGE-INTEGRITY placeholder (0x0008) length 20 // MESSAGE-INTEGRITY placeholder (0x0008) length 20
let mi_attr_offset = buf.len();
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes())); buf.extend_from_slice(&((HMAC_SHA1_LEN as u16).to_be_bytes()));
let mi_val_pos = buf.len(); let mi_val_pos = buf.len();
@ -495,19 +778,23 @@ mod tests {
buf.extend_from_slice(&[0u8]); buf.extend_from_slice(&[0u8]);
} }
// Fix length // Fix length to end-of-MI
let total_len = (buf.len() - 20) as u16; let mi_end = buf.len();
let total_len = (mi_end - 20) as u16;
let len_bytes = total_len.to_be_bytes(); let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0]; buf[2] = len_bytes[0];
buf[3] = len_bytes[1]; buf[3] = len_bytes[1];
// Compute HMAC over message up to MI attribute header (mi_attr_offset) // Compute HMAC over message up to end-of-MI (MI value is still zero here)
let hmac = compute_message_integrity(password.as_bytes(), &buf[..mi_attr_offset]); let hmac = compute_message_integrity(password.as_bytes(), &buf[..mi_end]);
// place first 20 bytes into mi value // place first 20 bytes into mi value
for i in 0..20 { for i in 0..20 {
buf[mi_val_pos + i] = hmac[i]; buf[mi_val_pos + i] = hmac[i];
} }
// Add FINGERPRINT after MESSAGE-INTEGRITY and ensure validation still succeeds.
append_fingerprint(&mut buf);
// Parse and validate // Parse and validate
let parsed = parse_message(&buf).expect("parsed"); let parsed = parse_message(&buf).expect("parsed");
assert!(validate_message_integrity(&parsed, password.as_bytes())); assert!(validate_message_integrity(&parsed, password.as_bytes()));
@ -518,4 +805,44 @@ mod tests {
let parsed2 = parse_message(&tampered).expect("parsed2"); let parsed2 = parse_message(&tampered).expect("parsed2");
assert!(!validate_message_integrity(&parsed2, password.as_bytes())); assert!(!validate_message_integrity(&parsed2, password.as_bytes()));
} }
#[test]
fn build_binding_success_includes_xor_mapped_ipv4() {
let req = StunHeader {
msg_type: METHOD_BINDING,
length: 0,
cookie: MAGIC_COOKIE_U32,
transaction_id: [7u8; 12],
};
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 10)), 54321);
let out = build_binding_success(&req, &peer);
let parsed = parse_message(&out).expect("parse");
let attr = parsed
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_MAPPED_ADDRESS)
.expect("xor-mapped present");
let decoded = decode_xor_peer_address(&attr.value, &req.transaction_id).expect("decode");
assert_eq!(decoded, peer);
}
#[test]
fn build_binding_success_includes_xor_mapped_ipv6() {
let req = StunHeader {
msg_type: METHOD_BINDING,
length: 0,
cookie: MAGIC_COOKIE_U32,
transaction_id: [8u8; 12],
};
let peer = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 12345);
let out = build_binding_success(&req, &peer);
let parsed = parse_message(&out).expect("parse");
let attr = parsed
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_MAPPED_ADDRESS)
.expect("xor-mapped present");
let decoded = decode_xor_peer_address(&attr.value, &req.transaction_id).expect("decode");
assert_eq!(decoded, peer);
}
} }

38
src/tcp.rs Normal file
View File

@ -0,0 +1,38 @@
//! Plain TCP listener for TURN over TCP (RFC5389 framing + ChannelData interleaving).
use tokio::net::TcpListener;
use crate::alloc::AllocationManager;
use crate::auth::{AuthManager, InMemoryStore};
use crate::rate_limit::RateLimiters;
use crate::turn_stream::handle_turn_stream_connection_with_limits;
pub async fn serve_tcp(
bind: &str,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()> {
serve_tcp_with_limits(bind, auth, allocs, std::sync::Arc::new(RateLimiters::disabled())).await
}
pub async fn serve_tcp_with_limits(
bind: &str,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: std::sync::Arc<RateLimiters>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(bind).await?;
tracing::info!("TCP listener bound to {}", bind);
loop {
let (stream, peer) = listener.accept().await?;
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move {
if let Err(e) = handle_turn_stream_connection_with_limits(stream, peer, auth_clone, alloc_clone, rl).await {
tracing::info!("TCP connection ended for {}: {:?}", peer, e);
}
});
}
}

View File

@ -4,19 +4,15 @@ use anyhow::Context;
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::alloc::AllocationManager; use crate::alloc::AllocationManager;
use crate::auth::{AuthManager, AuthStatus, InMemoryStore}; use crate::auth::{AuthManager, InMemoryStore};
use crate::constants::*; use crate::rate_limit::RateLimiters;
use crate::stun::{ use crate::turn_stream::{handle_turn_stream_connection, handle_turn_stream_connection_with_limits};
build_401_response, build_allocate_success, build_error_response, build_lifetime_success,
build_success_response, decode_xor_peer_address, extract_lifetime_seconds, parse_message,
};
fn load_certs(path: &str) -> anyhow::Result<Vec<Certificate>> { fn load_certs(path: &str) -> anyhow::Result<Vec<Certificate>> {
let f = File::open(path).context("opening cert file")?; let f = File::open(path).context("opening cert file")?;
@ -48,9 +44,27 @@ pub async fn serve_tls(
bind: &str, bind: &str,
cert_path: &str, cert_path: &str,
key_path: &str, key_path: &str,
udp_sock: std::sync::Arc<tokio::net::UdpSocket>,
auth: AuthManager<InMemoryStore>, auth: AuthManager<InMemoryStore>,
allocs: AllocationManager, allocs: AllocationManager,
) -> anyhow::Result<()> {
serve_tls_with_limits(
bind,
cert_path,
key_path,
auth,
allocs,
std::sync::Arc::new(RateLimiters::disabled()),
)
.await
}
pub async fn serve_tls_with_limits(
bind: &str,
cert_path: &str,
key_path: &str,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: std::sync::Arc<RateLimiters>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let certs = load_certs(cert_path)?; let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?; let key = load_private_key(key_path)?;
@ -68,23 +82,15 @@ pub async fn serve_tls(
loop { loop {
let (stream, peer) = listener.accept().await?; let (stream, peer) = listener.accept().await?;
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let udp_clone = udp_sock.clone();
let auth_clone = auth.clone(); let auth_clone = auth.clone();
let alloc_clone = allocs.clone(); let alloc_clone = allocs.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move { tokio::spawn(async move {
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(mut tls_stream) => { Ok(tls_stream) => {
if let Err(e) = handle_tls_connection( if let Err(e) = handle_tls_connection_with_limits(tls_stream, peer, auth_clone, alloc_clone, rl).await {
&mut tls_stream, tracing::info!("TLS connection ended for {}: {:?}", peer, e);
peer,
udp_clone,
auth_clone,
alloc_clone,
)
.await
{
tracing::error!("TLS connection error: {:?}", e);
} }
} }
Err(e) => tracing::error!("TLS accept error: {:?}", e), Err(e) => tracing::error!("TLS accept error: {:?}", e),
@ -93,596 +99,27 @@ pub async fn serve_tls(
} }
} }
#[allow(clippy::too_many_arguments)]
pub async fn handle_tls_connection<S>( pub async fn handle_tls_connection<S>(
tls_stream: &mut S, tls_stream: S,
peer: std::net::SocketAddr, peer: std::net::SocketAddr,
udp_sock: std::sync::Arc<tokio::net::UdpSocket>,
auth: AuthManager<InMemoryStore>, auth: AuthManager<InMemoryStore>,
allocs: AllocationManager, allocs: AllocationManager,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{ {
tracing::info!("accepted TLS connection from {}", peer); handle_turn_stream_connection(tls_stream, peer, auth, allocs).await
let mut read_buf = vec![0u8; 4096]; }
let mut buffer: Vec<u8> = Vec::new();
pub async fn handle_tls_connection_with_limits<S>(
loop { tls_stream: S,
match tls_stream.read(&mut read_buf).await { peer: std::net::SocketAddr,
Ok(0) => { auth: AuthManager<InMemoryStore>,
tracing::info!("TLS client {} closed connection", peer); allocs: AllocationManager,
break; rate_limiters: std::sync::Arc<RateLimiters>,
} ) -> anyhow::Result<()>
Ok(n) => { where
buffer.extend_from_slice(&read_buf[..n]); S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
while buffer.len() >= 20 { {
let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; handle_turn_stream_connection_with_limits(tls_stream, peer, auth, allocs, rate_limiters).await
let total = len + 20;
if buffer.len() < total {
break;
}
let chunk = buffer.drain(..total).collect::<Vec<u8>>();
match parse_message(&chunk) {
Ok(msg) => {
tracing::info!(
"STUN/TURN over TLS from {} type=0x{:04x} len={}",
peer,
msg.header.msg_type,
total
);
let requires_auth = matches!(
msg.header.msg_type,
METHOD_ALLOCATE
| METHOD_CREATE_PERMISSION
| METHOD_CHANNEL_BIND
| METHOD_SEND
| METHOD_REFRESH
);
if requires_auth {
match auth.authenticate(&msg, &peer).await {
AuthStatus::Granted { username } => {
tracing::debug!(
"TURN TLS auth ok for {} as {} (0x{:04x})",
peer,
username,
msg.header.msg_type
);
}
AuthStatus::Challenge { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
401,
"Unauthorized",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls challenge: {:?}",
e
);
}
continue;
}
AuthStatus::StaleNonce { nonce } => {
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
438,
"Stale Nonce",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls stale nonce: {:?}",
e
);
}
continue;
}
AuthStatus::Reject { code, reason } => {
let resp = build_error_response(&msg.header, code, reason);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls auth error: {:?}",
e
);
}
continue;
}
}
}
match msg.header.msg_type {
METHOD_ALLOCATE => {
let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero());
match allocs.allocate_for(peer, udp_sock.clone()).await {
Ok(relay_addr) => {
let applied = match allocs
.refresh_allocation(peer, requested_lifetime)
{
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply TLS lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
if let Err(e2) =
tls_stream.write_all(&resp).await
{
tracing::error!(
"failed to write tls allocate error: {:?}",
e2
);
}
continue;
}
};
let lifetime_secs =
applied.as_secs().min(u32::MAX as u64) as u32;
let resp = build_allocate_success(
&msg.header,
&relay_addr,
lifetime_secs,
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls allocate success: {:?}",
e
);
}
}
Err(e) => {
tracing::error!("allocate failed (tls): {:?}", e);
let resp = build_error_response(
&msg.header,
500,
"Allocate Failed",
);
if let Err(e2) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls allocate error: {:?}",
e2
);
}
}
}
continue;
}
METHOD_CREATE_PERMISSION => {
if allocs.get_allocation(&peer).is_none() {
tracing::warn!(
"create-permission without allocation from {} (tls)",
peer
);
let resp = build_error_response(
&msg.header,
437,
"Allocation Mismatch",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!("failed to write tls error: {:?}", e);
}
continue;
}
let mut added = 0usize;
for attr in msg
.attributes
.iter()
.filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS)
{
if let Some(peer_addr) = decode_xor_peer_address(
&attr.value,
&msg.header.transaction_id,
) {
match allocs.add_permission(peer, peer_addr) {
Ok(()) => {
tracing::info!(
"added TLS permission for {} -> {}",
peer,
peer_addr
);
added += 1;
}
Err(e) => {
tracing::error!(
"failed to persist TLS permission {} -> {}: {:?}",
peer,
peer_addr,
e
);
}
}
} else {
tracing::warn!(
"invalid XOR-PEER-ADDRESS via TLS from {}",
peer
);
}
}
let resp = if added == 0 {
build_error_response(
&msg.header,
400,
"No valid XOR-PEER-ADDRESS",
)
} else {
build_success_response(&msg.header)
};
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!("failed to write tls response: {:?}", e);
}
continue;
}
METHOD_CHANNEL_BIND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!(
"channel-bind without allocation from {} (tls)",
peer
);
let resp = build_error_response(
&msg.header,
437,
"Allocation Mismatch",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
let channel_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_CHANNEL_NUMBER);
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let channel = match channel_attr.and_then(|attr| {
if attr.value.len() >= 4 {
Some(u16::from_be_bytes([attr.value[0], attr.value[1]]))
} else {
None
}
}) {
Some(c) => c,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing CHANNEL-NUMBER",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
if channel < 0x4000 || channel > 0x7FFF {
let resp = build_error_response(
&msg.header,
400,
"Channel Out Of Range",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!("failed to write tls error: {:?}", e);
}
continue;
}
let peer_addr = match peer_attr.and_then(|attr| {
decode_xor_peer_address(
&attr.value,
&msg.header.transaction_id,
)
}) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing XOR-PEER-ADDRESS",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(
&msg.header,
403,
"Peer Not Permitted",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!("failed to write tls error: {:?}", e);
}
continue;
}
match allocs.add_channel_binding(peer, channel, peer_addr) {
Ok(()) => {
tracing::info!(
"bound channel 0x{:04x} for {} -> {} over TLS",
channel,
peer,
peer_addr
);
let resp = build_success_response(&msg.header);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls response: {:?}",
e
);
}
}
Err(e) => {
tracing::error!(
"failed TLS channel binding {} -> {} (0x{:04x}): {:?}",
peer,
peer_addr,
channel,
e
);
let resp = build_error_response(
&msg.header,
500,
"Channel Binding Failed",
);
if let Err(e2) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e2
);
}
}
}
continue;
}
METHOD_SEND => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!(
"send without allocation from {} (tls)",
peer
);
let resp = build_error_response(
&msg.header,
437,
"Allocation Mismatch",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let data_attr =
msg.attributes.iter().find(|a| a.typ == ATTR_DATA);
let peer_addr = match peer_attr.and_then(|attr| {
decode_xor_peer_address(
&attr.value,
&msg.header.transaction_id,
)
}) {
Some(addr) => addr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing XOR-PEER-ADDRESS",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
let data_attr = match data_attr {
Some(attr) => attr,
None => {
let resp = build_error_response(
&msg.header,
400,
"Missing DATA Attribute",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e
);
}
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response(
&msg.header,
403,
"Peer Not Permitted",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!("failed to write tls error: {:?}", e);
}
continue;
}
match allocation.send_to_peer(peer_addr, &data_attr.value).await
{
Ok(sent) => {
tracing::info!(
"forwarded {} bytes from {} to {} via TLS session",
sent,
peer,
peer_addr
);
let resp = build_success_response(&msg.header);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls response: {:?}",
e
);
}
}
Err(e) => {
tracing::error!(
"failed to send payload from {} to {} via TLS: {:?}",
peer,
peer_addr,
e
);
let resp = build_error_response(
&msg.header,
500,
"Peer Send Failed",
);
if let Err(e2) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls error: {:?}",
e2
);
}
}
}
continue;
}
METHOD_REFRESH => {
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match allocs.refresh_allocation(peer, requested) {
Ok(applied) => {
if applied.is_zero() {
tracing::info!(
"allocation for {} released (tls)",
peer
);
} else {
tracing::debug!(
"allocation for {} refreshed to {}s (tls)",
peer,
applied.as_secs()
);
}
let resp = build_lifetime_success(
&msg.header,
applied.as_secs().min(u32::MAX as u64) as u32,
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls refresh response: {:?}",
e
);
}
}
Err(_) => {
let resp = build_error_response(
&msg.header,
437,
"Allocation Mismatch",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls refresh error: {:?}",
e
);
}
}
}
continue;
}
METHOD_BINDING => {
let resp = build_success_response(&msg.header);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls binding response: {:?}",
e
);
}
continue;
}
_ => {
let nonce = auth.mint_nonce(&peer);
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
401,
"Unauthorized",
);
if let Err(e) = tls_stream.write_all(&resp).await {
tracing::error!(
"failed to write tls fallback challenge: {:?}",
e
);
}
continue;
}
}
}
Err(e) => {
tracing::warn!(
error = ?e,
length = chunk.len(),
"dropping unparseable STUN/TURN frame over TLS from {}",
peer
);
}
}
}
}
Err(e) => {
tracing::error!("tls read error from {}: {:?}", peer, e);
break;
}
}
}
Ok(())
} }

715
src/turn_stream.rs Normal file
View File

@ -0,0 +1,715 @@
//! Shared TURN-over-stream (TCP/TLS) handler.
//!
//! This implements TURN over a byte stream where STUN messages and ChannelData frames
//! may be interleaved on the same connection.
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
use crate::alloc::{AllocationManager, ClientSink};
use crate::alloc::AllocationError;
use crate::auth::{AuthManager, AuthStatus, InMemoryStore};
use crate::constants::*;
use crate::stun::{
build_401_response, build_allocate_success_with_integrity, build_error_response,
build_error_response_with_integrity, build_lifetime_success_with_integrity,
build_success_response_with_integrity, decode_xor_peer_address, extract_lifetime_seconds,
parse_message, validate_fingerprint_if_present, extract_requested_transport_protocol,
};
use crate::rate_limit::RateLimiters;
enum StreamFrame {
ChannelData { channel: u16, payload: Vec<u8> },
Stun(crate::models::stun::StunMessage),
}
const MAX_STREAM_BUFFER_BYTES: usize = 256 * 1024;
const MAX_STUN_BODY_BYTES: usize = 64 * 1024;
const MAX_CHANNELDATA_BYTES: usize = 64 * 1024;
fn try_pop_next_frame(buffer: &mut Vec<u8>) -> Option<anyhow::Result<StreamFrame>> {
if buffer.len() < 4 {
return None;
}
// ChannelData detection: channel number is 0x4000..0x7FFF
let channel = u16::from_be_bytes([buffer[0], buffer[1]]);
if (channel & 0xC000) == 0x4000 {
let data_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
if data_len > MAX_CHANNELDATA_BYTES {
buffer.drain(..1);
return Some(Err(anyhow::anyhow!(
"channeldata length {} exceeds max {}",
data_len,
MAX_CHANNELDATA_BYTES
)));
}
let total = 4 + data_len;
if buffer.len() < total {
return None;
}
let frame = buffer.drain(..total).collect::<Vec<u8>>();
let payload = frame[4..].to_vec();
return Some(Ok(StreamFrame::ChannelData { channel, payload }));
}
// Otherwise assume STUN over stream: 20-byte header + length.
if buffer.len() < 20 {
return None;
}
// Quick resync: STUN messages must include the magic cookie.
if buffer[4..8] != MAGIC_COOKIE_BYTES {
buffer.drain(..1);
return Some(Err(anyhow::anyhow!("invalid STUN magic cookie")));
}
let len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
if len > MAX_STUN_BODY_BYTES {
buffer.drain(..1);
return Some(Err(anyhow::anyhow!(
"stun length {} exceeds max {}",
len,
MAX_STUN_BODY_BYTES
)));
}
let total = 20 + len;
if buffer.len() < total {
return None;
}
let chunk = buffer.drain(..total).collect::<Vec<u8>>();
Some(parse_message(&chunk).map(StreamFrame::Stun).map_err(|e| anyhow::anyhow!("parse stun: {e:?}")))
}
/// Handle a single TURN-over-stream connection (plain TCP or TLS).
///
/// All server responses and relayed peer packets are written back over the same stream.
#[allow(clippy::too_many_arguments)]
pub async fn handle_turn_stream_connection<S>(
stream: S,
peer: SocketAddr,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> anyhow::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
handle_turn_stream_connection_with_limits(
stream,
peer,
auth,
allocs,
std::sync::Arc::new(RateLimiters::disabled()),
)
.await
}
/// Handle a single TURN-over-stream connection (plain TCP or TLS) with explicit rate limiters.
#[allow(clippy::too_many_arguments)]
pub async fn handle_turn_stream_connection_with_limits<S>(
stream: S,
peer: SocketAddr,
auth: AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: std::sync::Arc<RateLimiters>,
) -> anyhow::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
tracing::info!("accepted TURN stream connection from {}", peer);
crate::metrics::inc_stream_connections();
let (mut reader, mut writer) = tokio::io::split(stream);
// Single ordered output queue for both direct responses and relayed peer traffic.
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(256);
let writer_task = tokio::spawn(async move {
while let Some(buf) = rx.recv().await {
if let Err(e) = writer.write_all(&buf).await {
tracing::info!("stream writer closed: {:?}", e);
break;
}
}
});
let mut read_buf = vec![0u8; 4096];
let mut buffer: Vec<u8> = Vec::new();
loop {
match reader.read(&mut read_buf).await {
Ok(0) => {
tracing::info!("TURN stream client {} closed connection", peer);
break;
}
Ok(n) => {
buffer.extend_from_slice(&read_buf[..n]);
if buffer.len() > MAX_STREAM_BUFFER_BYTES {
tracing::warn!(
"closing stream connection {} due to oversized buffer ({} bytes)",
peer,
buffer.len()
);
break;
}
while let Some(frame_res) = try_pop_next_frame(&mut buffer) {
let frame = match frame_res {
Ok(f) => f,
Err(e) => {
tracing::debug!("dropping invalid frame from {}: {:?}", peer, e);
continue;
}
};
match frame {
StreamFrame::ChannelData { channel, payload } => {
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
tracing::warn!("channel data without allocation from {}", peer);
continue;
}
};
let target = match allocation.channel_peer(channel) {
Some(addr) => addr,
None => {
tracing::warn!(
"channel data with unknown channel 0x{:04x} from {}",
channel,
peer
);
continue;
}
};
if !allocation.is_peer_allowed(&target) {
tracing::warn!(
"channel data target {} no longer permitted for {}",
target,
peer
);
continue;
}
if let Err(e) = allocation.send_to_peer(target, &payload).await {
tracing::error!(
"failed to forward channel data 0x{:04x} from {} to {}: {:?}",
channel,
peer,
target,
e
);
}
}
StreamFrame::Stun(msg) => {
if !validate_fingerprint_if_present(&msg) {
tracing::debug!(
"dropping STUN/TURN over stream from {} due to invalid FINGERPRINT",
peer
);
continue;
}
crate::metrics::inc_stun_messages();
tracing::info!(
"STUN/TURN over stream from {} type=0x{:04x}",
peer,
msg.header.msg_type
);
let requires_auth = matches!(
msg.header.msg_type,
METHOD_ALLOCATE
| METHOD_CREATE_PERMISSION
| METHOD_CHANNEL_BIND
| METHOD_SEND
| METHOD_REFRESH
);
let mut auth_key: Option<Vec<u8>> = None;
if requires_auth {
match auth.authenticate(&msg, &peer).await {
AuthStatus::Granted { username, key } => {
auth_key = Some(key);
tracing::debug!(
"TURN stream auth ok for {} as {} (0x{:04x})",
peer,
username,
msg.header.msg_type
);
}
AuthStatus::Challenge { nonce } => {
crate::metrics::inc_auth_challenge();
if !rate_limiters.allow_unauth(peer.ip()) {
crate::metrics::inc_rate_limited();
continue;
}
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
401,
"Unauthorized",
);
let _ = tx.send(resp).await;
continue;
}
AuthStatus::StaleNonce { nonce } => {
crate::metrics::inc_auth_stale();
if !rate_limiters.allow_unauth(peer.ip()) {
crate::metrics::inc_rate_limited();
continue;
}
let resp = build_401_response(
&msg.header,
auth.realm(),
&nonce,
438,
"Stale Nonce",
);
let _ = tx.send(resp).await;
continue;
}
AuthStatus::Reject { code, reason } => {
crate::metrics::inc_auth_reject();
let resp = build_error_response(&msg.header, code, reason);
let _ = tx.send(resp).await;
continue;
}
}
}
match msg.header.msg_type {
METHOD_ALLOCATE => {
crate::metrics::inc_allocate_total();
let key = auth_key
.as_deref()
.expect("auth key must be set after AuthStatus::Granted");
// TURN Allocate MUST include REQUESTED-TRANSPORT; WebRTC expects UDP (17).
match extract_requested_transport_protocol(&msg) {
Some(IPPROTO_UDP) => {}
Some(_) => {
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(
&msg.header,
442,
"Unsupported Transport",
key,
);
let _ = tx.send(resp).await;
continue;
}
None => {
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Missing REQUESTED-TRANSPORT",
key,
);
let _ = tx.send(resp).await;
continue;
}
}
let requested_lifetime = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64))
.filter(|d| !d.is_zero());
match allocs
.allocate_for(
peer,
ClientSink::Stream {
tx: tx.clone(),
},
)
.await
{
Ok(relay_addr) => {
let applied = match allocs
.refresh_allocation(peer, requested_lifetime)
{
Ok(d) => d,
Err(e) => {
tracing::error!(
"failed to apply lifetime for {}: {:?}",
peer,
e
);
let resp = build_error_response_with_integrity(
&msg.header,
500,
"Allocate Failed",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
let lifetime_secs =
applied.as_secs().min(u32::MAX as u64) as u32;
let advertised =
allocs.relay_addr_for_response(relay_addr);
let resp = build_allocate_success_with_integrity(
&msg.header,
&advertised,
lifetime_secs,
key,
);
crate::metrics::inc_allocate_success();
let _ = tx.send(resp).await;
}
Err(e) => {
tracing::error!("allocate failed (stream): {:?}", e);
let (code, reason) = match e.downcast_ref::<AllocationError>() {
Some(AllocationError::AllocationQuotaExceeded) => {
(486, "Allocation Quota Reached")
}
_ => (500, "Allocate Failed"),
};
crate::metrics::inc_allocate_fail();
let resp = build_error_response_with_integrity(
&msg.header,
code,
reason,
key,
);
let _ = tx.send(resp).await;
}
}
}
METHOD_CREATE_PERMISSION => {
let key = auth_key
.as_deref()
.expect("auth key must be set after AuthStatus::Granted");
if allocs.get_allocation(&peer).is_none() {
let resp = build_error_response_with_integrity(
&msg.header,
437,
"Allocation Mismatch",
key,
);
let _ = tx.send(resp).await;
continue;
}
let mut added = 0usize;
for attr in msg
.attributes
.iter()
.filter(|a| a.typ == ATTR_XOR_PEER_ADDRESS)
{
if let Some(peer_addr) = decode_xor_peer_address(
&attr.value,
&msg.header.transaction_id,
) {
match allocs.add_permission(peer, peer_addr) {
Ok(()) => added += 1,
Err(e) => {
if matches!(
e.downcast_ref::<AllocationError>(),
Some(AllocationError::PermissionQuotaExceeded)
) {
let resp = build_error_response_with_integrity(
&msg.header,
508,
"Insufficient Capacity",
key,
);
let _ = tx.send(resp).await;
continue;
}
}
}
}
}
if added == 0 {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"No valid XOR-PEER-ADDRESS",
key,
);
let _ = tx.send(resp).await;
} else {
let resp = build_success_response_with_integrity(&msg.header, key);
let _ = tx.send(resp).await;
}
}
METHOD_CHANNEL_BIND => {
let key = auth_key
.as_deref()
.expect("auth key must be set after AuthStatus::Granted");
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
let resp = build_error_response_with_integrity(
&msg.header,
437,
"Allocation Mismatch",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
let channel_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_CHANNEL_NUMBER);
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let (channel_attr, peer_attr) = match (channel_attr, peer_attr)
{
(Some(c), Some(p)) => (c, p),
_ => {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Missing CHANNEL-NUMBER or XOR-PEER-ADDRESS",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
if channel_attr.value.len() < 2 {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Invalid CHANNEL-NUMBER",
key,
);
let _ = tx.send(resp).await;
continue;
}
let channel = u16::from_be_bytes([
channel_attr.value[0],
channel_attr.value[1],
]);
let peer_addr = match decode_xor_peer_address(
&peer_attr.value,
&msg.header.transaction_id,
) {
Some(addr) => addr,
None => {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Invalid XOR-PEER-ADDRESS",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response_with_integrity(
&msg.header,
403,
"Peer Not Permitted",
key,
);
let _ = tx.send(resp).await;
continue;
}
if let Err(e) =
allocs.add_channel_binding(peer, channel, peer_addr)
{
tracing::error!(
"failed to persist channel binding {} -> {} (0x{:04x}): {:?}",
peer,
peer_addr,
channel,
e
);
let (code, reason) = match e.downcast_ref::<AllocationError>() {
Some(AllocationError::ChannelQuotaExceeded) => {
(508, "Insufficient Capacity")
}
_ => (500, "Channel Bind Failed"),
};
let resp = build_error_response_with_integrity(
&msg.header,
code,
reason,
key,
);
let _ = tx.send(resp).await;
continue;
}
crate::metrics::inc_channel_binding_added();
let resp = build_success_response_with_integrity(&msg.header, key);
let _ = tx.send(resp).await;
}
METHOD_SEND => {
let key = auth_key
.as_deref()
.expect("auth key must be set after AuthStatus::Granted");
let allocation = match allocs.get_allocation(&peer) {
Some(a) => a,
None => {
let resp = build_error_response_with_integrity(
&msg.header,
437,
"Allocation Mismatch",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
let peer_attr = msg
.attributes
.iter()
.find(|a| a.typ == ATTR_XOR_PEER_ADDRESS);
let data_attr = msg.attributes.iter().find(|a| a.typ == ATTR_DATA);
let (peer_attr, data_attr) = match (peer_attr, data_attr) {
(Some(p), Some(d)) => (p, d),
_ => {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Missing DATA or XOR-PEER-ADDRESS",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
let peer_addr = match decode_xor_peer_address(
&peer_attr.value,
&msg.header.transaction_id,
) {
Some(addr) => addr,
None => {
let resp = build_error_response_with_integrity(
&msg.header,
400,
"Invalid XOR-PEER-ADDRESS",
key,
);
let _ = tx.send(resp).await;
continue;
}
};
if !allocation.is_peer_allowed(&peer_addr) {
let resp = build_error_response_with_integrity(
&msg.header,
403,
"Peer Not Permitted",
key,
);
let _ = tx.send(resp).await;
continue;
}
match allocation.send_to_peer(peer_addr, &data_attr.value).await {
Ok(_) => {
let resp = build_success_response_with_integrity(&msg.header, key);
let _ = tx.send(resp).await;
}
Err(e) => {
tracing::error!(
"failed to send payload from {} to {}: {:?}",
peer,
peer_addr,
e
);
let resp = build_error_response_with_integrity(
&msg.header,
500,
"Peer Send Failed",
key,
);
let _ = tx.send(resp).await;
}
}
}
METHOD_REFRESH => {
let key = auth_key
.as_deref()
.expect("auth key must be set after AuthStatus::Granted");
let requested = extract_lifetime_seconds(&msg)
.map(|secs| Duration::from_secs(secs as u64));
match allocs.refresh_allocation(peer, requested) {
Ok(applied) => {
let resp = build_lifetime_success_with_integrity(
&msg.header,
applied.as_secs().min(u32::MAX as u64) as u32,
key,
);
let _ = tx.send(resp).await;
}
Err(_) => {
let resp = build_error_response_with_integrity(
&msg.header,
437,
"Allocation Mismatch",
key,
);
let _ = tx.send(resp).await;
}
}
}
METHOD_BINDING => {
if rate_limiters.allow_binding(peer.ip()) {
let resp = crate::stun::build_binding_success(&msg.header, &peer);
let _ = tx.send(resp).await;
} else {
crate::metrics::inc_rate_limited();
}
}
_ => {
let resp = match auth_key.as_deref() {
Some(key) => build_error_response_with_integrity(
&msg.header,
420,
"Unknown TURN Method",
key,
),
None => build_error_response(
&msg.header,
420,
"Unknown TURN Method",
),
};
let _ = tx.send(resp).await;
}
}
}
}
}
}
Err(e) => return Err(anyhow::anyhow!("stream read error: {e:?}")),
}
}
// Stop writer.
drop(tx);
let _ = writer_task.await;
Ok(())
}

View File

@ -6,7 +6,7 @@ mod support;
mod helpers; mod helpers;
use helpers::*; use helpers::*;
use niom_turn::alloc::AllocationManager; use niom_turn::alloc::{AllocationManager, ClientSink};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -16,7 +16,13 @@ async fn allocate_sample(manager: &AllocationManager) -> SocketAddr {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("udp bind")); let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("udp bind"));
let client = sample_client(); let client = sample_client();
manager manager
.allocate_for(client, server) .allocate_for(
client,
ClientSink::Udp {
sock: server,
addr: client,
},
)
.await .await
.expect("allocate relay"); .expect("allocate relay");
client client

View File

@ -11,14 +11,12 @@ use niom_turn::alloc::AllocationManager;
use support::{default_test_credentials, init_tracing, test_auth_manager}; use support::{default_test_credentials, init_tracing, test_auth_manager};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tls_rejects_invalid_credentials() { async fn tls_rejects_invalid_credentials() {
init_tracing(); init_tracing();
let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind");
let udp_arc = Arc::new(udp);
let (username, password) = default_test_credentials(); let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password); let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new(); let allocs = AllocationManager::new();
@ -34,7 +32,6 @@ async fn tls_rejects_invalid_credentials() {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
let udp_clone = udp_arc.clone();
let auth_clone = auth.clone(); let auth_clone = auth.clone();
let alloc_clone = allocs.clone(); let alloc_clone = allocs.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -44,21 +41,18 @@ async fn tls_rejects_invalid_credentials() {
Err(_) => break, Err(_) => break,
}; };
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let udp_clone = udp_clone.clone();
let auth_clone = auth_clone.clone(); let auth_clone = auth_clone.clone();
let alloc_clone = alloc_clone.clone(); let alloc_clone = alloc_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(mut tls_stream) => { Ok(tls_stream) => {
if let Err(e) = niom_turn::tls::handle_tls_connection( if let Err(e) = niom_turn::tls::handle_tls_connection(
&mut tls_stream, tls_stream,
peer, peer,
udp_clone,
auth_clone, auth_clone,
alloc_clone, alloc_clone,
) )
.await .await {
{
tracing::error!("tls connection error: {:?}", e); tracing::error!("tls connection error: {:?}", e);
} }
} }

View File

@ -12,14 +12,12 @@ use niom_turn::auth;
use std::sync::Arc; use std::sync::Arc;
use support::{default_test_credentials, init_tracing, test_auth_manager}; use support::{default_test_credentials, init_tracing, test_auth_manager};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tls_channel_bind_without_allocation_returns_mismatch() { async fn tls_channel_bind_without_allocation_returns_mismatch() {
init_tracing(); init_tracing();
let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind");
let udp_arc = Arc::new(udp);
let (username, password) = default_test_credentials(); let (username, password) = default_test_credentials();
let auth_manager = test_auth_manager(username, password); let auth_manager = test_auth_manager(username, password);
let allocs = AllocationManager::new(); let allocs = AllocationManager::new();
@ -35,7 +33,6 @@ async fn tls_channel_bind_without_allocation_returns_mismatch() {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
let udp_clone = udp_arc.clone();
let auth_clone = auth_manager.clone(); let auth_clone = auth_manager.clone();
let alloc_clone = allocs.clone(); let alloc_clone = allocs.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -45,21 +42,18 @@ async fn tls_channel_bind_without_allocation_returns_mismatch() {
Err(_) => break, Err(_) => break,
}; };
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let udp_clone = udp_clone.clone();
let auth_clone = auth_clone.clone(); let auth_clone = auth_clone.clone();
let alloc_clone = alloc_clone.clone(); let alloc_clone = alloc_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(mut tls_stream) => { Ok(tls_stream) => {
if let Err(e) = niom_turn::tls::handle_tls_connection( if let Err(e) = niom_turn::tls::handle_tls_connection(
&mut tls_stream, tls_stream,
peer, peer,
udp_clone,
auth_clone, auth_clone,
alloc_clone, alloc_clone,
) )
.await .await {
{
tracing::error!("tls connection error: {:?}", e); tracing::error!("tls connection error: {:?}", e);
} }
} }

View File

@ -27,6 +27,7 @@ async fn channel_sink_mock_records_payload() {
fn parse_channel_data_round_trip() { fn parse_channel_data_round_trip() {
let payload = sample_payload(); let payload = sample_payload();
let frame = build_channel_data(sample_channel_number(), &payload); let frame = build_channel_data(sample_channel_number(), &payload);
assert_eq!(frame.len(), 4 + payload.len());
let (channel, body) = parse_channel_data(&frame).expect("parse channel frame"); let (channel, body) = parse_channel_data(&frame).expect("parse channel frame");
assert_eq!(channel, sample_channel_number()); assert_eq!(channel, sample_channel_number());
assert_eq!(body, payload.as_slice()); assert_eq!(body, payload.as_slice());

View File

@ -10,15 +10,13 @@ use niom_turn::alloc::AllocationManager;
use std::sync::Arc; use std::sync::Arc;
use support::{init_tracing, test_auth_manager, default_test_credentials}; use support::{init_tracing, test_auth_manager, default_test_credentials};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn malformed_tls_frame_is_ignored() { async fn malformed_tls_frame_is_ignored() {
init_tracing(); init_tracing();
let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind");
let udp_arc = Arc::new(udp);
let (username, password) = default_test_credentials(); let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password); let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new(); let allocs = AllocationManager::new();
@ -34,7 +32,6 @@ async fn malformed_tls_frame_is_ignored() {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
let udp_clone = udp_arc.clone();
let auth_clone = auth.clone(); let auth_clone = auth.clone();
let alloc_clone = allocs.clone(); let alloc_clone = allocs.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -44,16 +41,14 @@ async fn malformed_tls_frame_is_ignored() {
Err(_) => break, Err(_) => break,
}; };
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let udp_clone = udp_clone.clone();
let auth_clone = auth_clone.clone(); let auth_clone = auth_clone.clone();
let alloc_clone = alloc_clone.clone(); let alloc_clone = alloc_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(mut tls_stream) => { Ok(tls_stream) => {
let _ = niom_turn::tls::handle_tls_connection( let _ = niom_turn::tls::handle_tls_connection(
&mut tls_stream, tls_stream,
peer, peer,
udp_clone,
auth_clone, auth_clone,
alloc_clone, alloc_clone,
) )

153
tests/rate_limit_tcp.rs Normal file
View File

@ -0,0 +1,153 @@
use std::net::SocketAddr;
use std::sync::Arc;
use niom_turn::alloc::AllocationManager;
use niom_turn::auth::InMemoryStore;
use niom_turn::config::LimitsOptions;
use niom_turn::rate_limit::RateLimiters;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{timeout, Duration};
use crate::support::stream::{StreamFrame, StreamFramer};
use crate::support::stun_builders::{build_allocate_request, build_binding_request};
use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager};
mod support;
async fn start_tcp_test_server(
auth: niom_turn::auth::AuthManager<InMemoryStore>,
allocs: AllocationManager,
rate_limiters: Arc<RateLimiters>,
) -> SocketAddr {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
tokio::spawn(async move {
loop {
let (stream, peer) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
let rl = rate_limiters.clone();
tokio::spawn(async move {
let _ = niom_turn::turn_stream::handle_turn_stream_connection_with_limits(
stream, peer, auth_clone, alloc_clone, rl,
)
.await;
});
}
});
tcp_addr
}
#[tokio::test]
async fn tcp_binding_is_rate_limited_by_ip() {
init_tracing_with("warn,niom_turn=info");
// Configure a very small burst to make the test deterministic.
let mut limits = LimitsOptions::default();
limits.binding_rps = Some(1);
limits.binding_burst = Some(1);
let rate_limiters = Arc::new(RateLimiters::from_limits(&limits));
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_addr = start_tcp_test_server(auth.clone(), allocs.clone(), rate_limiters.clone()).await;
let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect");
// Fire multiple Binding requests quickly; with burst=1 we should only get 1 success response.
for _ in 0..3 {
let req = build_binding_request();
stream.write_all(&req).await.expect("write binding");
}
let mut framer = StreamFramer::new();
let mut responses = 0usize;
// Read for a short bounded period.
let deadline = tokio::time::Instant::now() + Duration::from_millis(150);
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
let frame = match timeout(remaining, framer.read_frame(&mut stream)).await {
Ok(Ok(f)) => f,
Ok(Err(e)) => panic!("read_frame error: {e:?}"),
Err(_) => break,
};
match frame {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type & 0x0110, 0x0100);
responses += 1;
}
other => panic!("expected STUN response, got: {other:?}"),
}
}
assert_eq!(responses, 1, "expected exactly 1 Binding response under burst=1");
}
#[tokio::test]
async fn tcp_unauth_challenge_is_rate_limited_by_ip() {
init_tracing_with("warn,niom_turn=info");
let mut limits = LimitsOptions::default();
limits.unauth_rps = Some(1);
limits.unauth_burst = Some(1);
let rate_limiters = Arc::new(RateLimiters::from_limits(&limits));
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_addr = start_tcp_test_server(auth.clone(), allocs.clone(), rate_limiters.clone()).await;
let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect");
for _ in 0..3 {
let req = build_allocate_request(None, None, None, None, None);
stream.write_all(&req).await.expect("write allocate");
}
let mut framer = StreamFramer::new();
let mut responses = 0usize;
let deadline = tokio::time::Instant::now() + Duration::from_millis(150);
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
let frame = match timeout(remaining, framer.read_frame(&mut stream)).await {
Ok(Ok(f)) => f,
Ok(Err(e)) => panic!("read_frame error: {e:?}"),
Err(_) => break,
};
match frame {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR);
msg.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
responses += 1;
}
other => panic!("expected STUN response, got: {other:?}"),
}
}
assert_eq!(responses, 1, "expected exactly 1 unauth challenge under burst=1");
}

148
tests/rate_limit_udp.rs Normal file
View File

@ -0,0 +1,148 @@
use std::sync::Arc;
use niom_turn::alloc::AllocationManager;
use niom_turn::config::LimitsOptions;
use niom_turn::rate_limit::RateLimiters;
use tokio::net::UdpSocket;
use crate::support::stun_builders::{build_allocate_request, build_binding_request, parse};
use crate::support::{default_test_credentials, init_tracing, test_auth_manager};
mod support;
#[tokio::test]
async fn udp_binding_is_rate_limited_by_ip() {
init_tracing();
// Configure a very small burst to make the test deterministic.
// We only limit Binding here to avoid affecting other integration tests.
let mut limits = LimitsOptions::default();
limits.binding_rps = Some(1);
limits.binding_burst = Some(1);
let rate_limiters = Arc::new(RateLimiters::from_limits(&limits));
// Start a UDP server loop.
let server = UdpSocket::bind("127.0.0.1:0").await.expect("server bind");
let server_addr = server.local_addr().expect("server addr");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
tokio::spawn({
let reader = server_arc.clone();
let auth = auth.clone();
let allocs = allocs.clone();
let rl = rate_limiters.clone();
async move {
let _ = niom_turn::server::udp_reader_loop_with_limits(reader, auth, allocs, rl).await;
}
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
// Fire multiple Binding requests quickly; with burst=1 we should only get 1 success response.
for _ in 0..3 {
let req = build_binding_request();
client
.send_to(&req, server_addr)
.await
.expect("send binding");
}
let mut buf = [0u8; 1500];
let mut responses = 0usize;
// Collect responses for a short, bounded period.
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(150);
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
match tokio::time::timeout(remaining, client.recv_from(&mut buf)).await {
Ok(Ok((len, _from))) => {
let resp = parse(&buf[..len]);
// Success class
assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
responses += 1;
}
Ok(Err(e)) => panic!("recv error: {e}"),
Err(_) => break, // timeout
}
}
assert_eq!(responses, 1, "expected exactly 1 Binding response under burst=1");
}
#[tokio::test]
async fn udp_unauth_challenge_is_rate_limited_by_ip() {
init_tracing();
// Configure a very small burst so only the first unauth challenge is answered.
let mut limits = LimitsOptions::default();
limits.unauth_rps = Some(1);
limits.unauth_burst = Some(1);
let rate_limiters = Arc::new(RateLimiters::from_limits(&limits));
// Start a UDP server loop.
let server = UdpSocket::bind("127.0.0.1:0").await.expect("server bind");
let server_addr = server.local_addr().expect("server addr");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
tokio::spawn({
let reader = server_arc.clone();
let auth = auth.clone();
let allocs = allocs.clone();
let rl = rate_limiters.clone();
async move {
let _ = niom_turn::server::udp_reader_loop_with_limits(reader, auth, allocs, rl).await;
}
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
for _ in 0..3 {
let req = build_allocate_request(None, None, None, None, None);
client
.send_to(&req, server_addr)
.await
.expect("send unauth allocate");
}
let mut buf = [0u8; 1500];
let mut responses = 0usize;
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(150);
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
match tokio::time::timeout(remaining, client.recv_from(&mut buf)).await {
Ok(Ok((len, _from))) => {
let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR);
resp.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
responses += 1;
}
Ok(Err(e)) => panic!("recv error: {e}"),
Err(_) => break,
}
}
assert_eq!(responses, 1, "expected exactly 1 unauth challenge under burst=1");
}

View File

@ -1,4 +1,5 @@
pub mod mocks; pub mod mocks;
pub mod stream;
pub mod stun_builders; pub mod stun_builders;
pub mod tls; pub mod tls;

78
tests/support/stream.rs Normal file
View File

@ -0,0 +1,78 @@
#![allow(dead_code)]
use std::io;
use niom_turn::models::stun::StunMessage;
use niom_turn::stun::parse_message;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
#[derive(Debug)]
pub enum StreamFrame {
Stun(StunMessage),
ChannelData { channel: u16, payload: Vec<u8> },
}
#[derive(Default)]
pub struct StreamFramer {
buffer: Vec<u8>,
}
impl StreamFramer {
pub fn new() -> Self {
Self { buffer: Vec::new() }
}
fn try_pop_next(&mut self) -> Option<io::Result<StreamFrame>> {
if self.buffer.len() < 4 {
return None;
}
// ChannelData: channel number 0x4000..=0x7FFF (top bits 01)
let channel = u16::from_be_bytes([self.buffer[0], self.buffer[1]]);
if (channel & 0xC000) == 0x4000 {
let len = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize;
let total = 4 + len;
if self.buffer.len() < total {
return None;
}
let frame = self.buffer.drain(..total).collect::<Vec<u8>>();
return Some(Ok(StreamFrame::ChannelData {
channel,
payload: frame[4..].to_vec(),
}));
}
// STUN over stream: 20 byte header + length.
if self.buffer.len() < 20 {
return None;
}
let len = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize;
let total = 20 + len;
if self.buffer.len() < total {
return None;
}
let chunk = self.buffer.drain(..total).collect::<Vec<u8>>();
match parse_message(&chunk) {
Ok(msg) => Some(Ok(StreamFrame::Stun(msg))),
Err(e) => Some(Err(io::Error::new(io::ErrorKind::InvalidData, format!(
"parse stun: {e:?}"
)))),
}
}
pub async fn read_frame<R: AsyncRead + Unpin>(&mut self, reader: &mut R) -> io::Result<StreamFrame> {
loop {
if let Some(frame) = self.try_pop_next() {
return frame;
}
let mut tmp = [0u8; 4096];
let n = reader.read(&mut tmp).await?;
if n == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed"));
}
self.buffer.extend_from_slice(&tmp[..n]);
}
}
}

View File

@ -21,7 +21,30 @@ pub fn build_allocate_request(
key: Option<&[u8]>, key: Option<&[u8]>,
lifetime: Option<u32>, lifetime: Option<u32>,
) -> Vec<u8> { ) -> Vec<u8> {
build_authenticated_request( build_allocate_request_with_requested_transport(
username,
realm,
nonce,
key,
lifetime,
Some(IPPROTO_UDP),
)
}
/// Construct a TURN Allocate request with explicit REQUESTED-TRANSPORT.
///
/// - `Some(IPPROTO_UDP)` (17) is the WebRTC default.
/// - `Some(IPPROTO_TCP)` (6) should be rejected by this server (UDP relay only).
/// - `None` omits the attribute (should be rejected after auth).
pub fn build_allocate_request_with_requested_transport(
username: Option<&str>,
realm: Option<&str>,
nonce: Option<&str>,
key: Option<&[u8]>,
lifetime: Option<u32>,
requested_transport: Option<u8>,
) -> Vec<u8> {
build_request_with_body(
METHOD_ALLOCATE, METHOD_ALLOCATE,
username, username,
realm, realm,
@ -30,6 +53,24 @@ pub fn build_allocate_request(
lifetime, lifetime,
None, None,
None, None,
None,
requested_transport,
)
}
/// Build a basic STUN Binding request (no auth).
pub fn build_binding_request() -> Vec<u8> {
build_request_with_body(
METHOD_BINDING,
None,
None,
None,
None,
None,
None,
None,
None,
None,
) )
} }
@ -52,6 +93,7 @@ pub fn build_refresh_request(
None, None,
None, None,
Some(trans), Some(trans),
None,
) )
} }
@ -73,6 +115,7 @@ pub fn build_create_permission_request(
Some(peer), Some(peer),
None, None,
None, None,
None,
) )
} }
@ -95,6 +138,7 @@ pub fn build_send_request(
Some(peer), Some(peer),
Some(payload), Some(payload),
None, None,
None,
) )
} }
@ -148,7 +192,16 @@ fn build_authenticated_request(
payload: Option<&[u8]>, payload: Option<&[u8]>,
) -> Vec<u8> { ) -> Vec<u8> {
build_request_with_body( build_request_with_body(
method, username, realm, nonce, key, lifetime, peer, payload, None, method,
username,
realm,
nonce,
key,
lifetime,
peer,
payload,
None,
None,
) )
} }
@ -162,6 +215,7 @@ fn build_request_with_body(
peer: Option<&std::net::SocketAddr>, peer: Option<&std::net::SocketAddr>,
payload: Option<&[u8]>, payload: Option<&[u8]>,
override_trans: Option<[u8; 12]>, override_trans: Option<[u8; 12]>,
requested_transport: Option<u8>,
) -> Vec<u8> { ) -> Vec<u8> {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
buf.extend_from_slice(&method.to_be_bytes()); buf.extend_from_slice(&method.to_be_bytes());
@ -182,6 +236,12 @@ fn build_request_with_body(
if let Some(lifetime) = lifetime { if let Some(lifetime) = lifetime {
push_u32_attr(&mut buf, ATTR_LIFETIME, lifetime); push_u32_attr(&mut buf, ATTR_LIFETIME, lifetime);
} }
if method == METHOD_ALLOCATE {
if let Some(proto) = requested_transport {
push_bytes_attr(&mut buf, ATTR_REQUESTED_TRANSPORT, &[proto, 0, 0, 0]);
}
}
if let Some(peer) = peer { if let Some(peer) = peer {
let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans); let encoded = niom_turn::stun::encode_xor_peer_address(peer, &trans);
push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded); push_bytes_attr(&mut buf, ATTR_XOR_PEER_ADDRESS, &encoded);
@ -221,23 +281,21 @@ fn push_bytes_attr(buf: &mut BytesMut, typ: u16, data: &[u8]) {
} }
fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) { fn append_message_integrity(buf: &mut BytesMut, key: &[u8]) {
// position before adding MESSAGE-INTEGRITY attribute
let attribute_start = buf.len();
// append attribute header and placeholder value // append attribute header and placeholder value
buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes()); buf.extend_from_slice(&ATTR_MESSAGE_INTEGRITY.to_be_bytes());
buf.extend_from_slice(&(20u16.to_be_bytes())); buf.extend_from_slice(&(20u16.to_be_bytes()));
let value_start = buf.len(); let value_start = buf.len();
buf.extend_from_slice(&[0u8; 20]); buf.extend_from_slice(&[0u8; 20]);
// update message length to include the attribute (spec requires this before HMAC) // update message length to end-of-MI (exclude any later attributes like FINGERPRINT)
let total_len = (buf.len() - 20) as u16; let mi_end = buf.len();
let total_len = (mi_end - 20) as u16;
let len_bytes = total_len.to_be_bytes(); let len_bytes = total_len.to_be_bytes();
buf[2] = len_bytes[0]; buf[2] = len_bytes[0];
buf[3] = len_bytes[1]; buf[3] = len_bytes[1];
// compute the HMAC over all bytes preceding the attribute (RFC 5389 §15.4) // compute the HMAC over the message up to end-of-MI (MI value is zero here)
let signed = compute_message_integrity(key, &buf[..attribute_start]); let signed = compute_message_integrity(key, &buf[..mi_end]);
// write the computed MAC into the placeholder we appended above // write the computed MAC into the placeholder we appended above
buf[value_start..value_start + 20].copy_from_slice(&signed[..20]); buf[value_start..value_start + 20].copy_from_slice(&signed[..20]);

365
tests/tcp_turn.rs Normal file
View File

@ -0,0 +1,365 @@
use std::net::SocketAddr;
use niom_turn::alloc::AllocationManager;
use niom_turn::auth::{compute_a1_md5, InMemoryStore};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::time::{timeout, Duration};
use crate::support::stream::{StreamFrame, StreamFramer};
use crate::support::stun_builders::{
build_allocate_request, build_channel_bind_request, build_create_permission_request,
build_send_request,
};
use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager};
mod support;
async fn start_tcp_test_server(
auth: niom_turn::auth::AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> SocketAddr {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
tokio::spawn(async move {
loop {
let (stream, peer) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
if let Err(e) = niom_turn::turn_stream::handle_turn_stream_connection(
stream, peer, auth_clone, alloc_clone,
)
.await
{
tracing::info!("tcp connection ended: {:?}", e);
}
});
}
});
tcp_addr
}
#[tokio::test]
async fn tcp_stream_resyncs_after_garbage_and_still_processes_allocate() {
init_tracing_with("warn,niom_turn=info");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await;
let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect");
// Send garbage bytes that look like a STUN header with a huge length but invalid cookie.
// Without resync, the server could wait for 20+65535 bytes and stall parsing.
let garbage = [0x00, 0x01, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00];
stream.write_all(&garbage).await.expect("write garbage");
// Now send a valid Allocate request; server should still respond with a 401 challenge.
let allocate = build_allocate_request(None, None, None, None, None);
stream.write_all(&allocate).await.expect("write allocate");
let mut framer = StreamFramer::new();
let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout response")
.expect("read response");
match frame {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR);
msg.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
}
other => panic!("expected STUN challenge, got: {:?}", other),
}
}
#[tokio::test]
async fn tcp_peer_data_is_delivered_over_stream_as_data_indication() {
init_tracing_with("warn,niom_turn=info");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await;
let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect");
let client_addr = stream.local_addr().expect("client addr");
// 1) Allocate without auth -> 401 + NONCE
let allocate = build_allocate_request(None, None, None, None, None);
stream.write_all(&allocate).await.expect("write allocate");
let mut framer = StreamFramer::new();
let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout challenge")
.expect("read challenge");
let nonce = match challenge {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_ERROR);
let nonce_attr = msg
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8")
}
_ => panic!("expected STUN 401 challenge"),
};
// 2) Authenticated allocate
let key = compute_a1_md5(username, auth.realm(), password);
let allocate = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce),
Some(&key),
Some(600),
);
stream.write_all(&allocate).await.expect("write auth allocate");
let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout alloc success")
.expect("read alloc success");
match alloc_success {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS);
msg.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS)
.expect("xor-relayed attr");
}
_ => panic!("expected STUN allocate success"),
}
let relay_addr = allocs
.get_allocation(&client_addr)
.expect("allocation exists")
.relay_addr;
// 3) CreatePermission
let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind");
let perm = build_create_permission_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
);
stream.write_all(&perm).await.expect("write create permission");
let perm_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout perm resp")
.expect("read perm resp");
match perm_resp {
StreamFrame::Stun(msg) => assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS),
_ => panic!("expected STUN permission success"),
}
// 4) Send indication -> peer should receive UDP payload
let payload = b"hello-turn-tcp";
let send = build_send_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
payload,
);
stream.write_all(&send).await.expect("write send");
// This implementation responds to SEND with a success response.
let send_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout send resp")
.expect("read send resp");
match send_resp {
StreamFrame::Stun(msg) => {
assert_eq!(
msg.header.msg_type,
niom_turn::constants::METHOD_SEND | niom_turn::constants::CLASS_SUCCESS
);
}
_ => panic!("expected STUN send success"),
}
let mut peer_buf = [0u8; 1500];
let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf))
.await
.expect("timeout peer recv")
.expect("peer recv");
assert_eq!(&peer_buf[..n], payload);
assert!(from.ip().is_loopback());
// 5) Peer -> relay -> client: should come back over TCP as Data Indication
let back = b"peer-reply";
peer_sock
.send_to(back, relay_addr)
.await
.expect("peer send back");
let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout data indication")
.expect("read data indication");
match frame {
StreamFrame::Stun(msg) => {
assert_eq!(msg.header.msg_type, niom_turn::constants::METHOD_DATA | niom_turn::constants::CLASS_INDICATION);
let data_attr = msg
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_DATA)
.expect("data attr");
assert_eq!(data_attr.value.as_slice(), back);
}
_ => panic!("expected STUN data indication"),
}
// sanity: allocation exists by client addr
assert!(allocs.get_allocation(&client_addr).is_some());
}
#[tokio::test]
async fn tcp_channel_data_round_trip_works() {
init_tracing_with("warn,niom_turn=info");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_addr = start_tcp_test_server(auth.clone(), allocs.clone()).await;
let mut stream = TcpStream::connect(server_addr).await.expect("tcp connect");
let allocate = build_allocate_request(None, None, None, None, None);
stream.write_all(&allocate).await.expect("write allocate");
let mut framer = StreamFramer::new();
let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout challenge")
.expect("read challenge");
let nonce = match challenge {
StreamFrame::Stun(msg) => {
let nonce_attr = msg
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8")
}
_ => panic!("expected STUN 401 challenge"),
};
let key = compute_a1_md5(username, auth.realm(), password);
let allocate = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce),
Some(&key),
Some(600),
);
stream.write_all(&allocate).await.expect("write auth allocate");
let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout alloc success")
.expect("read alloc success");
match alloc_success {
StreamFrame::Stun(msg) => {
msg.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS)
.expect("xor-relayed attr");
}
_ => panic!("expected STUN allocate success"),
}
let client_addr = stream.local_addr().expect("client addr");
let relay_addr = allocs
.get_allocation(&client_addr)
.expect("allocation exists")
.relay_addr;
let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind");
// Permission
let perm = build_create_permission_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
);
stream.write_all(&perm).await.expect("write create permission");
let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout perm resp")
.expect("read perm resp");
// ChannelBind
let channel: u16 = 0x4000;
let bind = build_channel_bind_request(
username,
auth.realm(),
&nonce,
&key,
channel,
&peer_sock.local_addr().unwrap(),
);
stream.write_all(&bind).await.expect("write channel bind");
let bind_resp = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout bind resp")
.expect("read bind resp");
match bind_resp {
StreamFrame::Stun(msg) => assert_eq!(msg.header.msg_type & 0x0110, niom_turn::constants::CLASS_SUCCESS),
_ => panic!("expected STUN channel bind success"),
}
// Client -> Server -> Peer via ChannelData
let payload = b"chan-hello";
let ch = niom_turn::stun::build_channel_data(channel, payload);
stream.write_all(&ch).await.expect("write channel data");
let mut peer_buf = [0u8; 1500];
let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf))
.await
.expect("timeout peer recv")
.expect("peer recv");
assert_eq!(&peer_buf[..n], payload);
assert!(from.ip().is_loopback());
// Peer -> Relay -> Client as ChannelData
let back = b"chan-back";
peer_sock.send_to(back, relay_addr).await.expect("peer send back");
let frame = timeout(Duration::from_secs(2), framer.read_frame(&mut stream))
.await
.expect("timeout channel back")
.expect("read channel back");
match frame {
StreamFrame::ChannelData { channel: chn, payload } => {
assert_eq!(chn, channel);
assert_eq!(payload.as_slice(), back);
}
_ => panic!("expected ChannelData frame"),
}
}

206
tests/tls_data_plane.rs Normal file
View File

@ -0,0 +1,206 @@
use std::net::SocketAddr;
use std::sync::Arc;
use niom_turn::alloc::AllocationManager;
use niom_turn::auth::{compute_a1_md5, InMemoryStore};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::time::{timeout, Duration};
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
use crate::support::stream::{StreamFrame, StreamFramer};
use crate::support::stun_builders::{
build_allocate_request, build_channel_bind_request, build_create_permission_request,
};
use crate::support::{default_test_credentials, init_tracing_with, test_auth_manager};
mod support;
async fn start_tls_test_server(
acceptor: TlsAcceptor,
auth: niom_turn::auth::AuthManager<InMemoryStore>,
allocs: AllocationManager,
) -> SocketAddr {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
tokio::spawn(async move {
loop {
let (stream, peer) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acceptor = acceptor.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = niom_turn::tls::handle_tls_connection(
tls_stream,
peer,
auth_clone,
alloc_clone,
)
.await
{
tracing::info!("tls connection ended: {:?}", e);
}
}
Err(e) => tracing::error!("tls accept failed: {:?}", e),
}
});
}
});
tcp_addr
}
#[tokio::test]
async fn tls_channel_data_round_trip_works() {
init_tracing_with("warn,niom_turn=info");
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let (cert, key) = support::tls::generate_self_signed_cert();
let mut cfg = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key)
.expect("server config");
cfg.alpn_protocols.push(b"turn".to_vec());
let acceptor = TlsAcceptor::from(Arc::new(cfg));
let server_addr = start_tls_test_server(acceptor, auth.clone(), allocs.clone()).await;
// client config trusting generated cert
let mut root_store = tokio_rustls::rustls::RootCertStore::empty();
root_store.add(&cert).expect("add root");
let client_config = tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
let tcp_stream = TcpStream::connect(server_addr).await.expect("tcp connect");
let client_addr = tcp_stream.local_addr().expect("client addr");
let domain = tokio_rustls::rustls::ServerName::try_from("localhost").unwrap();
let mut tls_stream = connector
.connect(domain, tcp_stream)
.await
.expect("tls connect");
// allocate (unauth -> nonce)
let allocate = build_allocate_request(None, None, None, None, None);
tls_stream.write_all(&allocate).await.expect("write allocate");
let mut framer = StreamFramer::new();
let challenge = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream))
.await
.expect("timeout challenge")
.expect("read challenge");
let nonce = match challenge {
StreamFrame::Stun(msg) => {
let nonce_attr = msg
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr");
String::from_utf8(nonce_attr.value.clone()).expect("nonce utf8")
}
_ => panic!("expected STUN 401 challenge"),
};
// auth allocate
let key = compute_a1_md5(username, auth.realm(), password);
let allocate = build_allocate_request(
Some(username),
Some(auth.realm()),
Some(&nonce),
Some(&key),
Some(600),
);
tls_stream.write_all(&allocate).await.expect("write auth allocate");
let alloc_success = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream))
.await
.expect("timeout alloc success")
.expect("read alloc success");
match alloc_success {
StreamFrame::Stun(msg) => {
msg.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_XOR_RELAYED_ADDRESS)
.expect("xor-relayed attr");
}
_ => panic!("expected STUN allocate success"),
}
let relay_addr = allocs
.get_allocation(&client_addr)
.expect("allocation exists")
.relay_addr;
// set up peer + permission
let peer_sock = UdpSocket::bind("127.0.0.1:0").await.expect("peer bind");
let perm = build_create_permission_request(
username,
auth.realm(),
&nonce,
&key,
&peer_sock.local_addr().unwrap(),
);
tls_stream.write_all(&perm).await.expect("write create permission");
let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream))
.await
.expect("timeout perm resp")
.expect("read perm resp");
// channel bind
let channel: u16 = 0x4000;
let bind = build_channel_bind_request(
username,
auth.realm(),
&nonce,
&key,
channel,
&peer_sock.local_addr().unwrap(),
);
tls_stream.write_all(&bind).await.expect("write channel bind");
let _ = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream))
.await
.expect("timeout bind resp")
.expect("read bind resp");
// client -> peer via ChannelData
let payload = b"tls-chan";
let frame = niom_turn::stun::build_channel_data(channel, payload);
tls_stream.write_all(&frame).await.expect("write channel data");
let mut peer_buf = [0u8; 1500];
let (n, from) = timeout(Duration::from_secs(2), peer_sock.recv_from(&mut peer_buf))
.await
.expect("timeout peer recv")
.expect("peer recv");
assert_eq!(&peer_buf[..n], payload);
assert!(from.ip().is_loopback());
// peer -> client as ChannelData over TLS
let back = b"tls-back";
peer_sock.send_to(back, relay_addr).await.expect("peer send back");
let received = timeout(Duration::from_secs(2), framer.read_frame(&mut tls_stream))
.await
.expect("timeout channel back")
.expect("read channel back");
match received {
StreamFrame::ChannelData { channel: ch, payload } => {
assert_eq!(ch, channel);
assert_eq!(payload.as_slice(), back);
}
other => panic!("expected ChannelData frame, got: {:?}", other),
}
}

View File

@ -3,7 +3,7 @@ use std::sync::Arc;
use niom_turn::alloc::AllocationManager; use niom_turn::alloc::AllocationManager;
use niom_turn::stun::parse_message; use niom_turn::stun::parse_message;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::TcpListener;
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
use crate::support::stun_builders::{build_allocate_request, build_refresh_request}; use crate::support::stun_builders::{build_allocate_request, build_refresh_request};
@ -15,8 +15,6 @@ mod support;
async fn tls_allocate_refresh_flow() { async fn tls_allocate_refresh_flow() {
init_tracing_with("warn,niom_turn=info"); init_tracing_with("warn,niom_turn=info");
let udp = UdpSocket::bind("127.0.0.1:0").await.expect("udp bind");
let udp_arc = Arc::new(udp);
let (username, password) = default_test_credentials(); let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password); let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new(); let allocs = AllocationManager::new();
@ -32,7 +30,6 @@ async fn tls_allocate_refresh_flow() {
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind"); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.expect("tcp bind");
let tcp_addr = tcp_listener.local_addr().expect("tcp addr"); let tcp_addr = tcp_listener.local_addr().expect("tcp addr");
let udp_clone = udp_arc.clone();
let auth_clone = auth.clone(); let auth_clone = auth.clone();
let alloc_clone = allocs.clone(); let alloc_clone = allocs.clone();
@ -43,21 +40,18 @@ async fn tls_allocate_refresh_flow() {
Err(_) => break, Err(_) => break,
}; };
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let udp_clone = udp_clone.clone();
let auth_clone = auth_clone.clone(); let auth_clone = auth_clone.clone();
let alloc_clone = alloc_clone.clone(); let alloc_clone = alloc_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(mut tls_stream) => { Ok(tls_stream) => {
match niom_turn::tls::handle_tls_connection( match niom_turn::tls::handle_tls_connection(
&mut tls_stream, tls_stream,
peer, peer,
udp_clone,
auth_clone, auth_clone,
alloc_clone, alloc_clone,
) )
.await .await {
{
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
tracing::error!("tls connection error: {:?}", e); tracing::error!("tls connection error: {:?}", e);

View File

@ -4,11 +4,14 @@ use std::sync::Arc;
use niom_turn::alloc::AllocationManager; use niom_turn::alloc::AllocationManager;
use niom_turn::auth::InMemoryStore; use niom_turn::auth::InMemoryStore;
use niom_turn::server::udp_reader_loop; use niom_turn::server::udp_reader_loop;
use niom_turn::stun::parse_message; use niom_turn::stun::{
find_fingerprint, find_message_integrity, parse_message, validate_fingerprint_if_present,
validate_message_integrity,
};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::support::stun_builders::{ use crate::support::stun_builders::{
build_allocate_request, build_create_permission_request, build_refresh_request, build_allocate_request, build_allocate_request_with_requested_transport, build_create_permission_request, build_refresh_request,
build_send_request, new_transaction_id, parse, build_send_request, new_transaction_id, parse,
}; };
use crate::support::{default_test_credentials, init_tracing, test_auth_manager}; use crate::support::{default_test_credentials, init_tracing, test_auth_manager};
@ -70,6 +73,120 @@ async fn allocate_requires_auth_then_succeeds() {
let (len, _) = client.recv_from(&mut buf).await.expect("recv success"); let (len, _) = client.recv_from(&mut buf).await.expect("recv success");
let resp = parse(&buf[..len]); let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0100); assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
assert!(find_message_integrity(&resp).is_some());
assert!(validate_message_integrity(&resp, &key));
assert!(find_fingerprint(&resp).is_some());
assert!(validate_fingerprint_if_present(&resp));
}
#[tokio::test]
async fn authenticated_allocate_rejects_missing_requested_transport() {
init_tracing();
let (server, client_addr) = start_udp_server().await;
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
let server_clone = server_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await;
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
// Get nonce via normal unauth Allocate (builder includes requested-transport=UDP)
let req = build_allocate_request(None, None, None, None, None);
client.send_to(&req, client_addr).await.expect("send unauth allocate");
let mut buf = [0u8; 1500];
let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge");
let resp = parse_message(&buf[..len]).expect("parse 401");
let nonce = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr")
.value
.clone();
let nonce_str = String::from_utf8(nonce).expect("nonce utf8");
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
// Authenticated Allocate but omit REQUESTED-TRANSPORT => 400
let req = build_allocate_request_with_requested_transport(
Some(username),
Some(auth.realm()),
Some(&nonce_str),
Some(&key),
Some(600),
None,
);
client.send_to(&req, client_addr).await.expect("send auth allocate");
let (len, _) = client.recv_from(&mut buf).await.expect("recv error");
let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0110);
assert!(find_message_integrity(&resp).is_some());
assert!(validate_message_integrity(&resp, &key));
assert!(find_fingerprint(&resp).is_some());
assert!(validate_fingerprint_if_present(&resp));
let code = crate::support::stun_builders::extract_error_code(&resp).expect("error code");
assert_eq!(code, 400);
}
#[tokio::test]
async fn authenticated_allocate_rejects_tcp_requested_transport() {
init_tracing();
let (server, client_addr) = start_udp_server().await;
let (username, password) = default_test_credentials();
let auth = test_auth_manager(username, password);
let allocs = AllocationManager::new();
let server_arc = Arc::new(server);
let server_clone = server_arc.clone();
let auth_clone = auth.clone();
let alloc_clone = allocs.clone();
tokio::spawn(async move {
let _ = udp_reader_loop(server_clone, auth_clone, alloc_clone).await;
});
let client = UdpSocket::bind("127.0.0.1:0").await.expect("client bind");
// Get nonce
let req = build_allocate_request(None, None, None, None, None);
client.send_to(&req, client_addr).await.expect("send unauth allocate");
let mut buf = [0u8; 1500];
let (len, _) = client.recv_from(&mut buf).await.expect("recv challenge");
let resp = parse_message(&buf[..len]).expect("parse 401");
let nonce = resp
.attributes
.iter()
.find(|a| a.typ == niom_turn::constants::ATTR_NONCE)
.expect("nonce attr")
.value
.clone();
let nonce_str = String::from_utf8(nonce).expect("nonce utf8");
let key = niom_turn::auth::compute_a1_md5(username, auth.realm(), password);
// Request TCP relay (unsupported): 442
let req = build_allocate_request_with_requested_transport(
Some(username),
Some(auth.realm()),
Some(&nonce_str),
Some(&key),
Some(600),
Some(niom_turn::constants::IPPROTO_TCP),
);
client.send_to(&req, client_addr).await.expect("send auth allocate");
let (len, _) = client.recv_from(&mut buf).await.expect("recv error");
let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0110);
assert!(find_message_integrity(&resp).is_some());
assert!(validate_message_integrity(&resp, &key));
assert!(find_fingerprint(&resp).is_some());
assert!(validate_fingerprint_if_present(&resp));
let code = crate::support::stun_builders::extract_error_code(&resp).expect("error code");
assert_eq!(code, 442);
} }
#[tokio::test] #[tokio::test]
@ -105,6 +222,10 @@ async fn refresh_zero_lifetime_releases_allocation() {
let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh resp"); let (len, _) = client.recv_from(&mut buf).await.expect("recv refresh resp");
let resp = parse(&buf[..len]); let resp = parse(&buf[..len]);
assert_eq!(resp.header.msg_type & 0x0110, 0x0100); assert_eq!(resp.header.msg_type & 0x0110, 0x0100);
assert!(find_message_integrity(&resp).is_some());
assert!(validate_message_integrity(&resp, &key));
assert!(find_fingerprint(&resp).is_some());
assert!(validate_fingerprint_if_present(&resp));
let lifetime = resp let lifetime = resp
.attributes .attributes
.iter() .iter()