diff --git a/README.md b/README.md index 134f35c..86cdedc 100644 --- a/README.md +++ b/README.md @@ -1 +1,105 @@ -# MicroCalibrate +# MicroCalibrate + +[![CI](https://github.com/PolicyEngine/microcalibrate/actions/workflows/main.yml/badge.svg)](https://github.com/PolicyEngine/microcalibrate/actions/workflows/main.yml) +[![codecov](https://codecov.io/gh/PolicyEngine/microcalibrate/branch/main/graph/badge.svg)](https://codecov.io/gh/PolicyEngine/microcalibrate) +[![PyPI version](https://badge.fury.io/py/microcalibrate.svg)](https://badge.fury.io/py/microcalibrate) +[![Python Version](https://img.shields.io/pypi/pyversions/microcalibrate)](https://pypi.org/project/microcalibrate/) + +MicroCalibrate is a Python package for calibrating survey weights to match population targets, with advanced features including L0 regularization for sparsity, hyperparameter tuning, and robustness evaluation. + +## Features + +- **Survey Weight Calibration**: The package adjusts sample weights to match known population totals. +- **L0 Regularization**: The system creates sparse weights to reduce dataset size while maintaining accuracy. +- **Automatic Hyperparameter Tuning**: The optimization module automatically finds optimal regularization parameters using cross-validation. +- **Robustness Evaluation**: The evaluation tools assess calibration stability using holdout validation. +- **Target Assessment**: The analysis features help identify which targets complicate calibration. +- **Performance Monitoring**: The system tracks calibration progress with detailed logging. +- **Interactive Dashboard**: Users can visualize calibration performance at https://microcalibrate.vercel.app/. + +## Installation + +```bash +pip install microcalibrate +``` + +The package requires the following dependencies: +- Python version 3.13 or higher is required. +- PyTorch version 2.7.0 or higher is needed. +- Additional required packages include NumPy, Pandas, Optuna, and L0-python. + +## Quick start + +### Basic calibration + +```python +from microcalibrate import Calibration +import numpy as np +import pandas as pd + +# Create sample data for calibration +n_samples = 1000 +weights = np.ones(n_samples) # Initial weights are set to one + +# Create an estimate matrix that represents the contribution of each record to targets +estimate_matrix = pd.DataFrame({ + 'total_income': np.random.normal(50000, 15000, n_samples), + 'total_employed': np.random.binomial(1, 0.6, n_samples), +}) + +# Set the target values to achieve through calibration +targets = np.array([ + 50_000_000, # This is the total income target + 600, # This is the total employed target +]) + +# Initialize the calibration object and configure the optimization parameters +cal = Calibration( + weights=weights, + targets=targets, + estimate_matrix=estimate_matrix, + epochs=500, + learning_rate=1e-3, +) + +# Perform the calibration to adjust weights +performance_df = cal.calibrate() + +# Retrieve the calibrated weights from the calibration object +new_weights = cal.weights +``` + +## API reference + +### Calibration class + +The Calibration class is the main class for weight calibration. + +**Parameters:** +- `weights`: The initial weights array for each record. +- `targets`: The target values to match during calibration. +- `estimate_matrix`: A DataFrame containing the contribution of each record to targets. +- `estimate_function`: An alternative to estimate_matrix that uses a custom function. +- `epochs`: The number of optimization iterations to perform (default is 32). +- `learning_rate`: The optimization learning rate (default is 1e-3). +- `noise_level`: The amount of noise added for robustness (default is 10.0). +- `dropout_rate`: The dropout rate for regularization (default is 0). +- `regularize_with_l0`: This parameter enables L0 regularization (default is False). +- `l0_lambda`: The L0 regularization strength parameter (default is 5e-6). +- `init_mean`: The initial proportion of non-zero weights (default is 0.999). +- `temperature`: The sparsity control parameter (default is 0.5). + +**Methods:** +- `calibrate()`: This method performs the weight calibration process. +- `tune_l0_hyperparameters()`: This method automatically tunes L0 parameters using cross-validation. +- `evaluate_holdout_robustness()`: This method assesses calibration stability using holdout validation. +- `assess_analytical_solution()`: This method analyzes the difficulty of achieving target combinations. +- `summary()`: This method returns a summary of the calibration results. + +## Examples and documentation + +For detailed examples and interactive notebooks, see the [documentation](https://policyengine.github.io/microcalibrate/). + +## Contributing + +Contributions are welcome to the project. Please feel free to submit a Pull Request with your improvements. diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..8cbf276 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Adding documentation for L0, hyperparameter tuning and robustness checks. diff --git a/docs/_toc.yml b/docs/_toc.yml index 19344f5..34246d6 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -2,3 +2,5 @@ format: jb-book root: intro chapters: - file: calibration.ipynb +- file: l0_regularization.ipynb +- file: robustness_evaluation.ipynb diff --git a/docs/calibration.ipynb b/docs/calibration.ipynb index 68c2a59..4132c3a 100644 --- a/docs/calibration.ipynb +++ b/docs/calibration.ipynb @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "id": "vxj460xngvg", "metadata": {}, "outputs": [ @@ -62,9 +62,8 @@ "import plotly.graph_objs as go\n", "from plotly.subplots import make_subplots\n", "\n", - "logging.basicConfig(\n", - " level=logging.INFO,\n", - ")\n", + "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n", + "calibration_logger.setLevel(logging.WARNING)\n", "\n", "# Create a sample dataset with age and income data\n", "random_generator = np.random.default_rng(0)\n", @@ -96,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "id": "xfw94bs21yl", "metadata": {}, "outputs": [ @@ -104,172 +103,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:microcalibrate.calibration:Performing basic target assessment...\n", - "WARNING:microcalibrate.calibration:Target income_aged_20_30 (7.37e+08) differs from initial estimate (7.37e+05) by 3.00 orders of magnitude.\n", - "WARNING:microcalibrate.calibration:Target income_aged_71 is supported by only 0.83% of records in the loss matrix. This may make calibration unstable or ineffective.\n", - "INFO:microcalibrate.reweight:Starting calibration process for targets ['income_aged_20_30' 'income_aged_40_50' 'income_aged_71']: [7.37032429e+08 9.76779350e+05 4.36479914e+04]\n", - "INFO:microcalibrate.reweight:Original weights - mean: 1.0000, std: 0.0000\n", - "INFO:microcalibrate.reweight:Initial weights after noise - mean: 1.0226, std: 0.0148\n", - "Reweighting progress: 0%| | 0/528 [00:00 3\n", + "if np.any(large_scale_diff):\n", + " print(\"\\n⚠️ WARNING: The following targets differ by more than 3 orders of magnitude from the median:\")\n", + " for i in np.where(large_scale_diff)[0]:\n", + " print(f\" - {target_names[i]}: {target_values[i]:.2e}\")\n", + " print(\"\\nConsider rescaling these targets for better numerical stability.\")\n", + "else:\n", + " print(\"\\n✓ All targets are within reasonable scale differences\")" + ] + }, + { + "cell_type": "markdown", + "id": "5l3q4s17a2r", + "metadata": {}, + "source": [ + "### Interpreting tolerance warnings\n", + "\n", + "The tolerance check identifies targets that:\n", + "- Differ by more than 3 orders of magnitude from the median target\n", + "- May cause numerical instability during optimization\n", + "- Should potentially be rescaled or reviewed\n", + "\n", + "In our example, we have one target that is much larger than the others (income_aged_20_30), which the system warns about. This is expected since we artificially multiplied it by 1000 in our setup." + ] + }, + { + "cell_type": "markdown", + "id": "gzefuat6c78", + "metadata": {}, + "source": [ + "### Excluded targets\n", + "\n", + "If the pre-calibration assessment identifies problematic targets, you can exclude them from calibration:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ytre6ko9ru", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calibration setup with excluded targets:\n", + "Original targets: ['income_aged_20_30' 'income_aged_40_50' 'income_aged_71']\n", + "Excluded targets: ['income_aged_71']\n", + "Active targets for calibration: ['income_aged_20_30' 'income_aged_40_50']\n", + "Number of active targets: 2\n" + ] + } + ], + "source": [ + "# Example: Excluding problematic targets\n", + "# If you identify targets that cause issues, you can exclude them:\n", + "\n", + "# Create a calibrator with excluded targets\n", + "calibrator_with_exclusions = Calibration(\n", + " estimate_matrix=targets_matrix,\n", + " weights=weights.copy(), \n", + " targets=targets,\n", + " excluded_targets=[\"income_aged_71\"], # Exclude the smallest target\n", + " noise_level=0.05,\n", + " epochs=100,\n", + " learning_rate=0.01,\n", + ")\n", + "\n", + "print(\"Calibration setup with excluded targets:\")\n", + "print(f\"Original targets: {calibrator_with_exclusions.original_target_names}\")\n", + "print(f\"Excluded targets: {calibrator_with_exclusions.excluded_targets}\")\n", + "print(f\"Active targets for calibration: {calibrator_with_exclusions.target_names}\")\n", + "print(f\"Number of active targets: {len(calibrator_with_exclusions.targets)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "wl7a94tjmnj", + "metadata": {}, + "source": [ + "## Best practices for pre-calibration assessment\n", + "\n", + "1. **Always run analytical assessment first** - This helps identify fundamental issues with your target specification before spending time on calibration.\n", + "\n", + "2. **Check target scales** - Targets that differ by many orders of magnitude should be rescaled or normalized to improve convergence.\n", + "\n", + "3. **Look for redundant targets** - If targets are highly correlated or mathematically dependent, consider removing redundant ones.\n", + "\n", + "4. **Consider the degrees of freedom** - Having more targets than observations makes exact calibration impossible. The system will find a best-fit solution.\n", + "\n", + "5. **Use exclusions strategically** - Temporarily exclude problematic targets to get initial calibration working, then gradually add them back.\n", + "\n", + "6. **Monitor condition numbers** - High condition numbers (>1e10) indicate numerical instability. Consider reformulating your targets or adding regularization.\n", + "\n", + "Now let's proceed with the actual calibration:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "id": "9djvjpfrhxb", "metadata": {}, "outputs": [ @@ -313,9 +378,9 @@ "output_type": "stream", "text": [ "Target totals: [7.37032429e+08 9.76779350e+05 4.36479914e+04]\n", - "Final calibrated totals: [7.37025225e+08 9.76778430e+05 4.36469951e+04]\n", - "Difference: [-7.20317294e+03 -9.19358560e-01 -9.96308503e-01]\n", - "Relative error: [-9.77321032e-04 -9.41214165e-05 -2.28259874e-03]\n" + "Final calibrated totals: [7.37025243e+08 9.76778358e+05 4.36469951e+04]\n", + "Difference: [-7.18606246e+03 -9.91859882e-01 -9.96308503e-01]\n", + "Relative error: [-0.000975 -0.00010154 -0.0022826 ]\n" ] } ], @@ -331,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "id": "96cc818b", "metadata": {}, "outputs": [], @@ -346,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "id": "923d79dd", "metadata": {}, "outputs": [ @@ -385,79 +450,79 @@ " \n", " 0\n", " 0\n", - " 0.340969\n", + " 0.339299\n", " income_aged_20_30\n", " 7.370324e+08\n", - " 752926.625000\n", - " -7.362795e+08\n", - " 7.362795e+08\n", - " 0.998978\n", + " 7.595331e+05\n", + " -7.362729e+08\n", + " 7.362729e+08\n", + " 0.998969\n", " \n", " \n", " 1\n", " 0\n", - " 0.340969\n", + " 0.339299\n", " income_aged_40_50\n", " 9.767794e+05\n", - " 872380.750000\n", - " -1.043986e+05\n", - " 1.043986e+05\n", - " 0.106880\n", + " 8.723239e+05\n", + " -1.044555e+05\n", + " 1.044555e+05\n", + " 0.106939\n", " \n", " \n", " 2\n", " 0\n", - " 0.340969\n", + " 0.339299\n", " income_aged_71\n", " 4.364799e+04\n", - " 38570.789062\n", - " -5.077203e+03\n", - " 5.077203e+03\n", - " 0.116322\n", + " 3.961747e+04\n", + " -4.030520e+03\n", + " 4.030520e+03\n", + " 0.092341\n", " \n", " \n", " 3\n", - " 10\n", - " 0.332947\n", + " 52\n", + " 0.332147\n", " income_aged_20_30\n", " 7.370324e+08\n", - " 832346.250000\n", - " -7.362001e+08\n", - " 7.362001e+08\n", - " 0.998871\n", + " 1.320273e+06\n", + " -7.357122e+08\n", + " 7.357122e+08\n", + " 0.998209\n", " \n", " \n", " 4\n", - " 10\n", - " 0.332947\n", + " 52\n", + " 0.332147\n", " income_aged_40_50\n", " 9.767794e+05\n", - " 958994.500000\n", - " -1.778488e+04\n", - " 1.778488e+04\n", - " 0.018208\n", + " 9.774676e+05\n", + " 6.882500e+02\n", + " 6.882500e+02\n", + " 0.000705\n", " \n", " \n", "\n", "" ], "text/plain": [ - " epoch loss target_name target estimate \\\n", - "0 0 0.340969 income_aged_20_30 7.370324e+08 752926.625000 \n", - "1 0 0.340969 income_aged_40_50 9.767794e+05 872380.750000 \n", - "2 0 0.340969 income_aged_71 4.364799e+04 38570.789062 \n", - "3 10 0.332947 income_aged_20_30 7.370324e+08 832346.250000 \n", - "4 10 0.332947 income_aged_40_50 9.767794e+05 958994.500000 \n", + " epoch loss target_name target estimate \\\n", + "0 0 0.339299 income_aged_20_30 7.370324e+08 7.595331e+05 \n", + "1 0 0.339299 income_aged_40_50 9.767794e+05 8.723239e+05 \n", + "2 0 0.339299 income_aged_71 4.364799e+04 3.961747e+04 \n", + "3 52 0.332147 income_aged_20_30 7.370324e+08 1.320273e+06 \n", + "4 52 0.332147 income_aged_40_50 9.767794e+05 9.774676e+05 \n", "\n", " error abs_error rel_abs_error \n", - "0 -7.362795e+08 7.362795e+08 0.998978 \n", - "1 -1.043986e+05 1.043986e+05 0.106880 \n", - "2 -5.077203e+03 5.077203e+03 0.116322 \n", - "3 -7.362001e+08 7.362001e+08 0.998871 \n", - "4 -1.778488e+04 1.778488e+04 0.018208 " + "0 -7.362729e+08 7.362729e+08 0.998969 \n", + "1 -1.044555e+05 1.044555e+05 0.106939 \n", + "2 -4.030520e+03 4.030520e+03 0.092341 \n", + "3 -7.357122e+08 7.357122e+08 0.998209 \n", + "4 6.882500e+02 6.882500e+02 0.000705 " ] }, - "execution_count": 5, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -468,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "id": "da828d30", "metadata": {}, "outputs": [ @@ -489,103 +554,19 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x", "y": [ - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, - 737032448, 737032448, 737032448, 737032448, @@ -609,113 +590,29 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x", "y": [ - 752926.625, - 832346.25, - 921668.5, - 1023621.5625, - 1141370.5, - 1278621.25, - 1439833.625, - 1630507.875, - 1857545, - 2129699.25, - 2458166.5, - 2857359, - 3345938.25, - 3948207.75, - 4695982.5, - 5631132, - 6809027, - 8303248.5, - 10212014, - 12666988, - 15845385, - 19986564, - 25414842, - 32570676, - 42053112, - 54676848, - 71547424, - 94156512, - 124493704, - 165154752, - 219382320, - 290860512, - 382795616, - 495131744, - 617828800, - 721720128, - 770444416, - 766268800, - 745850304, - 733191488, - 731150848, - 734141312, - 737060736, - 738073920, - 737737984, - 737144384, - 736858624, - 736877632, - 736992192, - 737061056, - 737064960, - 737042432, + 759533.0625, + 1320273, + 2631305, + 6359750.5, + 19233504, + 72160320, + 309902368, + 773134208, + 736156928, + 736861504, 737027328 ], "yaxis": "y" @@ -730,103 +627,19 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x2", "y": [ - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, - 976779.375, 976779.375, 976779.375, 976779.375, @@ -850,113 +663,29 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x2", "y": [ - 872380.75, - 958994.5, - 1003960.5, - 984014.625, - 966661.75, - 975252.125, - 980736.625, - 976016.625, - 975696.5, - 977566.9375, - 976717.0625, - 976543.625, - 976956, - 976731.375, - 976753.5, - 976814.25, - 976756.8125, - 976785, - 976779.625, - 976775, - 976781.1875, - 976776.6875, - 976779, - 976778.25, - 976778.375, - 976778.5625, + 872323.875, + 977467.625, + 976435.125, + 976771.8125, + 976777.25, 976778.25, 976778.375, 976778.375, 976778.375, 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, - 976778.375, 976778.375 ], "yaxis": "y2" @@ -970,113 +699,29 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x3", "y": [ - 0.9989784349019597, - 0.998870678960935, - 0.9987494872138927, - 0.9986111580768557, - 0.9984513972171819, - 0.998265176447157, - 0.9980464447272205, - 0.9977877393601536, - 0.9974796971218287, - 0.9971104403126632, - 0.9966647784549101, - 0.9961231571177719, - 0.9954602565204836, - 0.9946431018597434, - 0.993628526786381, - 0.9923597230823683, - 0.9907615641367258, - 0.9887342158102651, - 0.9861444173486376, - 0.9828135273631806, - 0.9785011025728951, - 0.9728823825135037, - 0.9655173363547761, - 0.955808355400928, - 0.9429426586114157, - 0.9258148699580727, - 0.9029250012070025, - 0.8722491631738851, - 0.8310878926188063, - 0.7759192930404063, - 0.7023437426733212, - 0.6053626773295061, - 0.4806258299227553, - 0.3282090288649001, - 0.16173459977707794, - 0.020775638903756923, - 0.04533310316345801, - 0.0396676592453091, - 0.011963999717960857, - 0.005211385211631822, - 0.007980109988318995, - 0.003922671258023093, - 0.00003838094249006524, - 0.0014130612604996192, - 0.0009572658597522154, - 0.00015187391044145727, - 0.00023584307647741447, - 0.00021005316715716702, - 0.00005461903354355438, - 0.00003881511604764517, - 0.00004411203345012023, - 0.000013546214996493614, + 0.9989694713379837, + 0.9982086636706665, + 0.9964298654596006, + 0.9913711390627947, + 0.973904128573726, + 0.9020934285921697, + 0.5795268324468668, + 0.04898259241905181, + 0.001187898853538671, + 0.00023193551445919514, 0.000006946776921278777 ], "yaxis": "y3" @@ -1090,113 +735,29 @@ "type": "scatter", "x": [ 0, - 10, - 20, - 30, - 40, - 50, - 60, - 70, - 80, - 90, - 100, - 110, - 120, - 130, - 140, - 150, - 160, - 170, - 180, - 190, - 200, - 210, - 220, - 230, - 240, - 250, + 52, + 104, + 156, + 208, 260, - 270, - 280, - 290, - 300, - 310, - 320, - 330, - 340, - 350, - 360, - 370, - 380, - 390, - 400, - 410, - 420, - 430, - 440, - 450, - 460, - 470, - 480, - 490, - 500, - 510, + 312, + 364, + 416, + 468, 520 ], "xaxis": "x4", "y": [ - 0.10688045598833411, - 0.018207668440992624, - 0.02782729211496711, - 0.0074072509976984315, - 0.010358147662567097, - 0.001563556765313559, - 0.004051324281903475, - 0.000780882581596279, - 0.0011086177981593848, - 0.0008062849402404714, - 0.00006379383266564161, - 0.0002413544000148447, - 0.00018082384264102628, - 0.00004914108674745513, - 0.000026490117074800028, - 0.00003570407083994786, - 0.000023098870202905338, - 0.000005758721103217397, - 2.5594316014299544e-7, - 0.00000447900530250242, - 0.000001855587911036717, - 0.000002751388971537201, - 3.839147402144932e-7, - 0.0000011517442206434795, - 0.0000010237726405719818, - 8.318152704647351e-7, + 0.10693868305726664, + 0.0007046115198736664, + 0.0003524337315169047, + 0.000007742280594325611, + 0.0000021755168612154613, 0.0000011517442206434795, 0.0000010237726405719818, 0.0000010237726405719818, 0.0000010237726405719818, 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, - 0.0000010237726405719818, 0.0000010237726405719818 ], "yaxis": "y4" @@ -2259,7 +1820,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "id": "fefaf031", "metadata": {}, "outputs": [ @@ -2323,7 +1884,7 @@ "2 income_aged_71 4.364799e+04 4.364699e+04 -0.000023" ] }, - "execution_count": 7, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -2332,11 +1893,40 @@ "summary = calibrator.summary()\n", "summary" ] + }, + { + "cell_type": "markdown", + "id": "zizmg24l2dh", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The Calibration class provides comprehensive tools for survey weight calibration:\n", + "\n", + "1. **Pre-calibration assessment**:\n", + " - Analytical solution feasibility analysis\n", + " - Target tolerance checking\n", + " - Correlation analysis\n", + " - Target exclusion capabilities\n", + "\n", + "2. **Standard calibration**:\n", + " - Gradient-based optimization\n", + " - Multi-target support\n", + " - Progress monitoring\n", + " - Performance logging\n", + "\n", + "3. **Advanced features**:\n", + " - L0 regularization for sparse weights\n", + " - Hyperparameter tuning (see [hyperparameter tuning notebook](hyperparameter_tuning.ipynb))\n", + " - Robustness evaluation (see [robustness evaluation notebook](robustness_evaluation.ipynb))\n", + "\n", + "By using the pre-calibration assessment tools, you can identify and address potential issues before running the calibration, saving time and improving results. The L0 regularization feature allows you to maintain calibration accuracy while significantly reducing dataset size, which is valuable for large-scale applications." + ] } ], "metadata": { "kernelspec": { - "display_name": "pe", + "display_name": "pe3.13", "language": "python", "name": "python3" }, @@ -2350,7 +1940,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/docs/intro.md b/docs/intro.md index 9320689..9e4e4cf 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -1,15 +1,92 @@ ## MicroCalibrate -MicroCalibrate is a framework that enables survey weight calibration integrating informative diagnostics and visualizations for the user to easily identify and work through hard-to-hit targets. A dashboard allows the user to explore the reweighting process on sample data as well as supporting uploading or pasting a link to the performance log of a specific reweighting job. - -To make sure that your data can be loaded and visualized by the dashboard, it must be in csv format and contain the following columns: -- epoch (int) -- target_name (str) -- target (float) -- estimate (float) -- error (float) -- abs_error (float) -- rel_abs_error (float) -- loss (float) - -To access the performance dashboard see: https://microcalibrate.vercel.app/ +MicroCalibrate is a comprehensive framework for survey weight calibration that combines traditional calibration techniques with modern machine learning approaches. It enables users to adjust sample weights to match population targets while providing advanced features for sparsity, optimization, and robustness analysis. + +## Key features + +### 1. Core calibration +- **Survey weight adjustment**: The system calibrates sample weights to match known population totals using gradient-based optimization. +- **Multi-target support**: The calibration process can handle multiple calibration targets simultaneously. +- **Custom estimate functions**: Users can use either estimate matrices or custom functions for flexible calibration scenarios. + +### 2. L0 regularization for sparsity +- **Dataset reduction**: The algorithm automatically identifies and zeros out unnecessary weights to reduce dataset size. +- **Sparse weight generation**: The system creates compact datasets while maintaining calibration accuracy. +- **Configurable sparsity**: Users can control the trade-off between dataset size and calibration precision. + +### 3. Automatic hyperparameter tuning +- **Cross-validation**: The system uses holdout validation to find optimal regularization parameters. +- **Multi-objective optimization**: The optimization process balances calibration loss, accuracy, and sparsity. +- **Optuna integration**: The package leverages state-of-the-art hyperparameter optimization through Optuna. + +### 4. Robustness evaluation +- **Generalization assessment**: The evaluation module assesses how well calibration performs on unseen targets. +- **Stability snalysis**: The system identifies targets that are difficult to calibrate reliably. +- **Actionable recommendations**: Users receive specific suggestions for improving calibration robustness. + +### 5. Target analysis +- **Pre-calibration assessment**: The system identifies problematic targets before calibration begins. +- **Analytical solutions**: The analysis helps users understand the mathematical difficulty of target combinations. +- **Order of magnitude warnings**: The system detects targets that differ significantly from initial estimates. + +### 6. Performance monitoring +- **Detailed logging**: The system tracks calibration progress across epochs. +- **Performance dashboard**: Users can visualize calibration results at https://microcalibrate.vercel.app/. +- **CSV export**: The system can save detailed performance metrics for further analysis. + +## Dashboard requirements + +To use the performance dashboard for visualization, users must ensure their calibration log CSV contains the following fields: +- epoch (int): The iteration number during calibration. +- target_name (str): The name of each calibration target. +- target (float): The target value to achieve. +- estimate (float): The estimated value at each epoch. +- error (float): The difference between target and estimate. +- abs_error (float): The absolute error value. +- rel_abs_error (float): The relative absolute error. +- loss (float): The loss value at each epoch. + +## Getting started + +```python +from microcalibrate import Calibration +import numpy as np +import pandas as pd + +# Basic calibration +cal = Calibration( + weights=initial_weights, + targets=target_values, + estimate_matrix=contribution_matrix +) +performance = cal.calibrate() + +# With L0 regularization +cal = Calibration( + weights=initial_weights, + targets=target_values, + estimate_matrix=contribution_matrix, + regularize_with_l0=True +) +performance = cal.calibrate() +sparse_weights = cal.sparse_weights + +# Hyperparameter tuning +best_params = cal.tune_l0_hyperparameters(n_trials=30) + +# Robustness evaluation +robustness = cal.evaluate_holdout_robustness() +print(robustness['recommendation']) +``` + +## Documentation structure + +- **[Basic calibration](calibration.ipynb)**: Core calibration concepts and basic usage +- **[L0 regularization](l0_regularization.ipynb)**: Creating sparse weights for dataset reduction +- **[Robustness evaluation](robustness_evaluation.ipynb)**: Assessing calibration stability and generalization + +## Support + +- **GitHub issues**: https://github.com/PolicyEngine/microcalibrate/issues +- **Documentation**: https://policyengine.github.io/microcalibrate/ +- **Performance dashboard**: https://microcalibrate.vercel.app/ \ No newline at end of file diff --git a/docs/l0_regularization.ipynb b/docs/l0_regularization.ipynb new file mode 100644 index 0000000..2bb6b27 --- /dev/null +++ b/docs/l0_regularization.ipynb @@ -0,0 +1,1117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# L0 regularization for sparse weights\n", + "\n", + "L0 regularization is a powerful technique that creates sparse weights during calibration, effectively reducing the dataset size by setting many weights to zero. This is particularly useful when:\n", + "\n", + "- You need to reduce computational costs in downstream processing\n", + "- You want to identify the most important records in your dataset\n", + "- You need a smaller, representative sample that still matches population targets\n", + "\n", + "## How L0 regularization works\n", + "\n", + "L0 regularization adds a penalty term to the calibration loss that encourages weights to be exactly zero. Unlike L1 regularization (which shrinks weights), L0 creates truly sparse solutions by using a differentiable approximation of the L0 norm through the Hard Concrete distribution." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from microcalibrate import Calibration\n", + "import numpy as np\n", + "import pandas as pd\n", + "import logging\n", + "\n", + "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n", + "calibration_logger.setLevel(logging.WARNING)\n", + "\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Basic L0 regularization\n", + "\n", + "Let's create a synthetic dataset and apply L0 regularization to reduce its size while maintaining calibration accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset size: 5000 records\n", + "Number of targets: 10\n", + "Target names: ['income_18-30', 'employed_18-30', 'income_31-50', 'employed_31-50', 'income_51-65', 'employed_51-65', 'income_65+', 'employed_65+', 'total_income', 'total_employed']\n" + ] + } + ], + "source": [ + "# Create synthetic data\n", + "n_samples = 5000\n", + "n_targets = 10\n", + "\n", + "# Generate random data with some structure\n", + "age_groups = np.random.choice(['18-30', '31-50', '51-65', '65+'], n_samples)\n", + "income = np.random.lognormal(10.5, 0.8, n_samples) # Log-normal income distribution\n", + "employed = np.random.binomial(1, 0.65, n_samples)\n", + "\n", + "# Create estimate matrix with various demographic combinations\n", + "estimate_matrix = pd.DataFrame()\n", + "for age in ['18-30', '31-50', '51-65', '65+']:\n", + " mask = age_groups == age\n", + " estimate_matrix[f'income_{age}'] = mask * income\n", + " estimate_matrix[f'employed_{age}'] = mask * employed\n", + "\n", + "estimate_matrix['total_income'] = income\n", + "estimate_matrix['total_employed'] = employed\n", + "\n", + "# Set realistic targets (scaled population values)\n", + "targets = estimate_matrix.sum().values * 1.1 # 10% higher than unweighted\n", + "\n", + "print(f\"Dataset size: {n_samples} records\")\n", + "print(f\"Number of targets: {len(targets)}\")\n", + "print(f\"Target names: {list(estimate_matrix.columns)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing standard vs L0 calibration" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running standard calibration...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 2660.14epoch/s, loss=13.2, weights_mean=5.1, weights_std=2.42, weights_min=0.842]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Standard calibration results:\n", + "Non-zero weights: 5000 (100.0%)\n", + "Weight range: [0.835, 9.164]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Standard calibration (no sparsity)\n", + "weights_init = np.ones(n_samples)\n", + "\n", + "cal_standard = Calibration(\n", + " weights=weights_init.copy(),\n", + " targets=targets,\n", + " estimate_matrix=estimate_matrix,\n", + " epochs=200,\n", + " learning_rate=1e-3,\n", + " regularize_with_l0=False\n", + ")\n", + "\n", + "print(\"Running standard calibration...\")\n", + "perf_standard = cal_standard.calibrate()\n", + "weights_standard = cal_standard.weights\n", + "\n", + "print(f\"\\nStandard calibration results:\")\n", + "print(f\"Non-zero weights: {np.sum(weights_standard != 0)} ({100*np.mean(weights_standard != 0):.1f}%)\")\n", + "print(f\"Weight range: [{weights_standard.min():.3f}, {weights_standard.max():.3f}]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running L0 regularized calibration...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reweighting progress: 100%|██████████| 200/200 [00:00<00:00, 1776.92epoch/s, loss=13.5, weights_mean=5.13, weights_std=2.43, weights_min=0.84] \n", + "Sparse reweighting progress: 100%|██████████| 400/400 [00:00<00:00, 722.65epoch/s, loss=0.0103, loss_rel_change=-0.691]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "L0 calibration results:\n", + "Non-zero weights: 1998 (40.0%)\n", + "Dataset reduction: 60.0%\n", + "Weight range: [0.010, 14.061]\n" + ] + } + ], + "source": [ + "# L0 regularized calibration\n", + "cal_l0 = Calibration(\n", + " weights=weights_init.copy(),\n", + " targets=targets,\n", + " estimate_matrix=estimate_matrix,\n", + " epochs=200,\n", + " learning_rate=1e-3,\n", + " regularize_with_l0=True,\n", + " l0_lambda=5e-6, # Regularization strength\n", + " init_mean=0.999, # Start with most weights active\n", + " temperature=0.5, # Controls sparsity gradient\n", + ")\n", + "\n", + "print(\"Running L0 regularized calibration...\")\n", + "perf_l0 = cal_l0.calibrate()\n", + "weights_l0 = cal_l0.sparse_weights\n", + "\n", + "print(f\"\\nL0 calibration results:\")\n", + "print(f\"Non-zero weights: {np.sum(weights_l0 != 0)} ({100*np.mean(weights_l0 != 0):.1f}%)\")\n", + "print(f\"Dataset reduction: {100*(1 - np.mean(weights_l0 != 0)):.1f}%\")\n", + "print(f\"Weight range: [{weights_l0[weights_l0>0].min():.3f}, {weights_l0.max():.3f}]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hyperparameter tuning for L0 regularization\n", + "\n", + "Finding the optimal L0 regularization parameters is crucial for achieving the right balance between sparsity and calibration accuracy. This notebook demonstrates how to use the automatic hyperparameter tuning feature to find the best parameters for your specific dataset.\n", + "\n", + "## Why hyperparameter tuning matters\n", + "\n", + "L0 regularization has three key parameters that interact in complex ways:\n", + "- **l0_lambda**: Controls the strength of sparsity penalty\n", + "- **init_mean**: Sets the initial proportion of active weights\n", + "- **temperature**: Determines how \"hard\" the sparsity decisions are\n", + "\n", + "Manual tuning can be time-consuming and may miss optimal combinations. The automatic tuning uses Optuna to efficiently search the parameter space." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic hyperparameter tuning\n", + "\n", + "Let's start with a simple tuning run to find good L0 parameters. The tuning process will:\n", + "1. Create multiple holdout sets for cross-validation\n", + "2. Try different parameter combinations\n", + "3. Evaluate each combination on both training and validation targets\n", + "4. Select the best parameters based on a multi-objective criterion" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:microcalibrate.hyperparameter_tuning:Multi-holdout hyperparameter tuning:\n", + " - 3 holdout sets\n", + " - 2 targets per holdout (20.0%)\n", + " - Aggregation: mean\n", + "\n", + "WARNING:microcalibrate.hyperparameter_tuning:Data leakage warning: Targets often share overlapping information (e.g., geographic breakdowns like 'snap in CA' and 'snap in US'). Holdout validation may not provide complete isolation between training and validation sets. The robustness metrics should be interpreted with this limitation in mind - they may overestimate the model's true generalization performance.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting hyperparameter tuning...\n", + "This will take a few minutes as it explores different parameter combinations.\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "85fd8b3742ea4b07b12ca58d3ea857ff", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/20 [00:00 0.1:\n", + " print(\"\\nAdditional suggestions for high variability:\")\n", + " print(\"1. Some target combinations may be inherently difficult\")\n", + " print(\"2. Consider grouping related targets\")\n", + " print(\"3. Increase epochs to ensure convergence\")\n", + "else:\n", + " print(\"\\nYour calibration shows good robustness!\")\n", + " print(\"Consider saving these settings for production use.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced: Custom holdout strategies\n", + "\n", + "You can implement custom holdout strategies for specific evaluation needs." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2407.43epoch/s, loss=20.1, weights_mean=5.5, weights_std=2.64, weights_min=0.918]\n", + "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2500.32epoch/s, loss=20.4, weights_mean=5.5, weights_std=2.62, weights_min=0.919]\n", + "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2694.62epoch/s, loss=21, weights_mean=5.56, weights_std=2.67, weights_min=0.918]\n", + "Reweighting progress: 100%|██████████| 100/100 [00:00<00:00, 2461.17epoch/s, loss=20, weights_mean=5.43, weights_std=2.68, weights_min=0.917]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Robustness by target category:\n", + " Category N targets Mean error Max error Within 10%\n", + " State targets 12 449.3% 494.2% 0%\n", + "Gender targets 7 445.0% 477.2% 0%\n", + " Age targets 4 450.4% 470.4% 0%\n", + " Total targets 3 422.9% 425.1% 0%\n", + "\n", + "Interpretation:\n", + "- Lower errors indicate targets that can be predicted from others\n", + "- High errors suggest independent information in those targets\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Example: Evaluate robustness by holding out entire target categories\n", + "def evaluate_by_category():\n", + " categories = {\n", + " 'State targets': [i for i, name in enumerate(estimate_matrix.columns) \n", + " if any(s in name for s in ['CA', 'NY', 'TX', 'FL'])],\n", + " 'Gender targets': [i for i, name in enumerate(estimate_matrix.columns) \n", + " if any(g in name for g in ['_M', '_F'])],\n", + " 'Age targets': [i for i, name in enumerate(estimate_matrix.columns) \n", + " if 'age' in name],\n", + " 'Total targets': [i for i, name in enumerate(estimate_matrix.columns) \n", + " if 'total' in name],\n", + " }\n", + " \n", + " results = []\n", + " \n", + " for category, indices in categories.items():\n", + " if len(indices) == 0:\n", + " continue\n", + " \n", + " # Create masks for train and holdout\n", + " train_mask = np.ones(len(targets), dtype=bool)\n", + " train_mask[indices] = False\n", + " \n", + " # Skip if too few training targets remain\n", + " if train_mask.sum() < 3:\n", + " continue\n", + " \n", + " # Calibrate on subset\n", + " cal_temp = Calibration(\n", + " weights=weights_init.copy(),\n", + " targets=targets[train_mask],\n", + " estimate_matrix=estimate_matrix.iloc[:, train_mask],\n", + " epochs=100,\n", + " learning_rate=1e-3,\n", + " )\n", + " \n", + " # Suppress logging for cleaner output\n", + " import logging\n", + " original_level = logging.getLogger().level\n", + " logging.getLogger().setLevel(logging.WARNING)\n", + " \n", + " try:\n", + " cal_temp.calibrate()\n", + " \n", + " # Evaluate on holdout\n", + " holdout_estimates = (estimate_matrix.iloc[:, indices].T * cal_temp.weights).sum(axis=1).values\n", + " holdout_targets = targets[indices]\n", + " holdout_errors = np.abs((holdout_estimates - holdout_targets) / holdout_targets)\n", + " \n", + " results.append({\n", + " 'Category': category,\n", + " 'N targets': len(indices),\n", + " 'Mean error': f\"{np.mean(holdout_errors):.1%}\",\n", + " 'Max error': f\"{np.max(holdout_errors):.1%}\",\n", + " 'Within 10%': f\"{100*np.mean(holdout_errors < 0.1):.0f}%\"\n", + " })\n", + " finally:\n", + " # Restore logging level\n", + " logging.getLogger().setLevel(original_level)\n", + " \n", + " return pd.DataFrame(results)\n", + "\n", + "category_results = evaluate_by_category()\n", + "print(\"Robustness by target category:\")\n", + "print(category_results.to_string(index=False))\n", + "print(\"\\nInterpretation:\")\n", + "print(\"- Lower errors indicate targets that can be predicted from others\")\n", + "print(\"- High errors suggest independent information in those targets\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Best practices for robustness evaluation\n", + "\n", + "### 1. Choose appropriate holdout parameters\n", + "- **Holdout fraction**: 20-30% is typically good\n", + "- **Number of rounds**: At least 5-10 for reliable estimates\n", + "- **Epochs per round**: Enough to converge (check loss curves)\n", + "\n", + "### 2. Interpret results carefully\n", + "- **High variability**: Indicates unstable calibration\n", + "- **Large generalization gap**: Suggests overfitting\n", + "- **Low consistency**: Some target combinations are problematic\n", + "\n", + "### 3. Be aware of data leakage\n", + "Since many calibration targets share information (e.g., 'total_income' includes all state incomes), holdout validation may give optimistic results. The evaluation includes a warning about this.\n", + "\n", + "### 4. Use results to improve calibration\n", + "- Add regularization if overfitting is detected\n", + "- Remove or combine highly correlated targets\n", + "- Investigate targets with high difficulty scores\n", + "- Consider different optimization parameters\n", + "\n", + "### 5. Document your evaluation\n", + "Save robustness results along with your calibration parameters for reproducibility and comparison." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next steps\n", + "\n", + "After evaluating robustness:\n", + "\n", + "1. If robustness is poor, try:\n", + " - Hyperparameter tuning to find better L0 parameters\n", + " - Reviewing and cleaning your targets\n", + " - Increasing the dataset size\n", + "\n", + "2. If robustness is good:\n", + " - Save your calibration configuration\n", + " - Apply to production data\n", + " - Monitor performance over time\n", + "\n", + "3. For specific issues:\n", + " - High difficulty targets → Check data quality\n", + " - Large generalization gap → Add regularization\n", + " - High variability → Increase epochs or adjust learning rate" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pe3.13", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}