Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ end
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
## COV_EXCL_START
function broadcast_cartesian_static(dest, bc, Is)
i = thread_position_in_grid().x
stride = threads_per_grid().x
while 1 <= i <= length(dest)
# Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32)
# Metal's thread_position_in_grid() returns UInt32, but we need 64-bit indexing
i = Int64(thread_position_in_grid().x)
stride = Int64(threads_per_grid().x)
len = Int64(length(dest))
while i <= len
I = @inbounds Is[i]
@inbounds dest[I] = bc[I]
i += stride
Expand All @@ -91,9 +94,11 @@ end
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
## COV_EXCL_START
function broadcast_linear(dest, bc)
i = thread_position_in_grid().x
stride = threads_per_grid().x
while 1 <= i <= length(dest)
# Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32)
i = Int64(thread_position_in_grid().x)
stride = Int64(threads_per_grid().x)
len = Int64(length(dest))
while i <= len
@inbounds dest[i] = bc[i]
i += stride
end
Expand All @@ -108,12 +113,16 @@ end
elseif ndims(dest) == 2
## COV_EXCL_START
function broadcast_2d(dest, bc)
is = Tuple(thread_position_in_grid_2d())
# Use Int64 to avoid UInt32 overflow for large 2D arrays
pos = thread_position_in_grid_2d()
stride = threads_per_grid_2d()
while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2)
is = (Int64(pos.x), Int64(pos.y))
strides = (Int64(stride.x), Int64(stride.y))
dims = (Int64(size(dest, 1)), Int64(size(dest, 2)))
while is[1] <= dims[1] && is[2] <= dims[2]
I = CartesianIndex(is)
@inbounds dest[I] = bc[I]
is = (is[1] + stride[1], is[2] + stride[2])
is = (is[1] + strides[1], is[2] + strides[2])
end
return
end
Expand All @@ -127,14 +136,16 @@ end
elseif ndims(dest) == 3
## COV_EXCL_START
function broadcast_3d(dest, bc)
is = Tuple(thread_position_in_grid_3d())
# Use Int64 to avoid UInt32 overflow for large 3D arrays
pos = thread_position_in_grid_3d()
stride = threads_per_grid_3d()
while 1 <= is[1] <= size(dest, 1) &&
1 <= is[2] <= size(dest, 2) &&
1 <= is[3] <= size(dest, 3)
is = (Int64(pos.x), Int64(pos.y), Int64(pos.z))
strides = (Int64(stride.x), Int64(stride.y), Int64(stride.z))
dims = (Int64(size(dest, 1)), Int64(size(dest, 2)), Int64(size(dest, 3)))
while is[1] <= dims[1] && is[2] <= dims[2] && is[3] <= dims[3]
I = CartesianIndex(is)
@inbounds dest[I] = bc[I]
is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3])
is = (is[1] + strides[1], is[2] + strides[2], is[3] + strides[3])
end
return
end
Expand All @@ -150,9 +161,11 @@ end
else
## COV_EXCL_START
function broadcast_cartesian(dest, bc)
i = thread_position_in_grid().x
stride = threads_per_grid().x
while 1 <= i <= length(dest)
# Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32)
i = Int64(thread_position_in_grid().x)
stride = Int64(threads_per_grid().x)
len = Int64(length(dest))
while i <= len
I = @inbounds CartesianIndices(dest)[i]
@inbounds dest[I] = bc[I]
i += stride
Expand Down