We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 540a2e7 commit 37ed7e2Copy full SHA for 37ed7e2
1 file changed
src/e3tools/nn/_conv.py
@@ -229,3 +229,20 @@ def forward(
229
out: [N, irreps_out.dim]
230
"""
231
return self.gated_conv(node_attr, edge_index, edge_attr, edge_sh)
232
+
233
234
+class SeparableConvBlock(ConvBlock):
235
+ """e3tools.nn.ConvBlock with SeparableConv as the underlying convolution layer."""
236
237
+ def __init__(self, *args, **kwargs):
238
+ """
239
+ Initializes the SeparableConvBlock.
240
241
+ All arguments are passed directly to the parent ConvBlock,
242
+ with the 'conv' argument specifically set to SeparableConv.
243
244
+ super().__init__(
245
+ *args,
246
+ **kwargs,
247
+ conv=SeparableConv, # Explicitly set the convolution type to SeparableConv
248
+ )
0 commit comments