Skip to content

Conversation

@christiangnrd
Copy link
Member

Ported from CUDA.jl
Close #691

@github-actions
Copy link
Contributor

github-actions bot commented Nov 23, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/random.jl b/src/random.jl
index a545533b..3c55d9a9 100644
--- a/src/random.jl
+++ b/src/random.jl
@@ -12,7 +12,7 @@ mutable struct RNG <: AbstractRNG
     counter::UInt32
 
     function RNG(seed::Integer)
-        new(seed%UInt32, 0)
+        return new(seed % UInt32, 0)
     end
     RNG(seed::UInt32, counter::UInt32) = new(seed, counter)
 end
@@ -27,7 +27,7 @@ Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter)
 
 function Random.seed!(rng::RNG, seed::Integer)
     rng.seed = seed % UInt32
-    rng.counter = 0
+    return rng.counter = 0
 end
 
 Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())
@@ -68,21 +68,21 @@ function Random.rand!(rng::RNG, A::WrappedMtlArray)
     threads = 32
     groups = cld(length(A), threads)
 
-    @metal threads groups name="rand!" kernel(A, rng.seed, rng.counter)
+    @metal threads groups name = "rand!" kernel(A, rng.seed, rng.counter)
 
     new_counter = Int64(rng.counter) + length(A)
     overflow, remainder = fldmod(new_counter, typemax(UInt32))
     rng.seed += overflow     # XXX: is this OK?
     rng.counter = remainder
 
-    A
+    return A
 end
 
-function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Complex{<:AbstractFloat}}})
+function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat, Complex{<:AbstractFloat}}})
     isempty(A) && return A
 
     ## COV_EXCL_START
-    function kernel(A::AbstractArray{T}, seed::UInt32, counter::UInt32) where {T<:Real}
+    function kernel(A::AbstractArray{T}, seed::UInt32, counter::UInt32) where {T <: Real}
         device_rng = Random.default_rng()
 
         # initialize the state
@@ -102,20 +102,20 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
                     U1 = Random.rand(device_rng, T)
                 end
                 U2 = Random.rand(device_rng, T)
-                Z0 = sqrt(T(-2.0)*log(U1))*cos(T(2pi)*U2)
-                Z1 = sqrt(T(-2.0)*log(U1))*sin(T(2pi)*U2)
+                Z0 = sqrt(T(-2.0) * log(U1)) * cos(T(2pi) * U2)
+                Z1 = sqrt(T(-2.0) * log(U1)) * sin(T(2pi) * U2)
                 @inbounds A[i] = Z0
                 if j <= length(A)
                     @inbounds A[j] = Z1
                 end
             end
 
-            offset += 2*window
+            offset += 2 * window
         end
         return
     end
 
-    function kernel(A::AbstractArray{Complex{T}}, seed::UInt32, counter::UInt32) where {T<:Real}
+    function kernel(A::AbstractArray{Complex{T}}, seed::UInt32, counter::UInt32) where {T <: Real}
         device_rng = Random.default_rng()
 
         # initialize the state
@@ -134,8 +134,8 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
                     U1 = Random.rand(device_rng, T)
                 end
                 U2 = Random.rand(device_rng, T)
-                Z0 = sqrt(-log(U1))*cos(T(2pi)*U2)
-                Z1 = sqrt(-log(U1))*sin(T(2pi)*U2)
+                Z0 = sqrt(-log(U1)) * cos(T(2pi) * U2)
+                Z1 = sqrt(-log(U1)) * sin(T(2pi) * U2)
                 @inbounds A[i] = complex(Z0, Z1)
             end
 
@@ -149,14 +149,14 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
     threads = 32
     groups = cld(cld(length(A), 2), threads)
 
-    @metal threads groups name="randn!" kernel(A, rng.seed, rng.counter)
+    @metal threads groups name = "randn!" kernel(A, rng.seed, rng.counter)
 
     new_counter = Int64(rng.counter) + length(A)
     overflow, remainder = fldmod(new_counter, typemax(UInt32))
     rng.seed += overflow     # XXX: is this OK?
     rng.counter = remainder
 
-    A
+    return A
 end
 
 function default_rng()
