diff --git a/src/layers/graph_conv.jl b/src/layers/graph_conv.jl index 3d71e978e..02cb3cfaf 100644 --- a/src/layers/graph_conv.jl +++ b/src/layers/graph_conv.jl @@ -556,7 +556,7 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real H = vcat(H, Hpad) end for i = 1:l.num_layers - M = _matmul(selectdim(l.weight, 3, i), H) + M = _matmul(l.weight[:, :, i], H) _, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing) H = apply_gru(l.gru, H, M) end diff --git a/test/cuda/graph_conv.jl b/test/cuda/graph_conv.jl index 579b0d84b..dfe797e12 100644 --- a/test/cuda/graph_conv.jl +++ b/test/cuda/graph_conv.jl @@ -149,7 +149,7 @@ @test size(Y) == (out_channel, N, batch_size) g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc)) - @test length(g.grads) == 6 + @test length(g.grads) == 7 end end