Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Normalization parameter to handle multi-level geography calibration.
6 changes: 5 additions & 1 deletion src/microcalibrate/reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def reweight(
epochs: Optional[int] = 2_000,
noise_level: Optional[float] = 10.0,
learning_rate: Optional[float] = 1e-3,
normalization_factor: Optional[torch.Tensor] = None,
csv_path: Optional[str] = None,
device: Optional[str] = None,
) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -38,6 +39,9 @@ def reweight(
epochs (int): Optional number of epochs for training.
noise_level (float): Optional level of noise to add to the original weights.
learning_rate (float): Optional learning rate for the optimizer.
normalization_factor (Optional[torch.Tensor]): Optional normalization factor for the loss (handles multi-level geographical calibration).
csv_path (Optional[str]): Optional path to save the performance metrics as a CSV file.
device (Optional[str]): Device to run the calibration on (e.g., 'cpu' or 'cuda'). If None, uses the default device.

Returns:
np.ndarray: Reweighted weights.
Expand Down Expand Up @@ -103,7 +107,7 @@ def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
optimizer.zero_grad()
weights_ = dropout_weights(weights, dropout_rate)
estimate = estimate_function(torch.exp(weights_))
l = loss(estimate, targets)
l = loss(estimate, targets, normalization_factor)
close = pct_close(estimate, targets)

if i % progress_update_interval == 0:
Expand Down
6 changes: 6 additions & 0 deletions src/microcalibrate/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
def loss(
estimate: torch.Tensor,
targets_array: torch.Tensor,
normalization_factor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate the loss based on the current weights and targets.

Args:
estimate (torch.Tensor): Current estimates in log space.
targets_array (torch.Tensor): Array of target values.
normalization_factor (Optional[torch.Tensor]): Optional normalization factor for the loss (handles multi-level geographical calibration).

Returns:
torch.Tensor: Mean squared relative error between estimated and target values.
"""
rel_error = (((estimate - targets_array) + 1) / (targets_array + 1)) ** 2
if normalization_factor is not None:
rel_error *= normalization_factor
if torch.isnan(rel_error).any():
raise ValueError("Relative error contains NaNs")
return rel_error.mean()


Expand Down