From 9de5a3185d2df85593eaaac1159dfe398c7454bc Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 12 Dec 2024 19:20:51 -0300 Subject: [PATCH 1/2] add vectorized Span.reverse() Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 35 +++++++++++++++++++++++++++++++ stdlib/test/memory/test_span.mojo | 25 ++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 03f860f899..987aec2beb 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -23,6 +23,7 @@ from memory import Span from collections import InlineArray from memory import Pointer, UnsafePointer +from sys.info import simdwidthof trait AsBytes: @@ -371,3 +372,37 @@ struct Span[ return Span[T, ImmutableOrigin.cast_from[origin].result]( ptr=self._data, length=self._len ) + + fn reverse[D: DType, O: MutableOrigin, //](self: Span[Scalar[D], O]): + """Reverse the elements of the Span inplace. + + Parameters: + D: The DType of the scalars the Span stores. + O: The origin of the Span. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4, 2) + var ptr = self.unsafe_ptr() + var length = len(self) + var middle = length // 2 + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((middle - processed) // w): + var lhs_ptr = ptr + processed + var rhs_ptr = ptr + length - (processed + w) + var lhs_v = lhs_ptr.load[width=w]().reversed() + var rhs_v = rhs_ptr.load[width=w]().reversed() + lhs_ptr.store(rhs_v) + rhs_ptr.store(lhs_v) + processed += w + + if middle * 2 != length: + var value = ptr[middle + 1] + (ptr + middle - 1).move_pointee_into(ptr + middle + 1) + (ptr + middle - 1).init_pointee_move(value) diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 92c49210c6..7e956113ff 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -199,6 +199,30 @@ def test_reversed(): i += 1 +def test_reverse(): + def _test_dtype[D: DType](): + forward = InlineArray[Scalar[D], 11](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + backward = InlineArray[Scalar[D], 11](11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + s = Span(forward) + s.reverse() + i = 0 + for num in s: + assert_equal(num[], backward[i]) + i += 1 + + _test_dtype[DType.uint8]() + _test_dtype[DType.uint16]() + _test_dtype[DType.uint32]() + _test_dtype[DType.uint64]() + _test_dtype[DType.int8]() + _test_dtype[DType.int16]() + _test_dtype[DType.int32]() + _test_dtype[DType.int64]() + _test_dtype[DType.float16]() + _test_dtype[DType.float32]() + _test_dtype[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -211,3 +235,4 @@ def main(): test_fill() test_ref() test_reversed() + test_reverse() From b7793960f81dcfc3722922be816bf04e79b58dfc Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 12 Dec 2024 19:34:15 -0300 Subject: [PATCH 2/2] fix detail Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 987aec2beb..84a4a5c9df 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -384,7 +384,7 @@ struct Span[ alias widths = (256, 128, 64, 32, 16, 8, 4, 2) var ptr = self.unsafe_ptr() var length = len(self) - var middle = length // 2 + middle, is_odd = length // 2, length % 2 != 0 var processed = 0 @parameter @@ -402,7 +402,7 @@ struct Span[ rhs_ptr.store(lhs_v) processed += w - if middle * 2 != length: + if is_odd: var value = ptr[middle + 1] (ptr + middle - 1).move_pointee_into(ptr + middle + 1) (ptr + middle - 1).init_pointee_move(value)