diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index cf1dc175e..6cb600d59 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -1,16 +1,40 @@ from __future__ import annotations -from emmet.core.similarity import SimilarityDoc +from typing import TYPE_CHECKING -from mp_api.client.core import BaseRester +from emmet.core.mpid import MPID, AlphaID +from emmet.core.similarity import ( + CrystalNNSimilarity, + SimilarityDoc, + SimilarityEntry, + _vector_to_hex_and_norm, +) +from pymatgen.core import Composition, Structure + +from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.utils import validate_ids +if TYPE_CHECKING: + import numpy as np + from emmet.core.similarity import SimilarityScorer + +# This limit seems to be associated with MongoDB vector search +MAX_VECTOR_SEARCH_RESULTS = 10_000 + class SimilarityRester(BaseRester): suffix = "materials/similarity" document_model = SimilarityDoc # type: ignore primary_key = "material_id" + _fingerprinter: SimilarityScorer | None = None + + def fingerprint_structure(self, structure: Structure) -> np.ndarray: + """Get the fingerprint of a user-submitted structures.""" + if self._fingerprinter is None: + self._fingerprinter = CrystalNNSimilarity() + return self._fingerprinter._featurize_structure(structure) + def search( self, material_ids: str | list[str] | None = None, @@ -53,3 +77,83 @@ def search( fields=fields, **query_params, ) + + def find_similar( + self, + structure_or_mpid: Structure | str | MPID | AlphaID, + top: int | None = 50, + num_chunks: int | None = None, + chunk_size: int | None = 1000, + ) -> list[SimilarityEntry] | list[dict]: + """Find structures most similar to a user-submitted structure. + + Arguments: + structure_or_mpid : pymatgen .Structure, or str, MPID, AlphaID + If a .Structure, the feature vector is computed on the fly + If a str, MPID, or AlphaID, attempts to retrieve a pre-computed + feature vector using the input as a material ID + top : int + The number of most similar materials to return, defaults to 50. + Setting to None will return the maximum possible number of + most similar materials.. + num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int or None): Number of data entries per chunk. + The chunk_size is also used to limit the number of responses returned. + + Returns: + ([SimilarityEntry] | [dict]) List of SimilarityEntry documents + (if `use_document_model`) or dict (otherwise) listing + structures most similar to the input structure. + """ + if isinstance(structure_or_mpid, str | MPID | AlphaID): + fmt_idx = AlphaID(structure_or_mpid).string + + docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"]) + if not docs: + raise MPRestError(f"No similarity data available for {fmt_idx}") + feature_vector = docs[0]["feature_vector"] + + elif isinstance(structure_or_mpid, Structure): + feature_vector = self.fingerprint_structure(structure_or_mpid) + + else: + raise ValueError("Please submit a pymatgen Structure or MP ID.") + + top = top or MAX_VECTOR_SEARCH_RESULTS + if not isinstance(top, int) or top < 1: + raise ValueError( + f"Invalid number of possible top matches specified = {top}." + "Please specify a positive integer or `None` to return all results." + ) + + vector_hex, vector_norm = _vector_to_hex_and_norm(feature_vector) + result = self._query_resource( + criteria={ + "feature_vector_hex": vector_hex, + "feature_vector_norm": vector_norm, + "_limit": top, + }, + suburl="match", + use_document_model=False, # Return type is not exactly a SimilarityDoc, closer to SimilarityEntry + chunk_size=chunk_size, + num_chunks=num_chunks, + ).get("data", None) + + if result is None: + raise MPRestError( + "Could not find any structures similar to the input structure." + ) + + sim_docs = [ + { + "formula": entry["formula_pretty"], + "task_id": entry["material_id"], + "nelements": len(Composition(entry["formula_pretty"]).elements), + "dissimilarity": 100 * (1.0 - entry["score"]), + } + for entry in result + ] + + if self.use_document_model: + return [SimilarityEntry(**doc) for doc in sim_docs] + return sim_docs diff --git a/pyproject.toml b/pyproject.toml index afefc1e06..093f8312b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.85.1rc0", + "emmet-core>=0.86.2rc1", "smart_open", "boto3", "orjson >= 3.10,<4", @@ -33,7 +33,7 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -all = ["emmet-core[all]>=0.85.1rc0", "custodian", "mpcontribs-client>=5.10"] +all = ["emmet-core[all]>=0.86.2rc1", "custodian", "mpcontribs-client>=5.10"] test = [ "pre-commit", "pytest", diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index d3ff66cb4..11c93cf9d 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.0 +emmet-core==0.86.2rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 59ac2f166..c94daa297 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -62,7 +62,7 @@ dnspython==2.8.0 # pymongo docutils==0.21.2 # via sphinx -emmet-core[all]==0.86.0 +emmet-core[all]==0.86.2rc1 # via mp-api (pyproject.toml) execnet==2.1.1 # via pytest-xdist diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index 29ee10749..8ecc55fff 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.0 +emmet-core==0.86.2rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 9b4c609a5..c5f57a7d6 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -62,7 +62,7 @@ dnspython==2.8.0 # pymongo docutils==0.21.2 # via sphinx -emmet-core[all]==0.86.0 +emmet-core[all]==0.86.2rc1 # via mp-api (pyproject.toml) execnet==2.1.1 # via pytest-xdist