Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions api/core/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
import zipfile
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, Generator, cast
from typing import Any, BinaryIO, Callable, Generator

import boto3
import humanize
Expand All @@ -17,7 +17,7 @@
from mypy_boto3_s3 import S3Client
from mypy_boto3_s3.type_defs import ObjectIdentifierTypeDef

from api import models, settings
from api import models
from api.schemas.file import FileHTTPRequest, FileInfo, FileTypes


Expand Down Expand Up @@ -387,29 +387,43 @@ def download_url(
)


def get_filesystem_with_root(root_path: str) -> FileSystem:
def get_filesystem_with_root(
root_path: str,
filesystem: str,
s3_region: str,
s3_bucket: str | None,
) -> FileSystem:
"""Get the filesystem to use."""
predef_dirs = [e.value for e in models.UploadFileTypes] + [
e.value for e in models.OutputEndpoints
]
if settings.filesystem == "s3":
if filesystem == "s3":
assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem"
s3_client = boto3.client(
"s3",
region_name=settings.s3_region,
endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com",
region_name=s3_region,
endpoint_url=f"https://s3.{s3_region}.amazonaws.com",
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
)
# this and config=... required to avoid DNS problems with new buckets
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
return S3Filesystem(
root_path, s3_client, cast(str, settings.s3_bucket), predef_dirs=predef_dirs
)
elif settings.filesystem == "local":
return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs)
elif filesystem == "local":
return LocalFilesystem(root_path, predef_dirs=predef_dirs)
else:
raise ValueError("Invalid filesystem setting")


def get_user_filesystem(user_id: str) -> FileSystem:
def user_filesystem_getter(
user_data_root_path: str,
filesystem: str,
s3_region: str,
s3_bucket: str | None,
) -> Callable[[str], FileSystem]:
"""Get the filesystem to use for a user."""
return get_filesystem_with_root(str(Path(settings.user_data_root_path) / user_id))
return lambda user_id: get_filesystem_with_root(
str(Path(user_data_root_path) / user_id),
filesystem=filesystem,
s3_region=s3_region,
s3_bucket=s3_bucket,
)
30 changes: 15 additions & 15 deletions api/crud/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from sqlalchemy.orm import Session

from api import models, settings
from api.core.filesystem import FileSystem, get_user_filesystem
from api.core.filesystem import FileSystem
from api.schemas import job as schemas


def enqueue_job(
job: models.Job, enqueueing_func: Callable[[schemas.QueueJob], None]
job: models.Job,
filesystem: FileSystem,
enqueueing_func: Callable[[schemas.QueueJob], None],
) -> None:
user_fs = get_user_filesystem(user_id=job.user_id)

app = job.application
job_config = settings.application_config.config[app["application"]][app["version"]][
app["entrypoint"]
Expand Down Expand Up @@ -47,14 +47,14 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
f"artifact/{artifact_id}"
for artifact_id in job.attributes["files_down"]["artifact_ids"]
]
_validate_files(user_fs, [config_path] + data_paths + artifact_paths)
_validate_files(filesystem, [config_path] + data_paths + artifact_paths)
roots_down = handler_config["files_down"]
files_down = prepare_files(config_path, roots_down["config_id"], user_fs)
files_down = prepare_files(config_path, roots_down["config_id"], filesystem)
for data_path in data_paths:
files_down.update(prepare_files(data_path, roots_down["data_ids"], user_fs))
files_down.update(prepare_files(data_path, roots_down["data_ids"], filesystem))
for artifact_path in artifact_paths:
files_down.update(
prepare_files(artifact_path, roots_down["artifact_ids"], user_fs)
prepare_files(artifact_path, roots_down["artifact_ids"], filesystem)
)

app_specs = schemas.AppSpecs(
Expand All @@ -76,9 +76,9 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
)

paths_upload = {
"output": user_fs.full_path_uri(job.paths_out["output"]),
"log": user_fs.full_path_uri(job.paths_out["log"]),
"artifact": user_fs.full_path_uri(job.paths_out["artifact"]),
"output": filesystem.full_path_uri(job.paths_out["output"]),
"log": filesystem.full_path_uri(job.paths_out["log"]),
"artifact": filesystem.full_path_uri(job.paths_out["artifact"]),
}

