Skip to content

Commit

Permalink
Added fit method to bro.BO
Browse files Browse the repository at this point in the history
  • Loading branch information
b-shields authored Dec 8, 2020
1 parent f32b6dd commit 3ca50f3
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions edbo/bro.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,42 @@ def init_sample(self, seed=None, append=False, export_path=None,
self.obj.get_results(self.proposed_experiments, append=append)

return self.proposed_experiments

# Fit model
def fit(self, n_restarts=0, learning_rate=0.1, training_iters=100):
"""Fit surrogate model.
Parameters
----------
n_restarts : int
Number of restarts used when optimizing GPyTorch model parameters.
learning_rate : float
ADAM learning rate used when optimizing GPyTorch model parameters.
training_iters : int
Number of iterations to run ADAM when optimizin GPyTorch models
parameters.
Returns
----------
None
"""

# Initialize and train model
self.model = self.base_model(self.obj.X,
self.obj.y,
gpu=self.gpu,
nu=self.nu,
noise_constraint=self.noise_constraint,
lengthscale_prior=self.lengthscale_prior,
outputscale_prior=self.outputscale_prior,
noise_prior=self.noise_prior,
n_restarts=n_restarts,
learning_rate=learning_rate,
training_iters=training_iters
)

self.model.fit()

# Run algorithm and get next round of experiments
def run(self, append=False, n_restarts=0, learning_rate=0.1,
training_iters=100):
Expand Down

0 comments on commit 3ca50f3

Please sign in to comment.