Skip to content

Commit

Permalink
Revert "removing heuristic"
Browse files Browse the repository at this point in the history
This reverts commit 9a7a84a.
  • Loading branch information
leios committed Sep 16, 2024
1 parent e8ffa59 commit cfe7eac
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/gpuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@

GPUArrays.device(x::MtlArray) = x.dev

import KernelAbstractions
import KernelAbstractions: Backend

@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N};
elements::Int, elements_per_thread::Int) where {O,N}

ndrange = ceil(Int, elements / elements_per_thread)
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange,
nothing)

ctx = KA.mkcontext(obj, ndrange, iterspace)

kernel = @metal launch=false obj.f(ctx, args...)

# The pipeline state automatically computes occupancy stats
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
blocks = cld(elements, threads)

return (; threads=Int(threads), blocks=Int(blocks))
end

const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
function GPUArrays.default_rng(::Type{<:MtlArray})
dev = device()
Expand Down

0 comments on commit cfe7eac

Please sign in to comment.