diff --git a/mojo/stdlib/stdlib/memory/span.mojo b/mojo/stdlib/stdlib/memory/span.mojo index 1dc09eacbbf..6c3aa183372 100644 --- a/mojo/stdlib/stdlib/memory/span.mojo +++ b/mojo/stdlib/stdlib/memory/span.mojo @@ -608,3 +608,40 @@ struct Span[ ptr=self._data.origin_cast[result.mut, result.origin](), length=self._len, ) + + fn reverse[ + dtype: DType, O: MutableOrigin, // + ](self: Span[Scalar[dtype], O]): + """Reverse the elements of the `Span` inplace. + + Parameters: + dtype: The DType of the scalars the `Span` stores. + O: The origin of the `Span`. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4, 2) + var ptr = self.unsafe_ptr() + var length = len(self) + var middle = length // 2 + var is_odd = length % 2 != 0 + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths[i] + + @parameter + if simdwidthof[dtype]() >= w: + for _ in range((middle - processed) // w): + var lhs_ptr = ptr + processed + var rhs_ptr = ptr + length - (processed + w) + var lhs_v = lhs_ptr.load[width=w]().reversed() + var rhs_v = rhs_ptr.load[width=w]().reversed() + lhs_ptr.store(rhs_v) + rhs_ptr.store(lhs_v) + processed += w + + if is_odd: + var value = ptr[middle + 1] + (ptr + middle - 1).move_pointee_into(ptr + middle + 1) + (ptr + middle - 1).init_pointee_move(value) diff --git a/mojo/stdlib/test/memory/test_span.mojo b/mojo/stdlib/test/memory/test_span.mojo index 67f802f4173..26658b572c6 100644 --- a/mojo/stdlib/test/memory/test_span.mojo +++ b/mojo/stdlib/test/memory/test_span.mojo @@ -271,6 +271,30 @@ def test_span_repr(): assert_equal(s.__repr__(), "[1, 2]") +def test_reverse(): + def _test_dtype[D: DType](): + forward = InlineArray[Scalar[D], 11](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + backward = InlineArray[Scalar[D], 11](11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + s = Span(forward) + s.reverse() + i = 0 + for num in s: + assert_equal(num, backward[i]) + i += 1 + + _test_dtype[DType.uint8]() + _test_dtype[DType.uint16]() + _test_dtype[DType.uint32]() + _test_dtype[DType.uint64]() + _test_dtype[DType.int8]() + _test_dtype[DType.int16]() + _test_dtype[DType.int32]() + _test_dtype[DType.int64]() + _test_dtype[DType.float16]() + _test_dtype[DType.float32]() + _test_dtype[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -288,3 +312,4 @@ def main(): test_merge() test_span_to_string() test_span_repr() + test_reverse()