diff --git a/src/microcalibrate/calibration.py b/src/microcalibrate/calibration.py index 969c4f1..c1996ff 100644 --- a/src/microcalibrate/calibration.py +++ b/src/microcalibrate/calibration.py @@ -14,9 +14,8 @@ def __init__( self, weights: np.ndarray, targets: np.ndarray, - target_names: Optional[np.ndarray] = None, - estimate_matrix: Optional[pd.DataFrame] = None, - estimate_function: Optional[Callable[[Tensor], Tensor]] = None, + target_names: np.ndarray, + estimate_function: Callable[[Tensor], Tensor], epochs: Optional[int] = 32, noise_level: Optional[float] = 10.0, learning_rate: Optional[float] = 1e-3, @@ -29,19 +28,16 @@ def __init__( Args: weights (np.ndarray): Array of original weights. targets (np.ndarray): Array of target values. - target_names (Optional[np.ndarray]): Optional names of the targets for logging. Defaults to None. You MUST pass these names if you are not passing in an estimate matrix, and just passing in an estimate function. - estimate_matrix (pd.DataFrame): DataFrame containing the estimate matrix. - estimate_function (Optional[Callable[[Tensor], Tensor]]): Function to estimate targets from weights. Defaults to None, in which case it will use the estimate_matrix. + target_names (np.ndarray): Array of target names. + estimate_function (Callable[[Tensor], Tensor]): Function to estimate targets from weights. epochs (int): Optional number of epochs for calibration. Defaults to 32. noise_level (float): Optional level of noise to add to weights. Defaults to 10.0. learning_rate (float): Optional learning rate for the optimizer. Defaults to 1e-3. dropout_rate (float): Optional probability of dropping weights during training. Defaults to 0.1. csv_path (str): Optional path to save performance logs as CSV. Defaults to None. + device (str): Optional device specification. Defaults to None. """ - self.estimate_function = estimate_function - self.target_names = target_names - if device is not None: self.device = torch.device(device) else: @@ -51,18 +47,8 @@ def __init__( else "mps" if torch.mps.is_available() else "cpu" ) - if self.estimate_function is None: - self.estimate_function = ( - lambda weights: weights @ self.estimate_matrix_tensor - ) - if estimate_matrix is not None: - self.estimate_matrix = estimate_matrix - self.estimate_matrix_tensor = torch.tensor( - estimate_matrix.values, dtype=torch.float32, device=self.device - ) - self.target_names = estimate_matrix.columns.to_numpy() - else: - self.estimate_matrix = None + self.target_names = target_names + self.estimate_function = estimate_function self.weights = weights self.targets = targets self.epochs = epochs @@ -77,7 +63,7 @@ def calibrate(self) -> None: self._assess_targets( estimate_function=self.estimate_function, - estimate_matrix=self.estimate_matrix, + estimate_matrix=getattr(self, "estimate_matrix", None), weights=self.weights, targets=self.targets, target_names=self.target_names, @@ -218,3 +204,61 @@ def summary( ) / df["Official target"] df = df.reset_index(drop=True) return df + + +class ColumnSumCalibration(Calibration): + def __init__( + self, + weights: np.ndarray, + targets: np.ndarray, + estimate_matrix: pd.DataFrame, + epochs: Optional[int] = 32, + noise_level: Optional[float] = 10.0, + learning_rate: Optional[float] = 1e-3, + dropout_rate: Optional[float] = 0.1, + csv_path: Optional[str] = None, + device: str = None, + ): + """Initialize the ColumnSumCalibration class. + + This class inherits from Calibration and provides column sum estimation functionality. + + Args: + weights (np.ndarray): Array of original weights. + targets (np.ndarray): Array of target values. + estimate_matrix (pd.DataFrame): DataFrame containing the estimate matrix. + epochs (int): Optional number of epochs for calibration. Defaults to 32. + noise_level (float): Optional level of noise to add to weights. Defaults to 10.0. + learning_rate (float): Optional learning rate for the optimizer. Defaults to 1e-3. + dropout_rate (float): Optional probability of dropping weights during training. Defaults to 0.1. + csv_path (str): Optional path to save performance logs as CSV. Defaults to None. + device (str): Optional device specification. Defaults to None. + """ + if device is not None: + device_obj = torch.device(device) + else: + device_obj = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.mps.is_available() else "cpu" + ) + + self.estimate_matrix = estimate_matrix + estimate_matrix_tensor = torch.tensor( + estimate_matrix.values, dtype=torch.float32, device=device_obj + ) + target_names = estimate_matrix.columns.to_numpy() + estimate_function = lambda weights: weights @ estimate_matrix_tensor + + super().__init__( + weights=weights, + targets=targets, + target_names=target_names, + estimate_function=estimate_function, + epochs=epochs, + noise_level=noise_level, + learning_rate=learning_rate, + dropout_rate=dropout_rate, + csv_path=csv_path, + device=device, + ) diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 7b0719e..2acd5f8 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -2,7 +2,7 @@ Test the calibration process. """ -from src.microcalibrate.calibration import Calibration +from src.microcalibrate.calibration import ColumnSumCalibration import logging import numpy as np import pandas as pd @@ -39,7 +39,7 @@ def test_calibration_basic() -> None: ] ) - calibrator = Calibration( + calibrator = ColumnSumCalibration( estimate_matrix=targets_matrix, weights=weights, targets=targets, @@ -94,7 +94,7 @@ def test_calibration_harder_targets() -> None: ] ) - calibrator = Calibration( + calibrator = ColumnSumCalibration( estimate_matrix=targets_matrix, weights=weights, targets=targets, @@ -161,7 +161,7 @@ def test_calibration_warnings_system(caplog) -> None: ] ) - calibrator = Calibration( + calibrator = ColumnSumCalibration( estimate_matrix=targets_matrix, weights=weights, targets=targets,