Skip to content

Commit dead47c

Browse files
author
Neumann, Jan
committed
Replace Modules' total_items and optimal_period with compute_output_shape() and comput_optimal_shape()
1 parent 0bd4c2c commit dead47c

19 files changed

+427
-534
lines changed

lib/nn/aggregation/fixed_count.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@
33

44
import torch
55
import torch.nn.functional as F
6+
from torch.jit import unused
67

78
from lib.nn.definitions.ops import AggregationDef
9+
from lib.nn.utils import ShapeTransformable
810

911

10-
class FixedCountAggregation(torch.nn.Module):
12+
class FixedCountAggregation(torch.nn.Module, ShapeTransformable):
1113
def __init__(self, dim: int = 1) -> None:
1214
super().__init__()
1315
self.dim = dim
1416

1517
def forward(self, x: torch.Tensor) -> torch.Tensor:
1618
raise NotImplementedError()
1719

20+
@unused
21+
def compute_output_shape(self, shape: list[int]) -> list[int]:
22+
return [shape[0], *shape[2:]]
23+
1824
def extra_repr(self) -> str:
1925
return f"dim={self.dim}"
2026

lib/nn/aggregation/universal.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from typing import List, Protocol, Union
22

33
import torch
4+
from torch.jit import unused
45

56
from lib.nn.aggregation.fixed_count import FixedCountAggregation, build_fixed_count_aggregate
67
from lib.nn.aggregation.scatter import build_optimal_scatter_aggregate
78
from lib.nn.definitions.ops import AggregationDef
89
from lib.nn.gather import ViewWithPeriod
10+
from lib.nn.scatter import Scatter, SegmentCOO, SegmentCSR
11+
from lib.nn.utils import ShapeTransformable
912

1013

11-
class ReshapeAggregateLike(Protocol):
14+
class ReshapeAggregateLike(ShapeTransformable, Protocol):
1215
@property
1316
def is_matching_dimension(self) -> bool:
1417
...
@@ -33,6 +36,12 @@ def __init__(self, view: ViewWithPeriod, aggregate: FixedCountAggregation) -> No
3336
self.reshape = view
3437
self.aggregate = aggregate
3538

39+
@unused
40+
def compute_output_shape(self, shape_like) -> list[int]:
41+
shape_like = self.reshape.compute_output_shape(shape_like)
42+
shape_like = self.aggregate.compute_output_shape(shape_like)
43+
return shape_like
44+
3645
@property
3746
def is_matching_dimension(self) -> bool:
3847
return True
@@ -50,10 +59,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5059

5160

5261
class ScatterAggregate(torch.nn.Module, ReshapeAggregateModuleLike):
53-
def __init__(self, scatter: torch.nn.Module) -> None:
62+
def __init__(self, scatter: Scatter | SegmentCOO | SegmentCSR) -> None:
5463
super().__init__()
5564
self.delegate = scatter
5665

66+
@unused
67+
def compute_output_shape(self, shape_like) -> list[int]:
68+
return self.delegate.compute_output_shape(shape_like)
69+
5770
def forward(self, x):
5871
return self.delegate(x)
5972

@@ -89,7 +102,7 @@ def build_optimal_reshape_aggregate(
89102
if (counts[1:] == counts[0]).all():
90103
period = int(counts[0].item())
91104
return ViewAndAggregate(
92-
view=ViewWithPeriod(input_length=period * counts.shape[0], period=period),
105+
view=ViewWithPeriod(period=period),
93106
aggregate=build_fixed_count_aggregate(aggregation=aggregation, dim=1),
94107
)
95108

0 commit comments

Comments
 (0)