Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
aws-region: eu-central-1
- name: Run tests
run: |
poetry run pytest -m "aws or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt
poetry run pytest -m "aws and not(deprecated) or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt
echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV
- name: Coverage comment
uses: MishaKav/pytest-coverage-comment@main
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ serve = "scripts.serve:main"

[tool.pytest.ini_options]
markers = [
"aws: requires aws credentials"
"aws: requires aws credentials",
"deprecated: tests for deprecated features",
]
addopts = "-m 'not aws'"

Expand Down
34 changes: 19 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,10 @@
REGION_NAME: BucketLocationConstraintType = "eu-central-1"


@pytest.fixture(scope="session")
def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]:
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(autouse=True, scope="session")
def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock:
@pytest.fixture(autouse=True)
def patch_update_job(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock_update_job = MagicMock()
monkeypatch_module.setattr(job_tracking, "update_job", mock_update_job)
monkeypatch.setattr(job_tracking, "update_job", mock_update_job)
return mock_update_job


Expand All @@ -40,7 +34,7 @@ def create(self) -> None:
self.add_ingress_rule()
self.db_url = self.create_db_url()
self.engine = self.get_engine()
self.delete_db_tables()
self.cleanup()

def get_engine(self) -> Engine:
for _ in range(5):
Expand Down Expand Up @@ -79,7 +73,20 @@ def add_ingress_rule(self) -> None:
else:
raise e

def delete_db_tables(self) -> None:
def remove_ingress_rules(self) -> None:
# cleans up earlier tests too (in case of failures)
security_groups = self.ec2_client.describe_security_groups(
GroupNames=[self.vpc_sg_rule_params["GroupName"]]
)
for sg in security_groups["SecurityGroups"]:
for rule in sg["IpPermissions"]:
if rule.get("FromPort") == 5432 and rule.get("ToPort") == 5432:
self.ec2_client.revoke_security_group_ingress(
GroupId=sg["GroupId"],
IpPermissions=[rule], # type: ignore
)

def cleanup(self) -> None:
metadata = MetaData()
engine = self.engine
metadata.reflect(engine)
Expand Down Expand Up @@ -138,14 +145,11 @@ def create_db_url(self) -> str:
address = response["DBInstances"][0]["Endpoint"]["Address"]
return f"postgresql://{user}:{password}@{address}:5432/{self.db_name}"

def cleanup(self) -> None:
self.delete_db_tables()
self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params)

def delete(self) -> None:
# never used (AWS tests skipped)
if not hasattr(self, "rds_client"):
return
self.remove_ingress_rules()
self.rds_client.delete_db_instance(
DBInstanceIdentifier=self.db_name,
SkipFinalSnapshot=True,
Expand Down
187 changes: 111 additions & 76 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import datetime
import shutil
from typing import Any, Generator, cast
from typing import Generator, cast

import pytest
from mypy_boto3_s3 import S3Client

from tests.conftest import RDSTestingInstance, S3TestingBucket
from workerfacing_api import settings
from workerfacing_api.core.auth import APIKeyDependency, GroupClaims
from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem
from workerfacing_api.core.queue import RDSJobQueue
from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue
from workerfacing_api.dependencies import (
APIKeyDependency,
GroupClaims,
authorizer,
current_user_dep,
filesystem_dep,
queue_dep,
)
from workerfacing_api.main import workerfacing_app
from workerfacing_api.schemas.queue_jobs import (
AppSpecs,
EnvironmentTypes,
HandlerSpecs,
HardwareSpecs,
JobSpecs,
MetaSpecs,
PathsUploadSpecs,
SubmittedJob,
)


@pytest.fixture(scope="session")
Expand All @@ -24,8 +34,8 @@ def test_username() -> str:


@pytest.fixture(scope="session")
def base_dir() -> str:
return "int_test_dir"
def base_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
return str(tmp_path_factory.mktemp("int_test_dir"))


@pytest.fixture(scope="session")
Expand All @@ -35,101 +45,97 @@ def internal_api_key_secret() -> str:

@pytest.fixture(
scope="session",
params=["local", pytest.param("aws", marks=pytest.mark.aws)],
params=["local-fs", pytest.param("aws-fs", marks=pytest.mark.aws)],
)
def env(
request: pytest.FixtureRequest,
rds_testing_instance: RDSTestingInstance,
s3_testing_bucket: S3TestingBucket,
) -> Generator[str, Any, None]:
env = cast(str, request.param)
if env == "aws":
rds_testing_instance.create()
s3_testing_bucket.create()
yield env
if env == "aws":
rds_testing_instance.cleanup()
s3_testing_bucket.cleanup()


@pytest.fixture(scope="session")
def base_filesystem(
env: str,
base_dir: str,
monkeypatch_module: pytest.MonkeyPatch,
s3_testing_bucket: S3TestingBucket,
) -> Generator[FileSystem, Any, None]:
monkeypatch_module.setattr(
settings,
"user_data_root_path",
base_dir,
)
monkeypatch_module.setattr(
settings,
"filesystem",
"local" if env == "local" else "s3",
)

if env == "local":
shutil.rmtree(base_dir, ignore_errors=True)
yield LocalFilesystem(base_dir, base_dir)
shutil.rmtree(base_dir, ignore_errors=True)

elif env == "aws":
# Update settings to use the actual unique bucket name created by S3TestingBucket
monkeypatch_module.setattr(
settings,
"s3_bucket",
s3_testing_bucket.bucket_name,
)
yield S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name)
s3_testing_bucket.cleanup()

