From 7f0de9fcbfd9630f599f1526d7fd5904c9f12c47 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Thu, 20 Feb 2025 17:09:40 -0800 Subject: [PATCH 1/6] add query operators and facet operator for tasks --- mp_api/client/routes/materials/tasks.py | 28 +++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 5e12f9aa1..5a37c7cf7 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -39,11 +39,17 @@ def search( elements: list[str] | None = None, exclude_elements: list[str] | None = None, formula: str | list[str] | None = None, + calc_type: str | None = None, + run_type: str | None = None, + task_type: str | None = None, + chemsys: str | list[str] | None = None, last_updated: tuple[datetime, datetime] | None = None, + batches: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, - all_fields: bool = True, + all_fields: bool = False, fields: list[str] | None = None, + facets: str | list[str] | None = None, ) -> list[TaskDoc] | list[dict]: """Query core task docs using a variety of search criteria. @@ -73,7 +79,10 @@ def search( query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) if formula: - query_params.update({"formula": formula}) + query_params.update({"formula": ",".join(formula) if isinstance(formula, list) else formula}) + + if chemsys: + query_params.update({"chemsys": ",".join(chemsys) if isinstance(chemsys, list) else chemsys}) if elements: query_params.update({"elements": ",".join(elements)}) @@ -89,6 +98,21 @@ def search( } ) + if task_type: + query_params.update({"task_type": task_type}) + + if calc_type: + query_params.update({"calc_type": calc_type}) + + if run_type: + query_params.update({"run_type": run_type}) + + if batches: + query_params.update({"batches": ".".join(batches) if isinstance(batches, list) else batches}) + + if facets: + query_params.update({"facets": ",".join(facets) if isinstance(facets, list) else facets}) + return super()._search( num_chunks=num_chunks, chunk_size=chunk_size, From b4c9e365dfea6b6dbe32505a03d813f660fc11a4 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Mon, 24 Feb 2025 16:12:24 -0800 Subject: [PATCH 2/6] return facet if specified --- mp_api/client/core/client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 9627f5d62..84cd4736e 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -888,6 +888,7 @@ def _submit_requests( # noqa if data_tuples and "meta" in data_tuples[0][0]: total_data["meta"]["time_stamp"] = data_tuples[0][0]["meta"]["time_stamp"] + total_data["meat"]["facets"] = data_tuples[0][0]["meta"]["facets"] if pbar is not None: pbar.close() @@ -1236,6 +1237,7 @@ def _get_all_documents( fields=None, chunk_size=1000, num_chunks=None, + facets=None, ) -> list[T] | list[dict]: """Iterates over pages until all documents are retrieved. Displays progress using tqdm. This method is designed to give a common @@ -1267,7 +1269,6 @@ def _get_all_documents( ) chosen_param = list_entries[0][0] if len(list_entries) > 0 else None - results = self._query_resource( query_params, fields=fields, @@ -1275,8 +1276,10 @@ def _get_all_documents( chunk_size=chunk_size, num_chunks=num_chunks, ) - - return results["data"] + if facets: + return results["data"], results["meta"] + else: + return results["data"] def count(self, criteria: dict | None = None) -> int | str: """Return a count of total documents. From 0672c7e39637f5aa5f1496da67d4656dce3a0c15 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Thu, 27 Feb 2025 14:36:50 -0800 Subject: [PATCH 3/6] add warnings for including fields --- mp_api/client/routes/materials/tasks.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 5a37c7cf7..96ecd59ed 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime - +import warnings from emmet.core.tasks import TaskDoc from mp_api.client.core import BaseRester, MPRestError @@ -44,7 +44,7 @@ def search( task_type: str | None = None, chemsys: str | list[str] | None = None, last_updated: tuple[datetime, datetime] | None = None, - batches: str | list[str] | None = None, + batches: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = False, @@ -55,12 +55,18 @@ def search( Arguments: task_ids (str, List[str]): List of Materials Project IDs to return data for. - elements (List[str]): A list of elements. + chemsys: (str, List[str]): A list of chemical systems to search for. exclude_elements (List[str]): A list of elements to exclude. formula (str, List[str]): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed (e.g., [Fe2O3, ABO3]). last_updated (tuple[datetime, datetime]): A tuple of min and max UTC formatted datetimes. + batches (str, List[str]): A list of batch IDs to search for. + run_type (str): The type of task to search for. Can be one of the following: + #TODO: check enum "GGA", "GGA+U", "SCAN" + task_type (str): The type of task to search for. Can be one of the following: + #TODO check enum NSCF Line + calc_type (str): The type of calculation to search for. A combination of the run_type and task_type. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. Max size is 100. all_fields (bool): Whether to return all fields in the document. Defaults to True. @@ -113,6 +119,17 @@ def search( if facets: query_params.update({"facets": ",".join(facets) if isinstance(facets, list) else facets}) + if all_fields: + warnings.warn( + """Please only use all_fields=True when necessary, as it may cause slow query. + """ + ) + if fields and ("calcs_reversed" in fields or "orig_inputs" in fields): + warnings.warn( + """Please only include calcs_reversed and orig_inputs when necessary, as it may cause slow query. + """ + ) + return super()._search( num_chunks=num_chunks, chunk_size=chunk_size, From 46ad5a2f1ecc62fdf79e53544d6438947f972f59 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Thu, 27 Feb 2025 14:36:56 -0800 Subject: [PATCH 4/6] fix type --- mp_api/client/core/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 84cd4736e..2d64ec540 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -888,7 +888,7 @@ def _submit_requests( # noqa if data_tuples and "meta" in data_tuples[0][0]: total_data["meta"]["time_stamp"] = data_tuples[0][0]["meta"]["time_stamp"] - total_data["meat"]["facets"] = data_tuples[0][0]["meta"]["facets"] + total_data["meta"]["facets"] = data_tuples[0][0]["meta"]["facet"] if pbar is not None: pbar.close() From 6d3c6428230857982d82b47d4d8c6bb167ed298a Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Wed, 5 Mar 2025 15:20:26 -0800 Subject: [PATCH 5/6] lint --- mp_api/client/routes/materials/tasks.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 96ecd59ed..cb159d19f 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -1,7 +1,8 @@ from __future__ import annotations -from datetime import datetime import warnings +from datetime import datetime + from emmet.core.tasks import TaskDoc from mp_api.client.core import BaseRester, MPRestError @@ -56,6 +57,7 @@ def search( Arguments: task_ids (str, List[str]): List of Materials Project IDs to return data for. chemsys: (str, List[str]): A list of chemical systems to search for. + elements: (List[str]): A list of elements to search for. exclude_elements (List[str]): A list of elements to exclude. formula (str, List[str]): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed @@ -63,15 +65,16 @@ def search( last_updated (tuple[datetime, datetime]): A tuple of min and max UTC formatted datetimes. batches (str, List[str]): A list of batch IDs to search for. run_type (str): The type of task to search for. Can be one of the following: - #TODO: check enum "GGA", "GGA+U", "SCAN" + #TODO: check enum task_type (str): The type of task to search for. Can be one of the following: - #TODO check enum NSCF Line + #TODO check enum calc_type (str): The type of calculation to search for. A combination of the run_type and task_type. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. Max size is 100. all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in TaskDoc to return data for. Default is material_id, last_updated, and formula_pretty if all_fields is False. + facets (str, List[str]): List of facets to return data for. Returns: ([TaskDoc], [dict]) List of task documents or dictionaries. @@ -85,10 +88,14 @@ def search( query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) if formula: - query_params.update({"formula": ",".join(formula) if isinstance(formula, list) else formula}) + query_params.update( + {"formula": ",".join(formula) if isinstance(formula, list) else formula} + ) if chemsys: - query_params.update({"chemsys": ",".join(chemsys) if isinstance(chemsys, list) else chemsys}) + query_params.update( + {"chemsys": ",".join(chemsys) if isinstance(chemsys, list) else chemsys} + ) if elements: query_params.update({"elements": ",".join(elements)}) @@ -114,10 +121,14 @@ def search( query_params.update({"run_type": run_type}) if batches: - query_params.update({"batches": ".".join(batches) if isinstance(batches, list) else batches}) + query_params.update( + {"batches": ".".join(batches) if isinstance(batches, list) else batches} + ) if facets: - query_params.update({"facets": ",".join(facets) if isinstance(facets, list) else facets}) + query_params.update( + {"facets": ",".join(facets) if isinstance(facets, list) else facets} + ) if all_fields: warnings.warn( From 9c93bd65af6f49255220c10ab0019a6e5a230057 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Fri, 7 Mar 2025 13:05:40 -0800 Subject: [PATCH 6/6] facet fix --- mp_api/client/core/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 2d64ec540..9449e0d3e 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -888,7 +888,7 @@ def _submit_requests( # noqa if data_tuples and "meta" in data_tuples[0][0]: total_data["meta"]["time_stamp"] = data_tuples[0][0]["meta"]["time_stamp"] - total_data["meta"]["facets"] = data_tuples[0][0]["meta"]["facet"] + total_data["meta"]["facets"] = data_tuples[0][0]["meta"].get("facet", None) if pbar is not None: pbar.close()