Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import CUDA, Adapt
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_copyto!

MBF.device(x::CUDA.CuArray) = MBF.GPU()
MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA()

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.GPU)
function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
nitems = length(parent(dest))
Expand Down
10 changes: 5 additions & 5 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble

struct CPU end
struct GPU end
device(x::AbstractArray) = CPU()
struct MBF_CPU end
struct MBF_CUDA end
device(x::AbstractArray) = MBF_CPU()

function Base.copyto!(fmb::FusedMultiBroadcast)
pairs = fmb.pairs # (Pair(dest1, bc1),Pair(dest2, bc2),...)
Expand All @@ -26,7 +26,7 @@ Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) =
@inline rcopyto_at!(pairs::Tuple{}, i...) = nothing

# This is better than the baseline.
function fused_copyto!(fmb::FusedMultiBroadcast, ::CPU)
function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU)
(; pairs) = fmb
destinations = map(x -> x.first, pairs)
ei = if eltype(destinations) <: Vector
Expand All @@ -44,7 +44,7 @@ end

# This should, in theory be better, but it seems like inlining is
# failing somewhere.
# function fused_copyto!(fmb::FusedMultiBroadcast, ::CPU)
# function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU)
# (; pairs) = fmb
# destinations = map(x -> x.first, pairs)
# ei = if eltype(destinations) <: Vector
Expand Down
4 changes: 2 additions & 2 deletions test/execution/bm_fused_reads_vs_hard_coded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import MultiBroadcastFusion as MBF
# =========================================== hard-coded implementations
perf_kernel_hard_coded!(X, Y) = perf_kernel_hard_coded!(X, Y, MBF.device(X.x1))

function perf_kernel_hard_coded!(X, Y, ::MBF.CPU)
function perf_kernel_hard_coded!(X, Y, ::MBF.MBF_CPU)
(; x1, x2, x3, x4) = X
(; y1, y2, y3, y4) = Y
@inbounds for i in eachindex(x1)
Expand All @@ -24,7 +24,7 @@ end
@static get(ENV, "USE_CUDA", nothing) == "true" && using CUDA
use_cuda = @isdefined(CUDA) && CUDA.has_cuda() # will be true if you first run `using CUDA`
@static if use_cuda
function perf_kernel_hard_coded!(X, Y, ::MBF.GPU)
function perf_kernel_hard_coded!(X, Y, ::MBF.MBF_CUDA)
x1 = X.x1
nitems = length(parent(x1))
max_threads = 256 # can be higher if conditions permit
Expand Down