Skip to content

Commit

Permalink
fix cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Sep 13, 2022
1 parent 1c41b77 commit dc71ab7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/layers/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
H = vcat(H, Hpad)
end
for i = 1:l.num_layers
M = _matmul(selectdim(l.weight, 3, i), H)
M = _matmul(l.weight[:, :, i], H)
_, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing)
H = apply_gru(l.gru, H, M)
end
Expand Down
2 changes: 1 addition & 1 deletion test/cuda/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
@test size(Y) == (out_channel, N, batch_size)

g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
@test length(g.grads) == 6
@test length(g.grads) == 7
end
end

Expand Down

0 comments on commit dc71ab7

Please sign in to comment.