-
Notifications
You must be signed in to change notification settings - Fork 1
Evaluation of estimates being within tolerance levels from target values #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
16341ef
evaluation of estimates within tolerance levels
juaristi22 4257a5a
fix dependency conflicts
juaristi22 4a6051b
move publishing to after versioning
juaristi22 d741e7a
switch condition
juaristi22 83284a3
Merge branch 'main' of https://github.com/PolicyEngine/microcalibrate…
juaristi22 baeb94e
add test for when all estimates are within tolerance
juaristi22 d0bf4e8
minor changes
juaristi22 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| - bump: minor | ||
| changes: | ||
| added: | ||
| - A function to evaluate whether estimates are within desired tolerance levels. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from .calibration import Calibration | ||
| from .evaluation import evaluate_estimate_distance_to_targets |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| import logging | ||
| from typing import List, Optional | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def evaluate_estimate_distance_to_targets( | ||
| targets: np.ndarray, | ||
| estimates: np.ndarray, | ||
| tolerances: np.ndarray, | ||
| target_names: Optional[List[str]] = None, | ||
| raise_on_error: Optional[bool] = False, | ||
| ): | ||
| """ | ||
| Evaluate the distance between estimates and targets against tolerances. | ||
|
|
||
| Args: | ||
| targets (np.ndarray): The ground truth target values. | ||
| estimates (np.ndarray): The estimated values to compare against the targets. | ||
| tolerances (np.ndarray): The acceptable tolerance levels for each target. | ||
| target_names (Optional[List[str]]): The names of the targets for reporting. | ||
| raise_on_error (Optional[bool]): If True, raises an error if any estimate is outside its tolerance. Default is False. | ||
|
|
||
| Returns: | ||
| evals (pd.DataFrame): A DataFrame containing the evaluation results, including: | ||
| - target_names: Names of the targets (if provided). | ||
| - distances: The absolute differences between estimates and targets. | ||
| - tolerances: The tolerance levels for each target. | ||
| - within_tolerance: Boolean array indicating if each estimate is within its tolerance. | ||
| """ | ||
| if targets.shape != estimates.shape or targets.shape != tolerances.shape: | ||
| raise ValueError( | ||
| "Targets, estimates, and tolerances must have the same shape." | ||
| ) | ||
|
|
||
| distances = np.abs(estimates - targets) | ||
| within_tolerance = distances <= tolerances | ||
|
|
||
| evals = { | ||
| "target_names": ( | ||
| target_names | ||
| if target_names is not None | ||
| else list(np.nan for _ in targets) | ||
| ), | ||
| "distances": distances, | ||
| "tolerances": tolerances, | ||
| "within_tolerance": within_tolerance, | ||
| } | ||
|
|
||
| num_outside_tolerance = (~within_tolerance).sum() | ||
| if raise_on_error and num_outside_tolerance > 0: | ||
| raise ValueError( | ||
| f"{num_outside_tolerance} target(s) are outside their tolerance levels." | ||
| ) | ||
|
|
||
| return pd.DataFrame(evals) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| """ | ||
| Test the evaluation functionality for the calibration process. | ||
| """ | ||
|
|
||
| import pytest | ||
| from src.microcalibrate.calibration import Calibration | ||
| from microcalibrate.evaluation import ( | ||
| evaluate_estimate_distance_to_targets, | ||
| ) | ||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
|
|
||
| def test_evaluate_estimate_distance_to_targets() -> None: | ||
| """Test the evaluation of estimates against targets with tolerances, for a case in which estimates are not within tolerance.""" | ||
|
|
||
| # Create a mock dataset with age and income | ||
| random_generator = np.random.default_rng(0) | ||
| data = pd.DataFrame( | ||
| { | ||
| "age": random_generator.integers(18, 70, size=100), | ||
| "income": random_generator.normal(40000, 50000, size=100), | ||
| } | ||
| ) | ||
| weights = np.ones(len(data)) | ||
| targets_matrix = pd.DataFrame( | ||
| { | ||
| "income_aged_20_30": ( | ||
| (data["age"] >= 20) & (data["age"] <= 30) | ||
| ).astype(float) | ||
| * data["income"], | ||
| "income_aged_40_50": ( | ||
| (data["age"] >= 40) & (data["age"] <= 50) | ||
| ).astype(float) | ||
| * data["income"], | ||
| } | ||
| ) | ||
| targets = np.array( | ||
| [ | ||
| (targets_matrix["income_aged_20_30"] * weights).sum() * 50, | ||
| (targets_matrix["income_aged_40_50"] * weights).sum() * 50, | ||
| ] | ||
| ) | ||
|
|
||
| calibrator = Calibration( | ||
| estimate_matrix=targets_matrix, | ||
| weights=weights, | ||
| targets=targets, | ||
| noise_level=0.05, | ||
| epochs=50, | ||
| learning_rate=0.01, | ||
| dropout_rate=0, | ||
| ) | ||
|
|
||
| performance_df = calibrator.calibrate() | ||
| final_estimates = calibrator.estimate() | ||
| tolerances = np.array([0.001, 0.005]) | ||
|
|
||
| # Evaluate the estimates against the targets without raising an error | ||
| evals_df = evaluate_estimate_distance_to_targets( | ||
| targets=targets, | ||
| estimates=final_estimates, | ||
| tolerances=tolerances, | ||
| target_names=["Income Aged 20-30", "Income Aged 40-50"], | ||
| raise_on_error=False, | ||
| ) | ||
|
|
||
| # Check that the evaluation DataFrame has the expected structure | ||
| assert set(evals_df.columns) == { | ||
| "target_names", | ||
| "distances", | ||
| "tolerances", | ||
| "within_tolerance", | ||
| } | ||
|
|
||
| # Evaluate the estimates against the targets raising an error | ||
| with pytest.raises(ValueError) as exc_info: | ||
| evals_df = evaluate_estimate_distance_to_targets( | ||
| targets=targets, | ||
| estimates=final_estimates, | ||
| tolerances=tolerances, | ||
| target_names=["Income Aged 20-30", "Income Aged 40-50"], | ||
| raise_on_error=True, | ||
| ) | ||
|
|
||
| assert "target(s) are outside their tolerance levels" in str( | ||
| exc_info.value | ||
| ) | ||
|
|
||
|
|
||
| def test_all_within_tolerance(): | ||
| """Tests a simple case where all estimates are within their tolerances.""" | ||
| targets = np.array([10, 20, 30]) | ||
| estimates = np.array([10.1, 19.8, 30.0]) | ||
| tolerances = np.array([0.2, 0.3, 0.1]) | ||
| target_names = ["A", "B", "C"] | ||
|
|
||
| result_df = evaluate_estimate_distance_to_targets( | ||
| targets, estimates, tolerances, target_names | ||
| ) | ||
|
|
||
| assert result_df["within_tolerance"].all() | ||
| assert result_df.shape == (3, 4) | ||
| np.testing.assert_array_almost_equal( | ||
| result_df["distances"], [0.1, 0.2, 0.0] | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.