fix(workers): safe_fetch pins IP + manual redirect re-validation

Two real findings from the security reviewer:

1. urllib auto-follows 3xx redirects via the default HTTPRedirectHandler.
   The previous code's hop loop never ran — urllib silently followed.
   Replaced with http.client + a manual hop loop. Every hop re-runs
   _validate_url, so an open-redirect to 127.0.0.1 / RFC1918 / metadata
   gets caught on the second hop.

2. DNS TOCTOU — _resolve() validated but urllib.request re-resolved on
   connect. Now the connection is pinned to the validated IP via a
   PinnedHTTPConn / PinnedHTTPSConn subclass that overrides connect() to
   bind socket.create_connection to (addr, port). For HTTPS, TLS
   server_hostname is set to the original host so SNI + cert
   verification still work against the named host while the TCP
   destination is the pinned IP.

Tests added: redirect-to-loopback short-circuits at validation;
too-many-redirects exhausts max_hops; 2xx returns body; non-2xx raises.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root
2026-06-01 10:28:55 +10:00
parent 7707b7eb00
commit a8b2cddcf5
2 changed files with 165 additions and 37 deletions

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from void_workers.safe_fetch import safe_fetch, SafeFetchError from unittest.mock import patch, MagicMock
from void_workers.safe_fetch import safe_fetch, SafeFetchError, _validate_url
def test_rejects_file_scheme(): def test_rejects_file_scheme():
@@ -25,3 +26,66 @@ def test_rejects_metadata_endpoint():
def test_rejects_cgnat(): def test_rejects_cgnat():
with pytest.raises(SafeFetchError): with pytest.raises(SafeFetchError):
safe_fetch("http://100.64.0.1/x") safe_fetch("http://100.64.0.1/x")
def test_redirect_to_loopback_is_rejected():
"""Open-redirect attack: a public URL 302s to http://127.0.0.1/.
The loop re-runs full validation on the next hop, so the redirect
target's literal IP triggers _validate_url and raises."""
redirect_headers = MagicMock()
redirect_headers.get.return_value = "http://127.0.0.1/admin"
call_count = {"n": 0}
def side_effect(url, **_kw):
call_count["n"] += 1
if call_count["n"] == 1:
# First hop: 302 to a blocked address.
return (302, redirect_headers, b"")
# Second hop: should never reach _request_one because _validate_url
# in _request_one will raise before issuing the request.
raise AssertionError("second hop was issued — validation bypassed!")
# Patch _request_one to short-circuit only the FIRST hop's network IO.
# The second hop's call still goes through the real _request_one which
# invokes the real _validate_url — that's where the blocked-IP error
# comes from.
import void_workers.safe_fetch as sf
real_request_one = sf._request_one
def hybrid(url, **kw):
if call_count["n"] == 0:
return side_effect(url, **kw)
return real_request_one(url, **kw)
with patch.object(sf, "_request_one", side_effect=hybrid):
with pytest.raises(SafeFetchError, match="blocked"):
safe_fetch("http://example.com/")
def test_too_many_redirects():
redirect_headers = MagicMock()
redirect_headers.get.return_value = "http://example.com/loop"
with patch("void_workers.safe_fetch._request_one",
return_value=(302, redirect_headers, b"")):
with pytest.raises(SafeFetchError, match="too many redirects"):
safe_fetch("http://example.com/loop", max_hops=2)
def test_validate_url_returns_pinned_address_for_literal_public_ip():
scheme, host, port, path, addr, family = _validate_url("http://8.8.8.8:80/x")
assert host == "8.8.8.8"
assert addr == "8.8.8.8"
assert port == 80
def test_2xx_returns_body():
headers = MagicMock()
with patch("void_workers.safe_fetch._request_one",
return_value=(200, headers, b"hello")):
assert safe_fetch("http://example.com/x") == b"hello"
def test_non_2xx_raises():
headers = MagicMock()
with patch("void_workers.safe_fetch._request_one",
return_value=(500, headers, b"err")):
with pytest.raises(SafeFetchError, match="http 500"):
safe_fetch("http://example.com/x")

View File

