diff --git a/mojo/stdlib/src/memory/span.mojo b/mojo/stdlib/src/memory/span.mojo index 70a0482a52c..8cde9856fe3 100644 --- a/mojo/stdlib/src/memory/span.mojo +++ b/mojo/stdlib/src/memory/span.mojo @@ -473,3 +473,37 @@ struct Span[ address_space=address_space, alignment=alignment, ](ptr=self._data, length=self._len) + + fn reverse[D: DType, O: MutableOrigin, //](self: Span[Scalar[D], O]): + """Reverse the elements of the Span inplace. + + Parameters: + D: The DType of the scalars the Span stores. + O: The origin of the Span. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4, 2) + var ptr = self.unsafe_ptr() + var length = len(self) + middle, is_odd = length // 2, length % 2 != 0 + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths[i] + + @parameter + if simdwidthof[D]() >= w: + for _ in range((middle - processed) // w): + var lhs_ptr = ptr + processed + var rhs_ptr = ptr + length - (processed + w) + var lhs_v = lhs_ptr.load[width=w]().reversed() + var rhs_v = rhs_ptr.load[width=w]().reversed() + lhs_ptr.store(rhs_v) + rhs_ptr.store(lhs_v) + processed += w + + if is_odd: + var value = ptr[middle + 1] + (ptr + middle - 1).move_pointee_into(ptr + middle + 1) + (ptr + middle - 1).init_pointee_move(value) diff --git a/mojo/stdlib/test/memory/test_span.mojo b/mojo/stdlib/test/memory/test_span.mojo index 629f8b95557..72d7419a65a 100644 --- a/mojo/stdlib/test/memory/test_span.mojo +++ b/mojo/stdlib/test/memory/test_span.mojo @@ -221,6 +221,30 @@ def test_span_coerce(): takes_span(a) +def test_reverse(): + def _test_dtype[D: DType](): + forward = InlineArray[Scalar[D], 11](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + backward = InlineArray[Scalar[D], 11](11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + s = Span(forward) + s.reverse() + i = 0 + for num in s: + assert_equal(num[], backward[i]) + i += 1 + + _test_dtype[DType.uint8]() + _test_dtype[DType.uint16]() + _test_dtype[DType.uint32]() + _test_dtype[DType.uint64]() + _test_dtype[DType.int8]() + _test_dtype[DType.int16]() + _test_dtype[DType.int32]() + _test_dtype[DType.int64]() + _test_dtype[DType.float16]() + _test_dtype[DType.float32]() + _test_dtype[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -234,3 +258,4 @@ def main(): test_fill() test_ref() test_reversed() + test_reverse()