Skip to content

Commit 37ed7e2

Browse files
authored
Add SeparableConvBlock
1 parent 540a2e7 commit 37ed7e2

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

src/e3tools/nn/_conv.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,20 @@ def forward(
229229
out: [N, irreps_out.dim]
230230
"""
231231
return self.gated_conv(node_attr, edge_index, edge_attr, edge_sh)
232+
233+
234+
class SeparableConvBlock(ConvBlock):
235+
"""e3tools.nn.ConvBlock with SeparableConv as the underlying convolution layer."""
236+
237+
def __init__(self, *args, **kwargs):
238+
"""
239+
Initializes the SeparableConvBlock.
240+
241+
All arguments are passed directly to the parent ConvBlock,
242+
with the 'conv' argument specifically set to SeparableConv.
243+
"""
244+
super().__init__(
245+
*args,
246+
**kwargs,
247+
conv=SeparableConv, # Explicitly set the convolution type to SeparableConv
248+
)

0 commit comments

Comments
 (0)