Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge refactor #2

Merged
merged 31 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
27c1b8e
standardize trigger logic
joannajzou May 16, 2024
3d41840
use AtomsCalculators logic for interactions
joannajzou May 21, 2024
a0eb8fe
reorganize training functions
joannajzou May 21, 2024
e62ef90
standardize triggers to receive sys.data Dict
joannajzou May 21, 2024
e25d60d
move get_values, has_step_property
joannajzou May 21, 2024
1df4946
add observable value to TriggerLogger
joannajzou May 21, 2024
7cbf74e
make fixed sys in sSVGD optional
joannajzou May 21, 2024
2f47ec4
consolidate system/ensemble functions
joannajzou May 21, 2024
215e250
update Molly and add AtomsCalculators dep
joannajzou May 21, 2024
8b6bafb
dispatch fisher_divergence on Integrator type
joannajzou May 21, 2024
7b2d387
interface AtomsBase & Molly Systems
joannajzou May 22, 2024
75038e1
create tests
joannajzou May 22, 2024
1a387b4
update compute_descriptor methods for PolynomialChaos
joannajzou May 22, 2024
fc9fabd
rearrangements
joannajzou May 22, 2024
837909d
add subselector methods
joannajzou May 23, 2024
685233e
update examples
joannajzou May 23, 2024
855e3f8
minor corrections to plots
joannajzou May 23, 2024
e23537b
restructure ALRoutine
joannajzou May 23, 2024
cafc131
restructure activelearning dir
joannajzou Jun 11, 2024
e6fd792
update compute_local_descriptors to save descriptor data
joannajzou Jun 11, 2024
7029343
GP variance methods and tests
joannajzou Jun 11, 2024
55c3afa
update training dir
joannajzou Jun 11, 2024
26d8411
update System definition in data/system
joannajzou Jun 11, 2024
e6b6048
add ProgressBars with Threads
joannajzou Jun 11, 2024
5912357
minor fix
joannajzou Jun 20, 2024
d6cc7ff
add MaxVolSubset <: SubsetSelector
joannajzou Jun 20, 2024
adf7f06
add KMeans <: SubsetSelector
joannajzou Jun 20, 2024
c407462
add Clustering dep
joannajzou Jun 20, 2024
cd6cffc
minor fix to kmeans
joannajzou Jun 20, 2024
f038e84
update README
joannajzou Jul 11, 2024
a9c6794
minor fix to kmeans
joannajzou Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ version = "0.1.0"
[deps]
AtomisticQoIs = "895e25ce-6034-4689-a3ba-4ac45d83446c"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Maxvol = "4cc553b9-be87-484b-81d9-b5ae2a4e3958"
Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
PotentialLearning = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
Expand All @@ -28,7 +31,6 @@ UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
[compat]
AtomsBase = "0.3"
Distributions = "0.25"
Molly = "0.18.3"
julia = "1.9"