@@ -165,13 +165,13 @@ function default_rng()
     # every task maintains library state per device
     LibraryState = @NamedTuple{rng::RNG}
     states = get!(task_local_storage(), :RNG) do
-        Dict{MTLDevice,LibraryState}()
-    end::Dict{MTLDevice,LibraryState}
+        Dict{MTLDevice, LibraryState}()
+    end::Dict{MTLDevice, LibraryState}
 
     # get library state
     @noinline function new_state(dev)
         # Metal RNG objects are cheap, so we don't need to cache them
-        (; rng = RNG())
+        return (; rng = RNG())
     end
     state = get!(states, dev) do
         new_state(dev)
@@ -213,17 +213,17 @@ Random.randn(rng::RNG, T::Type, dim1::Integer, dims::Integer...) =
 function Random.rand!(rng::RNG, A::AbstractArray{T}) where {T}
     B = MtlArray{T}(undef, size(A))
     rand!(rng, B)
-    copyto!(A, B)
+    return copyto!(A, B)
 end
 function Random.randn!(rng::RNG, A::AbstractArray{T}) where {T}
     B = MtlArray{T}(undef, size(A))
     randn!(rng, B)
-    copyto!(A, B)
+    return copyto!(A, B)
 end
 
 # scalars
-Random.rand(rng::RNG, T::Type=Float32) = Random.rand(rng, T, 1)[]
-Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[]
+Random.rand(rng::RNG, T::Type = Float32) = Random.rand(rng, T, 1)[]
+Random.randn(rng::RNG, T::Type = Float32) = Random.randn(rng, T, 1)[]
 # resolve ambiguities
 Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[]
 
@@ -248,9 +248,9 @@ function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
     return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
 end
 rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
-    Random.rand!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
+    Random.rand!(mtl_rng(), MtlArray{T, length(dims), storage}(undef, dims...))
 randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
-    Random.randn!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
+    Random.randn!(mtl_rng(), MtlArray{T, length(dims), storage}(undef, dims...))
 
 # support all dimension specifications
 function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
@@ -261,9 +261,9 @@ function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=Defau
 end
 
 rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
-    Random.rand!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
+    Random.rand!(mtl_rng(), MtlArray{T, length(dims) + 1, storage}(undef, dim1, dims...))
 randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
-    Random.randn!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
+    Random.randn!(mtl_rng(), MtlArray{T, length(dims) + 1, storage}(undef, dim1, dims...))
 
 # untyped out-of-place
 rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
diff --git a/test/random.jl b/test/random.jl
index 7a8e28fb..25a48c65 100644
--- a/test/random.jl
+++ b/test/random.jl
@@ -12,7 +12,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
 
     @testset "$f with $T" for (f, T) in INPLACE_TUPLES
         # d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
-        @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
+        @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000))
             A = MtlArray{T}(undef, d)
 
             # default_rng
@@ -34,7 +34,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
 
             # default_rng
             f(A)
-            @test A isa MtlArray{T,1}
+            @test A isa MtlArray{T, 1}
             @test Array(A) == fill(1, 0)
 
             # specified MPS rng
@@ -116,7 +116,7 @@ end
     # Test when views try to use rand!(rng, args..)
     @testset "MPS.RNG with views" begin
         rng = Metal.MPS.RNG()
-        @testset "$f with $T" for (f, T) in ((randn!, Float32),(rand!, Int64),(rand!, Float32), (rand!, UInt16), (rand!,Int8))
+        @testset "$f with $T" for (f, T) in ((randn!, Float32), (rand!, Int64), (rand!, Float32), (rand!, UInt16), (rand!, Int8))
             A = MtlArray{T}(undef, 100)
 
             ## Offset > 0
@@ -146,9 +146,9 @@ end
 # out-of-place
 @testset "out-of-place" begin
     @testset "$fr with implicit type" for (fm, fr, T) in
-                                            ((Metal.rand, rand, Float32), (Metal.randn, randn, Float32))
+        ((Metal.rand, rand, Float32), (Metal.randn, randn, Float32))
         rng = Metal.MPS.RNG()
-        @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000))
+        @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000, 1000))
             # default_rng
             A = fm(args...)
             @test eltype(A) == T
