From 332898e7e285394c9a68517a192bc6d99fa9923b Mon Sep 17 00:00:00 2001 From: a-kore <37000693+a-kore@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:13:56 -0400 Subject: [PATCH] add sample_size visualization and update notebooks (#667) * add sample_size vis and update notebooks * fix mypy error * change Group Size to sample_size --- benchmarks/mimiciv/discharge_prediction.ipynb | 4 +- .../mimiciv/icu_mortality_prediction.ipynb | 4 +- cyclops/evaluate/evaluator.py | 1 + cyclops/evaluate/fairness/evaluator.py | 4 +- cyclops/report/model_card/fields.py | 10 ++ cyclops/report/report.py | 92 ++++++---- cyclops/report/templates/model_report/plot.js | 160 +++++++++++++++--- cyclops/report/utils.py | 36 ++++ .../diabetes_130/readmission_prediction.ipynb | 65 +++---- .../kaggle/heart_failure_prediction.ipynb | 65 +++---- .../mimiciv/mortality_prediction.ipynb | 39 ++--- .../tutorials/nihcxr/cxr_classification.ipynb | 71 +++++--- .../nihcxr/generate_nihcxr_report.py | 69 +++++--- .../tutorials/synthea/los_prediction.ipynb | 51 ++---- tests/cyclops/report/test_report.py | 114 +++++++++++++ tests/cyclops/report/test_utils.py | 8 + 16 files changed, 540 insertions(+), 253 deletions(-) diff --git a/benchmarks/mimiciv/discharge_prediction.ipynb b/benchmarks/mimiciv/discharge_prediction.ipynb index 0de32ccd7..ffc69ae9e 100644 --- a/benchmarks/mimiciv/discharge_prediction.ipynb +++ b/benchmarks/mimiciv/discharge_prediction.ipynb @@ -1182,9 +1182,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/benchmarks/mimiciv/icu_mortality_prediction.ipynb b/benchmarks/mimiciv/icu_mortality_prediction.ipynb index 5009131d6..53d4d0f75 100644 --- a/benchmarks/mimiciv/icu_mortality_prediction.ipynb +++ b/benchmarks/mimiciv/icu_mortality_prediction.ipynb @@ -1159,9 +1159,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py index bfeb54bb7..cf4d01611 100644 --- a/cyclops/evaluate/evaluator.py +++ b/cyclops/evaluate/evaluator.py @@ -311,6 +311,7 @@ def _compute_metrics( model_name: str = "model_for_%s" % prediction_column results.setdefault(model_name, {}) results[model_name][slice_name] = metric_output + results[model_name][slice_name]["sample_size"] = len(sliced_dataset) set_decode(dataset, True) # restore decoding features diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py index 9b25773bf..22b39d175 100644 --- a/cyclops/evaluate/fairness/evaluator.py +++ b/cyclops/evaluate/fairness/evaluator.py @@ -260,7 +260,7 @@ def evaluate_fairness( # noqa: PLR0912 for prediction_column in fmt_prediction_columns: results.setdefault(prediction_column, {}) results[prediction_column].setdefault(slice_name, {}).update( - {"Group Size": len(sliced_dataset)}, + {"sample_size": len(sliced_dataset)}, ) pred_result = _get_metric_results_for_prediction_and_slice( @@ -966,7 +966,7 @@ def _compute_parity_metrics( parity_results[key] = {} for slice_name, slice_result in prediction_result.items(): for metric_name, metric_value in slice_result.items(): - if metric_name == "Group Size": + if metric_name == "sample_size": continue # add 'Parity' to the metric name before @threshold, if specified diff --git a/cyclops/report/model_card/fields.py b/cyclops/report/model_card/fields.py index 12dd29497..f4f8162b0 100644 --- a/cyclops/report/model_card/fields.py +++ b/cyclops/report/model_card/fields.py @@ -380,6 +380,11 @@ class PerformanceMetric( default_factory=list, ) + sample_size: Optional[StrictInt] = Field( + None, + description="The sample size used to compute this metric.", + ) + class User( BaseModelCardField, @@ -599,6 +604,11 @@ class MetricCard( description="Timestamps for each point in the history.", ) + sample_sizes: Optional[List[int]] = Field( + None, + description="Sample sizes for each point in the history.", + ) + class MetricCardCollection(BaseModelCardField, composable_with="Overview"): """A collection of metric cards to be displayed in the model card.""" diff --git a/cyclops/report/report.py b/cyclops/report/report.py index 903a7b682..786dae53c 100644 --- a/cyclops/report/report.py +++ b/cyclops/report/report.py @@ -48,6 +48,7 @@ get_histories, get_names, get_passed, + get_sample_sizes, get_slices, get_thresholds, get_timestamps, @@ -855,6 +856,7 @@ def log_quantitative_analysis( pass_fail_threshold_fns: Optional[ Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]] ] = None, + sample_size: Optional[int] = None, **extra: Any, ) -> None: """Add a quantitative analysis to the report. @@ -921,6 +923,7 @@ def log_quantitative_analysis( "slice": metric_slice, "decision_threshold": decision_threshold, "description": description, + "sample_size": sample_size, **extra, } @@ -958,42 +961,70 @@ def log_quantitative_analysis( field_type=field_type, ) - def log_performance_metrics(self, metrics: Dict[str, Any]) -> None: - """Add a performance metric to the `Quantitative Analysis` section. + def log_performance_metrics( + self, + results: Dict[str, Any], + metric_descriptions: Dict[str, str], + pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7, + pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x, + threshold: bool(x >= threshold), + ) -> None: + """ + Log all performance metrics to the model card report. Parameters ---------- - metrics : Dict[str, Any] - A dictionary of performance metrics. The keys should be the name of the - metric, and the values should be the value of the metric. If the metric - is a slice metric, the key should be the slice name followed by a slash - and then the metric name (e.g. "slice_name/metric_name"). If no slice - name is provided, the slice name will be "overall". - - Raises - ------ - TypeError - If the given metrics are not a dictionary with string keys. + results : Dict[str, Any] + Dictionary containing the results, + with keys in the format "split/metric_name". + metric_descriptions : Dict[str, str] + Dictionary mapping metric names to their descriptions. + pass_fail_thresholds : Union[float, Dict[str, float]], optional + The threshold(s) for pass/fail tests. + Can be a single float applied to all metrics, + or a dictionary mapping "split/metric_name" to individual thresholds. + Default is 0.7. + pass_fail_threshold_fn : Callable[[float, float], bool], optional + Function to determine if a metric passes or fails. + Default is lambda x, threshold: bool(x >= threshold). + Returns + ------- + None """ - _raise_if_not_dict_with_str_keys(metrics) - for metric_name, metric_value in metrics.items(): - name_split = metric_name.split("/") - if len(name_split) == 1: - slice_name = "overall" - metric_name = name_split[0] # noqa: PLW2901 - else: # everything before the last slash is the slice name - slice_name = "/".join(name_split[:-1]) - metric_name = name_split[-1] # noqa: PLW2901 - - # TODO: create plot + # Extract sample sizes + sample_sizes = { + key.split("/")[0]: value + for key, value in results.items() + if "sample_size" in key.split("/")[1] + } - self._log_field( - data={"type": metric_name, "value": metric_value, "slice": slice_name}, - section_name="quantitative_analysis", - field_name="performance_metrics", - field_type=PerformanceMetric, - ) + # Log metrics + for name, metric in results.items(): + split, metric_name = name.split("/") + if metric_name != "sample_size": + metric_value = metric.tolist() if hasattr(metric, "tolist") else metric + + # Determine the threshold for this specific metric + if isinstance(pass_fail_thresholds, dict): + threshold = pass_fail_thresholds.get( + name, 0.7 + ) # Default to 0.7 if not specified + else: + threshold = pass_fail_thresholds + + self.log_quantitative_analysis( + "performance", + name=metric_name, + value=metric_value, + description=metric_descriptions.get( + metric_name, "No description provided." + ), + metric_slice=split, + pass_fail_thresholds=threshold, + pass_fail_threshold_fns=pass_fail_threshold_fn, + sample_size=sample_sizes.get(split), + ) # TODO: MERGE/COMPARE MODEL CARDS @@ -1162,6 +1193,7 @@ def export( "get_names": get_names, "get_histories": get_histories, "get_timestamps": get_timestamps, + "get_sample_sizes": get_sample_sizes, } template.globals.update(func_dict) diff --git a/cyclops/report/templates/model_report/plot.js b/cyclops/report/templates/model_report/plot.js index 8f94071aa..2cf936b27 100644 --- a/cyclops/report/templates/model_report/plot.js +++ b/cyclops/report/templates/model_report/plot.js @@ -1,5 +1,6 @@ // Javascript code for plots in the model card template - +const MAX_SIZE = 20; +let maxSampleSize = 0; // Define a function to update the plot based on selected filters function updatePlot() { @@ -140,6 +141,7 @@ function updatePlot() { var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); + var sample_sizes_all = JSON.parse({{ get_sample_sizes(model_card)|safe|tojson }}); for (let i = 0; i < selection.length; i++) { // use selection to set label_slice_selection background color @@ -223,6 +225,21 @@ function updatePlot() { } } + // Find the maximum sample size across all selections + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + continue; + } + selection = selections[i] + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var sample_size_data = []; + for (let i = 0; i < sample_sizes_all[idx].length; i++) { + sample_size_data.push(sample_sizes_all[idx][i]); + } + maxSampleSize = Math.max(...sample_size_data); + } + traces = []; for (let i = 0; i < selections.length; i++) { if (selections[i] === null) { @@ -237,11 +254,17 @@ function updatePlot() { } var timestamp_data = []; for (let i = 0; i < timestamps_all[idx].length; i++) { - timestamp_data.push(timestamps_all[idx][i]); + // timestamp_data.push(timestamps_all[idx][i]); + timestamp_data.push(formatDate(timestamps_all[idx][i])); + } + var sample_size_data = []; + for (let i = 0; i < sample_sizes_all[idx].length; i++) { + sample_size_data.push(sample_sizes_all[idx][i]); } var last_n_evals = document.getElementById("n_evals_slider_pot").value; history_data = history_data.slice(-last_n_evals); timestamp_data = timestamp_data.slice(-last_n_evals); + sample_size_data = sample_size_data.slice(-last_n_evals); // get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat var slope = lineOfBestFit(history_data)[0]; if (slope > 0.01) { @@ -323,17 +346,41 @@ function updatePlot() { }; traces.push(threshold_trace); } - var trace = { - // range of x is the length of the list of floats - x: timestamp_data, - y: history_data, - mode: 'lines+markers', - type: 'scatter', - marker: {color: plot_colors[i+1]}, - line: {color: plot_colors[i+1]}, - name: name, - legendgroup: name + i, - //name: selection.toString(), + // Add sample size circles + var sample_size_trace = { + x: timestamp_data, + y: history_data, + mode: 'markers', + marker: { + sizemode: 'area', + size: sample_size_data, + sizeref: maxSampleSize / MAX_SIZE ** 2, + color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`, + line: {width: 0}, + }, + text: sample_size_data.map((s, index) => + `Date: ${timestamp_data[index]}
Value: ${history_data[index].toFixed(2)}
Sample Size: ${s}` + ), + hoverinfo: 'text', + hovertemplate: '%{text}', + name: name + ' (Sample Size)', + legendgroup: name + i, + }; + + // Add main data points and line + var main_trace = { + x: timestamp_data, + y: history_data, + mode: 'lines+markers', + type: 'scatter', + marker: { + color: plot_colors[i+1], + symbol: 'circle', + }, + line: {color: plot_colors[i+1]}, + name: name, + legendgroup: name + i, + hoverinfo: 'skip' }; // check if length of history_data is >= mean_std_min_evals and if so get rolling mean and std if mean_plot_selection or std_plot_selection is checked @@ -384,7 +431,8 @@ function updatePlot() { }; traces.push(trace_mean); } - traces.push(trace); + traces.push(sample_size_trace); + traces.push(main_trace); } @@ -744,6 +792,7 @@ function updatePlotSelection() { var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); + var sample_sizes_all = JSON.parse({{ get_sample_sizes(model_card)|safe|tojson }}); var radioGroups = {}; var labelGroups = {}; @@ -786,6 +835,21 @@ function updatePlotSelection() { } } + // Find the maximum sample size across all selections + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + continue; + } + selection = selections[i] + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var sample_size_data = []; + for (let i = 0; i < sample_sizes_all[idx].length; i++) { + sample_size_data.push(sample_sizes_all[idx][i]); + } + maxSampleSize = Math.max(...sample_size_data); +} + traces = []; var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); for (let i = 0; i < selections.length; i++) { @@ -802,11 +866,17 @@ function updatePlotSelection() { } var timestamp_data = []; for (let i = 0; i < timestamps_all[idx].length; i++) { - timestamp_data.push(timestamps_all[idx][i]); + // timestamp_data.push(timestamps_all[idx][i]); + timestamp_data.push(formatDate(timestamps_all[idx][i])); + } + var sample_size_data = []; + for (let i = 0; i < sample_sizes_all[idx].length; i++) { + sample_size_data.push(sample_sizes_all[idx][i]); } var last_n_evals = document.getElementById("n_evals_slider_pot").value; history_data = history_data.slice(-last_n_evals); timestamp_data = timestamp_data.slice(-last_n_evals); + sample_size_data = sample_size_data.slice(-last_n_evals); // get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat var slope = lineOfBestFit(history_data)[0]; @@ -891,16 +961,41 @@ function updatePlotSelection() { traces.push(threshold_trace); } - var trace = { - // range of x is the length of the list of floats - x: timestamp_data, - y: history_data, - mode: 'lines+markers', - type: 'scatter', - marker: {color: plot_colors[i+1]}, - line: {color: plot_colors[i+1]}, - name: name, - legendgroup: name + i, + // Add sample size circles + var sample_size_trace = { + x: timestamp_data, + y: history_data, + mode: 'markers', + marker: { + sizemode: 'area', + size: sample_size_data, + sizeref: maxSampleSize / MAX_SIZE ** 2, + color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`, + line: {width: 0}, + }, + text: sample_size_data.map((s, index) => + `Date: ${timestamp_data[index]}
Value: ${history_data[index].toFixed(2)}
Sample Size: ${s}` + ), + hoverinfo: 'text', + hovertemplate: '%{text}', + name: name + ' (Sample Size)', + legendgroup: name + i, + }; + + // Add main data points and line + var main_trace = { + x: timestamp_data, + y: history_data, + mode: 'lines+markers', + type: 'scatter', + marker: { + color: plot_colors[i+1], + symbol: 'circle', + }, + line: {color: plot_colors[i+1]}, + name: name, + legendgroup: name + i, + hoverinfo: 'skip' }; // check if length of history_data is >= mean_std_min_evals and if so get rolling mean and std if mean_plot_selection or std_plot_selection is checked @@ -914,7 +1009,6 @@ function updatePlotSelection() { var trace_std_upper = { x: timestamp_data.slice(-history_std_data.length), y: history_mean_data.map((x, i) => x + history_std_data[i]), - // fill: 'tonexty', fillcolor: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`, mode: 'lines', line: {width: 0, color: `rgba(${plot_colors[i+1].slice(4, -1)}, 0.3)`}, @@ -951,7 +1045,8 @@ function updatePlotSelection() { traces.push(trace_mean); } - traces.push(trace); + traces.push(main_trace); + traces.push(sample_size_trace); } @@ -1145,3 +1240,14 @@ function rollingStd(data, window) { } return std; } + + function formatDate(dateString) { + const date = new Date(dateString); + const year = date.getFullYear(); + const month = (date.getMonth() + 1).toString().padStart(2, '0'); + const day = date.getDate().toString().padStart(2, '0'); + const hours = date.getHours().toString().padStart(2, '0'); + const minutes = date.getMinutes().toString().padStart(2, '0'); + + return `${year}-${month}-${day} ${hours}:${minutes}`; + } diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py index 99e9c4500..94c8a556b 100644 --- a/cyclops/report/utils.py +++ b/cyclops/report/utils.py @@ -585,6 +585,25 @@ def get_timestamps(model_card: ModelCard) -> str: return json.dumps(timestamps) +def get_sample_sizes(model_card: ModelCard) -> str: + """Get all sample sizes from a model card.""" + sample_sizes = {} + if ( + (model_card.overview is None) + or (model_card.overview.metric_cards is None) + or (model_card.overview.metric_cards.collection is None) + ): + pass + else: + for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): + sample_sizes[itr] = ( + [str(sample_size) for sample_size in metric_card.sample_sizes] + if metric_card.sample_sizes is not None + else None + ) + return json.dumps(sample_sizes) + + def _extract_slices_and_values( current_metrics: List[PerformanceMetric], ) -> Tuple[List[str], List[List[str]]]: @@ -727,6 +746,7 @@ def _create_metric_card( name: str, history: List[float], timestamps: List[str], + sample_sizes: List[int], threshold: Union[float, None], passed: Union[bool, None], ) -> MetricCard: @@ -744,6 +764,8 @@ def _create_metric_card( The timestamps for the metric card. threshold : Union[float, None] The threshold for the metric card. + sample_sizes : List[int] + The sample sizes for the metric card. passed : Union[bool, None] Whether or not the metric card passed. @@ -772,6 +794,7 @@ def _create_metric_card( passed=passed, history=history, timestamps=timestamps, + sample_sizes=sample_sizes, ) @@ -811,7 +834,11 @@ def _get_metric_card( timestamps = metric["last_metric_card"].timestamps if timestamps is not None: timestamps.append(timestamp) + sample_sizes = metric["last_metric_card"].sample_sizes + if sample_sizes is not None: + sample_sizes.append(0) # Append 0 for missing data metric["last_metric_card"].timestamps = timestamps + metric["last_metric_card"].sample_sizes = sample_sizes metric_card = metric["last_metric_card"] elif ( metric["current_metric"] is not None @@ -829,6 +856,9 @@ def _get_metric_card( timestamps = metric["last_metric_card"].timestamps if timestamps is not None: timestamps.append(timestamp) + sample_sizes = metric["last_metric_card"].sample_sizes + if sample_sizes is not None: + sample_sizes.append(metric["current_metric"].sample_size) else: history = [ metric["current_metric"].value @@ -840,12 +870,18 @@ def _get_metric_card( else 0, ] timestamps = [timestamp] + sample_sizes = ( + [metric["current_metric"].sample_size] + if isinstance(metric["current_metric"], PerformanceMetric) + else [0] + ) if metric_card is None: metric_card = _create_metric_card( metric, name, history, timestamps, + sample_sizes, _get_threshold(metric), _get_passed(metric), ) diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb index 4e50a97d1..bb7bdb3dd 100644 --- a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb +++ b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb @@ -957,56 +957,36 @@ "metadata": {}, "outputs": [], "source": [ + "descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + "}\n", "model_name = f\"model_for_preds.{model_name}\"\n", + "\n", "results_flat = flatten_results_dict(\n", " results=results,\n", " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")\n", + "\n", "results_female_flat = flatten_results_dict(\n", " results=results_female,\n", " model_name=model_name,\n", ")\n", - "# ruff: noqa: W505\n", - "for name, metric in results_female_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )\n", "\n", - "for name, metric in results_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )" + "report.log_performance_metrics(\n", + " results=results_flat,\n", + " metric_descriptions=descriptions,\n", + ")\n", + "\n", + "report.log_performance_metrics(\n", + " results=results_female_flat,\n", + " metric_descriptions=descriptions,\n", + ")" ] }, { @@ -1066,7 +1046,8 @@ "metadata": {}, "outputs": [], "source": [ - "# extracting the precision-recall curves and average precision results for all the slices\n", + "# extracting the precision-recall curves and\n", + "# average precision results for all the slices\n", "pr_curves = {\n", " slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n", " for slice_name, slice_results in results[model_name].items()\n", @@ -1228,9 +1209,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index 1f987623f..744ca43c7 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -903,56 +903,36 @@ "metadata": {}, "outputs": [], "source": [ + "descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + "}\n", "model_name = f\"model_for_preds.{model_name}\"\n", + "\n", "results_flat = flatten_results_dict(\n", " results=results,\n", " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")\n", + "\n", "results_female_flat = flatten_results_dict(\n", " results=results_female,\n", " model_name=model_name,\n", ")\n", - "# ruff: noqa: W505\n", - "for name, metric in results_female_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )\n", "\n", - "for name, metric in results_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )" + "report.log_performance_metrics(\n", + " results=results_flat,\n", + " metric_descriptions=descriptions,\n", + ")\n", + "\n", + "report.log_performance_metrics(\n", + " results=results_female_flat,\n", + " metric_descriptions=descriptions,\n", + ")" ] }, { @@ -1012,7 +992,8 @@ "metadata": {}, "outputs": [], "source": [ - "# extracting the precision-recall curves and average precision results for all the slices\n", + "# extracting the precision-recall curves and\n", + "# average precision results for all the slices\n", "pr_curves = {\n", " slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n", " for slice_name, slice_results in results[model_name].items()\n", @@ -1174,9 +1155,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb index 5ad66b0c6..78a6c7218 100644 --- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -941,26 +941,18 @@ "metadata": {}, "outputs": [], "source": [ - "# ruff: noqa: W505\n", - "for name, metric in results_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )" + "descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + "}\n", + "report.log_performance_metrics(\n", + " results=results_flat,\n", + " metric_descriptions=descriptions,\n", + ")" ] }, { @@ -1004,7 +996,8 @@ "metadata": {}, "outputs": [], "source": [ - "# extracting the precision-recall curves and average precision results for all the slices\n", + "# extracting the precision-recall curves and\n", + "# average precision results for all the slices\n", "pr_curves = {\n", " slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n", " for slice_name, slice_results in results[model_name].items()\n", @@ -1126,9 +1119,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/docs/source/tutorials/nihcxr/cxr_classification.ipynb b/docs/source/tutorials/nihcxr/cxr_classification.ipynb index e2cbed78e..053c6db18 100644 --- a/docs/source/tutorials/nihcxr/cxr_classification.ipynb +++ b/docs/source/tutorials/nihcxr/cxr_classification.ipynb @@ -68,8 +68,8 @@ "\"\"\"Generate historical reports with validation data\n", "for comparison with periodic report on test data.\"\"\"\n", "\n", - "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-19\" --seed 43\n", - "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-16\" --seed 44\n", + "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-16\" --seed 43\n", + "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-19\" --seed 44\n", "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-22\" --seed 45\n", "!python3 generate_nihcxr_report.py --synthetic_timestamp \"2023-10-30\" --seed 46" ] @@ -127,7 +127,7 @@ "metadata": {}, "outputs": [], "source": [ - "data_dir = \"/mnt/data/clinical_datasets/NIHCXR\"\n", + "data_dir = \"/mnt/data2/clinical_datasets/NIHCXR\"\n", "nih_ds = load_nihcxr(data_dir)[\"test\"]\n", "nih_ds = nih_ds.select(range(1000))\n", "\n", @@ -411,23 +411,33 @@ "results_flat = {}\n", "for slice_, metrics in nih_eval_results_age[\"model_for_predictions.densenet\"].items():\n", " for name, metric in metrics.items():\n", - " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n", - " for itr, m in enumerate(metric):\n", - " if slice_ == \"overall\":\n", - " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n", - " else:\n", - " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n", + " if \"sample_size\" in name:\n", + " results_flat[f\"{slice_}/{name}\"] = metric\n", + " else:\n", + " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n", + " for itr, m in enumerate(metric):\n", + " if slice_ == \"overall\":\n", + " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n", + " else:\n", + " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n", "for slice_, metrics in nih_eval_results_gender[\n", " \"model_for_predictions.densenet\"\n", "].items():\n", " for name, metric in metrics.items():\n", - " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n", - " for itr, m in enumerate(metric):\n", - " if slice_ == \"overall\":\n", - " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n", - " else:\n", - " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n", - "\n", + " if \"sample_size\" in name:\n", + " results_flat[f\"{slice_}/{name}\"] = metric\n", + " else:\n", + " results_flat[f\"{slice_}/{name}\"] = metric.mean()\n", + " for itr, m in enumerate(metric):\n", + " if slice_ == \"overall\":\n", + " results_flat[f\"pathology:{pathologies[itr]}/{name}\"] = m\n", + " else:\n", + " results_flat[f\"{slice_}&pathology:{pathologies[itr]}/{name}\"] = m\n", + "sample_sizes = {\n", + " key.split(\"/\")[0]: value\n", + " for key, value in results_flat.items()\n", + " if \"sample_size\" in key.split(\"/\")[1]\n", + "}\n", "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", " descriptions = {\n", @@ -436,15 +446,26 @@ " \"MultilabelSensitivity\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", " \"MultilabelSpecificity\": \"The proportion of actual negative instances that are correctly predicted.\",\n", " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )" + " # remove the \"&pathology\" from the split name\n", + " # check if splits contains \"pathology:{pathology}\" and remove it\n", + " if \"pathology\" in split:\n", + " if len(split.split(\"&\")) == 1:\n", + " sample_size_split = \"overall\"\n", + " else:\n", + " sample_size_split = \"&\".join(split.split(\"&\")[:-1])\n", + " else:\n", + " sample_size_split = split\n", + " if name != \"sample_size\":\n", + " report.log_quantitative_analysis(\n", + " \"performance\",\n", + " name=name,\n", + " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n", + " description=descriptions.get(name),\n", + " metric_slice=split,\n", + " pass_fail_thresholds=0.7,\n", + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", + " sample_size=sample_sizes[sample_size_split],\n", + " )" ] }, { diff --git a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py index 584bf8be8..7d0f13bc1 100644 --- a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py +++ b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py @@ -34,7 +34,7 @@ report = ModelCardReport() -data_dir = "/mnt/data/clinical_datasets/NIHCXR" +data_dir = "/mnt/data2/clinical_datasets/NIHCXR" nih_ds = load_nihcxr(data_dir)[args.split] # select a subset of the data @@ -224,23 +224,33 @@ results_flat = {} for slice_, metrics in nih_eval_results_age["model_for_predictions.densenet"].items(): for name, metric in metrics.items(): - results_flat[f"{slice_}/{name}"] = metric.mean() - for itr, m in enumerate(metric): - if slice_ == "overall": - results_flat[f"pathology:{pathologies[itr]}/{name}"] = m - else: - results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m + if "sample_size" in name: + results_flat[f"{slice_}/{name}"] = metric + else: + results_flat[f"{slice_}/{name}"] = metric.mean() + for itr, m in enumerate(metric): + if slice_ == "overall": + results_flat[f"pathology:{pathologies[itr]}/{name}"] = m + else: + results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m for slice_, metrics in nih_eval_results_gender[ "model_for_predictions.densenet" ].items(): for name, metric in metrics.items(): - results_flat[f"{slice_}/{name}"] = metric.mean() - for itr, m in enumerate(metric): - if slice_ == "overall": - results_flat[f"pathology:{pathologies[itr]}/{name}"] = m - else: - results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m - + if "sample_size" in name: + results_flat[f"{slice_}/{name}"] = metric + else: + results_flat[f"{slice_}/{name}"] = metric.mean() + for itr, m in enumerate(metric): + if slice_ == "overall": + results_flat[f"pathology:{pathologies[itr]}/{name}"] = m + else: + results_flat[f"{slice_}&pathology:{pathologies[itr]}/{name}"] = m +sample_sizes = { + key.split("/")[0]: value + for key, value in results_flat.items() + if "sample_size" in key.split("/")[1] +} for name, metric in results_flat.items(): split, name = name.split("/") # noqa: PLW2901 descriptions = { @@ -249,15 +259,26 @@ "MultilabelSensitivity": "The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.", "MultilabelSpecificity": "The proportion of actual negative instances that are correctly predicted.", } - report.log_quantitative_analysis( - "performance", - name=name, - value=metric.tolist() if isinstance(metric, np.generic) else metric, - description=descriptions[name], - metric_slice=split, - pass_fail_thresholds=0.7, - pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold), - ) + # remove the "&pathology" from the split name + # check if splits contains "pathology:{pathology}" and remove it + if "pathology" in split: + if len(split.split("&")) == 1: + sample_size_split = "overall" + else: + sample_size_split = "&".join(split.split("&")[:-1]) + else: + sample_size_split = split + if name != "sample_size": + report.log_quantitative_analysis( + "performance", + name=name, + value=metric.tolist() if isinstance(metric, np.generic) else metric, + description=descriptions.get(name), + metric_slice=split, + pass_fail_thresholds=0.7, + pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold), + sample_size=sample_sizes[sample_size_split], + ) # model details for NIH Chest X-Ray model @@ -362,6 +383,6 @@ ) report_path = report.export( - output_filename=f"nihcxr_report_periodic_{args.synthetic_timestamp}.html", + output_filename="nihcxr_report_periodic.html", synthetic_timestamp=args.synthetic_timestamp, ) diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index 53d4e4304..a8c89da24 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -1085,51 +1085,34 @@ { "cell_type": "code", "execution_count": null, - "id": "d322a86f-1f7c-42f6-8a97-8a18ea8622e2", + "id": "d33a171c-02ef-4bc9-a3bf-87320c7c83d6", "metadata": { "tags": [] }, "outputs": [], "source": [ + "descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + "}\n", "model_name = f\"model_for_preds.{model_name}\"\n", + "\n", "results_flat = flatten_results_dict(\n", " results=results,\n", " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", + ")\n", + "\n", + "report.log_performance_metrics(\n", + " results=results_flat,\n", + " metric_descriptions=descriptions,\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "d33a171c-02ef-4bc9-a3bf-87320c7c83d6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "for name, metric in results_flat.items():\n", - " split, name = name.split(\"/\") # noqa: PLW2901\n", - " if name == \"BinaryConfusionMatrix\":\n", - " continue\n", - " descriptions = {\n", - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", - " }\n", - " report.log_quantitative_analysis(\n", - " \"performance\",\n", - " name=name,\n", - " value=metric.tolist(),\n", - " description=descriptions[name],\n", - " metric_slice=split,\n", - " pass_fail_thresholds=0.7,\n", - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", - " )" - ] - }, { "cell_type": "markdown", "id": "0ee464d3-7246-4c6f-ac2d-8cca2a985f22", @@ -1293,9 +1276,9 @@ "# Reformatting the fairness metrics\n", "fairness_results = copy.deepcopy(results[\"fairness\"])\n", "fairness_metrics = {}\n", - "# remove the group size from the fairness results and add it to the slice name\n", + "# remove the sample_size from the fairness results and add it to the slice name\n", "for slice_name, slice_results in fairness_results.items():\n", - " group_size = slice_results.pop(\"Group Size\")\n", + " group_size = slice_results.pop(\"sample_size\")\n", " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" ] }, diff --git a/tests/cyclops/report/test_report.py b/tests/cyclops/report/test_report.py index 33dd71358..2c5e72965 100644 --- a/tests/cyclops/report/test_report.py +++ b/tests/cyclops/report/test_report.py @@ -2,6 +2,8 @@ from unittest import TestCase +import numpy as np + from cyclops.report import ModelCardReport from cyclops.report.model_card.sections import ModelDetails @@ -349,6 +351,7 @@ def test_export(self): decision_threshold=0.7, pass_fail_thresholds=[0.6, 0.65, 0.7], pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)], + sample_size=100, ) self.model_card_report.log_quantitative_analysis( analysis_type="performance", @@ -359,8 +362,119 @@ def test_export(self): description="F1 score of the model on the test set", pass_fail_thresholds=[0.9, 0.85, 0.8], pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)], + sample_size=100, ) self.model_card_report.log_owner(name="John Doe") report_path = self.model_card_report.export(interactive=False, save_json=False) assert isinstance(report_path, str) + + +def test_log_performance_metrics(): + """Test log_performance_metrics.""" + report = ModelCardReport() + + # Mock results + results = { + "overall/BinaryAccuracy": np.array(0.85), + "overall/BinaryPrecision": np.array(0.78), + "overall/BinaryRecall": np.array(0.92), + "overall/BinaryF1Score": np.array(0.84), + "overall/BinaryAUROC": np.array(0.91), + "overall/BinaryAveragePrecision": np.array(0.88), + "overall/sample_size": 1000, + "slice1/BinaryAccuracy": np.array(0.82), + "slice1/BinaryPrecision": np.array(0.75), + "slice1/BinaryRecall": np.array(0.89), + "slice1/BinaryF1Score": np.array(0.81), + "slice1/BinaryAUROC": np.array(0.88), + "slice1/BinaryAveragePrecision": np.array(0.85), + "slice1/sample_size": 500, + } + + # Mock metric descriptions + metric_descriptions = { + "BinaryAccuracy": "The proportion of all instances that are correctly predicted.", + "BinaryPrecision": "The proportion of predicted positive instances that are correctly predicted.", + "BinaryRecall": "The proportion of actual positive instances that are correctly predicted.", + "BinaryF1Score": "The harmonic mean of precision and recall.", + "BinaryAUROC": "The area under the ROC curve.", + "BinaryAveragePrecision": "The area under the precision-recall curve.", + } + + # Test with a single threshold + report.log_performance_metrics( + results, metric_descriptions, pass_fail_thresholds=0.8 + ) + + # Check if metrics were logged correctly + assert report._model_card.quantitative_analysis is not None + assert ( + len(report._model_card.quantitative_analysis.performance_metrics) == 12 + ) # 6 metrics * 2 slices + + # Check a few specific metrics + metrics = report._model_card.quantitative_analysis.performance_metrics + + overall_accuracy = next( + m for m in metrics if m.type == "BinaryAccuracy" and m.slice == "overall" + ) + assert overall_accuracy.value == 0.85 + assert ( + overall_accuracy.description + == "The proportion of all instances that are correctly predicted." + ) + assert overall_accuracy.sample_size == 1000 + assert overall_accuracy.tests[0].threshold == 0.8 + assert overall_accuracy.tests[0].passed + + slice1_precision = next( + m for m in metrics if m.type == "BinaryPrecision" and m.slice == "slice1" + ) + assert slice1_precision.value == 0.75 + assert ( + slice1_precision.description + == "The proportion of predicted positive instances that are correctly predicted." + ) + assert slice1_precision.sample_size == 500 + assert slice1_precision.tests[0].threshold == 0.8 + assert not slice1_precision.tests[0].passed + + # Reset the report + report = ModelCardReport() + + # Test with per-metric thresholds + pass_fail_thresholds = { + "overall/BinaryAccuracy": 0.9, + "overall/BinaryPrecision": 0.75, + "slice1/BinaryRecall": 0.85, + } + report.log_performance_metrics( + results, metric_descriptions, pass_fail_thresholds=pass_fail_thresholds + ) + + metrics = report._model_card.quantitative_analysis.performance_metrics + + overall_accuracy = next( + m for m in metrics if m.type == "BinaryAccuracy" and m.slice == "overall" + ) + assert overall_accuracy.tests[0].threshold == 0.9 + assert not overall_accuracy.tests[0].passed + + overall_precision = next( + m for m in metrics if m.type == "BinaryPrecision" and m.slice == "overall" + ) + assert overall_precision.tests[0].threshold == 0.75 + assert overall_precision.tests[0].passed + + slice1_recall = next( + m for m in metrics if m.type == "BinaryRecall" and m.slice == "slice1" + ) + assert slice1_recall.tests[0].threshold == 0.85 + assert slice1_recall.tests[0].passed + + slice1_f1 = next( + m for m in metrics if m.type == "BinaryF1Score" and m.slice == "slice1" + ) + assert slice1_f1.tests[0].threshold == 0.7 # Default threshold + assert slice1_f1.tests[0].passed diff --git a/tests/cyclops/report/test_utils.py b/tests/cyclops/report/test_utils.py index c07c26e52..200a022bd 100644 --- a/tests/cyclops/report/test_utils.py +++ b/tests/cyclops/report/test_utils.py @@ -267,6 +267,7 @@ def model_card(): history=[0.8, 0.85, 0.9], trend="positive", plot=GraphicsCollection(collection=[Graphic(name="Accuracy")]), + sample_sizes=[100, 200, 300], ), MetricCard( name="Precision", @@ -279,6 +280,7 @@ def model_card(): history=[0.7, 0.8, 0.9], trend="positive", plot=GraphicsCollection(collection=[Graphic(name="Precision")]), + sample_sizes=[100, 200, 300], ), ], ), @@ -290,12 +292,14 @@ def model_card(): value=0.85, slice="overall", tests=[Test()], + sample_size=100, ), PerformanceMetric( type="BinaryPrecision", value=0.8, slice="overall", tests=[Test()], + sample_size=100, ), ] return model_card @@ -395,6 +399,7 @@ def test_create_metric_cards(model_card): description="Accuracy of binary classification", graphics=None, tests=None, + sample_size=100, ), PerformanceMetric( type="MulticlassPrecision", @@ -403,6 +408,7 @@ def test_create_metric_cards(model_card): description="Precision of multiclass classification", graphics=None, tests=None, + sample_size=100, ), ] timestamp = "2022-01-01" @@ -417,6 +423,7 @@ def test_create_metric_cards(model_card): passed=False, history=[0.75, 0.8, 0.85], timestamps=["2021-01-01", "2021-02-01", "2021-03-01"], + sample_sizes=[100, 200, 300], ), MetricCard( name="MulticlassPrecision", @@ -428,6 +435,7 @@ def test_create_metric_cards(model_card): passed=True, history=[0.8, 0.85, 0.9], timestamps=["2021-01-01", "2021-02-01", "2021-03-01"], + sample_sizes=[100, 200, 300], ), ] metrics, tooltips, slices, values, metric_cards = create_metric_cards(