Skip to content

Commit

Permalink
fix GatedGraphConv
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Sep 13, 2022
1 parent f1e442e commit 1c41b77
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
14 changes: 10 additions & 4 deletions src/layers/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,17 +556,23 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
H = vcat(H, Hpad)
end
for i = 1:l.num_layers
M = _matmul(view(l.weight, :, :, i), H)
M = _matmul(selectdim(l.weight, 3, i), H)
_, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing)
H, _ = l.gru(H, M)
H = apply_gru(l.gru, H, M)
end
return H
end

function apply_gru(gru, H::AbstractArray, M::AbstractArray)
H′ = apply_gru(gru, reshape(H, size(H, 1), :), reshape(M, size(M, 1), :))
return reshape(H′, size(H′, 1), size(H)[2:end]...)
end

apply_gru(gru, H::AbstractMatrix, M::AbstractMatrix) = gru(H, M)[1]

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
print(io, ", aggr=", l.aggr, ")")
end


Expand Down
18 changes: 9 additions & 9 deletions test/cuda/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
in_channel = 3
out_channel = 5
batch_size = 10

N = 4
adj = T[0 1 0 1;
1 0 1 0;
Expand Down Expand Up @@ -48,7 +48,7 @@
@test size(cc.weight) == (out_channel, in_channel, k)
@test size(cc.bias) == (out_channel,)
@test cc.k == k

fg = FeaturedGraph(adj, nf=X) |> gpu
fg_ = cc(fg)
@test size(node_feature(fg_)) == (out_channel, N)
Expand Down Expand Up @@ -106,12 +106,12 @@
@test size(gat.weight) == (out_channel * heads, in_channel)
@test size(gat.bias) == (out_channel * heads,)
@test size(gat.a) == (2*out_channel, heads)

X = rand(T, in_channel, N)
fg = FeaturedGraph(adj, nf=X) |> gpu
fg_ = gat(fg)
@test size(node_feature(fg_)) == (out_channel * heads, N)

g = gradient(() -> sum(node_feature(gat(fg))), Flux.params(gat))
@test length(g.grads) == 5
end
Expand All @@ -121,7 +121,7 @@
gat = WithGraph(fg, GATConv(in_channel=>out_channel, heads=2)) |> gpu
Y = gat(X |> gpu)
@test size(Y) == (out_channel * heads, N, batch_size)

g = gradient(() -> sum(gat(X |> gpu)), Flux.params(gat))
@test length(g.grads) == 4
end
Expand All @@ -145,11 +145,11 @@
@testset "layer with static graph" begin
X = rand(T, in_channel, N, batch_size)
ggc = WithGraph(fg, GatedGraphConv(out_channel, num_layers)) |> gpu
@test_broken Y = ggc(X |> gpu)
@test_broken size(Y) == (out_channel, N, batch_size)
Y = ggc(X |> gpu)
@test size(Y) == (out_channel, N, batch_size)

@test_broken g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
@test_broken length(g.grads) == 6
g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
@test length(g.grads) == 6
end
end

Expand Down
8 changes: 4 additions & 4 deletions test/layers/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,11 @@
@testset "layer with static graph" begin
X = rand(T, in_channel, N, batch_size)
ggc = WithGraph(fg, GatedGraphConv(out_channel, num_layers))
@test_broken Y = ggc(X)
@test_broken size(Y) == (out_channel, N, batch_size)
Y = ggc(X)
@test size(Y) == (out_channel, N, batch_size)

@test_broken g = gradient(() -> sum(ggc(X)), Flux.params(ggc))
@test_broken length(g.grads) == 6
g = gradient(() -> sum(ggc(X)), Flux.params(ggc))
@test length(g.grads) == 6
end
end

Expand Down

0 comments on commit 1c41b77

Please sign in to comment.