1
1
from typing import List , Protocol , Union
2
2
3
3
import torch
4
+ from torch .jit import unused
4
5
5
6
from lib .nn .aggregation .fixed_count import FixedCountAggregation , build_fixed_count_aggregate
6
7
from lib .nn .aggregation .scatter import build_optimal_scatter_aggregate
7
8
from lib .nn .definitions .ops import AggregationDef
8
9
from lib .nn .gather import ViewWithPeriod
10
+ from lib .nn .scatter import Scatter , SegmentCOO , SegmentCSR
11
+ from lib .nn .utils import ShapeTransformable
9
12
10
13
11
- class ReshapeAggregateLike (Protocol ):
14
+ class ReshapeAggregateLike (ShapeTransformable , Protocol ):
12
15
@property
13
16
def is_matching_dimension (self ) -> bool :
14
17
...
@@ -33,6 +36,12 @@ def __init__(self, view: ViewWithPeriod, aggregate: FixedCountAggregation) -> No
33
36
self .reshape = view
34
37
self .aggregate = aggregate
35
38
39
+ @unused
40
+ def compute_output_shape (self , shape_like ) -> list [int ]:
41
+ shape_like = self .reshape .compute_output_shape (shape_like )
42
+ shape_like = self .aggregate .compute_output_shape (shape_like )
43
+ return shape_like
44
+
36
45
@property
37
46
def is_matching_dimension (self ) -> bool :
38
47
return True
@@ -50,10 +59,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
59
51
60
52
61
class ScatterAggregate (torch .nn .Module , ReshapeAggregateModuleLike ):
53
- def __init__ (self , scatter : torch . nn . Module ) -> None :
62
+ def __init__ (self , scatter : Scatter | SegmentCOO | SegmentCSR ) -> None :
54
63
super ().__init__ ()
55
64
self .delegate = scatter
56
65
66
+ @unused
67
+ def compute_output_shape (self , shape_like ) -> list [int ]:
68
+ return self .delegate .compute_output_shape (shape_like )
69
+
57
70
def forward (self , x ):
58
71
return self .delegate (x )
59
72
@@ -89,7 +102,7 @@ def build_optimal_reshape_aggregate(
89
102
if (counts [1 :] == counts [0 ]).all ():
90
103
period = int (counts [0 ].item ())
91
104
return ViewAndAggregate (
92
- view = ViewWithPeriod (input_length = period * counts . shape [ 0 ], period = period ),
105
+ view = ViewWithPeriod (period = period ),
93
106
aggregate = build_fixed_count_aggregate (aggregation = aggregation , dim = 1 ),
94
107
)
95
108
0 commit comments