Skip to content

Commit

Permalink
Merge pull request #85 from cesmix-mit/rabab/subsample
Browse files Browse the repository at this point in the history
Rabab/subsample
  • Loading branch information
emmanuellujan authored Sep 5, 2024
2 parents 8ef7cdb + 0a58996 commit 9f11cfe
Showing 1 changed file with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ using AtomsBase, InteratomicPotentials, PotentialLearning
using Unitful, UnitfulAtomic
using LinearAlgebra, Random, DisplayAs
using DataFrames, Plots
using MPI

MPI.Init()
comm = MPI.COMM_WORLD
size = MPI.Comm_size(comm)
rank = MPI.Comm_rank(comm)

# Define paths.
base_path = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
Expand Down Expand Up @@ -101,8 +107,8 @@ end

# Load training and test configuration datasets ################################

paths = [
"$ds_path/Hf2_gas_form_sorted.extxyz",
# Dataset 1 (28k)
paths = ["$ds_path/Hf2_gas_form_sorted.extxyz",
"$ds_path/Hf2_mp103_EOS_1D_form_sorted.extxyz", # 200
"$ds_path/Hf2_mp103_EOS_3D_form_sorted.extxyz", # 9377
"$ds_path/Hf2_mp103_EOS_6D_form_sorted.extxyz", # 17.2k
Expand All @@ -113,6 +119,25 @@ paths = [
"$ds_path/Hf_mp100_primitive_EOS_1D_form_sorted.extxyz"
]

# Dataset 2
#paths = [
# "$ds_path/HfO2_figshare_form_sorted.extxyz",
# "$ds_path/HfO2_mp550893_EOS_1D_form_sorted.extxyz",
# "$ds_path/HfO_gas_form_sorted.extxyz",
# "$ds_path/HfO2_figshare_form_sorted.extxyz",
# "$ds_path/HfO2_mp352_EOS_1D_form_sorted.extxyz",
# "$ds_path/HfO2_mp550893_EOS_6D_form_sorted.extxyz",
# "$ds_path/Hf2_gas_form_sorted.extxyz",
# "$ds_path/Hf2_mp103_EOS_1D_form_sorted.extxyz",
# "$ds_path/Hf2_mp103_EOS_3D_form_sorted.extxyz",
# "$ds_path/Hf2_mp103_EOS_6D_form_sorted.extxyz",
# "$ds_path/Hf_mp100_EOS_1D_form_sorted.extxyz",
# "$ds_path/Hf128_MC_rattled_mp100_form_sorted.extxyz",
# "$ds_path/Hf128_MC_rattled_mp103_form_sorted.extxyz",
# "$ds_path/Hf128_MC_rattled_random_form_sorted.extxyz",
# "$ds_path/Hf_mp100_primitive_EOS_1D_form_sorted.extxyz",
#]

confs = []
for ds_path in paths
push!(confs, load_data(ds_path, uparse("eV"), uparse(""))...)
Expand All @@ -121,7 +146,7 @@ confs = DataSet(confs)
n = length(confs)
GC.gc()

#ds_path = string("../data/HfO2_large/HfO2_figshare_form_sorted.extxyz")
#ds_path = string("$ds_path/a-HfO2-300K-NVT-6000.extxyz")
#confs = load_data(ds_path, uparse("eV"), uparse("Å"))
#n = length(confs)

Expand Down Expand Up @@ -160,7 +185,16 @@ metrics = DataFrame([Any[] for _ in 1:length(metric_names)], metric_names)

# Subsampling experiments: subsample full dataset vs subsample dataset by chunks
n_experiments = 100
for j in 1:n_experiments
local_exp = ceil(Int, n_experiments / size)
for nc in 1:local_exp

#check it there is left over
j = rank + size * (nc-1) + 1

if j > n_experiments
break
end

global metrics

# Define randomized training and test dataset
Expand Down

0 comments on commit 9f11cfe

Please sign in to comment.