diff --git a/src/mavedb/lib/score_sets.py b/src/mavedb/lib/score_sets.py index c071898f..5a7d44c8 100644 --- a/src/mavedb/lib/score_sets.py +++ b/src/mavedb/lib/score_sets.py @@ -3,7 +3,7 @@ import logging import re from operator import attrgetter -from typing import Any, BinaryIO, Iterable, Optional, TYPE_CHECKING, Sequence, Literal +from typing import Any, BinaryIO, Iterable, List, Optional, TYPE_CHECKING, Sequence, Literal from mavedb.models.mapped_variant import MappedVariant import numpy as np @@ -401,12 +401,12 @@ def find_publish_or_private_superseded_score_set_tail( def get_score_set_variants_as_csv( db: Session, score_set: ScoreSet, - data_type: Literal["scores", "counts"], + data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]], start: Optional[int] = None, limit: Optional[int] = None, drop_na_columns: Optional[bool] = None, - include_custom_columns: bool = True, - include_post_mapped_hgvs: bool = False, + include_custom_columns: Optional[bool] = True, + include_post_mapped_hgvs: Optional[bool] = False, ) -> str: """ Get the variant data from a score set as a CSV string. @@ -417,8 +417,8 @@ def get_score_set_variants_as_csv( The database session to use. score_set : ScoreSet The score set to get the variants from. - data_type : {'scores', 'counts'} - The type of data to get. Either 'scores' or 'counts'. + data_types : List[Literal["scores", "counts", "clinVar", "gnomAD"]] + The data types to get. Either one of 'scores', 'counts', 'clinVar', 'gnomAD' or some of them. start : int, optional The index to start from. If None, starts from the beginning. limit : int, optional @@ -437,18 +437,33 @@ def get_score_set_variants_as_csv( The CSV string containing the variant data. """ assert type(score_set.dataset_columns) is dict - custom_columns_set = "score_columns" if data_type == "scores" else "count_columns" - type_column = "score_data" if data_type == "scores" else "count_data" - + custom_columns = { + "scores": "score_columns", + "counts": "count_columns", + } + custom_columns_set = [custom_columns[dt] for dt in data_types if dt in custom_columns] + type_to_column = { + "scores": "score_data", + "counts": "count_data" + } + type_columns = [type_to_column[dt] for dt in data_types if dt in type_to_column] columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] if include_post_mapped_hgvs: columns.append("post_mapped_hgvs_g") columns.append("post_mapped_hgvs_p") if include_custom_columns: - custom_columns = [str(x) for x in list(score_set.dataset_columns.get(custom_columns_set, []))] - columns += custom_columns - elif data_type == "scores": + for column in custom_columns_set: + dataset_columns = [str(x) for x in list(score_set.dataset_columns.get(column, []))] + if column == "score_columns": + for c in dataset_columns: + prefixed = "scores." + c + columns.append(prefixed) + elif column == "count_columns": + for c in dataset_columns: + prefixed = "counts." + c + columns.append(prefixed) + elif len(data_types) == 1 and data_types[0] == "scores": columns.append(REQUIRED_SCORE_COLUMN) variants: Sequence[Variant] = [] @@ -488,7 +503,35 @@ def get_score_set_variants_as_csv( variants_query = variants_query.limit(limit) variants = db.scalars(variants_query).all() - rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column, mappings=mappings) # type: ignore + rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_columns, mappings=mappings) # type: ignore + + # TODO: will add len(data_types) == 1 and "scores"/"counts" are not in [data_types] and include_post_mapped_hgvs + # case when we get the clinVar and gnomAD + if len(data_types) > 1 and include_post_mapped_hgvs: + rename_map = {} + rename_map["post_mapped_hgvs_g"] = "mavedb.post_mapped_hgvs_g" + rename_map["post_mapped_hgvs_p"] = "mavedb.post_mapped_hgvs_p" + + # Update column order list (preserve original order) + columns = [rename_map.get(col, col) for col in columns] + + # Rename keys in each row + renamed_rows_data = [] + for row in rows_data: + renamed_row = {rename_map.get(k, k): v for k, v in row.items()} + renamed_rows_data.append(renamed_row) + + rows_data = renamed_rows_data + elif len(data_types) == 1: + prefix = f"{data_types[0]}." + columns = [col[len(prefix):] if col.startswith(prefix) else col for col in columns] + + # Rename rows to remove the same prefix from keys + renamed_rows_data = [] + for row in rows_data: + renamed_row = {(k[len(prefix):] if k.startswith(prefix) else k): v for k, v in row.items()} + renamed_rows_data.append(renamed_row) + rows_data = renamed_rows_data if drop_na_columns: rows_data, columns = drop_na_columns_from_csv_file_rows(rows_data, columns) @@ -532,7 +575,7 @@ def is_null(value): def variant_to_csv_row( variant: Variant, columns: list[str], - dtype: str, + dtype: list[str], mapping: Optional[MappedVariant] = None, na_rep="NA", ) -> dict[str, Any]: @@ -546,7 +589,7 @@ def variant_to_csv_row( columns : list[str] Columns to serialize. dtype : str, {'scores', 'counts'} - The type of data requested. Either the 'score_data' or 'count_data'. + The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data']. na_rep : str String to represent null values. @@ -577,8 +620,18 @@ def variant_to_csv_row( else: value = "" else: - parent = variant.data.get(dtype) if variant.data else None - value = str(parent.get(column_key)) if parent else na_rep + for dt in dtype: + parent = variant.data.get(dt) if variant.data else None + if column_key.startswith("scores."): + inner_key = column_key.replace("scores.", "") + elif column_key.startswith("counts."): + inner_key = column_key.replace("counts.", "") + else: + # fallback for non-prefixed columns + inner_key = column_key + if parent and inner_key in parent: + value = str(parent[inner_key]) + break if is_null(value): value = na_rep row[column_key] = value @@ -589,7 +642,7 @@ def variant_to_csv_row( def variants_to_csv_rows( variants: Sequence[Variant], columns: list[str], - dtype: str, + dtype: List[str], mappings: Optional[Sequence[Optional[MappedVariant]]] = None, na_rep="NA", ) -> Iterable[dict[str, Any]]: @@ -602,8 +655,8 @@ def variants_to_csv_rows( List of variants. columns : list[str] Columns to serialize. - dtype : str, {'scores', 'counts'} - The type of data requested. Either the 'score_data' or 'count_data'. + dtype : list, {'scores', 'counts'} + The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data']. na_rep : str String to represent null values. diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 453c4b93..f210eefe 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -1,6 +1,6 @@ import logging from datetime import date -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Literal, Optional, Sequence, Union import pandas as pd from arq import ArqRedis @@ -249,7 +249,13 @@ def get_score_set_variants_csv( urn: str, start: int = Query(default=None, description="Start index for pagination"), limit: int = Query(default=None, description="Maximum number of variants to return"), + data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]] = Query( + default=["scores"], + description="One or more data types to include: scores, counts, clinVar, gnomAD" + ), drop_na_columns: Optional[bool] = None, + include_custom_columns: Optional[bool] = None, + include_post_mapped_hgvs: Optional[bool] = None, db: Session = Depends(deps.get_db), user_data: Optional[UserData] = Depends(get_current_user), ) -> Any: @@ -262,9 +268,6 @@ def get_score_set_variants_csv( TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may want to turn this into a general-purpose CSV export endpoint, with options governing which columns to include. - Parameters - __________ - Parameters __________ urn : str @@ -312,12 +315,12 @@ def get_score_set_variants_csv( csv_str = get_score_set_variants_as_csv( db, score_set, - "scores", + data_types, start, limit, drop_na_columns, - include_custom_columns=False, - include_post_mapped_hgvs=True, + include_custom_columns, + include_post_mapped_hgvs, ) return StreamingResponse(iter([csv_str]), media_type="text/csv") @@ -373,7 +376,7 @@ def get_score_set_scores_csv( assert_permission(user_data, score_set, Action.READ) - csv_str = get_score_set_variants_as_csv(db, score_set, "scores", start, limit, drop_na_columns) + csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"], start, limit, drop_na_columns) return StreamingResponse(iter([csv_str]), media_type="text/csv") @@ -428,7 +431,7 @@ async def get_score_set_counts_csv( assert_permission(user_data, score_set, Action.READ) - csv_str = get_score_set_variants_as_csv(db, score_set, "counts", start, limit, drop_na_columns) + csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"], start, limit, drop_na_columns) return StreamingResponse(iter([csv_str]), media_type="text/csv") @@ -1252,12 +1255,12 @@ async def update_score_set( ] + item.dataset_columns["count_columns"] scores_data = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data") + variants_to_csv_rows(item.variants, columns=score_columns, dtype=["score_data"]) ).replace("NA", pd.NA) if item.dataset_columns["count_columns"]: count_data = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data") + variants_to_csv_rows(item.variants, columns=count_columns, dtype=["count_data"]) ).replace("NA", pd.NA) else: count_data = None diff --git a/src/mavedb/scripts/export_public_data.py b/src/mavedb/scripts/export_public_data.py index 9d7d8e7f..2172878d 100644 --- a/src/mavedb/scripts/export_public_data.py +++ b/src/mavedb/scripts/export_public_data.py @@ -147,12 +147,12 @@ def export_public_data(db: Session): logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}") csv_filename_base = score_set.urn.replace(":", "-") - csv_str = get_score_set_variants_as_csv(db, score_set, "scores") + csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"]) zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str) count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None if count_columns and len(count_columns) > 0: - csv_str = get_score_set_variants_as_csv(db, score_set, "counts") + csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"]) zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str) diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 7d056bba..cfde04d1 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -2473,7 +2473,7 @@ def test_download_variants_data_file( worker_queue.assert_called_once() download_scores_csv_response = client.get( - f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?drop_na_columns=true" + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?drop_na_columns=true&include_post_mapped_hgvs=true" ) assert download_scores_csv_response.status_code == 200 download_scores_csv = download_scores_csv_response.text @@ -2545,6 +2545,127 @@ def test_download_counts_file(session, data_provider, client, setup_router_db, d assert "hgvs_splice" not in columns + +# Namespace variant CSV export tests. +def test_download_scores_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_scores_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?data_types=scores&drop_na_columns=true" + ) + assert download_scores_csv_response.status_code == 200 + download_scores_csv = download_scores_csv_response.text + reader = csv.reader(StringIO(download_scores_csv)) + columns = next(reader) + assert "hgvs_nt" in columns + assert "hgvs_pro" in columns + assert "hgvs_splice" not in columns + assert "score" in columns + + +def test_download_counts_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_counts_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?data_types=counts&include_custom_columns=true&drop_na_columns=true" + ) + assert download_counts_csv_response.status_code == 200 + download_counts_csv = download_counts_csv_response.text + reader = csv.reader(StringIO(download_counts_csv)) + columns = next(reader) + assert "hgvs_nt" in columns + assert "hgvs_pro" in columns + assert "hgvs_splice" not in columns + assert "c_0" in columns + assert "c_1" in columns + + +def test_download_scores_and_counts_file(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_scores_and_counts_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?data_types=counts&data_types=scores&include_custom_columns=true&drop_na_columns=true" + ) + assert download_scores_and_counts_csv_response.status_code == 200 + download_scores_and_counts_csv = download_scores_and_counts_csv_response.text + reader = csv.DictReader(StringIO(download_scores_and_counts_csv)) + assert sorted(reader.fieldnames) == sorted( + [ + "accession", + "hgvs_nt", + "hgvs_pro", + "scores.score", + "counts.c_0", + "counts.c_1" + ] + ) + + +@pytest.mark.parametrize( + "mapped_variant,has_hgvs_g,has_hgvs_p", + [ + (None, False, False), + (TEST_MAPPED_VARIANT_WITH_HGVS_G_EXPRESSION, True, False), + (TEST_MAPPED_VARIANT_WITH_HGVS_P_EXPRESSION, False, True), + ], + ids=["without_post_mapped_vrs", "with_post_mapped_hgvs_g", "with_post_mapped_hgvs_p"], +) +def test_download_scores_counts_and_post_mapped_variants_file( + session, data_provider, client, setup_router_db, data_files, mapped_variant, has_hgvs_g, has_hgvs_p +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + if mapped_variant is not None: + create_mapped_variants_for_score_set(session, score_set["urn"], mapped_variant) + + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_multiple_data_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?data_types=scores&data_types=counts&include_custom_columns=true&include_post_mapped_hgvs=true&drop_na_columns=true" + ) + assert download_multiple_data_csv_response.status_code == 200 + download_multiple_data_csv = download_multiple_data_csv_response.text + reader = csv.DictReader(StringIO(download_multiple_data_csv)) + assert sorted(reader.fieldnames) == sorted( + [ + "accession", + "hgvs_nt", + "hgvs_pro", + "mavedb.post_mapped_hgvs_g", + "mavedb.post_mapped_hgvs_p", + "scores.score", + "counts.c_0", + "counts.c_1" + ] + ) + + ######################################################################################################################## # Fetching clinical controls and control options for a score set ########################################################################################################################