From 332898e7e285394c9a68517a192bc6d99fa9923b Mon Sep 17 00:00:00 2001
From: a-kore <37000693+a-kore@users.noreply.github.com>
Date: Wed, 24 Jul 2024 15:13:56 -0400
Subject: [PATCH] add sample_size visualization and update notebooks (#667)
* add sample_size vis and update notebooks
* fix mypy error
* change Group Size to sample_size
---
benchmarks/mimiciv/discharge_prediction.ipynb | 4 +-
.../mimiciv/icu_mortality_prediction.ipynb | 4 +-
cyclops/evaluate/evaluator.py | 1 +
cyclops/evaluate/fairness/evaluator.py | 4 +-
cyclops/report/model_card/fields.py | 10 ++
cyclops/report/report.py | 92 ++++++----
cyclops/report/templates/model_report/plot.js | 160 +++++++++++++++---
cyclops/report/utils.py | 36 ++++
.../diabetes_130/readmission_prediction.ipynb | 65 +++----
.../kaggle/heart_failure_prediction.ipynb | 65 +++----
.../mimiciv/mortality_prediction.ipynb | 39 ++---
.../tutorials/nihcxr/cxr_classification.ipynb | 71 +++++---
.../nihcxr/generate_nihcxr_report.py | 69 +++++---
.../tutorials/synthea/los_prediction.ipynb | 51 ++----
tests/cyclops/report/test_report.py | 114 +++++++++++++
tests/cyclops/report/test_utils.py | 8 +
16 files changed, 540 insertions(+), 253 deletions(-)
diff --git a/benchmarks/mimiciv/discharge_prediction.ipynb b/benchmarks/mimiciv/discharge_prediction.ipynb
index 0de32ccd7..ffc69ae9e 100644
--- a/benchmarks/mimiciv/discharge_prediction.ipynb
+++ b/benchmarks/mimiciv/discharge_prediction.ipynb
@@ -1182,9 +1182,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/benchmarks/mimiciv/icu_mortality_prediction.ipynb b/benchmarks/mimiciv/icu_mortality_prediction.ipynb
index 5009131d6..53d4d0f75 100644
--- a/benchmarks/mimiciv/icu_mortality_prediction.ipynb
+++ b/benchmarks/mimiciv/icu_mortality_prediction.ipynb
@@ -1159,9 +1159,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py
index bfeb54bb7..cf4d01611 100644
--- a/cyclops/evaluate/evaluator.py
+++ b/cyclops/evaluate/evaluator.py
@@ -311,6 +311,7 @@ def _compute_metrics(
model_name: str = "model_for_%s" % prediction_column
results.setdefault(model_name, {})
results[model_name][slice_name] = metric_output
+ results[model_name][slice_name]["sample_size"] = len(sliced_dataset)
set_decode(dataset, True) # restore decoding features
diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py
index 9b25773bf..22b39d175 100644
--- a/cyclops/evaluate/fairness/evaluator.py
+++ b/cyclops/evaluate/fairness/evaluator.py
@@ -260,7 +260,7 @@ def evaluate_fairness( # noqa: PLR0912
for prediction_column in fmt_prediction_columns:
results.setdefault(prediction_column, {})
results[prediction_column].setdefault(slice_name, {}).update(
- {"Group Size": len(sliced_dataset)},
+ {"sample_size": len(sliced_dataset)},
)
pred_result = _get_metric_results_for_prediction_and_slice(
@@ -966,7 +966,7 @@ def _compute_parity_metrics(
parity_results[key] = {}
for slice_name, slice_result in prediction_result.items():
for metric_name, metric_value in slice_result.items():
- if metric_name == "Group Size":
+ if metric_name == "sample_size":
continue
# add 'Parity' to the metric name before @threshold, if specified
diff --git a/cyclops/report/model_card/fields.py b/cyclops/report/model_card/fields.py
index 12dd29497..f4f8162b0 100644
--- a/cyclops/report/model_card/fields.py
+++ b/cyclops/report/model_card/fields.py
@@ -380,6 +380,11 @@ class PerformanceMetric(
default_factory=list,
)
+ sample_size: Optional[StrictInt] = Field(
+ None,
+ description="The sample size used to compute this metric.",
+ )
+
class User(
BaseModelCardField,
@@ -599,6 +604,11 @@ class MetricCard(
description="Timestamps for each point in the history.",
)
+ sample_sizes: Optional[List[int]] = Field(
+ None,
+ description="Sample sizes for each point in the history.",
+ )
+
class MetricCardCollection(BaseModelCardField, composable_with="Overview"):
"""A collection of metric cards to be displayed in the model card."""
diff --git a/cyclops/report/report.py b/cyclops/report/report.py
index 903a7b682..786dae53c 100644
--- a/cyclops/report/report.py
+++ b/cyclops/report/report.py
@@ -48,6 +48,7 @@
get_histories,
get_names,
get_passed,
+ get_sample_sizes,
get_slices,
get_thresholds,
get_timestamps,
@@ -855,6 +856,7 @@ def log_quantitative_analysis(
pass_fail_threshold_fns: Optional[
Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]]
] = None,
+ sample_size: Optional[int] = None,
**extra: Any,
) -> None:
"""Add a quantitative analysis to the report.
@@ -921,6 +923,7 @@ def log_quantitative_analysis(
"slice": metric_slice,
"decision_threshold": decision_threshold,
"description": description,
+ "sample_size": sample_size,
**extra,
}
@@ -958,42 +961,70 @@ def log_quantitative_analysis(
field_type=field_type,
)
- def log_performance_metrics(self, metrics: Dict[str, Any]) -> None:
- """Add a performance metric to the `Quantitative Analysis` section.
+ def log_performance_metrics(
+ self,
+ results: Dict[str, Any],
+ metric_descriptions: Dict[str, str],
+ pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7,
+ pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x,
+ threshold: bool(x >= threshold),
+ ) -> None:
+ """
+ Log all performance metrics to the model card report.
Parameters
----------
- metrics : Dict[str, Any]
- A dictionary of performance metrics. The keys should be the name of the
- metric, and the values should be the value of the metric. If the metric
- is a slice metric, the key should be the slice name followed by a slash
- and then the metric name (e.g. "slice_name/metric_name"). If no slice
- name is provided, the slice name will be "overall".
-
- Raises
- ------
- TypeError
- If the given metrics are not a dictionary with string keys.
+ results : Dict[str, Any]
+ Dictionary containing the results,
+ with keys in the format "split/metric_name".
+ metric_descriptions : Dict[str, str]
+ Dictionary mapping metric names to their descriptions.
+ pass_fail_thresholds : Union[float, Dict[str, float]], optional
+ The threshold(s) for pass/fail tests.
+ Can be a single float applied to all metrics,
+ or a dictionary mapping "split/metric_name" to individual thresholds.
+ Default is 0.7.
+ pass_fail_threshold_fn : Callable[[float, float], bool], optional
+ Function to determine if a metric passes or fails.
+ Default is lambda x, threshold: bool(x >= threshold).
+ Returns
+ -------
+ None
"""
- _raise_if_not_dict_with_str_keys(metrics)
- for metric_name, metric_value in metrics.items():
- name_split = metric_name.split("/")
- if len(name_split) == 1:
- slice_name = "overall"
- metric_name = name_split[0] # noqa: PLW2901
- else: # everything before the last slash is the slice name
- slice_name = "/".join(name_split[:-1])
- metric_name = name_split[-1] # noqa: PLW2901
-
- # TODO: create plot
+ # Extract sample sizes
+ sample_sizes = {
+ key.split("/")[0]: value
+ for key, value in results.items()
+ if "sample_size" in key.split("/")[1]
+ }
- self._log_field(
- data={"type": metric_name, "value": metric_value, "slice": slice_name},
- section_name="quantitative_analysis",
- field_name="performance_metrics",
- field_type=PerformanceMetric,
- )
+ # Log metrics
+ for name, metric in results.items():
+ split, metric_name = name.split("/")
+ if metric_name != "sample_size":
+ metric_value = metric.tolist() if hasattr(metric, "tolist") else metric
+
+ # Determine the threshold for this specific metric
+ if isinstance(pass_fail_thresholds, dict):
+ threshold = pass_fail_thresholds.get(
+ name, 0.7
+ ) # Default to 0.7 if not specified
+ else:
+ threshold = pass_fail_thresholds
+
+ self.log_quantitative_analysis(
+ "performance",
+ name=metric_name,
+ value=metric_value,
+ description=metric_descriptions.get(
+ metric_name, "No description provided."
+ ),
+ metric_slice=split,
+ pass_fail_thresholds=threshold,
+ pass_fail_threshold_fns=pass_fail_threshold_fn,
+ sample_size=sample_sizes.get(split),
+ )
# TODO: MERGE/COMPARE MODEL CARDS
@@ -1162,6 +1193,7 @@ def export(
"get_names": get_names,
"get_histories": get_histories,
"get_timestamps": get_timestamps,
+ "get_sample_sizes": get_sample_sizes,
}
template.globals.update(func_dict)
diff --git a/cyclops/report/templates/model_report/plot.js b/cyclops/report/templates/model_report/plot.js
index 8f94071aa..2cf936b27 100644
--- a/cyclops/report/templates/model_report/plot.js
+++ b/cyclops/report/templates/model_report/plot.js
@@ -1,5 +1,6 @@
// Javascript code for plots in the model card template
-
+const MAX_SIZE = 20;
+let maxSampleSize = 0;
// Define a function to update the plot based on selected filters
function updatePlot() {
@@ -140,6 +141,7 @@ function updatePlot() {
var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }});
var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }});
var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }});
+ var sample_sizes_all = JSON.parse({{ get_sample_sizes(model_card)|safe|tojson }});
for (let i = 0; i < selection.length; i++) {
// use selection to set label_slice_selection background color
@@ -223,6 +225,21 @@ function updatePlot() {
}
}
+ // Find the maximum sample size across all selections
+ for (let i = 0; i < selections.length; i++) {
+ if (selections[i] === null) {
+ continue;
+ }
+ selection = selections[i]
+ // get idx of slices where all elements match
+ var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection));
+ var sample_size_data = [];
+ for (let i = 0; i < sample_sizes_all[idx].length; i++) {
+ sample_size_data.push(sample_sizes_all[idx][i]);
+ }
+ maxSampleSize = Math.max(...sample_size_data);
+ }
+
traces = [];
for (let i = 0; i < selections.length; i++) {
if (selections[i] === null) {
@@ -237,11 +254,17 @@ function updatePlot() {
}
var timestamp_data = [];
for (let i = 0; i < timestamps_all[idx].length; i++) {
- timestamp_data.push(timestamps_all[idx][i]);
+ // timestamp_data.push(timestamps_all[idx][i]);
+ timestamp_data.push(formatDate(timestamps_all[idx][i]));
+ }
+ var sample_size_data = [];
+ for (let i = 0; i < sample_sizes_all[idx].length; i++) {
+ sample_size_data.push(sample_sizes_all[idx][i]);
}
var last_n_evals = document.getElementById("n_evals_slider_pot").value;
history_data = history_data.slice(-last_n_evals);
timestamp_data = timestamp_data.slice(-last_n_evals);
+ sample_size_data = sample_size_data.slice(-last_n_evals);
// get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat
var slope = lineOfBestFit(history_data)[0];
if (slope > 0.01) {
@@ -323,17 +346,41 @@ function updatePlot() {
};
traces.push(threshold_trace);
}
- var trace = {
- // range of x is the length of the list of floats
- x: timestamp_data,
- y: history_data,
- mode: 'lines+markers',
- type: 'scatter',
- marker: {color: plot_colors[i+1]},
- line: {color: plot_colors[i+1]},
- name: name,
- legendgroup: name + i,
- //name: selection.toString(),
+ // Add sample size circles
+ var sample_size_trace = {
+ x: timestamp_data,
+ y: history_data,
+ mode: 'markers',
+ marker: {
+ sizemode: 'area',
+ size: sample_size_data,
+ sizeref: maxSampleSize / MAX_SIZE ** 2,
+ color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`,
+ line: {width: 0},
+ },
+ text: sample_size_data.map((s, index) =>
+ `Date: ${timestamp_data[index]}
Value: ${history_data[index].toFixed(2)}
Sample Size: ${s}`
+ ),
+ hoverinfo: 'text',
+ hovertemplate: '%{text}',
+ name: name + ' (Sample Size)',
+ legendgroup: name + i,
+ };
+
+ // Add main data points and line
+ var main_trace = {
+ x: timestamp_data,
+ y: history_data,
+ mode: 'lines+markers',
+ type: 'scatter',
+ marker: {
+ color: plot_colors[i+1],
+ symbol: 'circle',
+ },
+ line: {color: plot_colors[i+1]},
+ name: name,
+ legendgroup: name + i,
+ hoverinfo: 'skip'
};
// check if length of history_data is >= mean_std_min_evals and if so get rolling mean and std if mean_plot_selection or std_plot_selection is checked
@@ -384,7 +431,8 @@ function updatePlot() {
};
traces.push(trace_mean);
}
- traces.push(trace);
+ traces.push(sample_size_trace);
+ traces.push(main_trace);
}
@@ -744,6 +792,7 @@ function updatePlotSelection() {
var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }});
var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }});
var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }});
+ var sample_sizes_all = JSON.parse({{ get_sample_sizes(model_card)|safe|tojson }});
var radioGroups = {};
var labelGroups = {};
@@ -786,6 +835,21 @@ function updatePlotSelection() {
}
}
+ // Find the maximum sample size across all selections
+ for (let i = 0; i < selections.length; i++) {
+ if (selections[i] === null) {
+ continue;
+ }
+ selection = selections[i]
+ // get idx of slices where all elements match
+ var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection));
+ var sample_size_data = [];
+ for (let i = 0; i < sample_sizes_all[idx].length; i++) {
+ sample_size_data.push(sample_sizes_all[idx][i]);
+ }
+ maxSampleSize = Math.max(...sample_size_data);
+}
+
traces = [];
var plot_number = parseInt(plot_selected.value.split(" ")[1]-1);
for (let i = 0; i < selections.length; i++) {
@@ -802,11 +866,17 @@ function updatePlotSelection() {
}
var timestamp_data = [];
for (let i = 0; i < timestamps_all[idx].length; i++) {
- timestamp_data.push(timestamps_all[idx][i]);
+ // timestamp_data.push(timestamps_all[idx][i]);
+ timestamp_data.push(formatDate(timestamps_all[idx][i]));
+ }
+ var sample_size_data = [];
+ for (let i = 0; i < sample_sizes_all[idx].length; i++) {
+ sample_size_data.push(sample_sizes_all[idx][i]);
}
var last_n_evals = document.getElementById("n_evals_slider_pot").value;
history_data = history_data.slice(-last_n_evals);
timestamp_data = timestamp_data.slice(-last_n_evals);
+ sample_size_data = sample_size_data.slice(-last_n_evals);
// get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat
var slope = lineOfBestFit(history_data)[0];
@@ -891,16 +961,41 @@ function updatePlotSelection() {
traces.push(threshold_trace);
}
- var trace = {
- // range of x is the length of the list of floats
- x: timestamp_data,
- y: history_data,
- mode: 'lines+markers',
- type: 'scatter',
- marker: {color: plot_colors[i+1]},
- line: {color: plot_colors[i+1]},
- name: name,
- legendgroup: name + i,
+ // Add sample size circles
+ var sample_size_trace = {
+ x: timestamp_data,
+ y: history_data,
+ mode: 'markers',
+ marker: {
+ sizemode: 'area',
+ size: sample_size_data,
+ sizeref: maxSampleSize / MAX_SIZE ** 2,
+ color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`,
+ line: {width: 0},
+ },
+ text: sample_size_data.map((s, index) =>
+ `Date: ${timestamp_data[index]}
Value: ${history_data[index].toFixed(2)}
Sample Size: ${s}`
+ ),
+ hoverinfo: 'text',
+ hovertemplate: '%{text}',
+ name: name + ' (Sample Size)',
+ legendgroup: name + i,
+ };
+
+ // Add main data points and line
+ var main_trace = {
+ x: timestamp_data,
+ y: history_data,
+ mode: 'lines+markers',
+ type: 'scatter',
+ marker: {
+ color: plot_colors[i+1],
+ symbol: 'circle',
+ },
+ line: {color: plot_colors[i+1]},
+ name: name,
+ legendgroup: name + i,
+ hoverinfo: 'skip'
};
// check if length of history_data is >= mean_std_min_evals and if so get rolling mean and std if mean_plot_selection or std_plot_selection is checked
@@ -914,7 +1009,6 @@ function updatePlotSelection() {
var trace_std_upper = {
x: timestamp_data.slice(-history_std_data.length),
y: history_mean_data.map((x, i) => x + history_std_data[i]),
- // fill: 'tonexty',
fillcolor: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`,
mode: 'lines',
line: {width: 0, color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`},
@@ -951,7 +1045,8 @@ function updatePlotSelection() {
traces.push(trace_mean);
}
- traces.push(trace);
+ traces.push(main_trace);
+ traces.push(sample_size_trace);
}
@@ -1145,3 +1240,14 @@ function rollingStd(data, window) {
}
return std;
}
+
+ function formatDate(dateString) {
+ const date = new Date(dateString);
+ const year = date.getFullYear();
+ const month = (date.getMonth() + 1).toString().padStart(2, '0');
+ const day = date.getDate().toString().padStart(2, '0');
+ const hours = date.getHours().toString().padStart(2, '0');
+ const minutes = date.getMinutes().toString().padStart(2, '0');
+
+ return `${year}-${month}-${day} ${hours}:${minutes}`;
+ }
diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py
index 99e9c4500..94c8a556b 100644
--- a/cyclops/report/utils.py
+++ b/cyclops/report/utils.py
@@ -585,6 +585,25 @@ def get_timestamps(model_card: ModelCard) -> str:
return json.dumps(timestamps)
+def get_sample_sizes(model_card: ModelCard) -> str:
+ """Get all sample sizes from a model card."""
+ sample_sizes = {}
+ if (
+ (model_card.overview is None)
+ or (model_card.overview.metric_cards is None)
+ or (model_card.overview.metric_cards.collection is None)
+ ):
+ pass
+ else:
+ for itr, metric_card in enumerate(model_card.overview.metric_cards.collection):
+ sample_sizes[itr] = (
+ [str(sample_size) for sample_size in metric_card.sample_sizes]
+ if metric_card.sample_sizes is not None
+ else None
+ )
+ return json.dumps(sample_sizes)
+
+
def _extract_slices_and_values(
current_metrics: List[PerformanceMetric],
) -> Tuple[List[str], List[List[str]]]:
@@ -727,6 +746,7 @@ def _create_metric_card(
name: str,
history: List[float],
timestamps: List[str],
+ sample_sizes: List[int],
threshold: Union[float, None],
passed: Union[bool, None],
) -> MetricCard:
@@ -744,6 +764,8 @@ def _create_metric_card(
The timestamps for the metric card.
threshold : Union[float, None]
The threshold for the metric card.
+ sample_sizes : List[int]
+ The sample sizes for the metric card.
passed : Union[bool, None]
Whether or not the metric card passed.
@@ -772,6 +794,7 @@ def _create_metric_card(
passed=passed,
history=history,
timestamps=timestamps,
+ sample_sizes=sample_sizes,
)
@@ -811,7 +834,11 @@ def _get_metric_card(
timestamps = metric["last_metric_card"].timestamps
if timestamps is not None:
timestamps.append(timestamp)
+ sample_sizes = metric["last_metric_card"].sample_sizes
+ if sample_sizes is not None:
+ sample_sizes.append(0) # Append 0 for missing data
metric["last_metric_card"].timestamps = timestamps
+ metric["last_metric_card"].sample_sizes = sample_sizes
metric_card = metric["last_metric_card"]
elif (
metric["current_metric"] is not None
@@ -829,6 +856,9 @@ def _get_metric_card(
timestamps = metric["last_metric_card"].timestamps
if timestamps is not None:
timestamps.append(timestamp)
+ sample_sizes = metric["last_metric_card"].sample_sizes
+ if sample_sizes is not None:
+ sample_sizes.append(metric["current_metric"].sample_size)
else:
history = [
metric["current_metric"].value
@@ -840,12 +870,18 @@ def _get_metric_card(
else 0,
]
timestamps = [timestamp]
+ sample_sizes = (
+ [metric["current_metric"].sample_size]
+ if isinstance(metric["current_metric"], PerformanceMetric)
+ else [0]
+ )
if metric_card is None:
metric_card = _create_metric_card(
metric,
name,
history,
timestamps,
+ sample_sizes,
_get_threshold(metric),
_get_passed(metric),
)
diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
index 4e50a97d1..bb7bdb3dd 100644
--- a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
+++ b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
@@ -957,56 +957,36 @@
"metadata": {},
"outputs": [],
"source": [
+ "descriptions = {\n",
+ " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
+ " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
+ " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
+ " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
+ "}\n",
"model_name = f\"model_for_preds.{model_name}\"\n",
+ "\n",
"results_flat = flatten_results_dict(\n",
" results=results,\n",
" remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n",
" model_name=model_name,\n",
")\n",
+ "\n",
"results_female_flat = flatten_results_dict(\n",
" results=results_female,\n",
" model_name=model_name,\n",
")\n",
- "# ruff: noqa: W505\n",
- "for name, metric in results_female_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )\n",
"\n",
- "for name, metric in results_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )"
+ "report.log_performance_metrics(\n",
+ " results=results_flat,\n",
+ " metric_descriptions=descriptions,\n",
+ ")\n",
+ "\n",
+ "report.log_performance_metrics(\n",
+ " results=results_female_flat,\n",
+ " metric_descriptions=descriptions,\n",
+ ")"
]
},
{
@@ -1066,7 +1046,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# extracting the precision-recall curves and average precision results for all the slices\n",
+ "# extracting the precision-recall curves and\n",
+ "# average precision results for all the slices\n",
"pr_curves = {\n",
" slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
@@ -1228,9 +1209,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb
index 1f987623f..744ca43c7 100644
--- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb
+++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb
@@ -903,56 +903,36 @@
"metadata": {},
"outputs": [],
"source": [
+ "descriptions = {\n",
+ " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
+ " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
+ " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
+ " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
+ "}\n",
"model_name = f\"model_for_preds.{model_name}\"\n",
+ "\n",
"results_flat = flatten_results_dict(\n",
" results=results,\n",
" remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n",
" model_name=model_name,\n",
")\n",
+ "\n",
"results_female_flat = flatten_results_dict(\n",
" results=results_female,\n",
" model_name=model_name,\n",
")\n",
- "# ruff: noqa: W505\n",
- "for name, metric in results_female_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )\n",
"\n",
- "for name, metric in results_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )"
+ "report.log_performance_metrics(\n",
+ " results=results_flat,\n",
+ " metric_descriptions=descriptions,\n",
+ ")\n",
+ "\n",
+ "report.log_performance_metrics(\n",
+ " results=results_female_flat,\n",
+ " metric_descriptions=descriptions,\n",
+ ")"
]
},
{
@@ -1012,7 +992,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# extracting the precision-recall curves and average precision results for all the slices\n",
+ "# extracting the precision-recall curves and\n",
+ "# average precision results for all the slices\n",
"pr_curves = {\n",
" slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
@@ -1174,9 +1155,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb
index 5ad66b0c6..78a6c7218 100644
--- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb
+++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb
@@ -941,26 +941,18 @@
"metadata": {},
"outputs": [],
"source": [
- "# ruff: noqa: W505\n",
- "for name, metric in results_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )"
+ "descriptions = {\n",
+ " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
+ " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
+ " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
+ " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
+ "}\n",
+ "report.log_performance_metrics(\n",
+ " results=results_flat,\n",
+ " metric_descriptions=descriptions,\n",
+ ")"
]
},
{
@@ -1004,7 +996,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# extracting the precision-recall curves and average precision results for all the slices\n",
+ "# extracting the precision-recall curves and\n",
+ "# average precision results for all the slices\n",
"pr_curves = {\n",
" slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
@@ -1126,9 +1119,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/docs/source/tutorials/nihcxr/cxr_classification.ipynb b/docs/source/tutorials/nihcxr/cxr_classification.ipynb
index e2cbed78e..053c6db18 100644
--- a/docs/source/tutorials/nihcxr/cxr_classification.ipynb
+++ b/docs/source/tutorials/nihcxr/cxr_classification.ipynb
@@ -68,8 +68,8 @@
"\"\"\"Generate historical reports with validation data\n",
"for comparison with periodic report on test data.\"\"\"\n",
"\n",
- "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-19\" --seed 43\n",
- "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-16\" --seed 44\n",
+ "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-16\" --seed 43\n",
+ "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-19\" --seed 44\n",
"!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-22\" --seed 45\n",
"!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-30\" --seed 46"
]
@@ -127,7 +127,7 @@
"metadata": {},
"outputs": [],
"source": [
- "data_dir = \"/mnt/data/clinical_datasets/NIHCXR\"\n",
+ "data_dir = \"/mnt/data2/clinical_datasets/NIHCXR\"\n",
"nih_ds = load_nihcxr(data_dir)[\"test\"]\n",
"nih_ds = nih_ds.select(range(1000))\n",
"\n",
@@ -411,23 +411,33 @@
"results_flat = {}\n",
"for slice_, metrics in nih_eval_results_age[\"model_for_predictions.densenet\"].items():\n",
" for name, metric in metrics.items():\n",
- " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n",
- " for itr, m in enumerate(metric):\n",
- " if slice_ == \"overall\":\n",
- " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n",
- " else:\n",
- " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n",
+ " if \"sample_size\" in name:\n",
+ " results_flat[f\"{slice_}/{name}\"] = metric\n",
+ " else:\n",
+ " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n",
+ " for itr, m in enumerate(metric):\n",
+ " if slice_ == \"overall\":\n",
+ " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n",
+ " else:\n",
+ " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n",
"for slice_, metrics in nih_eval_results_gender[\n",
" \"model_for_predictions.densenet\"\n",
"].items():\n",
" for name, metric in metrics.items():\n",
- " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n",
- " for itr, m in enumerate(metric):\n",
- " if slice_ == \"overall\":\n",
- " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n",
- " else:\n",
- " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n",
- "\n",
+ " if \"sample_size\" in name:\n",
+ " results_flat[f\"{slice_}/{name}\"] = metric\n",
+ " else:\n",
+ " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n",
+ " for itr, m in enumerate(metric):\n",
+ " if slice_ == \"overall\":\n",
+ " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n",
+ " else:\n",
+ " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n",
+ "sample_sizes = {\n",
+ " key.split(\"/\")[0]: value\n",
+ " for key, value in results_flat.items()\n",
+ " if \"sample_size\" in key.split(\"/\")[1]\n",
+ "}\n",
"for name, metric in results_flat.items():\n",
" split, name = name.split(\"/\") # noqa: PLW2901\n",
" descriptions = {\n",
@@ -436,15 +446,26 @@
" \"MultilabelSensitivity\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
" \"MultilabelSpecificity\": \"The proportion of actual negative instances that are correctly predicted.\",\n",
" }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )"
+ " # remove the \"&pathology\" from the split name\n",
+ " # check if splits contains \"pathology:{pathology}\" and remove it\n",
+ " if \"pathology\" in split:\n",
+ " if len(split.split(\"&\")) == 1:\n",
+ " sample_size_split = \"overall\"\n",
+ " else:\n",
+ " sample_size_split = \"&\".join(split.split(\"&\")[:-1])\n",
+ " else:\n",
+ " sample_size_split = split\n",
+ " if name != \"sample_size\":\n",
+ " report.log_quantitative_analysis(\n",
+ " \"performance\",\n",
+ " name=name,\n",
+ " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n",
+ " description=descriptions.get(name),\n",
+ " metric_slice=split,\n",
+ " pass_fail_thresholds=0.7,\n",
+ " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
+ " sample_size=sample_sizes[sample_size_split],\n",
+ " )"
]
},
{
diff --git a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py
index 584bf8be8..7d0f13bc1 100644
--- a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py
+++ b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py
@@ -34,7 +34,7 @@
report = ModelCardReport()
-data_dir = "/mnt/data/clinical_datasets/NIHCXR"
+data_dir = "/mnt/data2/clinical_datasets/NIHCXR"
nih_ds = load_nihcxr(data_dir)[args.split]
# select a subset of the data
@@ -224,23 +224,33 @@
results_flat = {}
for slice_, metrics in nih_eval_results_age["model_for_predictions.densenet"].items():
for name, metric in metrics.items():
- results_flat[f"{slice_}/{name}"] = metric.mean()
- for itr, m in enumerate(metric):
- if slice_ == "overall":
- results_flat[f"pathology:{pathologies[itr]}/{name}"] = m
- else:
- results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m
+ if "sample_size" in name:
+ results_flat[f"{slice_}/{name}"] = metric
+ else:
+ results_flat[f"{slice_}/{name}"] = metric.mean()
+ for itr, m in enumerate(metric):
+ if slice_ == "overall":
+ results_flat[f"pathology:{pathologies[itr]}/{name}"] = m
+ else:
+ results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m
for slice_, metrics in nih_eval_results_gender[
"model_for_predictions.densenet"
].items():
for name, metric in metrics.items():
- results_flat[f"{slice_}/{name}"] = metric.mean()
- for itr, m in enumerate(metric):
- if slice_ == "overall":
- results_flat[f"pathology:{pathologies[itr]}/{name}"] = m
- else:
- results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m
-
+ if "sample_size" in name:
+ results_flat[f"{slice_}/{name}"] = metric
+ else:
+ results_flat[f"{slice_}/{name}"] = metric.mean()
+ for itr, m in enumerate(metric):
+ if slice_ == "overall":
+ results_flat[f"pathology:{pathologies[itr]}/{name}"] = m
+ else:
+ results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m
+sample_sizes = {
+ key.split("/")[0]: value
+ for key, value in results_flat.items()
+ if "sample_size" in key.split("/")[1]
+}
for name, metric in results_flat.items():
split, name = name.split("/") # noqa: PLW2901
descriptions = {
@@ -249,15 +259,26 @@
"MultilabelSensitivity": "The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.",
"MultilabelSpecificity": "The proportion of actual negative instances that are correctly predicted.",
}
- report.log_quantitative_analysis(
- "performance",
- name=name,
- value=metric.tolist() if isinstance(metric, np.generic) else metric,
- description=descriptions[name],
- metric_slice=split,
- pass_fail_thresholds=0.7,
- pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),
- )
+ # remove the "&pathology" from the split name
+ # check if splits contains "pathology:{pathology}" and remove it
+ if "pathology" in split:
+ if len(split.split("&")) == 1:
+ sample_size_split = "overall"
+ else:
+ sample_size_split = "&".join(split.split("&")[:-1])
+ else:
+ sample_size_split = split
+ if name != "sample_size":
+ report.log_quantitative_analysis(
+ "performance",
+ name=name,
+ value=metric.tolist() if isinstance(metric, np.generic) else metric,
+ description=descriptions.get(name),
+ metric_slice=split,
+ pass_fail_thresholds=0.7,
+ pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),
+ sample_size=sample_sizes[sample_size_split],
+ )
# model details for NIH Chest X-Ray model
@@ -362,6 +383,6 @@
)
report_path = report.export(
- output_filename=f"nihcxr_report_periodic_{args.synthetic_timestamp}.html",
+ output_filename="nihcxr_report_periodic.html",
synthetic_timestamp=args.synthetic_timestamp,
)
diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb
index 53d4e4304..a8c89da24 100644
--- a/docs/source/tutorials/synthea/los_prediction.ipynb
+++ b/docs/source/tutorials/synthea/los_prediction.ipynb
@@ -1085,51 +1085,34 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d322a86f-1f7c-42f6-8a97-8a18ea8622e2",
+ "id": "d33a171c-02ef-4bc9-a3bf-87320c7c83d6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
+ "descriptions = {\n",
+ " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
+ " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
+ " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
+ " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
+ " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
+ "}\n",
"model_name = f\"model_for_preds.{model_name}\"\n",
+ "\n",
"results_flat = flatten_results_dict(\n",
" results=results,\n",
" remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n",
" model_name=model_name,\n",
+ ")\n",
+ "\n",
+ "report.log_performance_metrics(\n",
+ " results=results_flat,\n",
+ " metric_descriptions=descriptions,\n",
")"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d33a171c-02ef-4bc9-a3bf-87320c7c83d6",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "for name, metric in results_flat.items():\n",
- " split, name = name.split(\"/\") # noqa: PLW2901\n",
- " if name == \"BinaryConfusionMatrix\":\n",
- " continue\n",
- " descriptions = {\n",
- " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
- " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
- " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
- " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
- " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
- " }\n",
- " report.log_quantitative_analysis(\n",
- " \"performance\",\n",
- " name=name,\n",
- " value=metric.tolist(),\n",
- " description=descriptions[name],\n",
- " metric_slice=split,\n",
- " pass_fail_thresholds=0.7,\n",
- " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
- " )"
- ]
- },
{
"cell_type": "markdown",
"id": "0ee464d3-7246-4c6f-ac2d-8cca2a985f22",
@@ -1293,9 +1276,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
- "# remove the group size from the fairness results and add it to the slice name\n",
+ "# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
- " group_size = slice_results.pop(\"Group Size\")\n",
+ " group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
diff --git a/tests/cyclops/report/test_report.py b/tests/cyclops/report/test_report.py
index 33dd71358..2c5e72965 100644
--- a/tests/cyclops/report/test_report.py
+++ b/tests/cyclops/report/test_report.py
@@ -2,6 +2,8 @@
from unittest import TestCase
+import numpy as np
+
from cyclops.report import ModelCardReport
from cyclops.report.model_card.sections import ModelDetails
@@ -349,6 +351,7 @@ def test_export(self):
decision_threshold=0.7,
pass_fail_thresholds=[0.6, 0.65, 0.7],
pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)],
+ sample_size=100,
)
self.model_card_report.log_quantitative_analysis(
analysis_type="performance",
@@ -359,8 +362,119 @@ def test_export(self):
description="F1 score of the model on the test set",
pass_fail_thresholds=[0.9, 0.85, 0.8],
pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)],
+ sample_size=100,
)
self.model_card_report.log_owner(name="John Doe")
report_path = self.model_card_report.export(interactive=False, save_json=False)
assert isinstance(report_path, str)
+
+
+def test_log_performance_metrics():
+ """Test log_performance_metrics."""
+ report = ModelCardReport()
+
+ # Mock results
+ results = {
+ "overall/BinaryAccuracy": np.array(0.85),
+ "overall/BinaryPrecision": np.array(0.78),
+ "overall/BinaryRecall": np.array(0.92),
+ "overall/BinaryF1Score": np.array(0.84),
+ "overall/BinaryAUROC": np.array(0.91),
+ "overall/BinaryAveragePrecision": np.array(0.88),
+ "overall/sample_size": 1000,
+ "slice1/BinaryAccuracy": np.array(0.82),
+ "slice1/BinaryPrecision": np.array(0.75),
+ "slice1/BinaryRecall": np.array(0.89),
+ "slice1/BinaryF1Score": np.array(0.81),
+ "slice1/BinaryAUROC": np.array(0.88),
+ "slice1/BinaryAveragePrecision": np.array(0.85),
+ "slice1/sample_size": 500,
+ }
+
+ # Mock metric descriptions
+ metric_descriptions = {
+ "BinaryAccuracy": "The proportion of all instances that are correctly predicted.",
+ "BinaryPrecision": "The proportion of predicted positive instances that are correctly predicted.",
+ "BinaryRecall": "The proportion of actual positive instances that are correctly predicted.",
+ "BinaryF1Score": "The harmonic mean of precision and recall.",
+ "BinaryAUROC": "The area under the ROC curve.",
+ "BinaryAveragePrecision": "The area under the precision-recall curve.",
+ }
+
+ # Test with a single threshold
+ report.log_performance_metrics(
+ results, metric_descriptions, pass_fail_thresholds=0.8
+ )
+
+ # Check if metrics were logged correctly
+ assert report._model_card.quantitative_analysis is not None
+ assert (
+ len(report._model_card.quantitative_analysis.performance_metrics) == 12
+ ) # 6 metrics * 2 slices
+
+ # Check a few specific metrics
+ metrics = report._model_card.quantitative_analysis.performance_metrics
+
+ overall_accuracy = next(
+ m for m in metrics if m.type == "BinaryAccuracy" and m.slice == "overall"
+ )
+ assert overall_accuracy.value == 0.85
+ assert (
+ overall_accuracy.description
+ == "The proportion of all instances that are correctly predicted."
+ )
+ assert overall_accuracy.sample_size == 1000
+ assert overall_accuracy.tests[0].threshold == 0.8
+ assert overall_accuracy.tests[0].passed
+
+ slice1_precision = next(
+ m for m in metrics if m.type == "BinaryPrecision" and m.slice == "slice1"
+ )
+ assert slice1_precision.value == 0.75
+ assert (
+ slice1_precision.description
+ == "The proportion of predicted positive instances that are correctly predicted."
+ )
+ assert slice1_precision.sample_size == 500
+ assert slice1_precision.tests[0].threshold == 0.8
+ assert not slice1_precision.tests[0].passed
+
+ # Reset the report
+ report = ModelCardReport()
+
+ # Test with per-metric thresholds
+ pass_fail_thresholds = {
+ "overall/BinaryAccuracy": 0.9,
+ "overall/BinaryPrecision": 0.75,
+ "slice1/BinaryRecall": 0.85,
+ }
+ report.log_performance_metrics(
+ results, metric_descriptions, pass_fail_thresholds=pass_fail_thresholds
+ )
+
+ metrics = report._model_card.quantitative_analysis.performance_metrics
+
+ overall_accuracy = next(
+ m for m in metrics if m.type == "BinaryAccuracy" and m.slice == "overall"
+ )
+ assert overall_accuracy.tests[0].threshold == 0.9
+ assert not overall_accuracy.tests[0].passed
+
+ overall_precision = next(
+ m for m in metrics if m.type == "BinaryPrecision" and m.slice == "overall"
+ )
+ assert overall_precision.tests[0].threshold == 0.75
+ assert overall_precision.tests[0].passed
+
+ slice1_recall = next(
+ m for m in metrics if m.type == "BinaryRecall" and m.slice == "slice1"
+ )
+ assert slice1_recall.tests[0].threshold == 0.85
+ assert slice1_recall.tests[0].passed
+
+ slice1_f1 = next(
+ m for m in metrics if m.type == "BinaryF1Score" and m.slice == "slice1"
+ )
+ assert slice1_f1.tests[0].threshold == 0.7 # Default threshold
+ assert slice1_f1.tests[0].passed
diff --git a/tests/cyclops/report/test_utils.py b/tests/cyclops/report/test_utils.py
index c07c26e52..200a022bd 100644
--- a/tests/cyclops/report/test_utils.py
+++ b/tests/cyclops/report/test_utils.py
@@ -267,6 +267,7 @@ def model_card():
history=[0.8, 0.85, 0.9],
trend="positive",
plot=GraphicsCollection(collection=[Graphic(name="Accuracy")]),
+ sample_sizes=[100, 200, 300],
),
MetricCard(
name="Precision",
@@ -279,6 +280,7 @@ def model_card():
history=[0.7, 0.8, 0.9],
trend="positive",
plot=GraphicsCollection(collection=[Graphic(name="Precision")]),
+ sample_sizes=[100, 200, 300],
),
],
),
@@ -290,12 +292,14 @@ def model_card():
value=0.85,
slice="overall",
tests=[Test()],
+ sample_size=100,
),
PerformanceMetric(
type="BinaryPrecision",
value=0.8,
slice="overall",
tests=[Test()],
+ sample_size=100,
),
]
return model_card
@@ -395,6 +399,7 @@ def test_create_metric_cards(model_card):
description="Accuracy of binary classification",
graphics=None,
tests=None,
+ sample_size=100,
),
PerformanceMetric(
type="MulticlassPrecision",
@@ -403,6 +408,7 @@ def test_create_metric_cards(model_card):
description="Precision of multiclass classification",
graphics=None,
tests=None,
+ sample_size=100,
),
]
timestamp = "2022-01-01"
@@ -417,6 +423,7 @@ def test_create_metric_cards(model_card):
passed=False,
history=[0.75, 0.8, 0.85],
timestamps=["2021-01-01", "2021-02-01", "2021-03-01"],
+ sample_sizes=[100, 200, 300],
),
MetricCard(
name="MulticlassPrecision",
@@ -428,6 +435,7 @@ def test_create_metric_cards(model_card):
passed=True,
history=[0.8, 0.85, 0.9],
timestamps=["2021-01-01", "2021-02-01", "2021-03-01"],
+ sample_sizes=[100, 200, 300],
),
]
metrics, tooltips, slices, values, metric_cards = create_metric_cards(