Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/db access #172

Merged
merged 11 commits into from
Jul 19, 2024
92 changes: 21 additions & 71 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,13 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""

import datetime
from collections import defaultdict
from typing import Iterable

from schemas.datasets.openml import Feature, Quality
from schemas.datasets.openml import Feature
from sqlalchemy import Connection, text
from sqlalchemy.engine import Row


def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
rows = connection.execute(
text(
"""
SELECT `quality`,`value`
FROM data_quality
WHERE `data`=:dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
return [Quality(name=row.quality, value=row.value) for row in rows]


def _get_qualities_for_datasets(
dataset_ids: Iterable[int],
qualities: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
"""Don't call with user-provided input, as query is not parameterized."""
qualities_filter = ",".join(f"'{q}'" for q in qualities)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
for did, quality, value in rows:
if value is not None:
qualities_by_id[did].append(Quality(name=quality, value=value))
return dict(qualities_by_id)


def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
qualities = connection.execute(
text(
"""
SELECT DISTINCT(`quality`)
FROM data_quality
""",
),
)
return [quality.quality for quality in qualities]


def get_dataset(dataset_id: int, connection: Connection) -> Row | None:
def get(id_: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -69,12 +16,12 @@ def get_dataset(dataset_id: int, connection: Connection) -> Row | None:
WHERE did = :dataset_id
""",
),
parameters={"dataset_id": dataset_id},
parameters={"dataset_id": id_},
)
return row.one_or_none()


def get_file(file_id: int, connection: Connection) -> Row | None:
def get_file(*, file_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -88,7 +35,7 @@ def get_file(file_id: int, connection: Connection) -> Row | None:
return row.one_or_none()


def get_tags(dataset_id: int, connection: Connection) -> list[str]:
def get_tags_for(id_: int, connection: Connection) -> list[str]:
rows = connection.execute(
text(
"""
Expand All @@ -97,12 +44,12 @@ def get_tags(dataset_id: int, connection: Connection) -> list[str]:
WHERE id = :dataset_id
""",
),
parameters={"dataset_id": dataset_id},
parameters={"dataset_id": id_},
)
return [row.tag for row in rows]


def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) -> None:
def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
connection.execute(
text(
"""
Expand All @@ -111,17 +58,18 @@ def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection)
""",
),
parameters={
"dataset_id": dataset_id,
"dataset_id": id_,
"user_id": user_id,
"tag": tag,
"tag": tag_,
},
)


def get_latest_dataset_description(
dataset_id: int,
def get_description(
id_: int,
connection: Connection,
) -> Row | None:
"""Get the most recent description for the dataset."""
row = connection.execute(
text(
"""
Expand All @@ -131,12 +79,13 @@ def get_latest_dataset_description(
ORDER BY version DESC
""",
),
parameters={"dataset_id": dataset_id},
parameters={"dataset_id": id_},
)
return row.first()


def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | None:
def get_status(id_: int, connection: Connection) -> Row | None:
"""Get most recent status for the dataset."""
row = connection.execute(
text(
"""
Expand All @@ -146,7 +95,7 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | N
ORDER BY status_date DESC
""",
),
parameters={"dataset_id": dataset_id},
parameters={"dataset_id": id_},
)
return row.first()

Expand All @@ -166,7 +115,7 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row
return row.one_or_none()


def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]:
def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
rows = connection.execute(
text(
"""
Expand All @@ -181,7 +130,7 @@ def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Fe
return [Feature(**row, nominal_values=None) for row in rows.mappings()]


def get_feature_values(dataset_id: int, feature_index: int, connection: Connection) -> list[str]:
def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]:
rows = connection.execute(
text(
"""
Expand All @@ -195,10 +144,11 @@ def get_feature_values(dataset_id: int, feature_index: int, connection: Connecti
return [row.value for row in rows]


def insert_status_for_dataset(
def update_status(
dataset_id: int,
user_id: int,
status: str,
*,
user_id: int,
connection: Connection,
) -> None:
connection.execute(
Expand Down
4 changes: 2 additions & 2 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No
).one_or_none()


def get_by_id(flow_id: int, expdb: Connection) -> Row | None:
def get(id_: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
Expand All @@ -72,5 +72,5 @@ def get_by_id(flow_id: int, expdb: Connection) -> Row | None:
WHERE id = :flow_id
""",
),
parameters={"flow_id": flow_id},
parameters={"flow_id": id_},
).one_or_none()
56 changes: 56 additions & 0 deletions src/database/qualities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections import defaultdict
from typing import Iterable

from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text


def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
rows = connection.execute(
text(
"""
SELECT `quality`,`value`
FROM data_quality
WHERE `data`=:dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
return [Quality(name=row.quality, value=row.value) for row in rows]


def _get_for_datasets(
dataset_ids: Iterable[int],
quality_names: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
"""Don't call with user-provided input, as query is not parameterized."""
qualities_filter = ",".join(f"'{q}'" for q in quality_names)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
for did, quality, value in rows:
if value is not None:
qualities_by_id[did].append(Quality(name=quality, value=value))
return dict(qualities_by_id)


def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
qualities_ = connection.execute(
text(
"""
SELECT DISTINCT(`quality`)
FROM data_quality
""",
),
)
return [quality.quality for quality in qualities_]
24 changes: 15 additions & 9 deletions src/database/studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from database.users import User


def get_study_by_id(study_id: int, connection: Connection) -> Row | None:
def get_by_id(id_: int, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -17,11 +17,11 @@ def get_study_by_id(study_id: int, connection: Connection) -> Row | None:
WHERE id = :study_id
""",
),
parameters={"study_id": study_id},
parameters={"study_id": id_},
).one_or_none()


def get_study_by_alias(alias: str, connection: Connection) -> Row | None:
def get_by_alias(alias: str, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -35,6 +35,11 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row | None:


def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
"""Return data related to the study, content depends on the study type.

For task studies: (task id, dataset id)
For run studies: (run id, task id, setup id, dataset id, flow id)
"""
if study.type_ == StudyType.TASK:
return cast(
Sequence[Row],
Expand Down Expand Up @@ -72,7 +77,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
)


def create_study(study: CreateStudy, user: User, expdb: Connection) -> int:
def create(study: CreateStudy, user: User, expdb: Connection) -> int:
expdb.execute(
text(
"""
Expand Down Expand Up @@ -100,7 +105,7 @@ def create_study(study: CreateStudy, user: User, expdb: Connection) -> int:
return cast(int, study_id)


def attach_task_to_study(task_id: int, study_id: int, user: User, expdb: Connection) -> None:
def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None:
expdb.execute(
text(
"""
Expand All @@ -112,7 +117,7 @@ def attach_task_to_study(task_id: int, study_id: int, user: User, expdb: Connect
)


def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connection) -> None:
def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> None:
expdb.execute(
text(
"""
Expand All @@ -124,7 +129,8 @@ def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connectio
)


def attach_tasks_to_study(
def attach_tasks(
*,
study_id: int,
task_ids: list[int],
user: User,
Expand Down Expand Up @@ -155,9 +161,9 @@ def attach_tasks_to_study(
raise ValueError(msg) from e


def attach_runs_to_study(
def attach_runs(
study_id: int, # noqa: ARG001
task_ids: list[int], # noqa: ARG001
run_ids: list[int], # noqa: ARG001
user: User, # noqa: ARG001
connection: Connection, # noqa: ARG001
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import Connection, Row, text


def get_task(task_id: int, expdb: Connection) -> Row | None:
def get(id_: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
Expand All @@ -12,7 +12,7 @@ def get_task(task_id: int, expdb: Connection) -> Row | None:
WHERE `task_id` = :task_id
""",
),
parameters={"task_id": task_id},
parameters={"task_id": id_},
).one_or_none()


Expand Down Expand Up @@ -59,7 +59,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro
)


def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]:
def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
Expand All @@ -70,7 +70,7 @@ def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]:
WHERE task_id = :task_id
""",
),
parameters={"task_id": task_id},
parameters={"task_id": id_},
).all(),
)

Expand All @@ -91,7 +91,7 @@ def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequ
)


def get_tags_for_task(task_id: int, expdb: Connection) -> list[str]:
def get_tags(id_: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
text(
"""
Expand All @@ -100,6 +100,6 @@ def get_tags_for_task(task_id: int, expdb: Connection) -> list[str]:
WHERE `id` = :task_id
""",
),
parameters={"task_id": task_id},
parameters={"task_id": id_},
)
return [row.tag for row in tag_rows]
Loading
Loading