diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 417b90d4..616f0ea2 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -2,6 +2,7 @@ module Optimisers using Functors: functor, fmap, isleaf using LinearAlgebra +using Base.Broadcast: broadcast_preserving_zero_d, broadcasted include("interface.jl") diff --git a/src/interface.jl b/src/interface.jl index 235c2e94..d4b0b542 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -21,7 +21,13 @@ function setup(rule, x; seen = Base.IdSet()) end end -subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) +function subtract!(x, x̄) + if iswriteable(x) + x .= x .- x̄ + else + broadcast_preserving_zero_d(eltype(x), broadcasted(-, x, x̄)) + end +end update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x update!(::Nothing, x, x̄s...) = nothing, x diff --git a/test/rules.jl b/test/rules.jl index ffb4ca65..63aef5bb 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -226,3 +226,31 @@ end @test static_loss(static_model) < 1.9 end end + +@testset "zero-dim arrays" begin + empty!(LOG) + @testset "$(name(o))" for o in RULES + m = (; arr = fill(1.0), pda = PermutedDimsArray(fill(1.0), ()), ref = Ref(1.0)) + # The point of PermutedDimsArray here is to test the out-of-place path, so check: + @test Optimisers.iswriteable(m.arr) + @test !Optimisers.iswriteable(m.pda) + s = Optimisers.setup(o, m) + for _ in 1:10^3 + g = loggradient(o)(x -> abs2(first(x.arr) + first(x.pda) + first(x.ref)), m)[1] + s, m = Optimisers.update(s, m, g) + end + # Goal is to check that broadcasting does not accidentally make a scalar, + # but `m.arr` iscopied & mutated, so only `m.pda` is a real test: + @test m.arr isa Array{Float64, 0} + @test m.pda isa AbstractArray{Float64, 0} + @test m.ref isa Ref # because it's mutated, broadcast_preserving_zero_d would make an array + if o isa RADAM + @test sum(m.arr) < 0.7 + @test_broken sum(m.arr) < 0.3 + else + @test sum(m.arr) < 0.3 + @test sum(m.pda) < 0.3 + end + @test only(m.ref) ≈ 1 # not currently regarded as trainable + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d47bce08..e01ca411 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end @testset "trainable subset" begin - @info "ignore these warnings about trainable, testing the old path" + @info "ignore these warnings about `trainable`, they are testing the path for old-style methods" # Foo has an old-style tuple trainable, both elements mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5)) sf = Optimisers.setup(Descent(0.1), mf) @@ -131,6 +131,30 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test eltype(m4[2]) == Float32 end + @testset "zero dimension" begin + # Mutable Array{T,0} + m = fill(1.0) + s = Optimisers.setup(Descent(0.1), m) + s2, m2 = Optimisers.update!(s, m, fill(2.0)) + @test m2 === m + @test only(m) ≈ 0.8 + + # "Immutable" zero-array, takes out-of-place path: + m3 = PermutedDimsArray(fill(1.0), ()) + @test !Optimisers.iswriteable(m3) # note that there's Base.iswritable, it seems I can't spell + s3 = Optimisers.setup(Descent(0.1), m3) + s4, m4 = Optimisers.update!(s3, m3, fill(2.0)) + @test m4 !== m3 + @test only(m4) ≈ 0.8 + + # Ref, should this be regarded as holding a parameter? At present it's not: + m5 = Ref(1.0) + s5 = Optimisers.setup(Descent(0.1), m5) + g5 = gradient(m -> m[]^2, m5)[1] # (x = 2.0,) + s6, m6 = Optimisers.update!(s5, m5, g5) + @test m6[] ≈ 1 + end + @testset "forgotten gradient" begin x = [1.0, 2.0] sx = Optimisers.setup(Descent(), x)