Skip to content

Commit

Permalink
Merge pull request #20 from takagi/remove-onnx-gather
Browse files Browse the repository at this point in the history
Reduce Gather op when ONNX-exported
  • Loading branch information
corochann authored Nov 14, 2023
2 parents a5b9fae + 51d5c32 commit c690dd6
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions torch_dftd/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ def _getc6_impl(
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
) -> Tensor:
# gather the relevant entries from the table
# c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
c6ab_ = c6ab[Zi, Zj].type(nci.dtype)
# calculate c6 coefficients

# cn0, cn1, cn2 (n_edges, 5, 5)
cn0 = c6ab_[:, :, :, 0]
cn1 = c6ab_[:, :, :, 1]
cn2 = c6ab_[:, :, :, 2]
# c6ab (95, 95, 5, 5, 3) --> cni (9025, 5, 5, 1)
cn0, cn1, cn2 = c6ab.reshape(-1, 5, 5, 3).split(1, dim=3)
index = Zi * c6ab.size(1) + Zj

# cni (9025, 5, 5, 1) --> cni (n_edges, 5, 5)
cn0 = cn0.squeeze(dim=3)[index].type(nci.dtype)
cn1 = cn1.squeeze(dim=3)[index].type(nci.dtype)
cn2 = cn2.squeeze(dim=3)[index].type(nci.dtype)

r = (cn1 - nci[:, None, None]) ** 2 + (cn2 - ncj[:, None, None]) ** 2

n_edges = r.shape[0]
Expand Down

0 comments on commit c690dd6

Please sign in to comment.