diff --git a/backend/src/main.py b/backend/src/main.py index 655b9f50..ef077948 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -28,6 +28,7 @@ from services.report_generator import ConversionReportGenerator from services.error_handlers import register_exception_handlers from services.rate_limiter import RateLimitMiddleware, get_rate_limiter, init_rate_limiter, close_rate_limiter, create_global_limiter +from services.security_headers import SecurityHeadersMiddleware # Import API routers from api import performance, behavioral_testing, validation, comparison, embeddings, feedback, experiments, behavior_files, behavior_templates, behavior_export, advanced_events, conversions, mod_imports @@ -125,6 +126,9 @@ async def lifespan(app: FastAPI): rate_limiter = create_global_limiter() app.add_middleware(RateLimitMiddleware, rate_limiter=rate_limiter) +# Security headers middleware +app.add_middleware(SecurityHeadersMiddleware) + @app.on_event("startup") async def startup_event(): """Initialize rate limiter on startup""" diff --git a/backend/src/services/security_headers.py b/backend/src/services/security_headers.py new file mode 100644 index 00000000..43829220 --- /dev/null +++ b/backend/src/services/security_headers.py @@ -0,0 +1,14 @@ +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """ + Middleware to add security headers to all responses. + """ + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + return response diff --git a/backend/src/tests/unit/test_security_headers.py b/backend/src/tests/unit/test_security_headers.py new file mode 100644 index 00000000..ff17a20b --- /dev/null +++ b/backend/src/tests/unit/test_security_headers.py @@ -0,0 +1,14 @@ +from fastapi.testclient import TestClient +from src.main import app + +client = TestClient(app) + +def test_security_headers(): + response = client.get("/api/v1/health") + assert response.status_code == 200 + + headers = response.headers + assert headers["X-Content-Type-Options"] == "nosniff" + assert headers["X-Frame-Options"] == "DENY" + assert headers["X-XSS-Protection"] == "1; mode=block" + assert headers["Referrer-Policy"] == "strict-origin-when-cross-origin"