diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index 37b8f852..84725512 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -2,6 +2,27 @@ GPUArrays.device(x::MtlArray) = x.dev +import KernelAbstractions +import KernelAbstractions: Backend + +@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N}; + elements::Int, elements_per_thread::Int) where {O,N} + + ndrange = ceil(Int, elements / elements_per_thread) + ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, + nothing) + + ctx = KA.mkcontext(obj, ndrange, iterspace) + + kernel = @metal launch=false obj.f(ctx, args...) + + # The pipeline state automatically computes occupancy stats + threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) + blocks = cld(elements, threads) + + return (; threads=Int(threads), blocks=Int(blocks)) +end + const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}() function GPUArrays.default_rng(::Type{<:MtlArray}) dev = device()