Skip to content

Commit

Permalink
Fix issues with in-place map/broadcast (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 15, 2025
1 parent a57e030 commit 3d774aa
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
74 changes: 37 additions & 37 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,25 @@ end
return SparseArrayDOK{T}(size...)
end

# map over a specified subset of indices of the inputs.
function map_indices! end

@interface interface::AbstractArrayInterface function map_indices!(
indices, f, a_dest::AbstractArray, as::AbstractArray...
)
for I in indices
a_dest[I] = f(map(a -> a[I], as)...)
end
return a_dest
end

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachstoredindex(as...)
a_dest[I] = f(map(a -> a[I], as)...)
end
@interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...)
return a_dest
end

Expand All @@ -221,9 +231,7 @@ function map_all! end
@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachindex(as...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
@interface interface map_indices!(eachindex(as...), f, a_dest, as...)
return a_dest
end

Expand All @@ -242,38 +250,32 @@ using ArrayLayouts: ArrayLayouts, zero!
return @interface interface map_stored!(f, a, a)
end

# Determines if a function preserves the stored values
# of the destination sparse array.
# The current code may be inefficient since it actually
# accesses an unstored element, which in the case of a
# sparse array of arrays can allocate an array.
# Sparse arrays could be expected to define a cheap
# unstored element allocator, for example
# `get_prototypical_unstored(a::AbstractArray)`.
function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
I = first(eachindex(as...))
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
# to determine if a function preserves the stored values
# of the destination sparse array.
# The current code may be inefficient since it actually
# accesses an unstored element, which in the case of a
# sparse array of arrays can allocate an array.
# Sparse arrays could be expected to define a cheap
# unstored element allocator, for example
# `get_prototypical_unstored(a::AbstractArray)`.
I = first(eachindex(as...))
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
if !preserves_unstored
# Doesn't preserve unstored values, loop over all elements.
@interface interface map_all!(f, a_dest, as...)
return a_dest
indices = if !preserves_unstored(f, a_dest, as...)
eachindex(a_dest)
elseif any(a -> a_dest !== a, as)
as = map(a -> Base.unalias(a_dest, a), as)
@interface interface zero!(a_dest)
eachstoredindex(as...)
else
eachstoredindex(a_dest)
end
# First zero out the destination.
# TODO: Make this more nuanced, skip when possible, for
# example if the sparsity of the destination is a subset of
# the sparsity of the sources, i.e.:
# ```julia
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
# zero!(a_dest)
# end
# ```
# This is the safest thing to do in general, for example
# if the destination is dense but the sources are sparse.
@interface interface zero!(a_dest)
@interface interface map_stored!(f, a_dest, as...)
@interface interface map_indices!(indices, f, a_dest, as...)
return a_dest
end

Expand Down Expand Up @@ -357,9 +359,7 @@ function sparse_mul!(
β::Number=false;
(mul!!)=(default_mul!!),
)
# TODO: Change to: `a_dest .*= β`
# once https://github.com/ITensor/SparseArraysBase.jl/issues/19 is fixed.
storedvalues(a_dest) .*= β
a_dest .*= β
β′ = one(Bool)
for I1 in eachstoredindex(a1)
for I2 in eachstoredindex(a2)
Expand Down
38 changes: 38 additions & 0 deletions test/basics/test_sparsearraydok.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,42 @@ arrayts = (Array,)
a[1, 2] = 12
@test sprint(show, "text/plain", a) == "$(summary(a)):\n$(eltype(a)(12))\n ⋅ ⋅"
end

# Regression test for:
# https://github.com/ITensor/SparseArraysBase.jl/issues/19
a = SparseArrayDOK{elt}(2, 2)
a[1, 1] = 1
a .*= 2
@test a == [2 0; 0 0]
@test storedlength(a) == 1

# Test aliasing behavior.
a = SparseArrayDOK{elt}(2, 2)
a[1, 1] = 11
a[1, 2] = 12
a[2, 2] = 22
c1 = @view a[:, 1]
r1 = @view a[1, :]
r1 .= c1
@test c1 == [11, 0]
@test storedlength(c1) == 1
@test r1 == [11, 0]
@test storedlength(r1) == 2
@test a == [11 0; 0 22]
@test storedlength(a) == 3

# Test aliasing behavior.
a = SparseArrayDOK{elt}(2, 2)
a[1, 1] = 11
a[1, 2] = 12
a[2, 2] = 22
c1 = @view a[:, 1]
r1 = @view a[1, :]
c1 .= r1
@test c1 == [11, 12]
@test storedlength(c1) == 2
@test r1 == [11, 12]
@test storedlength(r1) == 2
@test a == [11 12; 12 22]
@test storedlength(a) == 4
end

0 comments on commit 3d774aa

Please sign in to comment.