From 88bc25b035650801cb3195356884a476f18052c0 Mon Sep 17 00:00:00 2001 From: Kaan Kesgin Date: Fri, 5 Dec 2025 07:58:12 +0100 Subject: [PATCH] Fix broadcast for arrays >16GB by using 64-bit indexing Metal's thread_position_in_grid() returns UInt32, which overflows at 4,294,967,295 elements (~16GB for Float32). This caused silent data corruption where elements beyond the 32-bit limit received wrong values or zeros. Root cause: The grid-stride loop arithmetic used UInt32 throughout: i = thread_position_in_grid().x # UInt32 stride = threads_per_grid().x # UInt32 i += stride # UInt32 overflow! Fix: Convert to Int64 immediately and use 64-bit arithmetic: i = Int64(thread_position_in_grid().x) stride = Int64(threads_per_grid().x) i += stride # Int64, no overflow This is the standard pattern for handling >4B elements on GPUs. The UInt32 thread position is a Metal Shading Language constraint (per Apple's MSL spec), but we only need unique thread IDs up to 2^32. The index computation uses 64-bit to reach any element. Performance: Benchmarked at 1B elements - no measurable overhead. Apple Silicon GPUs handle Int64 arithmetic efficiently (native 64-bit). Verified working at 18GB (4.5B elements) on M2 Max. All 319 broadcast tests pass. --- src/broadcast.jl | 47 ++++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 26979706a..b01077751 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -66,9 +66,12 @@ end if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD ## COV_EXCL_START function broadcast_cartesian_static(dest, bc, Is) - i = thread_position_in_grid().x - stride = threads_per_grid().x - while 1 <= i <= length(dest) + # Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32) + # Metal's thread_position_in_grid() returns UInt32, but we need 64-bit indexing + i = Int64(thread_position_in_grid().x) + stride = Int64(threads_per_grid().x) + len = Int64(length(dest)) + while i <= len I = @inbounds Is[i] @inbounds dest[I] = bc[I] i += stride @@ -91,9 +94,11 @@ end (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) ## COV_EXCL_START function broadcast_linear(dest, bc) - i = thread_position_in_grid().x - stride = threads_per_grid().x - while 1 <= i <= length(dest) + # Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32) + i = Int64(thread_position_in_grid().x) + stride = Int64(threads_per_grid().x) + len = Int64(length(dest)) + while i <= len @inbounds dest[i] = bc[i] i += stride end @@ -108,12 +113,16 @@ end elseif ndims(dest) == 2 ## COV_EXCL_START function broadcast_2d(dest, bc) - is = Tuple(thread_position_in_grid_2d()) + # Use Int64 to avoid UInt32 overflow for large 2D arrays + pos = thread_position_in_grid_2d() stride = threads_per_grid_2d() - while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2) + is = (Int64(pos.x), Int64(pos.y)) + strides = (Int64(stride.x), Int64(stride.y)) + dims = (Int64(size(dest, 1)), Int64(size(dest, 2))) + while is[1] <= dims[1] && is[2] <= dims[2] I = CartesianIndex(is) @inbounds dest[I] = bc[I] - is = (is[1] + stride[1], is[2] + stride[2]) + is = (is[1] + strides[1], is[2] + strides[2]) end return end @@ -127,14 +136,16 @@ end elseif ndims(dest) == 3 ## COV_EXCL_START function broadcast_3d(dest, bc) - is = Tuple(thread_position_in_grid_3d()) + # Use Int64 to avoid UInt32 overflow for large 3D arrays + pos = thread_position_in_grid_3d() stride = threads_per_grid_3d() - while 1 <= is[1] <= size(dest, 1) && - 1 <= is[2] <= size(dest, 2) && - 1 <= is[3] <= size(dest, 3) + is = (Int64(pos.x), Int64(pos.y), Int64(pos.z)) + strides = (Int64(stride.x), Int64(stride.y), Int64(stride.z)) + dims = (Int64(size(dest, 1)), Int64(size(dest, 2)), Int64(size(dest, 3))) + while is[1] <= dims[1] && is[2] <= dims[2] && is[3] <= dims[3] I = CartesianIndex(is) @inbounds dest[I] = bc[I] - is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3]) + is = (is[1] + strides[1], is[2] + strides[2], is[3] + strides[3]) end return end @@ -150,9 +161,11 @@ end else ## COV_EXCL_START function broadcast_cartesian(dest, bc) - i = thread_position_in_grid().x - stride = threads_per_grid().x - while 1 <= i <= length(dest) + # Use Int64 to avoid UInt32 overflow for arrays > 4.29B elements (16GB Float32) + i = Int64(thread_position_in_grid().x) + stride = Int64(threads_per_grid().x) + len = Int64(length(dest)) + while i <= len I = @inbounds CartesianIndices(dest)[i] @inbounds dest[I] = bc[I] i += stride