diff --git a/Project.toml b/Project.toml index dc07457..3ed0c54 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PotentialLearning" uuid = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2" authors = ["CESMIX Team"] -version = "0.2.3" +version = "0.2.4" [deps] AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" diff --git a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl index 92cd126..3f0a186 100644 --- a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl +++ b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl @@ -52,12 +52,12 @@ end; model = ACE pars = OrderedDict( :body_order => [2, 3, 4], :polynomial_degree => [3, 4, 5], - :rcutoff => [4.5, 5.0, 5.5], - :wL => [0.5, 1.0, 1.5], - :csp => [0.5, 1.0, 1.5], - :r0 => [0.5, 1.0, 1.5]); + :rcutoff => LinRange(3.5, 6.5, 10), + :wL => LinRange(0.3, 1.8, 10), + :csp => LinRange(0.3, 1.8, 10), + :r0 => LinRange(0.3, 1.8, 10)); -# Use random sampling to find the optimal hyper-parameters. +# Use **random sampling** to find the optimal hyper-parameters. iap, res = hyperlearn!(model, pars, conf_train; n_samples = 10, sampler = RandomSampler(), loss = custom_loss, ws = [1.0, 1.0], int = true); @@ -74,21 +74,21 @@ err_time = plot_err_time(res) @save_fig res_path err_time DisplayAs.PNG(err_time) - -# Alternatively, use latin hypercube sampling to find the optimal hyper-parameters. -iap, res = hyperlearn!(model, pars, conf_train; - n_samples = 3, sampler = LHSampler(), - loss = custom_loss, ws = [1.0, 1.0], int = true); +# Alternatively, use **latin hypercube sampling** to find the optimal hyper-parameters. +sampler = CLHSampler(dims=[Categorical(3), Categorical(3), Continuous(), + Continuous(), Continuous(), Continuous()]) +iap2, res2 = hyperlearn!(model, pars, conf_train; + n_samples = 10, sampler = sampler, + loss = custom_loss, ws = [1.0, 1.0], int = true); # Save and show results. -@save_var res_path iap.β -@save_var res_path iap.β0 -@save_var res_path iap.basis -@save_dataframe res_path res -res +@save_var res_path iap2.β +@save_var res_path iap2.β0 +@save_var res_path iap2.basis +@save_dataframe res_path res2 +res2 # Plot error vs time. -err_time = plot_err_time(res) -@save_fig res_path err_time -DisplayAs.PNG(err_time) - +err_time2 = plot_err_time(res2) +@save_fig res_path err_time2 +DisplayAs.PNG(err_time2) diff --git a/src/HyperLearning/linear-hyperlearn.jl b/src/HyperLearning/linear-hyperlearn.jl index 5230c16..fc43dff 100644 --- a/src/HyperLearning/linear-hyperlearn.jl +++ b/src/HyperLearning/linear-hyperlearn.jl @@ -57,9 +57,9 @@ function hyperlearn!( model::DataType, pars::OrderedDict, conf_train::DataSet; - n_samples = 5, + n_samples = 10, sampler = RandomSampler(), - loss = loss, + loss = hyperlearn!, ws = [1.0, 1.0], int = true )