Skip to content

Adding AQSOL Dataset #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 22, 2024
1 change: 1 addition & 0 deletions docs/src/datasets/graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ MLDatasets.HeteroGraph
```

```@docs
AQSOL
ChickenPox
CiteSeer
Cora
Expand Down
4 changes: 3 additions & 1 deletion src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ include("datasets/graphs/citeseer.jl")
export CiteSeer
include("datasets/graphs/karateclub.jl")
export KarateClub
include("datasets/graphs/AQSOL.jl")
export AQSOL
include("datasets/graphs/movielens.jl")
export MovieLens
include("datasets/graphs/ogbdataset.jl")
Expand Down Expand Up @@ -151,6 +153,7 @@ function __init__()
# TODO automatically find and execute __init__xxx functions

# graph
__init__aqsol()
__init__chickenpox()
__init__citeseer()
__init__cora()
Expand All @@ -166,7 +169,6 @@ function __init__()
__init__temporalbrains()
__init__windmillenergy()


# misc
__init__iris()
__init__mutagenesis()
Expand Down
100 changes: 100 additions & 0 deletions src/datasets/graphs/AQSOL.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
function __init__aqsol()
DEPNAME = "AQSOL"
LINK = "https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1"
register(DataDep(DEPNAME,
"""
Dataset: The AQSOL dataset.
Website: http://arxiv.org/abs/2003.00982
""",
LINK,
post_fetch_method = unpack))
end

struct AQSOL <: AbstractDataset
split::Symbol
metadata::Dict{String,Any}
graphs::Vector{Graph}
end

"""
AQSOL(; split=:train, dir=nothing)

The AQSOL (Aqueous Solubility) dataset from the paper
[Graph Neural Network for Predicting Aqueous Solubility of Organic Molecules](http://arxiv.org/abs/2003.00982).

The dataset contains 9,882 graphs representing small organic molecules. Each graph represents a molecule, where nodes correspond to atoms and edges to bonds. The node features represent the atomic number, and the edge features represent the bond type. The target is the aqueous solubility of the molecule, measured in mol/L.

# Arguments

- `split`: Which split of the dataset to load. Can be one of `:train`, `:val`, or `:test`. Defaults to `:train`.
- `dir`: Directory in which the dataset is in.

# Examples

```julia-repl
julia> using MLDatasets

julia> data = AQSOL()
dataset AQSOL:
split => :train
metadata => Dict{String, Any} with 1 entry
graphs => 7985-element Vector{MLDatasets.Graph}

julia> length(data)
7985

julia> g = data[1]
Graph:
num_nodes => 23
num_edges => 42
edge_index => ("42-element Vector{Int64}", "42-element Vector{Int64}")
node_data => (features = "23-element Vector{Int64}",)
edge_data => (features = "42-element Vector{Int64}",)

julia> g.num_nodes
23

julia> g.node_data.features
23-element Vector{Int64}:
0
1
1
1
1
1

julia> g.edge_index
([2, 3, 3, 4, 4, 5, 5, 6, 6, 7 … 18, 19, 19, 20, 20, 21, 20, 22, 20, 23], [3, 2, 4, 3, 5, 4, 6, 5, 7, 6 … 19, 18, 20, 19, 21, 20, 22, 20, 23, 20])
```
"""
function AQSOL(;split=:train, dir=nothing)
@assert split ∈ [:train, :val, :test]
DEPNAME = "AQSOL"
path = datafile(DEPNAME, "asqol_graph_raw/$(split).pickle", dir)
graphs = Pickle.npyload(path)
g = [create_aqsol_graph(g...) for g in graphs]
metadata = Dict{String, Any}("n_observations" => length(g))
return AQSOL(split, metadata, g)
end

function create_aqsol_graph(x, edge_attr, edge_index, y)
x = Int.(x)
edge_attr = Int.(edge_attr)
edge_index = Int.(edge_index .+ 1)

if size(edge_index, 2) == 0
s, t = Int[], Int[]
else
s, t = edge_index[1, :], edge_index[2, :]
end

return Graph(; num_nodes = length(x),
edge_index = (s, t),
node_data = (features = x,),
edge_data = (features = edge_attr,))
end

Base.length(d::AQSOL) = length(d.graphs)
Base.getindex(d::AQSOL, ::Colon) = d.graphs
Base.getindex(d::AQSOL, i) = getindex(d.graphs, i)
2 changes: 1 addition & 1 deletion src/datasets/graphs/movielens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,4 @@ Base.length(data::MovieLens) = length(data.graphs)
function Base.getindex(data::MovieLens, ::Colon)
length(data.graphs) == 1 ? data.graphs[1] : data.graphs
end
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
20 changes: 20 additions & 0 deletions test/datasets/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,23 @@ end
@test size(g.edge_data.features) == (2, g.num_edges)
@test size(g.edge_data.targets) == (g.num_edges,)
end

@testset "AQSOL" begin
split_counts = Dict(:train => 7985, :val => 998, :test => 999)
for split in [:train, :val, :test]
data = AQSOL(split=split)
@test data isa AbstractDataset
@test data.split == split
@test length(data) == data.metadata["n_observations"]
@test length(data.graphs) == split_counts[split]

i = rand(1:length(data))
g = data[i]
@test g isa MLDatasets.Graph
s, t = g.edge_index
@test all(1 .<= s .<= g.num_nodes)
@test all(1 .<= t .<= g.num_nodes)
@test length(s) == g.num_edges
@test length(t) == g.num_edges
end
end
Loading