queue_item = schemas.QueueJob(
Expand Down Expand Up @@ -117,6 +117,7 @@ def _validate_files(filesystem: FileSystem, paths: list[str]) -> None:

def create_job(
db: Session,
filesystem: FileSystem,
enqueueing_func: Callable[[schemas.QueueJob], None],
job: schemas.JobCreate,
user_id: int,
Expand Down Expand Up @@ -146,18 +147,17 @@ def create_job(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ve,
)
enqueue_job(db_job, enqueueing_func)
enqueue_job(db_job, filesystem, enqueueing_func)
db.commit()
db.refresh(db_job)
return db_job


def delete_job(db: Session, db_job: models.Job) -> models.Job:
def delete_job(db: Session, filesystem: FileSystem, db_job: models.Job) -> models.Job:
db.delete(db_job)
user_fs = get_user_filesystem(user_id=db_job.user_id)
for path in db_job.paths_out.values():
if path[-1] != "/":
path += "/"
user_fs.delete(path)
filesystem.delete(path)
db.commit()
return db_job
17 changes: 14 additions & 3 deletions api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from api import settings
from api.core import notifications
from api.core.filesystem import FileSystem, get_user_filesystem
from api.core.filesystem import FileSystem, user_filesystem_getter
from api.schemas.job import QueueJob


Expand Down Expand Up @@ -52,11 +52,22 @@ async def current_user_global_dep(
return current_user


async def filesystem_dep(
async def filesystem_getter_dep() -> Callable[[str], FileSystem]:
"""Get the user's filesystem getter."""
return user_filesystem_getter(
user_data_root_path=settings.user_data_root_path,
filesystem=settings.filesystem,
s3_region=settings.s3_region,
s3_bucket=settings.s3_bucket,
)


async def user_filesystem_dep(
filesystem_getter: Callable[[str], FileSystem] = Depends(filesystem_getter_dep),
current_user: CognitoClaims = Depends(current_user_dep),
) -> FileSystem:
"""Get the user's filesystem."""
return get_user_filesystem(current_user.username)
return filesystem_getter(current_user.username)


class APIKeyDependency:
Expand Down
10 changes: 7 additions & 3 deletions api/endpoints/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
as the authentication is handled by the Cognito service.
"""

from typing import Callable

import boto3
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm

from api.core.aws import calculate_secret_hash
from api.core.filesystem import get_user_filesystem
from api.core.filesystem import FileSystem
from api.schemas.token import TokenResponse
from api.schemas.user import User, UserGroups
from api.settings import cognito_client_id, cognito_secret, cognito_user_pool_id
Expand All @@ -25,7 +27,9 @@
description="Register a new user",
)
def register_user(
user: OAuth2PasswordRequestForm = Depends(), groups: list[UserGroups] | None = None
user: OAuth2PasswordRequestForm = Depends(),
filesystem_getter_dep: Callable[[str], FileSystem] = Depends(),
groups: list[UserGroups] | None = None,
) -> User:
client = boto3.client("cognito-idp")
try:
Expand All @@ -52,7 +56,7 @@ def register_user(
Password=user.password,
Permanent=True,
)
filesystem = get_user_filesystem(response["User"]["Username"])
filesystem = filesystem_getter_dep(response["User"]["Username"])
filesystem.init()
except client.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "UsernameExistsException":
Expand Down
20 changes: 11 additions & 9 deletions api/endpoints/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from api import models
from api.core.filesystem import FileSystem
from api.dependencies import filesystem_dep
from api.dependencies import user_filesystem_dep
from api.schemas import file as file_schemas

router = APIRouter()
Expand All @@ -19,7 +19,7 @@
description="Download a file",
)
def download_file(
file_path: str, filesystem: FileSystem = Depends(filesystem_dep)
file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep)
) -> FileResponse | StreamingResponse:
try:
return filesystem.download(file_path)
Expand All @@ -33,7 +33,9 @@ def download_file(
description="Get request parameters (pre-signed URL) to download a file",
)
def get_download_presigned_url(
file_path: str, request: Request, filesystem: FileSystem = Depends(filesystem_dep)
file_path: str,
request: Request,
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> file_schemas.FileHTTPRequest:
try:
return filesystem.download_url(
Expand All @@ -52,7 +54,7 @@ def list_files(
base_path: str = "",
show_dirs: bool = True,
recursive: bool = False,
filesystem: FileSystem = Depends(filesystem_dep),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> list[file_schemas.FileInfo]:
try:
return sorted(
Expand All @@ -73,7 +75,7 @@ def upload_file(
f_type: models.UploadFileTypes,
base_path: str,
file: UploadFile,
filesystem: FileSystem = Depends(filesystem_dep),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> file_schemas.FileInfo:
base_path = f"{f_type.value}/" + base_path
file_path = os.path.join(base_path, file.filename or "unnamed")
Expand All @@ -91,7 +93,7 @@ def get_upload_presigned_url(
f_type: models.UploadFileTypes,
base_path: str,
request: Request,
filesystem: FileSystem = Depends(filesystem_dep),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> file_schemas.FileHTTPRequest:
base_path = f"{f_type.value}/" + base_path
return filesystem.create_file_url(
Expand All @@ -107,7 +109,7 @@ def get_upload_presigned_url(
def create_directory(
f_type: models.UploadFileTypes,
base_path: str,
filesystem: FileSystem = Depends(filesystem_dep),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> None:
return filesystem.create_directory(f"{f_type.value}/{base_path}/")

Expand All @@ -120,7 +122,7 @@ def create_directory(
def rename_file(
file_path: str,
file: file_schemas.FileUpdate,
filesystem: FileSystem = Depends(filesystem_dep),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> file_schemas.FileInfo:
try:
filesystem.rename(file_path, file.path)
Expand All @@ -141,6 +143,6 @@ def rename_file(
description="Delete a file or directory",
)
def delete_file(
file_path: str, filesystem: FileSystem = Depends(filesystem_dep)
file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep)
) -> None:
filesystem.delete(file_path)
12 changes: 9 additions & 3 deletions api/endpoints/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from sqlalchemy.orm import Session

import api.database as database
from api.core.filesystem import FileSystem
from api.crud import job as crud
from api.dependencies import enqueueing_function_dep
from api.dependencies import enqueueing_function_dep, user_filesystem_dep
from api.schemas.job import Job, JobCreate, QueueJob
from api.settings import application_config

Expand Down Expand Up @@ -60,11 +61,13 @@ def start_job(
request: Request,
job: JobCreate,
db: Session = Depends(database.get_db),
filesystem: FileSystem = Depends(user_filesystem_dep),
enqueueing_func: Callable[[QueueJob], None] = Depends(enqueueing_function_dep),
) -> Job:
try:
return crud.create_job(
db,
filesystem,
enqueueing_func,
job,
user_id=request.state.current_user.username,
Expand All @@ -81,11 +84,14 @@ def start_job(
description="Delete a job",
)
def delete_job(
request: Request, job_id: int, db: Session = Depends(database.get_db)
request: Request,
job_id: int,
db: Session = Depends(database.get_db),
filesystem: FileSystem = Depends(user_filesystem_dep),
) -> None:
db_job = crud.get_job(db, job_id)
if db_job is None or db_job.user_id != request.state.current_user.username:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Job not found"
)
crud.delete_job(db, db_job)
crud.delete_job(db, filesystem, db_job)
2 changes: 1 addition & 1 deletion api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _load_possibly_aws_secret(name: str) -> str | None:
if os.environ.get("DATABASE_SECRET"): # set and not None
database_secret = _load_possibly_aws_secret("DATABASE_SECRET")
database_url = database_url.format(database_secret)
filesystem = os.environ.get("FILESYSTEM")
filesystem = os.environ.get("FILESYSTEM", "local")
s3_bucket = os.environ.get("S3_BUCKET")
s3_region = os.environ.get("S3_REGION", "eu-central-1")
user_data_root_path = os.environ.get("USER_DATA_ROOT_PATH", "/data")
Expand Down
14 changes: 4 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import secrets
import time
from typing import Any, Generator
from typing import Any
from unittest.mock import MagicMock

import boto3
Expand All @@ -18,16 +18,10 @@
REGION_NAME: BucketLocationConstraintType = "eu-central-1"


@pytest.fixture(scope="session")
def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]:
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(autouse=True, scope="function")
def enqueueing_func(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock:
@pytest.fixture(autouse=True)
def enqueueing_func(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock_enqueueing_function = MagicMock()
monkeypatch_module.setitem(
monkeypatch.setitem(
app.dependency_overrides,
enqueueing_function_dep, # type: ignore
lambda: mock_enqueueing_function,
Expand Down
Loading
Loading