diff --git a/README.md b/README.md
index 134f35c..86cdedc 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,105 @@
-# MicroCalibrate
+# MicroCalibrate
+
+[](https://github.com/PolicyEngine/microcalibrate/actions/workflows/main.yml)
+[](https://codecov.io/gh/PolicyEngine/microcalibrate)
+[](https://badge.fury.io/py/microcalibrate)
+[](https://pypi.org/project/microcalibrate/)
+
+MicroCalibrate is a Python package for calibrating survey weights to match population targets, with advanced features including L0 regularization for sparsity, hyperparameter tuning, and robustness evaluation.
+
+## Features
+
+- **Survey Weight Calibration**: The package adjusts sample weights to match known population totals.
+- **L0 Regularization**: The system creates sparse weights to reduce dataset size while maintaining accuracy.
+- **Automatic Hyperparameter Tuning**: The optimization module automatically finds optimal regularization parameters using cross-validation.
+- **Robustness Evaluation**: The evaluation tools assess calibration stability using holdout validation.
+- **Target Assessment**: The analysis features help identify which targets complicate calibration.
+- **Performance Monitoring**: The system tracks calibration progress with detailed logging.
+- **Interactive Dashboard**: Users can visualize calibration performance at https://microcalibrate.vercel.app/.
+
+## Installation
+
+```bash
+pip install microcalibrate
+```
+
+The package requires the following dependencies:
+- Python version 3.13 or higher is required.
+- PyTorch version 2.7.0 or higher is needed.
+- Additional required packages include NumPy, Pandas, Optuna, and L0-python.
+
+## Quick start
+
+### Basic calibration
+
+```python
+from microcalibrate import Calibration
+import numpy as np
+import pandas as pd
+
+# Create sample data for calibration
+n_samples = 1000
+weights = np.ones(n_samples) # Initial weights are set to one
+
+# Create an estimate matrix that represents the contribution of each record to targets
+estimate_matrix = pd.DataFrame({
+ 'total_income': np.random.normal(50000, 15000, n_samples),
+ 'total_employed': np.random.binomial(1, 0.6, n_samples),
+})
+
+# Set the target values to achieve through calibration
+targets = np.array([
+ 50_000_000, # This is the total income target
+ 600, # This is the total employed target
+])
+
+# Initialize the calibration object and configure the optimization parameters
+cal = Calibration(
+ weights=weights,
+ targets=targets,
+ estimate_matrix=estimate_matrix,
+ epochs=500,
+ learning_rate=1e-3,
+)
+
+# Perform the calibration to adjust weights
+performance_df = cal.calibrate()
+
+# Retrieve the calibrated weights from the calibration object
+new_weights = cal.weights
+```
+
+## API reference
+
+### Calibration class
+
+The Calibration class is the main class for weight calibration.
+
+**Parameters:**
+- `weights`: The initial weights array for each record.
+- `targets`: The target values to match during calibration.
+- `estimate_matrix`: A DataFrame containing the contribution of each record to targets.
+- `estimate_function`: An alternative to estimate_matrix that uses a custom function.
+- `epochs`: The number of optimization iterations to perform (default is 32).
+- `learning_rate`: The optimization learning rate (default is 1e-3).
+- `noise_level`: The amount of noise added for robustness (default is 10.0).
+- `dropout_rate`: The dropout rate for regularization (default is 0).
+- `regularize_with_l0`: This parameter enables L0 regularization (default is False).
+- `l0_lambda`: The L0 regularization strength parameter (default is 5e-6).
+- `init_mean`: The initial proportion of non-zero weights (default is 0.999).
+- `temperature`: The sparsity control parameter (default is 0.5).
+
+**Methods:**
+- `calibrate()`: This method performs the weight calibration process.
+- `tune_l0_hyperparameters()`: This method automatically tunes L0 parameters using cross-validation.
+- `evaluate_holdout_robustness()`: This method assesses calibration stability using holdout validation.
+- `assess_analytical_solution()`: This method analyzes the difficulty of achieving target combinations.
+- `summary()`: This method returns a summary of the calibration results.
+
+## Examples and documentation
+
+For detailed examples and interactive notebooks, see the [documentation](https://policyengine.github.io/microcalibrate/).
+
+## Contributing
+
+Contributions are welcome to the project. Please feel free to submit a Pull Request with your improvements.
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29..8cbf276 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,4 @@
+- bump: minor
+ changes:
+ added:
+ - Adding documentation for L0, hyperparameter tuning and robustness checks.
diff --git a/docs/_toc.yml b/docs/_toc.yml
index 19344f5..34246d6 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -2,3 +2,5 @@ format: jb-book
root: intro
chapters:
- file: calibration.ipynb
+- file: l0_regularization.ipynb
+- file: robustness_evaluation.ipynb
diff --git a/docs/calibration.ipynb b/docs/calibration.ipynb
index 68c2a59..4132c3a 100644
--- a/docs/calibration.ipynb
+++ b/docs/calibration.ipynb
@@ -36,7 +36,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 5,
"id": "vxj460xngvg",
"metadata": {},
"outputs": [
@@ -62,9 +62,8 @@
"import plotly.graph_objs as go\n",
"from plotly.subplots import make_subplots\n",
"\n",
- "logging.basicConfig(\n",
- " level=logging.INFO,\n",
- ")\n",
+ "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n",
+ "calibration_logger.setLevel(logging.WARNING)\n",
"\n",
"# Create a sample dataset with age and income data\n",
"random_generator = np.random.default_rng(0)\n",
@@ -96,7 +95,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 6,
"id": "xfw94bs21yl",
"metadata": {},
"outputs": [
@@ -104,172 +103,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:microcalibrate.calibration:Performing basic target assessment...\n",
- "WARNING:microcalibrate.calibration:Target income_aged_20_30 (7.37e+08) differs from initial estimate (7.37e+05) by 3.00 orders of magnitude.\n",
- "WARNING:microcalibrate.calibration:Target income_aged_71 is supported by only 0.83% of records in the loss matrix. This may make calibration unstable or ineffective.\n",
- "INFO:microcalibrate.reweight:Starting calibration process for targets ['income_aged_20_30' 'income_aged_40_50' 'income_aged_71']: [7.37032429e+08 9.76779350e+05 4.36479914e+04]\n",
- "INFO:microcalibrate.reweight:Original weights - mean: 1.0000, std: 0.0000\n",
- "INFO:microcalibrate.reweight:Initial weights after noise - mean: 1.0226, std: 0.0148\n",
- "Reweighting progress: 0%| | 0/528 [00:00, ?epoch/s, loss=0.341, weights_mean=1.02, weights_std=0.0148, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 0.00% \n",
- "\n",
- "Reweighting progress: 0%| | 1/528 [00:00<01:32, 5.70epoch/s, loss=0.333, weights_mean=1.06, weights_std=0.0523, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 10: Loss = 0.332947, Change = 0.008021 (improving)\n",
- "Reweighting progress: 0%| | 1/528 [00:00<01:32, 5.70epoch/s, loss=0.333, weights_mean=1.09, weights_std=0.093, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 20: Loss = 0.332994, Change = -0.000046 (worsening)\n",
- "Reweighting progress: 4%|▍ | 22/528 [00:00<00:05, 96.55epoch/s, loss=0.332, weights_mean=1.1, weights_std=0.131, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 30: Loss = 0.332489, Change = 0.000505 (improving)\n",
- "Reweighting progress: 4%|▍ | 22/528 [00:00<00:05, 96.55epoch/s, loss=0.332, weights_mean=1.12, weights_std=0.184, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 40: Loss = 0.332363, Change = 0.000126 (improving)\n",
- "Reweighting progress: 4%|▍ | 22/528 [00:00<00:05, 96.55epoch/s, loss=0.332, weights_mean=1.15, weights_std=0.249, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 50: Loss = 0.332187, Change = 0.000176 (improving)\n",
- "Reweighting progress: 10%|▉ | 51/528 [00:00<00:02, 168.58epoch/s, loss=0.332, weights_mean=1.19, weights_std=0.326, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 60: Loss = 0.332042, Change = 0.000145 (improving)\n",
- "Reweighting progress: 10%|▉ | 51/528 [00:00<00:02, 168.58epoch/s, loss=0.332, weights_mean=1.22, weights_std=0.418, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 70: Loss = 0.331861, Change = 0.000182 (improving)\n",
- "Reweighting progress: 15%|█▍ | 77/528 [00:00<00:02, 199.34epoch/s, loss=0.332, weights_mean=1.27, weights_std=0.527, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 80: Loss = 0.331657, Change = 0.000204 (improving)\n",
- "Reweighting progress: 15%|█▍ | 77/528 [00:00<00:02, 199.34epoch/s, loss=0.331, weights_mean=1.32, weights_std=0.658, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 90: Loss = 0.331410, Change = 0.000247 (improving)\n",
- "Reweighting progress: 19%|█▉ | 100/528 [00:00<00:02, 208.55epoch/s, loss=0.331, weights_mean=1.39, weights_std=0.817, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 100: Loss = 0.331114, Change = 0.000296 (improving)\n",
- "Reweighting progress: 19%|█▉ | 100/528 [00:00<00:02, 208.55epoch/s, loss=0.331, weights_mean=1.47, weights_std=1.01, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 110: Loss = 0.330754, Change = 0.000360 (improving)\n",
- "Reweighting progress: 19%|█▉ | 100/528 [00:00<00:02, 208.55epoch/s, loss=0.33, weights_mean=1.57, weights_std=1.25, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 120: Loss = 0.330314, Change = 0.000440 (improving)\n",
- "Reweighting progress: 24%|██▍ | 129/528 [00:00<00:01, 233.06epoch/s, loss=0.33, weights_mean=1.69, weights_std=1.54, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 130: Loss = 0.329772, Change = 0.000542 (improving)\n",
- "Reweighting progress: 24%|██▍ | 129/528 [00:00<00:01, 233.06epoch/s, loss=0.329, weights_mean=1.84, weights_std=1.9, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 140: Loss = 0.329099, Change = 0.000672 (improving)\n",
- "Reweighting progress: 24%|██▍ | 129/528 [00:00<00:01, 233.06epoch/s, loss=0.328, weights_mean=2.03, weights_std=2.35, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 150: Loss = 0.328259, Change = 0.000840 (improving)\n",
- "Reweighting progress: 30%|██▉ | 158/528 [00:00<00:01, 249.44epoch/s, loss=0.327, weights_mean=2.27, weights_std=2.92, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 160: Loss = 0.327203, Change = 0.001056 (improving)\n",
- "Reweighting progress: 30%|██▉ | 158/528 [00:00<00:01, 249.44epoch/s, loss=0.326, weights_mean=2.57, weights_std=3.65, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 170: Loss = 0.325865, Change = 0.001338 (improving)\n",
- "Reweighting progress: 30%|██▉ | 158/528 [00:00<00:01, 249.44epoch/s, loss=0.324, weights_mean=2.95, weights_std=4.57, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 180: Loss = 0.324160, Change = 0.001705 (improving)\n",
- "Reweighting progress: 35%|███▌ | 185/528 [00:00<00:01, 255.70epoch/s, loss=0.322, weights_mean=3.45, weights_std=5.76, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 190: Loss = 0.321974, Change = 0.002186 (improving)\n",
- "Reweighting progress: 35%|███▌ | 185/528 [00:00<00:01, 255.70epoch/s, loss=0.319, weights_mean=4.09, weights_std=7.3, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 200: Loss = 0.319155, Change = 0.002819 (improving)\n",
- "Reweighting progress: 35%|███▌ | 185/528 [00:00<00:01, 255.70epoch/s, loss=0.316, weights_mean=4.92, weights_std=9.3, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 210: Loss = 0.315500, Change = 0.003655 (improving)\n",
- "Reweighting progress: 40%|████ | 213/528 [00:01<00:01, 262.61epoch/s, loss=0.311, weights_mean=6.02, weights_std=11.9, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 220: Loss = 0.310741, Change = 0.004759 (improving)\n",
- "Reweighting progress: 40%|████ | 213/528 [00:01<00:01, 262.61epoch/s, loss=0.305, weights_mean=7.46, weights_std=15.4, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 230: Loss = 0.304523, Change = 0.006218 (improving)\n",
- "Reweighting progress: 40%|████ | 213/528 [00:01<00:01, 262.61epoch/s, loss=0.296, weights_mean=9.37, weights_std=20, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 240: Loss = 0.296380, Change = 0.008143 (improving)\n",
- "Reweighting progress: 46%|████▌ | 241/528 [00:01<00:01, 267.28epoch/s, loss=0.286, weights_mean=11.9, weights_std=26.1, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 250: Loss = 0.285711, Change = 0.010669 (improving)\n",
- "Reweighting progress: 46%|████▌ | 241/528 [00:01<00:01, 267.28epoch/s, loss=0.272, weights_mean=15.3, weights_std=34.3, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 260: Loss = 0.271758, Change = 0.013953 (improving)\n",
- "Reweighting progress: 51%|█████ | 268/528 [00:01<00:01, 223.19epoch/s, loss=0.254, weights_mean=19.9, weights_std=45.2, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 270: Loss = 0.253606, Change = 0.018152 (improving)\n",
- "Reweighting progress: 51%|█████ | 268/528 [00:01<00:01, 223.19epoch/s, loss=0.23, weights_mean=26, weights_std=59.9, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 280: Loss = 0.230236, Change = 0.023370 (improving)\n",
- "Reweighting progress: 51%|█████ | 268/528 [00:01<00:01, 223.19epoch/s, loss=0.201, weights_mean=34.2, weights_std=79.6, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 290: Loss = 0.200684, Change = 0.029552 (improving)\n",
- "Reweighting progress: 55%|█████▌ | 292/528 [00:01<00:01, 218.99epoch/s, loss=0.164, weights_mean=45.1, weights_std=106, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 300: Loss = 0.164429, Change = 0.036255 (improving)\n",
- "Reweighting progress: 55%|█████▌ | 292/528 [00:01<00:01, 218.99epoch/s, loss=0.122, weights_mean=59.5, weights_std=140, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 310: Loss = 0.122155, Change = 0.042274 (improving)\n",
- "Reweighting progress: 55%|█████▌ | 292/528 [00:01<00:01, 218.99epoch/s, loss=0.077, weights_mean=78.1, weights_std=185, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 320: Loss = 0.077000, Change = 0.045154 (improving)\n",
- "Reweighting progress: 61%|██████ | 321/528 [00:01<00:00, 235.49epoch/s, loss=0.0359, weights_mean=101, weights_std=239, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 330: Loss = 0.035907, Change = 0.041093 (improving)\n",
- "Reweighting progress: 61%|██████ | 321/528 [00:01<00:00, 235.49epoch/s, loss=0.00872, weights_mean=125, weights_std=299, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 66.67% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 340: Loss = 0.008719, Change = 0.027188 (improving)\n",
- "Reweighting progress: 61%|██████ | 321/528 [00:01<00:00, 235.49epoch/s, loss=0.000144, weights_mean=146, weights_std=349, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 350: Loss = 0.000144, Change = 0.008575 (improving)\n",
- "Reweighting progress: 66%|██████▋ | 351/528 [00:01<00:00, 248.71epoch/s, loss=0.000685, weights_mean=156, weights_std=373, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 360: Loss = 0.000685, Change = -0.000541 (worsening)\n",
- "Reweighting progress: 66%|██████▋ | 351/528 [00:01<00:00, 248.71epoch/s, loss=0.000525, weights_mean=155, weights_std=371, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 370: Loss = 0.000525, Change = 0.000161 (improving)\n",
- "Reweighting progress: 72%|███████▏ | 380/528 [00:01<00:00, 259.54epoch/s, loss=4.77e-5, weights_mean=151, weights_std=361, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 380: Loss = 0.000048, Change = 0.000477 (improving)\n",
- "Reweighting progress: 72%|███████▏ | 380/528 [00:01<00:00, 259.54epoch/s, loss=9.05e-6, weights_mean=149, weights_std=355, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 390: Loss = 0.000009, Change = 0.000039 (improving)\n",
- "Reweighting progress: 72%|███████▏ | 380/528 [00:01<00:00, 259.54epoch/s, loss=2.12e-5, weights_mean=148, weights_std=354, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 400: Loss = 0.000021, Change = -0.000012 (worsening)\n",
- "Reweighting progress: 77%|███████▋ | 408/528 [00:01<00:00, 264.71epoch/s, loss=5.13e-6, weights_mean=149, weights_std=355, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 410: Loss = 0.000005, Change = 0.000016 (improving)\n",
- "Reweighting progress: 77%|███████▋ | 408/528 [00:01<00:00, 264.71epoch/s, loss=4.91e-10, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 420: Loss = 0.000000, Change = 0.000005 (improving)\n",
- "Reweighting progress: 77%|███████▋ | 408/528 [00:01<00:00, 264.71epoch/s, loss=6.66e-7, weights_mean=150, weights_std=357, weights_min=1] INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 430: Loss = 0.000001, Change = -0.000001 (worsening)\n",
- "Reweighting progress: 83%|████████▎ | 436/528 [00:01<00:00, 268.41epoch/s, loss=3.05e-7, weights_mean=150, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 440: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 83%|████████▎ | 436/528 [00:01<00:00, 268.41epoch/s, loss=7.69e-9, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 450: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 83%|████████▎ | 436/528 [00:01<00:00, 268.41epoch/s, loss=1.85e-8, weights_mean=149, weights_std=356, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 460: Loss = 0.000000, Change = -0.000000 (worsening)\n",
- "Reweighting progress: 88%|████████▊ | 464/528 [00:02<00:00, 270.16epoch/s, loss=1.47e-8, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 470: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 88%|████████▊ | 464/528 [00:02<00:00, 270.16epoch/s, loss=9.94e-10, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 480: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 88%|████████▊ | 464/528 [00:02<00:00, 270.16epoch/s, loss=5.02e-10, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 490: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 93%|█████████▎| 492/528 [00:02<00:00, 262.90epoch/s, loss=6.49e-10, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 500: Loss = 0.000000, Change = -0.000000 (worsening)\n",
- "Reweighting progress: 93%|█████████▎| 492/528 [00:02<00:00, 262.90epoch/s, loss=6.12e-11, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 510: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 93%|█████████▎| 492/528 [00:02<00:00, 262.90epoch/s, loss=1.61e-11, weights_mean=149, weights_std=357, weights_min=1]INFO:microcalibrate.reweight:Within 10% from targets: 100.00% \n",
- "\n",
- "INFO:microcalibrate.reweight:Epoch 520: Loss = 0.000000, Change = 0.000000 (improving)\n",
- "Reweighting progress: 100%|██████████| 528/528 [00:02<00:00, 236.82epoch/s, loss=1.61e-11, weights_mean=149, weights_std=357, weights_min=1]\n",
- "INFO:microcalibrate.reweight:Reweighting completed. Final sample size: 121\n"
+ "Target income_aged_20_30 (7.37e+08) differs from initial estimate (7.37e+05) by 3.00 orders of magnitude.\n",
+ "Target income_aged_71 is supported by only 0.83% of records in the loss matrix. This may make calibration unstable or ineffective.\n",
+ "Reweighting progress: 100%|██████████| 528/528 [00:00<00:00, 2153.20epoch/s, loss=1.61e-11, weights_mean=150, weights_std=357, weights_min=1]"
]
},
{
@@ -280,6 +116,13 @@
"Calibrated dataset size: 121\n",
"Number of calibrated weights: 121\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
}
],
"source": [
@@ -302,9 +145,231 @@
"print(f\"Number of calibrated weights: {len(calibrator.weights)}\")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "t1pbyy86uic",
+ "metadata": {},
+ "source": [
+ "## Pre-calibration target assessment\n",
+ "\n",
+ "Before running the calibration, it's important to understand whether your targets are achievable and well-posed. The Calibration class provides a key method for this:\n",
+ "\n",
+ "**`assess_analytical_solution()`** - Analyzes the mathematical difficulty of achieving your target combination\n",
+ "\n",
+ "Additionally, you should manually check for:\n",
+ "- Order of magnitude differences between targets\n",
+ "- Highly correlated targets\n",
+ "- Redundant or conflicting constraints\n",
+ "\n",
+ "### Analytical solution assessment\n",
+ "\n",
+ "The analytical solution assessment examines the optimization difficulty by using the Moore-Penrose inverse for a least squares solution. It shows:\n",
+ "- How the loss increases as each target is added\n",
+ "- Which targets contribute most to the optimization difficulty\n",
+ "- Targets with large delta_loss values that complicate calibration\n",
+ "\n",
+ "This is particularly useful when you have many correlated targets or when targets overlap (e.g., total income includes all regional incomes)."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 7,
+ "id": "sxds7gqd37o",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Assessing analytical solution feasibility...\n",
+ "\n",
+ "============================================================\n",
+ "ANALYTICAL SOLUTION ASSESSMENT\n",
+ "============================================================\n",
+ "\n",
+ "This shows how the loss increases as each target is added:\n",
+ " target_added loss delta_loss\n",
+ "income_aged_20_30 0.000000e+00 NaN\n",
+ "income_aged_40_50 6.776264e-21 6.776264e-21\n",
+ " income_aged_71 4.535156e-21 -2.241108e-21\n",
+ "\n",
+ "Targets with large delta_loss values complicate the optimization problem.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Assess the analytical solution before calibration\n",
+ "print(\"Assessing analytical solution feasibility...\")\n",
+ "analytical_assessment = calibrator.assess_analytical_solution()\n",
+ "\n",
+ "# Display the assessment results\n",
+ "print(\"\\n\" + \"=\"*60)\n",
+ "print(\"ANALYTICAL SOLUTION ASSESSMENT\")\n",
+ "print(\"=\"*60)\n",
+ "print(\"\\nThis shows how the loss increases as each target is added:\")\n",
+ "print(analytical_assessment.to_string(index=False))\n",
+ "print(\"\\nTargets with large delta_loss values complicate the optimization problem.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "jfzf3z2w4mi",
+ "metadata": {},
+ "source": [
+ "### Understanding the assessment\n",
+ "\n",
+ "The assessment provides several key insights:\n",
+ "\n",
+ "1. **Condition number**: Values above 1e10 suggest numerical instability. High condition numbers mean small changes in targets can lead to large changes in weights.\n",
+ "\n",
+ "2. **Rank analysis**: If the rank is less than the number of targets, some targets are redundant or conflicting.\n",
+ "\n",
+ "3. **Recommendations**: The assessment provides specific guidance based on the analysis.\n",
+ "\n",
+ "### Target tolerance checking\n",
+ "\n",
+ "Another important pre-calibration check is to ensure your targets are on similar scales. Targets that differ by many orders of magnitude can cause convergence issues."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "l46es1rx02p",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Checking target scales...\n",
+ "\n",
+ "Target scale analysis:\n",
+ " income_aged_20_30: 7.37e+08 (order difference from median: 2.9)\n",
+ " income_aged_40_50: 9.77e+05 (order difference from median: 0.0)\n",
+ " income_aged_71: 4.36e+04 (order difference from median: -1.3)\n",
+ "\n",
+ "✓ All targets are within reasonable scale differences\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Check for order of magnitude differences in targets\n",
+ "print(\"Checking target scales...\")\n",
+ "\n",
+ "# Manually check for large differences in target scales\n",
+ "target_values = targets\n",
+ "target_names = [\"income_aged_20_30\", \"income_aged_40_50\", \"income_aged_71\"]\n",
+ "\n",
+ "# Calculate order of magnitude for each target\n",
+ "orders_of_magnitude = np.log10(np.abs(target_values) + 1e-10)\n",
+ "median_order = np.median(orders_of_magnitude)\n",
+ "\n",
+ "print(\"\\nTarget scale analysis:\")\n",
+ "for i, name in enumerate(target_names):\n",
+ " order_diff = orders_of_magnitude[i] - median_order\n",
+ " print(f\" {name}: {target_values[i]:.2e} (order difference from median: {order_diff:.1f})\")\n",
+ " \n",
+ "# Identify targets with extreme scale differences\n",
+ "large_scale_diff = np.abs(orders_of_magnitude - median_order) > 3\n",
+ "if np.any(large_scale_diff):\n",
+ " print(\"\\n⚠️ WARNING: The following targets differ by more than 3 orders of magnitude from the median:\")\n",
+ " for i in np.where(large_scale_diff)[0]:\n",
+ " print(f\" - {target_names[i]}: {target_values[i]:.2e}\")\n",
+ " print(\"\\nConsider rescaling these targets for better numerical stability.\")\n",
+ "else:\n",
+ " print(\"\\n✓ All targets are within reasonable scale differences\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5l3q4s17a2r",
+ "metadata": {},
+ "source": [
+ "### Interpreting tolerance warnings\n",
+ "\n",
+ "The tolerance check identifies targets that:\n",
+ "- Differ by more than 3 orders of magnitude from the median target\n",
+ "- May cause numerical instability during optimization\n",
+ "- Should potentially be rescaled or reviewed\n",
+ "\n",
+ "In our example, we have one target that is much larger than the others (income_aged_20_30), which the system warns about. This is expected since we artificially multiplied it by 1000 in our setup."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "gzefuat6c78",
+ "metadata": {},
+ "source": [
+ "### Excluded targets\n",
+ "\n",
+ "If the pre-calibration assessment identifies problematic targets, you can exclude them from calibration:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "ytre6ko9ru",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Calibration setup with excluded targets:\n",
+ "Original targets: ['income_aged_20_30' 'income_aged_40_50' 'income_aged_71']\n",
+ "Excluded targets: ['income_aged_71']\n",
+ "Active targets for calibration: ['income_aged_20_30' 'income_aged_40_50']\n",
+ "Number of active targets: 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Example: Excluding problematic targets\n",
+ "# If you identify targets that cause issues, you can exclude them:\n",
+ "\n",
+ "# Create a calibrator with excluded targets\n",
+ "calibrator_with_exclusions = Calibration(\n",
+ " estimate_matrix=targets_matrix,\n",
+ " weights=weights.copy(), \n",
+ " targets=targets,\n",
+ " excluded_targets=[\"income_aged_71\"], # Exclude the smallest target\n",
+ " noise_level=0.05,\n",
+ " epochs=100,\n",
+ " learning_rate=0.01,\n",
+ ")\n",
+ "\n",
+ "print(\"Calibration setup with excluded targets:\")\n",
+ "print(f\"Original targets: {calibrator_with_exclusions.original_target_names}\")\n",
+ "print(f\"Excluded targets: {calibrator_with_exclusions.excluded_targets}\")\n",
+ "print(f\"Active targets for calibration: {calibrator_with_exclusions.target_names}\")\n",
+ "print(f\"Number of active targets: {len(calibrator_with_exclusions.targets)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "wl7a94tjmnj",
+ "metadata": {},
+ "source": [
+ "## Best practices for pre-calibration assessment\n",
+ "\n",
+ "1. **Always run analytical assessment first** - This helps identify fundamental issues with your target specification before spending time on calibration.\n",
+ "\n",
+ "2. **Check target scales** - Targets that differ by many orders of magnitude should be rescaled or normalized to improve convergence.\n",
+ "\n",
+ "3. **Look for redundant targets** - If targets are highly correlated or mathematically dependent, consider removing redundant ones.\n",
+ "\n",
+ "4. **Consider the degrees of freedom** - Having more targets than observations makes exact calibration impossible. The system will find a best-fit solution.\n",
+ "\n",
+ "5. **Use exclusions strategically** - Temporarily exclude problematic targets to get initial calibration working, then gradually add them back.\n",
+ "\n",
+ "6. **Monitor condition numbers** - High condition numbers (>1e10) indicate numerical instability. Consider reformulating your targets or adding regularization.\n",
+ "\n",
+ "Now let's proceed with the actual calibration:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
"id": "9djvjpfrhxb",
"metadata": {},
"outputs": [
@@ -313,9 +378,9 @@
"output_type": "stream",
"text": [
"Target totals: [7.37032429e+08 9.76779350e+05 4.36479914e+04]\n",
- "Final calibrated totals: [7.37025225e+08 9.76778430e+05 4.36469951e+04]\n",
- "Difference: [-7.20317294e+03 -9.19358560e-01 -9.96308503e-01]\n",
- "Relative error: [-9.77321032e-04 -9.41214165e-05 -2.28259874e-03]\n"
+ "Final calibrated totals: [7.37025243e+08 9.76778358e+05 4.36469951e+04]\n",
+ "Difference: [-7.18606246e+03 -9.91859882e-01 -9.96308503e-01]\n",
+ "Relative error: [-0.000975 -0.00010154 -0.0022826 ]\n"
]
}
],
@@ -331,7 +396,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 12,
"id": "96cc818b",
"metadata": {},
"outputs": [],
@@ -346,7 +411,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 13,
"id": "923d79dd",
"metadata": {},
"outputs": [
@@ -385,79 +450,79 @@
"
\n",
" | 0 | \n",
" 0 | \n",
- " 0.340969 | \n",
+ " 0.339299 | \n",
" income_aged_20_30 | \n",
" 7.370324e+08 | \n",
- " 752926.625000 | \n",
- " -7.362795e+08 | \n",
- " 7.362795e+08 | \n",
- " 0.998978 | \n",
+ " 7.595331e+05 | \n",
+ " -7.362729e+08 | \n",
+ " 7.362729e+08 | \n",
+ " 0.998969 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
- " 0.340969 | \n",
+ " 0.339299 | \n",
" income_aged_40_50 | \n",
" 9.767794e+05 | \n",
- " 872380.750000 | \n",
- " -1.043986e+05 | \n",
- " 1.043986e+05 | \n",
- " 0.106880 | \n",
+ " 8.723239e+05 | \n",
+ " -1.044555e+05 | \n",
+ " 1.044555e+05 | \n",
+ " 0.106939 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0 | \n",
- " 0.340969 | \n",
+ " 0.339299 | \n",
" income_aged_71 | \n",
" 4.364799e+04 | \n",
- " 38570.789062 | \n",
- " -5.077203e+03 | \n",
- " 5.077203e+03 | \n",
- " 0.116322 | \n",
+ " 3.961747e+04 | \n",
+ " -4.030520e+03 | \n",
+ " 4.030520e+03 | \n",
+ " 0.092341 | \n",
"
\n",
" \n",
" | 3 | \n",
- " 10 | \n",
- " 0.332947 | \n",
+ " 52 | \n",
+ " 0.332147 | \n",
" income_aged_20_30 | \n",
" 7.370324e+08 | \n",
- " 832346.250000 | \n",
- " -7.362001e+08 | \n",
- " 7.362001e+08 | \n",
- " 0.998871 | \n",
+ " 1.320273e+06 | \n",
+ " -7.357122e+08 | \n",
+ " 7.357122e+08 | \n",
+ " 0.998209 | \n",
"
\n",
" \n",
" | 4 | \n",
- " 10 | \n",
- " 0.332947 | \n",
+ " 52 | \n",
+ " 0.332147 | \n",
" income_aged_40_50 | \n",
" 9.767794e+05 | \n",
- " 958994.500000 | \n",
- " -1.778488e+04 | \n",
- " 1.778488e+04 | \n",
- " 0.018208 | \n",
+ " 9.774676e+05 | \n",
+ " 6.882500e+02 | \n",
+ " 6.882500e+02 | \n",
+ " 0.000705 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " epoch loss target_name target estimate \\\n",
- "0 0 0.340969 income_aged_20_30 7.370324e+08 752926.625000 \n",
- "1 0 0.340969 income_aged_40_50 9.767794e+05 872380.750000 \n",
- "2 0 0.340969 income_aged_71 4.364799e+04 38570.789062 \n",
- "3 10 0.332947 income_aged_20_30 7.370324e+08 832346.250000 \n",
- "4 10 0.332947 income_aged_40_50 9.767794e+05 958994.500000 \n",
+ " epoch loss target_name target estimate \\\n",
+ "0 0 0.339299 income_aged_20_30 7.370324e+08 7.595331e+05 \n",
+ "1 0 0.339299 income_aged_40_50 9.767794e+05 8.723239e+05 \n",
+ "2 0 0.339299 income_aged_71 4.364799e+04 3.961747e+04 \n",
+ "3 52 0.332147 income_aged_20_30 7.370324e+08 1.320273e+06 \n",
+ "4 52 0.332147 income_aged_40_50 9.767794e+05 9.774676e+05 \n",
"\n",
" error abs_error rel_abs_error \n",
- "0 -7.362795e+08 7.362795e+08 0.998978 \n",
- "1 -1.043986e+05 1.043986e+05 0.106880 \n",
- "2 -5.077203e+03 5.077203e+03 0.116322 \n",
- "3 -7.362001e+08 7.362001e+08 0.998871 \n",
- "4 -1.778488e+04 1.778488e+04 0.018208 "
+ "0 -7.362729e+08 7.362729e+08 0.998969 \n",
+ "1 -1.044555e+05 1.044555e+05 0.106939 \n",
+ "2 -4.030520e+03 4.030520e+03 0.092341 \n",
+ "3 -7.357122e+08 7.357122e+08 0.998209 \n",
+ "4 6.882500e+02 6.882500e+02 0.000705 "
]
},
- "execution_count": 5,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -468,7 +533,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 14,
"id": "da828d30",
"metadata": {},
"outputs": [
@@ -489,103 +554,19 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x",
"y": [
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
- 737032448,
737032448,
737032448,
737032448,
@@ -609,113 +590,29 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x",
"y": [
- 752926.625,
- 832346.25,
- 921668.5,
- 1023621.5625,
- 1141370.5,
- 1278621.25,
- 1439833.625,
- 1630507.875,
- 1857545,
- 2129699.25,
- 2458166.5,
- 2857359,
- 3345938.25,
- 3948207.75,
- 4695982.5,
- 5631132,
- 6809027,
- 8303248.5,
- 10212014,
- 12666988,
- 15845385,
- 19986564,
- 25414842,
- 32570676,
- 42053112,
- 54676848,
- 71547424,
- 94156512,
- 124493704,
- 165154752,
- 219382320,
- 290860512,
- 382795616,
- 495131744,
- 617828800,
- 721720128,
- 770444416,
- 766268800,
- 745850304,
- 733191488,
- 731150848,
- 734141312,
- 737060736,
- 738073920,
- 737737984,
- 737144384,
- 736858624,
- 736877632,
- 736992192,
- 737061056,
- 737064960,
- 737042432,
+ 759533.0625,
+ 1320273,
+ 2631305,
+ 6359750.5,
+ 19233504,
+ 72160320,
+ 309902368,
+ 773134208,
+ 736156928,
+ 736861504,
737027328
],
"yaxis": "y"
@@ -730,103 +627,19 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x2",
"y": [
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
- 976779.375,
976779.375,
976779.375,
976779.375,
@@ -850,113 +663,29 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x2",
"y": [
- 872380.75,
- 958994.5,
- 1003960.5,
- 984014.625,
- 966661.75,
- 975252.125,
- 980736.625,
- 976016.625,
- 975696.5,
- 977566.9375,
- 976717.0625,
- 976543.625,
- 976956,
- 976731.375,
- 976753.5,
- 976814.25,
- 976756.8125,
- 976785,
- 976779.625,
- 976775,
- 976781.1875,
- 976776.6875,
- 976779,
- 976778.25,
- 976778.375,
- 976778.5625,
+ 872323.875,
+ 977467.625,
+ 976435.125,
+ 976771.8125,
+ 976777.25,
976778.25,
976778.375,
976778.375,
976778.375,
976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
- 976778.375,
976778.375
],
"yaxis": "y2"
@@ -970,113 +699,29 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x3",
"y": [
- 0.9989784349019597,
- 0.998870678960935,
- 0.9987494872138927,
- 0.9986111580768557,
- 0.9984513972171819,
- 0.998265176447157,
- 0.9980464447272205,
- 0.9977877393601536,
- 0.9974796971218287,
- 0.9971104403126632,
- 0.9966647784549101,
- 0.9961231571177719,
- 0.9954602565204836,
- 0.9946431018597434,
- 0.993628526786381,
- 0.9923597230823683,
- 0.9907615641367258,
- 0.9887342158102651,
- 0.9861444173486376,
- 0.9828135273631806,
- 0.9785011025728951,
- 0.9728823825135037,
- 0.9655173363547761,
- 0.955808355400928,
- 0.9429426586114157,
- 0.9258148699580727,
- 0.9029250012070025,
- 0.8722491631738851,
- 0.8310878926188063,
- 0.7759192930404063,
- 0.7023437426733212,
- 0.6053626773295061,
- 0.4806258299227553,
- 0.3282090288649001,
- 0.16173459977707794,
- 0.020775638903756923,
- 0.04533310316345801,
- 0.0396676592453091,
- 0.011963999717960857,
- 0.005211385211631822,
- 0.007980109988318995,
- 0.003922671258023093,
- 0.00003838094249006524,
- 0.0014130612604996192,
- 0.0009572658597522154,
- 0.00015187391044145727,
- 0.00023584307647741447,
- 0.00021005316715716702,
- 0.00005461903354355438,
- 0.00003881511604764517,
- 0.00004411203345012023,
- 0.000013546214996493614,
+ 0.9989694713379837,
+ 0.9982086636706665,
+ 0.9964298654596006,
+ 0.9913711390627947,
+ 0.973904128573726,
+ 0.9020934285921697,
+ 0.5795268324468668,
+ 0.04898259241905181,
+ 0.001187898853538671,
+ 0.00023193551445919514,
0.000006946776921278777
],
"yaxis": "y3"
@@ -1090,113 +735,29 @@
"type": "scatter",
"x": [
0,
- 10,
- 20,
- 30,
- 40,
- 50,
- 60,
- 70,
- 80,
- 90,
- 100,
- 110,
- 120,
- 130,
- 140,
- 150,
- 160,
- 170,
- 180,
- 190,
- 200,
- 210,
- 220,
- 230,
- 240,
- 250,
+ 52,
+ 104,
+ 156,
+ 208,
260,
- 270,
- 280,
- 290,
- 300,
- 310,
- 320,
- 330,
- 340,
- 350,
- 360,
- 370,
- 380,
- 390,
- 400,
- 410,
- 420,
- 430,
- 440,
- 450,
- 460,
- 470,
- 480,
- 490,
- 500,
- 510,
+ 312,
+ 364,
+ 416,
+ 468,
520
],
"xaxis": "x4",
"y": [
- 0.10688045598833411,
- 0.018207668440992624,
- 0.02782729211496711,
- 0.0074072509976984315,
- 0.010358147662567097,
- 0.001563556765313559,
- 0.004051324281903475,
- 0.000780882581596279,
- 0.0011086177981593848,
- 0.0008062849402404714,
- 0.00006379383266564161,
- 0.0002413544000148447,
- 0.00018082384264102628,
- 0.00004914108674745513,
- 0.000026490117074800028,
- 0.00003570407083994786,
- 0.000023098870202905338,
- 0.000005758721103217397,
- 2.5594316014299544e-7,
- 0.00000447900530250242,
- 0.000001855587911036717,
- 0.000002751388971537201,
- 3.839147402144932e-7,
- 0.0000011517442206434795,
- 0.0000010237726405719818,
- 8.318152704647351e-7,
+ 0.10693868305726664,
+ 0.0007046115198736664,
+ 0.0003524337315169047,
+ 0.000007742280594325611,
+ 0.0000021755168612154613,
0.0000011517442206434795,
0.0000010237726405719818,
0.0000010237726405719818,
0.0000010237726405719818,
0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
- 0.0000010237726405719818,
0.0000010237726405719818
],
"yaxis": "y4"
@@ -2259,7 +1820,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 15,
"id": "fefaf031",
"metadata": {},
"outputs": [
@@ -2323,7 +1884,7 @@
"2 income_aged_71 4.364799e+04 4.364699e+04 -0.000023"
]
},
- "execution_count": 7,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -2332,11 +1893,40 @@
"summary = calibrator.summary()\n",
"summary"
]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "zizmg24l2dh",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "The Calibration class provides comprehensive tools for survey weight calibration:\n",
+ "\n",
+ "1. **Pre-calibration assessment**:\n",
+ " - Analytical solution feasibility analysis\n",
+ " - Target tolerance checking\n",
+ " - Correlation analysis\n",
+ " - Target exclusion capabilities\n",
+ "\n",
+ "2. **Standard calibration**:\n",
+ " - Gradient-based optimization\n",
+ " - Multi-target support\n",
+ " - Progress monitoring\n",
+ " - Performance logging\n",
+ "\n",
+ "3. **Advanced features**:\n",
+ " - L0 regularization for sparse weights\n",
+ " - Hyperparameter tuning (see [hyperparameter tuning notebook](hyperparameter_tuning.ipynb))\n",
+ " - Robustness evaluation (see [robustness evaluation notebook](robustness_evaluation.ipynb))\n",
+ "\n",
+ "By using the pre-calibration assessment tools, you can identify and address potential issues before running the calibration, saving time and improving results. The L0 regularization feature allows you to maintain calibration accuracy while significantly reducing dataset size, which is valuable for large-scale applications."
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -2350,7 +1940,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/docs/intro.md b/docs/intro.md
index 9320689..9e4e4cf 100644
--- a/docs/intro.md
+++ b/docs/intro.md
@@ -1,15 +1,92 @@
## MicroCalibrate
-MicroCalibrate is a framework that enables survey weight calibration integrating informative diagnostics and visualizations for the user to easily identify and work through hard-to-hit targets. A dashboard allows the user to explore the reweighting process on sample data as well as supporting uploading or pasting a link to the performance log of a specific reweighting job.
-
-To make sure that your data can be loaded and visualized by the dashboard, it must be in csv format and contain the following columns:
-- epoch (int)
-- target_name (str)
-- target (float)
-- estimate (float)
-- error (float)
-- abs_error (float)
-- rel_abs_error (float)
-- loss (float)
-
-To access the performance dashboard see: https://microcalibrate.vercel.app/
+MicroCalibrate is a comprehensive framework for survey weight calibration that combines traditional calibration techniques with modern machine learning approaches. It enables users to adjust sample weights to match population targets while providing advanced features for sparsity, optimization, and robustness analysis.
+
+## Key features
+
+### 1. Core calibration
+- **Survey weight adjustment**: The system calibrates sample weights to match known population totals using gradient-based optimization.
+- **Multi-target support**: The calibration process can handle multiple calibration targets simultaneously.
+- **Custom estimate functions**: Users can use either estimate matrices or custom functions for flexible calibration scenarios.
+
+### 2. L0 regularization for sparsity
+- **Dataset reduction**: The algorithm automatically identifies and zeros out unnecessary weights to reduce dataset size.
+- **Sparse weight generation**: The system creates compact datasets while maintaining calibration accuracy.
+- **Configurable sparsity**: Users can control the trade-off between dataset size and calibration precision.
+
+### 3. Automatic hyperparameter tuning
+- **Cross-validation**: The system uses holdout validation to find optimal regularization parameters.
+- **Multi-objective optimization**: The optimization process balances calibration loss, accuracy, and sparsity.
+- **Optuna integration**: The package leverages state-of-the-art hyperparameter optimization through Optuna.
+
+### 4. Robustness evaluation
+- **Generalization assessment**: The evaluation module assesses how well calibration performs on unseen targets.
+- **Stability snalysis**: The system identifies targets that are difficult to calibrate reliably.
+- **Actionable recommendations**: Users receive specific suggestions for improving calibration robustness.
+
+### 5. Target analysis
+- **Pre-calibration assessment**: The system identifies problematic targets before calibration begins.
+- **Analytical solutions**: The analysis helps users understand the mathematical difficulty of target combinations.
+- **Order of magnitude warnings**: The system detects targets that differ significantly from initial estimates.
+
+### 6. Performance monitoring
+- **Detailed logging**: The system tracks calibration progress across epochs.
+- **Performance dashboard**: Users can visualize calibration results at https://microcalibrate.vercel.app/.
+- **CSV export**: The system can save detailed performance metrics for further analysis.
+
+## Dashboard requirements
+
+To use the performance dashboard for visualization, users must ensure their calibration log CSV contains the following fields:
+- epoch (int): The iteration number during calibration.
+- target_name (str): The name of each calibration target.
+- target (float): The target value to achieve.
+- estimate (float): The estimated value at each epoch.
+- error (float): The difference between target and estimate.
+- abs_error (float): The absolute error value.
+- rel_abs_error (float): The relative absolute error.
+- loss (float): The loss value at each epoch.
+
+## Getting started
+
+```python
+from microcalibrate import Calibration
+import numpy as np
+import pandas as pd
+
+# Basic calibration
+cal = Calibration(
+ weights=initial_weights,
+ targets=target_values,
+ estimate_matrix=contribution_matrix
+)
+performance = cal.calibrate()
+
+# With L0 regularization
+cal = Calibration(
+ weights=initial_weights,
+ targets=target_values,
+ estimate_matrix=contribution_matrix,
+ regularize_with_l0=True
+)
+performance = cal.calibrate()
+sparse_weights = cal.sparse_weights
+
+# Hyperparameter tuning
+best_params = cal.tune_l0_hyperparameters(n_trials=30)
+
+# Robustness evaluation
+robustness = cal.evaluate_holdout_robustness()
+print(robustness['recommendation'])
+```
+
+## Documentation structure
+
+- **[Basic calibration](calibration.ipynb)**: Core calibration concepts and basic usage
+- **[L0 regularization](l0_regularization.ipynb)**: Creating sparse weights for dataset reduction
+- **[Robustness evaluation](robustness_evaluation.ipynb)**: Assessing calibration stability and generalization
+
+## Support
+
+- **GitHub issues**: https://github.com/PolicyEngine/microcalibrate/issues
+- **Documentation**: https://policyengine.github.io/microcalibrate/
+- **Performance dashboard**: https://microcalibrate.vercel.app/
\ No newline at end of file
diff --git a/docs/l0_regularization.ipynb b/docs/l0_regularization.ipynb
new file mode 100644
index 0000000..2bb6b27
--- /dev/null
+++ b/docs/l0_regularization.ipynb
@@ -0,0 +1,1117 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# L0 regularization for sparse weights\n",
+ "\n",
+ "L0 regularization is a powerful technique that creates sparse weights during calibration, effectively reducing the dataset size by setting many weights to zero. This is particularly useful when:\n",
+ "\n",
+ "- You need to reduce computational costs in downstream processing\n",
+ "- You want to identify the most important records in your dataset\n",
+ "- You need a smaller, representative sample that still matches population targets\n",
+ "\n",
+ "## How L0 regularization works\n",
+ "\n",
+ "L0 regularization adds a penalty term to the calibration loss that encourages weights to be exactly zero. Unlike L1 regularization (which shrinks weights), L0 creates truly sparse solutions by using a differentiable approximation of the L0 norm through the Hard Concrete distribution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from microcalibrate import Calibration\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import logging\n",
+ "\n",
+ "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n",
+ "calibration_logger.setLevel(logging.WARNING)\n",
+ "\n",
+ "np.random.seed(42)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Example 1: Basic L0 regularization\n",
+ "\n",
+ "Let's create a synthetic dataset and apply L0 regularization to reduce its size while maintaining calibration accuracy."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset size: 5000 records\n",
+ "Number of targets: 10\n",
+ "Target names: ['income_18-30', 'employed_18-30', 'income_31-50', 'employed_31-50', 'income_51-65', 'employed_51-65', 'income_65+', 'employed_65+', 'total_income', 'total_employed']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create synthetic data\n",
+ "n_samples = 5000\n",
+ "n_targets = 10\n",
+ "\n",
+ "# Generate random data with some structure\n",
+ "age_groups = np.random.choice(['18-30', '31-50', '51-65', '65+'], n_samples)\n",
+ "income = np.random.lognormal(10.5, 0.8, n_samples) # Log-normal income distribution\n",
+ "employed = np.random.binomial(1, 0.65, n_samples)\n",
+ "\n",
+ "# Create estimate matrix with various demographic combinations\n",
+ "estimate_matrix = pd.DataFrame()\n",
+ "for age in ['18-30', '31-50', '51-65', '65+']:\n",
+ " mask = age_groups == age\n",
+ " estimate_matrix[f'income_{age}'] = mask * income\n",
+ " estimate_matrix[f'employed_{age}'] = mask * employed\n",
+ "\n",
+ "estimate_matrix['total_income'] = income\n",
+ "estimate_matrix['total_employed'] = employed\n",
+ "\n",
+ "# Set realistic targets (scaled population values)\n",
+ "targets = estimate_matrix.sum().values * 1.1 # 10% higher than unweighted\n",
+ "\n",
+ "print(f\"Dataset size: {n_samples} records\")\n",
+ "print(f\"Number of targets: {len(targets)}\")\n",
+ "print(f\"Target names: {list(estimate_matrix.columns)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Comparing standard vs L0 calibration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running standard calibration...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2660.14epoch/s, loss=13.2, weights_mean=5.1, weights_std=2.42, weights_min=0.842]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Standard calibration results:\n",
+ "Non-zero weights: 5000 (100.0%)\n",
+ "Weight range: [0.835, 9.164]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Standard calibration (no sparsity)\n",
+ "weights_init = np.ones(n_samples)\n",
+ "\n",
+ "cal_standard = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=200,\n",
+ " learning_rate=1e-3,\n",
+ " regularize_with_l0=False\n",
+ ")\n",
+ "\n",
+ "print(\"Running standard calibration...\")\n",
+ "perf_standard = cal_standard.calibrate()\n",
+ "weights_standard = cal_standard.weights\n",
+ "\n",
+ "print(f\"\\nStandard calibration results:\")\n",
+ "print(f\"Non-zero weights: {np.sum(weights_standard != 0)} ({100*np.mean(weights_standard != 0):.1f}%)\")\n",
+ "print(f\"Weight range: [{weights_standard.min():.3f}, {weights_standard.max():.3f}]\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running L0 regularized calibration...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1776.92epoch/s, loss=13.5, weights_mean=5.13, weights_std=2.43, weights_min=0.84] \n",
+ "Sparse reweighting progress: 100%|██████████| 400/400 [00:00<00:00, 722.65epoch/s, loss=0.0103, loss_rel_change=-0.691]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "L0 calibration results:\n",
+ "Non-zero weights: 1998 (40.0%)\n",
+ "Dataset reduction: 60.0%\n",
+ "Weight range: [0.010, 14.061]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# L0 regularized calibration\n",
+ "cal_l0 = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=200,\n",
+ " learning_rate=1e-3,\n",
+ " regularize_with_l0=True,\n",
+ " l0_lambda=5e-6, # Regularization strength\n",
+ " init_mean=0.999, # Start with most weights active\n",
+ " temperature=0.5, # Controls sparsity gradient\n",
+ ")\n",
+ "\n",
+ "print(\"Running L0 regularized calibration...\")\n",
+ "perf_l0 = cal_l0.calibrate()\n",
+ "weights_l0 = cal_l0.sparse_weights\n",
+ "\n",
+ "print(f\"\\nL0 calibration results:\")\n",
+ "print(f\"Non-zero weights: {np.sum(weights_l0 != 0)} ({100*np.mean(weights_l0 != 0):.1f}%)\")\n",
+ "print(f\"Dataset reduction: {100*(1 - np.mean(weights_l0 != 0)):.1f}%\")\n",
+ "print(f\"Weight range: [{weights_l0[weights_l0>0].min():.3f}, {weights_l0.max():.3f}]\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Hyperparameter tuning for L0 regularization\n",
+ "\n",
+ "Finding the optimal L0 regularization parameters is crucial for achieving the right balance between sparsity and calibration accuracy. This notebook demonstrates how to use the automatic hyperparameter tuning feature to find the best parameters for your specific dataset.\n",
+ "\n",
+ "## Why hyperparameter tuning matters\n",
+ "\n",
+ "L0 regularization has three key parameters that interact in complex ways:\n",
+ "- **l0_lambda**: Controls the strength of sparsity penalty\n",
+ "- **init_mean**: Sets the initial proportion of active weights\n",
+ "- **temperature**: Determines how \"hard\" the sparsity decisions are\n",
+ "\n",
+ "Manual tuning can be time-consuming and may miss optimal combinations. The automatic tuning uses Optuna to efficiently search the parameter space."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Basic hyperparameter tuning\n",
+ "\n",
+ "Let's start with a simple tuning run to find good L0 parameters. The tuning process will:\n",
+ "1. Create multiple holdout sets for cross-validation\n",
+ "2. Try different parameter combinations\n",
+ "3. Evaluate each combination on both training and validation targets\n",
+ "4. Select the best parameters based on a multi-objective criterion"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:microcalibrate.hyperparameter_tuning:Multi-holdout hyperparameter tuning:\n",
+ " - 3 holdout sets\n",
+ " - 2 targets per holdout (20.0%)\n",
+ " - Aggregation: mean\n",
+ "\n",
+ "WARNING:microcalibrate.hyperparameter_tuning:Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Starting hyperparameter tuning...\n",
+ "This will take a few minutes as it explores different parameter combinations.\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "85fd8b3742ea4b07b12ca58d3ea857ff",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/20 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1549.46epoch/s, loss=17.2, weights_mean=5.7, weights_std=2.74, weights_min=0.962]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 347.13epoch/s, loss=0.0279, loss_rel_change=-0.485]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1961.55epoch/s, loss=68, weights_mean=10.2, weights_std=3.83, weights_min=1.05]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 357.60epoch/s, loss=0.0297, loss_rel_change=-0.998]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1983.82epoch/s, loss=148, weights_mean=14.5, weights_std=4.58, weights_min=1.48]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 296.82epoch/s, loss=0.0652, loss_rel_change=-0.999]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 0:\n",
+ " Objectives by holdout: ['109.9197', '110.0278', '10.0005']\n",
+ " Mean objective: 76.6493\n",
+ " Mean val accuracy: 33.33% (±47.14%)\n",
+ " Sparsity: 0.00%\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1506.98epoch/s, loss=252, weights_mean=18.6, weights_std=5.22, weights_min=2.08]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 342.20epoch/s, loss=0.0483, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1877.67epoch/s, loss=377, weights_mean=22.5, weights_std=5.71, weights_min=6.05]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 293.54epoch/s, loss=0.0579, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 760.92epoch/s, loss=519, weights_mean=26.3, weights_std=6.09, weights_min=6.69]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 285.17epoch/s, loss=0.122, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1814.74epoch/s, loss=680, weights_mean=29.9, weights_std=6.44, weights_min=8.01]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 333.69epoch/s, loss=0.012, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1365.93epoch/s, loss=850, weights_mean=33.2, weights_std=6.73, weights_min=9.59]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 306.98epoch/s, loss=0.00717, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 684.75epoch/s, loss=1.03e+3, weights_mean=36.5, weights_std=7.01, weights_min=14.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 330.39epoch/s, loss=0.0122, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1853.43epoch/s, loss=1.22e+3, weights_mean=39.6, weights_std=7.17, weights_min=15.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 270.74epoch/s, loss=0.18, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1673.64epoch/s, loss=1.41e+3, weights_mean=42.5, weights_std=7.37, weights_min=17.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 314.66epoch/s, loss=0.162, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1858.19epoch/s, loss=1.61e+3, weights_mean=45.4, weights_std=7.5, weights_min=17.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 300.71epoch/s, loss=0.148, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1888.00epoch/s, loss=1.81e+3, weights_mean=48, weights_std=7.69, weights_min=19.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 320.72epoch/s, loss=0.154, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1826.56epoch/s, loss=2e+3, weights_mean=50.5, weights_std=7.84, weights_min=25.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 345.48epoch/s, loss=0.144, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1912.76epoch/s, loss=2.2e+3, weights_mean=52.9, weights_std=8.03, weights_min=25.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 346.88epoch/s, loss=0.139, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1808.84epoch/s, loss=2.41e+3, weights_mean=55.2, weights_std=8.14, weights_min=27]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 297.11epoch/s, loss=0.0109, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 2018.82epoch/s, loss=2.6e+3, weights_mean=57.4, weights_std=8.17, weights_min=31.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 266.95epoch/s, loss=0.0126, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1980.28epoch/s, loss=2.81e+3, weights_mean=59.5, weights_std=8.27, weights_min=31.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 311.84epoch/s, loss=0.0114, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 5:\n",
+ " Objectives by holdout: ['110.2807', '110.2407', '110.2354']\n",
+ " Mean objective: 110.2523\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 3.16%\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 2016.38epoch/s, loss=3.01e+3, weights_mean=61.5, weights_std=8.41, weights_min=30.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 300.27epoch/s, loss=0.0303, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1813.63epoch/s, loss=3.19e+3, weights_mean=63.3, weights_std=8.48, weights_min=30.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 348.51epoch/s, loss=0.0383, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1943.34epoch/s, loss=3.39e+3, weights_mean=65.1, weights_std=8.49, weights_min=34]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 291.53epoch/s, loss=0.0413, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1881.48epoch/s, loss=3.58e+3, weights_mean=66.9, weights_std=8.61, weights_min=37.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 354.18epoch/s, loss=0.0203, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1999.21epoch/s, loss=3.75e+3, weights_mean=68.5, weights_std=8.63, weights_min=37.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 355.29epoch/s, loss=0.00945, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 731.52epoch/s, loss=3.94e+3, weights_mean=70.2, weights_std=8.67, weights_min=38.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 318.25epoch/s, loss=0.0231, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 2034.75epoch/s, loss=4.12e+3, weights_mean=71.7, weights_std=8.75, weights_min=41.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 297.68epoch/s, loss=0.658, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1994.67epoch/s, loss=4.27e+3, weights_mean=73.2, weights_std=8.79, weights_min=40.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 316.13epoch/s, loss=0.665, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 737.44epoch/s, loss=4.44e+3, weights_mean=74.5, weights_std=8.83, weights_min=42.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 311.54epoch/s, loss=0.825, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1816.41epoch/s, loss=4.6e+3, weights_mean=75.8, weights_std=8.89, weights_min=43.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 283.56epoch/s, loss=0.359, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1865.69epoch/s, loss=4.75e+3, weights_mean=77, weights_std=8.86, weights_min=44.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 369.63epoch/s, loss=0.402, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 2012.95epoch/s, loss=4.91e+3, weights_mean=78.2, weights_std=8.88, weights_min=44.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 319.55epoch/s, loss=0.463, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 2031.77epoch/s, loss=5.05e+3, weights_mean=79.3, weights_std=8.87, weights_min=48.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 296.65epoch/s, loss=0.386, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1906.16epoch/s, loss=5.18e+3, weights_mean=80.4, weights_std=8.89, weights_min=51.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 262.07epoch/s, loss=0.375, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1541.84epoch/s, loss=5.32e+3, weights_mean=81.4, weights_std=8.93, weights_min=51]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 221.70epoch/s, loss=0.472, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 10:\n",
+ " Objectives by holdout: ['110.1258', '110.1155', '110.1133']\n",
+ " Mean objective: 110.1182\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 0.08%\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1374.11epoch/s, loss=5.45e+3, weights_mean=82.4, weights_std=8.92, weights_min=52.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 224.13epoch/s, loss=3.02, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1509.43epoch/s, loss=5.57e+3, weights_mean=83.3, weights_std=8.89, weights_min=50]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 232.38epoch/s, loss=3.29, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1050.00epoch/s, loss=5.69e+3, weights_mean=84.1, weights_std=8.96, weights_min=50.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 227.71epoch/s, loss=3.33, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1599.43epoch/s, loss=5.79e+3, weights_mean=84.9, weights_std=8.94, weights_min=53.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 232.27epoch/s, loss=3.73, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1640.21epoch/s, loss=5.9e+3, weights_mean=85.7, weights_std=9, weights_min=52.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 205.89epoch/s, loss=3.88, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1071.53epoch/s, loss=6.03e+3, weights_mean=86.5, weights_std=9.03, weights_min=56.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 178.22epoch/s, loss=3.88, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1378.17epoch/s, loss=6.13e+3, weights_mean=87.2, weights_std=9.02, weights_min=58.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 187.56epoch/s, loss=1.11, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 990.38epoch/s, loss=6.22e+3, weights_mean=88, weights_std=9.08, weights_min=58.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 235.56epoch/s, loss=1.28, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1169.07epoch/s, loss=6.34e+3, weights_mean=88.6, weights_std=9.05, weights_min=56.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 189.05epoch/s, loss=1.24, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1263.08epoch/s, loss=6.42e+3, weights_mean=89.2, weights_std=9, weights_min=59.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 186.46epoch/s, loss=1.8, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1296.58epoch/s, loss=6.5e+3, weights_mean=89.8, weights_std=9.05, weights_min=59.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 187.17epoch/s, loss=1.74, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1361.70epoch/s, loss=6600.0, weights_mean=90.4, weights_std=9.1, weights_min=59.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 208.56epoch/s, loss=1.91, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1469.60epoch/s, loss=6.67e+3, weights_mean=90.9, weights_std=9.05, weights_min=57.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 192.95epoch/s, loss=1.39, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1604.42epoch/s, loss=6.74e+3, weights_mean=91.4, weights_std=9.13, weights_min=59]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 193.51epoch/s, loss=1.31, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1444.58epoch/s, loss=6.83e+3, weights_mean=91.9, weights_std=9.18, weights_min=62.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 189.72epoch/s, loss=1.39, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 15:\n",
+ " Objectives by holdout: ['9.9997', '10.0021', '9.9987']\n",
+ " Mean objective: 10.0001\n",
+ " Mean val accuracy: 100.00% (±0.00%)\n",
+ " Sparsity: 0.04%\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 862.93epoch/s, loss=6.91e+3, weights_mean=92.4, weights_std=9.19, weights_min=63.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 186.64epoch/s, loss=1.34, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 986.41epoch/s, loss=6.97e+3, weights_mean=92.9, weights_std=9.12, weights_min=63.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 162.70epoch/s, loss=1.33, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1075.04epoch/s, loss=7.05e+3, weights_mean=93.4, weights_std=9.14, weights_min=61.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 195.61epoch/s, loss=1.39, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1250.63epoch/s, loss=7.1e+3, weights_mean=93.7, weights_std=9.12, weights_min=62.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 162.61epoch/s, loss=1.02, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1480.33epoch/s, loss=7.16e+3, weights_mean=94.1, weights_std=9.16, weights_min=64.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 213.27epoch/s, loss=1.09, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1535.44epoch/s, loss=7.21e+3, weights_mean=94.4, weights_std=9.22, weights_min=63.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 226.69epoch/s, loss=0.932, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1697.40epoch/s, loss=7.27e+3, weights_mean=94.8, weights_std=9.26, weights_min=63.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 195.13epoch/s, loss=1.27, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1298.33epoch/s, loss=7.3e+3, weights_mean=95.1, weights_std=9.23, weights_min=64.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 218.78epoch/s, loss=1.2, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1068.55epoch/s, loss=7.36e+3, weights_mean=95.4, weights_std=9.22, weights_min=66.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 199.13epoch/s, loss=1.38, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1222.49epoch/s, loss=7.4e+3, weights_mean=95.6, weights_std=9.21, weights_min=66.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 173.33epoch/s, loss=0.891, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1347.59epoch/s, loss=7.41e+3, weights_mean=95.9, weights_std=9.22, weights_min=63.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 194.97epoch/s, loss=0.814, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 50/50 [00:00<00:00, 1135.51epoch/s, loss=7.48e+3, weights_mean=96.2, weights_std=9.25, weights_min=65.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 188.28epoch/s, loss=0.814, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:\n",
+ "Multi-holdout tuning completed!\n",
+ "Best parameters:\n",
+ " - l0_lambda: 7.43e-05\n",
+ " - init_mean: 0.8443\n",
+ " - temperature: 1.7891\n",
+ "Performance across 3 holdouts:\n",
+ " - Mean val loss: 0.002145 (±0.000429)\n",
+ " - Mean val accuracy: 100.00% (±0.00%)\n",
+ " - Individual objectives: ['9.9997', '10.0021', '9.9987']\n",
+ " - Sparsity: 0.04%\n",
+ "\n",
+ "Evaluation history saved with 60 records across 20 trials.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================================================\n",
+ "Tuning completed!\n",
+ "==================================================\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialize calibration object\n",
+ "weights_init = np.ones(n_samples)\n",
+ "\n",
+ "cal = Calibration(\n",
+ " weights=weights_init,\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=100, # Will be overridden during tuning\n",
+ " learning_rate=1e-3,\n",
+ ")\n",
+ "\n",
+ "print(\"Starting hyperparameter tuning...\")\n",
+ "print(\"This will take a few minutes as it explores different parameter combinations.\\n\")\n",
+ "\n",
+ "# Run hyperparameter tuning\n",
+ "best_params = cal.tune_l0_hyperparameters(\n",
+ " n_trials=20, # Number of parameter combinations to try\n",
+ " objectives_balance={\n",
+ " 'loss': 1.0, # Weight for calibration loss\n",
+ " 'accuracy': 100.0, # Weight for accuracy (targets within 10%)\n",
+ " 'sparsity': 10.0, # Weight for sparsity\n",
+ " },\n",
+ " n_holdout_sets=3, # Number of cross-validation folds\n",
+ " holdout_fraction=0.2, # Fraction of targets to hold out\n",
+ " epochs_per_trial=50, # Epochs per trial (faster for tuning)\n",
+ ")\n",
+ "\n",
+ "print(\"\\n\" + \"=\"*50)\n",
+ "print(\"Tuning completed!\")\n",
+ "print(\"=\"*50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Analyzing tuning results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best parameters found:\n",
+ " l0_lambda: 7.43e-05\n",
+ " init_mean: 0.8443\n",
+ " temperature: 1.79\n",
+ "\n",
+ "Performance metrics:\n",
+ " Mean validation loss: 0.002145 (±0.000429)\n",
+ " Mean validation accuracy: 100.0% (±0.0%)\n",
+ " Sparsity achieved: 0.0%\n",
+ "\n",
+ "Cross-validation results:\n",
+ " Holdout objectives: [np.float64(9.999662299386227), np.float64(10.002067347057164), np.float64(9.998705346327275)]\n",
+ " Number of holdout sets: 3\n",
+ " Aggregation method: mean\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Display best parameters\n",
+ "print(\"Best parameters found:\")\n",
+ "print(f\" l0_lambda: {best_params['l0_lambda']:.2e}\")\n",
+ "print(f\" init_mean: {best_params['init_mean']:.4f}\")\n",
+ "print(f\" temperature: {best_params['temperature']:.2f}\")\n",
+ "print()\n",
+ "print(\"Performance metrics:\")\n",
+ "print(f\" Mean validation loss: {best_params['mean_val_loss']:.6f} (±{best_params['std_val_loss']:.6f})\")\n",
+ "print(f\" Mean validation accuracy: {best_params['mean_val_accuracy']:.1%} (±{best_params['std_val_accuracy']:.1%})\")\n",
+ "print(f\" Sparsity achieved: {best_params['sparsity']:.1%}\")\n",
+ "print()\n",
+ "print(\"Cross-validation results:\")\n",
+ "print(f\" Holdout objectives: {best_params['holdout_objectives']}\")\n",
+ "print(f\" Number of holdout sets: {best_params['n_holdout_sets']}\")\n",
+ "print(f\" Aggregation method: {best_params['aggregation']}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Applying the best parameters\n",
+ "\n",
+ "Now let's apply the best parameters found through tuning and compare with default parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Calibrating with tuned parameters...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2654.50epoch/s, loss=12.8, weights_mean=5.02, weights_std=2.41, weights_min=0.84]\n",
+ "Sparse reweighting progress: 100%|██████████| 400/400 [00:00<00:00, 678.76epoch/s, loss=0.105, loss_rel_change=-0.786]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Calibrating with default parameters...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2813.69epoch/s, loss=13.1, weights_mean=5.09, weights_std=2.43, weights_min=0.842]\n",
+ "Sparse reweighting progress: 100%|██████████| 400/400 [00:00<00:00, 593.57epoch/s, loss=0.0103, loss_rel_change=-0.691]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Comparison complete!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Calibration with tuned parameters\n",
+ "cal_tuned = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=200,\n",
+ " learning_rate=1e-3,\n",
+ " regularize_with_l0=True,\n",
+ " l0_lambda=best_params['l0_lambda'],\n",
+ " init_mean=best_params['init_mean'],\n",
+ " temperature=best_params['temperature'],\n",
+ ")\n",
+ "\n",
+ "print(\"Calibrating with tuned parameters...\")\n",
+ "perf_tuned = cal_tuned.calibrate()\n",
+ "weights_tuned = cal_tuned.sparse_weights\n",
+ "\n",
+ "# Calibration with default parameters\n",
+ "cal_default = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=200,\n",
+ " learning_rate=1e-3,\n",
+ " regularize_with_l0=True,\n",
+ " l0_lambda=5e-6, # Default\n",
+ " init_mean=0.999, # Default\n",
+ " temperature=0.5, # Default\n",
+ ")\n",
+ "\n",
+ "print(\"Calibrating with default parameters...\")\n",
+ "perf_default = cal_default.calibrate()\n",
+ "weights_default = cal_default.sparse_weights\n",
+ "\n",
+ "print(\"\\nComparison complete!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Parameter comparison:\n",
+ " Label Non-zero weights Sparsity Mean rel error Max rel error Within 1% Within 5% Within 10%\n",
+ "Default params 1998 60.0% 0.0250 0.0620 30.0% 90.0% 100.0%\n",
+ " Tuned params 1372 72.6% 0.1008 0.1155 0.0% 0.0% 40.0%\n",
+ "\n",
+ "==================================================\n",
+ "Improvement summary:\n",
+ "Sparsity improvement: 12.5%\n",
+ "Dataset reduction: 72.6% of records can be dropped\n",
+ "Remaining records: 1,372 out of 5,000\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Compare results\n",
+ "def evaluate_calibration(weights, estimate_matrix, targets, label):\n",
+ " estimates = (estimate_matrix.T * weights).sum(axis=1).values\n",
+ " rel_errors = np.abs((estimates - targets) / targets)\n",
+ " \n",
+ " return {\n",
+ " 'Label': label,\n",
+ " 'Non-zero weights': np.sum(weights != 0),\n",
+ " 'Sparsity': f\"{100 * np.mean(weights == 0):.1f}%\",\n",
+ " 'Mean rel error': f\"{np.mean(rel_errors):.4f}\",\n",
+ " 'Max rel error': f\"{np.max(rel_errors):.4f}\",\n",
+ " 'Within 1%': f\"{100 * np.mean(rel_errors < 0.01):.1f}%\",\n",
+ " 'Within 5%': f\"{100 * np.mean(rel_errors < 0.05):.1f}%\",\n",
+ " 'Within 10%': f\"{100 * np.mean(rel_errors < 0.10):.1f}%\",\n",
+ " }\n",
+ "\n",
+ "comparison = pd.DataFrame([\n",
+ " evaluate_calibration(weights_default, estimate_matrix, targets, 'Default params'),\n",
+ " evaluate_calibration(weights_tuned, estimate_matrix, targets, 'Tuned params'),\n",
+ "])\n",
+ "\n",
+ "print(\"\\nParameter comparison:\")\n",
+ "print(comparison.to_string(index=False))\n",
+ "\n",
+ "print(\"\\n\" + \"=\"*50)\n",
+ "print(\"Improvement summary:\")\n",
+ "sparsity_default = np.mean(weights_default == 0)\n",
+ "sparsity_tuned = np.mean(weights_tuned == 0)\n",
+ "print(f\"Sparsity improvement: {sparsity_tuned - sparsity_default:.1%}\")\n",
+ "print(f\"Dataset reduction: {100*sparsity_tuned:.1f}% of records can be dropped\")\n",
+ "print(f\"Remaining records: {np.sum(weights_tuned != 0):,} out of {len(weights_tuned):,}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Advanced tuning with custom objectives\n",
+ "\n",
+ "You can customize the tuning process by adjusting the objective balance. Here's how different balances affect the results:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:microcalibrate.hyperparameter_tuning:Multi-holdout hyperparameter tuning:\n",
+ " - 2 holdout sets\n",
+ " - 2 targets per holdout (20.0%)\n",
+ " - Aggregation: mean\n",
+ "\n",
+ "WARNING:microcalibrate.hyperparameter_tuning:Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Tuning with Accuracy-focused objectives...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cd9f50e3d9f8423493290be31087c02d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1640.43epoch/s, loss=18.8, weights_mean=5.88, weights_std=2.83, weights_min=0.98]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 334.73epoch/s, loss=0.0282, loss_rel_change=-0.479]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1762.34epoch/s, loss=75, weights_mean=10.6, weights_std=3.91, weights_min=1.27]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 327.34epoch/s, loss=0.112, loss_rel_change=-0.993]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 0:\n",
+ " Objectives by holdout: ['201.0131', '201.0256']\n",
+ " Mean objective: 201.0193\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 0.00%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1506.32epoch/s, loss=163, weights_mean=15.2, weights_std=4.78, weights_min=2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 308.44epoch/s, loss=0.0668, loss_rel_change=-0.996]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1833.63epoch/s, loss=284, weights_mean=19.6, weights_std=5.45, weights_min=2.41]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 239.72epoch/s, loss=0.179, loss_rel_change=-0.996]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1667.23epoch/s, loss=432, weights_mean=24, weights_std=6, weights_min=2.72]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 298.97epoch/s, loss=0.0608, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1786.79epoch/s, loss=607, weights_mean=28.2, weights_std=6.53, weights_min=8.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 391.59epoch/s, loss=0.023, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1736.41epoch/s, loss=802, weights_mean=32.3, weights_std=6.95, weights_min=10.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 303.16epoch/s, loss=0.271, loss_rel_change=-0.998]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1771.79epoch/s, loss=1.02e+3, weights_mean=36.3, weights_std=7.32, weights_min=12.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 281.84epoch/s, loss=0.256, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1773.67epoch/s, loss=1.26e+3, weights_mean=40.1, weights_std=7.64, weights_min=13.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 286.69epoch/s, loss=0.234, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 516.45epoch/s, loss=1.51e+3, weights_mean=43.9, weights_std=7.94, weights_min=14.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 402.53epoch/s, loss=0.177, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 486.85epoch/s, loss=1.77e+3, weights_mean=47.5, weights_std=8.21, weights_min=19]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 287.30epoch/s, loss=0.0374, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1961.57epoch/s, loss=2.06e+3, weights_mean=51.1, weights_std=8.48, weights_min=20.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 332.78epoch/s, loss=0.0206, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 5:\n",
+ " Objectives by holdout: ['201.5720', '201.5335']\n",
+ " Mean objective: 201.5527\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 2.08%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1934.37epoch/s, loss=2.35e+3, weights_mean=54.5, weights_std=8.7, weights_min=24.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 320.62epoch/s, loss=0.0319, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1842.98epoch/s, loss=2.65e+3, weights_mean=57.8, weights_std=8.89, weights_min=25.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 270.79epoch/s, loss=0.0323, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1863.89epoch/s, loss=2.96e+3, weights_mean=61.1, weights_std=9.06, weights_min=27.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 306.88epoch/s, loss=0.00722, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1825.70epoch/s, loss=3.29e+3, weights_mean=64.3, weights_std=9.17, weights_min=30.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 327.58epoch/s, loss=0.0185, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1968.91epoch/s, loss=3.61e+3, weights_mean=67.4, weights_std=9.41, weights_min=31.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 256.84epoch/s, loss=0.677, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1961.18epoch/s, loss=3.95e+3, weights_mean=70.3, weights_std=9.54, weights_min=31.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 337.01epoch/s, loss=0.796, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1860.91epoch/s, loss=4.29e+3, weights_mean=73.3, weights_std=9.68, weights_min=37.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 252.18epoch/s, loss=0.364, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1951.72epoch/s, loss=4.61e+3, weights_mean=76, weights_std=9.85, weights_min=40.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 340.44epoch/s, loss=0.539, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:\n",
+ "Multi-holdout tuning completed!\n",
+ "Best parameters:\n",
+ " - l0_lambda: 1.31e-06\n",
+ " - init_mean: 0.9322\n",
+ " - temperature: 1.4017\n",
+ "Performance across 2 holdouts:\n",
+ " - Mean val loss: 0.020892 (±0.014057)\n",
+ " - Mean val accuracy: 50.00% (±50.00%)\n",
+ " - Individual objectives: ['201.0349', '1.0068']\n",
+ " - Sparsity: 0.00%\n",
+ "\n",
+ "Evaluation history saved with 20 records across 10 trials.\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Multi-holdout hyperparameter tuning:\n",
+ " - 2 holdout sets\n",
+ " - 2 targets per holdout (20.0%)\n",
+ " - Aggregation: mean\n",
+ "\n",
+ "WARNING:microcalibrate.hyperparameter_tuning:Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Tuning with Sparsity-focused objectives...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0818d35a3df04db0a94f17031644e061",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 567.32epoch/s, loss=18.4, weights_mean=5.85, weights_std=2.81, weights_min=0.981]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 257.40epoch/s, loss=0.0282, loss_rel_change=-0.479]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1167.63epoch/s, loss=74.1, weights_mean=10.6, weights_std=3.97, weights_min=1.08]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 229.01epoch/s, loss=0.106, loss_rel_change=-0.993]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 0:\n",
+ " Objectives by holdout: ['99.9935', '100.0218']\n",
+ " Mean objective: 100.0076\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 0.00%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1799.85epoch/s, loss=164, weights_mean=15.2, weights_std=4.77, weights_min=2.37]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 279.03epoch/s, loss=0.0647, loss_rel_change=-0.997]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1666.83epoch/s, loss=285, weights_mean=19.7, weights_std=5.48, weights_min=3.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 292.92epoch/s, loss=0.182, loss_rel_change=-0.996]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1645.45epoch/s, loss=433, weights_mean=24.1, weights_std=6.03, weights_min=4.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 288.71epoch/s, loss=0.0601, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1507.39epoch/s, loss=608, weights_mean=28.3, weights_std=6.51, weights_min=5.81]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 274.26epoch/s, loss=0.0228, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1698.17epoch/s, loss=808, weights_mean=32.5, weights_std=6.96, weights_min=8.26]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 401.21epoch/s, loss=0.269, loss_rel_change=-0.998]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1829.26epoch/s, loss=1.03e+3, weights_mean=36.4, weights_std=7.29, weights_min=9.09]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 279.44epoch/s, loss=0.253, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1618.55epoch/s, loss=1.27e+3, weights_mean=40.4, weights_std=7.57, weights_min=13.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 270.84epoch/s, loss=0.235, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1770.25epoch/s, loss=1.53e+3, weights_mean=44.1, weights_std=7.94, weights_min=15.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 441.84epoch/s, loss=0.174, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1870.29epoch/s, loss=1.79e+3, weights_mean=47.7, weights_std=8.29, weights_min=18.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 327.45epoch/s, loss=0.0345, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1893.19epoch/s, loss=2.08e+3, weights_mean=51.3, weights_std=8.57, weights_min=19.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 437.13epoch/s, loss=0.0204, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 5:\n",
+ " Objectives by holdout: ['99.9053', '99.5356']\n",
+ " Mean objective: 99.7205\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 2.04%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1919.30epoch/s, loss=2.37e+3, weights_mean=54.8, weights_std=8.76, weights_min=26.3]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 328.66epoch/s, loss=0.0325, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1895.25epoch/s, loss=2.68e+3, weights_mean=58.1, weights_std=8.96, weights_min=29.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 275.38epoch/s, loss=0.0321, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1853.24epoch/s, loss=2.99e+3, weights_mean=61.3, weights_std=9.17, weights_min=28.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 330.65epoch/s, loss=0.00741, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1940.52epoch/s, loss=3.31e+3, weights_mean=64.4, weights_std=9.37, weights_min=30.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 343.55epoch/s, loss=0.0188, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 2020.47epoch/s, loss=3.64e+3, weights_mean=67.5, weights_std=9.52, weights_min=35.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 330.64epoch/s, loss=0.686, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 506.36epoch/s, loss=3.97e+3, weights_mean=70.5, weights_std=9.68, weights_min=35.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 334.39epoch/s, loss=0.806, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1901.29epoch/s, loss=4.31e+3, weights_mean=73.4, weights_std=9.78, weights_min=39.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 261.74epoch/s, loss=0.37, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1848.88epoch/s, loss=4.66e+3, weights_mean=76.3, weights_std=9.85, weights_min=41.4]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 335.97epoch/s, loss=0.565, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:\n",
+ "Multi-holdout tuning completed!\n",
+ "Best parameters:\n",
+ " - l0_lambda: 1.58e-05\n",
+ " - init_mean: 0.5779\n",
+ " - temperature: 0.7340\n",
+ "Performance across 2 holdouts:\n",
+ " - Mean val loss: 0.053487 (±0.052596)\n",
+ " - Mean val accuracy: 50.00% (±50.00%)\n",
+ " - Individual objectives: ['98.6861', '49.1909']\n",
+ " - Sparsity: 1.62%\n",
+ "\n",
+ "Evaluation history saved with 20 records across 10 trials.\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Multi-holdout hyperparameter tuning:\n",
+ " - 2 holdout sets\n",
+ " - 2 targets per holdout (20.0%)\n",
+ " - Aggregation: mean\n",
+ "\n",
+ "WARNING:microcalibrate.hyperparameter_tuning:Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Tuning with Balanced objectives...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f63b5da1620e4b058f380ebb243d0eab",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1818.97epoch/s, loss=19.6, weights_mean=5.94, weights_std=2.84, weights_min=0.982]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 325.64epoch/s, loss=0.0282, loss_rel_change=-0.479]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1941.75epoch/s, loss=76.5, weights_mean=10.7, weights_std=3.96, weights_min=1.13]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 336.02epoch/s, loss=0.114, loss_rel_change=-0.993]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 0:\n",
+ " Objectives by holdout: ['110.0095', '110.0267']\n",
+ " Mean objective: 110.0181\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 0.00%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1966.42epoch/s, loss=165, weights_mean=15.3, weights_std=4.77, weights_min=1.73]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 266.74epoch/s, loss=0.0703, loss_rel_change=-0.996]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1976.05epoch/s, loss=287, weights_mean=19.8, weights_std=5.37, weights_min=3.56]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 341.35epoch/s, loss=0.182, loss_rel_change=-0.996]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1937.71epoch/s, loss=433, weights_mean=24.1, weights_std=5.95, weights_min=5.88]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 326.82epoch/s, loss=0.0603, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 833.21epoch/s, loss=610, weights_mean=28.3, weights_std=6.45, weights_min=8.51]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 326.70epoch/s, loss=0.0218, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1949.33epoch/s, loss=804, weights_mean=32.4, weights_std=6.78, weights_min=10.8]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 338.99epoch/s, loss=0.271, loss_rel_change=-0.998]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1983.25epoch/s, loss=1.02e+3, weights_mean=36.3, weights_std=7.12, weights_min=11.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 272.58epoch/s, loss=0.254, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1898.59epoch/s, loss=1.26e+3, weights_mean=40.2, weights_std=7.54, weights_min=12.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 335.71epoch/s, loss=0.232, loss_rel_change=-0.999]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1959.13epoch/s, loss=1.52e+3, weights_mean=43.9, weights_std=7.9, weights_min=17.1]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 213.15epoch/s, loss=0.177, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1086.84epoch/s, loss=1.78e+3, weights_mean=47.5, weights_std=8.24, weights_min=17.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 308.13epoch/s, loss=0.0344, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1482.00epoch/s, loss=2.06e+3, weights_mean=51, weights_std=8.5, weights_min=17.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 326.75epoch/s, loss=0.0214, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:Trial 5:\n",
+ " Objectives by holdout: ['110.4485', '110.3485']\n",
+ " Mean objective: 110.3985\n",
+ " Mean val accuracy: 0.00% (±0.00%)\n",
+ " Sparsity: 2.06%\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1841.41epoch/s, loss=2.35e+3, weights_mean=54.5, weights_std=8.79, weights_min=22.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 450.16epoch/s, loss=0.0335, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1908.90epoch/s, loss=2.66e+3, weights_mean=57.8, weights_std=9.01, weights_min=22.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 335.56epoch/s, loss=0.0317, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1957.21epoch/s, loss=2.96e+3, weights_mean=61.1, weights_std=9.19, weights_min=26.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 447.14epoch/s, loss=0.00758, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1873.35epoch/s, loss=3.28e+3, weights_mean=64.2, weights_std=9.33, weights_min=33.2]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 318.26epoch/s, loss=0.0161, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1921.99epoch/s, loss=3.6e+3, weights_mean=67.2, weights_std=9.47, weights_min=35.5]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 343.78epoch/s, loss=0.668, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1958.86epoch/s, loss=3.94e+3, weights_mean=70.2, weights_std=9.63, weights_min=34.6]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 260.00epoch/s, loss=0.782, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1924.67epoch/s, loss=4.27e+3, weights_mean=73.1, weights_std=9.79, weights_min=36.7]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 332.87epoch/s, loss=0.361, loss_rel_change=-1]\n",
+ "Reweighting progress: 100%|██████████| 30/30 [00:00<00:00, 1966.30epoch/s, loss=4.62e+3, weights_mean=76, weights_std=9.87, weights_min=44.9]\n",
+ "Sparse reweighting progress: 100%|██████████| 60/60 [00:00<00:00, 277.73epoch/s, loss=0.554, loss_rel_change=-1]\n",
+ "INFO:microcalibrate.hyperparameter_tuning:\n",
+ "Multi-holdout tuning completed!\n",
+ "Best parameters:\n",
+ " - l0_lambda: 1.58e-05\n",
+ " - init_mean: 0.5779\n",
+ " - temperature: 0.7340\n",
+ "Performance across 2 holdouts:\n",
+ " - Mean val loss: 0.050039 (±0.049977)\n",
+ " - Mean val accuracy: 50.00% (±50.00%)\n",
+ " - Individual objectives: ['109.8460', '9.8461']\n",
+ " - Sparsity: 1.54%\n",
+ "\n",
+ "Evaluation history saved with 20 records across 10 trials.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================================================\n",
+ "Objective balance comparison:\n",
+ " Config l0_lambda Accuracy Sparsity\n",
+ "Accuracy-focused 1.31e-06 50.0% 0.0%\n",
+ "Sparsity-focused 1.58e-05 50.0% 1.6%\n",
+ " Balanced 1.58e-05 50.0% 1.5%\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Different objective balances for different use cases\n",
+ "objective_configs = {\n",
+ " 'Accuracy-focused': {'loss': 1.0, 'accuracy': 200.0, 'sparsity': 1.0},\n",
+ " 'Sparsity-focused': {'loss': 1.0, 'accuracy': 50.0, 'sparsity': 50.0},\n",
+ " 'Balanced': {'loss': 1.0, 'accuracy': 100.0, 'sparsity': 10.0},\n",
+ "}\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "for name, objectives in objective_configs.items():\n",
+ " print(f\"\\nTuning with {name} objectives...\")\n",
+ " \n",
+ " cal_temp = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=100,\n",
+ " learning_rate=1e-3,\n",
+ " )\n",
+ " \n",
+ " params = cal_temp.tune_l0_hyperparameters(\n",
+ " n_trials=10, # Fewer trials for demonstration\n",
+ " objectives_balance=objectives,\n",
+ " n_holdout_sets=2,\n",
+ " holdout_fraction=0.2,\n",
+ " epochs_per_trial=30,\n",
+ " )\n",
+ " \n",
+ " results.append({\n",
+ " 'Config': name,\n",
+ " 'l0_lambda': f\"{params['l0_lambda']:.2e}\",\n",
+ " 'Accuracy': f\"{params['mean_val_accuracy']:.1%}\",\n",
+ " 'Sparsity': f\"{params['sparsity']:.1%}\",\n",
+ " })\n",
+ "\n",
+ "results_df = pd.DataFrame(results)\n",
+ "print(\"\\n\" + \"=\"*50)\n",
+ "print(\"Objective balance comparison:\")\n",
+ "print(results_df.to_string(index=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Best practices for hyperparameter tuning\n",
+ "\n",
+ "### 1. Start with fewer trials\n",
+ "Begin with 10-20 trials to get a sense of the parameter space, then increase if needed.\n",
+ "\n",
+ "### 2. Adjust objective balance based on your needs\n",
+ "- **High accuracy weight**: When precision is critical\n",
+ "- **High sparsity weight**: When dataset reduction is the priority\n",
+ "- **Balanced**: Good starting point for most use cases\n",
+ "\n",
+ "### 3. Use appropriate cross-validation\n",
+ "- **More holdout sets**: Better generalization estimates but slower\n",
+ "- **Larger holdout fraction**: More robust validation but less training data\n",
+ "\n",
+ "### 4. Consider computational resources\n",
+ "- Reduce `epochs_per_trial` for faster exploration\n",
+ "- Use `n_jobs=-1` for parallel trials if you have multiple cores\n",
+ "\n",
+ "### 5. Monitor for overfitting\n",
+ "Watch for large gaps between training and validation performance.\n",
+ "\n",
+ "### 6. Data leakage awareness\n",
+ "Remember that targets often share information (e.g., 'income_north' and 'total_income'), so validation metrics may be optimistic."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Next steps\n",
+ "\n",
+ "After finding optimal hyperparameters:\n",
+ "1. Apply them to your full calibration with more epochs\n",
+ "2. Evaluate robustness using the [Robustness evaluation](robustness_evaluation.ipynb) notebook\n",
+ "3. Save the parameters for future use\n",
+ "4. Consider fine-tuning if results aren't satisfactory\n",
+ "\n",
+ "The tuned parameters are specific to your dataset and target configuration, so re-tune if these change significantly."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pe3.13",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/robustness_evaluation.ipynb b/docs/robustness_evaluation.ipynb
new file mode 100644
index 0000000..d550254
--- /dev/null
+++ b/docs/robustness_evaluation.ipynb
@@ -0,0 +1,584 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Robustness evaluation for calibration\n",
+ "\n",
+ "Evaluating the robustness of your calibration ensures that your weights will generalize well to related targets not explicitly used during calibration. This notebook demonstrates how to use the robustness evaluation feature to assess and improve calibration stability.\n",
+ "\n",
+ "## What is robustness evaluation\n",
+ "\n",
+ "Robustness evaluation uses holdout validation to test how well the calibration performs on unseen targets. The process:\n",
+ "1. Randomly holds out a subset of targets\n",
+ "2. Calibrates on the remaining targets\n",
+ "3. Evaluates performance on the held-out targets\n",
+ "4. Repeats multiple times to assess consistency\n",
+ "\n",
+ "This helps identify whether your calibration is overfitting to specific targets or if it will generalize well."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from microcalibrate import Calibration\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import logging\n",
+ "\n",
+ "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n",
+ "calibration_logger.setLevel(logging.WARNING)\n",
+ "\n",
+ "np.random.seed(42)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating a test dataset\n",
+ "\n",
+ "We'll create a dataset with correlated targets to demonstrate the robustness evaluation. Some targets will be combinations of others, making them easier to predict from partial information."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset: 5000 records\n",
+ "Number of targets: 23\n",
+ "Target categories:\n",
+ " - State-level: 12 targets (4 states × 3 metrics)\n",
+ " - Gender: 4 targets (2 genders × 2 metrics)\n",
+ " - Age groups: 4 targets\n",
+ " - Overall: 3 targets\n",
+ "\n",
+ "Note: Many targets overlap (e.g., total_income = sum of state incomes)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create synthetic data with structure\n",
+ "n_samples = 5000\n",
+ "\n",
+ "# Demographics\n",
+ "age = np.random.randint(18, 80, n_samples)\n",
+ "gender = np.random.choice(['M', 'F'], n_samples)\n",
+ "state = np.random.choice(['CA', 'NY', 'TX', 'FL'], n_samples, p=[0.35, 0.25, 0.25, 0.15])\n",
+ "\n",
+ "# Income (correlated with age and state)\n",
+ "base_income = 30000\n",
+ "state_multiplier = {'CA': 1.3, 'NY': 1.2, 'TX': 1.0, 'FL': 0.95}\n",
+ "income = (base_income + (age - 18) * 800).astype(float) # Ensure float type\n",
+ "for s in ['CA', 'NY', 'TX', 'FL']:\n",
+ " mask = state == s\n",
+ " income[mask] *= state_multiplier[s]\n",
+ "income += np.random.normal(0, 10000, n_samples)\n",
+ "income = np.maximum(income, 15000)\n",
+ "\n",
+ "# Employment (correlated with age)\n",
+ "emp_prob = 0.8 - np.maximum(0, (age - 60) / 20)\n",
+ "emp_prob = np.clip(emp_prob, 0, 1) # Ensure valid probability range\n",
+ "employed = np.random.binomial(1, emp_prob)\n",
+ "\n",
+ "# Create estimate matrix with various overlapping targets\n",
+ "estimate_matrix = pd.DataFrame()\n",
+ "\n",
+ "# State-level targets\n",
+ "for s in ['CA', 'NY', 'TX', 'FL']:\n",
+ " mask = state == s\n",
+ " estimate_matrix[f'pop_{s}'] = mask.astype(float)\n",
+ " estimate_matrix[f'income_{s}'] = mask * income\n",
+ " estimate_matrix[f'employed_{s}'] = mask * employed\n",
+ "\n",
+ "# Gender targets\n",
+ "for g in ['M', 'F']:\n",
+ " mask = gender == g\n",
+ " estimate_matrix[f'pop_{g}'] = mask.astype(float)\n",
+ " estimate_matrix[f'income_{g}'] = mask * income\n",
+ "\n",
+ "# Age group targets\n",
+ "age_groups = pd.cut(age, bins=[0, 35, 50, 65, 100], labels=['18-35', '36-50', '51-65', '65+'])\n",
+ "for ag in age_groups.unique():\n",
+ " mask = age_groups == ag\n",
+ " estimate_matrix[f'pop_age_{ag}'] = mask.astype(float)\n",
+ "\n",
+ "# Overall targets\n",
+ "estimate_matrix['total_population'] = 1.0\n",
+ "estimate_matrix['total_income'] = income\n",
+ "estimate_matrix['total_employed'] = employed\n",
+ "\n",
+ "# Create realistic targets\n",
+ "true_totals = estimate_matrix.sum().values\n",
+ "# Add some noise to make calibration non-trivial\n",
+ "noise = np.random.normal(1.0, 0.03, len(true_totals))\n",
+ "targets = true_totals * noise\n",
+ "\n",
+ "print(f\"Dataset: {n_samples} records\")\n",
+ "print(f\"Number of targets: {len(targets)}\")\n",
+ "print(f\"Target categories:\")\n",
+ "print(f\" - State-level: 12 targets (4 states × 3 metrics)\")\n",
+ "print(f\" - Gender: 4 targets (2 genders × 2 metrics)\")\n",
+ "print(f\" - Age groups: 4 targets\")\n",
+ "print(f\" - Overall: 3 targets\")\n",
+ "print(f\"\\nNote: Many targets overlap (e.g., total_income = sum of state incomes)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Basic robustness evaluation\n",
+ "\n",
+ "Let's evaluate the robustness of a standard calibration without regularization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating calibration robustness...\n",
+ "This will perform multiple rounds of holdout validation.\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1703.46epoch/s, loss=16.4, weights_mean=5.08, weights_std=2.41, weights_min=0.84] \n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1172.65epoch/s, loss=15.7, weights_mean=4.98, weights_std=2.41, weights_min=0.84] \n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1848.77epoch/s, loss=16.5, weights_mean=5.02, weights_std=2.41, weights_min=0.839]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1536.10epoch/s, loss=16.3, weights_mean=5.03, weights_std=2.44, weights_min=0.844]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2512.73epoch/s, loss=16.4, weights_mean=5.01, weights_std=2.42, weights_min=0.839]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2495.60epoch/s, loss=16.1, weights_mean=5.03, weights_std=2.42, weights_min=0.84]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2466.74epoch/s, loss=16.1, weights_mean=5.06, weights_std=2.42, weights_min=0.839]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2451.46epoch/s, loss=16.2, weights_mean=5.02, weights_std=2.4, weights_min=0.84]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1529.99epoch/s, loss=16.6, weights_mean=5.05, weights_std=2.43, weights_min=0.842]\n",
+ "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2534.02epoch/s, loss=16.7, weights_mean=5.05, weights_std=2.4, weights_min=0.846]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "============================================================\n",
+ "Robustness evaluation complete!\n",
+ "============================================================\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialize calibration\n",
+ "weights_init = np.ones(n_samples)\n",
+ "\n",
+ "cal = Calibration(\n",
+ " weights=weights_init,\n",
+ " targets=targets,\n",
+ " estimate_matrix=estimate_matrix,\n",
+ " epochs=200,\n",
+ " learning_rate=1e-3,\n",
+ ")\n",
+ "\n",
+ "print(\"Evaluating calibration robustness...\")\n",
+ "print(\"This will perform multiple rounds of holdout validation.\\n\")\n",
+ "\n",
+ "# Evaluate robustness\n",
+ "robustness_results = cal.evaluate_holdout_robustness(\n",
+ " n_holdout_sets=10, # Number of random holdout sets to test\n",
+ " holdout_fraction=0.3, # Hold out 30% of targets each round\n",
+ ")\n",
+ "\n",
+ "print(\"\\n\" + \"=\"*60)\n",
+ "print(\"Robustness evaluation complete!\")\n",
+ "print(\"=\"*60)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Analyzing robustness results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Overall robustness metrics:\n",
+ " Average holdout accuracy: 0.0%\n",
+ " Std dev of accuracies: 0.0%\n",
+ " Worst holdout accuracy: 0.0%\n",
+ " Best holdout accuracy: 0.0%\n",
+ "\n",
+ "Generalization gap:\n",
+ " Average training accuracy: 0.0%\n",
+ " Average holdout accuracy: 0.0%\n",
+ " Gap: 0.0%\n",
+ "\n",
+ "Consistency score: 1.00/1.00\n",
+ " (Higher is better - measures stability across rounds)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Display overall metrics\n",
+ "metrics = robustness_results['overall_metrics']\n",
+ "print(\"Overall robustness metrics:\")\n",
+ "print(f\" Average holdout accuracy: {metrics['mean_holdout_accuracy']:.1%}\")\n",
+ "print(f\" Std dev of accuracies: {metrics['std_holdout_accuracy']:.1%}\")\n",
+ "print(f\" Worst holdout accuracy: {metrics['worst_holdout_accuracy']:.1%}\")\n",
+ "print(f\" Best holdout accuracy: {metrics['best_holdout_accuracy']:.1%}\")\n",
+ "print()\n",
+ "print(f\"Generalization gap:\")\n",
+ "print(f\" Average training accuracy: {metrics['mean_train_accuracy']:.1%}\")\n",
+ "print(f\" Average holdout accuracy: {metrics['mean_holdout_accuracy']:.1%}\")\n",
+ "print(f\" Gap: {metrics['mean_train_accuracy'] - metrics['mean_holdout_accuracy']:.1%}\")\n",
+ "print()\n",
+ "# Calculate consistency score (1 - coefficient of variation)\n",
+ "consistency_score = 1 - (metrics['std_holdout_accuracy'] / max(metrics['mean_holdout_accuracy'], 0.01))\n",
+ "print(f\"Consistency score: {consistency_score:.2f}/1.00\")\n",
+ "print(f\" (Higher is better - measures stability across rounds)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Most difficult targets to predict (when held out):\n",
+ "target_name holdout_accuracy_rate times_held_out\n",
+ " pop_CA 0.0 2\n",
+ " income_CA 0.0 3\n",
+ "employed_CA 0.0 2\n",
+ " income_NY 0.0 1\n",
+ "employed_NY 0.0 3\n",
+ " pop_TX 0.0 3\n",
+ " income_TX 0.0 2\n",
+ "employed_TX 0.0 1\n",
+ " pop_FL 0.0 6\n",
+ " income_FL 0.0 3\n",
+ "\n",
+ "Easiest targets to predict (when held out):\n",
+ " target_name holdout_accuracy_rate times_held_out\n",
+ " pop_age_36-50 0.0 2\n",
+ " pop_age_18-35 0.0 1\n",
+ "total_population 0.0 1\n",
+ " total_income 0.0 3\n",
+ " total_employed 0.0 3\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Show target-level difficulty\n",
+ "target_robustness = robustness_results['target_robustness']\n",
+ "\n",
+ "# Sort by accuracy (lower accuracy = higher difficulty)\n",
+ "target_robustness = target_robustness.sort_values('holdout_accuracy_rate', ascending=True)\n",
+ "\n",
+ "print(\"\\nMost difficult targets to predict (when held out):\")\n",
+ "print(target_robustness.head(10)[['target_name', 'holdout_accuracy_rate', 'times_held_out']].to_string(index=False))\n",
+ "\n",
+ "print(\"\\nEasiest targets to predict (when held out):\")\n",
+ "print(target_robustness.tail(5)[['target_name', 'holdout_accuracy_rate', 'times_held_out']].to_string(index=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Understanding the recommendation\n",
+ "\n",
+ "The robustness evaluation provides actionable recommendations based on the results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ROBUSTNESS EVALUATION RECOMMENDATION\n",
+ "============================================================\n",
+ "❌ POOR ROBUSTNESS: The calibration shows weak generalization.\n",
+ "On average, 0.0% of held-out targets are within 10% of their true values.\n",
+ " ⚠️ Worst-case scenario: Only 0.0% accuracy in some holdout sets.\n",
+ "\n",
+ "📊 Targets with poor holdout performance (<50% accuracy):\n",
+ " - pop_CA: 0.0% accuracy\n",
+ " - total_population: 0.0% accuracy\n",
+ " - pop_age_18-35: 0.0% accuracy\n",
+ " - pop_age_36-50: 0.0% accuracy\n",
+ " - pop_age_65+: 0.0% accuracy\n",
+ "\n",
+ "💡 RECOMMENDATIONS:\n",
+ " 1. Consider enabling L0 regularization for better generalization\n",
+ " 2. Increase the noise_level parameter to improve robustness\n",
+ " 3. Try increasing dropout_rate to reduce overfitting\n",
+ " 4. Investigate why these targets are hard to predict: pop_CA, total_population, pop_age_18-35\n",
+ " 5. Consider if these targets have sufficient support in the microdata\n",
+ " 6. Generalization gap of 0.1811 suggests some overfitting - consider regularization\n",
+ "============================================================\n",
+ "\n",
+ "Additional suggestions for poor robustness:\n",
+ "1. Consider using L0 regularization to reduce overfitting\n",
+ "2. Review targets with highest difficulty scores\n",
+ "3. Check for data quality issues in difficult targets\n",
+ "4. Consider removing highly correlated redundant targets\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"ROBUSTNESS EVALUATION RECOMMENDATION\")\n",
+ "print(\"=\"*60)\n",
+ "print(robustness_results['recommendation'])\n",
+ "print(\"=\"*60)\n",
+ "\n",
+ "# Additional analysis based on results\n",
+ "metrics = robustness_results['overall_metrics']\n",
+ "if metrics['mean_holdout_accuracy'] < 0.8:\n",
+ " print(\"\\nAdditional suggestions for poor robustness:\")\n",
+ " print(\"1. Consider using L0 regularization to reduce overfitting\")\n",
+ " print(\"2. Review targets with highest difficulty scores\")\n",
+ " print(\"3. Check for data quality issues in difficult targets\")\n",
+ " print(\"4. Consider removing highly correlated redundant targets\")\n",
+ "elif metrics['std_holdout_accuracy'] > 0.1:\n",
+ " print(\"\\nAdditional suggestions for high variability:\")\n",
+ " print(\"1. Some target combinations may be inherently difficult\")\n",
+ " print(\"2. Consider grouping related targets\")\n",
+ " print(\"3. Increase epochs to ensure convergence\")\n",
+ "else:\n",
+ " print(\"\\nYour calibration shows good robustness!\")\n",
+ " print(\"Consider saving these settings for production use.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Advanced: Custom holdout strategies\n",
+ "\n",
+ "You can implement custom holdout strategies for specific evaluation needs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2407.43epoch/s, loss=20.1, weights_mean=5.5, weights_std=2.64, weights_min=0.918]\n",
+ "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2500.32epoch/s, loss=20.4, weights_mean=5.5, weights_std=2.62, weights_min=0.919]\n",
+ "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2694.62epoch/s, loss=21, weights_mean=5.56, weights_std=2.67, weights_min=0.918]\n",
+ "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2461.17epoch/s, loss=20, weights_mean=5.43, weights_std=2.68, weights_min=0.917]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Robustness by target category:\n",
+ " Category N targets Mean error Max error Within 10%\n",
+ " State targets 12 449.3% 494.2% 0%\n",
+ "Gender targets 7 445.0% 477.2% 0%\n",
+ " Age targets 4 450.4% 470.4% 0%\n",
+ " Total targets 3 422.9% 425.1% 0%\n",
+ "\n",
+ "Interpretation:\n",
+ "- Lower errors indicate targets that can be predicted from others\n",
+ "- High errors suggest independent information in those targets\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Example: Evaluate robustness by holding out entire target categories\n",
+ "def evaluate_by_category():\n",
+ " categories = {\n",
+ " 'State targets': [i for i, name in enumerate(estimate_matrix.columns) \n",
+ " if any(s in name for s in ['CA', 'NY', 'TX', 'FL'])],\n",
+ " 'Gender targets': [i for i, name in enumerate(estimate_matrix.columns) \n",
+ " if any(g in name for g in ['_M', '_F'])],\n",
+ " 'Age targets': [i for i, name in enumerate(estimate_matrix.columns) \n",
+ " if 'age' in name],\n",
+ " 'Total targets': [i for i, name in enumerate(estimate_matrix.columns) \n",
+ " if 'total' in name],\n",
+ " }\n",
+ " \n",
+ " results = []\n",
+ " \n",
+ " for category, indices in categories.items():\n",
+ " if len(indices) == 0:\n",
+ " continue\n",
+ " \n",
+ " # Create masks for train and holdout\n",
+ " train_mask = np.ones(len(targets), dtype=bool)\n",
+ " train_mask[indices] = False\n",
+ " \n",
+ " # Skip if too few training targets remain\n",
+ " if train_mask.sum() < 3:\n",
+ " continue\n",
+ " \n",
+ " # Calibrate on subset\n",
+ " cal_temp = Calibration(\n",
+ " weights=weights_init.copy(),\n",
+ " targets=targets[train_mask],\n",
+ " estimate_matrix=estimate_matrix.iloc[:, train_mask],\n",
+ " epochs=100,\n",
+ " learning_rate=1e-3,\n",
+ " )\n",
+ " \n",
+ " # Suppress logging for cleaner output\n",
+ " import logging\n",
+ " original_level = logging.getLogger().level\n",
+ " logging.getLogger().setLevel(logging.WARNING)\n",
+ " \n",
+ " try:\n",
+ " cal_temp.calibrate()\n",
+ " \n",
+ " # Evaluate on holdout\n",
+ " holdout_estimates = (estimate_matrix.iloc[:, indices].T * cal_temp.weights).sum(axis=1).values\n",
+ " holdout_targets = targets[indices]\n",
+ " holdout_errors = np.abs((holdout_estimates - holdout_targets) / holdout_targets)\n",
+ " \n",
+ " results.append({\n",
+ " 'Category': category,\n",
+ " 'N targets': len(indices),\n",
+ " 'Mean error': f\"{np.mean(holdout_errors):.1%}\",\n",
+ " 'Max error': f\"{np.max(holdout_errors):.1%}\",\n",
+ " 'Within 10%': f\"{100*np.mean(holdout_errors < 0.1):.0f}%\"\n",
+ " })\n",
+ " finally:\n",
+ " # Restore logging level\n",
+ " logging.getLogger().setLevel(original_level)\n",
+ " \n",
+ " return pd.DataFrame(results)\n",
+ "\n",
+ "category_results = evaluate_by_category()\n",
+ "print(\"Robustness by target category:\")\n",
+ "print(category_results.to_string(index=False))\n",
+ "print(\"\\nInterpretation:\")\n",
+ "print(\"- Lower errors indicate targets that can be predicted from others\")\n",
+ "print(\"- High errors suggest independent information in those targets\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Best practices for robustness evaluation\n",
+ "\n",
+ "### 1. Choose appropriate holdout parameters\n",
+ "- **Holdout fraction**: 20-30% is typically good\n",
+ "- **Number of rounds**: At least 5-10 for reliable estimates\n",
+ "- **Epochs per round**: Enough to converge (check loss curves)\n",
+ "\n",
+ "### 2. Interpret results carefully\n",
+ "- **High variability**: Indicates unstable calibration\n",
+ "- **Large generalization gap**: Suggests overfitting\n",
+ "- **Low consistency**: Some target combinations are problematic\n",
+ "\n",
+ "### 3. Be aware of data leakage\n",
+ "Since many calibration targets share information (e.g., 'total_income' includes all state incomes), holdout validation may give optimistic results. The evaluation includes a warning about this.\n",
+ "\n",
+ "### 4. Use results to improve calibration\n",
+ "- Add regularization if overfitting is detected\n",
+ "- Remove or combine highly correlated targets\n",
+ "- Investigate targets with high difficulty scores\n",
+ "- Consider different optimization parameters\n",
+ "\n",
+ "### 5. Document your evaluation\n",
+ "Save robustness results along with your calibration parameters for reproducibility and comparison."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Next steps\n",
+ "\n",
+ "After evaluating robustness:\n",
+ "\n",
+ "1. If robustness is poor, try:\n",
+ " - Hyperparameter tuning to find better L0 parameters\n",
+ " - Reviewing and cleaning your targets\n",
+ " - Increasing the dataset size\n",
+ "\n",
+ "2. If robustness is good:\n",
+ " - Save your calibration configuration\n",
+ " - Apply to production data\n",
+ " - Monitor performance over time\n",
+ "\n",
+ "3. For specific issues:\n",
+ " - High difficulty targets → Check data quality\n",
+ " - Large generalization gap → Add regularization\n",
+ " - High variability → Increase epochs or adjust learning rate"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pe3.13",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}