Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Upsample #81

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions basalt/autograd/attributes.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ struct AttributeVector(Sized, Stringable, CollectionElement):
return self.attributes[i]
return None

@always_inline("nodebug")
fn append(inout self, attribute: Attribute):
self.attributes[self.size] = attribute
self.size += 1

@always_inline("nodebug")
fn __str__(self) -> String:
var s: String = "["
Expand Down
127 changes: 126 additions & 1 deletion basalt/autograd/ops/mlops.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from math.limit import min_finite, max_finite

from basalt import Tensor, TensorShape
from basalt.utils.tensorutils import elwise_transform
from basalt.utils.itertools import product
from basalt.autograd.attributes import Attribute, AttributeVector


Expand Down Expand Up @@ -491,4 +492,128 @@ struct SLICE:

Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug)

return res_grad ^
return res_grad ^


struct INDEX:
@staticmethod
fn adjust_boundary(slice: Int, dim_size: Int) -> Int:
# Adjust negative indices & ensure they are within bounds.
var s = slice if slice >= 0 else dim_size + slice
return max(min(s, dim_size), 0)

@staticmethod
fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]:
var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s")
var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i")

var indeces = List[List[Int]]()
for dim in range(shape.rank()):
var temp = List[Int]()

# Option 1: Slice
if attrs[SLICE_LITERALS[dim]]:
var slice = attrs[SLICE_LITERALS[dim]].value().to_shape()
var step = slice[2] if slice.rank() == 3 else 1
for i in range(
start=Self.adjust_boundary(slice[0], shape[dim]),
end=Self.adjust_boundary(slice[1], shape[dim]),
step=step
):
temp.append(i)

# Option 2: Indeces
elif attrs[INDEX_LITERALS[dim]]:
var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape()
for i in range(indeces.rank()):
temp.append(indeces[i])

# All indeces
else:
for i in range(shape[dim]):
temp.append(i)

indeces.append(temp)

return indeces ^

@staticmethod
fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape:
var indeces = Self.to_indeces(shape, attrs)
var new_shape = List[Int]()
for i in range(shape.rank()):
new_shape.append(len(indeces[i]))
return TensorShape(new_shape)

@staticmethod
fn map_indeces[
nelts: Int,
strides: TensorShape,
indeces: List[List[Int]],
](idx: Int) -> SIMD[DType.int64, nelts]:
alias indeces_product = product(indeces)

var temp = SIMD[DType.int64, nelts]()
for i in range(idx, idx + nelts):
var comb = indeces_product[i]
var flat_index = 0

for dim in range(len(comb)):
flat_index += comb[dim] * strides[dim]

temp[i % nelts] = flat_index

return temp

@staticmethod
fn forward[
t1_shape: TensorShape,
attributes: AttributeVector,
](inout res: Tensor[dtype], t1: Tensor[dtype]):
alias indeces = Self.to_indeces(t1_shape, attributes)
alias strides = t1_shape.strides()
alias total_length = len(product(indeces))

@parameter
fn vec_index[nelts: Int](i: Int):

res.store[nelts](i,
t1.data().gather(Self.map_indeces[nelts, strides, indeces](i))
)

vectorize[vec_index, nelts](total_length)


@staticmethod
fn backward[
ug_shape: TensorShape,
t1_shape: TensorShape,
attributes: AttributeVector = AttributeVector(),
](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]:
alias indeces = Self.to_indeces(t1_shape, attributes)
alias strides = t1_shape.strides()
alias total_length = len(product(indeces))

var res_grad = Tensor[dtype](t1_shape)

@parameter
fn vec_index[nelts: Int](i: Int):

var offset = Self.map_indeces[nelts, strides, indeces](i)

# res_grad.data().scatter(
# offset,
# res_grad.data().gather(offset) + ug.load[nelts](i),
# )
# BUG: Edge case in vectorization:
# When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1]
# It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1

