Skip to content

Commit 0b0ac2c

Browse files
author
Neumann, Jan
committed
Revert changes from 45d6d6a to compute_layer_shapes.
1 parent 45d6d6a commit 0b0ac2c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/compute_graph_vectorize/vectorize/pipeline/compute_layer_shapes.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,12 @@ def compute_linear_shape_from_shapes(self, input_shape: Shape, weight_shape: Sha
128128
return ConcreteShape([*begin_shape, weight_shape[-2], input_shape[-1]])
129129
case (VariousShape(), _) | (_, VariousShape()):
130130
return VARIOUS_SHAPE
131-
case (AnyShape(), _) | (_, AnyShape()):
132-
return ANY_SHAPE
131+
case (ConcreteShape(_), _):
132+
return weight_shape
133+
case (_, ConcreteShape(_)):
134+
return input_shape
133135
case _:
134-
assert False, f"{weight_shape} {input_shape}"
136+
return ANY_SHAPE
135137

136138
def compute_linear_shape(self, batch: int, input: Input, weight: Input) -> Shape:
137139
weight_shape = self.compute_input_shape(batch, weight)

0 commit comments

Comments
 (0)