[extras]
Expand Down
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://cesmix-mit.github.io/Cairn.jl/dev/)
[![Build Status](https://github.com/cesmix-mit/Cairn.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/cesmix-mit/Cairn.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Build Status](https://travis-ci.com/cesmix-mit/Cairn.jl.svg?branch=main)](https://travis-ci.com/cesmix-mit/Cairn.jl)
[![Build Status](https://ci.appveyor.com/api/projects/status/github/cesmix-mit/Cairn.jl?svg=true)](https://ci.appveyor.com/project/cesmix-mit/Cairn-jl)
<!-- [![Build Status](https://ci.appveyor.com/api/projects/status/github/cesmix-mit/Cairn.jl?svg=true)](https://ci.appveyor.com/project/cesmix-mit/Cairn-jl)
[![Coverage](https://codecov.io/gh/cesmix-mit/Cairn.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/cesmix-mit/Cairn.jl)
[![Coverage](https://coveralls.io/repos/github/cesmix-mit/Cairn.jl/badge.svg?branch=main)](https://coveralls.io/github/cesmix-mit/Cairn.jl?branch=main)
[![Coverage](https://coveralls.io/repos/github/cesmix-mit/Cairn.jl/badge.svg?branch=main)](https://coveralls.io/github/cesmix-mit/Cairn.jl?branch=main) -->

Cairn.jl is a toolkit of active learning algorithms for training machine learning interatomic potentials (ML-IPs) for molecular dynamics simulation.

Cairn.jl is constructed as an extension to [Molly.jl](https://github.com/JuliaMolSim/Molly.jl), implementing enhanced MD samplers, and interfaces with other packages in the Julia ecosystem for molecular simulation, developed by [CESMIX](https://github.com/cesmix-mit) and [JuliaMolSim](https://github.com/JuliaMolSim).

Active learning algorithms build efficient training datasets which maximally improve accuracy of a scientific machine learning model. These algorithms take an iterative structure, looping through the steps:

1. **Data generation**. A system's potential energy landscape is sampled by generating trajectories of molecular configurations through the simulation of Newton's equation of motion or its modifications. Users have a choice between standard MD simulation, such as Langevin dynamics or Velocity-Verlet, or enhanced sampling algorithms, such as uncertainty driven dynamics ([UDD](https://www.nature.com/articles/s43588-023-00406-5)), Stein repulsive Langevin dynamics, or Stein variational molecular dynamics. These methods are specified under the abstract type `Simulator`.

2. **Trigger for retraining.** Sampling is terminated and retraining is triggered when the trajectory has met a certain criteria. A "fixed trigger" calls on retraining after a fixed number of simulation steps. "Adaptive triggers" are based on metrics of uncertainty, from Gaussian process or ensemble-based estimates of variance; metrics of extrapolation, based on a MaxVol vector basis; or metrics of diversity, such as a DPP inclusion probability. These methods are specified under the abstract type `ActiveLearningTrigger`.

3. **Data subset selection and labelling.** A subset of the data from the simulated trajectory is selected for labelling using reference calculations and appending to the training set. The most basic selection is a random subset of the trajectory. "Adaptive" selections can be made based on data which exceeds a threshold or data which are chosen by a subset selection algorithm, such as MaxVol, k-means clustering, or DPPs. These methods are specified under the abstract type `SubsetSelector`.

4. **Model updating.** The machine learning model is retrained on the augmented dataset according to the choice of objective function defined by the abstract type `LinearProblem`. These methods live in the package [PotentialLearning.jl](https://github.com/cesmix-mit/PotentialLearning.jl).


For a technical manual on the package, see the [docs](cesmix-mit.github.io/Cairn.jl/). For a demo, see the Jupyter notebooks in the `examples` folder.


219 changes: 219 additions & 0 deletions examples/himmelblau_train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
using Cairn
using LinearAlgebra, Random, Statistics, StatsBase, Distributions
using PotentialLearning
using Molly, AtomsCalculators
using AtomisticQoIs
using SpecialPolynomials, SpecialFunctions

include("./src/makie/makie.jl")
include("./examples/utils.jl")



## define models ------------------------------------------------------------------
# choose reference model
ref = Himmelblau()

# define main support
limits = [[-6.5,6.5],[-6,6]]
# limits = [[-3.5,1.5],[-1.5,3.5]]
coord_grid = coord_grid_2d(limits, 0.1)
ctr_lvls = 0:25:400

# PCE properties
basisfam = Jacobi{0.5,0.5}
order = 5
pce0 = PolynomialChaos(order, 2, basisfam, xscl=limits)

# grid over main support
coords_eval = potential_grid_2d(ref, limits, 0.1, cutoff = 400)
sys_eval = define_ens(ref, coords_eval)

# use grid to define uniform quadrature points
ξ = [ustrip.(Vector(coords)) for coords in coords_eval]
GQint = GaussQuadrature(ξ, ones(length(ξ))./length(ξ))

# plot
f0, ax0 = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords_eval))'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="test points")
axislegend(ax0)
f0

# plot density
f, _ = plot_density(ref, coord_grid, GQint)


# reference: train to test set
# pce = deepcopy(pce0)
# lp = learn!(sys_eval, ref, pce, [1000,1], false; e_flag=true, f_flag=true)
# p = define_gibbs_dist(ref)
# q = define_gibbs_dist(pce, θ=lp.β)
# fish = FisherDivergence(GQint)
# fd_best = compute_divergence(p, q, fish)


## training set 1: grid over main support ---------------------------------------
# sample from grid
coords1 = potential_grid_2d(ref, limits, 0.2, cutoff = 400)
trainset1 = define_ens(deepcopy(pce0), coords1)

# plot
f0, ax0 = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords1))'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 1")
axislegend(ax0)
f0



## training set 2: samples from Langevin MD -------------------------------------
# Langevin simulator
sim_langevin = OverdampedLangevin(
dt=0.002u"ps",
temperature=500.0u"K",
friction=4.0u"ps^-1",
)

x0arr = [[4.5, -2], [-3.5,3], [-3.5,-3]]
sys_langevin = Vector(undef, 3)
for (i,x0) in enumerate(x0arr)
sys0 = define_sys(
ref,
x0,
loggers=(coords=CoordinateLogger(100; dims=2),),
)
# simulate
sys2 = deepcopy(sys0)
simulate!(sys2, sim_langevin, 1_000_000)
sys_langevin[i] = sys2
end


# subselect train data from the trajectory
n = [1335, 669, 669]
coords2 = [[sys_langevin[j].loggers.coords.history[i][1] for i=1:n[j]] for j=1:3]
coords2 = reduce(vcat, coords2)
trainset2 = define_ens(deepcopy(pce0), coords2)

# plot
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords2))'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 2")
axislegend(ax)
f



## training set 3: samples from high-T MD -------------------------------------
# high-temp Langevin simulator
sim_highT = OverdampedLangevin(
dt=0.002u"ps",
temperature=2000.0u"K",
friction=4.0u"ps^-1",
)
# simulate
sys3 = deepcopy(sys0)
simulate!(sys3, sim_highT, 2_000_000)
# f = plot_md_trajectory(sys3, coord_grid, fill=false, lvls=ctr_lvls, showpath=false)

# subselect train data from the trajectory
id = StatsBase.sample(1:length(sys3.loggers.coords.history), length(coords1), replace=false)
coords3 = [sys3.loggers.coords.history[i][1] for i in id]
trainset3 = define_ens(deepcopy(pce0), coords3)

# plot
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords3))'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 3")
axislegend(ax)
f


# train with changing weight λ --------------------------------------------------------------
λarr = [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4]
trainsets = [trainset1, trainset2, trainset3]
p = define_gibbs_dist(ref)
fish = FisherDivergence(GQint)


# store results
param_dict = Dict( "ts$j" => Dict(
"E" => zeros(length(pce.basis)),
"F" => zeros(length(pce.basis)),
"EF" => Vector{Vector}(undef, length(λarr)),
) for j = 1:length(trainsets)
)

err_dict = Dict( "ts$j" => Dict(
"E" => 0.0,
"F" => 0.0,
"EF" => zeros(length(λarr)),
) for j = 1:length(trainsets)
)

fd_dict = Dict( "ts$j" => Dict(
"E" => 0.0,
"F" => 0.0,
"EF" => zeros(length(λarr)),
) for j = 1:length(trainsets)
)


# train on E or F only (UnivariateLinearProblem)
for (j,ts) in enumerate(trainsets)
# E objective
println("train set $j, E only")
pce = deepcopy(pce0)
lpe = learn!(ts, ref, pce; e_flag=true, f_flag=false)
q = define_gibbs_dist(pce, θ=lpe.β)
err_dict =
fd_dict["ts$j"]["E"] = compute_divergence(p, q, fish)
param_dict["ts$j"]["E"] = lpe.β

# F objective
println("train set $j, F only")
pce = deepcopy(pce0)
lpf = learn!(ts, ref, pce; e_flag=false, f_flag=true)
q = define_gibbs_dist(pce, θ=lpf.β)
fd_dict["ts$j"]["F"] = compute_divergence(p, q, fish)
param_dict["ts$j"]["F"] = lpf.β
end

# train on EF (CovariateLinearProblem)
for (i,λ) in enumerate(λarr)
for (j,ts) in enumerate(trainsets)

# EF objective
println("train set $j, EF (λ=$λ)")
pce = deepcopy(pce0)
lpef = learn!(ts, ref, pce, [λ, 1], false; e_flag=true, f_flag=true)
q = define_gibbs_dist(pce, θ=lpef.β)
fd_dict["ts$j"]["EF"][i] = compute_divergence(p, q, fish)
param_dict["ts$j"]["EF"][i] = lpef.β
end
end



# plot results
λlab = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4, 1e5]
f = Figure(resolution=(550,450))
ax = Axis(f[1,1],
xlabel="λ",
ylabel="Fisher divergence",
title="Model Error vs. Weight λ",
xscale=log10,
yscale=log10,
xticks=(λlab, ["F", "1e-4", "1e-3", "1e-2", "1e-1", "1", "1e1", "1e2", "1e3", "1e4", "E"]))

for j = 1:3
fd_all = reduce(vcat, [[fd_dict["ts$j"]["F"]], fd_dict["ts$j"]["EF"], [fd_dict["ts$j"]["E"]]])
scatterlines!(ax, λlab, fd_all, label="train set $j")
end
axislegend(ax, position=:lt)
f

pce.params = param_dict["ts2"]["E"]
ctr_lvls2 = -20:5:50 # for forces
f, _ = plot_contours_2d(pce, coord_grid, fill=true, lvls=ctr_lvls)
Loading
Loading