@@ -169,16 +169,18 @@ end
     # out-of-place, with type specified
     @testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES
         rng = Metal.MPS.RNG()
-        @testset "$args" for args in ((T, 0),
-                                        (T, 1),
-                                        (T, 3),
-                                        (T, 3, 3),
-                                        (T, (3, 3)),
-                                        (T, 16),
-                                        (T, 16, 16),
-                                        (T, (16, 16)),
-                                        (T, 1000),
-                                        (T, 1000, 1000),)
+        @testset "$args" for args in (
+                (T, 0),
+                (T, 1),
+                (T, 3),
+                (T, 3, 3),
+                (T, (3, 3)),
+                (T, 16),
+                (T, 16, 16),
+                (T, (16, 16)),
+                (T, 1000),
+                (T, 1000, 1000),
+            )
             # default_rng
             A = fm(args...)
             @test eltype(A) == T
@@ -207,7 +209,7 @@ end
     rng = Metal.MPS.RNG()
     @testset "$f with $T" for (f, T) in mps_tuples
         # d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
-        @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), 16384, 16385)
+        @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000), 16384, 16385)
             A = zeros(T, d)
 
             f(rng, A)
@@ -217,11 +219,13 @@ end
 end
 
 ## seeding
-@testset "Seeding $L" for (f,T,L) in [(Metal.rand,UInt32,"Uniform Integers MPS"),
-                                        (Metal.rand,Float32,"Uniform Float32 MPS"),
-                                        (Metal.randn,Float32,"Normal Float32 MPS"),
-                                        (Metal.randn,Float16,"Float16 Native")]
-    @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
+@testset "Seeding $L" for (f, T, L) in [
+        (Metal.rand, UInt32, "Uniform Integers MPS"),
+        (Metal.rand, Float32, "Uniform Float32 MPS"),
+        (Metal.randn, Float32, "Normal Float32 MPS"),
+        (Metal.randn, Float16, "Float16 Native"),
+    ]
+    @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000))
         Metal.seed!(1)
         a = f(T, d)
         Metal.seed!(1)
@@ -238,11 +242,13 @@ end
     ## in-place
 
     # uniform
-    for T in (Float16, Float32,
-              ComplexF16, ComplexF32,
-              Int8, Int16, Int32, Int64,
-              UInt8, UInt16, UInt32, UInt64),
-        dims = (0, 2, (2,2), (2,2,2))
+    for T in (
+                Float16, Float32,
+                ComplexF16, ComplexF32,
+                Int8, Int16, Int32, Int64,
+                UInt8, UInt16, UInt32, UInt64,
+            ),
+            dims in (0, 2, (2, 2), (2, 2, 2))
         A = MtlArray{T}(undef, dims)
         rand!(rng, A)
 
@@ -251,9 +257,11 @@ end
     end
 
     # normal
-    for T in (Float16, Float32,
-              ComplexF16, ComplexF32),
-        dims = (0, 2, (2,2), (2,2,2))
+    for T in (
+                Float16, Float32,
+                ComplexF16, ComplexF32,
+            ),
+            dims in (0, 2, (2, 2), (2, 2, 2))
         A = MtlArray{T}(undef, dims)
         randn!(rng, A)
 
@@ -268,12 +276,14 @@ end
         @test rand(rng) isa Number
         @test rand(rng, Float32) isa Float32
     end
-    for dims in (0, 2, (2,2), (2,2,2))
+    for dims in (0, 2, (2, 2), (2, 2, 2))
         @test rand(rng, dims) isa MtlArray
-        for T in (Float16, Float32,
-                  ComplexF16, ComplexF32,
-                  Int8, Int16, Int32, Int64,
-                  UInt8, UInt16, UInt32, UInt64)
+        for T in (
+                Float16, Float32,
+                ComplexF16, ComplexF32,
+                Int8, Int16, Int32, Int64,
+                UInt8, UInt16, UInt32, UInt64,
+            )
             @test rand(rng, T, dims) isa MtlArray{T}
         end
     end
@@ -283,10 +293,12 @@ end
         @test randn(rng) isa Number
         @test randn(rng, Float32) isa Float32
     end