# Workaround
var u = ug.load[nelts](i)
for j in range(nelts):
res_grad[int(offset[j])] += u[j]

vectorize[vec_index, nelts](total_length)

return res_grad^
9 changes: 8 additions & 1 deletion basalt/autograd/ops/ops.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from .basics import (
TRANSPOSE,
FMA,
)
from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE
from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX
from .dynamics import CONCAT, SPLIT
from .conv import CONV2D
from .pool import MAXPOOL2D
Expand Down Expand Up @@ -61,6 +61,7 @@ struct OP(Stringable):
alias CONCAT = OP(23, "CONCAT", dynamic=True)
alias SPLIT = OP(24, "SPLIT", dynamic=True)
alias SLICE = OP(25, "SLICE")
alias INDEX = OP(26, "INDEX")

var id: UInt8
var name: Bytes[16]
Expand Down Expand Up @@ -135,6 +136,8 @@ fn static_result_shape(
return UNSQUEEZE.result_shape(t1_shape, attributes)
elif op == OP.SLICE:
return SLICE.result_shape(t1_shape, attributes)
elif op == OP.INDEX:
return INDEX.result_shape(t1_shape, attributes)
else:
print("[ERROR] Operator not found.")
return TensorShape(-1)
Expand Down Expand Up @@ -249,6 +252,8 @@ fn forward_op[
UNSQUEEZE.forward[t1_shape, attributes](res, t1)
elif op == OP.SLICE:
SLICE.forward[t1_shape, attributes](res, t1)
elif op == OP.INDEX:
INDEX.forward[t1_shape, attributes](res, t1)
else:
print("[ERROR] Operator not found.")

Expand Down Expand Up @@ -361,6 +366,8 @@ fn backward_op[
res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1)
elif op == OP.SLICE:
res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1)
elif op == OP.INDEX:
res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1)
else:
print("[ERROR] Operator not found.")
res_grad = Tensor[dtype](-1)
Expand Down
1 change: 1 addition & 0 deletions basalt/nn/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from .model import Model
from .layers.linear import Linear
from .layers.conv import Conv2d
from .layers.pool import MaxPool2d
from .layers.upsample import Upsample

from .loss import MSELoss, CrossEntropyLoss
from .activations import Softmax, LogSoftmax, ReLU, Sigmoid, Tanh
181 changes: 181 additions & 0 deletions basalt/nn/layers/upsample.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from math import min, max, floor, ceil

from basalt import dtype
from basalt import Graph, Symbol, OP
from basalt import Tensor, TensorShape
from basalt.autograd.attributes import AttributeVector, Attribute
from basalt.utils.itertools import product


