diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 2a0c2d020b..884a259c19 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -406,3 +406,37 @@ struct Span[ return Span[T, ImmutableOrigin.cast_from[origin].result]( ptr=self._data, length=self._len ) + + fn reverse[D: DType, O: MutableOrigin, //](self: Span[Scalar[D], O]): + """Reverse the elements of the Span inplace. + + Parameters: + D: The DType of the scalars the Span stores. + O: The origin of the Span. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4, 2) + var ptr = self.unsafe_ptr() + var length = len(self) + middle, is_odd = length // 2, length % 2 != 0 + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((middle - processed) // w): + var lhs_ptr = ptr + processed + var rhs_ptr = ptr + length - (processed + w) + var lhs_v = lhs_ptr.load[width=w]().reversed() + var rhs_v = rhs_ptr.load[width=w]().reversed() + lhs_ptr.store(rhs_v) + rhs_ptr.store(lhs_v) + processed += w + + if is_odd: + var value = ptr[middle + 1] + (ptr + middle - 1).move_pointee_into(ptr + middle + 1) + (ptr + middle - 1).init_pointee_move(value) diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 4a3b6dd980..ef62fb6a92 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -208,6 +208,30 @@ def test_reversed(): i += 1 +def test_reverse(): + def _test_dtype[D: DType](): + forward = InlineArray[Scalar[D], 11](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + backward = InlineArray[Scalar[D], 11](11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + s = Span(forward) + s.reverse() + i = 0 + for num in s: + assert_equal(num[], backward[i]) + i += 1 + + _test_dtype[DType.uint8]() + _test_dtype[DType.uint16]() + _test_dtype[DType.uint32]() + _test_dtype[DType.uint64]() + _test_dtype[DType.int8]() + _test_dtype[DType.int16]() + _test_dtype[DType.int32]() + _test_dtype[DType.int64]() + _test_dtype[DType.float16]() + _test_dtype[DType.float32]() + _test_dtype[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -221,3 +245,4 @@ def main(): test_fill() test_ref() test_reversed() + test_reverse()