Skip to content

Commit 96a2c99

Browse files
committed
update database link to enable calibration in ci
1 parent f9c5cfa commit 96a2c99

File tree

9 files changed

+210
-61
lines changed

9 files changed

+210
-61
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"microcalibrate",
2020
"sqlalchemy",
2121
"huggingface_hub",
22+
"torch",
2223
]
2324

2425
[project.optional-dependencies]

src/policyengine_data/calibration/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
validate_metrics_matrix,
99
)
1010
from .target_rescaling import download_database, rescale_calibration_targets
11+
from .utils import create_geographic_normalization_factor

src/policyengine_data/calibration/calibrate.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from typing import Dict, Optional
6+
from typing import Dict, List, Optional
77

88
import numpy as np
99
import pandas as pd
@@ -21,6 +21,9 @@
2121
download_database,
2222
rescale_calibration_targets,
2323
)
24+
from policyengine_data.calibration.utils import (
25+
create_geographic_normalization_factor,
26+
)
2427
from policyengine_data.tools.legacy_class_conversions import (
2528
SingleYearDataset_to_Dataset,
2629
)
@@ -99,7 +102,7 @@ def calibrate_single_geography_level(
99102
use_dataset_weights: Optional[bool] = True,
100103
regularize_with_l0: Optional[bool] = False,
101104
raise_error: Optional[bool] = True,
102-
):
105+
) -> "SingleYearDataset":
103106
"""
104107
This function will calibrate the dataset for a specific geography level, defaulting to stacking the base dataset per area within it.
105108
It will handle conversion between dataset classes to enable:
@@ -291,17 +294,19 @@ def calibrate_single_geography_level(
291294
return geography_level_calibrated_dataset
292295

293296

297+
# TODO: create normalization factor to pass into Calibrator balancing targets at different levels
294298
def calibrate_all_levels(
295299
database_stacking_areas: Dict[str, str],
296300
dataset: str,
297301
dataset_subsample_size: Optional[int] = None,
298302
geo_sim_filter_variable: Optional[str] = "ucgid",
303+
geo_hierarchy: Optional[List[str]] = None,
299304
year: Optional[int] = 2023,
300305
db_uri: Optional[str] = None,
301306
noise_level: Optional[float] = 10.0,
302307
regularize_with_l0: Optional[bool] = False,
303308
raise_error: Optional[bool] = True,
304-
):
309+
) -> "SingleYearDataset":
305310
"""
306311
This function will calibrate the dataset for all geography levels in the database, defaulting to stacking the base dataset per area within the specified level (it is recommended to use the lowest in the hierarchy for stacking). (Eg. when calibrating for district, state and national levels in the US, this function will stack the CPS dataset for each district and calibrate the stacked dataset for the three levels' targets.)
307312
It will handle conversion between dataset classes to enable:
@@ -318,6 +323,7 @@ def calibrate_all_levels(
318323
dataset (str): Path to the base dataset to stack.
319324
dataset_subsample_size (Optional[int]): The size of the subsample to use for calibration.
320325
geo_sim_filter_variable (Optional[str]): The variable to use for geographic similarity filtering. Default in the US: "ucgid".
326+
geo_hierarchy (Optional[List[str]]): The geographic hierarchy to use for calibration.
321327
year (Optional[int]): The year to use for calibration. Default: 2023.
322328
db_uri (Optional[str]): The database URI to use for calibration. If None, it will download the database from the default URI.
323329
noise_level (Optional[float]): The noise level to use for calibration. Default: 10.0.
@@ -438,6 +444,10 @@ def calibrate_all_levels(
438444
raise_error=raise_error,
439445
)
440446

447+
normalization_factor = create_geographic_normalization_factor(
448+
geo_hierarchy=geo_hierarchy, target_info=target_info
449+
)
450+
441451
target_names = []
442452
excluded_targets = []
443453
for target_id, info in target_info.items():
@@ -462,6 +472,7 @@ def calibrate_all_levels(
462472
excluded_targets=(
463473
excluded_targets if len(excluded_targets) > 0 else None
464474
),
475+
normalization_factor=normalization_factor,
465476
sparse_learning_rate=0.1,
466477
regularize_with_l0=regularize_with_l0,
467478
csv_path=f"full_calibration.csv",
@@ -494,7 +505,6 @@ def calibrate_all_levels(
494505
state_level_calibrated_dataset = calibrate_single_geography_level(
495506
areas_in_state_level,
496507
"hf://policyengine/policyengine-us-data/cps_2023.h5",
497-
db_uri="sqlite:///policy_data.db",
498508
use_dataset_weights=False,
499509
regularize_with_l0=True,
500510
)
@@ -504,7 +514,8 @@ def calibrate_all_levels(
504514
].values
505515

506516
SingleYearDataset_to_Dataset(
507-
state_level_calibrated_dataset, output_path="Dataset_state_level.h5"
517+
state_level_calibrated_dataset,
518+
output_path="Dataset_state_level_age_medicaid_snap_eitc_agi_targets.h5",
508519
)
509520

510521
print("Completed calibration for state level dataset.")
@@ -516,9 +527,8 @@ def calibrate_all_levels(
516527

517528
national_level_calibrated_dataset = calibrate_single_geography_level(
518529
areas_in_national_level,
519-
dataset="Dataset_state_level.h5",
530+
dataset="Dataset_state_level_age_medicaid_snap_eitc_agi_targets.h5",
520531
stack_datasets=False,
521-
db_uri="sqlite:///policy_data.db",
522532
noise_level=0.0,
523533
use_dataset_weights=True,
524534
regularize_with_l0=False,
@@ -530,7 +540,7 @@ def calibrate_all_levels(
530540

531541
SingleYearDataset_to_Dataset(
532542
national_level_calibrated_dataset,
533-
output_path="Dataset_national_level.h5",
543+
output_path="Dataset_national_level_age_medicaid_snap_eitc_agi_targets.h5",
534544
)
535545

536546
print("Completed calibration for national level dataset.")

src/policyengine_data/calibration/metrics_matrix_creation.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,9 @@
66
from policyengine_us import Microsimulation
77
from sqlalchemy import create_engine
88

9-
logger = logging.getLogger(__name__)
10-
11-
12-
def download_database(
13-
filename: Optional[str] = "policy_data.db",
14-
repo_id: Optional[str] = "policyengine/test",
15-
) -> create_engine:
16-
"""
17-
Download the SQLite database from Hugging Face Hub and return the connection string.
18-
19-
Args:
20-
filename: Optional name of the database file to download
21-
repo_id: Optional Hugging Face repository ID where the database is stored
22-
23-
Returns:
24-
Connection string for the SQLite database
25-
"""
26-
import os
9+
from .target_rescaling import download_database
2710

28-
from huggingface_hub import hf_hub_download
29-
30-
# Download the file to the current working directory
31-
try:
32-
downloaded_path = hf_hub_download(
33-
repo_id=repo_id,
34-
filename=filename,
35-
local_dir=".", # Use "." for the current working directory
36-
local_dir_use_symlinks=False, # Recommended to avoid symlinks
37-
)
38-
path = os.path.abspath(downloaded_path)
39-
logger.info(f"File downloaded successfully to: {path}")
40-
return f"sqlite:///{path}"
41-
42-
except Exception as e:
43-
raise ValueError(f"An error occurred: {e}")
11+
logger = logging.getLogger(__name__)
4412

4513

4614
# NOTE (juaristi22): This could fail if trying to filter by more than one stratum constraint if there are mismatches between the filtering variable, value and operation.

src/policyengine_data/calibration/target_rescaling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def download_database(
1616
filename: Optional[str] = "policy_data.db",
17-
repo_id: Optional[str] = "policyengine/test",
17+
repo_id: Optional[str] = "policyengine/policyengine-us-data",
1818
) -> create_engine:
1919
"""
2020
Download the SQLite database from Hugging Face Hub and return the connection string.
@@ -37,6 +37,7 @@ def download_database(
3737
filename=filename,
3838
local_dir="download/",
3939
local_dir_use_symlinks=False, # Recommended to avoid symlinks
40+
force_download=True, # Always download, ignore cache
4041
)
4142
path = os.path.abspath(downloaded_path)
4243
logger.info(f"File downloaded successfully to: {path}")
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Additional utilities for the calibration process.
3+
"""
4+
5+
from typing import Dict, List
6+
7+
import numpy as np
8+
import torch
9+
10+
11+
def create_geographic_normalization_factor(
12+
geo_hierarchy: List[str],
13+
target_info: Dict[int, Dict[str, any]],
14+
) -> torch.Tensor:
15+
"""
16+
Create a normalization factor for the calibration process to balance targets that belong to different geographic areas or concepts.
17+
18+
Args:
19+
geo_hierarchy (List[str]): Geographic hierarchy levels' codes (e.g., ["0100000US", "0400000US", "0500000US"]). Make sure to pass the part of the code general to all areas within a given level.
20+
target_info (Dict[int, Dict[str, any]]): A dictionary containing information about each target, including its name which denotes geographic area and its active status.
21+
22+
Returns:
23+
normalization_factor (torch.Tensor): Normalization factor for each active target.
24+
"""
25+
is_active = []
26+
geo_codes = []
27+
geo_level_sum = {}
28+
29+
for code in geo_hierarchy:
30+
geo_level_sum[code] = 0
31+
32+
# First pass: collect active status and geo codes for all targets
33+
for target_id, info in target_info.items():
34+
is_active.append(info["active"])
35+
target_name = info["name"]
36+
matched_geo = None
37+
38+
for code in geo_hierarchy:
39+
if code in target_name:
40+
matched_geo = code
41+
if info["active"]:
42+
geo_level_sum[code] += 1
43+
break
44+
45+
geo_codes.append(matched_geo)
46+
47+
is_active = torch.tensor(is_active, dtype=torch.float32)
48+
normalization_factor = torch.zeros_like(is_active)
49+
50+
# Assign normalization factors based on geo level for each target
51+
for i, (is_target_active, geo_code) in enumerate(
52+
zip(is_active, geo_codes)
53+
):
54+
if (
55+
is_target_active
56+
and geo_code is not None
57+
and geo_level_sum[geo_code] > 0
58+
):
59+
normalization_factor[i] = 1.0 / geo_level_sum[geo_code]
60+
61+
# Check if only one geographic level is represented among active targets
62+
active_geo_levels = set()
63+
for i, is_target_active in enumerate(is_active):
64+
if is_target_active and geo_codes[i] is not None:
65+
active_geo_levels.add(geo_codes[i])
66+
67+
# If no matching geo codes for active targets, return zeros for active targets
68+
if len(active_geo_levels) == 0:
69+
active_factors = torch.zeros(sum(is_active.bool()))
70+
return active_factors
71+
72+
# If only one geographic level is present, return tensor of ones for active targets
73+
if len(active_geo_levels) <= 1:
74+
normalization_factor = torch.where(
75+
is_active.bool(), torch.tensor(1.0), torch.tensor(0.0)
76+
)
77+
else:
78+
# Apply mean normalization for multiple geographic levels
79+
active_factors = normalization_factor[is_active.bool()]
80+
if len(active_factors) > 0 and active_factors.sum() > 0:
81+
inv_mean_norm = 1.0 / active_factors.mean()
82+
normalization_factor = normalization_factor * inv_mean_norm
83+
84+
return normalization_factor[is_active.bool()]

tests/test_calibration/test_calibration.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@
6363
}
6464

6565

66-
@pytest.mark.skip(
67-
reason="Online database is not yet updated with necessary format."
68-
)
6966
def test_calibration_per_geographic_level_iteration():
7067
"""
7168
Test and example of the calibration routine involving calibrating one geographic level at a time from lowest to highest in the hierarchy and generating sparsity in all but the last levels.
@@ -126,9 +123,6 @@ def test_calibration_per_geographic_level_iteration():
126123
).sum() > 0, "Household weights do not differ between state and national levels, suggesting national calibration was unsucessful."
127124

128125

129-
@pytest.mark.skip(
130-
reason="Online database is not yet updated with necessary format."
131-
)
132126
def test_calibration_combining_all_levels_at_once():
133127
"""
134128
Test and example of the calibration routine involving stacking datasets at a single (most often lowest) geographic level for increased data richness and then calibrating said stacked dataset for all geographic levels at once.
@@ -147,6 +141,7 @@ def test_calibration_combining_all_levels_at_once():
147141
areas_in_state_level,
148142
"hf://policyengine/policyengine-us-data/cps_2023.h5",
149143
db_uri="sqlite:///policy_data.db", # remove once online database is updated
144+
geo_hierarchy=["0100000US", "0400000US"],
150145
dataset_subsample_size=2000,
151146
regularize_with_l0=True,
152147
raise_error=False, # this will avoid raising an error if some targets have no records contributing to them (given sampling)

tests/test_calibration/test_matrix_creation.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
import pytest
88

99

10-
@pytest.mark.skip(
11-
reason="Online database is not yet updated with necessary format."
12-
)
1310
def test_matrix_creation() -> None:
1411
from policyengine_data.calibration import (
1512
create_metrics_matrix,
@@ -28,19 +25,11 @@ def test_matrix_creation() -> None:
2825
reform_id=0,
2926
)
3027

31-
# Validate the matrix
28+
# Validate the matrix (it will raise an error if matrix creation failed)
3229
validation_results = validate_metrics_matrix(
3330
metrics_matrix, target_values, target_info=target_info
3431
)
3532

36-
assert metrics_matrix.columns.tolist() == [
37-
i for i in range(1, 937)
38-
], "Metrics matrix columns do not match expected target ids"
39-
assert all(
40-
validation_results[validation_results["target_id"] < 19]["estimate"]
41-
!= 0
42-
), "Metrics matrix should have all estimates non-zero for federal age targets"
43-
4433

4534
def test_parse_constraint_value():
4635
"""Test parsing constraint values from strings."""

0 commit comments

Comments
 (0)