diff --git a/Project.toml b/Project.toml index ef0a59e..d059e87 100644 --- a/Project.toml +++ b/Project.toml @@ -14,10 +14,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [extensions] NNopAMDGPUExt = "AMDGPU" NNopCUDAExt = "CUDA" +NNopMetalExt = "Metal" [compat] AMDGPU = "1.2.5, 2" diff --git a/ext/NNopMetalExt.jl b/ext/NNopMetalExt.jl new file mode 100644 index 0000000..28ca1e7 --- /dev/null +++ b/ext/NNopMetalExt.jl @@ -0,0 +1,11 @@ +module NNopMetalExt + +using Metal +using NNop + +function NNop._shared_memory(::MetalBackend, device_id::Integer) + dev = Metal.devices()[device_id] + return UInt64(dev.maxThreadgroupMemoryLength) +end + +end