Skip to content

Commit

Permalink
fix ad tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Jan 4, 2025
1 parent 9ed3264 commit bbad0fb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lib/YaoBlocks/src/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module AD
using BitBasis, YaoArrayRegister, YaoAPI
using ..YaoBlocks
import ChainRulesCore:
rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent
rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent, AbstractThunk, unthunk
import YaoAPI: mat_back!, apply_back!
using SparseArrays, LuxurySparse, LinearAlgebra

Expand Down
5 changes: 3 additions & 2 deletions lib/YaoBlocks/src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ unsafe_primitive_tangent(x::Number) = x
for GT in [:RotationGate, :ShiftGate, :PhaseGate, :(Scale{<:Number})]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_primitive_tangent(getfield(c, fn))
fn => unsafe_primitive_tangent(unthunk(getfield(c, fn)))
end
nt = NamedTuple(lst)
Tangent{typeof(c),typeof(nt)}(nt)
Expand Down Expand Up @@ -46,7 +46,7 @@ for GT in [
]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_composite_tangent(getfield(c, fn))
fn => unsafe_composite_tangent(unthunk(getfield(c, fn)))
end
nt = NamedTuple(lst)
Tangent{typeof(c),typeof(nt)}(nt)
Expand Down Expand Up @@ -209,6 +209,7 @@ rrule(::typeof(parent), reg::AdjointArrayReg) = parent(reg), adjy -> (NoTangent(
rrule(::typeof(Base.adjoint), reg::AbstractArrayReg) =
Base.adjoint(reg), adjy -> (NoTangent(), parent(adjy))

_totype(::Type{T}, x::AbstractThunk) where {T} = _totype(T, unthunk(x))
_totype(::Type{T}, x::AbstractArray{T}) where {T} = x
_totype(::Type{T}, x::AbstractArray{T2}) where {T,T2} = convert.(T, x)
_match_type(::ArrayReg{D}, mat) where D = ArrayReg{D}(mat)
Expand Down
15 changes: 6 additions & 9 deletions lib/YaoBlocks/test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Zygote, ForwardDiff
using Random, Test
using YaoBlocks, YaoArrayRegister
using ChainRulesCore: Tangent
using ChainRulesCore: Tangent, unthunk, AbstractThunk

@testset "recursive_create_tangent" begin
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
Expand All @@ -10,16 +10,13 @@ using ChainRulesCore: Tangent
end

@testset "construtors" begin
@test Zygote.gradient(x -> x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1] == (n=nothing,
list = NamedTuple{
(:n, :blocks,),
Tuple{Nothing, Vector{NamedTuple{(:block, :theta),Tuple{Nothing,Float64}}}},
}[(n=nothing, blocks = [(block = nothing, theta = 1.0)],)],
)
@test Zygote.gradient(
res = Zygote.gradient(x -> x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1]
@test res.list[].blocks[1].theta 1.0
res = Zygote.gradient(
x -> getfield(getfield(x, :content), :theta),
Daggered(Rx(0.5)),
)[1] == (content = (block = nothing, theta = 1.0),)
)[1]
@test res.content.theta 1.0
end

@testset "rules" begin
Expand Down

0 comments on commit bbad0fb

Please sign in to comment.