-    for dims in (0, 2, (2,2), (2,2,2))
+    for dims in (0, 2, (2, 2), (2, 2, 2))
         @test randn(rng, dims) isa MtlArray
-        for T in (Float16, Float32,
-                  ComplexF16, ComplexF32)
+        for T in (
+                Float16, Float32,
+                ComplexF16, ComplexF32,
+            )
             @test randn(rng, T, dims) isa MtlArray{T}
         end
     end
@@ -304,7 +316,7 @@ end
     # bound of around 4 or 4.1, while CURAND.rand gets down to u = 2f0^(-33) giving an upper
     # bound of around 4.8. In contrast, incorrectly reusing the real Box-Muller transform
     # gives typical real parts in the hundreds.
-    @test maximum(real(randn(rng, ComplexF32, 32))) <= sqrt(-log(2f0^(-33)))
+    @test maximum(real(randn(rng, ComplexF32, 32))) <= sqrt(-log(2.0f0^(-33)))
 end
 
 @testset "seeding idempotency" begin

@codecov
Copy link

codecov bot commented Nov 23, 2025

Codecov Report

❌ Patch coverage is 98.46154% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 81.26%. Comparing base (f1ec854) to head (6a85440).

Files with missing lines Patch % Lines
src/random.jl 98.46% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #705      +/-   ##
==========================================
+ Coverage   80.96%   81.26%   +0.30%     
==========================================
  Files          62       62              
  Lines        2837     2899      +62     
