1
1
from typing import Dict , List , TYPE_CHECKING , Optional
2
2
3
3
import torch
4
+ import numpy as np
4
5
from openff .utilities import requires_package
5
6
6
7
from openff .nagl .features .atoms import AtomFeature
@@ -39,7 +40,8 @@ def openff_molecule_to_base_dgl_graph(
39
40
{
40
41
("atom" , forward , "atom" ): (indices_a , indices_b ),
41
42
("atom" , reverse , "atom" ): (indices_b , indices_a ),
42
- }
43
+ },
44
+ num_nodes_dict = {"atom" : molecule .n_atoms },
43
45
)
44
46
return molecule_graph
45
47
@@ -99,7 +101,7 @@ def openff_molecule_to_dgl_graph(
99
101
100
102
for direction in (forward , reverse ):
101
103
n_bonds = len (molecule .bonds )
102
- if bond_feature_tensor is not None :
104
+ if bond_feature_tensor is not None and n_bonds :
103
105
bond_feature_tensor = bond_feature_tensor .reshape (n_bonds , - 1 )
104
106
else :
105
107
bond_feature_tensor = torch .zeros ((n_bonds , 0 ))
@@ -108,13 +110,87 @@ def openff_molecule_to_dgl_graph(
108
110
109
111
return molecule_graph
110
112
113
+ @requires_package ("dgl" )
114
+ def heterograph_to_homograph_no_edges (G : "dgl.DGLHeteroGraph" , ndata = None , edata = None ) -> "dgl.DGLGraph" :
115
+ """
116
+ Copied and modified from dgl.python.dgl.convert.to_homogeneous,
117
+ but with the edges removed.
118
+
119
+ This part of the code is licensed under the Apache 2.0 license according
120
+ to the terms of DGL (https://github.com/dmlc/dgl?tab=Apache-2.0-1-ov-file).
121
+
122
+ Please see our third-party license file for more information
123
+ (https://github.com/openforcefield/openff-nagl/blob/main/LICENSE-3RD-PARTY)
124
+ """
125
+ import dgl
126
+ from dgl import backend as F
127
+ from dgl .base import EID , NID , ETYPE , NTYPE
128
+ from dgl .heterograph import combine_frames
129
+
130
+ # TODO: revisit in case DGL accounts for this in the future
131
+ num_nodes_per_ntype = [G .num_nodes (ntype ) for ntype in G .ntypes ]
132
+ offset_per_ntype = np .insert (np .cumsum (num_nodes_per_ntype ), 0 , 0 )
133
+ srcs = []
134
+ dsts = []
135
+ nids = []
136
+ eids = []
137
+ ntype_ids = []
138
+ etype_ids = []
139
+ total_num_nodes = 0
140
+
141
+ for ntype_id , ntype in enumerate (G .ntypes ):
142
+ num_nodes = G .num_nodes (ntype )
143
+ total_num_nodes += num_nodes
144
+ # Type ID is always in int64
145
+ ntype_ids .append (F .full_1d (num_nodes , ntype_id , F .int64 , G .device ))
146
+ nids .append (F .arange (0 , num_nodes , G .idtype , G .device ))
147
+
148
+ for etype_id , etype in enumerate (G .canonical_etypes ):
149
+ srctype , _ , dsttype = etype
150
+ src , dst = G .all_edges (etype = etype , order = "eid" )
151
+ num_edges = len (src )
152
+ srcs .append (src + int (offset_per_ntype [G .get_ntype_id (srctype )]))
153
+ dsts .append (dst + int (offset_per_ntype [G .get_ntype_id (dsttype )]))
154
+ etype_ids .append (F .full_1d (num_edges , etype_id , F .int64 , G .device ))
155
+ eids .append (F .arange (0 , num_edges , G .idtype , G .device ))
156
+
157
+ retg = dgl .graph (
158
+ (F .cat (srcs , 0 ), F .cat (dsts , 0 )),
159
+ num_nodes = total_num_nodes ,
160
+ idtype = G .idtype ,
161
+ device = G .device ,
162
+ )
163
+
164
+ # copy features
165
+ if ndata is None :
166
+ ndata = []
167
+ if edata is None :
168
+ edata = []
169
+ comb_nf = combine_frames (
170
+ G ._node_frames , range (len (G .ntypes )), col_names = ndata
171
+ )
172
+ if comb_nf is not None :
173
+ retg .ndata .update (comb_nf )
174
+
175
+ retg .ndata [NID ] = F .cat (nids , 0 )
176
+ retg .edata [EID ] = F .cat (eids , 0 )
177
+ retg .ndata [NTYPE ] = F .cat (ntype_ids , 0 )
178
+ retg .edata [ETYPE ] = F .cat (etype_ids , 0 )
179
+
180
+ return retg
181
+
182
+
183
+
111
184
112
185
@requires_package ("dgl" )
113
186
def dgl_heterograph_to_homograph (graph : "dgl.DGLHeteroGraph" ) -> "dgl.DGLGraph" :
114
187
import dgl
115
188
116
189
try :
117
190
homo_graph = dgl .to_homogeneous (graph , ndata = [FEATURE ], edata = [FEATURE ])
191
+ except TypeError as e :
192
+ if graph .num_edges () == 0 :
193
+ homo_graph = heterograph_to_homograph_no_edges (graph )
118
194
except KeyError :
119
195
# A nasty workaround to check when we don't have any atom / bond features as
120
196
# DGL doesn't allow easy querying of features dicts for hetereographs with
0 commit comments