diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..4d91dbb 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,7 @@ +- bump: minor + changes: + added: + - Logic to create estimate matrix for calibration from a database. + - Conversion functions between dataset classes to enable stacking datasets. + - Logic to calibrate for multiple geographic levels with two different routines. + - Calibration documentation. diff --git a/docs/_toc.yml b/docs/_toc.yml index 3439a33..9c74b64 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -2,3 +2,5 @@ format: jb-book root: intro chapters: - file: dataset.ipynb + - file: normalise_keys.md + - file: calibration.ipynb diff --git a/docs/calibration.ipynb b/docs/calibration.ipynb new file mode 100644 index 0000000..47a50f7 --- /dev/null +++ b/docs/calibration.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5b9560b0", + "metadata": {}, + "source": [ + "# PolicyEngine survey weight calibration guide\n", + "\n", + "This notebook demonstrates how to use the two main calibration routines available in PolicyEngine Data:\n", + "\n", + "1. **Geographic level iteration**: Calibrating one geographic level at a time from lowest to highest in hierarchy\n", + "2. **All levels at once**: Stacking datasets at the lowest level and calibrating for all geographic levels simultaneously\n", + "\n", + "Both methods adjust household weights to match official statistics (targets) while maintaining data representativeness with a gradient descent algorithm implemented in PolicyEngine's [`microcalibrate`](https://policyengine.github.io/microcalibrate/) package." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "25555d04", + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import logging\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from policyengine_data.calibration.calibrate import (\n", + " calibrate_single_geography_level,\n", + " calibrate_all_levels,\n", + ")\n", + "from policyengine_data.calibration.target_rescaling import (\n", + " download_database,\n", + " rescale_calibration_targets,\n", + ")\n", + "from policyengine_data.calibration.target_uprating import (\n", + " uprate_calibration_targets,\n", + ")\n", + "from policyengine_data.tools.legacy_class_conversions import (\n", + " SingleYearDataset_to_Dataset,\n", + ")\n", + "from policyengine_data.calibration.target_rescaling import download_database\n", + "\n", + "from policyengine_us import Microsimulation\n", + "from policyengine_us.system import system\n", + "\n", + "# Set up logging to see calibration progress\n", + "logging.basicConfig(level=logging.ERROR)\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "calibration_logger = logging.getLogger(\"microcalibrate.calibration\")\n", + "calibration_logger.setLevel(logging.ERROR)\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\", category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5a58bd2b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "30e0dcf3dd8741d7910119daae1dc240", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "policy_data.db: 0%| | 0.00/8.52M [00:00=0.25.1", "tables", "policyengine-core>=3.6.4", + "policyengine-us", # remove as soon as we fix UCGID "microdf-python", + "microcalibrate", "sqlalchemy", - "huggingface_hub", + "torch", ] [project.optional-dependencies] @@ -30,6 +32,7 @@ dev = [ "build", "linecheck", "yaml-changelog>=0.1.7", + "policyengine-us>=1.366.0", ] docs = [ diff --git a/src/policyengine_data/calibration/__init__.py b/src/policyengine_data/calibration/__init__.py index bc4d3c1..e30d3b9 100644 --- a/src/policyengine_data/calibration/__init__.py +++ b/src/policyengine_data/calibration/__init__.py @@ -1 +1,12 @@ +from .calibrate import calibrate_all_levels, calibrate_single_geography_level +from .dataset_duplication import ( + load_dataset_for_geography_legacy, + minimize_calibrated_dataset_legacy, +) +from .metrics_matrix_creation import ( + create_metrics_matrix, + validate_metrics_matrix, +) from .target_rescaling import download_database, rescale_calibration_targets +from .target_uprating import uprate_calibration_targets +from .utils import create_geographic_normalization_factor diff --git a/src/policyengine_data/calibration/calibrate.py b/src/policyengine_data/calibration/calibrate.py index 6960329..072ee61 100644 --- a/src/policyengine_data/calibration/calibrate.py +++ b/src/policyengine_data/calibration/calibrate.py @@ -1,3 +1,583 @@ """ -This file will contain the logic for calibrating policy engine data from start to end. It will include functions for target rescaling, matrix creation, household duplication and assignment to new geographic areas, and final calibration. +This file will contain the logic for calibrating policy engine data from start to end. It will include different calibration routine options, from calibration at one geographic level to full calibration across all levels. """ + +import logging +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd + +from policyengine_data import SingleYearDataset, normalise_table_keys +from policyengine_data.calibration.dataset_duplication import ( + load_dataset_for_geography_legacy, + minimize_calibrated_dataset_legacy, +) +from policyengine_data.calibration.metrics_matrix_creation import ( + create_metrics_matrix, + validate_metrics_matrix, +) +from policyengine_data.calibration.target_rescaling import ( + download_database, + rescale_calibration_targets, +) +from policyengine_data.calibration.target_uprating import ( + uprate_calibration_targets, +) +from policyengine_data.calibration.utils import ( + create_geographic_normalization_factor, +) +from policyengine_data.tools.legacy_class_conversions import ( + SingleYearDataset_to_Dataset, +) + +logger = logging.getLogger(__name__) + + +def calibrate_single_geography_level( + microsimulation_class, + calibration_areas: Dict[str, str], + dataset: str, + stack_datasets: Optional[bool] = True, + dataset_subsample_size: Optional[int] = None, + geo_db_filter_variable: Optional[str] = "ucgid_str", + geo_sim_filter_variable: Optional[str] = "ucgid", + year: Optional[int] = 2023, + db_uri: Optional[str] = None, + noise_level: Optional[float] = 10.0, + use_dataset_weights: Optional[bool] = True, + regularize_with_l0: Optional[bool] = False, + calibration_log_path: Optional[str] = None, + raise_error: Optional[bool] = True, +) -> "SingleYearDataset": + """ + This function will calibrate the dataset for a specific geography level, defaulting to stacking the base dataset per area within it. + It will handle conversion between dataset classes to enable: + 1. Loading the base dataset and reassigning it to the specified geography. + 2. Selecting the appropriate targets that match each area at the geography level. + 3. Creating a metrics matrix that enables computing estimates for those targets. + 4. Calibrating the dataset's household weights with regularization. + 5. Filtering the resulting dataset to only include households with non-zero weights. + 6. Stacking all areas at that level into a single dataset. + + Args: + microsimulation_class: The Microsimulation class to use for creating simulations. + calibration_areas (Dict[str, str]): A dictionary mapping area names to their corresponding geography level. + dataset (str): The name of the dataset to be calibrated. + stack_datasets (Optional[bool]): Whether to assign the dataset to each area in the geography level and combine them. Default: True. + dataset_subsample_size (Optional[int]): The size of the base dataset subsample to use for calibration. If None, the full dataset will be used for stacking when enabled. + year (Optional[int]): The year for which the calibration is being performed. Default: 2023. + geo_db_filter_variable (str): The variable used to filter the database by geography. Default in the US: "ucgid_str". + geo_sim_filter_variable (str): The variable used to filter the simulation by geography. Default in the US: "ucgid". + db_uri (Optional[str]): The URI of the database to use for rescaling targets. If None, it will download the database from the default URI. + noise_level (Optional[float]): The level of noise to apply during calibration. Default: 10.0. + use_dataset_weights (Optional[bool]): Whether to use original dataset weights as the starting weights for calibration. Default: True. + regularize_with_l0 (Optional[bool]): Whether to use L0 regularization during calibration. Default: False. + calibration_log_path (Optional[str]): The path to the calibration log file. If None, calibration log CSVs will not be saved. + raise_error (Optional[bool]): Whether to raise an error if matrix creation fails. Default: True. + + Returns: + geography_level_calibrated_dataset (SingleYearDataset): The calibrated dataset for the specified geography level. + """ + if db_uri is None: + db_uri = download_database() + + geography_level_calibrated_dataset = None + for area, geo_identifier in calibration_areas.items(): + logger.info(f"Calibrating dataset for {area}...") + + if stack_datasets: + # Load dataset configured for the specific geography first + # TODO: move away from hardcoding UCGID for geographic identification once -us is updated + from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( + UCGID, + ) + + sim_data_to_calibrate = load_dataset_for_geography_legacy( + microsimulation_class=microsimulation_class, + year=year, + dataset=dataset, + dataset_subsample_size=dataset_subsample_size, + geography_variable=geo_sim_filter_variable, + geography_identifier=UCGID( + geo_identifier + ), # will need a non-hardcoded solution to assign geography_identifier in the future + ) + else: + sim_data_to_calibrate = microsimulation_class(dataset=dataset) + sim_data_to_calibrate.default_input_period = year + sim_data_to_calibrate.build_from_dataset() + + # Create metrics matrix for the area based on strata constraints using configured simulation + metrics_matrix, targets, target_info = create_metrics_matrix( + db_uri=db_uri, + time_period=year, + microsimulation_class=microsimulation_class, + sim=sim_data_to_calibrate, + stratum_filter_variable=geo_db_filter_variable, + stratum_filter_value=geo_identifier, + stratum_filter_operation="in", + ) + metrics_evaluation = validate_metrics_matrix( + metrics_matrix, + targets, + target_info=target_info, + raise_error=raise_error, + ) + + target_names = [] + excluded_targets = [] + for target_id, info in target_info.items(): + target_names.append(info["name"]) + if not info["active"]: + excluded_targets.append(target_id) + target_names = np.array(target_names) + + if use_dataset_weights: + weights = sim_data_to_calibrate.calculate( + "household_weight" + ).values + else: + weights = np.ones(len(metrics_matrix)) + + # Calibrate with L0 regularization + from microcalibrate import Calibration + + calibrator = Calibration( + weights=weights, + targets=targets, + target_names=target_names, + estimate_matrix=metrics_matrix, + epochs=600, + learning_rate=0.2, + noise_level=noise_level, + excluded_targets=( + excluded_targets if len(excluded_targets) > 0 else None + ), + sparse_learning_rate=0.1, + regularize_with_l0=regularize_with_l0, + csv_path=calibration_log_path, + ) + performance_log = calibrator.calibrate() + optimized_sparse_weights = calibrator.sparse_weights + optimized_weights = calibrator.weights + + # Minimize the calibrated dataset storing only records with non-zero weights + single_year_calibrated_dataset = minimize_calibrated_dataset_legacy( + microsimulation_class=microsimulation_class, + sim=sim_data_to_calibrate, + year=year, + optimized_weights=( + optimized_sparse_weights + if regularize_with_l0 + else optimized_weights + ), + ) + + # Detect ids that require resetting after minimization + primary_id_variables = {} + for entity in single_year_calibrated_dataset.entities: + primary_id_variables[entity] = f"{entity}_id" + + foreign_id_variables = {} + for entity in single_year_calibrated_dataset.entities: + entity_foreign_keys = {} + for target_entity in single_year_calibrated_dataset.entities: + if entity != target_entity: + foreign_key_name = f"{entity}_{target_entity}_id" + if ( + foreign_key_name + in sim_data_to_calibrate.tax_benefit_system.variables + ) and ( + foreign_key_name + in single_year_calibrated_dataset.entities[ + entity + ].columns + ): + entity_foreign_keys[foreign_key_name] = target_entity + + if entity_foreign_keys: + foreign_id_variables[entity] = entity_foreign_keys + + # Combine area datasets + if geography_level_calibrated_dataset is None: + geography_level_calibrated_dataset = single_year_calibrated_dataset + single_year_calibrated_dataset.entities = normalise_table_keys( + single_year_calibrated_dataset.entities, + primary_keys=primary_id_variables, + foreign_keys=foreign_id_variables, + start_index=None, + ) + else: + previous_max_ids = {} + for entity in single_year_calibrated_dataset.entities: + previous_max_ids[entity] = ( + geography_level_calibrated_dataset.entities[entity][ + f"{entity}_id" + ].max() + + 1 + ) + + single_year_calibrated_dataset.entities = normalise_table_keys( + single_year_calibrated_dataset.entities, + primary_keys=primary_id_variables, + foreign_keys=foreign_id_variables, + start_index=previous_max_ids, + ) + + geography_level_calibrated_dataset.entities = { + entity: pd.concat( + [ + geography_level_calibrated_dataset.entities[entity], + single_year_calibrated_dataset.entities[entity], + ], + ignore_index=True, + ) + for entity in geography_level_calibrated_dataset.entities.keys() + } + + return geography_level_calibrated_dataset + + +def calibrate_all_levels( + microsimulation_class, + database_stacking_areas: Dict[str, str], + dataset: str, + dataset_subsample_size: Optional[int] = None, + geo_sim_filter_variable: Optional[str] = "ucgid", + geo_hierarchy: Optional[List[str]] = None, + year: Optional[int] = 2023, + db_uri: Optional[str] = None, + noise_level: Optional[float] = 10.0, + regularize_with_l0: Optional[bool] = False, + raise_error: Optional[bool] = True, +) -> "SingleYearDataset": + """ + This function will calibrate the dataset for all geography levels in the database, defaulting to stacking the base dataset per area within the specified level (it is recommended to use the lowest in the hierarchy for stacking). (Eg. when calibrating for district, state and national levels in the US, this function will stack the CPS dataset for each district and calibrate the stacked dataset for the three levels' targets.) + It will handle conversion between dataset classes to enable: + 1. Loading the base dataset and reassigning it to the geographic areas within the lowest level. + 2. Stacking all areas at that level into a single dataset. + 3. Selecting all targets that match the specified geography levels. + 4. Creating a metrics matrix that enables computing estimates for those targets. + 5. Calibrating the dataset's household weights with regularization. + 6. Filtering the resulting dataset to only include households with non-zero weights. + + Args: + microsimulation_class: The Microsimulation class to use for creating simulations. + database_stacking_areas (Dict[str, str]): A dictionary mapping area names to their identifiers for base dataset stacking. + dataset (str): Path to the base dataset to stack. + dataset_subsample_size (Optional[int]): The size of the subsample to use for calibration. + geo_sim_filter_variable (Optional[str]): The variable to use for geographic similarity filtering. Default in the US: "ucgid". + geo_hierarchy (Optional[List[str]]): The geographic hierarchy to use for calibration. + year (Optional[int]): The year to use for calibration. Default: 2023. + db_uri (Optional[str]): The database URI to use for calibration. If None, it will download the database from the default URI. + noise_level (Optional[float]): The noise level to use for calibration. Default: 10.0. + regularize_with_l0 (Optional[bool]): Whether to use L0 regularization for calibration. Default: False. + raise_error (Optional[bool]): Whether to raise an error if matrix creation fails. Default: True. + + Returns: + fully_calibrated_dataset (SingleYearDataset): The calibrated dataset for all geography levels. + """ + if db_uri is None: + db_uri = download_database() + + stacked_dataset = None + for area, geo_identifier in database_stacking_areas.items(): + logger.info(f"Stacking dataset for {area}...") + + # Load dataset configured for the specific geographic area + from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( + UCGID, + ) + + sim_data_to_stack = load_dataset_for_geography_legacy( + microsimulation_class=microsimulation_class, + year=year, + dataset=dataset, + dataset_subsample_size=dataset_subsample_size, + geography_variable=geo_sim_filter_variable, + geography_identifier=UCGID( + geo_identifier + ), # will need a non-hardcoded solution to assign geography_identifier in the future + ) + + single_year_dataset = SingleYearDataset.from_simulation( + simulation=sim_data_to_stack, + time_period=year, + ) + + # Detect ids that require resetting + primary_id_variables = {} + for entity in single_year_dataset.entities: + primary_id_variables[entity] = f"{entity}_id" + + foreign_id_variables = {} + for entity in single_year_dataset.entities: + entity_foreign_keys = {} + for target_entity in single_year_dataset.entities: + if entity != target_entity: + foreign_key_name = f"{entity}_{target_entity}_id" + if ( + foreign_key_name + in sim_data_to_stack.tax_benefit_system.variables + ) and ( + foreign_key_name + in single_year_dataset.entities[entity].columns + ): + entity_foreign_keys[foreign_key_name] = target_entity + + if entity_foreign_keys: + foreign_id_variables[entity] = entity_foreign_keys + + # Combine datasets + if stacked_dataset is None: + stacked_dataset = single_year_dataset + single_year_dataset.entities = normalise_table_keys( + single_year_dataset.entities, + primary_keys=primary_id_variables, + foreign_keys=foreign_id_variables, + start_index=None, + ) + else: + previous_max_ids = {} + for entity in single_year_dataset.entities: + previous_max_ids[entity] = ( + stacked_dataset.entities[entity][f"{entity}_id"].max() + 1 + ) + + single_year_dataset.entities = normalise_table_keys( + single_year_dataset.entities, + primary_keys=primary_id_variables, + foreign_keys=foreign_id_variables, + start_index=previous_max_ids, + ) + + stacked_dataset.entities = { + entity: pd.concat( + [ + stacked_dataset.entities[entity], + single_year_dataset.entities[entity], + ], + ignore_index=True, + ) + for entity in stacked_dataset.entities.keys() + } + + SingleYearDataset_to_Dataset( + stacked_dataset, + output_path="Dataset_stacked.h5", + ) + + logger.info( + "Stacked dataset created successfully, starting to process it for calibration..." + ) + + metrics_matrix, targets, target_info = create_metrics_matrix( + db_uri=db_uri, + time_period=year, + microsimulation_class=microsimulation_class, + dataset="Dataset_stacked.h5", + reform_id=0, + ) + metrics_evaluation = validate_metrics_matrix( + metrics_matrix, + targets, + target_info=target_info, + raise_error=raise_error, + ) + + normalization_factor = create_geographic_normalization_factor( + geo_hierarchy=geo_hierarchy, target_info=target_info + ) + + target_names = [] + excluded_targets = [] + for target_id, info in target_info.items(): + target_names.append(info["name"]) + if not info["active"]: + excluded_targets.append(target_id) + target_names = np.array(target_names) + + weights = np.ones(len(metrics_matrix)) + + # Calibrate with L0 regularization + from microcalibrate import Calibration + + calibrator = Calibration( + weights=weights, + targets=targets, + target_names=target_names, + estimate_matrix=metrics_matrix, + epochs=600, + learning_rate=0.2, + noise_level=noise_level, + excluded_targets=( + excluded_targets if len(excluded_targets) > 0 else None + ), + normalization_factor=normalization_factor, + sparse_learning_rate=0.1, + regularize_with_l0=regularize_with_l0, + csv_path=f"full_calibration.csv", + ) + performance_log = calibrator.calibrate() + optimized_sparse_weights = calibrator.sparse_weights + optimized_weights = calibrator.weights + + sim_stacked_dataset = microsimulation_class(dataset="Dataset_stacked.h5") + sim_stacked_dataset.default_input_period = year + sim_stacked_dataset.build_from_dataset() + + # Minimize the calibrated dataset storing only records with non-zero weights + fully_calibrated_dataset = minimize_calibrated_dataset_legacy( + microsimulation_class=microsimulation_class, + sim=sim_stacked_dataset, + year=year, + optimized_weights=( + optimized_sparse_weights + if regularize_with_l0 + else optimized_weights + ), + ) + + return fully_calibrated_dataset + + +if __name__ == "__main__": + from policyengine_us import Microsimulation + from policyengine_us.system import system + + print("US calibration example:") + + areas_in_national_level = { + "United States": "0100000US", + } + + areas_in_state_level = { + "Alabama": "0400000US01", + "Alaska": "0400000US02", + "Arizona": "0400000US04", + "Arkansas": "0400000US05", + "California": "0400000US06", + "Colorado": "0400000US08", + "Connecticut": "0400000US09", + "Delaware": "0400000US10", + "District of Columbia": "0400000US11", + "Florida": "0400000US12", + "Georgia": "0400000US13", + "Hawaii": "0400000US15", + "Idaho": "0400000US16", + "Illinois": "0400000US17", + "Indiana": "0400000US18", + "Iowa": "0400000US19", + "Kansas": "0400000US20", + "Kentucky": "0400000US21", + "Louisiana": "0400000US22", + "Maine": "0400000US23", + "Maryland": "0400000US24", + "Massachusetts": "0400000US25", + "Michigan": "0400000US26", + "Minnesota": "0400000US27", + "Mississippi": "0400000US28", + "Missouri": "0400000US29", + "Montana": "0400000US30", + "Nebraska": "0400000US31", + "Nevada": "0400000US32", + "New Hampshire": "0400000US33", + "New Jersey": "0400000US34", + "New Mexico": "0400000US35", + "New York": "0400000US36", + "North Carolina": "0400000US37", + "North Dakota": "0400000US38", + "Ohio": "0400000US39", + "Oklahoma": "0400000US40", + "Oregon": "0400000US41", + "Pennsylvania": "0400000US42", + "Rhode Island": "0400000US44", + "South Carolina": "0400000US45", + "South Dakota": "0400000US46", + "Tennessee": "0400000US47", + "Texas": "0400000US48", + "Utah": "0400000US49", + "Vermont": "0400000US50", + "Virginia": "0400000US51", + "Washington": "0400000US53", + "West Virginia": "0400000US54", + "Wisconsin": "0400000US55", + "Wyoming": "0400000US56", + } + + db_uri = download_database() + + # Rescale targets for consistency across geography areas + rescaling_results = rescale_calibration_targets( + db_uri=db_uri, update_database=True + ) + + # Uprate targets for consistency across definition year (disabled until IRS SOI variables are renamed to avoid errors) + # uprating_results = uprate_calibration_targets( + # system=system, + # db_uri=db_uri, + # from_period=2022, + # to_period=2023, + # update_database=True, + # ) + + state_level_calibrated_dataset = calibrate_single_geography_level( + Microsimulation, + areas_in_state_level, + "hf://policyengine/policyengine-us-data/cps_2023.h5", + db_uri=db_uri, + use_dataset_weights=False, + regularize_with_l0=True, + ) + + state_level_weights = state_level_calibrated_dataset.entities["household"][ + "household_weight" + ].values + + SingleYearDataset_to_Dataset( + state_level_calibrated_dataset, + output_path="Dataset_state_level_age_medicaid_snap_eitc_agi_targets.h5", + ) + + print("Completed calibration for state level dataset.") + + print( + "Number of household records at the state level:", + len(state_level_calibrated_dataset.entities["household"]), + ) + + national_level_calibrated_dataset = calibrate_single_geography_level( + Microsimulation, + areas_in_national_level, + dataset="Dataset_state_level_age_medicaid_snap_eitc_agi_targets.h5", + db_uri=db_uri, + stack_datasets=False, + noise_level=0.0, + use_dataset_weights=True, + regularize_with_l0=False, + ) + + national_level_weights = national_level_calibrated_dataset.entities[ + "household" + ]["household_weight"].values + + SingleYearDataset_to_Dataset( + national_level_calibrated_dataset, + output_path="Dataset_national_level_age_medicaid_snap_eitc_agi_targets.h5", + ) + + print("Completed calibration for national level dataset.") + + print( + "Number of household records at the national level:", + len(national_level_calibrated_dataset.entities["household"]), + ) + + print("Weights comparison:") + print( + f"Weights from state level calibration - mean: {state_level_weights.mean()}; range: [{state_level_weights.min()}, {state_level_weights.max()}]" + ) + print( + f"Weights from national level calibration - mean: {national_level_weights.mean()}; range: [{national_level_weights.min()}, {national_level_weights.max()}]" + ) diff --git a/src/policyengine_data/calibration/dataset_duplication.py b/src/policyengine_data/calibration/dataset_duplication.py new file mode 100644 index 0000000..4eca09b --- /dev/null +++ b/src/policyengine_data/calibration/dataset_duplication.py @@ -0,0 +1,223 @@ +from typing import Any, Optional + +import numpy as np +import pandas as pd +from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( + UCGID, +) + +from ..dataset_legacy import Dataset +from ..single_year_dataset import SingleYearDataset + +""" +Functions using the legacy Dataset class to operate datasets given their dependency on Microsimulation objects. +""" + + +def load_dataset_for_geography_legacy( + microsimulation_class, + year: Optional[int] = 2023, + dataset: Optional[str] = None, + dataset_subsample_size: Optional[int] = None, + geography_variable: Optional[str] = "ucgid", + geography_identifier: Optional[Any] = UCGID("0100000US"), +): + """ + Load the necessary dataset from the legacy Dataset class, making it specific to a geography area. (e.g., CPS for the state of California). + + Args: + microsimulation_class: The Microsimulation class to use for creating simulations. + year (Optional[int]): The year for which to calibrate the dataset. + dataset (Optional[None]): The dataset to load. If None, defaults to the CPS dataset for the specified year. + dataset_subsample_size (Optional[int]): The size of the base dataset subsample to use for calibration. If None, the full dataset will be used for stacking. + geography_variable (Optional[str]): The variable representing the geography in the dataset. + geography_identifier (Optional[str]): The identifier for the geography to calibrate. + + Returns: + Microsimulation: The Microsimulation object with the specified geography. + """ + if dataset is None: + dataset = f"hf://policyengine/policyengine-us-data/cps_{year}.h5" + + sim = microsimulation_class(dataset=dataset) + sim.default_input_period = year + sim.build_from_dataset() + + if dataset_subsample_size is not None: + df = sim.to_input_dataframe() + + # Find the household ID column (it should be named with the year) + household_id_column = None + for col in df.columns: + if col.startswith("household_id__"): + household_id_column = col + break + + if household_id_column is None: + raise KeyError( + "Could not find household_id column in simulation dataframe" + ) + + # Get unique household IDs + unique_household_ids = df[household_id_column].unique() + + # Subsample households if we have more than requested + if len(unique_household_ids) > dataset_subsample_size: + np.random.seed(42) # For reproducible results + subsampled_household_ids = np.random.choice( + unique_household_ids, + size=dataset_subsample_size, + replace=False, + ) + + # Filter dataframe to only include subsampled households + subset_df = df[ + df[household_id_column].isin(subsampled_household_ids) + ].copy() + + # Create new simulation from subsampled data + sim = microsimulation_class() + sim.dataset = Dataset.from_dataframe(subset_df, year) + sim.default_input_period = year + sim.build_from_dataset() + + hhs = len(sim.calculate("household_id").values) + geo_values = [geography_identifier] * hhs + sim.set_input(geography_variable, year, geo_values) + + ucgid_values = sim.calculate(geography_variable).values + assert all(val == geography_identifier.name for val in ucgid_values) + + return sim + + +def minimize_calibrated_dataset_legacy( + microsimulation_class, sim, year: int, optimized_weights: pd.Series +) -> "SingleYearDataset": + """ + Use sparse weights to minimize the calibrated dataset storing in the legacy Dataset class. + + Args: + microsimulation_class: The Microsimulation class to use for creating simulations. + sim: The Microsimulation object with the dataset to minimize. + year (int): Year the dataset is representing. + optimized_weights (pd.Series): The calibrated, regularized weights used to minimize the dataset. + + Returns: + SingleYearDataset: The regularized dataset + """ + # Copy all existing variable data to the target year + for variable_name in sim.tax_benefit_system.variables: + holder = sim.get_holder(variable_name) + known_periods = holder.get_known_periods() + if known_periods and variable_name != "household_weight": + # Copy from the first available period to target year + source_period = known_periods[0] + try: + values = sim.calculate(variable_name, source_period).values + sim.set_input(variable_name, year, values) + except Exception: + # Skip variables that can't be copied + continue + + # Set the calibrated household weights for the target year + sim.set_input("household_weight", year, optimized_weights) + + df = sim.to_input_dataframe() + + # Use the target year for column names + household_weight_column = f"household_weight__{year}" + df_household_id_column = f"household_id__{year}" + + # Fallback: if target year columns don't exist, detect the actual year from column names + if ( + household_weight_column not in df.columns + or df_household_id_column not in df.columns + ): + for col in df.columns: + if col.startswith("household_weight__"): + detected_year = col.split("__")[1].split("-")[0] + household_weight_column = f"household_weight__{detected_year}" + df_household_id_column = f"household_id__{detected_year}" + break + else: + raise KeyError( + "Could not find household_weight or household_id columns" + ) + + # Group by household ID and get the first entry for each group + h_df = df.groupby(df_household_id_column).first() + h_ids = pd.Series(h_df.index) + h_weights = pd.Series(h_df[household_weight_column].values) + + # Filter to housholds with non-zero weights + h_ids = h_ids[h_weights > 0] + h_weights = h_weights[h_weights > 0] + + subset_df = df[df[df_household_id_column].isin(h_ids)].copy() + + # Update the dataset and rebuild the simulation + sim = microsimulation_class() + sim.dataset = Dataset.from_dataframe(subset_df, year) + sim.default_input_period = year + sim.build_from_dataset() + + single_year_dataset = SingleYearDataset.from_simulation(sim, year) + + return single_year_dataset + + +""" +Functions using the new SingleYearDataset class once the Microsimulation object is adapted to it. +""" + + +def load_dataset_for_geography( + microsimulation_class, + year: Optional[int] = 2023, + dataset: Optional[str] = None, + geography_variable: Optional[str] = "ucgid", + geography_identifier: Optional[Any] = UCGID("0100000US"), +) -> "SingleYearDataset": + """ + Load the necessary dataset from the legacy Dataset class into the new SingleYearDataset, or directly from it, making it specific to a geography area. (e.g., CPS for the state of California). + + Args: + microsimulation_class: The Microsimulation class to use for creating simulations. + year (Optional[int]): The year for which to calibrate the dataset. + dataset (Optional[None]): The dataset to load. If None, defaults to the CPS dataset for the specified year. + geography_variable (Optional[str]): The variable representing the geography in the dataset. + geography_identifier (Optional[str]): The identifier for the geography to calibrate. + + Returns: + SingleYearDataset: The calibrated dataset after applying regularization. + """ + if dataset is None: + dataset = f"hf://policyengine/policyengine-us-data/cps_{year}.h5" + + sim = microsimulation_class(dataset=dataset) + + # To load from the Microsimulation object for compatibility with legacy Dataset class + single_year_dataset = SingleYearDataset.from_simulation( + sim, time_period=year + ) + # To load from the SingleYearDataset class directly + # single_year_dataset = SingleYearDataset(file_path=dataset) + single_year_dataset.time_period = year + + household_vars = single_year_dataset.entities["household"] + household_vars[geography_variable] = geography_identifier + single_year_dataset.entities["household"] = household_vars + + return single_year_dataset + + +def minimize_calibrated_dataset( + dataset: SingleYearDataset, +) -> "SingleYearDataset": + """ + Use sparse weights to minimize the calibrated dataset. + + To come after policyengine_core adaptation. + """ + pass diff --git a/src/policyengine_data/calibration/metrics_matrix_creation.py b/src/policyengine_data/calibration/metrics_matrix_creation.py new file mode 100644 index 0000000..7997893 --- /dev/null +++ b/src/policyengine_data/calibration/metrics_matrix_creation.py @@ -0,0 +1,596 @@ +import logging +from typing import Dict, Optional, Tuple + +import numpy as np +import pandas as pd +from sqlalchemy import create_engine + +from .target_rescaling import download_database + +logger = logging.getLogger(__name__) + + +# NOTE (juaristi22): This could fail if trying to filter by more than one stratum constraint if there are mismatches between the filtering variable, value and operation. +def fetch_targets_from_database( + engine, + time_period: int, + reform_id: Optional[int] = 0, + stratum_filter_variable: Optional[str] = None, + stratum_filter_value: Optional[str] = None, + stratum_filter_operation: Optional[str] = None, +) -> pd.DataFrame: + """ + Fetch all targets for a specific time period and reform from the database. + + Args: + engine: SQLAlchemy engine + time_period: The year to fetch targets for + reform_id: The reform scenario ID (0 for baseline) + stratum_filter_variable: Optional variable name to filter strata by + stratum_filter_value: Optional value to filter strata by + stratum_filter_operation: Optional operation for filtering ('equals', 'in', etc.) + + Returns: + DataFrame with target data including target_id, variable, value, etc. + """ + # Base query + query = """ + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.period, + t.reform_id, + t.value, + t.active, + t.tolerance, + t.notes, + s.stratum_group_id, + s.parent_stratum_id + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE t.period = :period + AND t.reform_id = :reform_id + """ + + params = {"period": time_period, "reform_id": reform_id} + + # Add stratum filtering if specified + if all( + [ + stratum_filter_variable, + stratum_filter_value, + stratum_filter_operation, + ] + ): + # Add join with stratum_constraints and apply filter + query += """ + AND t.stratum_id IN ( + SELECT sc.stratum_id + FROM stratum_constraints sc + WHERE sc.constraint_variable = :filter_variable + AND sc.operation = :filter_operation + AND sc.value = :filter_value + ) + """ + params.update( + { + "filter_variable": stratum_filter_variable, + "filter_operation": stratum_filter_operation, + "filter_value": stratum_filter_value, + } + ) + + query += " ORDER BY t.target_id" + + return pd.read_sql(query, engine, params=params) + + +def fetch_stratum_constraints(engine, stratum_id: int) -> pd.DataFrame: + """ + Fetch all constraints for a specific stratum from the database. + + Args: + engine: SQLAlchemy engine + stratum_id: The stratum ID + + Returns: + DataFrame with constraint data + """ + query = """ + SELECT + stratum_id, + constraint_variable, + value, + operation, + notes + FROM stratum_constraints + WHERE stratum_id = :stratum_id + ORDER BY constraint_variable + """ + + return pd.read_sql(query, engine, params={"stratum_id": stratum_id}) + + +def parse_constraint_value(value: str, operation: str): + """ + Parse constraint value based on its type and operation. + + Args: + value: String value from constraint + operation: Operation type + + Returns: + Parsed value (could be list, float, int, or string) + """ + # Handle special operations that might use lists + if operation == "in" and "," in value: + # Parse as list + return [v.strip() for v in value.split(",")] + + # Try to convert to boolean + if value.lower() in ("true", "false"): + return value.lower() == "true" + + # Try to convert to numeric + try: + num_value = float(value) + if num_value.is_integer(): + return int(num_value) + return num_value + except ValueError: + return value + + +def apply_single_constraint( + values: np.ndarray, operation: str, constraint_value +) -> np.ndarray: + """ + Apply a single constraint operation to create a boolean mask. + + Args: + values: Array of values to apply constraint to + operation: Operation type + constraint_value: Parsed constraint value + + Returns: + Boolean array indicating which values meet the constraint + """ + operations = { + "equals": lambda v, cv: v == cv, + "is_greater_than": lambda v, cv: v > cv, + "greater_than": lambda v, cv: v > cv, + "greater_than_or_equal": lambda v, cv: v >= cv, + "less_than": lambda v, cv: v < cv, + "less_than_or_equal": lambda v, cv: v <= cv, + "not_equals": lambda v, cv: v != cv, + } + + # "in" operation - check if constraint value is contained in string values + if operation == "in": + if isinstance(constraint_value, list): + mask = np.zeros(len(values), dtype=bool) + for cv in constraint_value: + mask |= np.array( + [str(cv) in str(v) for v in values], dtype=bool + ) + return mask + else: + return np.array( + [str(constraint_value) in str(v) for v in values], dtype=bool + ) + + if operation not in operations: + raise ValueError(f"Unknown operation: {operation}") + + result = operations[operation](values, constraint_value) + return np.array(result, dtype=bool) + + +def apply_constraints_at_entity_level( + sim, constraints_df: pd.DataFrame, target_entity: str +) -> np.ndarray: + """ + Create a boolean mask at the target entity level by applying all constraints. + + Args: + sim: Microsimulation instance + constraints_df: DataFrame with constraint data + target_entity: Entity level of the target variable ('person', 'tax_unit', 'household', etc.) + + Returns: + Boolean array at the target entity level + """ + # Get the number of entities at the target level + entity_count = len(sim.calculate(f"{target_entity}_id").values) + + # Start with all True + if constraints_df.empty: + return np.ones(entity_count, dtype=bool) + combined_mask = np.ones(entity_count, dtype=bool) + + # Apply each constraint + for _, constraint in constraints_df.iterrows(): + constraint_var = constraint["constraint_variable"] + + constraint_values = sim.calculate(constraint_var).values + constraint_entity = sim.tax_benefit_system.variables[ + constraint_var + ].entity.key + + parsed_value = parse_constraint_value( + constraint["value"], constraint["operation"] + ) + + # Apply the constraint at its native level + constraint_mask = apply_single_constraint( + constraint_values, constraint["operation"], parsed_value + ) + + # Map the constraint mask to the target entity level if needed + if constraint_entity != target_entity: + constraint_mask = sim.map_result( + constraint_mask, constraint_entity, target_entity + ) + + # Ensure it's boolean + constraint_mask = np.array(constraint_mask, dtype=bool) + + # Combine + combined_mask = combined_mask & constraint_mask + + assert ( + len(combined_mask) == entity_count + ), f"Combined mask length {len(combined_mask)} does not match entity count {entity_count}." + + return combined_mask + + +def process_single_target( + sim, + target: pd.Series, + constraints_df: pd.DataFrame, +) -> Tuple[np.ndarray, Dict[str, any]]: + """ + Process a single target by applying constraints at the appropriate entity level. + + Args: + sim: Microsimulation instance + target: pandas Series with target data + constraints_df: DataFrame with constraint data + + Returns: + Tuple of (metric_values at household level, target_info_dict) + """ + target_var = target["variable"] + target_entity = sim.tax_benefit_system.variables[target_var].entity.key + + # Create constraint mask at the target entity level + entity_mask = apply_constraints_at_entity_level( + sim, constraints_df, target_entity + ) + + # Calculate the target variable at its native level + target_values = sim.calculate(target_var).values + + # Apply the mask at the entity level + masked_values = target_values * entity_mask + masked_values_sum_true = masked_values.sum() + + # Map the masked result to household level + if target_entity != "household": + household_values = sim.map_result( + masked_values, target_entity, "household" + ) + else: + household_values = masked_values + + household_values_sum = household_values.sum() + + if target_var == "person_count": + assert ( + household_values_sum == masked_values_sum_true + ), f"Household values sum {household_values_sum} does not match masked values sum {masked_values_sum_true} for person_count with age constraints." + + # Build target info dictionary + target_info = { + "name": build_target_name(target["variable"], constraints_df), + "active": bool(target["active"]), + "tolerance": ( + target["tolerance"] if pd.notna(target["tolerance"]) else None + ), + } + + return household_values, target_info + + +def parse_constraint_for_name(constraint: pd.Series) -> str: + """ + Parse a single constraint into a human-readable format for naming. + + Args: + constraint: pandas Series with constraint data + + Returns: + Human-readable constraint description + """ + var = constraint["constraint_variable"] + op = constraint["operation"] + val = constraint["value"] + + # Map operations to symbols for readability + op_symbols = { + "equals": "=", + "is_greater_than": ">", + "greater_than": ">", + "greater_than_or_equal": ">=", + "less_than": "<", + "less_than_or_equal": "<=", + "not_equals": "!=", + "in": "in", + } + + # Get the symbol or use the operation name if not found + symbol = op_symbols.get(op, op) + + # Format the constraint + if op == "in": + # Replace commas with underscores for "in" operations + return f"{var}_in_{val.replace(',', '_')}" + else: + # Use the symbol format for all other operations + return f"{var}{symbol}{val}" + + +def build_target_name(variable: str, constraints_df: pd.DataFrame) -> str: + """ + Build a descriptive name for a target with variable and constraints. + + Args: + variable: Target variable name + constraints_df: DataFrame with constraint data + + Returns: + Descriptive string name + """ + parts = [variable] + + if not constraints_df.empty: + # Sort constraints to ensure consistent naming + # First by whether it's ucgid, then alphabetically + constraints_sorted = constraints_df.copy() + constraints_sorted["is_ucgid"] = constraints_sorted[ + "constraint_variable" + ].str.contains("ucgid") + constraints_sorted = constraints_sorted.sort_values( + ["is_ucgid", "constraint_variable"], ascending=[False, True] + ) + + # Add each constraint + for _, constraint in constraints_sorted.iterrows(): + parts.append(parse_constraint_for_name(constraint)) + + return "_".join(parts) + + +def create_metrics_matrix( + db_uri: str, + time_period: int, + microsimulation_class, + sim=None, + dataset: Optional[type] = None, + reform_id: Optional[int] = 0, + stratum_filter_variable: Optional[str] = None, + stratum_filter_value: Optional[str] = None, + stratum_filter_operation: Optional[str] = None, +) -> Tuple[pd.DataFrame, np.ndarray, Dict[int, Dict[str, any]]]: + """ + Create the metrics matrix from the targets database. + + This function processes all targets in the database to create a matrix where: + - Rows represent households + - Columns represent targets + - Values represent the metric calculation for each household-target combination + + Args: + db_uri: Database connection string + time_period: Time period for the simulation + microsimulation_class: The Microsimulation class to use for creating simulations + sim: Optional existing Microsimulation instance + dataset: Optional dataset type for creating new simulation + reform_id: Reform scenario ID (0 for baseline) + stratum_filter_variable: Optional variable name to filter strata by + stratum_filter_value: Optional value to filter strata by + stratum_filter_operation: Optional operation for filtering ('equals', 'in', etc.) + + Returns: + Tuple of: + - metrics_matrix: DataFrame with target_id as columns, households as rows + - target_values: Array of target values in same order as columns + - target_info: Dictionary mapping target_id to info dict with keys: + - name: Descriptive name + - active: Boolean active status + - tolerance: Tolerance percentage (or None) + """ + # Setup database connection + engine = create_engine(db_uri) + + # Initialize simulation + if sim is None: + if dataset is None: + raise ValueError("Either 'sim' or 'dataset' must be provided") + sim = microsimulation_class(dataset=dataset) + sim.default_calculation_period = time_period + sim.build_from_dataset() + + # Get household IDs for matrix index + household_ids = sim.calculate("household_id").values + n_households = len(household_ids) + + # Fetch all targets from database + targets_df = fetch_targets_from_database( + engine, + time_period, + reform_id, + stratum_filter_variable, + stratum_filter_value, + stratum_filter_operation, + ) + logger.info( + f"Processing {len(targets_df)} targets for period {time_period}" + ) + + # Initialize outputs + target_values = [] + target_info = {} + metrics_list = [] + target_ids = [] + + # Process each target + for _, target in targets_df.iterrows(): + target_id = target["target_id"] + + try: + # Fetch constraints for this target's stratum + constraints_df = fetch_stratum_constraints( + engine, int(target["stratum_id"]) + ) + + # Process the target + household_values, info_dict = process_single_target( + sim, target, constraints_df + ) + + # Store results + metrics_list.append(household_values) + target_ids.append(target_id) + target_values.append(target["value"]) + target_info[target_id] = info_dict + + logger.debug( + f"Processed target {target_id}: {info_dict['name']} " + f"(active={info_dict['active']}, tolerance={info_dict['tolerance']})" + ) + + except Exception as e: + logger.error(f"Error processing target {target_id}: {str(e)}") + # Add zero column for failed targets + metrics_list.append(np.zeros(n_households)) + target_ids.append(target_id) + target_values.append(target["value"]) + target_info[target_id] = { + "name": f"ERROR_{target['variable']}", + "active": False, + "tolerance": None, + } + + # Create the metrics matrix DataFrame + metrics_matrix = pd.DataFrame( + data=np.column_stack(metrics_list), + index=household_ids, + columns=target_ids, + ) + + # Convert target values to numpy array + target_values = np.array(target_values) + + logger.info(f"Created metrics matrix with shape {metrics_matrix.shape}") + logger.info( + f"Active targets: {sum(info['active'] for info in target_info.values())}" + ) + + return metrics_matrix, target_values, target_info + + +def validate_metrics_matrix( + metrics_matrix: pd.DataFrame, + target_values: np.ndarray, + weights: Optional[np.ndarray] = None, + target_info: Optional[Dict[int, Dict[str, any]]] = None, + raise_error: Optional[bool] = False, +) -> pd.DataFrame: + """ + Validate the metrics matrix by checking estimates vs targets. + + Args: + metrics_matrix: The metrics matrix + target_values: Array of target values + weights: Optional weights array (defaults to uniform weights) + target_info: Optional target info dictionary + raise_error: Whether to raise an error for invalid estimates + + Returns: + DataFrame with validation results + """ + if weights is None: + weights = np.ones(len(metrics_matrix)) / len(metrics_matrix) + + estimates = weights @ metrics_matrix.values + + if raise_error: + for _, record in metrics_matrix.iterrows(): + if record.sum() == 0: + raise ValueError( + f"Record {record.name} has all zero estimates. None of the target constraints were met by this household and its individuals." + ) + if not np.all(estimates != 0): + zero_indices = np.where(estimates == 0)[0] + zero_targets = [metrics_matrix.columns[i] for i in zero_indices] + raise ValueError( + f"{(estimates == 0).sum()} estimate(s) contain zero values for targets: {zero_targets}" + ) + + validation_data = { + "target_id": metrics_matrix.columns, + "target_value": target_values, + "estimate": estimates, + "absolute_error": np.abs(estimates - target_values), + "relative_error": np.abs( + (estimates - target_values) / (target_values + 1e-10) + ), + } + + # Add target info if provided + if target_info is not None: + validation_data["name"] = [ + target_info.get(tid, {}).get("name", "Unknown") + for tid in metrics_matrix.columns + ] + validation_data["active"] = [ + target_info.get(tid, {}).get("active", False) + for tid in metrics_matrix.columns + ] + validation_data["tolerance"] = [ + target_info.get(tid, {}).get("tolerance", None) + for tid in metrics_matrix.columns + ] + + validation_df = pd.DataFrame(validation_data) + + return validation_df + + +if __name__ == "__main__": + from policyengine_us import Microsimulation + + # Download the database from Hugging Face Hub + db_uri = download_database() + + # Create metrics matrix + metrics_matrix, target_values, target_info = create_metrics_matrix( + db_uri=db_uri, + time_period=2023, + microsimulation_class=Microsimulation, + dataset="hf://policyengine/policyengine-us-data/cps_2023.h5", + reform_id=0, + ) + + # Validate the matrix + validation_results = validate_metrics_matrix( + metrics_matrix, target_values, target_info=target_info + ) + + print("\nValidation Results Summary:") + print(f"Total targets: {len(validation_results)}") + print(f"Active targets: {validation_results['active'].sum()}") + print(validation_results) diff --git a/src/policyengine_data/calibration/target_rescaling.py b/src/policyengine_data/calibration/target_rescaling.py index cbd7199..ffea269 100644 --- a/src/policyengine_data/calibration/target_rescaling.py +++ b/src/policyengine_data/calibration/target_rescaling.py @@ -14,7 +14,7 @@ def download_database( filename: Optional[str] = "policy_data.db", - repo_id: Optional[str] = "policyengine/test", + repo_id: Optional[str] = "policyengine/policyengine-us-data", ) -> create_engine: """ Download the SQLite database from Hugging Face Hub and return the connection string. @@ -35,8 +35,9 @@ def download_database( downloaded_path = hf_hub_download( repo_id=repo_id, filename=filename, - local_dir=".", # Use "." for the current working directory + local_dir="download/", local_dir_use_symlinks=False, # Recommended to avoid symlinks + force_download=True, # Always download, ignore cache ) path = os.path.abspath(downloaded_path) logger.info(f"File downloaded successfully to: {path}") diff --git a/src/policyengine_data/calibration/target_uprating.py b/src/policyengine_data/calibration/target_uprating.py new file mode 100644 index 0000000..3a06cad --- /dev/null +++ b/src/policyengine_data/calibration/target_uprating.py @@ -0,0 +1,424 @@ +import logging +from typing import Dict, List, Optional + +import pandas as pd +from sqlalchemy import create_engine, text + +logger = logging.getLogger(__name__) + + +""" +Database connection and structure functions +""" + + +def fetch_targets( + engine, period: int, reform_id: Optional[int] = 0 +) -> pd.DataFrame: + """ + Fetch targets for a specific period, and reform scenario. + + Args: + engine: SQLAlchemy engine + period: Time period (typically year) + reform_id: Reform scenario ID (0 for baseline) + + Returns: + DataFrame with target data joined with stratum information + """ + query = """ + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.period, + t.reform_id, + t.value, + t.active, + t.tolerance, + s.stratum_group_id, + s.parent_stratum_id, + s.definition_hash + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE t.period = :period + AND t.reform_id = :reform_id + AND t.active = true + ORDER BY s.parent_stratum_id NULLS FIRST, s.stratum_group_id, s.stratum_id + """ + + return pd.read_sql( + query, + engine, + params={ + "period": period, + "reform_id": reform_id, + }, + ) + + +def get_uprating_factors( + system, + population_path: str = "calibration.gov.census.populations.total", + inflation_path: str = "gov.bls.cpi.cpi_u", + current_year: int = 2023, + start_year: Optional[int] = 2020, + end_year: Optional[int] = 2034, +): + """ + Get population growth factors and inflation factors as a DataFrame indexed to current_year = 1.000. + + Args: + system: The policy engine country system instance to retrieve uprating factors from. + population_path (str): The parameter path for population data. + inflation_path (str): The parameter path for inflation data. + current_year (int): The current year for which to retrieve factors. + start_year (Optional[int]): The start year for the range of years to retrieve factors. + end_year (Optional[int]): The end year for the range of years to retrieve factors. + + Returns: + pd.DataFrame: A DataFrame containing the population and inflation factors. + """ + # Get parameters + population = system.parameters.get_child(population_path) + cpi_u = system.parameters.get_child(inflation_path) + + # Get base year values + base_population = population(current_year) + base_cpi = cpi_u(current_year) + + # Create DataFrame + years = list(range(start_year, end_year + 1)) + population_factors = [ + round(population(year) / base_population, 3) for year in years + ] + inflation_factors = [round(cpi_u(year) / base_cpi, 3) for year in years] + + df = pd.DataFrame( + { + "Year": years, + "Population_factor": population_factors, + "Inflation_factor": inflation_factors, + } + ) + + return df + + +""" +Uprating calculation functions +""" + + +def calculate_uprating_factor( + uprating_factors_df: pd.DataFrame, + from_year: int, + to_year: int, + factor_type: str = "inflation", +) -> float: + """ + Calculate uprating factor from one year to another. + + Args: + uprating_factors_df: DataFrame with uprating factors + from_year: Source year + to_year: Target year + factor_type: Type of factor to use ('inflation' or 'population') + + Returns: + Uprating factor to apply + """ + factor_column = f"{factor_type.title()}_factor" + + from_factor = uprating_factors_df[ + uprating_factors_df["Year"] == from_year + ][factor_column] + to_factor = uprating_factors_df[uprating_factors_df["Year"] == to_year][ + factor_column + ] + + if from_factor.empty or to_factor.empty: + logger.warning( + f"Missing {factor_type} factor for year {from_year} or {to_year}" + ) + return 1.0 + + return float(to_factor.iloc[0] / from_factor.iloc[0]) + + +def uprate_targets_for_period( + targets_df: pd.DataFrame, + uprating_factors_df: pd.DataFrame, + from_period: int, + to_period: int, + factor_type: str = "inflation", +) -> pd.DataFrame: + """ + Uprate all targets from one period to another using specified factor type. + + Args: + targets_df: DataFrame with target data + uprating_factors_df: DataFrame with uprating factors + from_period: Source period (year) + to_period: Target period (year) + factor_type: Type of uprating factor ('inflation' or 'population') + + Returns: + DataFrame with uprated targets + """ + uprated_df = targets_df.copy() + + uprating_factor = calculate_uprating_factor( + uprating_factors_df, from_period, to_period, factor_type + ) + + uprated_df["uprated_value"] = targets_df["value"] * uprating_factor + uprated_df["uprating_factor"] = uprating_factor + uprated_df["original_period"] = from_period + uprated_df["uprated_period"] = to_period + uprated_df["factor_type"] = factor_type + + uprated_df["period"] = to_period + + logger.info( + f"Uprated {len(targets_df)} targets from {from_period} to {to_period} using {factor_type} factor ({uprating_factor:.4f})" + ) + + return uprated_df + + +""" +Functions for preparing and updating database +""" + + +def prepare_insert_data(uprated_df: pd.DataFrame) -> List[Dict]: + """Prepare data for database insertion of uprated targets.""" + inserts = [] + for _, row in uprated_df.iterrows(): + inserts.append( + { + "stratum_id": row["stratum_id"], + "variable": row["variable"], + "period": row["uprated_period"], + "reform_id": row["reform_id"], + "value": row["uprated_value"], + "active": True, + "tolerance": row["tolerance"], + } + ) + return inserts + + +def insert_uprated_targets_in_db(engine, inserts: List[Dict]) -> int: + """ + Insert uprated target values as new records in the database. + + Returns: + Number of records inserted + """ + if not inserts: + return 0 + + with engine.begin() as conn: + for insert in inserts: + # Check if target already exists for this combination + check_query = text( + """ + SELECT target_id FROM targets + WHERE stratum_id = :stratum_id + AND variable = :variable + AND period = :period + AND reform_id = :reform_id + """ + ) + + result = conn.execute(check_query, insert) + existing = result.fetchone() + + if existing: + # Update existing target + update_query = text( + """ + UPDATE targets + SET value = :value, tolerance = :tolerance, active = :active + WHERE target_id = :target_id + """ + ) + conn.execute( + update_query, {**insert, "target_id": existing[0]} + ) + else: + # Insert new target + insert_query = text( + """ + INSERT INTO targets (stratum_id, variable, period, reform_id, value, active, tolerance) + VALUES (:stratum_id, :variable, :period, :reform_id, :value, :active, :tolerance) + """ + ) + conn.execute(insert_query, insert) + + return len(inserts) + + +""" +Main uprating function to be called externally +""" + + +def uprate_calibration_targets( + system, + db_uri: str, + from_period: int, + to_period: int, + variable: Optional[str] = None, + reform_id: Optional[int] = 0, + population_path: str = "calibration.gov.census.populations.total", + inflation_path: str = "gov.bls.cpi.cpi_u", + update_database: Optional[bool] = False, +) -> pd.DataFrame: + """ + Main function to uprate calibration targets from one period to another. + + Automatically selects uprating factor based on variable name: + - Variables containing "_count" use population uprating factor + - All other variables use inflation uprating factor + + Args: + system: Tax benefit system object from which to retrieve uprating parameters + db_uri: Database connection string + from_period: Source period (year) to uprate from + to_period: Target period (year) to uprate to + variable: Target variable to uprate (None = all variables) + reform_id: Reform scenario ID (0 for baseline) + population_path: Parameter path for population data + inflation_path: Parameter path for inflation data + update_database: If True, insert uprated targets into database + + Returns: + DataFrame with original and uprated values + """ + # Connect to database + engine = create_engine(db_uri) + + # Get uprating factors + uprating_factors_df = get_uprating_factors( + system=system, + population_path=population_path, + inflation_path=inflation_path, + current_year=from_period, + start_year=min(from_period, to_period), + end_year=max(from_period, to_period), + ) + + # Fetch targets for the source period + targets_df = fetch_targets(engine, from_period, reform_id) + + if targets_df.empty: + logger.warning(f"No targets found for period {from_period}") + return pd.DataFrame() + + # Filter by variable if specified + if variable is not None: + targets_df = targets_df[targets_df["variable"] == variable] + if targets_df.empty: + logger.warning( + f"No targets found for variable '{variable}' in period {from_period}" + ) + return pd.DataFrame() + + logger.info( + f"Found {len(targets_df)} targets to uprate from {from_period} to {to_period}" + ) + + # Group targets by variable and apply appropriate uprating factor + all_uprated_dfs = [] + + for var_name in targets_df["variable"].unique(): + var_targets = targets_df[targets_df["variable"] == var_name] + + if "_count" in var_name: + factor_type = "population" + logger.info( + f"Using population factor for variable '{var_name}' (contains '_count')" + ) + else: + factor_type = "inflation" + logger.info(f"Using inflation factor for variable '{var_name}'") + + uprated_var_df = uprate_targets_for_period( + var_targets, + uprating_factors_df, + from_period, + to_period, + factor_type, + ) + + all_uprated_dfs.append(uprated_var_df) + + # Combine all uprated results + if all_uprated_dfs: + uprated_df = pd.concat(all_uprated_dfs, ignore_index=True) + else: + logger.warning("No targets to uprate") + return pd.DataFrame() + + results_df = uprated_df[ + [ + "target_id", + "stratum_id", + "stratum_group_id", + "parent_stratum_id", + "variable", + "original_period", + "uprated_period", + "reform_id", + "value", + "uprated_value", + "uprating_factor", + "factor_type", + "tolerance", + ] + ].copy() + + # Update database if requested + if update_database: + inserts = prepare_insert_data(uprated_df) + inserted_count = insert_uprated_targets_in_db(engine, inserts) + logger.info( + f"Inserted/updated {inserted_count} uprated targets in database" + ) + else: + logger.info( + "Update database was set to False - no database updates performed" + ) + + logger.info(f"Total targets uprated: {len(results_df)}") + + return results_df + + +if __name__ == "__main__": + from policyengine_us.system import system + + from policyengine_data.calibration.target_rescaling import ( + download_database, + ) + + # Connection to database in huggingface hub + db_uri = download_database() + + # Example: uprate 2022 targets to 2023 + results = uprate_calibration_targets( + system=system, db_uri=db_uri, from_period=2022, to_period=2023 + ) + + print("\nUprating Results:") + print(results) + + # Show factor type breakdown + if not results.empty: + print(f"\nFactor Type Summary:") + factor_summary = results.groupby("factor_type")["variable"].unique() + for factor_type, variables in factor_summary.items(): + print(f"{factor_type.title()} factor used for: {list(variables)}") diff --git a/src/policyengine_data/calibration/utils.py b/src/policyengine_data/calibration/utils.py new file mode 100644 index 0000000..d404c31 --- /dev/null +++ b/src/policyengine_data/calibration/utils.py @@ -0,0 +1,84 @@ +""" +Additional utilities for the calibration process. +""" + +from typing import Dict, List + +import numpy as np +import torch + + +def create_geographic_normalization_factor( + geo_hierarchy: List[str], + target_info: Dict[int, Dict[str, any]], +) -> torch.Tensor: + """ + Create a normalization factor for the calibration process to balance targets that belong to different geographic areas or concepts. + + Args: + geo_hierarchy (List[str]): Geographic hierarchy levels' codes (e.g., ["0100000US", "0400000US", "0500000US"]). Make sure to pass the part of the code general to all areas within a given level. + target_info (Dict[int, Dict[str, any]]): A dictionary containing information about each target, including its name which denotes geographic area and its active status. + + Returns: + normalization_factor (torch.Tensor): Normalization factor for each active target. + """ + is_active = [] + geo_codes = [] + geo_level_sum = {} + + for code in geo_hierarchy: + geo_level_sum[code] = 0 + + # First pass: collect active status and geo codes for all targets + for target_id, info in target_info.items(): + is_active.append(info["active"]) + target_name = info["name"] + matched_geo = None + + for code in geo_hierarchy: + if code in target_name: + matched_geo = code + if info["active"]: + geo_level_sum[code] += 1 + break + + geo_codes.append(matched_geo) + + is_active = torch.tensor(is_active, dtype=torch.float32) + normalization_factor = torch.zeros_like(is_active) + + # Assign normalization factors based on geo level for each target + for i, (is_target_active, geo_code) in enumerate( + zip(is_active, geo_codes) + ): + if ( + is_target_active + and geo_code is not None + and geo_level_sum[geo_code] > 0 + ): + normalization_factor[i] = 1.0 / geo_level_sum[geo_code] + + # Check if only one geographic level is represented among active targets + active_geo_levels = set() + for i, is_target_active in enumerate(is_active): + if is_target_active and geo_codes[i] is not None: + active_geo_levels.add(geo_codes[i]) + + # If no matching geo codes for active targets, return zeros for active targets + if len(active_geo_levels) == 0: + active_factors = torch.zeros(sum(is_active.bool())) + return active_factors + + # If only one geographic level is present, return tensor of ones for active targets + if len(active_geo_levels) <= 1: + normalization_factor = torch.where( + is_active.bool(), torch.tensor(1.0), torch.tensor(0.0) + ) + else: + # Apply mean normalization for multiple geographic levels + active_factors = normalization_factor[is_active.bool()] + if len(active_factors) > 0 and active_factors.sum() > 0: + inv_mean_norm = 1.0 / active_factors.mean() + normalization_factor = normalization_factor * inv_mean_norm + + return normalization_factor[is_active.bool()] diff --git a/src/policyengine_data/dataset_legacy.py b/src/policyengine_data/dataset_legacy.py index ad516f8..040e693 100644 --- a/src/policyengine_data/dataset_legacy.py +++ b/src/policyengine_data/dataset_legacy.py @@ -13,8 +13,9 @@ import numpy as np import pandas as pd import requests -from policyengine_core.tools.hugging_face import * -from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager + +from policyengine_data.tools.hugging_face import * +from policyengine_data.tools.win_file_manager import WindowsAtomicFileManager def atomic_write(file: Path, content: bytes) -> None: diff --git a/src/policyengine_data/normalise_keys.py b/src/policyengine_data/normalise_keys.py index 49189b4..cfd317f 100644 --- a/src/policyengine_data/normalise_keys.py +++ b/src/policyengine_data/normalise_keys.py @@ -15,7 +15,7 @@ def normalise_table_keys( tables: Dict[str, pd.DataFrame], primary_keys: Dict[str, str], foreign_keys: Optional[Dict[str, Dict[str, str]]] = None, - start_index: Optional[int] = 0, + start_index: Optional[Dict[str, int]] = None, ) -> Dict[str, pd.DataFrame]: """ Normalise primary and foreign keys across multiple tables to zero-based indices. @@ -31,7 +31,7 @@ def normalise_table_keys( relationships. Format: {table_name: {fk_column: referenced_table}} If None, foreign keys will be auto-detected based on column names matching primary key names from other tables. - start_index: Starting index for normalisation (default: 0). + start_index: Dictionary mapping table names to their starting index for normalisation (default: 0). Returns: Dictionary of normalised tables with `start_index`-based integer keys @@ -56,6 +56,9 @@ def normalise_table_keys( if not tables: return {} + if not start_index: + start_index = {} + if foreign_keys is None: foreign_keys = _auto_detect_foreign_keys(tables, primary_keys) @@ -79,8 +82,10 @@ def normalise_table_keys( # Get unique values and create zero-based mapping unique_keys = df[pk_column].unique() key_mappings[table_name] = { - old_key: new_key + start_index - for new_key, old_key in enumerate(unique_keys) + old_key: new_key + for new_key, old_key in enumerate( + unique_keys, start=start_index.get(table_name, 0) + ) } # Second pass: apply mappings to all tables diff --git a/src/policyengine_data/tools/__init__.py b/src/policyengine_data/tools/__init__.py index 59fb25b..4a350a1 100644 --- a/src/policyengine_data/tools/__init__.py +++ b/src/policyengine_data/tools/__init__.py @@ -1,2 +1,3 @@ from .hugging_face import download_huggingface_dataset, get_or_prompt_hf_token +from .legacy_class_conversions import SingleYearDataset_to_Dataset from .win_file_manager import WindowsAtomicFileManager diff --git a/src/policyengine_data/tools/legacy_class_conversions.py b/src/policyengine_data/tools/legacy_class_conversions.py new file mode 100644 index 0000000..c0fb17a --- /dev/null +++ b/src/policyengine_data/tools/legacy_class_conversions.py @@ -0,0 +1,73 @@ +""" +Utilities to convert back from SingleYearDataset to the legacy Dataset class. +""" + +from pathlib import Path +from typing import Union + +import h5py +import numpy as np + +from ..single_year_dataset import SingleYearDataset + + +def SingleYearDataset_to_Dataset( + dataset: SingleYearDataset, + output_path: Union[str, Path], + time_period: int = None, +) -> None: + """ + Convert a SingleYearDataset to legacy Dataset format and save as h5 file. + + This function loads entity tables from a SingleYearDataset, separates them into + variable arrays, and saves them in the legacy ARRAYS format used + by the legacy Dataset class. + + Args: + dataset: SingleYearDataset instance with entity tables + output_path: Path where to save the h5 file + time_period: Time period for the data (defaults to dataset.time_period) + """ + if time_period is None: + time_period = dataset.time_period + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert entity tables to variable arrays dictionary with proper type handling + variable_arrays = {} + + for entity_name, entity_df in dataset.entities.items(): + # Extract each column as a separate variable array + for column_name in entity_df.columns: + values = entity_df[column_name].values + + # Handle special data type conversions following CPS pattern + if values.dtype == object: + # Try to determine if this should be string or numeric + try: + # Check if it's actually string data that should be encoded + if hasattr(values, "decode_to_str"): + values = values.decode_to_str().astype("S") + elif column_name == "county_fips": + values = values.astype("int32") + else: + # For other object types, try to preserve as string + values = np.array(values, dtype="S") + except: + # Fallback: convert to string + values = np.array( + [str(v).encode() for v in values], dtype="S" + ) + + variable_arrays[column_name] = values + + # Save in ARRAYS format (direct variable datasets) + with h5py.File(output_path, "w") as f: + for variable_name, values in variable_arrays.items(): + try: + # Store each variable directly as a dataset (no time period grouping) + f.create_dataset(variable_name, data=values) + except Exception as e: + print(f" Warning: Could not save {variable_name}: {e}") + continue diff --git a/tests/test_calibration/test_calibration.py b/tests/test_calibration/test_calibration.py new file mode 100644 index 0000000..1fecc92 --- /dev/null +++ b/tests/test_calibration/test_calibration.py @@ -0,0 +1,208 @@ +""" +Test the calibration logic for different geographic levels that integrates all other calibration pipeline components. +""" + +import pytest + +areas_in_national_level = { + "United States": "0100000US", +} + +areas_in_state_level = { + "Alabama": "0400000US01", + "Alaska": "0400000US02", + "Arizona": "0400000US04", + "Arkansas": "0400000US05", + "California": "0400000US06", + "Colorado": "0400000US08", + "Connecticut": "0400000US09", + "Delaware": "0400000US10", + "District of Columbia": "0400000US11", + "Florida": "0400000US12", + "Georgia": "0400000US13", + "Hawaii": "0400000US15", + "Idaho": "0400000US16", + "Illinois": "0400000US17", + "Indiana": "0400000US18", + "Iowa": "0400000US19", + "Kansas": "0400000US20", + "Kentucky": "0400000US21", + "Louisiana": "0400000US22", + "Maine": "0400000US23", + "Maryland": "0400000US24", + "Massachusetts": "0400000US25", + "Michigan": "0400000US26", + "Minnesota": "0400000US27", + "Mississippi": "0400000US28", + "Missouri": "0400000US29", + "Montana": "0400000US30", + "Nebraska": "0400000US31", + "Nevada": "0400000US32", + "New Hampshire": "0400000US33", + "New Jersey": "0400000US34", + "New Mexico": "0400000US35", + "New York": "0400000US36", + "North Carolina": "0400000US37", + "North Dakota": "0400000US38", + "Ohio": "0400000US39", + "Oklahoma": "0400000US40", + "Oregon": "0400000US41", + "Pennsylvania": "0400000US42", + "Rhode Island": "0400000US44", + "South Carolina": "0400000US45", + "South Dakota": "0400000US46", + "Tennessee": "0400000US47", + "Texas": "0400000US48", + "Utah": "0400000US49", + "Vermont": "0400000US50", + "Virginia": "0400000US51", + "Washington": "0400000US53", + "West Virginia": "0400000US54", + "Wisconsin": "0400000US55", + "Wyoming": "0400000US56", +} + + +def test_calibration_per_geographic_level_iteration(): + """ + Test and example of the calibration routine involving calibrating one geographic level at a time from lowest to highest in the hierarchy and generating sparsity in all but the last levels. + + Conversion between dataset class types is necessary until full migration to the new SingleYearDataset class in the policyengine_core repository. + """ + from policyengine_us import Microsimulation + from policyengine_data.tools.legacy_class_conversions import ( + SingleYearDataset_to_Dataset, + ) + from policyengine_data.calibration.target_rescaling import ( + download_database, + rescale_calibration_targets, + ) + from policyengine_data.calibration.target_uprating import ( + uprate_calibration_targets, + ) + from policyengine_data.calibration.calibrate import ( + calibrate_single_geography_level, + ) + + db_uri = download_database() + + # Rescale targets for consistency across geography areas + rescaling_results = rescale_calibration_targets( + db_uri=db_uri, update_database=True + ) + + # Uprate targets for consistency across definition year (disabled until IRS SOI variables are renamed to avoid errors) + # uprating_results = uprate_calibration_targets( + # system=system, + # db_uri=db_uri, + # from_period=2022, + # to_period=2023, + # update_database=True, + # ) + + # Calibrate the state level dataset with sparsity + state_level_calibrated_dataset = calibrate_single_geography_level( + Microsimulation, + areas_in_state_level, + "hf://policyengine/policyengine-us-data/cps_2023.h5", + dataset_subsample_size=1000, # approximately 5% of the base dataset to decrease computation costs + use_dataset_weights=False, + regularize_with_l0=True, + ) + + state_level_weights = state_level_calibrated_dataset.entities["household"][ + "household_weight" + ].values + + SingleYearDataset_to_Dataset( + state_level_calibrated_dataset, output_path="Dataset_state_level.h5" + ) + + # Calibrate the national level dataset using the previously calibrated state dataset, without sparsity, and without initial noise (trying to minimize deviation from state-calibrated weights) + national_level_calibrated_dataset = calibrate_single_geography_level( + Microsimulation, + areas_in_national_level, + dataset="Dataset_state_level.h5", + stack_datasets=False, + noise_level=0.0, + use_dataset_weights=True, # use the previously calibrated weights + regularize_with_l0=False, + ) + + national_level_weights = national_level_calibrated_dataset.entities[ + "household" + ]["household_weight"].values + + SingleYearDataset_to_Dataset( + national_level_calibrated_dataset, + output_path="Dataset_national_level.h5", + ) + + assert len(state_level_calibrated_dataset.entities["household"]) == len( + national_level_calibrated_dataset.entities["household"] + ), "Household record counts do not match after national calibration." + + assert ( + state_level_weights - national_level_weights + ).sum() > 0, "Household weights do not differ between state and national levels, suggesting national calibration was unsucessful." + + +def test_calibration_combining_all_levels_at_once(): + """ + Test and example of the calibration routine involving stacking datasets at a single (most often lowest) geographic level for increased data richness and then calibrating said stacked dataset for all geographic levels at once. + + Conversion between dataset class types is necessary until full migration to the new SingleYearDataset class in the policyengine_core repository. + """ + from policyengine_us import Microsimulation + from policyengine_data.tools.legacy_class_conversions import ( + SingleYearDataset_to_Dataset, + ) + from policyengine_data.calibration.target_rescaling import ( + download_database, + rescale_calibration_targets, + ) + from policyengine_data.calibration.target_uprating import ( + uprate_calibration_targets, + ) + from policyengine_data.calibration.calibrate import ( + calibrate_all_levels, + ) + + db_uri = download_database() + + # Rescale targets for consistency across geography areas + rescaling_results = rescale_calibration_targets( + db_uri=db_uri, update_database=True + ) + + # Uprate targets for consistency across definition year (disabled until IRS SOI variables are renamed to avoid errors) + # uprating_results = uprate_calibration_targets( + # system=system, + # db_uri=db_uri, + # from_period=2022, + # to_period=2023, + # update_database=True, + # ) + + # Calibrate the full dataset at once (only passing the identifyers of the areas for which the base dataset will be stacked) + fully_calibrated_dataset = calibrate_all_levels( + Microsimulation, + areas_in_state_level, + "hf://policyengine/policyengine-us-data/cps_2023.h5", + geo_hierarchy=["0100000US", "0400000US"], + dataset_subsample_size=1000, + regularize_with_l0=True, + raise_error=False, # this will avoid raising an error if some targets have no records contributing to them (given sampling) + ) + + weights = fully_calibrated_dataset.entities["household"][ + "household_weight" + ].values + + SingleYearDataset_to_Dataset( + fully_calibrated_dataset, output_path="Dataset_fully_calibrated.h5" + ) + + assert len(weights) < 1000 * len( + areas_in_state_level + ), "Weight vector length should be less than the sampled 1000 per area after regularization." diff --git a/tests/test_calibration/test_dataset_duplication.py b/tests/test_calibration/test_dataset_duplication.py new file mode 100644 index 0000000..a8e9987 --- /dev/null +++ b/tests/test_calibration/test_dataset_duplication.py @@ -0,0 +1,152 @@ +""" +Test the logic for assigning a dataset to a geographic level and minimizing it. +""" + +from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( + UCGID, +) +from policyengine_data import SingleYearDataset + + +def test_dataset_assignment_to_geography() -> None: + """Test that a dataset can be assigned to a geographic level without errors.""" + from policyengine_us import Microsimulation + from policyengine_data.calibration import load_dataset_for_geography_legacy + + sim = load_dataset_for_geography_legacy(Microsimulation) + + assert hasattr(sim, "dataset") + assert hasattr(sim, "default_input_period") + assert sim.default_input_period == 2023 + + # Verify household data exists + household_ids = sim.calculate("household_id").values + assert len(household_ids) > 0 + + # Verify geography is set correctly + ucgid_values = sim.calculate("ucgid").values + expected_ucgid = UCGID("0100000US") + # The system returns enum names as strings, so compare with the name + assert all(val == expected_ucgid.name for val in ucgid_values) + + # Test with California state identifier + california_ucgid = UCGID("0400000US06") + sim = load_dataset_for_geography_legacy( + Microsimulation, geography_identifier=california_ucgid + ) + + # Verify geography is set correctly + ucgid_values = sim.calculate("ucgid").values + # The system returns enum names as strings, so compare with the name + assert all(val == california_ucgid.name for val in ucgid_values) + + +def test_dataset_minimization() -> None: + """Test that a dataset can be minimized using sparse weights.""" + from policyengine_data.calibration import ( + minimize_calibrated_dataset_legacy, + ) + from policyengine_us import Microsimulation + import pandas as pd + + # Load the dataset + sim = Microsimulation( + dataset="hf://policyengine/policyengine-us-data/cps_2023.h5" + ) + sim.default_input_period = 2023 + sim.build_from_dataset() + + before_minimizing = SingleYearDataset.from_simulation( + sim, time_period=2023 + ) + before_minimizing.time_period = 2023 + + # Create dummy sparse weights + household_ids = sim.calculate("household_id").values + optimized_sparse_weights = pd.Series( + [1.0] * (len(household_ids) // 2) + + [0.0] * (len(household_ids) - (len(household_ids) // 2)) + ) + + # Get age values before minimization for comparison + age_before = sim.calculate("age", 2023).values + + # Minimize the dataset + after_minimizing = minimize_calibrated_dataset_legacy( + Microsimulation, + sim, + year=2023, + optimized_weights=optimized_sparse_weights, + ) + + assert len(before_minimizing.entities["household"]) > len( + after_minimizing.entities["household"] + ) + assert ( + abs( + len(before_minimizing.entities["household"]) + - 2 * len(after_minimizing.entities["household"]) + ) + < 2 + ) + + # Check that age values did not change for the records that were kept + age_after = after_minimizing.entities["person"]["age"].values + kept_person_ids = after_minimizing.entities["person"]["person_id"].values + + # Find the indices of these person IDs in the original dataset + original_person_ids = before_minimizing.entities["person"][ + "person_id" + ].values + kept_indices = [ + i + for i, pid in enumerate(original_person_ids) + if pid in kept_person_ids + ] + + # Compare age values for kept records + age_before_kept = age_before[kept_indices] + assert pd.Series(age_before_kept).equals( + pd.Series(age_after) + ), "Age values should not change for records that were kept" + + +def test_dataset_subsampling() -> None: + """Test that dataset subsampling works correctly.""" + from policyengine_us import Microsimulation + from policyengine_data.calibration import load_dataset_for_geography_legacy + + # Load full dataset first + sim_full = load_dataset_for_geography_legacy(Microsimulation) + full_households = len(sim_full.calculate("household_id").unique()) + + # Test subsampling with a smaller size + subsample_size = min( + 100, full_households // 2 + ) # Ensure we're actually reducing the size + sim_subsampled = load_dataset_for_geography_legacy( + Microsimulation, dataset_subsample_size=subsample_size + ) + + subsampled_households = len( + sim_subsampled.calculate("household_id").unique() + ) + + # Verify the subsampled dataset has the expected number of households + assert ( + subsampled_households == subsample_size + ), f"Expected {subsample_size} households, got {subsampled_households}" + + # Verify geography is still set correctly after subsampling + expected_ucgid = UCGID("0100000US") + ucgid_values = sim_subsampled.calculate("ucgid").values + assert all(val == expected_ucgid.name for val in ucgid_values) + + # Test with a subsample size larger than available households (should return original) + sim_large_subsample = load_dataset_for_geography_legacy( + Microsimulation, dataset_subsample_size=full_households + 1000 + ) + large_subsample_households = len( + sim_large_subsample.calculate("household_id").unique() + ) + assert large_subsample_households == full_households diff --git a/tests/test_calibration/test_matrix_creation.py b/tests/test_calibration/test_matrix_creation.py new file mode 100644 index 0000000..fccef74 --- /dev/null +++ b/tests/test_calibration/test_matrix_creation.py @@ -0,0 +1,385 @@ +""" +Test the logic for creating an estimate matrix from a database. +""" + +import numpy as np +import pandas as pd +import pytest + + +def test_matrix_creation() -> None: + from policyengine_us import Microsimulation + from policyengine_data.calibration import ( + create_metrics_matrix, + validate_metrics_matrix, + download_database, + ) + + # Download database from Hugging Face Hub + db_uri = download_database() + + # Create metrics matrix + metrics_matrix, target_values, target_info = create_metrics_matrix( + db_uri=db_uri, + time_period=2023, + microsimulation_class=Microsimulation, + dataset="hf://policyengine/policyengine-us-data/cps_2023.h5", + reform_id=0, + ) + + # Validate the matrix (it will raise an error if matrix creation failed) + validation_results = validate_metrics_matrix( + metrics_matrix, target_values, target_info=target_info + ) + + +def test_parse_constraint_value(): + """Test parsing constraint values from strings.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + parse_constraint_value, + ) + + # Test boolean values + assert parse_constraint_value("true", "equals") == True + assert parse_constraint_value("false", "equals") == False + assert parse_constraint_value("True", "equals") == True + assert parse_constraint_value("FALSE", "equals") == False + + # Test integer values + assert parse_constraint_value("42", "equals") == 42 + assert parse_constraint_value("0", "equals") == 0 + assert parse_constraint_value("-10", "greater_than") == -10 + + # Test float values + assert parse_constraint_value("3.14", "less_than") == 3.14 + assert parse_constraint_value("0.0", "equals") == 0 + + # Test list values for "in" operation + result = parse_constraint_value("apple,banana,cherry", "in") + assert result == ["apple", "banana", "cherry"] + + # Test string values + assert parse_constraint_value("hello", "equals") == "hello" + assert parse_constraint_value("test_string", "not_equals") == "test_string" + + +def test_apply_single_constraint(): + """Test applying single constraints to create boolean masks.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + apply_single_constraint, + ) + + # Test data + values = np.array([1, 2, 3, 4, 5]) + + # Test equals operation + mask = apply_single_constraint(values, "equals", 3) + expected = np.array([False, False, True, False, False]) + np.testing.assert_array_equal(mask, expected) + + # Test greater_than operation + mask = apply_single_constraint(values, "greater_than", 3) + expected = np.array([False, False, False, True, True]) + np.testing.assert_array_equal(mask, expected) + + # Test less_than_or_equal operation + mask = apply_single_constraint(values, "less_than_or_equal", 3) + expected = np.array([True, True, True, False, False]) + np.testing.assert_array_equal(mask, expected) + + # Test not_equals operation + mask = apply_single_constraint(values, "not_equals", 3) + expected = np.array([True, True, False, True, True]) + np.testing.assert_array_equal(mask, expected) + + # Test "in" operation with string values + str_values = np.array(["apple", "banana", "cherry", "date"]) + mask = apply_single_constraint(str_values, "in", "an") + expected = np.array([False, True, False, False]) # "an" is in "banana" + np.testing.assert_array_equal(mask, expected) + + # Test "in" operation with list + mask = apply_single_constraint(str_values, "in", ["app", "che"]) + expected = np.array( + [True, False, True, False] + ) # "app" in "apple", "che" in "cherry" + np.testing.assert_array_equal(mask, expected) + + # Test invalid operation + with pytest.raises(ValueError, match="Unknown operation"): + apply_single_constraint(values, "invalid_op", 3) + + +def test_parse_constraint_for_name(): + """Test parsing constraints into human-readable names.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + parse_constraint_for_name, + ) + + # Test different operations + constraint_data = [ + ( + { + "constraint_variable": "age", + "operation": "equals", + "value": "30", + }, + "age=30", + ), + ( + { + "constraint_variable": "income", + "operation": "greater_than", + "value": "50000", + }, + "income>50000", + ), + ( + { + "constraint_variable": "score", + "operation": "less_than_or_equal", + "value": "100", + }, + "score<=100", + ), + ( + { + "constraint_variable": "status", + "operation": "not_equals", + "value": "active", + }, + "status!=active", + ), + ( + { + "constraint_variable": "category", + "operation": "in", + "value": "A,B,C", + }, + "category_in_A_B_C", + ), + ] + + for constraint_dict, expected in constraint_data: + constraint = pd.Series(constraint_dict) + result = parse_constraint_for_name(constraint) + assert result == expected, f"Expected {expected}, got {result}" + + +def test_build_target_name(): + """Test building descriptive target names.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + build_target_name, + ) + + # Test with no constraints + assert build_target_name("population", pd.DataFrame()) == "population" + + # Test with single constraint + constraints = pd.DataFrame( + [{"constraint_variable": "age", "operation": "equals", "value": "30"}] + ) + result = build_target_name("income", constraints) + assert result == "income_age=30" + + # Test with multiple constraints (should be sorted) + constraints = pd.DataFrame( + [ + { + "constraint_variable": "state", + "operation": "equals", + "value": "CA", + }, + { + "constraint_variable": "age", + "operation": "greater_than", + "value": "18", + }, + ] + ) + result = build_target_name("population", constraints) + assert result == "population_age>18_state=CA" + + # Test with ucgid constraint (should come first) + constraints = pd.DataFrame( + [ + { + "constraint_variable": "age", + "operation": "equals", + "value": "25", + }, + { + "constraint_variable": "ucgid", + "operation": "equals", + "value": "123456", + }, + ] + ) + result = build_target_name("count", constraints) + assert result == "count_ucgid=123456_age=25" + + +def test_validate_metrics_matrix(): + """Test validating metrics matrix with synthetic data.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + validate_metrics_matrix, + ) + + # Create synthetic metrics matrix + np.random.seed(42) # For reproducible results + n_households = 100 + n_targets = 5 + + # Create metrics matrix with known values + metrics_data = np.random.rand(n_households, n_targets) * 1000 + household_ids = range(1000, 1000 + n_households) + target_ids = range(1, n_targets + 1) + + metrics_matrix = pd.DataFrame( + data=metrics_data, index=household_ids, columns=target_ids + ) + + # Create target values + target_values = np.array([5000, 3000, 8000, 1500, 6000]) + + # Create target info + target_info = { + 1: {"name": "target_1", "active": True, "tolerance": 0.1}, + 2: {"name": "target_2", "active": True, "tolerance": 0.05}, + 3: {"name": "target_3", "active": False, "tolerance": None}, + 4: {"name": "target_4", "active": True, "tolerance": 0.2}, + 5: {"name": "target_5", "active": True, "tolerance": 0.15}, + } + + # Test with uniform weights + validation_results = validate_metrics_matrix( + metrics_matrix, target_values, target_info=target_info + ) + + # Check structure + assert len(validation_results) == n_targets + assert set(validation_results.columns) == { + "target_id", + "target_value", + "estimate", + "absolute_error", + "relative_error", + "name", + "active", + "tolerance", + } + + # Check values + assert ( + validation_results["target_value"].tolist() == target_values.tolist() + ) + assert validation_results["name"].tolist() == [ + "target_1", + "target_2", + "target_3", + "target_4", + "target_5", + ] + assert validation_results["active"].tolist() == [ + True, + True, + False, + True, + True, + ] + + # Test with custom weights + custom_weights = np.random.rand(n_households) + custom_weights /= custom_weights.sum() # Normalize + + validation_results_weighted = validate_metrics_matrix( + metrics_matrix, + target_values, + weights=custom_weights, + target_info=target_info, + ) + + # Should have same structure but different estimates + assert len(validation_results_weighted) == n_targets + # Estimates should be different with different weights + assert not np.allclose( + validation_results["estimate"].values, + validation_results_weighted["estimate"].values, + ) + + # Test error raising with all-zero matrix + zero_matrix = pd.DataFrame( + data=np.zeros((n_households, n_targets)), + index=household_ids, + columns=target_ids, + ) + + with pytest.raises(ValueError, match="Record.*has all zero estimates"): + validate_metrics_matrix(zero_matrix, target_values, raise_error=True) + + # Test with household that meets no constraints (all zeros in row) + mixed_matrix = metrics_matrix.copy() + mixed_matrix.iloc[0, :] = 0 # First household meets no constraints + + with pytest.raises(ValueError, match="Record.*has all zero estimates"): + validate_metrics_matrix(mixed_matrix, target_values, raise_error=True) + + +def test_validate_metrics_matrix_zero_estimates(): + """Test validate_metrics_matrix with zero estimate columns.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + validate_metrics_matrix, + ) + + # Create a matrix where one column will have zero estimates + # but no individual records are all zero + metrics_matrix = pd.DataFrame( + data=[[1, 0], [2, 0], [3, 0]], # Second column is all zeros + index=[1, 2, 3], + columns=[101, 102], + ) + target_values = np.array([6, 1]) # Second target will have zero estimate + + with pytest.raises(ValueError, match="estimate.*contain zero values"): + validate_metrics_matrix( + metrics_matrix, target_values, raise_error=True + ) + + +def test_validate_metrics_matrix_without_target_info(): + """Test validate_metrics_matrix without target_info parameter.""" + from policyengine_data.calibration.metrics_matrix_creation import ( + validate_metrics_matrix, + ) + + # Simple test case + metrics_matrix = pd.DataFrame( + data=[[10, 20], [30, 40], [50, 60]], + index=[1, 2, 3], + columns=[101, 102], + ) + target_values = np.array([90, 120]) # Sum of columns: [90, 120] + + validation_results = validate_metrics_matrix(metrics_matrix, target_values) + + # Check basic columns exist + expected_cols = { + "target_id", + "target_value", + "estimate", + "absolute_error", + "relative_error", + } + assert set(validation_results.columns) == expected_cols + + # Check estimates (uniform weights: 1/3 each) + expected_estimates = np.array([30, 40]) # (10+30+50)/3, (20+40+60)/3 + np.testing.assert_array_almost_equal( + validation_results["estimate"].values, expected_estimates + ) + + # Check errors + expected_abs_errors = np.abs(expected_estimates - target_values) + np.testing.assert_array_almost_equal( + validation_results["absolute_error"].values, expected_abs_errors + ) diff --git a/tests/test_calibration/test_normalization_factor.py b/tests/test_calibration/test_normalization_factor.py new file mode 100644 index 0000000..8f78766 --- /dev/null +++ b/tests/test_calibration/test_normalization_factor.py @@ -0,0 +1,100 @@ +""" +Test the logic for creating a normalization factor per target to balance targets represented at different concepts or geographic levels. +""" + +import torch +from policyengine_data.calibration.utils import ( + create_geographic_normalization_factor, +) + + +def test_multiple_geo_levels_normalization() -> None: + """ + Test normalization factors with multiple geographic levels. + """ + geo_hierarchy = ["0100000US", "0400000US", "0500000US"] + target_info = { + 1: {"name": "0100000US_population", "active": True}, + 2: {"name": "0400000US01_population", "active": True}, + 3: {"name": "0400000US02_population", "active": True}, + 4: {"name": "0500000US0101_population", "active": True}, + 5: {"name": "0500000US0102_population", "active": False}, + } + + normalization_factor = create_geographic_normalization_factor( + geo_hierarchy, target_info + ) + active_targets = sum( + [1 for info in target_info.values() if info["active"]] + ) + + # Should return factors for active targets + assert ( + len(normalization_factor) == active_targets + ), "Normalization factor length does not match number of active targets." + + # Active factors should have mean = 1.0 (due to mean normalization) + mean_factor = normalization_factor.mean().item() + assert ( + abs(mean_factor - 1.0) < 1e-6 + ), "Normalization factor does not have a mean of 1." + + +def test_single_geo_level_returns_ones() -> None: + """ + Test that single geographic level returns tensor of ones for active targets. + """ + geo_hierarchy = ["0100000US", "0400000US", "0500000US"] + target_info = { + 1: {"name": "0400000US01_population", "active": True}, + 2: {"name": "0400000US02_population", "active": True}, + 3: {"name": "0400000US03_population", "active": False}, + } + + normalization_factor = create_geographic_normalization_factor( + geo_hierarchy, target_info + ) + + # Active targets should have factor = 1.0 + assert normalization_factor[0] == 1.0 + assert normalization_factor[1] == 1.0 + + +def test_all_inactive_targets() -> None: + """ + Test behavior when all targets are inactive. + """ + geo_hierarchy = ["0100000US", "0400000US"] + target_info = { + 1: {"name": "0100000US_population", "active": False}, + 2: {"name": "0400000US01_population", "active": False}, + } + + normalization_factor = create_geographic_normalization_factor( + geo_hierarchy, target_info + ) + + # All factors should be zero + assert ( + len(normalization_factor) == 0 + ), "Normalization factor length should be 0 as there are no active targets." + + +def test_no_matching_geo_codes() -> None: + """ + Test behavior when target names don't match geographic codes. + """ + geo_hierarchy = ["0100000US", "0400000US"] + target_info = { + 1: {"name": "some_other_target", "active": True}, + 2: {"name": "another_target", "active": True}, + } + + normalization_factor = create_geographic_normalization_factor( + geo_hierarchy, target_info + ) + + # All factors should be zero since no geo codes match + assert torch.all( + normalization_factor == 0 + ), "Normalization factors should be zero since no geo codes match." diff --git a/tests/test_target_rescaling.py b/tests/test_calibration/test_target_rescaling.py similarity index 98% rename from tests/test_target_rescaling.py rename to tests/test_calibration/test_target_rescaling.py index fc6cc82..9885890 100644 --- a/tests/test_target_rescaling.py +++ b/tests/test_calibration/test_target_rescaling.py @@ -3,7 +3,7 @@ """ -def setup_test_database(): +def setup_test_database() -> str: """ Creates an in-memory SQLite database for testing. Populates it with a geographic hierarchy where children do not sum to the parent. @@ -146,7 +146,7 @@ def setup_test_database(): return db_uri -def test_rescale_with_geographic_scaling(): +def test_rescale_with_geographic_scaling() -> None: """ Tests that child strata (states) are correctly scaled to match the parent stratum (nation) total. diff --git a/tests/test_calibration/test_target_uprating.py b/tests/test_calibration/test_target_uprating.py new file mode 100644 index 0000000..975ae5b --- /dev/null +++ b/tests/test_calibration/test_target_uprating.py @@ -0,0 +1,788 @@ +""" +Test the logic for uprating calibration targets from one period to another. +""" + +import os +import tempfile +import pandas as pd +import pytest +from unittest.mock import Mock, patch +from sqlalchemy import create_engine + +from policyengine_data.calibration.target_uprating import ( + uprate_calibration_targets, + get_uprating_factors, + calculate_uprating_factor, + uprate_targets_for_period, + fetch_targets, + prepare_insert_data, + insert_uprated_targets_in_db, +) + + +def setup_test_database() -> str: + """ + Creates an in-memory SQLite database for testing. + Populates it with test targets for multiple years and variables. + """ + # Create temporary database file + db_fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(db_fd) + db_uri = f"sqlite:///{db_path}" + engine = create_engine(db_uri) + + # Define schema + strata_schema = """ + CREATE TABLE strata ( + stratum_id INTEGER PRIMARY KEY, + stratum_group_id INTEGER, + parent_stratum_id INTEGER, + notes TEXT, + definition_hash TEXT + ) + """ + targets_schema = """ + CREATE TABLE targets ( + target_id INTEGER PRIMARY KEY, + stratum_id INTEGER, + variable TEXT, + period INTEGER, + reform_id INTEGER, + value REAL, + active BOOLEAN, + tolerance REAL, + FOREIGN KEY(stratum_id) REFERENCES strata(stratum_id) + ) + """ + + with engine.connect() as conn: + conn.exec_driver_sql("DROP TABLE IF EXISTS targets") + conn.exec_driver_sql("DROP TABLE IF EXISTS strata") + conn.exec_driver_sql(strata_schema) + conn.exec_driver_sql(targets_schema) + + # Create test strata data + strata_data = pd.DataFrame( + [ + { + "stratum_id": 1, + "stratum_group_id": 10, + "parent_stratum_id": None, + "notes": "Total Population", + }, + { + "stratum_id": 2, + "stratum_group_id": 11, + "parent_stratum_id": 1, + "notes": "Income < 50k", + }, + { + "stratum_id": 3, + "stratum_group_id": 11, + "parent_stratum_id": 1, + "notes": "Income >= 50k", + }, + { + "stratum_id": 4, + "stratum_group_id": 20, + "parent_stratum_id": None, + "notes": "All States", + }, + ] + ) + + # Create test targets data for multiple years + targets_data = pd.DataFrame( + [ + # 2021 data (baseline year) + { + "target_id": 101, + "stratum_id": 1, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 1000000.0, + "active": True, + "tolerance": 0.01, + }, + { + "target_id": 102, + "stratum_id": 2, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 300000.0, + "active": True, + "tolerance": 0.05, + }, + { + "target_id": 103, + "stratum_id": 3, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 700000.0, + "active": True, + "tolerance": 0.05, + }, + # 2021 SNAP data + { + "target_id": 104, + "stratum_id": 1, + "variable": "snap", + "period": 2021, + "reform_id": 0, + "value": 80000.0, + "active": True, + "tolerance": 0.02, + }, + { + "target_id": 105, + "stratum_id": 2, + "variable": "snap", + "period": 2021, + "reform_id": 0, + "value": 60000.0, + "active": True, + "tolerance": 0.05, + }, + { + "target_id": 106, + "stratum_id": 3, + "variable": "snap", + "period": 2021, + "reform_id": 0, + "value": 20000.0, + "active": True, + "tolerance": 0.05, + }, + # 2021 Count data (should use population factor) + { + "target_id": 110, + "stratum_id": 1, + "variable": "household_count", + "period": 2021, + "reform_id": 0, + "value": 130000.0, + "active": True, + "tolerance": 0.02, + }, + { + "target_id": 111, + "stratum_id": 2, + "variable": "person_count", + "period": 2021, + "reform_id": 0, + "value": 320000.0, + "active": True, + "tolerance": 0.03, + }, + # 2022 data (for testing different periods) + { + "target_id": 107, + "stratum_id": 1, + "variable": "income_tax", + "period": 2022, + "reform_id": 0, + "value": 1050000.0, + "active": True, + "tolerance": 0.01, + }, + # Inactive target (should be ignored) + { + "target_id": 108, + "stratum_id": 4, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 999999.0, + "active": False, + "tolerance": 0.01, + }, + # Different reform_id (should be isolated) + { + "target_id": 109, + "stratum_id": 1, + "variable": "income_tax", + "period": 2021, + "reform_id": 1, + "value": 1100000.0, + "active": True, + "tolerance": 0.01, + }, + ] + ) + + # Load data into the database + strata_data.to_sql("strata", engine, if_exists="append", index=False) + targets_data.to_sql("targets", engine, if_exists="append", index=False) + + return db_uri, db_path + + +class TestUpratingFactors: + """Test the uprating factors calculation.""" + + def test_get_uprating_factors_structure(self): + """Test that get_uprating_factors returns the expected structure.""" + mock_system = Mock() + # Mock the system parameters + mock_pop = Mock() + mock_cpi = Mock() + mock_system.parameters.get_child.side_effect = lambda path: ( + mock_pop if "population" in path else mock_cpi + ) + mock_pop.side_effect = ( + lambda year: 100000 + (year - 2023) * 1000 + ) # Growing population + mock_cpi.side_effect = ( + lambda year: 250 + (year - 2023) * 10 + ) # Growing inflation + + factors_df = get_uprating_factors( + mock_system, current_year=2023, start_year=2021, end_year=2025 + ) + + assert isinstance(factors_df, pd.DataFrame) + assert list(factors_df.columns) == [ + "Year", + "Population_factor", + "Inflation_factor", + ] + assert len(factors_df) == 5 # 2021-2025 + assert ( + factors_df[factors_df["Year"] == 2023]["Population_factor"].iloc[0] + == 1.0 + ) # Base year + assert ( + factors_df[factors_df["Year"] == 2023]["Inflation_factor"].iloc[0] + == 1.0 + ) # Base year + + def test_calculate_uprating_factor(self): + """Test uprating factor calculation between years.""" + # Create test uprating factors DataFrame + uprating_df = pd.DataFrame( + { + "Year": [2021, 2022, 2023, 2024], + "Population_factor": [0.95, 0.97, 1.00, 1.03], + "Inflation_factor": [0.90, 0.95, 1.00, 1.05], + } + ) + + # Test inflation uprating from 2021 to 2024 + inflation_factor = calculate_uprating_factor( + uprating_df, 2021, 2024, "inflation" + ) + expected = 1.05 / 0.90 # ≈ 1.1667 + assert inflation_factor == pytest.approx(expected, rel=1e-4) + + # Test population uprating from 2022 to 2023 + pop_factor = calculate_uprating_factor( + uprating_df, 2022, 2023, "population" + ) + expected = 1.00 / 0.97 # ≈ 1.0309 + assert pop_factor == pytest.approx(expected, rel=1e-4) + + # Test missing year (should return 1.0) + missing_factor = calculate_uprating_factor( + uprating_df, 2025, 2026, "inflation" + ) + assert missing_factor == 1.0 + + +class TestTargetUprating: + """Test the target uprating functionality.""" + + def setup_method(self): + """Set up test database for each test.""" + self.db_uri, self.db_path = setup_test_database() + + def teardown_method(self): + """Clean up test database after each test.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + + def test_fetch_targets(self): + """Test fetching targets from database.""" + engine = create_engine(self.db_uri) + + # Fetch 2021 baseline targets + targets_df = fetch_targets(engine, period=2021, reform_id=0) + + assert not targets_df.empty + assert ( + len(targets_df) == 8 + ) # 3 income_tax + 3 snap + 2 count targets for 2021, reform_id=0 + assert all(targets_df["period"] == 2021) + assert all(targets_df["reform_id"] == 0) + assert all(targets_df["active"] == True) + + # Test variable filtering + income_targets = targets_df[targets_df["variable"] == "income_tax"] + assert len(income_targets) == 3 + + # Test no results for non-existent period + empty_df = fetch_targets(engine, period=2030, reform_id=0) + assert empty_df.empty + + def test_uprate_targets_for_period(self): + """Test uprating targets from one period to another.""" + # Create mock uprating factors + uprating_df = pd.DataFrame( + { + "Year": [2021, 2022, 2023, 2024], + "Population_factor": [0.95, 0.97, 1.00, 1.03], + "Inflation_factor": [0.90, 0.95, 1.00, 1.05], + } + ) + + # Create sample targets + targets_df = pd.DataFrame( + [ + { + "target_id": 1, + "stratum_id": 1, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 1000.0, + "tolerance": 0.01, + }, + { + "target_id": 2, + "stratum_id": 2, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 500.0, + "tolerance": 0.05, + }, + ] + ) + + # Uprate from 2021 to 2024 using inflation + uprated_df = uprate_targets_for_period( + targets_df, + uprating_df, + from_period=2021, + to_period=2024, + factor_type="inflation", + ) + + expected_factor = 1.05 / 0.90 # ≈ 1.1667 + + assert len(uprated_df) == 2 + for _, row in uprated_df.iterrows(): + assert row["uprating_factor"] == pytest.approx( + expected_factor, rel=1e-3 + ) + assert all(uprated_df["original_period"] == 2021) + assert all(uprated_df["uprated_period"] == 2024) + assert all(uprated_df["period"] == 2024) + assert all(uprated_df["factor_type"] == "inflation") + + # Check uprated values + assert uprated_df.iloc[0]["uprated_value"] == pytest.approx( + 1000.0 * expected_factor, rel=1e-3 + ) + assert uprated_df.iloc[1]["uprated_value"] == pytest.approx( + 500.0 * expected_factor, rel=1e-3 + ) + + @patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) + def test_uprate_calibration_targets_basic(self, mock_get_factors): + """Test the main uprating function with basic functionality.""" + # Mock uprating factors + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2021, 2024], + "Population_factor": [0.95, 1.03], + "Inflation_factor": [0.90, 1.05], + } + ) + mock_system = Mock() + + # Run uprating for income_tax (should use inflation factor) + results_df = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2024, + variable="income_tax", + reform_id=0, + update_database=False, + ) + + assert not results_df.empty + assert len(results_df) == 3 # 3 income_tax targets for 2021 + assert all(results_df["original_period"] == 2021) + assert all(results_df["uprated_period"] == 2024) + assert all(results_df["factor_type"] == "inflation") + + # Check that uprating factor was calculated correctly + expected_factor = 1.05 / 0.90 + for _, row in results_df.iterrows(): + assert row["uprating_factor"] == pytest.approx( + expected_factor, rel=1e-3 + ) + + @patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) + def test_uprate_calibration_targets_automatic_factor_selection( + self, mock_get_factors + ): + """Test automatic factor selection based on variable name.""" + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2021, 2024], + "Population_factor": [0.95, 1.03], + "Inflation_factor": [0.90, 1.05], + } + ) + mock_system = Mock() + + # Test count variables (should use population factor) + results_df = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2024, + variable="household_count", + reform_id=0, + update_database=False, + ) + + assert not results_df.empty + assert len(results_df) == 1 # 1 household_count target for 2021 + assert all(results_df["factor_type"] == "population") + + # Check that population factor was used + expected_pop_factor = 1.03 / 0.95 + for _, row in results_df.iterrows(): + assert row["uprating_factor"] == pytest.approx( + expected_pop_factor, rel=1e-3 + ) + + # Test non-count variables (should use inflation factor) + results_df_snap = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2024, + variable="snap", + reform_id=0, + update_database=False, + ) + + assert not results_df_snap.empty + assert len(results_df_snap) == 3 # 3 SNAP targets for 2021 + assert all(results_df_snap["factor_type"] == "inflation") + + # Check that inflation factor was used + expected_inflation_factor = 1.05 / 0.90 + for _, row in results_df_snap.iterrows(): + assert row["uprating_factor"] == pytest.approx( + expected_inflation_factor, rel=1e-3 + ) + + @patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) + def test_uprate_calibration_targets_mixed_factor_types( + self, mock_get_factors + ): + """Test uprating all variables with mixed factor types.""" + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2021, 2022], + "Population_factor": [0.95, 0.97], + "Inflation_factor": [0.90, 0.95], + } + ) + mock_system = Mock() + + results_df = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2022, + variable=None, # All variables + reform_id=0, + update_database=False, + ) + + assert not results_df.empty + assert len(results_df) == 8 # 3 income_tax + 3 snap + 2 count targets + + # Check all variables are present + variables = results_df["variable"].unique() + assert set(variables) == { + "income_tax", + "snap", + "household_count", + "person_count", + } + + # Check factor types are correctly assigned + count_variables = results_df[ + results_df["variable"].str.contains("_count") + ] + non_count_variables = results_df[ + ~results_df["variable"].str.contains("_count") + ] + + assert all(count_variables["factor_type"] == "population") + assert all(non_count_variables["factor_type"] == "inflation") + + # Verify different uprating factors were applied + pop_factor = 0.97 / 0.95 + inflation_factor = 0.95 / 0.90 + + # Check each count variable has the population factor + for _, row in count_variables.iterrows(): + assert row["uprating_factor"] == pytest.approx( + pop_factor, rel=1e-3 + ) + + # Check each non-count variable has the inflation factor + for _, row in non_count_variables.iterrows(): + assert row["uprating_factor"] == pytest.approx( + inflation_factor, rel=1e-3 + ) + + def test_uprate_calibration_targets_no_data(self): + """Test uprating when no targets exist for the specified period.""" + with patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) as mock_get_factors: + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2030, 2031], + "Population_factor": [1.0, 1.1], + "Inflation_factor": [1.0, 1.05], + } + ) + + mock_system = Mock() + results_df = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2030, # Non-existent period + to_period=2031, + update_database=False, + ) + + assert results_df.empty + + def test_prepare_insert_data(self): + """Test preparation of data for database insertion.""" + uprated_df = pd.DataFrame( + [ + { + "stratum_id": 1, + "variable": "income_tax", + "uprated_period": 2024, + "reform_id": 0, + "uprated_value": 1200.0, + "tolerance": 0.01, + }, + { + "stratum_id": 2, + "variable": "income_tax", + "uprated_period": 2024, + "reform_id": 0, + "uprated_value": 600.0, + "tolerance": 0.05, + }, + ] + ) + + inserts = prepare_insert_data(uprated_df) + + assert len(inserts) == 2 + assert inserts[0]["stratum_id"] == 1 + assert inserts[0]["variable"] == "income_tax" + assert inserts[0]["period"] == 2024 + assert inserts[0]["value"] == 1200.0 + assert inserts[0]["active"] == True + + def test_insert_uprated_targets_in_db(self): + """Test inserting uprated targets into database.""" + engine = create_engine(self.db_uri) + + # Prepare test data + inserts = [ + { + "stratum_id": 1, + "variable": "new_variable", + "period": 2025, + "reform_id": 0, + "value": 5000.0, + "active": True, + "tolerance": 0.01, + } + ] + + # Insert data + inserted_count = insert_uprated_targets_in_db(engine, inserts) + assert inserted_count == 1 + + # Verify insertion + result_df = pd.read_sql( + "SELECT * FROM targets WHERE variable = 'new_variable' AND period = 2025", + engine, + ) + assert len(result_df) == 1 + assert result_df.iloc[0]["value"] == 5000.0 + + def test_insert_uprated_targets_update_existing(self): + """Test updating existing targets during insertion.""" + engine = create_engine(self.db_uri) + + # Update existing target (stratum_id=1, income_tax, 2021, reform_id=0) + inserts = [ + { + "stratum_id": 1, + "variable": "income_tax", + "period": 2021, + "reform_id": 0, + "value": 9999999.0, + "active": True, + "tolerance": 0.001, + } + ] + + # This should update the existing record + updated_count = insert_uprated_targets_in_db(engine, inserts) + assert updated_count == 1 + + # Verify update + result_df = pd.read_sql( + "SELECT * FROM targets WHERE stratum_id = 1 AND variable = 'income_tax' AND period = 2021 AND reform_id = 0", + engine, + ) + assert len(result_df) == 1 + assert result_df.iloc[0]["value"] == 9999999.0 + assert result_df.iloc[0]["tolerance"] == 0.001 + + +class TestIntegration: + """Integration tests for the complete uprating workflow.""" + + def setup_method(self): + """Set up test database for integration tests.""" + self.db_uri, self.db_path = setup_test_database() + + def teardown_method(self): + """Clean up test database after integration tests.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + + @patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) + def test_complete_uprating_workflow_with_database_update( + self, mock_get_factors + ): + """Test complete workflow including database update.""" + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2021, 2025], + "Population_factor": [0.9, 1.1], + "Inflation_factor": [0.85, 1.15], + } + ) + mock_system = Mock() + + # Run complete uprating with database update + results_df = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2025, + variable="income_tax", + update_database=True, # This will insert new records + ) + + # Verify results + assert not results_df.empty + assert len(results_df) == 3 + + # Verify database was updated + engine = create_engine(self.db_uri) + new_targets = pd.read_sql( + "SELECT * FROM targets WHERE period = 2025 AND variable = 'income_tax'", + engine, + ) + assert len(new_targets) == 3 + + # Check that values were correctly uprated + expected_factor = 1.15 / 0.85 + original_total = 1000000.0 # Original total value for stratum_id=1 + expected_uprated = original_total * expected_factor + + uprated_total = new_targets[new_targets["stratum_id"] == 1][ + "value" + ].iloc[0] + assert uprated_total == pytest.approx(expected_uprated, rel=1e-4) + + @patch( + "policyengine_data.calibration.target_uprating.get_uprating_factors" + ) + def test_different_reform_scenarios_isolated(self, mock_get_factors): + """Test that different reform scenarios are processed independently.""" + mock_get_factors.return_value = pd.DataFrame( + { + "Year": [2021, 2024], + "Population_factor": [1.0, 1.05], + "Inflation_factor": [1.0, 1.10], + } + ) + mock_system = Mock() + + # Uprate baseline scenario (reform_id=0) + baseline_results = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2024, + reform_id=0, + update_database=False, + ) + + # Uprate reform scenario (reform_id=1) + reform_results = uprate_calibration_targets( + system=mock_system, + db_uri=self.db_uri, + from_period=2021, + to_period=2024, + reform_id=1, + update_database=False, + ) + + # Baseline should have more targets (8 vs 1) + assert len(baseline_results) > len(reform_results) + assert len(reform_results) == 1 # Only one target with reform_id=1 + + # But uprating factors should be the same + if not reform_results.empty: + baseline_factor = baseline_results["uprating_factor"].iloc[0] + reform_factor = reform_results["uprating_factor"].iloc[0] + assert baseline_factor == pytest.approx(reform_factor) + + +def test_error_handling(): + """Test error handling in various scenarios.""" + # Test with invalid database URI + with pytest.raises(Exception): + mock_system = Mock() + uprate_calibration_targets( + system=mock_system, + db_uri="invalid://database/path", + from_period=2021, + to_period=2024, + ) diff --git a/tests/test_conversion.py b/tests/test_conversion.py new file mode 100644 index 0000000..30a730c --- /dev/null +++ b/tests/test_conversion.py @@ -0,0 +1,110 @@ +""" +Test SingleYearDataset to legacy Dataset conversion functions. +""" + +import sys + +sys.path.insert(0, "src") + +from policyengine_data.single_year_dataset import SingleYearDataset +from policyengine_data.tools.legacy_class_conversions import ( + SingleYearDataset_to_Dataset, +) +import numpy as np +import h5py +from pathlib import Path + + +def test_conversion(): + """Test the conversion functions""" + from policyengine_us import Microsimulation + + start_year = 2023 + dataset = "hf://policyengine/policyengine-us-data/cps_2023.h5" + + # Load original CPS data + sim = Microsimulation(dataset=dataset) + single_year_dataset = SingleYearDataset.from_simulation( + sim, time_period=start_year + ) + single_year_dataset.time_period = start_year + + # Assert we have expected entities + assert ( + len(single_year_dataset.entities) == 6 + ), f"Expected 6 entities, got {len(single_year_dataset.entities)}" + expected_entities = { + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + } + actual_entities = set(single_year_dataset.entities.keys()) + assert ( + actual_entities == expected_entities + ), f"Entity mismatch: {actual_entities} vs {expected_entities}" + + # Test conversion to legacy format + output_path = Path("test_legacy_dataset.h5") + SingleYearDataset_to_Dataset( + single_year_dataset, output_path, time_period=2024 + ) + + # Assert output file was created + assert output_path.exists(), f"Output file {output_path} was not created" + + # Verify h5 file structure + with h5py.File(output_path, "r") as f: + variables = list(f.keys()) + assert ( + len(variables) > 100 + ), f"Expected >100 variables, got {len(variables)}" + + # Check that important variables exist + important_vars = [ + "person_id", + "household_id", + "age", + "employment_income_last_year", + ] + for var in important_vars: + assert ( + var in variables + ), f"Important variable {var} missing from saved file" + + # Loading back to SingleYearDataset + sim_loaded = Microsimulation(dataset=str(output_path)) + loaded_single_year_dataset = SingleYearDataset.from_simulation( + sim_loaded, time_period=2024 + ) + + # Assert loaded dataset has same entities + assert len(loaded_single_year_dataset.entities) == len( + single_year_dataset.entities + ), f"Loaded dataset has {len(loaded_single_year_dataset.entities)} entities, expected {len(single_year_dataset.entities)}" + + # Compare original and loaded data + for entity_name in single_year_dataset.entities.keys(): + assert ( + entity_name in loaded_single_year_dataset.entities + ), f"Entity {entity_name} missing in loaded dataset" + + original = single_year_dataset.entities[entity_name] + loaded = loaded_single_year_dataset.entities[entity_name] + + assert len(original) == len( + loaded + ), f"{entity_name}: Record count mismatch - original {len(original)}, loaded {len(loaded)}" + assert len(original.columns) == len( + loaded.columns + ), f"{entity_name}: Column count mismatch - original {len(original.columns)}, loaded {len(loaded.columns)}" + + common_cols = set(original.columns) & set(loaded.columns) + assert len(common_cols) == len( + original.columns + ), f"{entity_name}: Not all columns preserved. Missing: {set(original.columns) - common_cols}" + + # Clean up + output_path.unlink(missing_ok=True) diff --git a/tests/test_data_download_upload_tools.py b/tests/test_data_download_upload_tools.py index 9013fc5..cbdbd24 100644 --- a/tests/test_data_download_upload_tools.py +++ b/tests/test_data_download_upload_tools.py @@ -6,7 +6,7 @@ from tempfile import NamedTemporaryFile import sys import threading -from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager +from policyengine_data.tools.win_file_manager import WindowsAtomicFileManager import tempfile from pathlib import Path import uuid diff --git a/tests/test_normalise_keys.py b/tests/test_normalise_keys.py index 61badd9..23abb84 100644 --- a/tests/test_normalise_keys.py +++ b/tests/test_normalise_keys.py @@ -17,140 +17,161 @@ class TestNormaliseTableKeys: def test_simple_single_table(self): """Test normalisation of a single table with no foreign keys.""" - users = pd.DataFrame( - {"user_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} + persons = pd.DataFrame( + {"person_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} ) - tables = {"users": users} - primary_keys = {"users": "user_id"} + tables = {"persons": persons} + primary_keys = {"persons": "person_id"} result = normalise_table_keys(tables, primary_keys) assert len(result) == 1 - assert "users" in result + assert "persons" in result - normalised_users = result["users"] - assert list(normalised_users["user_id"]) == [0, 1, 2] - assert list(normalised_users["name"]) == ["Alice", "Bob", "Carol"] + normalised_persons = result["persons"] + assert list(normalised_persons["person_id"]) == [0, 1, 2] + assert list(normalised_persons["name"]) == ["Alice", "Bob", "Carol"] def test_custom_start_index(self): """Test normalisation with custom start index.""" - users = pd.DataFrame( - {"user_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} + persons = pd.DataFrame( + {"person_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} ) - tables = {"users": users} - primary_keys = {"users": "user_id"} + households = pd.DataFrame( + { + "household_id": [201, 205, 207], + "person_id": [105, 101, 105], + "income": [25000, 15000, 42000], + } + ) - result = normalise_table_keys(tables, primary_keys, start_index=10) + tables = {"persons": persons, "households": households} + primary_keys = {"persons": "person_id", "households": "household_id"} + foreign_keys = {"households": {"person_id": "persons"}} - assert len(result) == 1 - assert "users" in result + result = normalise_table_keys( + tables, + primary_keys, + foreign_keys, + start_index={"persons": 10, "households": 20}, + ) + + assert len(result) == 2 + assert "persons" in result + assert "households" in result - normalised_users = result["users"] - assert list(normalised_users["user_id"]) == [10, 11, 12] - assert list(normalised_users["name"]) == ["Alice", "Bob", "Carol"] + normalised_persons = result["persons"] + assert list(normalised_persons["person_id"]) == [10, 11, 12] + assert list(normalised_persons["name"]) == ["Alice", "Bob", "Carol"] + normalised_households = result["households"] + assert list(normalised_households["household_id"]) == [20, 21, 22] def test_two_tables_with_foreign_keys(self): """Test normalisation with explicit foreign key relationships.""" - users = pd.DataFrame( - {"user_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} + persons = pd.DataFrame( + {"person_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} ) - orders = pd.DataFrame( + households = pd.DataFrame( { - "order_id": [201, 205, 207], - "user_id": [105, 101, 105], - "amount": [25.99, 15.50, 42.00], + "household_id": [201, 205, 207], + "person_id": [105, 101, 105], + "income": [25000, 15000, 42000], } ) - tables = {"users": users, "orders": orders} - primary_keys = {"users": "user_id", "orders": "order_id"} - foreign_keys = {"orders": {"user_id": "users"}} + tables = {"persons": persons, "households": households} + primary_keys = {"persons": "person_id", "households": "household_id"} + foreign_keys = {"households": {"person_id": "persons"}} result = normalise_table_keys(tables, primary_keys, foreign_keys) - # Check users table - normalised_users = result["users"] - assert set(normalised_users["user_id"]) == {0, 1, 2} + # Check persons table + normalised_persons = result["persons"] + assert set(normalised_persons["person_id"]) == {0, 1, 2} - # Check orders table - normalised_orders = result["orders"] - assert set(normalised_orders["order_id"]) == {0, 1, 2} + # Check households table + normalised_households = result["households"] + assert set(normalised_households["household_id"]) == {0, 1, 2} # Check foreign key relationships are preserved - # Original: user 105 had orders 201, 207 + # Original: person 105 had households 201, 207 # After normalisation: find which index 105 became - user_105_new_id = normalised_users[normalised_users["name"] == "Bob"][ - "user_id" - ].iloc[0] - bob_orders = normalised_orders[ - normalised_orders["user_id"] == user_105_new_id + person_105_new_id = normalised_persons[ + normalised_persons["name"] == "Bob" + ]["person_id"].iloc[0] + bob_households = normalised_households[ + normalised_households["person_id"] == person_105_new_id ] - assert len(bob_orders) == 2 - assert set(bob_orders["amount"]) == {25.99, 42.00} + assert len(bob_households) == 2 + assert set(bob_households["income"]) == {25000, 42000} def test_auto_detect_foreign_keys(self): """Test automatic detection of foreign key relationships.""" - users = pd.DataFrame( - {"user_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} + persons = pd.DataFrame( + {"person_id": [101, 105, 103], "name": ["Alice", "Bob", "Carol"]} ) - orders = pd.DataFrame( + households = pd.DataFrame( { - "order_id": [201, 205, 207], - "user_id": [105, 101, 105], - "amount": [25.99, 15.50, 42.00], + "household_id": [201, 205, 207], + "person_id": [105, 101, 105], + "income": [25000, 15000, 42000], } ) - tables = {"users": users, "orders": orders} - primary_keys = {"users": "user_id", "orders": "order_id"} + tables = {"persons": persons, "households": households} + primary_keys = {"persons": "person_id", "households": "household_id"} # Test without explicit foreign keys - should auto-detect result = normalise_table_keys(tables, primary_keys) # Verify relationships are still preserved - normalised_users = result["users"] - normalised_orders = result["orders"] - - # Bob should still have his two orders - user_105_new_id = normalised_users[normalised_users["name"] == "Bob"][ - "user_id" - ].iloc[0] - bob_orders = normalised_orders[ - normalised_orders["user_id"] == user_105_new_id + normalised_persons = result["persons"] + normalised_households = result["households"] + + # Bob should still have his two households + person_105_new_id = normalised_persons[ + normalised_persons["name"] == "Bob" + ]["person_id"].iloc[0] + bob_households = normalised_households[ + normalised_households["person_id"] == person_105_new_id ] - assert len(bob_orders) == 2 + assert len(bob_households) == 2 def test_multiple_foreign_keys(self): """Test table with multiple foreign key relationships.""" - users = pd.DataFrame( - {"user_id": [1, 2, 3], "name": ["Alice", "Bob", "Carol"]} + persons = pd.DataFrame( + {"person_id": [1, 2, 3], "name": ["Alice", "Bob", "Carol"]} ) - categories = pd.DataFrame( + benefit_units = pd.DataFrame( { - "category_id": [10, 20, 30], - "category_name": ["Electronics", "Books", "Clothing"], + "benefit_unit_id": [10, 20, 30], + "benefit_type": ["Disability", "Unemployment", "Family"], } ) - orders = pd.DataFrame( + households = pd.DataFrame( { - "order_id": [100, 200, 300], - "user_id": [2, 1, 2], - "category_id": [20, 10, 30], - "amount": [25.99, 15.50, 42.00], + "household_id": [100, 200, 300], + "person_id": [2, 1, 2], + "benefit_unit_id": [20, 10, 30], + "income": [25000, 15000, 42000], } ) - tables = {"users": users, "categories": categories, "orders": orders} + tables = { + "persons": persons, + "benefit_units": benefit_units, + "households": households, + } primary_keys = { - "users": "user_id", - "categories": "category_id", - "orders": "order_id", + "persons": "person_id", + "benefit_units": "benefit_unit_id", + "households": "household_id", } result = normalise_table_keys(tables, primary_keys) @@ -161,17 +182,17 @@ def test_multiple_foreign_keys(self): assert set(df[pk_col]) == {0, 1, 2} # Verify relationships preserved - normalised_orders = result["orders"] - normalised_users = result["users"] + normalised_households = result["households"] + normalised_persons = result["persons"] - # Bob (original user_id=2) should have 2 orders - bob_new_id = normalised_users[normalised_users["name"] == "Bob"][ - "user_id" + # Bob (original person_id=2) should have 2 households + bob_new_id = normalised_persons[normalised_persons["name"] == "Bob"][ + "person_id" ].iloc[0] - bob_orders = normalised_orders[ - normalised_orders["user_id"] == bob_new_id + bob_households = normalised_households[ + normalised_households["person_id"] == bob_new_id ] - assert len(bob_orders) == 2 + assert len(bob_households) == 2 def test_empty_tables(self): """Test with empty input.""" @@ -181,8 +202,8 @@ def test_empty_tables(self): def test_missing_primary_key_column(self): """Test error handling for missing primary key column.""" df = pd.DataFrame({"name": ["Alice", "Bob"]}) - tables = {"users": df} - primary_keys = {"users": "missing_id"} + tables = {"persons": df} + primary_keys = {"persons": "missing_id"} with pytest.raises( ValueError, match="Primary key column 'missing_id' not found" @@ -191,36 +212,37 @@ def test_missing_primary_key_column(self): def test_missing_foreign_key_column(self): """Test error handling for missing foreign key column.""" - users = pd.DataFrame({"user_id": [1, 2], "name": ["Alice", "Bob"]}) - orders = pd.DataFrame( - {"order_id": [100, 200], "amount": [25.99, 15.50]} + persons = pd.DataFrame({"person_id": [1, 2], "name": ["Alice", "Bob"]}) + households = pd.DataFrame( + {"household_id": [100, 200], "income": [25000, 15000]} ) - tables = {"users": users, "orders": orders} - primary_keys = {"users": "user_id", "orders": "order_id"} - foreign_keys = {"orders": {"missing_user_id": "users"}} + tables = {"persons": persons, "households": households} + primary_keys = {"persons": "person_id", "households": "household_id"} + foreign_keys = {"households": {"missing_person_id": "persons"}} with pytest.raises( - ValueError, match="Foreign key column 'missing_user_id' not found" + ValueError, + match="Foreign key column 'missing_person_id' not found", ): normalise_table_keys(tables, primary_keys, foreign_keys) def test_missing_referenced_table(self): """Test error handling for missing referenced table.""" - orders = pd.DataFrame( + households = pd.DataFrame( { - "order_id": [100, 200], - "user_id": [1, 2], - "amount": [25.99, 15.50], + "household_id": [100, 200], + "person_id": [1, 2], + "income": [25000, 15000], } ) - tables = {"orders": orders} - primary_keys = {"orders": "order_id"} - foreign_keys = {"orders": {"user_id": "missing_users"}} + tables = {"households": households} + primary_keys = {"households": "household_id"} + foreign_keys = {"households": {"person_id": "missing_persons"}} with pytest.raises( - ValueError, match="Referenced table 'missing_users' not found" + ValueError, match="Referenced table 'missing_persons' not found" ): normalise_table_keys(tables, primary_keys, foreign_keys) @@ -281,26 +303,31 @@ class TestAutoDetectForeignKeys: def test_simple_detection(self): """Test basic foreign key detection.""" - users = pd.DataFrame({"user_id": [1, 2], "name": ["Alice", "Bob"]}) - orders = pd.DataFrame({"order_id": [100, 200], "user_id": [1, 2]}) + persons = pd.DataFrame({"person_id": [1, 2], "name": ["Alice", "Bob"]}) + households = pd.DataFrame( + {"household_id": [100, 200], "person_id": [1, 2]} + ) - tables = {"users": users, "orders": orders} - primary_keys = {"users": "user_id", "orders": "order_id"} + tables = {"persons": persons, "households": households} + primary_keys = {"persons": "person_id", "households": "household_id"} result = _auto_detect_foreign_keys(tables, primary_keys) - expected = {"orders": {"user_id": "users"}} + expected = {"households": {"person_id": "persons"}} assert result == expected def test_no_foreign_keys(self): """Test when no foreign keys are detected.""" - users = pd.DataFrame({"user_id": [1, 2], "name": ["Alice", "Bob"]}) - products = pd.DataFrame( - {"product_id": [100, 200], "name": ["Widget", "Gadget"]} + persons = pd.DataFrame({"person_id": [1, 2], "name": ["Alice", "Bob"]}) + benefit_units = pd.DataFrame( + {"benefit_unit_id": [100, 200], "name": ["Disability", "Family"]} ) - tables = {"users": users, "products": products} - primary_keys = {"users": "user_id", "products": "product_id"} + tables = {"persons": persons, "benefit_units": benefit_units} + primary_keys = { + "persons": "person_id", + "benefit_units": "benefit_unit_id", + } result = _auto_detect_foreign_keys(tables, primary_keys) @@ -308,28 +335,35 @@ def test_no_foreign_keys(self): def test_multiple_foreign_keys_detection(self): """Test detection of multiple foreign keys in one table.""" - users = pd.DataFrame({"user_id": [1, 2], "name": ["Alice", "Bob"]}) - categories = pd.DataFrame( - {"category_id": [10, 20], "name": ["Electronics", "Books"]} + persons = pd.DataFrame({"person_id": [1, 2], "name": ["Alice", "Bob"]}) + benefit_units = pd.DataFrame( + {"benefit_unit_id": [10, 20], "name": ["Disability", "Family"]} ) - orders = pd.DataFrame( + households = pd.DataFrame( { - "order_id": [100, 200], - "user_id": [1, 2], - "category_id": [10, 20], + "household_id": [100, 200], + "person_id": [1, 2], + "benefit_unit_id": [10, 20], } ) - tables = {"users": users, "categories": categories, "orders": orders} + tables = { + "persons": persons, + "benefit_units": benefit_units, + "households": households, + } primary_keys = { - "users": "user_id", - "categories": "category_id", - "orders": "order_id", + "persons": "person_id", + "benefit_units": "benefit_unit_id", + "households": "household_id", } result = _auto_detect_foreign_keys(tables, primary_keys) expected = { - "orders": {"user_id": "users", "category_id": "categories"} + "households": { + "person_id": "persons", + "benefit_unit_id": "benefit_units", + } } assert result == expected