-
Notifications
You must be signed in to change notification settings - Fork 48
Native Metal rand #705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Native Metal rand #705
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/random.jl b/src/random.jl
index a545533b..3c55d9a9 100644
--- a/src/random.jl
+++ b/src/random.jl
@@ -12,7 +12,7 @@ mutable struct RNG <: AbstractRNG
counter::UInt32
function RNG(seed::Integer)
- new(seed%UInt32, 0)
+ return new(seed % UInt32, 0)
end
RNG(seed::UInt32, counter::UInt32) = new(seed, counter)
end
@@ -27,7 +27,7 @@ Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter)
function Random.seed!(rng::RNG, seed::Integer)
rng.seed = seed % UInt32
- rng.counter = 0
+ return rng.counter = 0
end
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())
@@ -68,21 +68,21 @@ function Random.rand!(rng::RNG, A::WrappedMtlArray)
threads = 32
groups = cld(length(A), threads)
- @metal threads groups name="rand!" kernel(A, rng.seed, rng.counter)
+ @metal threads groups name = "rand!" kernel(A, rng.seed, rng.counter)
new_counter = Int64(rng.counter) + length(A)
overflow, remainder = fldmod(new_counter, typemax(UInt32))
rng.seed += overflow # XXX: is this OK?
rng.counter = remainder
- A
+ return A
end
-function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Complex{<:AbstractFloat}}})
+function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat, Complex{<:AbstractFloat}}})
isempty(A) && return A
## COV_EXCL_START
- function kernel(A::AbstractArray{T}, seed::UInt32, counter::UInt32) where {T<:Real}
+ function kernel(A::AbstractArray{T}, seed::UInt32, counter::UInt32) where {T <: Real}
device_rng = Random.default_rng()
# initialize the state
@@ -102,20 +102,20 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
U1 = Random.rand(device_rng, T)
end
U2 = Random.rand(device_rng, T)
- Z0 = sqrt(T(-2.0)*log(U1))*cos(T(2pi)*U2)
- Z1 = sqrt(T(-2.0)*log(U1))*sin(T(2pi)*U2)
+ Z0 = sqrt(T(-2.0) * log(U1)) * cos(T(2pi) * U2)
+ Z1 = sqrt(T(-2.0) * log(U1)) * sin(T(2pi) * U2)
@inbounds A[i] = Z0
if j <= length(A)
@inbounds A[j] = Z1
end
end
- offset += 2*window
+ offset += 2 * window
end
return
end
- function kernel(A::AbstractArray{Complex{T}}, seed::UInt32, counter::UInt32) where {T<:Real}
+ function kernel(A::AbstractArray{Complex{T}}, seed::UInt32, counter::UInt32) where {T <: Real}
device_rng = Random.default_rng()
# initialize the state
@@ -134,8 +134,8 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
U1 = Random.rand(device_rng, T)
end
U2 = Random.rand(device_rng, T)
- Z0 = sqrt(-log(U1))*cos(T(2pi)*U2)
- Z1 = sqrt(-log(U1))*sin(T(2pi)*U2)
+ Z0 = sqrt(-log(U1)) * cos(T(2pi) * U2)
+ Z1 = sqrt(-log(U1)) * sin(T(2pi) * U2)
@inbounds A[i] = complex(Z0, Z1)
end
@@ -149,14 +149,14 @@ function Random.randn!(rng::RNG, A::WrappedMtlArray{<:Union{AbstractFloat,Comple
threads = 32
groups = cld(cld(length(A), 2), threads)
- @metal threads groups name="randn!" kernel(A, rng.seed, rng.counter)
+ @metal threads groups name = "randn!" kernel(A, rng.seed, rng.counter)
new_counter = Int64(rng.counter) + length(A)
overflow, remainder = fldmod(new_counter, typemax(UInt32))
rng.seed += overflow # XXX: is this OK?
rng.counter = remainder
- A
+ return A
end
function default_rng()
@@ -165,13 +165,13 @@ function default_rng()
# every task maintains library state per device
LibraryState = @NamedTuple{rng::RNG}
states = get!(task_local_storage(), :RNG) do
- Dict{MTLDevice,LibraryState}()
- end::Dict{MTLDevice,LibraryState}
+ Dict{MTLDevice, LibraryState}()
+ end::Dict{MTLDevice, LibraryState}
# get library state
@noinline function new_state(dev)
# Metal RNG objects are cheap, so we don't need to cache them
- (; rng = RNG())
+ return (; rng = RNG())
end
state = get!(states, dev) do
new_state(dev)
@@ -213,17 +213,17 @@ Random.randn(rng::RNG, T::Type, dim1::Integer, dims::Integer...) =
function Random.rand!(rng::RNG, A::AbstractArray{T}) where {T}
B = MtlArray{T}(undef, size(A))
rand!(rng, B)
- copyto!(A, B)
+ return copyto!(A, B)
end
function Random.randn!(rng::RNG, A::AbstractArray{T}) where {T}
B = MtlArray{T}(undef, size(A))
randn!(rng, B)
- copyto!(A, B)
+ return copyto!(A, B)
end
# scalars
-Random.rand(rng::RNG, T::Type=Float32) = Random.rand(rng, T, 1)[]
-Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[]
+Random.rand(rng::RNG, T::Type = Float32) = Random.rand(rng, T, 1)[]
+Random.randn(rng::RNG, T::Type = Float32) = Random.randn(rng, T, 1)[]
# resolve ambiguities
Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[]
@@ -248,9 +248,9 @@ function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
end
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
- Random.rand!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
+ Random.rand!(mtl_rng(), MtlArray{T, length(dims), storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
- Random.randn!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
+ Random.randn!(mtl_rng(), MtlArray{T, length(dims), storage}(undef, dims...))
# support all dimension specifications
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
@@ -261,9 +261,9 @@ function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=Defau
end
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
- Random.rand!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
+ Random.rand!(mtl_rng(), MtlArray{T, length(dims) + 1, storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
- Random.randn!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
+ Random.randn!(mtl_rng(), MtlArray{T, length(dims) + 1, storage}(undef, dim1, dims...))
# untyped out-of-place
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
diff --git a/test/random.jl b/test/random.jl
index 7a8e28fb..25a48c65 100644
--- a/test/random.jl
+++ b/test/random.jl
@@ -12,7 +12,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
@testset "$f with $T" for (f, T) in INPLACE_TUPLES
# d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
- @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
+ @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000))
A = MtlArray{T}(undef, d)
# default_rng
@@ -34,7 +34,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
# default_rng
f(A)
- @test A isa MtlArray{T,1}
+ @test A isa MtlArray{T, 1}
@test Array(A) == fill(1, 0)
# specified MPS rng
@@ -116,7 +116,7 @@ end
# Test when views try to use rand!(rng, args..)
@testset "MPS.RNG with views" begin
rng = Metal.MPS.RNG()
- @testset "$f with $T" for (f, T) in ((randn!, Float32),(rand!, Int64),(rand!, Float32), (rand!, UInt16), (rand!,Int8))
+ @testset "$f with $T" for (f, T) in ((randn!, Float32), (rand!, Int64), (rand!, Float32), (rand!, UInt16), (rand!, Int8))
A = MtlArray{T}(undef, 100)
## Offset > 0
@@ -146,9 +146,9 @@ end
# out-of-place
@testset "out-of-place" begin
@testset "$fr with implicit type" for (fm, fr, T) in
- ((Metal.rand, rand, Float32), (Metal.randn, randn, Float32))
+ ((Metal.rand, rand, Float32), (Metal.randn, randn, Float32))
rng = Metal.MPS.RNG()
- @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000))
+ @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000, 1000))
# default_rng
A = fm(args...)
@test eltype(A) == T
@@ -169,16 +169,18 @@ end
# out-of-place, with type specified
@testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES
rng = Metal.MPS.RNG()
- @testset "$args" for args in ((T, 0),
- (T, 1),
- (T, 3),
- (T, 3, 3),
- (T, (3, 3)),
- (T, 16),
- (T, 16, 16),
- (T, (16, 16)),
- (T, 1000),
- (T, 1000, 1000),)
+ @testset "$args" for args in (
+ (T, 0),
+ (T, 1),
+ (T, 3),
+ (T, 3, 3),
+ (T, (3, 3)),
+ (T, 16),
+ (T, 16, 16),
+ (T, (16, 16)),
+ (T, 1000),
+ (T, 1000, 1000),
+ )
# default_rng
A = fm(args...)
@test eltype(A) == T
@@ -207,7 +209,7 @@ end
rng = Metal.MPS.RNG()
@testset "$f with $T" for (f, T) in mps_tuples
# d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
- @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), 16384, 16385)
+ @testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000), 16384, 16385)
A = zeros(T, d)
f(rng, A)
@@ -217,11 +219,13 @@ end
end
## seeding
-@testset "Seeding $L" for (f,T,L) in [(Metal.rand,UInt32,"Uniform Integers MPS"),
- (Metal.rand,Float32,"Uniform Float32 MPS"),
- (Metal.randn,Float32,"Normal Float32 MPS"),
- (Metal.randn,Float16,"Float16 Native")]
- @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
+@testset "Seeding $L" for (f, T, L) in [
+ (Metal.rand, UInt32, "Uniform Integers MPS"),
+ (Metal.rand, Float32, "Uniform Float32 MPS"),
+ (Metal.randn, Float32, "Normal Float32 MPS"),
+ (Metal.randn, Float16, "Float16 Native"),
+ ]
+ @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000, 1000))
Metal.seed!(1)
a = f(T, d)
Metal.seed!(1)
@@ -238,11 +242,13 @@ end
## in-place
# uniform
- for T in (Float16, Float32,
- ComplexF16, ComplexF32,
- Int8, Int16, Int32, Int64,
- UInt8, UInt16, UInt32, UInt64),
- dims = (0, 2, (2,2), (2,2,2))
+ for T in (
+ Float16, Float32,
+ ComplexF16, ComplexF32,
+ Int8, Int16, Int32, Int64,
+ UInt8, UInt16, UInt32, UInt64,
+ ),
+ dims in (0, 2, (2, 2), (2, 2, 2))
A = MtlArray{T}(undef, dims)
rand!(rng, A)
@@ -251,9 +257,11 @@ end
end
# normal
- for T in (Float16, Float32,
- ComplexF16, ComplexF32),
- dims = (0, 2, (2,2), (2,2,2))
+ for T in (
+ Float16, Float32,
+ ComplexF16, ComplexF32,
+ ),
+ dims in (0, 2, (2, 2), (2, 2, 2))
A = MtlArray{T}(undef, dims)
randn!(rng, A)
@@ -268,12 +276,14 @@ end
@test rand(rng) isa Number
@test rand(rng, Float32) isa Float32
end
- for dims in (0, 2, (2,2), (2,2,2))
+ for dims in (0, 2, (2, 2), (2, 2, 2))
@test rand(rng, dims) isa MtlArray
- for T in (Float16, Float32,
- ComplexF16, ComplexF32,
- Int8, Int16, Int32, Int64,
- UInt8, UInt16, UInt32, UInt64)
+ for T in (
+ Float16, Float32,
+ ComplexF16, ComplexF32,
+ Int8, Int16, Int32, Int64,
+ UInt8, UInt16, UInt32, UInt64,
+ )
@test rand(rng, T, dims) isa MtlArray{T}
end
end
@@ -283,10 +293,12 @@ end
@test randn(rng) isa Number
@test randn(rng, Float32) isa Float32
end
- for dims in (0, 2, (2,2), (2,2,2))
+ for dims in (0, 2, (2, 2), (2, 2, 2))
@test randn(rng, dims) isa MtlArray
- for T in (Float16, Float32,
- ComplexF16, ComplexF32)
+ for T in (
+ Float16, Float32,
+ ComplexF16, ComplexF32,
+ )
@test randn(rng, T, dims) isa MtlArray{T}
end
end
@@ -304,7 +316,7 @@ end
# bound of around 4 or 4.1, while CURAND.rand gets down to u = 2f0^(-33) giving an upper
# bound of around 4.8. In contrast, incorrectly reusing the real Box-Muller transform
# gives typical real parts in the hundreds.
- @test maximum(real(randn(rng, ComplexF32, 32))) <= sqrt(-log(2f0^(-33)))
+ @test maximum(real(randn(rng, ComplexF32, 32))) <= sqrt(-log(2.0f0^(-33)))
end
@testset "seeding idempotency" begin |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #705 +/- ##
==========================================
+ Coverage 80.96% 81.26% +0.30%
==========================================
Files 62 62
Lines 2837 2899 +62
==========================================
+ Hits 2297 2356 +59
- Misses 540 543 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
| Benchmark suite | Current: 6a85440 | Previous: f1ec854 | Ratio |
|---|---|---|---|
latency/precompile |
25179033792 ns |
24360634500 ns |
1.03 |
latency/ttfp |
2302321042 ns |
2326898000 ns |
0.99 |
latency/import |
1461697042 ns |
1425799334 ns |
1.03 |
integration/metaldevrt |
844000 ns |
833666 ns |
1.01 |
integration/byval/slices=1 |
1588687 ns |
1573625.5 ns |
1.01 |
integration/byval/slices=3 |
20020959 ns |
19669333 ns |
1.02 |
integration/byval/reference |
1591979.5 ns |
1572021 ns |
1.01 |
integration/byval/slices=2 |
2682458 ns |
2716187 ns |
0.99 |
kernel/indexing |
514271 ns |
475833 ns |
1.08 |
kernel/indexing_checked |
504458 ns |
484125 ns |
1.04 |
kernel/launch |
12541 ns |
12416 ns |
1.01 |
kernel/rand |
528396 ns |
528000 ns |
1.00 |
array/construct |
6333 ns |
6291 ns |
1.01 |
array/broadcast |
544834 ns |
553000 ns |
0.99 |
array/random/randn/Float32 |
917375 ns |
913687 ns |
1.00 |
array/random/randn!/Float32 |
584458 ns |
580854 ns |
1.01 |
array/random/rand!/Int64 |
539542 ns |
542292 ns |
0.99 |
array/random/rand!/Float32 |
536875 ns |
541291.5 ns |
0.99 |
array/random/rand/Int64 |
903708 ns |
928667 ns |
0.97 |
array/random/rand/Float32 |
849334 ns |
833937.5 ns |
1.02 |
array/accumulate/Int64/1d |
1313834 ns |
1303916 ns |
1.01 |
array/accumulate/Int64/dims=1 |
1875042 ns |
1850791.5 ns |
1.01 |
array/accumulate/Int64/dims=2 |
2238417 ns |
2214833 ns |
1.01 |
array/accumulate/Int64/dims=1L |
12217500 ns |
12158375 ns |
1.00 |
array/accumulate/Int64/dims=2L |
9662125 ns |
9771875 ns |
0.99 |
array/accumulate/Float32/1d |
1085083 ns |
1077312 ns |
1.01 |
array/accumulate/Float32/dims=1 |
1616500 ns |
1581250 ns |
1.02 |
array/accumulate/Float32/dims=2 |
2008645.5 ns |
1946979 ns |
1.03 |
array/accumulate/Float32/dims=1L |
10442166.5 ns |
10346124.5 ns |
1.01 |
array/accumulate/Float32/dims=2L |
7297146 ns |
7465979.5 ns |
0.98 |
array/reductions/reduce/Int64/1d |
1301500 ns |
1286833 ns |
1.01 |
array/reductions/reduce/Int64/dims=1 |
1100854 ns |
1117167 ns |
0.99 |
array/reductions/reduce/Int64/dims=2 |
1146854 ns |
1162416.5 ns |
0.99 |
array/reductions/reduce/Int64/dims=1L |
2041833 ns |
2032312.5 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
3898917 ns |
3847000.5 ns |
1.01 |
array/reductions/reduce/Float32/1d |
745791 ns |
746228.5 ns |
1.00 |
array/reductions/reduce/Float32/dims=1 |
799020.5 ns |
799249.5 ns |
1.00 |
array/reductions/reduce/Float32/dims=2 |
841541.5 ns |
838042 ns |
1.00 |
array/reductions/reduce/Float32/dims=1L |
1324208.5 ns |
1338500 ns |
0.99 |
array/reductions/reduce/Float32/dims=2L |
1823875 ns |
1796542 ns |
1.02 |
array/reductions/mapreduce/Int64/1d |
1324833.5 ns |
1311833 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1 |
1110416.5 ns |
1109625 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
1158459 ns |
1147916 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1L |
2013750 ns |
2003166.5 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=2L |
3605499.5 ns |
3590937.5 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
837083 ns |
807333 ns |
1.04 |
array/reductions/mapreduce/Float32/dims=1 |
800916 ns |
802167 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2 |
826583 ns |
823146 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=1L |
1338375 ns |
1348667 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=2L |
1812750 ns |
1809125 ns |
1.00 |
array/private/copyto!/gpu_to_gpu |
551666 ns |
526729.5 ns |
1.05 |
array/private/copyto!/cpu_to_gpu |
748166.5 ns |
756229 ns |
0.99 |
array/private/copyto!/gpu_to_cpu |
740583.5 ns |
757041.5 ns |
0.98 |
array/private/iteration/findall/int |
1576417 ns |
1574021 ns |
1.00 |
array/private/iteration/findall/bool |
1481416 ns |
1469833 ns |
1.01 |
array/private/iteration/findfirst/int |
2108958 ns |
2086312 ns |
1.01 |
array/private/iteration/findfirst/bool |
2028791.5 ns |
2021375 ns |
1.00 |
array/private/iteration/scalar |
3555167 ns |
3490125 ns |
1.02 |
array/private/iteration/logical |
2670833 ns |
2683791.5 ns |
1.00 |
array/private/iteration/findmin/1d |
2279187.5 ns |
2264708.5 ns |
1.01 |
array/private/iteration/findmin/2d |
1549354 ns |
1546667 ns |
1.00 |
array/private/copy |
865958.5 ns |
896333 ns |
0.97 |
array/shared/copyto!/gpu_to_gpu |
84166 ns |
84459 ns |
1.00 |
array/shared/copyto!/cpu_to_gpu |
83333 ns |
83458 ns |
1.00 |
array/shared/copyto!/gpu_to_cpu |
83083 ns |
83583 ns |
0.99 |
array/shared/iteration/findall/int |
1574604.5 ns |
1574563 ns |
1.00 |
array/shared/iteration/findall/bool |
1493020.5 ns |
1481104.5 ns |
1.01 |
array/shared/iteration/findfirst/int |
1707500 ns |
1691520.5 ns |
1.01 |
array/shared/iteration/findfirst/bool |
1637209 ns |
1635583 ns |
1.00 |
array/shared/iteration/scalar |
203875 ns |
206458 ns |
0.99 |
array/shared/iteration/logical |
2413250 ns |
2265000 ns |
1.07 |
array/shared/iteration/findmin/1d |
1887916.5 ns |
1896625 ns |
1.00 |
array/shared/iteration/findmin/2d |
1546458 ns |
1542708 ns |
1.00 |
array/shared/copy |
215209 ns |
213500 ns |
1.01 |
array/permutedims/4d |
2492875 ns |
2462125 ns |
1.01 |
array/permutedims/2d |
1169959 ns |
1182146 ns |
0.99 |
array/permutedims/3d |
1795958.5 ns |
1792209 ns |
1.00 |
metal/synchronization/stream |
19417 ns |
19750 ns |
0.98 |
metal/synchronization/context |
20125 ns |
20083 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
b66110f to
4a5e04d
Compare
Ported from CUDA.jl
Close #691