Skip to content

Commit f58e3b5

Browse files
Similarity search with vector search (#1032)
2 parents 52a3c57 + 2e8d9f8 commit f58e3b5

File tree

6 files changed

+112
-8
lines changed

6 files changed

+112
-8
lines changed

mp_api/client/routes/materials/similarity.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
11
from __future__ import annotations
22

3-
from emmet.core.similarity import SimilarityDoc
3+
from typing import TYPE_CHECKING
44

5-
from mp_api.client.core import BaseRester
5+
from emmet.core.mpid import MPID, AlphaID
6+
from emmet.core.similarity import (
7+
CrystalNNSimilarity,
8+
SimilarityDoc,
9+
SimilarityEntry,
10+
_vector_to_hex_and_norm,
11+
)
12+
from pymatgen.core import Composition, Structure
13+
14+
from mp_api.client.core import BaseRester, MPRestError
615
from mp_api.client.core.utils import validate_ids
716

17+
if TYPE_CHECKING:
18+
import numpy as np
19+
from emmet.core.similarity import SimilarityScorer
20+
21+
# This limit seems to be associated with MongoDB vector search
22+
MAX_VECTOR_SEARCH_RESULTS = 10_000
23+
824

925
class SimilarityRester(BaseRester):
1026
suffix = "materials/similarity"
1127
document_model = SimilarityDoc # type: ignore
1228
primary_key = "material_id"
1329

30+
_fingerprinter: SimilarityScorer | None = None
31+
32+
def fingerprint_structure(self, structure: Structure) -> np.ndarray:
33+
"""Get the fingerprint of a user-submitted structures."""
34+
if self._fingerprinter is None:
35+
self._fingerprinter = CrystalNNSimilarity()
36+
return self._fingerprinter._featurize_structure(structure)
37+
1438
def search(
1539
self,
1640
material_ids: str | list[str] | None = None,
@@ -53,3 +77,83 @@ def search(
5377
fields=fields,
5478
**query_params,
5579
)
80+
81+
def find_similar(
82+
self,
83+
structure_or_mpid: Structure | str | MPID | AlphaID,
84+
top: int | None = 50,
85+
num_chunks: int | None = None,
86+
chunk_size: int | None = 1000,
87+
) -> list[SimilarityEntry] | list[dict]:
88+
"""Find structures most similar to a user-submitted structure.
89+
90+
Arguments:
91+
structure_or_mpid : pymatgen .Structure, or str, MPID, AlphaID
92+
If a .Structure, the feature vector is computed on the fly
93+
If a str, MPID, or AlphaID, attempts to retrieve a pre-computed
94+
feature vector using the input as a material ID
95+
top : int
96+
The number of most similar materials to return, defaults to 50.
97+
Setting to None will return the maximum possible number of
98+
most similar materials..
99+
num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible.
100+
chunk_size (int or None): Number of data entries per chunk.
101+
The chunk_size is also used to limit the number of responses returned.
102+
103+
Returns:
104+
([SimilarityEntry] | [dict]) List of SimilarityEntry documents
105+
(if `use_document_model`) or dict (otherwise) listing
106+
structures most similar to the input structure.
107+
"""
108+
if isinstance(structure_or_mpid, str | MPID | AlphaID):
109+
fmt_idx = AlphaID(structure_or_mpid).string
110+
111+
docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"])
112+
if not docs:
113+
raise MPRestError(f"No similarity data available for {fmt_idx}")
114+
feature_vector = docs[0]["feature_vector"]
115+
116+
elif isinstance(structure_or_mpid, Structure):
117+
feature_vector = self.fingerprint_structure(structure_or_mpid)
118+
119+
else:
120+
raise ValueError("Please submit a pymatgen Structure or MP ID.")
121+
122+
top = top or MAX_VECTOR_SEARCH_RESULTS
123+
if not isinstance(top, int) or top < 1:
124+
raise ValueError(
125+
f"Invalid number of possible top matches specified = {top}."
126+
"Please specify a positive integer or `None` to return all results."
127+
)
128+
129+
vector_hex, vector_norm = _vector_to_hex_and_norm(feature_vector)
130+
result = self._query_resource(
131+
criteria={
132+
"feature_vector_hex": vector_hex,
133+
"feature_vector_norm": vector_norm,
134+
"_limit": top,
135+
},
136+
suburl="match",
137+
use_document_model=False, # Return type is not exactly a SimilarityDoc, closer to SimilarityEntry
138+
chunk_size=chunk_size,
139+
num_chunks=num_chunks,
140+
).get("data", None)
141+
142+
if result is None:
143+
raise MPRestError(
144+
"Could not find any structures similar to the input structure."
145+
)
146+
147+
sim_docs = [
148+
{
149+
"formula": entry["formula_pretty"],
150+
"task_id": entry["material_id"],
151+
"nelements": len(Composition(entry["formula_pretty"]).elements),
152+
"dissimilarity": 100 * (1.0 - entry["score"]),
153+
}
154+
for entry in result
155+
]
156+
157+
if self.use_document_model:
158+
return [SimilarityEntry(**doc) for doc in sim_docs]
159+
return sim_docs

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ dependencies = [
2525
"typing-extensions>=3.7.4.1",
2626
"requests>=2.23.0",
2727
"monty>=2024.12.10",
28-
"emmet-core>=0.85.1rc0",
28+
"emmet-core>=0.86.2rc1",
2929
"smart_open",
3030
"boto3",
3131
"orjson >= 3.10,<4",
3232
]
3333
dynamic = ["version"]
3434

3535
[project.optional-dependencies]
36-
all = ["emmet-core[all]>=0.85.1rc0", "custodian", "mpcontribs-client>=5.10"]
36+
all = ["emmet-core[all]>=0.86.2rc1", "custodian", "mpcontribs-client>=5.10"]
3737
test = [
3838
"pre-commit",
3939
"pytest",

requirements/requirements-ubuntu-latest_py3.11.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ contourpy==1.3.3
2424
# via matplotlib
2525
cycler==0.12.1
2626
# via matplotlib
27-
emmet-core==0.86.0
27+
emmet-core==0.86.2rc1
2828
# via mp-api (pyproject.toml)
2929
fonttools==4.60.1
3030
# via matplotlib

requirements/requirements-ubuntu-latest_py3.11_extras.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ dnspython==2.8.0
6262
# pymongo
6363
docutils==0.21.2
6464
# via sphinx
65-
emmet-core[all]==0.86.0
65+
emmet-core[all]==0.86.2rc1
6666
# via mp-api (pyproject.toml)
6767
execnet==2.1.1
6868
# via pytest-xdist

requirements/requirements-ubuntu-latest_py3.12.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ contourpy==1.3.3
2424
# via matplotlib
2525
cycler==0.12.1
2626
# via matplotlib
27-
emmet-core==0.86.0
27+
emmet-core==0.86.2rc1
2828
# via mp-api (pyproject.toml)
2929
fonttools==4.60.1
3030
# via matplotlib

requirements/requirements-ubuntu-latest_py3.12_extras.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ dnspython==2.8.0
6262
# pymongo
6363
docutils==0.21.2
6464
# via sphinx
65-
emmet-core[all]==0.86.0
65+
emmet-core[all]==0.86.2rc1
6666
# via mp-api (pyproject.toml)
6767
execnet==2.1.1
6868
# via pytest-xdist

0 commit comments

Comments
 (0)