diff --git a/backend/src/file_processor.py b/backend/src/file_processor.py index 4ed314c8..def2934a 100644 --- a/backend/src/file_processor.py +++ b/backend/src/file_processor.py @@ -86,19 +86,12 @@ def _sanitize_filename(self, filename: str) -> str: return "default_sanitized_filename" return sanitized - async def _is_safe_url(self, url: str) -> bool: + async def _resolve_and_validate(self, hostname: str) -> Optional[str]: """ - Validates URL to prevent SSRF by resolving hostname and checking for private/loopback IPs. + Resolves hostname to IP and checks if it's a safe (public) IP. + Returns the first safe IP address found, or None if validation fails. """ try: - parsed = urllib.parse.urlparse(url) - if parsed.scheme not in ('http', 'https'): - logger.warning(f"Unsafe scheme in URL for hostname: {parsed.hostname}") - return False - hostname = parsed.hostname - if not hostname: - return False - # Run blocking DNS resolution in executor loop = asyncio.get_running_loop() # Use getaddrinfo to support both IPv4 and IPv6 @@ -106,8 +99,9 @@ async def _is_safe_url(self, url: str) -> bool: infos = await loop.run_in_executor(None, socket.getaddrinfo, hostname, None) except socket.gaierror: logger.warning(f"Could not resolve hostname: {hostname}") - return False + return None + safe_ip = None for info in infos: ip = info[4][0] # Strip IPv6 scope ID if present (e.g. fe80::1%eth0 -> fe80::1) @@ -118,15 +112,36 @@ async def _is_safe_url(self, url: str) -> bool: ip_obj = ipaddress.ip_address(ip) if ip_obj.is_loopback or ip_obj.is_private or ip_obj.is_link_local: logger.warning(f"Blocked attempt to access private IP: {ip} for hostname: {hostname}") - return False + return None # Fail immediately if ANY resolved IP is private except ValueError: # Fail securely if IP cannot be parsed logger.warning(f"Could not parse IP address: {ip} for hostname: {hostname}") - return False - return True + return None + + if safe_ip is None: + safe_ip = ip + + return safe_ip except Exception as e: - # Don't log full URL in exception either - logger.error(f"Error validating URL for hostname {parsed.hostname if 'parsed' in locals() else 'unknown'}: {e}") + logger.error(f"Error resolving/validating hostname {hostname}: {e}") + return None + + async def _is_safe_url(self, url: str) -> bool: + """ + Validates URL to prevent SSRF by resolving hostname and checking for private/loopback IPs. + Wrapper for backward compatibility. + """ + try: + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ('http', 'https'): + logger.warning(f"Unsafe scheme in URL for hostname: {parsed.hostname}") + return False + hostname = parsed.hostname + if not hostname: + return False + return await self._resolve_and_validate(hostname) is not None + except Exception as e: + logger.error(f"Error validating URL: {e}") return False def validate_upload(self, file: UploadFile) -> ValidationResult: @@ -615,15 +630,40 @@ async def download_from_url(self, url: str, job_id: str) -> DownloadResult: # Manual redirect handling with SSRF protection for _ in range(5): # Max 5 redirects - if not await self._is_safe_url(current_url): - # Sanitize URL for logging (strip query params/fragments) - parsed_unsafe = urllib.parse.urlparse(current_url) - sanitized_unsafe = f"{parsed_unsafe.scheme}://{parsed_unsafe.hostname}{parsed_unsafe.path}" - msg = f"Unsafe URL detected: {sanitized_unsafe}" + parsed_current = urllib.parse.urlparse(current_url) + hostname = parsed_current.hostname + + if not hostname: + return DownloadResult(success=False, message="Invalid URL: Missing hostname") + + # Resolve and validate IP to prevent SSRF + safe_ip = await self._resolve_and_validate(hostname) + if not safe_ip: + sanitized_unsafe = f"{parsed_current.scheme}://{parsed_current.hostname}{parsed_current.path}" + msg = f"Unsafe URL detected or resolution failed: {sanitized_unsafe}" logger.warning(msg) return DownloadResult(success=False, message="Unsafe URL detected") - response = await client.get(current_url, follow_redirects=False, timeout=30.0) + if parsed_current.scheme == "http": + # TOCTOU Protection for HTTP: Use IP directly + ip_netloc = f"[{safe_ip}]" if ":" in safe_ip else safe_ip + if parsed_current.port: + ip_netloc += f":{parsed_current.port}" + + # Preserve credentials if present + final_netloc = ip_netloc + if "@" in parsed_current.netloc: + auth_part = parsed_current.netloc.rsplit("@", 1)[0] + final_netloc = f"{auth_part}@{ip_netloc}" + + url_with_ip = parsed_current._replace(netloc=final_netloc).geturl() + headers = {"Host": hostname} + logger.info(f"Using DNS pinning for HTTP: {hostname} -> {safe_ip}") + response = await client.get(url_with_ip, headers=headers, follow_redirects=False, timeout=30.0) + else: + # HTTPS: Cannot rewrite URL without breaking SSL verification (TOCTOU risk remains) + logger.warning(f"HTTPS URL used: {hostname}. TOCTOU protection limited.") + response = await client.get(current_url, follow_redirects=False, timeout=30.0) if 300 <= response.status_code < 400: location = response.headers.get("Location") diff --git a/backend/src/tests/unit/test_file_processor.py b/backend/src/tests/unit/test_file_processor.py index b3eb2c12..e135ef86 100644 --- a/backend/src/tests/unit/test_file_processor.py +++ b/backend/src/tests/unit/test_file_processor.py @@ -17,8 +17,8 @@ def file_processor(): """Pytest fixture to provide a FileProcessor instance.""" fp = FileProcessor() - # Mock _is_safe_url to always return True by default for existing tests to avoid network calls - fp._is_safe_url = mock.AsyncMock(return_value=True) + # Mock _resolve_and_validate to return a safe IP by default + fp._resolve_and_validate = mock.AsyncMock(return_value="93.184.216.34") return fp @@ -97,9 +97,10 @@ async def mock_aiter_bytes(): # But wait, we mocked _is_safe_url, so the loop runs. # The loop calls client.get(..., follow_redirects=False). - # Verify that get was called with follow_redirects=False + # Verify that get was called with follow_redirects=False and IP pinning + expected_url = "http://93.184.216.34/download.zip" MockAsyncClient.return_value.__aenter__.return_value.get.assert_called_with( - url, follow_redirects=False, timeout=30.0 + expected_url, headers={"Host": "example.com"}, follow_redirects=False, timeout=30.0 ) @pytest.mark.asyncio @@ -135,6 +136,12 @@ async def mock_aiter_bytes(): assert result.file_name == "another_example.jar" assert expected_file_path.read_bytes() == b"jar content" + # Verify IP pinning + expected_url = "http://93.184.216.34/another_example.jar" + MockAsyncClient.return_value.__aenter__.return_value.get.assert_called_with( + expected_url, headers={"Host": "example.com"}, follow_redirects=False, timeout=30.0 + ) + @pytest.mark.asyncio @mock.patch("file_processor.httpx.AsyncClient") async def test_download_from_url_success_content_type_extension( @@ -166,6 +173,12 @@ async def mock_aiter_bytes(): assert result.file_name == "some_file_no_ext" assert expected_file_path.read_bytes() == b"java archive" + # Verify IP pinning + expected_url = "http://93.184.216.34/some_file_no_ext" + MockAsyncClient.return_value.__aenter__.return_value.get.assert_called_with( + expected_url, headers={"Host": "example.com"}, follow_redirects=False, timeout=30.0 + ) + @pytest.mark.asyncio @pytest.mark.parametrize( "status_code, error_type_expected", diff --git a/backend/src/tests/unit/test_file_processor_security.py b/backend/src/tests/unit/test_file_processor_security.py new file mode 100644 index 00000000..f30dc6ee --- /dev/null +++ b/backend/src/tests/unit/test_file_processor_security.py @@ -0,0 +1,110 @@ +import pytest +import socket +import logging +from unittest import mock +import httpx +from file_processor import FileProcessor + +@pytest.fixture +def file_processor(): + return FileProcessor() + +@pytest.mark.asyncio +async def test_resolve_and_validate_public_ip(file_processor): + with mock.patch("socket.getaddrinfo") as mock_getaddrinfo: + # Mock public IP: 93.184.216.34 (example.com) + mock_getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('93.184.216.34', 80))] + + result = await file_processor._resolve_and_validate("example.com") + assert result == "93.184.216.34" + +@pytest.mark.asyncio +async def test_resolve_and_validate_private_ip(file_processor): + with mock.patch("socket.getaddrinfo") as mock_getaddrinfo: + # Mock private IP: 192.168.1.1 + mock_getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 80))] + + result = await file_processor._resolve_and_validate("internal.local") + assert result is None + +@pytest.mark.asyncio +async def test_resolve_and_validate_mixed_ips(file_processor): + with mock.patch("socket.getaddrinfo") as mock_getaddrinfo: + # Mock mixed IPs: one public, one private (should fail) + mock_getaddrinfo.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('93.184.216.34', 80)), + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 80)) + ] + + result = await file_processor._resolve_and_validate("mixed.local") + assert result is None + +@pytest.mark.asyncio +async def test_download_from_url_http_rewrites_to_ip(file_processor): + job_id = "test_job_http" + url = "http://example.com/file.zip" + + with mock.patch.object(file_processor, "_resolve_and_validate", return_value="93.184.216.34") as mock_resolve: + with mock.patch("httpx.AsyncClient") as MockAsyncClient: + mock_client = MockAsyncClient.return_value.__aenter__.return_value + mock_response = mock.AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.headers = {"Content-Disposition": 'attachment; filename="file.zip"'} + mock_response.url = httpx.URL("http://93.184.216.34/file.zip") + async def mock_aiter_bytes(): + yield b"content" + mock_response.aiter_bytes = mock_aiter_bytes + mock_client.get.return_value = mock_response + + # Mock file writing to avoid disk usage/errors + with mock.patch("builtins.open", mock.mock_open()): + # Also mock Path.mkdir to avoid actual FS + with mock.patch("pathlib.Path.mkdir"): + # Mock Path.stat to return size > 0 + with mock.patch("pathlib.Path.stat") as mock_stat: + mock_stat.return_value.st_size = 100 + + result = await file_processor.download_from_url(url, job_id) + + assert result.success is True + # Check that client.get was called with IP URL and Host header + expected_url = "http://93.184.216.34/file.zip" + mock_client.get.assert_called_with( + expected_url, + headers={"Host": "example.com"}, + follow_redirects=False, + timeout=30.0 + ) + +@pytest.mark.asyncio +async def test_download_from_url_https_uses_hostname(file_processor): + job_id = "test_job_https" + url = "https://example.com/file.zip" + + with mock.patch.object(file_processor, "_resolve_and_validate", return_value="93.184.216.34") as mock_resolve: + with mock.patch("httpx.AsyncClient") as MockAsyncClient: + mock_client = MockAsyncClient.return_value.__aenter__.return_value + mock_response = mock.AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.headers = {"Content-Disposition": 'attachment; filename="file.zip"'} + mock_response.url = httpx.URL("https://example.com/file.zip") + async def mock_aiter_bytes(): + yield b"content" + mock_response.aiter_bytes = mock_aiter_bytes + mock_client.get.return_value = mock_response + + # Mock file writing + with mock.patch("builtins.open", mock.mock_open()): + with mock.patch("pathlib.Path.mkdir"): + with mock.patch("pathlib.Path.stat") as mock_stat: + mock_stat.return_value.st_size = 100 + + result = await file_processor.download_from_url(url, job_id) + + assert result.success is True + # Check that client.get was called with ORIGINAL URL (no IP rewrite for HTTPS) + mock_client.get.assert_called_with( + url, + follow_redirects=False, + timeout=30.0 + )