fn _scale_indeces(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> List[Scalar[dtype]]:
var M = int(scale * N)
var indeces = List[Scalar[dtype]]()
if align_corners:
for i in range(M):
indeces.append(i * ((N - 1) / (M - 1)))
else:
var step = 1 / scale
var start = ((M - 1) * step - N + 1) / 2
for i in range(M):
indeces.append(i * step - start)

return indeces ^


fn nearest_coeffs(N: Int, scale: Scalar[dtype], dim: Int, ndims: Int) -> List[Int]:

@parameter
fn round_to_index(number: Scalar[dtype]) -> Int:
return int(number + 0.5) if number > 0 else int(number - 0.5)

var indeces = List[Int]()
var scaled = _scale_indeces(N, scale, True, dim, ndims)
for i in range(len(scaled)):
indeces.append(round_to_index(scaled[i]))
return indeces ^


alias Coeff = Tuple[List[Int], List[Scalar[dtype]]]
alias Coeffs = List[Coeff]

fn linear_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Coeffs:

var indeces_l = List[Int]()
var indeces_r = List[Int]()
var weights_l = List[Scalar[dtype]]()
var weights_r = List[Scalar[dtype]]()
for value in _scale_indeces(N, scale, align_corners, dim, ndims):
var clipped = min[dtype]((max[dtype](value[], 0)), N-1)
var idx_l = floor(clipped)
var idx_r = ceil(clipped)

indeces_l.append(int(idx_l))
indeces_r.append(int(idx_r))
weights_l.append(1 - (clipped - idx_l))
weights_r.append(clipped - idx_l)

print(len(indeces_l), len(indeces_r), len(weights_l), len(weights_r))

return List[Coeff](
Tuple[List[Int]](indeces_l, weights_l),
Tuple(indeces_r, weights_r),
)


fn cubic_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Coeffs:
# TODO
return List[Coeff](
Tuple(List[Int](), List[Scalar[dtype]]()),
Tuple(List[Int](), List[Scalar[dtype]]()),
)





fn interpolate_nd[
indices_fn: fn (Int, Scalar[dtype], Bool, Int, Int) -> Coeffs,
](inout g: Graph, input: Symbol, scale_factors: List[Scalar[dtype]], align_corners: Bool) -> Symbol:

var spatial_dims = input.shape.rank() - 2

var temp = List[Int]()
var indeces_weights = List[Coeffs]()
for i in range(spatial_dims):
indeces_weights.append(
indices_fn(
input.shape[i + 2],
scale_factors[i],
align_corners,
i,
spatial_dims,
)
)

temp.append(i)

@parameter
fn get_comb_idx(dim: Int, coeff_id: Int) -> List[Int]:
return indeces_weights[dim][coeff_id].get[0, List[Int]]()

@parameter
fn get_comb_weight(dim: Int, coeff_id: Int) -> List[Scalar[dtype]]:
return indeces_weights[dim][coeff_id].get[1, List[Scalar[dtype]]]()

var indeces_weights_copy = indeces_weights

for comb_id in product(List[List[Int]](temp, temp)):
print("----")

for i in range(spatial_dims):
print("D", i,"COMB", comb_id[i])
print(len(indeces_weights), len(indeces_weights[i]))
# var temp = indeces_weights[i][comb_id[i]].get[0, List[Int]]()
var temp = indeces_weights_copy[i][comb_id[i]].get[0, List[Int]]()[0]
# print(len(temp))
# var idx = get_comb_idx(i, comb_id[i])
# var weight = get_comb_weight(i, comb_id[i])

# for j in range(len(idx)):
# print(idx[j], weight[j])


# # for i in range(len(comb_id)):
# # var iw_l = indeces_weights[0]
# # var iw_r = indeces_weights[1]

# # print(comb_id[i])


return Symbol(-1, dtype, TensorShape(), False)


fn Upsample(
inout g: Graph,
input: Symbol,
mode: StringLiteral,
scale_factors: List[Scalar[dtype]],
align_corners: Bool = False,
) -> Symbol:

# Assumption: A scale needs to be provided for each spatial dimension.
# input shape (B, C, *N) with batch and channel considered as non-spatial dimensions.
# input.shape.rank() - 2 == len(scale_factor)
var spatial_dims = input.shape.rank() - 2

var res: Symbol
var attributes = AttributeVector()
var INDEX_LITERALS = List[StringLiteral]("dim_2i", "dim_3i", "dim_4i")

if mode == "nearest":
# Nearest neighbor interpolation --> input[:, :, *indeces]
for i in range(spatial_dims):
attributes.append(
Attribute(
INDEX_LITERALS[i],
nearest_coeffs(input.shape[i + 2], scale_factors[i], i, spatial_dims)
)
)

res = g.op(OP.INDEX, input, attributes=attributes)

elif mode == "linear":
res = interpolate_nd[linear_coeffs](g,
input,
scale_factors,
align_corners
)

elif mode == "cubic":
res = interpolate_nd[cubic_coeffs](g,
input,
scale_factors,
align_corners
)

else:
print("[ERROR] Upsampling mode not supported")
res = input

return res

Loading