diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..8558889 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Final weights are now consistent with the training log. diff --git a/src/microcalibrate/reweight.py b/src/microcalibrate/reweight.py index df40088..3145a70 100644 --- a/src/microcalibrate/reweight.py +++ b/src/microcalibrate/reweight.py @@ -20,7 +20,7 @@ def reweight( estimate_function: Callable[[Tensor], Tensor], targets_array: np.ndarray, target_names: np.ndarray, - dropout_rate: Optional[float] = 0.1, + dropout_rate: Optional[float] = 0.05, epochs: Optional[int] = 2_000, noise_level: Optional[float] = 10.0, learning_rate: Optional[float] = 1e-3, @@ -96,22 +96,15 @@ def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor: loss_over_epochs = [] estimates_over_epochs = [] pct_close_over_epochs = [] + max_epochs = epochs - 1 if epochs > 0 else 0 epochs = [] for i in iterator: optimizer.zero_grad() - running_loss = None - for j in range(2): - weights_ = dropout_weights(weights, dropout_rate) - estimate = estimate_function(torch.exp(weights_)) - l = loss(estimate, targets) - close = pct_close(estimate, targets) - if running_loss is None: - running_loss = l - else: - running_loss += l - - l = running_loss / 2 + weights_ = dropout_weights(weights, dropout_rate) + estimate = estimate_function(torch.exp(weights_)) + l = loss(estimate, targets) + close = pct_close(estimate, targets) if i % progress_update_interval == 0: iterator.set_postfix( @@ -139,8 +132,9 @@ def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor: f"({'improving' if loss_change > 0 else 'worsening'})" ) - l.backward() - optimizer.step() + if i != max_epochs - 1: + l.backward() + optimizer.step() tracker_dict = { "epochs": epochs, @@ -160,7 +154,7 @@ def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor: logger.info(f"Reweighting completed. Final sample size: {len(weights)}") - final_weights = torch.exp(weights).detach().cpu().numpy() + final_weights = torch.exp(weights_).detach().cpu().numpy() return ( final_weights,