From 51d5c32c468a8bc00e2cf9c392bfcfb3419483d1 Mon Sep 17 00:00:00 2001 From: Masayuki Takagi Date: Wed, 8 Nov 2023 08:57:49 +0000 Subject: [PATCH] Remove Gather op when exported --- torch_dftd/functions/dftd3.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torch_dftd/functions/dftd3.py b/torch_dftd/functions/dftd3.py index f4a22b6..46a993b 100644 --- a/torch_dftd/functions/dftd3.py +++ b/torch_dftd/functions/dftd3.py @@ -109,14 +109,15 @@ def _getc6_impl( Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3 ) -> Tensor: # gather the relevant entries from the table - # c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3) - c6ab_ = c6ab[Zi, Zj].type(nci.dtype) - # calculate c6 coefficients - - # cn0, cn1, cn2 (n_edges, 5, 5) - cn0 = c6ab_[:, :, :, 0] - cn1 = c6ab_[:, :, :, 1] - cn2 = c6ab_[:, :, :, 2] + # c6ab (95, 95, 5, 5, 3) --> cni (9025, 5, 5, 1) + cn0, cn1, cn2 = c6ab.reshape(-1, 5, 5, 3).split(1, dim=3) + index = Zi * c6ab.size(1) + Zj + + # cni (9025, 5, 5, 1) --> cni (n_edges, 5, 5) + cn0 = cn0.squeeze(dim=3)[index].type(nci.dtype) + cn1 = cn1.squeeze(dim=3)[index].type(nci.dtype) + cn2 = cn2.squeeze(dim=3)[index].type(nci.dtype) + r = (cn1 - nci[:, None, None]) ** 2 + (cn2 - ncj[:, None, None]) ** 2 n_edges = r.shape[0]