@@ -1,18 +1,22 @@
"""Python port of lib/ingest/safe_fetch.js. """SSRF-safe HTTP client used by sync.source_doc (and any future workers).
Same SSRF mitigations the Node side ships: Same contract as lib/ingest/safe_fetch.js on the Node side:
- http/https only - http/https only
- DNS-resolved hostnames checked against loopback / RFC1918 / - DNS-resolve and reject loopback / RFC1918 / link-local / CGNAT / metadata /
link-local / CGNAT / IPv6 ULA + link-local IPv6 ULA + link-local
- Redirects followed manually with the same checks on each hop - Pin the validated IP into the connection so a rebind between our DNS check
- VOID_INGEST_ALLOW_PRIVATE=true gate for offline-fixture tests and the TCP connect cannot point us at an internal address.
- Follow redirects MANUALLY, re-validating every hop. We disable urllib's
built-in redirect handler so it cannot silently auto-follow.
- `VOID_INGEST_ALLOW_PRIVATE=true` gate for offline-fixture tests.
""" """
import socket import http.client
import ipaddress import ipaddress
import urllib.request
import urllib.error
import os import os
from urllib.parse import urlparse import socket
import ssl
import urllib.parse
BLOCK_V4_NETS = [ipaddress.ip_network(c) for c in [ BLOCK_V4_NETS = [ipaddress.ip_network(c) for c in [
"0.0.0.0/8", "127.0.0.0/8", "10.0.0.0/8", "0.0.0.0/8", "127.0.0.0/8", "10.0.0.0/8",
@@ -42,41 +46,101 @@ def _is_blocked(addr):
return False return False
def _resolve(host): def _resolve_validated(host):
"""Resolve the host and return (address, family). Raises if any returned
address is in a blocked range."""
try: try:
infos = socket.getaddrinfo(host, None) infos = socket.getaddrinfo(host, None)
except socket.gaierror as e: except socket.gaierror as e:
raise SafeFetchError(f"no DNS for {host}: {e}") raise SafeFetchError(f"no DNS for {host}: {e}")
addrs = list({i[4][0] for i in infos}) addrs = {(i[4][0], i[0]) for i in infos} # de-dupe
for a in addrs:
if _is_blocked(a):
raise SafeFetchError(f"{host} resolves to blocked address {a}")
if not addrs: if not addrs:
raise SafeFetchError(f"no addresses for {host}") raise SafeFetchError(f"no addresses for {host}")
return addrs[0] for a, _fam in addrs:
if _is_blocked(a):
raise SafeFetchError(f"{host} resolves to blocked address {a}")
# Pick the first record. Caller pins this exact IP into the socket.
address, family = next(iter(addrs))
return address, family
def _validate_url(url):
"""Returns (scheme, hostname, port, path-with-query, pinned_addr, family)."""
u = urllib.parse.urlparse(url)
if u.scheme not in ("http", "https"):
raise SafeFetchError(f"unsupported scheme {u.scheme}")
host = u.hostname
if not host:
raise SafeFetchError(f"no hostname in {url}")
# Literal IP path
try:
ipaddress.ip_address(host)
if _is_blocked(host):
raise SafeFetchError(f"blocked literal IP {host}")
addr, family = host, (socket.AF_INET6 if ":" in host else socket.AF_INET)
except ValueError:
addr, family = _resolve_validated(host)
port = u.port or (443 if u.scheme == "https" else 80)
path = (u.path or "/") + (("?" + u.query) if u.query else "")
return u.scheme, host, port, path, addr, family
def _request_one(url, *, headers, timeout):
"""Issue one HTTP request with the IP pinned. Returns
(status, headers_obj, body_bytes). Does NOT follow redirects."""
scheme, host, port, path, addr, family = _validate_url(url)
# Build a socket bound to the validated IP. http.client lets us pass a
# custom socket via a connection subclass.
class PinnedHTTPConn(http.client.HTTPConnection):
def connect(self):
self.sock = socket.create_connection(
(addr, port), timeout=timeout,
source_address=None
)
class PinnedHTTPSConn(http.client.HTTPSConnection):
def connect(self):
sock = socket.create_connection(
(addr, port), timeout=timeout
)
ctx = ssl.create_default_context()
# TLS SNI + cert verification against the original hostname,
# while the TCP connection is pinned to the validated IP.
self.sock = ctx.wrap_socket(sock, server_hostname=host)
if scheme == "https":
conn = PinnedHTTPSConn(host, port, timeout=timeout)
else:
conn = PinnedHTTPConn(host, port, timeout=timeout)
try:
req_headers = {"Host": host, **(headers or {})}
conn.request("GET", path, headers=req_headers)
resp = conn.getresponse()
body = resp.read()
return resp.status, resp.headers, body
finally:
conn.close()
def safe_fetch(url, *, headers=None, timeout=15, max_hops=5): def safe_fetch(url, *, headers=None, timeout=15, max_hops=5):
"""GET `url` with SSRF mitigations. Returns body bytes on 2xx, raises on
non-2xx (after exhausting redirect budget)."""
current = url current = url
for hop in range(max_hops + 1): for hop in range(max_hops + 1):
u = urlparse(current) status, resp_headers, body = _request_one(
if u.scheme not in ("http", "https"): current, headers=headers, timeout=timeout
raise SafeFetchError(f"unsupported scheme {u.scheme}") )
host = u.hostname if status in (301, 302, 303, 307, 308):
try: loc = resp_headers.get("Location")
ipaddress.ip_address(host) if not loc:
if _is_blocked(host): raise SafeFetchError("redirect without Location")
raise SafeFetchError(f"blocked literal IP {host}") if hop >= max_hops:
except ValueError: raise SafeFetchError(f"too many redirects ({max_hops})")
_resolve(host) # Resolve relative redirects + re-validate on the next loop pass.
req = urllib.request.Request(current, headers=headers or {}) current = urllib.parse.urljoin(current, loc)
try: continue
opener = urllib.request.build_opener() if 200 <= status < 300:
with opener.open(req, timeout=timeout) as r: return body
return r.read() raise SafeFetchError(f"http {status} from {current}")
except urllib.error.HTTPError as e:
if e.code in (301, 302, 303, 307, 308) and "Location" in e.headers and hop < max_hops:
current = e.headers["Location"]
continue
raise
raise SafeFetchError(f"too many redirects ({max_hops})") raise SafeFetchError(f"too many redirects ({max_hops})")