Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions mp_api/client/routes/materials/similarity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
from __future__ import annotations

from emmet.core.similarity import SimilarityDoc
from typing import TYPE_CHECKING

from mp_api.client.core import BaseRester
from emmet.core.mpid import MPID, AlphaID
from emmet.core.similarity import (
CrystalNNSimilarity,
SimilarityDoc,
SimilarityEntry,
_vector_to_hex_and_norm,
)
from pymatgen.core import Composition, Structure

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.utils import validate_ids

if TYPE_CHECKING:
import numpy as np
from emmet.core.similarity import SimilarityScorer

# This limit seems to be associated with MongoDB vector search
MAX_VECTOR_SEARCH_RESULTS = 10_000


class SimilarityRester(BaseRester):
suffix = "materials/similarity"
document_model = SimilarityDoc # type: ignore
primary_key = "material_id"

_fingerprinter: SimilarityScorer | None = None

def fingerprint_structure(self, structure: Structure) -> np.ndarray:
"""Get the fingerprint of a user-submitted structures."""
if self._fingerprinter is None:
self._fingerprinter = CrystalNNSimilarity()
return self._fingerprinter._featurize_structure(structure)

def search(
self,
material_ids: str | list[str] | None = None,
Expand Down Expand Up @@ -53,3 +77,83 @@ def search(
fields=fields,
**query_params,
)

def find_similar(
self,
structure_or_mpid: Structure | str | MPID | AlphaID,
top: int | None = 50,
num_chunks: int | None = None,
chunk_size: int | None = 1000,
) -> list[SimilarityEntry] | list[dict]:
"""Find structures most similar to a user-submitted structure.

Arguments:
structure_or_mpid : pymatgen .Structure, or str, MPID, AlphaID
If a .Structure, the feature vector is computed on the fly
If a str, MPID, or AlphaID, attempts to retrieve a pre-computed
feature vector using the input as a material ID
top : int
The number of most similar materials to return, defaults to 50.
Setting to None will return the maximum possible number of
most similar materials..
num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int or None): Number of data entries per chunk.
The chunk_size is also used to limit the number of responses returned.

Returns:
([SimilarityEntry] | [dict]) List of SimilarityEntry documents
(if `use_document_model`) or dict (otherwise) listing
structures most similar to the input structure.
"""
if isinstance(structure_or_mpid, str | MPID | AlphaID):
fmt_idx = AlphaID(structure_or_mpid).string

docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"])
if not docs:
raise MPRestError(f"No similarity data available for {fmt_idx}")
feature_vector = docs[0]["feature_vector"]

elif isinstance(structure_or_mpid, Structure):
feature_vector = self.fingerprint_structure(structure_or_mpid)

else:
raise ValueError("Please submit a pymatgen Structure or MP ID.")

top = top or MAX_VECTOR_SEARCH_RESULTS
if not isinstance(top, int) or top < 1:
raise ValueError(
f"Invalid number of possible top matches specified = {top}."
"Please specify a positive integer or `None` to return all results."
)

vector_hex, vector_norm = _vector_to_hex_and_norm(feature_vector)
result = self._query_resource(
criteria={
"feature_vector_hex": vector_hex,
"feature_vector_norm": vector_norm,
"_limit": top,
},
suburl="match",
use_document_model=False, # Return type is not exactly a SimilarityDoc, closer to SimilarityEntry
chunk_size=chunk_size,
num_chunks=num_chunks,
).get("data", None)

if result is None:
raise MPRestError(
"Could not find any structures similar to the input structure."
)

sim_docs = [
{
"formula": entry["formula_pretty"],
"task_id": entry["material_id"],
"nelements": len(Composition(entry["formula_pretty"]).elements),
"dissimilarity": 100 * (1.0 - entry["score"]),
}
for entry in result
]

if self.use_document_model:
return [SimilarityEntry(**doc) for doc in sim_docs]
return sim_docs
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ dependencies = [
"typing-extensions>=3.7.4.1",
"requests>=2.23.0",
"monty>=2024.12.10",
"emmet-core>=0.85.1rc0",
"emmet-core>=0.86.2rc1",
"smart_open",
"boto3",
"orjson >= 3.10,<4",
]
dynamic = ["version"]

[project.optional-dependencies]
all = ["emmet-core[all]>=0.85.1rc0", "custodian", "mpcontribs-client>=5.10"]
all = ["emmet-core[all]>=0.86.2rc1", "custodian", "mpcontribs-client>=5.10"]
test = [
"pre-commit",
"pytest",
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-ubuntu-latest_py3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contourpy==1.3.3
# via matplotlib
cycler==0.12.1
# via matplotlib
emmet-core==0.86.0
emmet-core==0.86.2rc1
# via mp-api (pyproject.toml)
fonttools==4.60.1
# via matplotlib
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-ubuntu-latest_py3.11_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dnspython==2.8.0
# pymongo
docutils==0.21.2
# via sphinx
emmet-core[all]==0.86.0
emmet-core[all]==0.86.2rc1
# via mp-api (pyproject.toml)
execnet==2.1.1
# via pytest-xdist
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-ubuntu-latest_py3.12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contourpy==1.3.3
# via matplotlib
cycler==0.12.1
# via matplotlib
emmet-core==0.86.0
emmet-core==0.86.2rc1
# via mp-api (pyproject.toml)
fonttools==4.60.1
# via matplotlib
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-ubuntu-latest_py3.12_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dnspython==2.8.0
# pymongo
docutils==0.21.2
# via sphinx
emmet-core[all]==0.86.0
emmet-core[all]==0.86.2rc1
# via mp-api (pyproject.toml)
execnet==2.1.1
# via pytest-xdist
Expand Down