diff --git a/src/database/datasets.py b/src/database/datasets.py index 5efddcc..fbc203a 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,66 +1,13 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" import datetime -from collections import defaultdict -from typing import Iterable -from schemas.datasets.openml import Feature, Quality +from schemas.datasets.openml import Feature from sqlalchemy import Connection, text from sqlalchemy.engine import Row -def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: - rows = connection.execute( - text( - """ - SELECT `quality`,`value` - FROM data_quality - WHERE `data`=:dataset_id - """, - ), - parameters={"dataset_id": dataset_id}, - ) - return [Quality(name=row.quality, value=row.value) for row in rows] - - -def _get_qualities_for_datasets( - dataset_ids: Iterable[int], - qualities: Iterable[str], - connection: Connection, -) -> dict[int, list[Quality]]: - """Don't call with user-provided input, as query is not parameterized.""" - qualities_filter = ",".join(f"'{q}'" for q in qualities) - dids = ",".join(str(did) for did in dataset_ids) - qualities_query = text( - f""" - SELECT `data`, `quality`, `value` - FROM data_quality - WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) - """, # nosec - dids and qualities are not user-provided - ) - rows = connection.execute(qualities_query) - qualities_by_id = defaultdict(list) - for did, quality, value in rows: - if value is not None: - qualities_by_id[did].append(Quality(name=quality, value=value)) - return dict(qualities_by_id) - - -def list_all_qualities(connection: Connection) -> list[str]: - # The current implementation only fetches *used* qualities, otherwise you should - # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' - qualities = connection.execute( - text( - """ - SELECT DISTINCT(`quality`) - FROM data_quality - """, - ), - ) - return [quality.quality for quality in qualities] - - -def get_dataset(dataset_id: int, connection: Connection) -> Row | None: +def get(id_: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -69,12 +16,12 @@ def get_dataset(dataset_id: int, connection: Connection) -> Row | None: WHERE did = :dataset_id """, ), - parameters={"dataset_id": dataset_id}, + parameters={"dataset_id": id_}, ) return row.one_or_none() -def get_file(file_id: int, connection: Connection) -> Row | None: +def get_file(*, file_id: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -88,7 +35,7 @@ def get_file(file_id: int, connection: Connection) -> Row | None: return row.one_or_none() -def get_tags(dataset_id: int, connection: Connection) -> list[str]: +def get_tags_for(id_: int, connection: Connection) -> list[str]: rows = connection.execute( text( """ @@ -97,12 +44,12 @@ def get_tags(dataset_id: int, connection: Connection) -> list[str]: WHERE id = :dataset_id """, ), - parameters={"dataset_id": dataset_id}, + parameters={"dataset_id": id_}, ) return [row.tag for row in rows] -def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) -> None: +def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: connection.execute( text( """ @@ -111,17 +58,18 @@ def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) """, ), parameters={ - "dataset_id": dataset_id, + "dataset_id": id_, "user_id": user_id, - "tag": tag, + "tag": tag_, }, ) -def get_latest_dataset_description( - dataset_id: int, +def get_description( + id_: int, connection: Connection, ) -> Row | None: + """Get the most recent description for the dataset.""" row = connection.execute( text( """ @@ -131,12 +79,13 @@ def get_latest_dataset_description( ORDER BY version DESC """, ), - parameters={"dataset_id": dataset_id}, + parameters={"dataset_id": id_}, ) return row.first() -def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | None: +def get_status(id_: int, connection: Connection) -> Row | None: + """Get most recent status for the dataset.""" row = connection.execute( text( """ @@ -146,7 +95,7 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | N ORDER BY status_date DESC """, ), - parameters={"dataset_id": dataset_id}, + parameters={"dataset_id": id_}, ) return row.first() @@ -166,7 +115,7 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row return row.one_or_none() -def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]: +def get_features(dataset_id: int, connection: Connection) -> list[Feature]: rows = connection.execute( text( """ @@ -181,7 +130,7 @@ def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Fe return [Feature(**row, nominal_values=None) for row in rows.mappings()] -def get_feature_values(dataset_id: int, feature_index: int, connection: Connection) -> list[str]: +def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]: rows = connection.execute( text( """ @@ -195,10 +144,11 @@ def get_feature_values(dataset_id: int, feature_index: int, connection: Connecti return [row.value for row in rows] -def insert_status_for_dataset( +def update_status( dataset_id: int, - user_id: int, status: str, + *, + user_id: int, connection: Connection, ) -> None: connection.execute( diff --git a/src/database/flows.py b/src/database/flows.py index c6c8807..52bd867 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -63,7 +63,7 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No ).one_or_none() -def get_by_id(flow_id: int, expdb: Connection) -> Row | None: +def get(id_: int, expdb: Connection) -> Row | None: return expdb.execute( text( """ @@ -72,5 +72,5 @@ def get_by_id(flow_id: int, expdb: Connection) -> Row | None: WHERE id = :flow_id """, ), - parameters={"flow_id": flow_id}, + parameters={"flow_id": id_}, ).one_or_none() diff --git a/src/database/qualities.py b/src/database/qualities.py new file mode 100644 index 0000000..a48b7b2 --- /dev/null +++ b/src/database/qualities.py @@ -0,0 +1,56 @@ +from collections import defaultdict +from typing import Iterable + +from schemas.datasets.openml import Quality +from sqlalchemy import Connection, text + + +def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: + rows = connection.execute( + text( + """ + SELECT `quality`,`value` + FROM data_quality + WHERE `data`=:dataset_id + """, + ), + parameters={"dataset_id": dataset_id}, + ) + return [Quality(name=row.quality, value=row.value) for row in rows] + + +def _get_for_datasets( + dataset_ids: Iterable[int], + quality_names: Iterable[str], + connection: Connection, +) -> dict[int, list[Quality]]: + """Don't call with user-provided input, as query is not parameterized.""" + qualities_filter = ",".join(f"'{q}'" for q in quality_names) + dids = ",".join(str(did) for did in dataset_ids) + qualities_query = text( + f""" + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) + """, # nosec - dids and qualities are not user-provided + ) + rows = connection.execute(qualities_query) + qualities_by_id = defaultdict(list) + for did, quality, value in rows: + if value is not None: + qualities_by_id[did].append(Quality(name=quality, value=value)) + return dict(qualities_by_id) + + +def list_all_qualities(connection: Connection) -> list[str]: + # The current implementation only fetches *used* qualities, otherwise you should + # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' + qualities_ = connection.execute( + text( + """ + SELECT DISTINCT(`quality`) + FROM data_quality + """, + ), + ) + return [quality.quality for quality in qualities_] diff --git a/src/database/studies.py b/src/database/studies.py index 3c7c166..31ed6f2 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -8,7 +8,7 @@ from database.users import User -def get_study_by_id(study_id: int, connection: Connection) -> Row | None: +def get_by_id(id_: int, connection: Connection) -> Row | None: return connection.execute( text( """ @@ -17,11 +17,11 @@ def get_study_by_id(study_id: int, connection: Connection) -> Row | None: WHERE id = :study_id """, ), - parameters={"study_id": study_id}, + parameters={"study_id": id_}, ).one_or_none() -def get_study_by_alias(alias: str, connection: Connection) -> Row | None: +def get_by_alias(alias: str, connection: Connection) -> Row | None: return connection.execute( text( """ @@ -35,6 +35,11 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row | None: def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: + """Return data related to the study, content depends on the study type. + + For task studies: (task id, dataset id) + For run studies: (run id, task id, setup id, dataset id, flow id) + """ if study.type_ == StudyType.TASK: return cast( Sequence[Row], @@ -72,7 +77,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: ) -def create_study(study: CreateStudy, user: User, expdb: Connection) -> int: +def create(study: CreateStudy, user: User, expdb: Connection) -> int: expdb.execute( text( """ @@ -100,7 +105,7 @@ def create_study(study: CreateStudy, user: User, expdb: Connection) -> int: return cast(int, study_id) -def attach_task_to_study(task_id: int, study_id: int, user: User, expdb: Connection) -> None: +def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None: expdb.execute( text( """ @@ -112,7 +117,7 @@ def attach_task_to_study(task_id: int, study_id: int, user: User, expdb: Connect ) -def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connection) -> None: +def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> None: expdb.execute( text( """ @@ -124,7 +129,8 @@ def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connectio ) -def attach_tasks_to_study( +def attach_tasks( + *, study_id: int, task_ids: list[int], user: User, @@ -155,9 +161,9 @@ def attach_tasks_to_study( raise ValueError(msg) from e -def attach_runs_to_study( +def attach_runs( study_id: int, # noqa: ARG001 - task_ids: list[int], # noqa: ARG001 + run_ids: list[int], # noqa: ARG001 user: User, # noqa: ARG001 connection: Connection, # noqa: ARG001 ) -> None: diff --git a/src/database/tasks.py b/src/database/tasks.py index 69ce220..fa78722 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -3,7 +3,7 @@ from sqlalchemy import Connection, Row, text -def get_task(task_id: int, expdb: Connection) -> Row | None: +def get(id_: int, expdb: Connection) -> Row | None: return expdb.execute( text( """ @@ -12,7 +12,7 @@ def get_task(task_id: int, expdb: Connection) -> Row | None: WHERE `task_id` = :task_id """, ), - parameters={"task_id": task_id}, + parameters={"task_id": id_}, ).one_or_none() @@ -59,7 +59,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro ) -def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]: +def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]: return cast( Sequence[Row], expdb.execute( @@ -70,7 +70,7 @@ def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]: WHERE task_id = :task_id """, ), - parameters={"task_id": task_id}, + parameters={"task_id": id_}, ).all(), ) @@ -91,7 +91,7 @@ def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequ ) -def get_tags_for_task(task_id: int, expdb: Connection) -> list[str]: +def get_tags(id_: int, expdb: Connection) -> list[str]: tag_rows = expdb.execute( text( """ @@ -100,6 +100,6 @@ def get_tags_for_task(task_id: int, expdb: Connection) -> list[str]: WHERE `id` = :task_id """, ), - parameters={"task_id": task_id}, + parameters={"task_id": id_}, ) return [row.tag for row in tag_rows] diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index a1c5c1b..e0a0a0e 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -1,14 +1,11 @@ -""" -We add separate endpoints for old-style JSON responses, so they don't clutter the schema of the -new API, and are easily removed later. -""" - import http.client import re from datetime import datetime from enum import StrEnum from typing import Annotated, Any, Literal, NamedTuple +import database.datasets +import database.qualities from core.access import _user_has_access from core.errors import DatasetError from core.formatting import ( @@ -17,20 +14,6 @@ _format_error, _format_parquet_url, ) -from database.datasets import ( - _get_qualities_for_datasets, - get_feature_values, - get_features_for_dataset, - get_file, - get_latest_dataset_description, - get_latest_processing_update, - get_latest_status_update, - get_tags, - insert_status_for_dataset, - remove_deactivated_status, -) -from database.datasets import get_dataset as db_get_dataset -from database.datasets import tag_dataset as db_tag_dataset from database.users import User, UserGroup from fastapi import APIRouter, Body, Depends, HTTPException from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType @@ -52,7 +35,7 @@ def tag_dataset( user: Annotated[User | None, Depends(fetch_user)] = None, expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: - tags = get_tags(data_id, expdb_db) + tags = database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: raise HTTPException( status_code=http.client.INTERNAL_SERVER_ERROR, @@ -68,7 +51,7 @@ def tag_dataset( status_code=http.client.PRECONDITION_FAILED, detail={"code": "103", "message": "Authentication failed"}, ) from None - db_tag_dataset(user.user_id, data_id, tag, connection=expdb_db) + database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) all_tags = [*tags, tag] tag_value = all_tags if len(all_tags) > 1 else all_tags[0] @@ -241,9 +224,9 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_by_dataset = _get_qualities_for_datasets( + qualities_by_dataset = database.qualities._get_for_datasets( dataset_ids=datasets.keys(), - qualities=qualities_to_show, + quality_names=qualities_to_show, connection=expdb_db, ) for did, qualities in qualities_by_dataset.items(): @@ -259,7 +242,9 @@ class ProcessingInformation(NamedTuple): def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" - if not (data_processed := get_latest_processing_update(dataset_id, connection)): + if not ( + data_processed := database.datasets.get_latest_processing_update(dataset_id, connection) + ): return ProcessingInformation(date=None, warning=None, error=None) date_processed = data_processed.processing_date @@ -277,7 +262,7 @@ def _get_dataset_raise_otherwise( Raises HTTPException if the dataset does not exist or the user can not access it. """ - if not (dataset := db_get_dataset(dataset_id, expdb)): + if not (dataset := database.datasets.get(dataset_id, expdb)): error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") raise HTTPException(status_code=http.client.NOT_FOUND, detail=error) @@ -295,12 +280,16 @@ def get_dataset_features( expdb: Annotated[Connection, Depends(expdb_connection)] = None, ) -> list[Feature]: _get_dataset_raise_otherwise(dataset_id, user, expdb) - features = get_features_for_dataset(dataset_id, expdb) + features = database.datasets.get_features(dataset_id, expdb) for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: - feature.nominal_values = get_feature_values(dataset_id, feature.index, expdb) + feature.nominal_values = database.datasets.get_feature_values( + dataset_id, + feature_index=feature.index, + connection=expdb, + ) if not features: - processing_state = get_latest_processing_update(dataset_id, expdb) + processing_state = database.datasets.get_latest_processing_update(dataset_id, expdb) if processing_state is None: code, msg = ( 273, @@ -349,7 +338,7 @@ def update_dataset_status( detail={"code": 696, "message": "Only administrators can activate datasets."}, ) - current_status = get_latest_status_update(dataset_id, expdb) + current_status = database.datasets.get_status(dataset_id, expdb) if current_status and current_status.status == status: raise HTTPException( status_code=http.client.PRECONDITION_FAILED, @@ -363,9 +352,9 @@ def update_dataset_status( # - active => deactivated (add a row) # - deactivated => active (delete a row) if current_status is None or status == DatasetStatus.DEACTIVATED: - insert_status_for_dataset(dataset_id, user.user_id, status, expdb) + database.datasets.update_status(dataset_id, status, user_id=user.user_id, connection=expdb) elif current_status.status == DatasetStatus.DEACTIVATED: - remove_deactivated_status(dataset_id, expdb) + database.datasets.remove_deactivated_status(dataset_id, expdb) else: raise HTTPException( status_code=http.client.INTERNAL_SERVER_ERROR, @@ -386,17 +375,19 @@ def get_dataset( expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db) - if not (dataset_file := get_file(dataset.file_id, user_db)): + if not ( + dataset_file := database.datasets.get_file(file_id=dataset.file_id, connection=user_db) + ): error = _format_error( code=DatasetError.NO_DATA_FILE, message="No data file found", ) raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error) - tags = get_tags(dataset_id, expdb_db) - description = get_latest_dataset_description(dataset_id, expdb_db) + tags = database.datasets.get_tags_for(dataset_id, expdb_db) + description = database.datasets.get_description(dataset_id, expdb_db) processing_result = _get_processing_information(dataset_id, expdb_db) - status = get_latest_status_update(dataset_id, expdb_db) + status = database.datasets.get_status(dataset_id, expdb_db) status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index 4739d47..7d8b76c 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -1,6 +1,6 @@ from typing import Annotated, Iterable -from database.evaluations import get_estimation_procedures as db_get_estimation_procedures +import database.evaluations from fastapi import APIRouter, Depends from schemas.datasets.openml import EstimationProcedure from sqlalchemy import Connection @@ -14,4 +14,4 @@ def get_estimation_procedures( expdb: Annotated[Connection, Depends(expdb_connection)], ) -> Iterable[EstimationProcedure]: - return db_get_estimation_procedures(expdb) + return database.evaluations.get_estimation_procedures(expdb) diff --git a/src/routers/openml/evaluations.py b/src/routers/openml/evaluations.py index c641b5b..bdd12f3 100644 --- a/src/routers/openml/evaluations.py +++ b/src/routers/openml/evaluations.py @@ -1,6 +1,6 @@ from typing import Annotated -from database.evaluations import get_math_functions +import database.evaluations from fastapi import APIRouter, Depends from sqlalchemy import Connection @@ -11,5 +11,8 @@ @router.get("/list") def get_evaluation_measures(expdb: Annotated[Connection, Depends(expdb_connection)]) -> list[str]: - functions = get_math_functions(function_type="EvaluationFunction", connection=expdb) + functions = database.evaluations.get_math_functions( + function_type="EvaluationFunction", + connection=expdb, + ) return [function.name for function in functions] diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 9b73084..3a0d3ba 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -30,7 +30,7 @@ def flow_exists( @router.get("/{flow_id}") def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: - flow = database.flows.get_by_id(flow_id, expdb) + flow = database.flows.get(flow_id, expdb) if not flow: raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found") diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index c498a54..e4d4976 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -1,9 +1,10 @@ import http.client from typing import Annotated, Literal +import database.datasets +import database.qualities from core.access import _user_has_access from core.errors import DatasetError -from database.datasets import get_dataset, get_qualities_for_dataset, list_all_qualities from database.users import User from fastapi import APIRouter, Depends, HTTPException from schemas.datasets.openml import Quality @@ -18,7 +19,7 @@ def list_qualities( expdb: Annotated[Connection, Depends(expdb_connection)], ) -> dict[Literal["data_qualities_list"], dict[Literal["quality"], list[str]]]: - qualities = list_all_qualities(connection=expdb) + qualities = database.qualities.list_all_qualities(connection=expdb) return { "data_qualities_list": { "quality": qualities, @@ -32,13 +33,13 @@ def get_qualities( user: Annotated[User | None, Depends(fetch_user)], expdb: Annotated[Connection, Depends(expdb_connection)], ) -> list[Quality]: - dataset = get_dataset(dataset_id, expdb) + dataset = database.datasets.get(dataset_id, expdb) if not dataset or not _user_has_access(dataset, user): raise HTTPException( status_code=http.client.PRECONDITION_FAILED, detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"}, ) from None - return get_qualities_for_dataset(dataset_id, expdb) + return database.qualities.get_for_dataset(dataset_id, expdb) # The PHP API provided (sometime) helpful error messages # if not qualities: # check if dataset exists: error 360 diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index f4153f9..a06fc9f 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -1,17 +1,8 @@ import http.client from typing import Annotated, Literal +import database.studies from core.formatting import _str_to_bool -from database.studies import ( - attach_run_to_study, - attach_runs_to_study, - attach_task_to_study, - attach_tasks_to_study, - get_study_by_alias, - get_study_by_id, - get_study_data, -) -from database.studies import create_study as db_create_study from database.users import User, UserGroup from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel @@ -26,9 +17,9 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: Connection) -> Row: if isinstance(id_or_alias, int) or id_or_alias.isdigit(): - study = get_study_by_id(int(id_or_alias), expdb) + study = database.studies.get_by_id(int(id_or_alias), expdb) else: - study = get_study_by_alias(id_or_alias, expdb) + study = database.studies.get_by_alias(id_or_alias, expdb) if study is None: raise HTTPException(status_code=http.client.NOT_FOUND, detail="Study not found.") @@ -77,9 +68,16 @@ def attach_to_study( # We let the database handle the constraints on whether # the entity is already attached or if it even exists. - attach = attach_tasks_to_study if study.type_ == StudyType.TASK else attach_runs_to_study + attach_kwargs = { + "study_id": study_id, + "user": user, + "connection": expdb, + } try: - attach(study_id, entity_ids, user, expdb) + if study.type_ == StudyType.TASK: + database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) + else: + database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) except ValueError as e: raise HTTPException( status_code=http.client.CONFLICT, @@ -109,18 +107,18 @@ def create_study( status_code=http.client.BAD_REQUEST, detail="Cannot create a task study with runs.", ) - if study.alias and get_study_by_alias(study.alias, expdb): + if study.alias and database.studies.get_by_alias(study.alias, expdb): raise HTTPException( status_code=http.client.CONFLICT, detail="Study alias already exists.", ) - study_id = db_create_study(study, user, expdb) + study_id = database.studies.create(study, user, expdb) if study.main_entity_type == StudyType.TASK: for task_id in study.tasks: - attach_task_to_study(task_id, study_id, user, expdb) + database.studies.attach_task(task_id, study_id, user, expdb) if study.main_entity_type == StudyType.RUN: for run_id in study.runs: - attach_run_to_study(run_id, study_id, user, expdb) + database.studies.attach_run(run_id=run_id, study_id=study_id, user=user, expdb=expdb) # Make sure that invalid fields raise an error (e.g., "task_ids") return {"study_id": study_id} @@ -132,7 +130,7 @@ def get_study( expdb: Annotated[Connection, Depends(expdb_connection)] = None, ) -> Study: study = _get_study_raise_otherwise(alias_or_id, user, expdb) - study_data = get_study_data(study, expdb) + study_data = database.studies.get_study_data(study, expdb) return Study( _legacy=_str_to_bool(study.legacy), id_=study.id, diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index e01b661..020453b 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -3,15 +3,9 @@ import re from typing import Annotated, Any +import database.datasets +import database.tasks import xmltodict -from database.datasets import get_dataset -from database.tasks import ( - get_input_for_task, - get_tags_for_task, - get_task_type, - get_task_type_inout_with_template, -) -from database.tasks import get_task as db_get_task from fastapi import APIRouter, Depends, HTTPException from schemas.datasets.openml import Task from sqlalchemy import Connection, RowMapping, text @@ -154,9 +148,9 @@ def get_task( # user: Annotated[User | None, Depends(fetch_user)] = None, # Privacy is not respected expdb: Annotated[Connection, Depends(expdb_connection)] = None, ) -> Task: - if not (task := db_get_task(task_id, expdb)): + if not (task := database.tasks.get(task_id, expdb)): raise HTTPException(status_code=http.client.NOT_FOUND, detail="Task not found") - if not (task_type := get_task_type(task.ttid, expdb)): + if not (task_type := database.tasks.get_task_type(task.ttid, expdb)): raise HTTPException( status_code=http.client.INTERNAL_SERVER_ERROR, detail="Task type not found", @@ -164,9 +158,9 @@ def get_task( task_inputs = { row.input: int(row.value) if row.value.isdigit() else row.value - for row in get_input_for_task(task_id, expdb) + for row in database.tasks.get_input_for_task(task_id, expdb) } - ttios = get_task_type_inout_with_template(task_type.ttid, expdb) + ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb) templates = [(tt_io.name, tt_io.io, tt_io.requirement, tt_io.template_api) for tt_io in ttios] inputs = [ fill_template(template, task, task_inputs, expdb) | {"name": name} @@ -178,10 +172,10 @@ def get_task( for name, io, required, template in templates if io == "output" ] - tags = get_tags_for_task(task_id, expdb) + tags = database.tasks.get_tags(task_id, expdb) name = f"Task {task_id} ({task_type.name})" dataset_id = task_inputs.get("source_data") - if dataset_id and (dataset := get_dataset(dataset_id, expdb)): + if dataset_id and (dataset := database.datasets.get(dataset_id, expdb)): name = f"Task {task_id}: {dataset.name} ({task_type.name})" return Task( diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 9a982b6..51300bc 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,7 +1,7 @@ import http.client import pytest -from database.datasets import get_tags +from database.datasets import get_tags_for from sqlalchemy import Connection from starlette.testclient import TestClient @@ -38,7 +38,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> assert response.status_code == http.client.OK assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}} - tags = get_tags(dataset_id=dataset_id, connection=expdb_test) + tags = get_tags_for(id_=dataset_id, connection=expdb_test) assert tag in tags