Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import botocore.exceptions
import pytest
import requests
from moto import mock_aws
from mypy_boto3_s3.literals import BucketLocationConstraintType
from sqlalchemy import Engine, MetaData, create_engine

Expand All @@ -17,6 +18,14 @@
REGION_NAME: BucketLocationConstraintType = "eu-central-1"


# Apply moto AWS mock to all tests
@pytest.fixture(scope="session", autouse=True)
def _mock_aws() -> Generator[None, None, None]:
"""Mock AWS services for all tests to avoid needing real AWS credentials."""
with mock_aws():
yield


@pytest.fixture(autouse=True)
def patch_update_job(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock_update_job = MagicMock()
Expand Down Expand Up @@ -154,11 +163,13 @@ def __init__(self, bucket_name_suffix: str):
self.region_name: BucketLocationConstraintType = REGION_NAME

def create(self) -> None:
# Use path-style addressing for better compatibility with moto
# This avoids virtual-hosted-style URLs that don't work well with presigned URLs in moto
from botocore.config import Config
self.s3_client = boto3.client(
"s3",
region_name=self.region_name,
# required for pre-signing URLs to work
endpoint_url=f"https://s3.{self.region_name}.amazonaws.com",
config=Config(s3={"addressing_style": "path"}),
)
exists = self.cleanup()
if not exists:
Expand Down Expand Up @@ -215,22 +226,26 @@ def s3_testing_bucket() -> Generator[S3TestingBucket, Any, None]:
bucket.delete()


@pytest.mark.aws
@pytest.fixture(scope="session", autouse=True)
def cleanup_old_test_buckets() -> None:
"""
Find and delete all S3 buckets with the test prefix.
This helps clean up buckets from previous test runs.
"""
s3_client = boto3.client("s3", region_name=REGION_NAME)
response = s3_client.list_buckets(Prefix=TEST_BUCKET_PREFIX)
for bucket in response["Buckets"]:
bucket_name = bucket["Name"]
s3 = boto3.resource("s3", region_name=REGION_NAME)
s3_bucket = s3.Bucket(bucket_name)
bucket_versioning = s3.BucketVersioning(bucket_name)
if bucket_versioning.status == "Enabled":
s3_bucket.object_versions.delete()
else:
s3_bucket.objects.all().delete()
s3_client.delete_bucket(Bucket=bucket_name)
try:
s3_client = boto3.client("s3", region_name=REGION_NAME)
response = s3_client.list_buckets()
for bucket in response.get("Buckets", []):
bucket_name = bucket["Name"]
if bucket_name.startswith(TEST_BUCKET_PREFIX):
s3 = boto3.resource("s3", region_name=REGION_NAME)
s3_bucket = s3.Bucket(bucket_name)
bucket_versioning = s3.BucketVersioning(bucket_name)
if bucket_versioning.status == "Enabled":
s3_bucket.object_versions.delete()
else:
s3_bucket.objects.all().delete()
s3_client.delete_bucket(Bucket=bucket_name)
except botocore.exceptions.NoCredentialsError:
# Skip cleanup if no AWS credentials are available
pass
16 changes: 16 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import datetime
import shutil
from typing import Generator, cast
from unittest.mock import MagicMock, patch

import pytest
from mypy_boto3_s3 import S3Client

# Patch fastapi_cloudauth.cognito.WorkerGroupCognitoCurrentUser to avoid Cognito initialization
# This prevents attempting to connect to AWS Cognito during test setup
original_cognito_current_user = None

def _mock_cognito_init(self, region=None, userPoolId=None, client_id=None):
"""Mock Cognito init that doesn't make HTTP requests"""
self.region = region
self.userPoolId = userPoolId
self.client_id = client_id

# Patch the auth module's Cognito class before importing anything that uses it
from workerfacing_api.core import auth
if hasattr(auth, "WorkerGroupCognitoCurrentUser"):
auth.WorkerGroupCognitoCurrentUser.__init__ = _mock_cognito_init

from tests.conftest import RDSTestingInstance, S3TestingBucket
from workerfacing_api.core.auth import APIKeyDependency, GroupClaims
from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/endpoints/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def test_job_files_post(
params={"type": "output", "base_path": "test"},
)
assert res.status_code == 201
if isinstance(queue, SQLiteRDSJobQueue):
req_base = client
else:
# When using S3 filesystem, we need to use requests library to make HTTP calls
# to the presigned S3 URL. TestClient can't make external HTTP requests.
if isinstance(base_filesystem, S3Filesystem):
req_base = requests # type: ignore
else:
req_base = client
res = req_base.request(
**res.json(),
files={
Expand Down
31 changes: 26 additions & 5 deletions workerfacing_api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,32 @@ def queue_dep() -> queue.RDSJobQueue:


# Worker authentication
current_user_dep = auth.WorkerGroupCognitoCurrentUser(
region=settings.cognito_region,
userPoolId=settings.cognito_user_pool_id,
client_id=settings.cognito_client_id,
)
# Lazy initialization to avoid HTTP calls during module import (important for testing)
_current_user_dep: auth.WorkerGroupCognitoCurrentUser | None = None


def _get_current_user_dep() -> auth.WorkerGroupCognitoCurrentUser:
global _current_user_dep
if _current_user_dep is None:
_current_user_dep = auth.WorkerGroupCognitoCurrentUser(
region=settings.cognito_region,
userPoolId=settings.cognito_user_pool_id,
client_id=settings.cognito_client_id,
)
return _current_user_dep


# Create a property-like object that behaves like the original current_user_dep
# but initializes lazily on first access
class _CurrentUserDepProxy:
def __call__(self, *args, **kwargs): # type: ignore
return _get_current_user_dep()(*args, **kwargs)

def __getattr__(self, name): # type: ignore
return getattr(_get_current_user_dep(), name)


current_user_dep = _CurrentUserDepProxy()


async def current_user_global_dep(
Expand Down