diff --git a/edbo/bro.py b/edbo/bro.py index f8a18e7..233aa1f 100644 --- a/edbo/bro.py +++ b/edbo/bro.py @@ -193,7 +193,42 @@ def init_sample(self, seed=None, append=False, export_path=None, self.obj.get_results(self.proposed_experiments, append=append) return self.proposed_experiments + + # Fit model + def fit(self, n_restarts=0, learning_rate=0.1, training_iters=100): + """Fit surrogate model. + + Parameters + ---------- + n_restarts : int + Number of restarts used when optimizing GPyTorch model parameters. + learning_rate : float + ADAM learning rate used when optimizing GPyTorch model parameters. + training_iters : int + Number of iterations to run ADAM when optimizin GPyTorch models + parameters. + + Returns + ---------- + None + """ + # Initialize and train model + self.model = self.base_model(self.obj.X, + self.obj.y, + gpu=self.gpu, + nu=self.nu, + noise_constraint=self.noise_constraint, + lengthscale_prior=self.lengthscale_prior, + outputscale_prior=self.outputscale_prior, + noise_prior=self.noise_prior, + n_restarts=n_restarts, + learning_rate=learning_rate, + training_iters=training_iters + ) + + self.model.fit() + # Run algorithm and get next round of experiments def run(self, append=False, n_restarts=0, learning_rate=0.1, training_iters=100):