Skip to content

Commit 6cb4387

Browse files
authored
Add ion capabilities (#146)
* add dgl ion capabilities * move files around and update tests * add test model and tidy files * add ion tests * remove commented out lines * add pointer to apache 2 license * update changelog [skip ci]
1 parent dffc164 commit 6cb4387

File tree

102 files changed

+3833
-83
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+3833
-83
lines changed

docs/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The rules for this file:
2828

2929
### Fixed
3030
- Check lookup tables for allowed molecules before ChemicalDomain for forbidden ones (PR #145, Issue #144)
31+
- Add support for single atoms (PR #146, Issue #138)
3132

3233

3334
## v0.4.0 -- 2024-07-18

openff/nagl/molecule/_dgl/utils.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, List, TYPE_CHECKING, Optional
22

33
import torch
4+
import numpy as np
45
from openff.utilities import requires_package
56

67
from openff.nagl.features.atoms import AtomFeature
@@ -39,7 +40,8 @@ def openff_molecule_to_base_dgl_graph(
3940
{
4041
("atom", forward, "atom"): (indices_a, indices_b),
4142
("atom", reverse, "atom"): (indices_b, indices_a),
42-
}
43+
},
44+
num_nodes_dict={"atom": molecule.n_atoms},
4345
)
4446
return molecule_graph
4547

@@ -99,7 +101,7 @@ def openff_molecule_to_dgl_graph(
99101

100102
for direction in (forward, reverse):
101103
n_bonds = len(molecule.bonds)
102-
if bond_feature_tensor is not None:
104+
if bond_feature_tensor is not None and n_bonds:
103105
bond_feature_tensor = bond_feature_tensor.reshape(n_bonds, -1)
104106
else:
105107
bond_feature_tensor = torch.zeros((n_bonds, 0))
@@ -108,13 +110,87 @@ def openff_molecule_to_dgl_graph(
108110

109111
return molecule_graph
110112

113+
@requires_package("dgl")
114+
def heterograph_to_homograph_no_edges(G: "dgl.DGLHeteroGraph", ndata=None, edata=None) -> "dgl.DGLGraph":
115+
"""
116+
Copied and modified from dgl.python.dgl.convert.to_homogeneous,
117+
but with the edges removed.
118+
119+
This part of the code is licensed under the Apache 2.0 license according
120+
to the terms of DGL (https://github.com/dmlc/dgl?tab=Apache-2.0-1-ov-file).
121+
122+
Please see our third-party license file for more information
123+
(https://github.com/openforcefield/openff-nagl/blob/main/LICENSE-3RD-PARTY)
124+
"""
125+
import dgl
126+
from dgl import backend as F
127+
from dgl.base import EID, NID, ETYPE, NTYPE
128+
from dgl.heterograph import combine_frames
129+
130+
# TODO: revisit in case DGL accounts for this in the future
131+
num_nodes_per_ntype = [G.num_nodes(ntype) for ntype in G.ntypes]
132+
offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)
133+
srcs = []
134+
dsts = []
135+
nids = []
136+
eids = []
137+
ntype_ids = []
138+
etype_ids = []
139+
total_num_nodes = 0
140+
141+
for ntype_id, ntype in enumerate(G.ntypes):
142+
num_nodes = G.num_nodes(ntype)
143+
total_num_nodes += num_nodes
144+
# Type ID is always in int64
145+
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, G.device))
146+
nids.append(F.arange(0, num_nodes, G.idtype, G.device))
147+
148+
for etype_id, etype in enumerate(G.canonical_etypes):
149+
srctype, _, dsttype = etype
150+
src, dst = G.all_edges(etype=etype, order="eid")
151+
num_edges = len(src)
152+
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
153+
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
154+
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, G.device))
155+
eids.append(F.arange(0, num_edges, G.idtype, G.device))
156+
157+
retg = dgl.graph(
158+
(F.cat(srcs, 0), F.cat(dsts, 0)),
159+
num_nodes=total_num_nodes,
160+
idtype=G.idtype,
161+
device=G.device,
162+
)
163+
164+
# copy features
165+
if ndata is None:
166+
ndata = []
167+
if edata is None:
168+
edata = []
169+
comb_nf = combine_frames(
170+
G._node_frames, range(len(G.ntypes)), col_names=ndata
171+
)
172+
if comb_nf is not None:
173+
retg.ndata.update(comb_nf)
174+
175+
retg.ndata[NID] = F.cat(nids, 0)
176+
retg.edata[EID] = F.cat(eids, 0)
177+
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
178+
retg.edata[ETYPE] = F.cat(etype_ids, 0)
179+
180+
return retg
181+
182+
183+
111184

112185
@requires_package("dgl")
113186
def dgl_heterograph_to_homograph(graph: "dgl.DGLHeteroGraph") -> "dgl.DGLGraph":
114187
import dgl
115188

116189
try:
117190
homo_graph = dgl.to_homogeneous(graph, ndata=[FEATURE], edata=[FEATURE])
191+
except TypeError as e:
192+
if graph.num_edges() == 0:
193+
homo_graph = heterograph_to_homograph_no_edges(graph)
118194
except KeyError:
119195
# A nasty workaround to check when we don't have any atom / bond features as
120196
# DGL doesn't allow easy querying of features dicts for hetereographs with

openff/nagl/molecule/_graph/_graph.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ def in_edges(self, nodes, form="uv"):
154154
raise ValueError("Unknown form: {}".format(form))
155155

156156
def _bond_indices(self):
157-
u, v = map(list, zip(*self.graph.edges()))
157+
try:
158+
u, v = map(list, zip(*self.graph.edges()))
159+
except ValueError as e:
160+
# this may be due to there not being bonds
161+
if not self.graph.edges():
162+
return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)
163+
raise e
158164
U = torch.tensor(u, dtype=torch.long)
159165
V = torch.tensor(v, dtype=torch.long)
160166
return U, V

0 commit comments

Comments
 (0)