request: pytest.FixtureRequest,
) -> FileSystem:
if request.param == "local-fs":
return LocalFilesystem(base_dir, base_dir)
elif request.param == "aws-fs":
s3_testing_bucket.create()
return S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name)
else:
raise NotImplementedError


@pytest.fixture(scope="session")
@pytest.fixture(
scope="session",
params=["local-queue", pytest.param("aws-queue", marks=pytest.mark.aws)],
)
def queue(
env: str,
base_filesystem: FileSystem,
s3_testing_bucket: S3TestingBucket,
rds_testing_instance: RDSTestingInstance,
tmpdir_factory: pytest.TempdirFactory,
) -> Generator[RDSJobQueue, Any, None]:
if env == "local":
queue = RDSJobQueue(
f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db"
request: pytest.FixtureRequest,
) -> RDSJobQueue:
retry_different = False # allow retries on same worker
if request.param == "local-queue":
queue_path = tmpdir_factory.mktemp("integration") / "local.db"
s3_bucket: str | None = None
s3_client: S3Client | None = None
if isinstance(base_filesystem, S3Filesystem):
s3_bucket = s3_testing_bucket.bucket_name
s3_client = s3_testing_bucket.s3_client
return SQLiteRDSJobQueue(
f"sqlite:///{queue_path}",
retry_different=retry_different,
s3_client=s3_client,
s3_bucket=s3_bucket,
)
elif request.param == "aws-queue":
if isinstance(base_filesystem, LocalFilesystem):
pytest.skip("Only testing RDS queue in combination with S3 filesystem")
rds_testing_instance.create()
return RDSJobQueue(rds_testing_instance.db_url, retry_different=retry_different)
else:
queue = RDSJobQueue(rds_testing_instance.db_url)
queue.create(err_on_exists=True)
yield queue
raise NotImplementedError


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_filesystem_dep(
base_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch
) -> None:
monkeypatch_module.setitem(
base_filesystem: FileSystem,
s3_testing_bucket: S3TestingBucket,
base_dir: str,
monkeypatch: pytest.MonkeyPatch,
) -> Generator[None, None, None]:
monkeypatch.setitem(
workerfacing_app.dependency_overrides, # type: ignore
filesystem_dep,
lambda: base_filesystem,
)
yield
# cleanup after every test
if isinstance(base_filesystem, S3Filesystem):
s3_testing_bucket.cleanup()
else:
shutil.rmtree(base_dir, ignore_errors=True)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_queue_dep(
queue: RDSJobQueue, monkeypatch_module: pytest.MonkeyPatch
) -> None:
monkeypatch_module.setitem(
queue: RDSJobQueue,
rds_testing_instance: RDSTestingInstance,
monkeypatch: pytest.MonkeyPatch,
) -> Generator[None, None, None]:
monkeypatch.setitem(
workerfacing_app.dependency_overrides, # type: ignore
queue_dep,
lambda: queue,
)
yield
if isinstance(queue, SQLiteRDSJobQueue):
queue.delete()
else:
rds_testing_instance.cleanup()


@pytest.fixture(scope="session", autouse=True)
def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None:
monkeypatch_module.setitem(
@pytest.fixture(autouse=True)
def override_auth(monkeypatch: pytest.MonkeyPatch, test_username: str) -> None:
monkeypatch.setitem(
workerfacing_app.dependency_overrides, # type: ignore
current_user_dep,
lambda: GroupClaims(
Expand All @@ -142,13 +148,42 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) ->
)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_internal_api_key_secret(
monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str
monkeypatch: pytest.MonkeyPatch, internal_api_key_secret: str
) -> str:
monkeypatch_module.setitem(
monkeypatch.setitem(
workerfacing_app.dependency_overrides, # type: ignore
authorizer,
APIKeyDependency(internal_api_key_secret),
)
return internal_api_key_secret


@pytest.fixture
def base_job(base_filesystem: FileSystem, test_username: str) -> SubmittedJob:
time_now = datetime.datetime.now(datetime.timezone.utc).isoformat()
if isinstance(base_filesystem, S3Filesystem):
base_path = f"s3://{base_filesystem.bucket}"
else:
base_path = cast(LocalFilesystem, base_filesystem).base_post_path
paths_upload = PathsUploadSpecs(
output=f"{base_path}/{test_username}/test_out/1",
log=f"{base_path}/{test_username}/test_log/1",
artifact=f"{base_path}/{test_username}/test_arti/1",
)
return SubmittedJob(
job=JobSpecs(
app=AppSpecs(cmd=["cmd"], env={"env": "var"}),
handler=HandlerSpecs(image_url="u", files_up={"output": "out"}),
hardware=HardwareSpecs(),
meta=MetaSpecs(
job_id=1,
date_created=time_now,
),
),
environment=EnvironmentTypes.local,
group=None,
priority=1,
paths_upload=paths_upload,
)
13 changes: 8 additions & 5 deletions tests/integration/endpoints/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Generator

import pytest
from fastapi.testclient import TestClient
Expand All @@ -9,6 +9,13 @@
from workerfacing_api.main import workerfacing_app


@pytest.fixture
def client() -> Generator[TestClient, None, None]:
# run everything in lifespan context
with TestClient(workerfacing_app) as client:
yield client


@dataclass
class EndpointParams:
method: str
Expand All @@ -24,10 +31,6 @@ class _TestEndpoint(abc.ABC):
def passing_params(self, *args: Any, **kwargs: Any) -> list[EndpointParams]:
raise NotImplementedError

@pytest.fixture(scope="session")
def client(self) -> TestClient:
return TestClient(workerfacing_app)

def test_required_auth(
self,
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading