Skip to content

Commit

Permalink
Merge pull request #313 from FluxML/doc
Browse files Browse the repository at this point in the history
Update doc for graph network
  • Loading branch information
yuehhua authored Jul 4, 2022
2 parents bc575ca + a688bfe commit 65c0162
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 11 deletions.
20 changes: 20 additions & 0 deletions docs/bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,23 @@ @inproceedings{Hamilton2017
title = {Inductive Representation Learning on Large Graphs},
year = {2017},
}

@article{Battaglia2018,
abstract = {Artificial intelligence (AI) has undergone a renaissance recently, making major progress in key domains such as vision, language, control, and decision-making. This has been due, in part, to cheap data and cheap compute resources, which have fit the natural strengths of deep learning. However, many defining characteristics of human intelligence, which developed under much different pressures, remain out of reach for current approaches. In particular, generalizing beyond one's experiences--a hallmark of human intelligence from infancy--remains a formidable challenge for modern AI. The following is part position paper, part review, and part unification. We argue that combinatorial generalization must be a top priority for AI to achieve human-like abilities, and that structured representations and computations are key to realizing this objective. Just as biology uses nature and nurture cooperatively, we reject the false choice between "hand-engineering" and "end-to-end" learning, and instead advocate for an approach which benefits from their complementary strengths. We explore how using relational inductive biases within deep learning architectures can facilitate learning about entities, relations, and rules for composing them. We present a new building block for the AI toolkit with a strong relational inductive bias--the graph network--which generalizes and extends various approaches for neural networks that operate on graphs, and provides a straightforward interface for manipulating structured knowledge and producing structured behaviors. We discuss how graph networks can support relational reasoning and combinatorial generalization, laying the foundation for more sophisticated, interpretable, and flexible patterns of reasoning. As a companion to this paper, we have released an open-source software library for building graph networks, with demonstrations of how to use them in practice.},
author = {Peter W. Battaglia and Jessica B. Hamrick and Victor Bapst and Alvaro Sanchez-Gonzalez and Vinicius Zambaldi and Mateusz Malinowski and Andrea Tacchetti and David Raposo and Adam Santoro and Ryan Faulkner and Caglar Gulcehre and Francis Song and Andrew Ballard and Justin Gilmer and George Dahl and Ashish Vaswani and Kelsey Allen and Charles Nash and Victoria Langston and Chris Dyer and Nicolas Heess and Daan Wierstra and Pushmeet Kohli and Matt Botvinick and Oriol Vinyals and Yujia Li and Razvan Pascanu},
journal = {ArXiv},
month = {6},
title = {Relational inductive biases, deep learning, and graph networks},
url = {http://arxiv.org/abs/1806.01261},
year = {2018},
}

@inproceedings{Gilmer2017,
abstract = {Supervised learning on molecules has incredible potential to be useful in chemistry, drug discovery, and materials science. Luckily, several promising and closely related neural network models invariant to molecular symmetries have already been described in the literature. These models learn a message passing algorithm and aggregation procedure to compute a function of their entire input graph. At this point, the next step is to find a particularly effective variant of this general approach and apply it to chemical prediction benchmarks until we either solve them or reach the limits of the approach. In this paper, we reformulate existing models into a single common framework we call Message Passing Neural Networks (MPNNs) and explore additional novel variations within this framework. Using MPNNs we demonstrate state of the art results on an important molecular property prediction benchmark; these results are strong enough that we believe future work should focus on datasets with larger molecules or more accurate ground truth labels.},
author = {Justin Gilmer and Samuel S. Schoenholz and Patrick F. Riley and Oriol Vinyals and George E. Dahl},
booktitle = {ICML 2017},
month = {4},
title = {Neural Message Passing for Quantum Chemistry},
url = {http://arxiv.org/abs/1704.01212},
year = {2017},
}
40 changes: 38 additions & 2 deletions docs/src/abstractions/gn.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
# Graph network block

Graph network (GN) is a more generic model for graph neural network. It describes an update order: edge, node and then global. There are three corresponding update functions for edge, node and then global, respectively. Three update functions return their default values as follow:
Graph network (GN) is a more generic model for graph neural network. For details, a graph network block is defined as follow:

```math
\begin{aligned}
e_{ij}^{\prime} &= \phi^{e}(e_{ij}, v_i, v_j, u) \\
v_{i}^{\prime} &= \phi^{v}(\bar{e}_{i}^{\prime}, v_i, u) \\
u^{\prime} &= \phi^{u}(\bar{e}^{\prime}, \bar{v}^{\prime}, u)
\end{aligned}
\ \ \ \
\begin{aligned}
\bar{e}_{i}^{\prime} &= \rho^{e \rightarrow v}(E_i^{\prime}) \\
\bar{e}^{\prime} &= \rho^{e \rightarrow u}(E^{\prime}) \\
\bar{v}^{\prime} &= \rho^{v \rightarrow u}(V^{\prime})
\end{aligned}
```

where ``\phi`` and ``\rho`` denote update functions and aggregate functions, respectively. ``v_i`` and ``v_j`` are node features from node ``i`` and its neighbor node ``j``, ``e_{ij}`` is edge feature for edge ``(i, j)``, and ``u`` is global feature for whole graph. ``e_{ij}^{\prime}``, ``v_{i}^{\prime}`` and ``u^{\prime}`` are new features for edge, node and global. ``\bar{e}_{i}^{\prime}``, ``\bar{e}^{\prime}`` and ``\bar{v}^{\prime}`` are aggregated new features for edge, node and global.

> Reference: [Battaglia2018](@cite)
Ingraph network, it describes an update order: edge, node and then global. There are three corresponding update functions for edge, node and then global, respectively. Three update functions return their default values as follow:

```
update_edge(gn, e, vi, vj, u) = e
Expand All @@ -16,9 +36,25 @@ GN block is realized into a abstract type `GraphNet`. User can make a subtype of

`update_edge` acts as the first update function to apply to edge states. It takes edge state `e`, node `i` state `vi`, node `j` state `vj` and global state `u`. It is expected to return a feature vector for new edge state. `update_vertex` updates nodes state by taking aggregated edge state ``, node `i` state `vi` and global state `u`. It is expected to return a feature vector for new node state. `update_global` updates global state with aggregated information from edge and node. It takes aggregated edge state ``, aggregated node state `` and global state `u`. It is expected to return a feature vector for new global state. User can define their own behavior by overriding update functions.

```@docs
GeometricFlux.update_edge
GeometricFlux.update_vertex
GeometricFlux.update_global
GeometricFlux.update_batch_edge
GeometricFlux.update_batch_vertex
```

## Aggregate functions

An aggregate function `aggregate_neighbors` aggregates edge states for edges incident to some node `i` into node-level information. Aggregate function `aggregate_edges` aggregates all edge states into global-level information. The last aggregate function `aggregate_vertices` aggregates all vertex states into global-level information. It is available for assigning aggregate function by assigning aggregate operations to `propagate` function.
An aggregate function `aggregate_neighbors` aggregates edge states for edges incident to some node `i` into node-level information. Aggregate function `aggregate_edges` aggregates all edge states into global-level information. The last aggregate function `aggregate_vertices` aggregates all vertex states into global-level information.

```@docs
GeometricFlux.aggregate_neighbors
GeometricFlux.aggregate_edges
GeometricFlux.aggregate_vertices
```

It is available for assigning aggregate function by assigning aggregate operations to `propagate` function.

```
propagate(gn, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
Expand Down
12 changes: 12 additions & 0 deletions docs/src/abstractions/msgpass.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

Message passing scheme is a popular GNN scheme in many frameworks. It adapts the property of connectivity of neighbors and form a general approach for spatial graph convolutional neural network. It comprises two user-defined functions and one aggregate function. A message function is defined to process information from edge states and node states from neighbors and itself. Messages from each node are obtained and aggregated by aggregate function to provide node-level information for update function. Update function takes current node state and aggregated message and gives a new node state.

```math
\begin{aligned}
m_{ij}^{(l+1)} &= message(h_i^{(l)}, h_j^{(l)}, e_{ij}^{(l)}) \\
m_{i}^{(l+1)} &= \Box_{j \in \mathcal{N}(i)} m_{ij}^{(l+1)} \\
h_i^{(l+1)} &= update(h_i^{(l)}, m_{i}^{(l+1)})
\end{aligned}
```

where ``h_i`` and ``h_j`` are node features from node ``i`` and its neighbor node ``j``, ``e_{ij}`` is edge feature for edge ``(i, j)``, and ``u`` is global feature for whole graph. ``m_{ij}^{(l+1)}`` denotes messages for ``(i, j)`` in ``l``-th layer. ``message`` and ``update`` are message functions and update function, respectively. Aggregate function ``\Box`` can be any supported aggregate functions, e.g. `max`, `sum` or `mean`.

> Reference: [Gilmer2017](@cite)
Message passing scheme is realized into a abstract type `MessagePassing`. Any subtype of `MessagePassing` is a message passing layer which utilize default message and update functions:

```
Expand Down
139 changes: 131 additions & 8 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,77 @@
abstract type GraphNet <: AbstractGraphLayer end

"""
update_edge(gn, e, vi, vj, u)
Update function for edge feature in graph network.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `e`: Edge feature.
- `vi`: Node feature for node `i`.
- `vj`: Node feature for neighbors of node `i`.
- `u`: Global feature.
See also [`update_vertex`](@ref), [`update_global`](@ref), [`update_batch_edge`](@ref),
[`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
update_edge(::GraphNet, e, vi, vj, u) = e

"""
update_vertex(gn, ē, vi, u)
Update function for node feature in graph network.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `ē`: Aggregated edge feature.
- `vi`: Node feature for node `i`.
- `u`: Global feature.
See also [`update_edge`](@ref), [`update_global`](@ref), [`update_batch_edge`](@ref),
[`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
update_vertex(::GraphNet, ē, vi, u) = vi

"""
update_global(gn, ē, v̄, u)
Update function for global feature in graph network.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `ē`: Aggregated edge feature.
- `v̄`: Aggregated node feature for node `i`.
- `u`: Global feature.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_batch_edge`](@ref),
[`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
update_global(::GraphNet, ē, v̄, u) = u

"""
update_batch_edge(gn, el, E, V, u)
Returns new edge features of size `(E_out_dim, #E, [batch_size])`.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `el::NamedTuple`: Collection of graph information.
- `E`: All edge features. Its size should be `(E_in_dim, #E, [batch_size])`.
- `V`: All node features.
- `u`: Global features.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref),
[`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
update_batch_edge(gn::GraphNet, el::NamedTuple, E, V, u) =
update_edge(
gn,
Expand All @@ -13,8 +81,41 @@ update_batch_edge(gn::GraphNet, el::NamedTuple, E, V, u) =
u
)

"""
update_batch_vertex(gn, el, Ē, V, u)
Returns new node features of size `(V_out_dim, #V, [batch_size])`.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `el::NamedTuple`: Collection of graph information.
- `Ē`: All edge features. Its size should be `(E_in_dim, #V, [batch_size])`.
- `V`: All node features. Its size should be `(V_in_dim, #V, [batch_size])`.
- `u`: Global features.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref),
[`update_batch_edge`](@ref), [`aggregate_neighbors`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
update_batch_vertex(gn::GraphNet, ::NamedTuple, Ē, V, u) = update_vertex(gn, Ē, V, u)

"""
aggregate_neighbors(gn, el, aggr, E)
Returns aggregated neighbor features of size `(E_out_dim, #V, [batch_size])`.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `el::NamedTuple`: Collection of graph information.
- `aggr`: Aggregate function to apply on neighbor features.
- `E`: All edge features from neighbors. Its size should be `(E_out_dim, #E, [batch_size])`.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref),
[`update_batch_edge`](@ref), [`update_batch_vertex`](@ref), [`aggregate_edges`](@ref),
[`aggregate_vertices`](@ref).
"""
function aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E)
batch_size = size(E)[end]
dstsize = (size(E, 1), el.N, batch_size)
Expand All @@ -27,9 +128,39 @@ aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E::AbstractMatrix) = _scat
@inline aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, E) = nothing
@inline aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing

"""
aggregate_edges(gn, aggr, E)
Returns aggregated edge features of size `(E_out_dim, 1, [batch_size])` for updating global feature.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `aggr`: Aggregate function to apply on edge features.
- `E`: All edge features. Its size should be `(E_out_dim, #E, [batch_size])`.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref),
[`update_batch_edge`](@ref), [`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref),
[`aggregate_vertices`](@ref).
"""
aggregate_edges(::GraphNet, aggr, E) = aggregate(aggr, E)
@inline aggregate_edges(::GraphNet, ::Nothing, E) = nothing

"""
aggregate_vertices(gn, aggr, V)
Returns aggregated node features of size `(V_out_dim, 1, [batch_size])` for updating global feature.
# Arguments
- `gn::GraphNet`: A graph network layer.
- `aggr`: Aggregate function to apply on node features.
- `V`: All node features. Its size should be `(V_out_dim, #V, [batch_size])`.
See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref),
[`update_batch_edge`](@ref), [`update_batch_vertex`](@ref), [`aggregate_neighbors`](@ref),
[`aggregate_edges`](@ref).
"""
aggregate_vertices(::GraphNet, aggr, V) = aggregate(aggr, V)
@inline aggregate_vertices(::GraphNet, ::Nothing, V) = nothing

Expand All @@ -38,14 +169,6 @@ function propagate(gn::GraphNet, sg::SparseGraph, E, V, u, naggr, eaggr, vaggr)
return propagate(gn, el, E, V, u, naggr, eaggr, vaggr)
end

"""
- `update_batch_edge`: (E_in_dim, E) -> (E_out_dim, E)
- `aggregate_neighbors`: (E_out_dim, E) -> (E_out_dim, V)
- `update_batch_vertex`: (V_in_dim, V) -> (V_out_dim, V)
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim, 1)
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim, 1)
- `update_global`: (dim, 1) -> (dim, 1)
"""
function propagate(gn::GraphNet, el::NamedTuple, E, V, u, naggr, eaggr, vaggr)
E = update_batch_edge(gn, el, E, V, u)
= aggregate_neighbors(gn, el, naggr, E)
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ tests = [
"layers/pool",
"layers/graphlayers",
"sampling",
"embedding/node2vec",
"models",
]

Expand All @@ -35,6 +34,10 @@ else
@warn "CUDA unavailable, not testing GPU support"
end

if !Sys.iswindows()
push!(tests, "embedding/node2vec")
end

@testset "GeometricFlux" begin
for t in tests
include("$(t).jl")
Expand Down

0 comments on commit 65c0162

Please sign in to comment.