Skip to content

Commit

Permalink
Merge pull request #300 from alan-turing-institute/sensitivity_analys…
Browse files Browse the repository at this point in the history
…is_var_name

Sensitivity analysis var name
  • Loading branch information
marjanfamili authored Mar 4, 2025
2 parents ba02872 + 9359041 commit 5d056b0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 47 deletions.
89 changes: 48 additions & 41 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from SALib.analyze.sobol import analyze
from SALib.sample.sobol import sample
from SALib.util import ResultDict

from autoemulate.plotting import _display_figure
from autoemulate.utils import _ensure_2d
Expand Down Expand Up @@ -97,12 +100,14 @@ def _generate_problem(X):

return {
"num_vars": X.shape[1],
"names": [f"x{i+1}" for i in range(X.shape[1])],
"names": [f"X{i+1}" for i in range(X.shape[1])],
"bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
}


def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
def _sobol_analysis(
model, problem=None, X=None, N=1024, conf_level=0.95
) -> Dict[str, ResultDict]:
"""
Perform Sobol sensitivity analysis on a fitted emulator.
Expand Down Expand Up @@ -149,56 +154,58 @@ def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
return results


def _sobol_results_to_df(results):
def _sobol_results_to_df(results: Dict[str, ResultDict]) -> pd.DataFrame:
"""
Convert Sobol results to a (long-format)pandas DataFrame.
Convert Sobol results to a (long-format) pandas DataFrame.
Parameters:
-----------
results : dict
The Sobol indices returned by sobol_analysis.
problem : dict, optional
The problem definition, including 'names'.
Returns:
--------
pd.DataFrame
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
"""
rename_dict = {
"variable": "index",
"S1": "value",
"S1_conf": "confidence",
"ST": "value",
"ST_conf": "confidence",
"S2": "value",
"S2_conf": "confidence",
}
rows = []
for output, indices in results.items():
for index_type in ["S1", "ST", "S2"]:
values = indices.get(index_type)
conf_values = indices.get(f"{index_type}_conf")
if values is None or conf_values is None:
continue

if index_type in ["S1", "ST"]:
rows.extend(
{
"output": output,
"parameter": f"X{i+1}",
"index": index_type,
"value": value,
"confidence": conf,
}
for i, (value, conf) in enumerate(zip(values, conf_values))
)

elif index_type == "S2":
n = values.shape[0]
rows.extend(
{
"output": output,
"parameter": f"X{i+1}-X{j+1}",
"index": index_type,
"value": values[i, j],
"confidence": conf_values[i, j],
}
for i in range(n)
for j in range(i + 1, n)
if not np.isnan(values[i, j])
)

return pd.DataFrame(rows)
for output, result in results.items():
s1, st, s2 = result.to_df()
s1 = (
s1.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
s1["index"] = "S1"
st = (
st.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
st["index"] = "ST"
s2 = (
s2.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
s2["index"] = "S2"

df = pd.concat([s1, st, s2])
df["output"] = output
rows.append(df[["output", "parameter", "index", "value", "confidence"]])

return pd.concat(rows)


# plotting --------------------------------------------------------------------
Expand Down Expand Up @@ -242,7 +249,7 @@ def _create_bar_plot(ax, output_data, output_name):
ax.set_title(f"Output: {output_name}")


def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.
Expand All @@ -264,7 +271,7 @@ def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
"""
with plt.style.context("fast"):
# prepare data
results = _validate_input(results, index)
results = _validate_input(results, problem, index)
unique_outputs = results["output"].unique()
n_outputs = len(unique_outputs)

Expand Down
15 changes: 9 additions & 6 deletions tests/test_sensitivity_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.datasets import make_regression

from autoemulate.emulators import RandomForest
from autoemulate.experimental_design import LatinHypercube
Expand Down Expand Up @@ -150,7 +149,11 @@ def sobol_results_1d(model_1d):

# # test conversion to DataFrame --------------------------------------------------
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sobol_results_to_df(sobol_results_1d):
@pytest.mark.parametrize(
"expected_names",
[["c", "v0", "c", "v0", ["c", "v0"]]],
)
def test_sobol_results_to_df(sobol_results_1d, expected_names):
df = _sobol_results_to_df(sobol_results_1d)
assert isinstance(df, pd.DataFrame)
assert df.columns.tolist() == [
Expand All @@ -160,7 +163,7 @@ def test_sobol_results_to_df(sobol_results_1d):
"value",
"confidence",
]
assert ["X1", "X2", "X1-X2"] in df["parameter"].unique()
assert expected_names == df["parameter"].to_list()
assert all(isinstance(x, float) for x in df["value"])
assert all(isinstance(x, float) for x in df["confidence"])

Expand All @@ -172,12 +175,12 @@ def test_sobol_results_to_df(sobol_results_1d):
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_validate_input(sobol_results_1d):
with pytest.raises(ValueError):
_validate_input(sobol_results_1d, "S3")
_validate_input(sobol_results_1d, index="S3")


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_validate_input_valid(sobol_results_1d):
Si = _validate_input(sobol_results_1d, "S1")
Si = _validate_input(sobol_results_1d, index="S1")
assert isinstance(Si, pd.DataFrame)


Expand Down Expand Up @@ -207,5 +210,5 @@ def test_generate_problem():
X = np.array([[0, 0], [1, 1], [2, 2]])
problem = _generate_problem(X)
assert problem["num_vars"] == 2
assert problem["names"] == ["x1", "x2"]
assert problem["names"] == ["X1", "X2"]
assert problem["bounds"] == [[0, 2], [0, 2]]

0 comments on commit 5d056b0

Please sign in to comment.