From 1e6e9d30ea1498dbcecf211b1cbcaebc40205e18 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 14:34:53 +0200 Subject: [PATCH 1/6] itertools.product & unvectorized fw index --- basalt/autograd/ops/mlops.mojo | 78 +++++++++++++++++++++++++++++++++- basalt/autograd/ops/ops.mojo | 9 +++- basalt/utils/itertools.mojo | 47 ++++++++++++++++++++ tests/mojo/test_mlops.mojo | 76 +++++++++++++++++++++++---------- 4 files changed, 185 insertions(+), 25 deletions(-) create mode 100644 basalt/utils/itertools.mojo diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 0869919..0f9bb1f 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -4,6 +4,7 @@ from math.limit import min_finite, max_finite from basalt import Tensor, TensorShape from basalt.utils.tensorutils import elwise_transform +from basalt.utils.itertools import product from basalt.autograd.attributes import Attribute, AttributeVector @@ -491,4 +492,79 @@ struct SLICE: Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug) - return res_grad ^ \ No newline at end of file + return res_grad ^ + + +struct INDEX: + @staticmethod + fn adjust_boundary(slice: Int, dim_size: Int) -> Int: + # Adjust negative indices & ensure they are within bounds. + var s = slice if slice >= 0 else dim_size + slice + return max(min(s, dim_size), 0) + + @staticmethod + fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]: + var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s") + var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i") + + var indeces = List[List[Int]]() + for dim in range(shape.rank()): + var temp = List[Int]() + + # Option 1: Slice + if attrs[SLICE_LITERALS[dim]]: + var slice = attrs[SLICE_LITERALS[dim]].value().to_shape() + var step = slice[2] if slice.rank() == 3 else 1 + for i in range( + start=Self.adjust_boundary(slice[0], shape[dim]), + end=Self.adjust_boundary(slice[1], shape[dim]), + step=step + ): + temp.append(i) + + # Option 2: Indeces + elif attrs[INDEX_LITERALS[dim]]: + var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape() + for i in range(indeces.rank()): + temp.append(indeces[i]) + + # All indeces + else: + for i in range(shape[dim]): + temp.append(i) + + indeces.append(temp) + + return indeces ^ + + @staticmethod + fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape: + var indeces = Self.to_indeces(shape, attrs) + var new_shape = List[Int]() + for i in range(shape.rank()): + new_shape.append(len(indeces[i])) + return TensorShape(new_shape) + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + + var j = 0 + for comb in product(indeces): + var flat_index = 0 + for dim in range(t1_shape.rank()): + flat_index += comb[dim] * strides[dim] + res[j] = t1[flat_index] + j += 1 + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector = AttributeVector(), + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + return Tensor[dtype]() \ No newline at end of file diff --git a/basalt/autograd/ops/ops.mojo b/basalt/autograd/ops/ops.mojo index 7198270..c737821 100644 --- a/basalt/autograd/ops/ops.mojo +++ b/basalt/autograd/ops/ops.mojo @@ -15,7 +15,7 @@ from .basics import ( TRANSPOSE, FMA, ) -from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE +from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX from .dynamics import CONCAT, SPLIT from .conv import CONV2D from .pool import MAXPOOL2D @@ -61,6 +61,7 @@ struct OP(Stringable): alias CONCAT = OP(23, "CONCAT", dynamic=True) alias SPLIT = OP(24, "SPLIT", dynamic=True) alias SLICE = OP(25, "SLICE") + alias INDEX = OP(26, "INDEX") var id: UInt8 var name: Bytes[16] @@ -135,6 +136,8 @@ fn static_result_shape( return UNSQUEEZE.result_shape(t1_shape, attributes) elif op == OP.SLICE: return SLICE.result_shape(t1_shape, attributes) + elif op == OP.INDEX: + return INDEX.result_shape(t1_shape, attributes) else: print("[ERROR] Operator not found.") return TensorShape(-1) @@ -249,6 +252,8 @@ fn forward_op[ UNSQUEEZE.forward[t1_shape, attributes](res, t1) elif op == OP.SLICE: SLICE.forward[t1_shape, attributes](res, t1) + elif op == OP.INDEX: + INDEX.forward[t1_shape, attributes](res, t1) else: print("[ERROR] Operator not found.") @@ -361,6 +366,8 @@ fn backward_op[ res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1) elif op == OP.SLICE: res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1) + elif op == OP.INDEX: + res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1) diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo new file mode 100644 index 0000000..aceda31 --- /dev/null +++ b/basalt/utils/itertools.mojo @@ -0,0 +1,47 @@ + +@value +struct _ProductIterator(Sized): + var lists: List[List[Int]] + var indeces: List[Int] + var _iters: Int + + @always_inline("nodebug") + fn __init__(inout self, lists: List[List[Int]]): + self.lists = lists + self.indeces = List[Int]() + for i in range(len(lists)): + self.indeces.append(0) + + self._iters = 1 + for lst in self.lists: + self._iters *= len(lst[]) + + @always_inline("nodebug") + fn __len__(self) -> Int: + return self._iters + + @always_inline("nodebug") + fn __iter__(self) -> Self: + return self + + @always_inline("nodebug") + fn __next__(inout self) -> List[Int]: + var res = List[Int]() + for i in range(len(self.lists)): + res.append(self.lists[i][self.indeces[i]]) + self._increment_indeces() + self._iters -= 1 + return res ^ + + @always_inline("nodebug") + fn _increment_indeces(inout self): + for i in reversed(range(len(self.indeces))): + self.indeces[i] += 1 + if self.indeces[i] < len(self.lists[i]): + break + self.indeces[i] = 0 + + +@always_inline("nodebug") +fn product(lists: List[List[Int]]) -> _ProductIterator: + return _ProductIterator(lists) \ No newline at end of file diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 2ba723e..4d87bb1 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -620,33 +620,63 @@ fn test_backward_SLICE_multiple_axes() raises: ](t1, ug, expected_ug) +from basalt.autograd.ops.mlops import INDEX + +fn test_INDEX() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + # t[:, [0, 0], 0:5:2] + # TODO: need for a list attribute as this only supports to specify indeces of MAX_RANK + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + var expected = Tensor[dtype](2, 2, 3) + for i in range(2): + for j in range(2): + for k in range(3): + expected[i*2*3 + j*3 + k] = i * 3 * 5 + k * 2 + + test_unary_op[ + OP.INDEX, t1_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, expected) + + print(expected) + + fn main(): try: - test_SIGMOID() - test_RELU() - test_TANH() - test_CLIP() - test_SQUEEZE() - test_UNSQUEEZE() - test_SLICE() - test_SLICE_step() - test_SLICE_neg() - test_SLICE_multiple_axes() + # test_SIGMOID() + # test_RELU() + # test_TANH() + # test_CLIP() + # test_SQUEEZE() + # test_UNSQUEEZE() + # test_SLICE() + # test_SLICE_step() + # test_SLICE_neg() + # test_SLICE_multiple_axes() + test_INDEX() except e: print("[ERROR] Error in forward mlops") print(e) return - try: - test_backward_SIGMOID() - test_backward_RELU() - test_backward_TANH() - test_backward_CLIP() - test_backward_SQUEEZE() - test_backward_UNSQUEEZE() - test_backward_SLICE() - test_backward_SLICE_multiple_axes() - except e: - print("[ERROR] Error in backward mlops") - print(e) - return + # try: + # test_backward_SIGMOID() + # test_backward_RELU() + # test_backward_TANH() + # test_backward_CLIP() + # test_backward_SQUEEZE() + # test_backward_UNSQUEEZE() + # test_backward_SLICE() + # test_backward_SLICE_multiple_axes() + # except e: + # print("[ERROR] Error in backward mlops") + # print(e) + # return From 60e510844596ede4fe51a22ed1c54fe27fc50292 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 14:56:27 +0200 Subject: [PATCH 2/6] unoptimized index bw --- basalt/autograd/ops/mlops.mojo | 15 ++++++++- tests/mojo/test_mlops.mojo | 57 ++++++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 0f9bb1f..5aa2d8b 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -567,4 +567,17 @@ struct INDEX: t1_shape: TensorShape, attributes: AttributeVector = AttributeVector(), ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: - return Tensor[dtype]() \ No newline at end of file + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + + var res_grad = Tensor[dtype](t1_shape) + + var j = 0 + for comb in product(indeces): + var flat_index = 0 + for dim in range(t1_shape.rank()): + flat_index += comb[dim] * strides[dim] + res_grad[flat_index] += ug[j] + j += 1 + + return res_grad^ \ No newline at end of file diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 4d87bb1..964e134 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -649,6 +649,36 @@ fn test_INDEX() raises: print(expected) +fn test_INDEX_backward() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + alias ug_shape = TensorShape(2, 2, 3) + var ug = Tensor[dtype](ug_shape) + fill(ug, 1.0) + + var expected = Tensor[dtype](t1_shape) + for i in range(2): + for j in range(2): + for k in range(3): + # NOTE: `+=` because selected indeces [0, 0] can repeat + expected[i * 3 * 5 + k * 2] += 1.0 + + test_unary_op_backward[ + OP.INDEX, t1_shape, ug_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, ug, expected) + + print(expected) + + fn main(): try: # test_SIGMOID() @@ -667,16 +697,17 @@ fn main(): print(e) return - # try: - # test_backward_SIGMOID() - # test_backward_RELU() - # test_backward_TANH() - # test_backward_CLIP() - # test_backward_SQUEEZE() - # test_backward_UNSQUEEZE() - # test_backward_SLICE() - # test_backward_SLICE_multiple_axes() - # except e: - # print("[ERROR] Error in backward mlops") - # print(e) - # return + try: + # test_backward_SIGMOID() + # test_backward_RELU() + # test_backward_TANH() + # test_backward_CLIP() + # test_backward_SQUEEZE() + # test_backward_UNSQUEEZE() + # test_backward_SLICE() + # test_backward_SLICE_multiple_axes() + test_INDEX_backward() + except e: + print("[ERROR] Error in backward mlops") + print(e) + return From 8d90c09d852356ac5f30c8328d5549d785e9c827 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 18:10:38 +0200 Subject: [PATCH 3/6] getindex to product & vectorized fw --- basalt/autograd/ops/mlops.mojo | 55 +++++++++++++++++++++++++++++----- basalt/utils/itertools.mojo | 34 +++++++++++---------- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 5aa2d8b..fd871fd 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -545,6 +545,26 @@ struct INDEX: new_shape.append(len(indeces[i])) return TensorShape(new_shape) + @staticmethod + fn map_indeces[ + nelts: Int, + strides: TensorShape, + indeces: List[List[Int]], + ](idx: Int) -> SIMD[DType.int64, nelts]: + alias indeces_product = product(indeces) + + var temp = SIMD[DType.int64, nelts]() + for i in range(idx, idx + nelts): + var comb = indeces_product[i] + var flat_index = 0 + + for dim in range(len(comb)): + flat_index += comb[dim] * strides[dim] + + temp[i % nelts] = flat_index + + return temp + @staticmethod fn forward[ t1_shape: TensorShape, @@ -552,14 +572,17 @@ struct INDEX: ](inout res: Tensor[dtype], t1: Tensor[dtype]): alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + @parameter + fn vec_index[nelts: Int](i: Int): + + res.store[nelts](i, + t1.data().gather(Self.map_indeces[nelts, strides, indeces](i)) + ) + + vectorize[vec_index, nelts](total_length) - var j = 0 - for comb in product(indeces): - var flat_index = 0 - for dim in range(t1_shape.rank()): - flat_index += comb[dim] * strides[dim] - res[j] = t1[flat_index] - j += 1 @staticmethod fn backward[ @@ -569,9 +592,25 @@ struct INDEX: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() + # alias total_length = len(product(indeces)) var res_grad = Tensor[dtype](t1_shape) + # @parameter + # fn vec_index[nelts: Int](i: Int): + + # var offset = Self.map_indeces[nelts, strides, indeces](i) + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + + # vectorize[vec_index, nelts](total_length) + + # BUG: Edge case in vectorization: + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + var j = 0 for comb in product(indeces): var flat_index = 0 @@ -579,5 +618,5 @@ struct INDEX: flat_index += comb[dim] * strides[dim] res_grad[flat_index] += ug[j] j += 1 - + return res_grad^ \ No newline at end of file diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo index aceda31..fd7a6ce 100644 --- a/basalt/utils/itertools.mojo +++ b/basalt/utils/itertools.mojo @@ -2,16 +2,14 @@ @value struct _ProductIterator(Sized): var lists: List[List[Int]] - var indeces: List[Int] + var _current: Int var _iters: Int @always_inline("nodebug") fn __init__(inout self, lists: List[List[Int]]): self.lists = lists - self.indeces = List[Int]() - for i in range(len(lists)): - self.indeces.append(0) - + self._current = 0 + self._iters = 1 for lst in self.lists: self._iters *= len(lst[]) @@ -26,20 +24,24 @@ struct _ProductIterator(Sized): @always_inline("nodebug") fn __next__(inout self) -> List[Int]: - var res = List[Int]() - for i in range(len(self.lists)): - res.append(self.lists[i][self.indeces[i]]) - self._increment_indeces() + self._current += 1 self._iters -= 1 - return res ^ + return self._get_combination(self._current - 1) + + @always_inline("nodebug") + fn _get_combination(self, current: Int) -> List[Int]: + var combination = List[Int]() + var count = current + for i in reversed(range(len(self.lists))): + var index = count % len(self.lists[i]) + combination.append(self.lists[i][index]) + count //= len(self.lists[i]) + combination._reverse() + return combination ^ @always_inline("nodebug") - fn _increment_indeces(inout self): - for i in reversed(range(len(self.indeces))): - self.indeces[i] += 1 - if self.indeces[i] < len(self.lists[i]): - break - self.indeces[i] = 0 + fn __getitem__(self, index: Int) -> List[Int]: + return self._get_combination(index) @always_inline("nodebug") From 113b5aeabbf5479638addbf3b9afe24c189c73aa Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 19:19:02 +0200 Subject: [PATCH 4/6] something inbetween --- basalt/autograd/ops/mlops.mojo | 39 ++++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index fd871fd..6e38aaa 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -592,31 +592,28 @@ struct INDEX: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() - # alias total_length = len(product(indeces)) + alias total_length = len(product(indeces)) var res_grad = Tensor[dtype](t1_shape) - # @parameter - # fn vec_index[nelts: Int](i: Int): + @parameter + fn vec_index[nelts: Int](i: Int): - # var offset = Self.map_indeces[nelts, strides, indeces](i) - # res_grad.data().scatter( - # offset, - # res_grad.data().gather(offset) + ug.load[nelts](i), - # ) - - # vectorize[vec_index, nelts](total_length) - - # BUG: Edge case in vectorization: - # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] - # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + var offset = Self.map_indeces[nelts, strides, indeces](i) + + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + # BUG: Edge case in vectorization: + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + + # Workaround + var u = ug.load[nelts](i) + for j in range(nelts): + res_grad[int(offset[j])] += u[j] - var j = 0 - for comb in product(indeces): - var flat_index = 0 - for dim in range(t1_shape.rank()): - flat_index += comb[dim] * strides[dim] - res_grad[flat_index] += ug[j] - j += 1 + vectorize[vec_index, nelts](total_length) return res_grad^ \ No newline at end of file From 0c56fa4673263929ae86fee75d7a99e29d9301d6 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Thu, 16 May 2024 00:57:56 +0200 Subject: [PATCH 5/6] upsampling nearest --- basalt/autograd/attributes.mojo | 5 + basalt/nn/__init__.mojo | 1 + basalt/nn/layers/upsample.mojo | 117 +++++++++++++++++++++++ tests/python/test_upsample.mojo | 159 ++++++++++++++++++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 basalt/nn/layers/upsample.mojo create mode 100644 tests/python/test_upsample.mojo diff --git a/basalt/autograd/attributes.mojo b/basalt/autograd/attributes.mojo index 9be1822..2e87300 100644 --- a/basalt/autograd/attributes.mojo +++ b/basalt/autograd/attributes.mojo @@ -67,6 +67,11 @@ struct AttributeVector(Sized, Stringable, CollectionElement): return self.attributes[i] return None + @always_inline("nodebug") + fn append(inout self, attribute: Attribute): + self.attributes[self.size] = attribute + self.size += 1 + @always_inline("nodebug") fn __str__(self) -> String: var s: String = "[" diff --git a/basalt/nn/__init__.mojo b/basalt/nn/__init__.mojo index 99b30a3..9c994a4 100644 --- a/basalt/nn/__init__.mojo +++ b/basalt/nn/__init__.mojo @@ -4,6 +4,7 @@ from .model import Model from .layers.linear import Linear from .layers.conv import Conv2d from .layers.pool import MaxPool2d +from .layers.upsample import Upsample from .loss import MSELoss, CrossEntropyLoss from .activations import Softmax, LogSoftmax, ReLU, Sigmoid, Tanh diff --git a/basalt/nn/layers/upsample.mojo b/basalt/nn/layers/upsample.mojo new file mode 100644 index 0000000..c70de00 --- /dev/null +++ b/basalt/nn/layers/upsample.mojo @@ -0,0 +1,117 @@ +from basalt import dtype +from basalt import Graph, Symbol, OP +from basalt import Tensor, TensorShape +from basalt.autograd.attributes import AttributeVector, Attribute +from basalt.utils.itertools import product + + +fn _scale_indeces(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> List[Scalar[dtype]]: + var M = int(scale * N) + var indeces = List[Scalar[dtype]]() + if align_corners: + for i in range(M): + indeces.append(i * ((N - 1) / (M - 1))) + else: + var step = 1 / scale + var start = ((M - 1) * step - N + 1) / 2 + for i in range(M): + indeces.append(i * step - start) + + return indeces ^ + + +fn nearest_coeffs(N: Int, scale: Scalar[dtype], dim: Int, ndims: Int) -> List[Int]: + + @parameter + fn round_to_index(number: Scalar[dtype]) -> Int: + return int(number + 0.5) if number > 0 else int(number - 0.5) + + var indeces = List[Int]() + var scaled = _scale_indeces(N, scale, True, dim, ndims) + for i in range(len(scaled)): + indeces.append(round_to_index(scaled[i])) + return indeces ^ + + +fn linear_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: + # TODO + return (List[Int](), List[Int]()) + + +fn cubic_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: + # TODO + return (List[Int](), List[Int]()) + + +fn interpolate_nd[ + indices_fn: fn (Int, Scalar[dtype], Bool, Int, Int) -> Tuple[List[Int], List[Int]], +](inout g: Graph, input: Symbol, scale_factors: List[Scalar[dtype]], align_corners: Bool) -> Symbol: + + var spatial_dims = input.shape.rank() - 2 + + var indeces_weights = List[Tuple[List[Int], List[Int]]]() + for i in range(spatial_dims): + indeces_weights.append( + indices_fn( + input.shape[i + 2], + scale_factors[i], + align_corners, + i, + spatial_dims, + ) + ) + + # TODO: interpolation logic + # for idx_weight in product(indeces_weights): + # ... + + return Symbol(-1, dtype, TensorShape(), False) + + +fn Upsample( + inout g: Graph, + input: Symbol, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool = False, +) -> Symbol: + + # Assumption: A scale needs to be provided for each spatial dimension. + # input shape (B, C, *N) with batch and channel considered as non-spatial dimensions. + # input.shape.rank() - 2 == len(scale_factor) + var spatial_dims = input.shape.rank() - 2 + + var res: Symbol + var attributes = AttributeVector() + var INDEX_LITERALS = List[StringLiteral]("dim_2i", "dim_3i", "dim_4i") + + if mode == "nearest": + # Nearest neighbor interpolation --> input[:, :, *indeces] + for i in range(spatial_dims): + attributes.append( + Attribute( + INDEX_LITERALS[i], + nearest_coeffs(input.shape[i + 2], scale_factors[i], i, spatial_dims) + ) + ) + + res = g.op(OP.INDEX, input, attributes=attributes) + + # elif mode == "linear": + # res = interpolate_nd[linear_coeffs](g, + # input, + # scale_factor, + # align_corners + # ) + + # elif mode == "cubic": + # res = interpolate_nd[cubic_coeffs](g, + # input, + # scale_factor, + # align_corners + # ) + else: + res = input + + return res + diff --git a/tests/python/test_upsample.mojo b/tests/python/test_upsample.mojo new file mode 100644 index 0000000..c5918ff --- /dev/null +++ b/tests/python/test_upsample.mojo @@ -0,0 +1,159 @@ +from python.python import Python, PythonObject + +import basalt.nn as nn +from basalt import dtype, Graph +from basalt import Tensor, TensorShape +from tests import assert_tensors_equal, to_numpy, to_tensor + + +fn test_upsample[ + shape: TensorShape, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool +]( + t1: Tensor[dtype], + ug: Tensor[dtype], + expected: Tensor[dtype], + expected_grad: Tensor[dtype] +) raises: + + fn create_graph() -> Graph: + var g = Graph() + var t1 = g.input(shape, trainable=True) + var t2 = nn.Upsample(g, t1, mode, scale_factors, align_corners) + g.out(t2) + return g ^ + + alias graph = create_graph() + var model = nn.Model[graph](inference_only=True) + var res = model.inference(t1)[0] + + model.backward(ug) + var res_grad = model.parameters.grads[graph.inputs[0]] + + assert_tensors_equal["almost"](res, expected) + assert_tensors_equal["almost"](res_grad, expected_grad) + + +@value +struct torch_upsample_result: + var expected: Tensor[dtype] + var grad: Tensor[dtype] + + +fn test_upsample_torch[ + shape: TensorShape, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool +](data: PythonObject, ug: PythonObject) raises -> torch_upsample_result: + + var py = Python.import_module("builtins") + var np = Python.import_module("numpy") + var torch = Python.import_module("torch") + + var py_scales = py.list() + for i in range(len(scale_factors)): + py_scales.append(scale_factors[i]) + + # if mode == "nearest": + # var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode) + # else: + # var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode, align_corners=align_corners) + + var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode) + + var tensor = torch.from_numpy(data).requires_grad_(True) + var expected = ups(tensor) + var upper_grad = torch.from_numpy(ug) + _ = expected.backward(upper_grad) + + return torch_upsample_result( + to_tensor(expected.detach().numpy()), + to_tensor(tensor.grad.numpy()), + ) + + + +fn test_UPSAMPLE_nearest() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 2, 2) + alias mode: StringLiteral = "nearest" + alias scales = List[Scalar[dtype]](2.0, 3.0) + alias align_corners = False + + var data = np.array([ + 1, 2, + 3, 4 + ], dtype=np.float32).reshape(1, 1, 2, 2) + + var ug = np.ones((1, 1, 4, 6)) + + var torch_out = test_upsample_torch[shape, mode, scales, align_corners](data, ug) + test_upsample[shape, mode, scales, align_corners]( + to_tensor(data), + to_tensor(ug), + torch_out.expected, + torch_out.grad + ) + + _ = data + + +fn test_UPSAMPLE_linear() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 2, 2) + alias mode: StringLiteral = "linear" + alias scales = List[Scalar[dtype]](2.0, 2.0) + + var data = np.array([ + 1, 2, + 3, 4 + ], dtype=np.float32).reshape(1, 1, 2, 2) + + # var expected = np.array([ + # 1., 1.25, 1.75, 2. , + # 1.5, 1.75, 2.25, 2.5 , + # 2.5, 2.75, 3.25, 3.5 , + # 3., 3.25, 3.75, 4. , + # ], dtype=np.float32).reshape(1, 1, 4, 4) + + +fn test_UPSAMPLE_cubic() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 4, 4) + alias mode: StringLiteral = "cubic" + alias scales = List[Scalar[dtype]](2.0, 2.0) + + var data = np.array([ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + ], dtype=np.float32).reshape(1, 1, 4, 4) + + # var expected = np.array([ + # 0.47265625, 0.76953125, 1.24609375, 1.875, 2.28125, 2.91015625, 3.38671875, 3.68359375, + # 1.66015625, 1.95703125, 2.43359375, 3.0625, 3.46875, 4.09765625, 4.57421875, 4.87109375, + # 3.56640625, 3.86328125, 4.33984375, 4.96875, 5.375, 6.00390625, 6.48046875, 6.77734375, + # 6.08203125, 6.37890625, 6.85546875, 7.484375, 7.890625, 8.51953125, 8.99609375, 9.29296875, + # 7.70703125, 8.00390625, 8.48046875, 9.109375, 9.515625, 10.14453125, 10.62109375, 10.91796875, + # 10.22265625, 10.51953125, 10.99609375, 11.625, 12.03125, 12.66015625, 13.13671875, 13.43359375, + # 12.12890625, 12.42578125, 12.90234375, 13.53125, 13.9375, 14.56640625, 15.04296875, 15.33984375, + # 13.31640625, 13.61328125, 14.08984375, 14.71875, 15.125, 15.75390625, 16.23046875, 16.52734375 + # ], dtype=np.float32).reshape(1, 1, 8, 8) + + +fn main(): + + try: + test_UPSAMPLE_nearest() + # test_UPSAMPLE_linear() + # test_UPSAMPLE_cubic() + except e: + print("[Error] Error in Upsample") + print(e) \ No newline at end of file From d00886d1511e5ef533f8e67c2b2184e86251cdaa Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Fri, 31 May 2024 09:30:10 +0200 Subject: [PATCH 6/6] temporary results --- basalt/nn/layers/upsample.mojo | 108 ++++++++++++++++++++++++++------- 1 file changed, 86 insertions(+), 22 deletions(-) diff --git a/basalt/nn/layers/upsample.mojo b/basalt/nn/layers/upsample.mojo index c70de00..f30ffe4 100644 --- a/basalt/nn/layers/upsample.mojo +++ b/basalt/nn/layers/upsample.mojo @@ -1,3 +1,5 @@ +from math import min, max, floor, ceil + from basalt import dtype from basalt import Graph, Symbol, OP from basalt import Tensor, TensorShape @@ -33,23 +35,52 @@ fn nearest_coeffs(N: Int, scale: Scalar[dtype], dim: Int, ndims: Int) -> List[In return indeces ^ -fn linear_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: - # TODO - return (List[Int](), List[Int]()) +alias Coeff = Tuple[List[Int], List[Scalar[dtype]]] +alias Coeffs = List[Coeff] + +fn linear_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Coeffs: + + var indeces_l = List[Int]() + var indeces_r = List[Int]() + var weights_l = List[Scalar[dtype]]() + var weights_r = List[Scalar[dtype]]() + for value in _scale_indeces(N, scale, align_corners, dim, ndims): + var clipped = min[dtype]((max[dtype](value[], 0)), N-1) + var idx_l = floor(clipped) + var idx_r = ceil(clipped) + + indeces_l.append(int(idx_l)) + indeces_r.append(int(idx_r)) + weights_l.append(1 - (clipped - idx_l)) + weights_r.append(clipped - idx_l) + + print(len(indeces_l), len(indeces_r), len(weights_l), len(weights_r)) + return List[Coeff]( + Tuple[List[Int]](indeces_l, weights_l), + Tuple(indeces_r, weights_r), + ) -fn cubic_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: + +fn cubic_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Coeffs: # TODO - return (List[Int](), List[Int]()) + return List[Coeff]( + Tuple(List[Int](), List[Scalar[dtype]]()), + Tuple(List[Int](), List[Scalar[dtype]]()), + ) + + + fn interpolate_nd[ - indices_fn: fn (Int, Scalar[dtype], Bool, Int, Int) -> Tuple[List[Int], List[Int]], + indices_fn: fn (Int, Scalar[dtype], Bool, Int, Int) -> Coeffs, ](inout g: Graph, input: Symbol, scale_factors: List[Scalar[dtype]], align_corners: Bool) -> Symbol: var spatial_dims = input.shape.rank() - 2 - var indeces_weights = List[Tuple[List[Int], List[Int]]]() + var temp = List[Int]() + var indeces_weights = List[Coeffs]() for i in range(spatial_dims): indeces_weights.append( indices_fn( @@ -61,9 +92,40 @@ fn interpolate_nd[ ) ) - # TODO: interpolation logic - # for idx_weight in product(indeces_weights): - # ... + temp.append(i) + + @parameter + fn get_comb_idx(dim: Int, coeff_id: Int) -> List[Int]: + return indeces_weights[dim][coeff_id].get[0, List[Int]]() + + @parameter + fn get_comb_weight(dim: Int, coeff_id: Int) -> List[Scalar[dtype]]: + return indeces_weights[dim][coeff_id].get[1, List[Scalar[dtype]]]() + + var indeces_weights_copy = indeces_weights + + for comb_id in product(List[List[Int]](temp, temp)): + print("----") + + for i in range(spatial_dims): + print("D", i,"COMB", comb_id[i]) + print(len(indeces_weights), len(indeces_weights[i])) + # var temp = indeces_weights[i][comb_id[i]].get[0, List[Int]]() + var temp = indeces_weights_copy[i][comb_id[i]].get[0, List[Int]]()[0] + # print(len(temp)) + # var idx = get_comb_idx(i, comb_id[i]) + # var weight = get_comb_weight(i, comb_id[i]) + + # for j in range(len(idx)): + # print(idx[j], weight[j]) + + + # # for i in range(len(comb_id)): + # # var iw_l = indeces_weights[0] + # # var iw_r = indeces_weights[1] + + # # print(comb_id[i]) + return Symbol(-1, dtype, TensorShape(), False) @@ -97,20 +159,22 @@ fn Upsample( res = g.op(OP.INDEX, input, attributes=attributes) - # elif mode == "linear": - # res = interpolate_nd[linear_coeffs](g, - # input, - # scale_factor, - # align_corners - # ) + elif mode == "linear": + res = interpolate_nd[linear_coeffs](g, + input, + scale_factors, + align_corners + ) - # elif mode == "cubic": - # res = interpolate_nd[cubic_coeffs](g, - # input, - # scale_factor, - # align_corners - # ) + elif mode == "cubic": + res = interpolate_nd[cubic_coeffs](g, + input, + scale_factors, + align_corners + ) + else: + print("[ERROR] Upsampling mode not supported") res = input return res