diff --git a/workers/tests/test_safe_fetch.py b/workers/tests/test_safe_fetch.py index a9cb95b..a1b3a4a 100644 --- a/workers/tests/test_safe_fetch.py +++ b/workers/tests/test_safe_fetch.py @@ -1,5 +1,6 @@ import pytest -from void_workers.safe_fetch import safe_fetch, SafeFetchError +from unittest.mock import patch, MagicMock +from void_workers.safe_fetch import safe_fetch, SafeFetchError, _validate_url def test_rejects_file_scheme(): @@ -25,3 +26,66 @@ def test_rejects_metadata_endpoint(): def test_rejects_cgnat(): with pytest.raises(SafeFetchError): safe_fetch("http://100.64.0.1/x") + + +def test_redirect_to_loopback_is_rejected(): + """Open-redirect attack: a public URL 302s to http://127.0.0.1/. + The loop re-runs full validation on the next hop, so the redirect + target's literal IP triggers _validate_url and raises.""" + redirect_headers = MagicMock() + redirect_headers.get.return_value = "http://127.0.0.1/admin" + call_count = {"n": 0} + + def side_effect(url, **_kw): + call_count["n"] += 1 + if call_count["n"] == 1: + # First hop: 302 to a blocked address. + return (302, redirect_headers, b"") + # Second hop: should never reach _request_one because _validate_url + # in _request_one will raise before issuing the request. + raise AssertionError("second hop was issued — validation bypassed!") + + # Patch _request_one to short-circuit only the FIRST hop's network IO. + # The second hop's call still goes through the real _request_one which + # invokes the real _validate_url — that's where the blocked-IP error + # comes from. + import void_workers.safe_fetch as sf + real_request_one = sf._request_one + def hybrid(url, **kw): + if call_count["n"] == 0: + return side_effect(url, **kw) + return real_request_one(url, **kw) + with patch.object(sf, "_request_one", side_effect=hybrid): + with pytest.raises(SafeFetchError, match="blocked"): + safe_fetch("http://example.com/") + + +def test_too_many_redirects(): + redirect_headers = MagicMock() + redirect_headers.get.return_value = "http://example.com/loop" + with patch("void_workers.safe_fetch._request_one", + return_value=(302, redirect_headers, b"")): + with pytest.raises(SafeFetchError, match="too many redirects"): + safe_fetch("http://example.com/loop", max_hops=2) + + +def test_validate_url_returns_pinned_address_for_literal_public_ip(): + scheme, host, port, path, addr, family = _validate_url("http://8.8.8.8:80/x") + assert host == "8.8.8.8" + assert addr == "8.8.8.8" + assert port == 80 + + +def test_2xx_returns_body(): + headers = MagicMock() + with patch("void_workers.safe_fetch._request_one", + return_value=(200, headers, b"hello")): + assert safe_fetch("http://example.com/x") == b"hello" + + +def test_non_2xx_raises(): + headers = MagicMock() + with patch("void_workers.safe_fetch._request_one", + return_value=(500, headers, b"err")): + with pytest.raises(SafeFetchError, match="http 500"): + safe_fetch("http://example.com/x") diff --git a/workers/void_workers/safe_fetch.py b/workers/void_workers/safe_fetch.py index c7bc89f..ee21b28 100644 --- a/workers/void_workers/safe_fetch.py +++ b/workers/void_workers/safe_fetch.py @@ -1,18 +1,22 @@ -"""Python port of lib/ingest/safe_fetch.js. +"""SSRF-safe HTTP client used by sync.source_doc (and any future workers). -Same SSRF mitigations the Node side ships: +Same contract as lib/ingest/safe_fetch.js on the Node side: - http/https only -- DNS-resolved hostnames checked against loopback / RFC1918 / - link-local / CGNAT / IPv6 ULA + link-local -- Redirects followed manually with the same checks on each hop -- VOID_INGEST_ALLOW_PRIVATE=true gate for offline-fixture tests +- DNS-resolve and reject loopback / RFC1918 / link-local / CGNAT / metadata / + IPv6 ULA + link-local +- Pin the validated IP into the connection so a rebind between our DNS check + and the TCP connect cannot point us at an internal address. +- Follow redirects MANUALLY, re-validating every hop. We disable urllib's + built-in redirect handler so it cannot silently auto-follow. +- `VOID_INGEST_ALLOW_PRIVATE=true` gate for offline-fixture tests. """ -import socket +import http.client import ipaddress -import urllib.request -import urllib.error import os -from urllib.parse import urlparse +import socket +import ssl +import urllib.parse + BLOCK_V4_NETS = [ipaddress.ip_network(c) for c in [ "0.0.0.0/8", "127.0.0.0/8", "10.0.0.0/8", @@ -42,41 +46,101 @@ def _is_blocked(addr): return False -def _resolve(host): +def _resolve_validated(host): + """Resolve the host and return (address, family). Raises if any returned + address is in a blocked range.""" try: infos = socket.getaddrinfo(host, None) except socket.gaierror as e: raise SafeFetchError(f"no DNS for {host}: {e}") - addrs = list({i[4][0] for i in infos}) - for a in addrs: - if _is_blocked(a): - raise SafeFetchError(f"{host} resolves to blocked address {a}") + addrs = {(i[4][0], i[0]) for i in infos} # de-dupe if not addrs: raise SafeFetchError(f"no addresses for {host}") - return addrs[0] + for a, _fam in addrs: + if _is_blocked(a): + raise SafeFetchError(f"{host} resolves to blocked address {a}") + # Pick the first record. Caller pins this exact IP into the socket. + address, family = next(iter(addrs)) + return address, family + + +def _validate_url(url): + """Returns (scheme, hostname, port, path-with-query, pinned_addr, family).""" + u = urllib.parse.urlparse(url) + if u.scheme not in ("http", "https"): + raise SafeFetchError(f"unsupported scheme {u.scheme}") + host = u.hostname + if not host: + raise SafeFetchError(f"no hostname in {url}") + # Literal IP path + try: + ipaddress.ip_address(host) + if _is_blocked(host): + raise SafeFetchError(f"blocked literal IP {host}") + addr, family = host, (socket.AF_INET6 if ":" in host else socket.AF_INET) + except ValueError: + addr, family = _resolve_validated(host) + port = u.port or (443 if u.scheme == "https" else 80) + path = (u.path or "/") + (("?" + u.query) if u.query else "") + return u.scheme, host, port, path, addr, family + + +def _request_one(url, *, headers, timeout): + """Issue one HTTP request with the IP pinned. Returns + (status, headers_obj, body_bytes). Does NOT follow redirects.""" + scheme, host, port, path, addr, family = _validate_url(url) + + # Build a socket bound to the validated IP. http.client lets us pass a + # custom socket via a connection subclass. + class PinnedHTTPConn(http.client.HTTPConnection): + def connect(self): + self.sock = socket.create_connection( + (addr, port), timeout=timeout, + source_address=None + ) + class PinnedHTTPSConn(http.client.HTTPSConnection): + def connect(self): + sock = socket.create_connection( + (addr, port), timeout=timeout + ) + ctx = ssl.create_default_context() + # TLS SNI + cert verification against the original hostname, + # while the TCP connection is pinned to the validated IP. + self.sock = ctx.wrap_socket(sock, server_hostname=host) + + if scheme == "https": + conn = PinnedHTTPSConn(host, port, timeout=timeout) + else: + conn = PinnedHTTPConn(host, port, timeout=timeout) + + try: + req_headers = {"Host": host, **(headers or {})} + conn.request("GET", path, headers=req_headers) + resp = conn.getresponse() + body = resp.read() + return resp.status, resp.headers, body + finally: + conn.close() def safe_fetch(url, *, headers=None, timeout=15, max_hops=5): + """GET `url` with SSRF mitigations. Returns body bytes on 2xx, raises on + non-2xx (after exhausting redirect budget).""" current = url for hop in range(max_hops + 1): - u = urlparse(current) - if u.scheme not in ("http", "https"): - raise SafeFetchError(f"unsupported scheme {u.scheme}") - host = u.hostname - try: - ipaddress.ip_address(host) - if _is_blocked(host): - raise SafeFetchError(f"blocked literal IP {host}") - except ValueError: - _resolve(host) - req = urllib.request.Request(current, headers=headers or {}) - try: - opener = urllib.request.build_opener() - with opener.open(req, timeout=timeout) as r: - return r.read() - except urllib.error.HTTPError as e: - if e.code in (301, 302, 303, 307, 308) and "Location" in e.headers and hop < max_hops: - current = e.headers["Location"] - continue - raise + status, resp_headers, body = _request_one( + current, headers=headers, timeout=timeout + ) + if status in (301, 302, 303, 307, 308): + loc = resp_headers.get("Location") + if not loc: + raise SafeFetchError("redirect without Location") + if hop >= max_hops: + raise SafeFetchError(f"too many redirects ({max_hops})") + # Resolve relative redirects + re-validate on the next loop pass. + current = urllib.parse.urljoin(current, loc) + continue + if 200 <= status < 300: + return body + raise SafeFetchError(f"http {status} from {current}") raise SafeFetchError(f"too many redirects ({max_hops})")