==========================================
+ Hits         2297     2356      +59     
- Misses        540      543       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 6a85440 Previous: f1ec854 Ratio
latency/precompile 25179033792 ns 24360634500 ns 1.03
latency/ttfp 2302321042 ns 2326898000 ns 0.99
latency/import 1461697042 ns 1425799334 ns 1.03
integration/metaldevrt 844000 ns 833666 ns 1.01
integration/byval/slices=1 1588687 ns 1573625.5 ns 1.01
integration/byval/slices=3 20020959 ns 19669333 ns 1.02
integration/byval/reference 1591979.5 ns 1572021 ns 1.01
integration/byval/slices=2 2682458 ns 2716187 ns 0.99
kernel/indexing 514271 ns 475833 ns 1.08
kernel/indexing_checked 504458 ns 484125 ns 1.04
kernel/launch 12541 ns 12416 ns 1.01
kernel/rand 528396 ns 528000 ns 1.00
array/construct 6333 ns 6291 ns 1.01
array/broadcast 544834 ns 553000 ns 0.99
array/random/randn/Float32 917375 ns 913687 ns 1.00
array/random/randn!/Float32 584458 ns 580854 ns 1.01
array/random/rand!/Int64 539542 ns 542292 ns 0.99
array/random/rand!/Float32 536875 ns 541291.5 ns 0.99
array/random/rand/Int64 903708 ns 928667 ns 0.97
array/random/rand/Float32 849334 ns 833937.5 ns 1.02
array/accumulate/Int64/1d 1313834 ns 1303916 ns 1.01
array/accumulate/Int64/dims=1 1875042 ns 1850791.5 ns 1.01
array/accumulate/Int64/dims=2 2238417 ns 2214833 ns 1.01
array/accumulate/Int64/dims=1L 12217500 ns 12158375 ns 1.00
array/accumulate/Int64/dims=2L 9662125 ns 9771875 ns 0.99
array/accumulate/Float32/1d 1085083 ns 1077312 ns 1.01
array/accumulate/Float32/dims=1 1616500 ns 1581250 ns 1.02
array/accumulate/Float32/dims=2 2008645.5 ns 1946979 ns 1.03
array/accumulate/Float32/dims=1L 10442166.5 ns 10346124.5 ns 1.01
array/accumulate/Float32/dims=2L 7297146 ns 7465979.5 ns 0.98
array/reductions/reduce/Int64/1d 1301500 ns 1286833 ns 1.01
array/reductions/reduce/Int64/dims=1 1100854 ns 1117167 ns 0.99
array/reductions/reduce/Int64/dims=2 1146854 ns 1162416.5 ns 0.99
array/reductions/reduce/Int64/dims=1L 2041833 ns 2032312.5 ns 1.00
array/reductions/reduce/Int64/dims=2L 3898917 ns 3847000.5 ns 1.01
array/reductions/reduce/Float32/1d 745791 ns 746228.5 ns 1.00
array/reductions/reduce/Float32/dims=1 799020.5 ns 799249.5 ns 1.00
array/reductions/reduce/Float32/dims=2 841541.5 ns 838042 ns 1.00
array/reductions/reduce/Float32/dims=1L 1324208.5 ns 1338500 ns 0.99
array/reductions/reduce/Float32/dims=2L 1823875 ns 1796542 ns 1.02
array/reductions/mapreduce/Int64/1d 1324833.5 ns 1311833 ns 1.01
array/reductions/mapreduce/Int64/dims=1 1110416.5 ns 1109625 ns 1.00
array/reductions/mapreduce/Int64/dims=2 1158459 ns 1147916 ns 1.01
array/reductions/mapreduce/Int64/dims=1L 2013750 ns 2003166.5 ns 1.01
array/reductions/mapreduce/Int64/dims=2L 3605499.5 ns 3590937.5 ns 1.00
array/reductions/mapreduce/Float32/1d 837083 ns 807333 ns 1.04
array/reductions/mapreduce/Float32/dims=1 800916 ns 802167 ns 1.00
array/reductions/mapreduce/Float32/dims=2 826583 ns 823146 ns 1.00
array/reductions/mapreduce/Float32/dims=1L 1338375 ns 1348667 ns 0.99
array/reductions/mapreduce/Float32/dims=2L 1812750 ns 1809125 ns 1.00
array/private/copyto!/gpu_to_gpu 551666 ns 526729.5 ns 1.05
array/private/copyto!/cpu_to_gpu 748166.5 ns 756229 ns 0.99
array/private/copyto!/gpu_to_cpu 740583.5 ns 757041.5 ns 0.98
array/private/iteration/findall/int 1576417 ns 1574021 ns 1.00
array/private/iteration/findall/bool 1481416 ns 1469833 ns 1.01
array/private/iteration/findfirst/int 2108958 ns 2086312 ns 1.01
array/private/iteration/findfirst/bool 2028791.5 ns 2021375 ns 1.00
array/private/iteration/scalar 3555167 ns 3490125 ns 1.02
array/private/iteration/logical 2670833 ns 2683791.5 ns 1.00
array/private/iteration/findmin/1d 2279187.5 ns 2264708.5 ns 1.01
array/private/iteration/findmin/2d 1549354 ns 1546667 ns 1.00
array/private/copy 865958.5 ns 896333 ns 0.97
array/shared/copyto!/gpu_to_gpu 84166 ns 84459 ns 1.00
array/shared/copyto!/cpu_to_gpu 83333 ns 83458 ns 1.00
array/shared/copyto!/gpu_to_cpu 83083 ns 83583 ns 0.99
array/shared/iteration/findall/int 1574604.5 ns 1574563 ns 1.00
array/shared/iteration/findall/bool 1493020.5 ns 1481104.5 ns 1.01
array/shared/iteration/findfirst/int 1707500 ns 1691520.5 ns 1.01
array/shared/iteration/findfirst/bool 1637209 ns 1635583 ns 1.00
array/shared/iteration/scalar 203875 ns 206458 ns 0.99
array/shared/iteration/logical 2413250 ns 2265000 ns 1.07
array/shared/iteration/findmin/1d 1887916.5 ns 1896625 ns 1.00
array/shared/iteration/findmin/2d 1546458 ns 1542708 ns 1.00
array/shared/copy 215209 ns 213500 ns 1.01
array/permutedims/4d 2492875 ns 2462125 ns 1.01
array/permutedims/2d 1169959 ns 1182146 ns 0.99
array/permutedims/3d 1795958.5 ns 1792209 ns 1.00
metal/synchronization/stream 19417 ns 19750 ns 0.98
metal/synchronization/context 20125 ns 20083 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@christiangnrd christiangnrd force-pushed the rand branch 2 times, most recently from b66110f to 4a5e04d Compare December 2, 2025 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Replace GPUArrays-based RNG with native one

2 participants