Skip to content

Commit c797881

Browse files
committed
adjust for multiple h_0 and rename to null_hypothesis
1 parent 59945ec commit c797881

File tree

6 files changed

+65
-51
lines changed

6 files changed

+65
-51
lines changed

doubleml/double_ml.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,27 +1609,27 @@ def _calc_sensitivity_analysis(self, cf_y, cf_d, rho, level):
16091609

16101610
return res_dict
16111611

1612-
def _calc_robustness_value(self, theta, level, rho, idx_treatment):
1613-
_check_float(theta, "theta")
1612+
def _calc_robustness_value(self, null_hypothesis, level, rho, idx_treatment):
1613+
_check_float(null_hypothesis, "null_hypothesis")
16141614
_check_integer(idx_treatment, "idx_treatment", lower_bound=0, upper_bound=self._dml_data.n_treat-1)
16151615

16161616
# check which side is relvant
1617-
bound = 'upper' if (theta > self.coef[idx_treatment]) else 'lower'
1617+
bound = 'upper' if (null_hypothesis > self.coef[idx_treatment]) else 'lower'
16181618

16191619
# minimize the square to find boundary solutions
16201620
def rv_fct(value, param):
16211621
res = self._calc_sensitivity_analysis(cf_y=value,
16221622
cf_d=value,
16231623
rho=rho,
1624-
level=level)[param][bound][idx_treatment] - theta
1624+
level=level)[param][bound][idx_treatment] - null_hypothesis
16251625
return np.square(res)
16261626

16271627
rv = minimize_scalar(rv_fct, bounds=(0, 0.9999), method='bounded', args=('theta', )).x
16281628
rva = minimize_scalar(rv_fct, bounds=(0, 0.9999), method='bounded', args=('ci', )).x
16291629

16301630
return rv, rva
16311631

1632-
def sensitivity_analysis(self, cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95, theta=0.0):
1632+
def sensitivity_analysis(self, cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95, null_hypothesis=0.0):
16331633
"""
16341634
Performs a sensitivity analysis to account for unobserved confounders.
16351635
@@ -1655,8 +1655,9 @@ def sensitivity_analysis(self, cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95, theta=
16551655
The confidence level.
16561656
Default is ``0.95``.
16571657
1658-
theta : float
1658+
null_hypothesis : float or numpy.ndarray
16591659
Null hypothesis for the effect. Determines the robustness values.
1660+
If it is a single float uses the same null hypothesis for all estimated parameters. Else the array has to be of shape (n_coefs,).
16601661
Default is ``0.0``.
16611662
16621663
Returns
@@ -1666,11 +1667,24 @@ def sensitivity_analysis(self, cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95, theta=
16661667
# compute sensitivity analysis
16671668
sensitivity_dict = self._calc_sensitivity_analysis(cf_y=cf_y, cf_d=cf_d, rho=rho, level=level)
16681669

1669-
# compute robustess values with respect to theta
1670+
if isinstance(null_hypothesis, float):
1671+
null_hypothesis_vec = np.full(shape=self._dml_data.n_treat, fill_value=null_hypothesis)
1672+
elif isinstance(null_hypothesis, np.ndarray):
1673+
if null_hypothesis.shape == (self._dml_data.n_treat,):
1674+
null_hypothesis_vec = null_hypothesis
1675+
else:
1676+
raise ValueError(f"null_hypothesis is numpy.ndarray but does not have the required shape ({self._dml_data.n_treat},). "
1677+
f'Array of shape {str(null_hypothesis.shape)} was passed.')
1678+
else:
1679+
raise TypeError("null_hypothesis has to be of type float or np.ndarry. "
1680+
f"{str(null_hypothesis)} of type {str(type(null_hypothesis))} was passed.")
1681+
1682+
# compute robustess values with respect to null_hypothesis
16701683
rv = np.full(shape=self._dml_data.n_treat, fill_value=np.nan)
16711684
rva = np.full(shape=self._dml_data.n_treat, fill_value=np.nan)
1685+
16721686
for i_treat in range(self._dml_data.n_treat):
1673-
rv[i_treat], rva[i_treat] = self._calc_robustness_value(theta=theta, level=level, rho=rho, idx_treatment=i_treat)
1687+
rv[i_treat], rva[i_treat] = self._calc_robustness_value(null_hypothesis=null_hypothesis_vec[i_treat], level=level, rho=rho, idx_treatment=i_treat)
16741688

16751689
sensitivity_dict['rv'] = rv
16761690
sensitivity_dict['rva'] = rva
@@ -1680,7 +1694,7 @@ def sensitivity_analysis(self, cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95, theta=
16801694
'cf_d': cf_d,
16811695
'rho': rho,
16821696
'level': level,
1683-
'theta': theta}
1697+
'null_hypothesis': null_hypothesis_vec}
16841698
sensitivity_dict['input'] = input_params
16851699

16861700
self._sensitivity_params = sensitivity_dict
@@ -1700,20 +1714,11 @@ def sensitivity_summary(self):
17001714
if self.sensitivity_params is None:
17011715
res = header + 'Apply sensitivity_analysis() to generate sensitivity_summary.'
17021716
else:
1703-
hypothesis = f'Null Hypothesis: theta={self._sensitivity_params["input"]["theta"]}\n'
17041717
sig_level = f'Significance Level: level={self.sensitivity_params["input"]["level"]}\n'
17051718
scenario_params = f'Sensitivity parameters: cf_y={self.sensitivity_params["input"]["cf_y"]}; ' \
17061719
f'cf_d={self.sensitivity_params["input"]["cf_d"]}, ' \
17071720
f'rho={self.sensitivity_params["input"]["rho"]}'
17081721

1709-
rvs_col_names = ['RV (%)', 'RVa (%)']
1710-
rvs = np.transpose(np.vstack((self._sensitivity_params['rv'],
1711-
self._sensitivity_params['rva']))) * 100
1712-
df_rvs = pd.DataFrame(rvs,
1713-
columns=rvs_col_names,
1714-
index=self._dml_data.d_cols)
1715-
rvs_summary = str(df_rvs)
1716-
17171722
theta_and_ci_col_names = ['CI lower', 'theta lower', ' theta', 'theta upper', 'CI upper']
17181723
theta_and_ci = np.transpose(np.vstack((self._sensitivity_params['ci']['lower'],
17191724
self._sensitivity_params['theta']['lower'],
@@ -1725,17 +1730,26 @@ def sensitivity_summary(self):
17251730
index=self._dml_data.d_cols)
17261731
theta_and_ci_summary = str(df_theta_and_ci)
17271732

1733+
rvs_col_names = ['H_0', 'RV (%)', 'RVa (%)']
1734+
rvs = np.transpose(np.vstack((self._sensitivity_params['rv'],
1735+
self._sensitivity_params['rva']))) * 100
1736+
1737+
df_rvs = pd.DataFrame(np.column_stack((self.sensitivity_params["input"]["null_hypothesis"], rvs)),
1738+
columns=rvs_col_names,
1739+
index=self._dml_data.d_cols)
1740+
rvs_summary = str(df_rvs)
1741+
17281742
res = header + \
17291743
'\n------------------ Scenario ------------------\n' + \
1730-
hypothesis + sig_level + scenario_params + '\n' + \
1731-
'\n------------------ Robustness Values ------------------\n' + \
1732-
rvs_summary + '\n' + \
1744+
sig_level + scenario_params + '\n' + \
17331745
'\n------------------ Bounds with CI ------------------\n' + \
1734-
theta_and_ci_summary
1746+
theta_and_ci_summary + '\n' + \
1747+
'\n------------------ Robustness Values ------------------\n' + \
1748+
rvs_summary
17351749

17361750
return res
17371751

1738-
def sensitivity_plot(self, idx_treatment=0, theta=0.0, value='theta', include_scenario=True,
1752+
def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True,
17391753
fill=True, grid_bounds=(0.15, 0.15), grid_size=100):
17401754
"""
17411755
Contour plot of the sensivity with respect to latent/confounding variables.
@@ -1746,11 +1760,6 @@ def sensitivity_plot(self, idx_treatment=0, theta=0.0, value='theta', include_sc
17461760
Index of the treatment to perform the sensitivity analysis.
17471761
Default is ``0``.
17481762
1749-
theta : float
1750-
Null hypothesis for the effect. Determines whether the upper or lower bound of the estimates has to be considered.
1751-
If the null hypothesis is smaller than the treatment effect estimate the lower bounds are used and vice versa.
1752-
Default is ``0.0``.
1753-
17541763
value : str
17551764
Determines which contours to plot. Valid values are ``'theta'`` (refers to the bounds)
17561765
and ``'ci'`` (refers to the bounds including statistical uncertainty).
@@ -1781,7 +1790,6 @@ def sensitivity_plot(self, idx_treatment=0, theta=0.0, value='theta', include_sc
17811790
raise ValueError('Apply sensitivity_analysis() to include senario in sensitivity_plot. '
17821791
'The values of rho and the level are used for the scenario.')
17831792
_check_integer(idx_treatment, "idx_treatment", lower_bound=0, upper_bound=self._dml_data.n_treat-1)
1784-
_check_float(theta, "theta")
17851793
if not isinstance(value, str):
17861794
raise TypeError('value must be a string. '
17871795
f'{str(value)} of type {type(value)} was passed.')
@@ -1795,9 +1803,10 @@ def sensitivity_plot(self, idx_treatment=0, theta=0.0, value='theta', include_sc
17951803
_check_in_zero_one(grid_bounds[1], "grid_bounds", include_zero=False, include_one=False)
17961804
_check_integer(grid_size, "grid_size", lower_bound=10)
17971805

1806+
null_hypothesis = self.sensitivity_params['input']['null_hypothesis'][idx_treatment]
17981807
unadjusted_theta = self.coef[idx_treatment]
17991808
# check which side is relvant
1800-
bound = 'upper' if (theta > unadjusted_theta) else 'lower'
1809+
bound = 'upper' if (null_hypothesis > unadjusted_theta) else 'lower'
18011810

18021811
# create evaluation grid
18031812
cf_d_vec = np.linspace(0, grid_bounds[0], grid_size)

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,23 +1011,23 @@ def test_doubleml_sensitivity_inputs():
10111011
with pytest.raises(TypeError, match=msg):
10121012
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1, level=0.95)
10131013
with pytest.raises(TypeError, match=msg):
1014-
_ = dml_irm._calc_robustness_value(rho=1, theta=0.0, level=0.95, idx_treatment=0)
1014+
_ = dml_irm._calc_robustness_value(rho=1, null_hypothesis=0.0, level=0.95, idx_treatment=0)
10151015

10161016
msg = "rho must be of float type. 1 of type <class 'str'> was passed."
10171017
with pytest.raises(TypeError, match=msg):
10181018
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1")
10191019
with pytest.raises(TypeError, match=msg):
10201020
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1", level=0.95)
10211021
with pytest.raises(TypeError, match=msg):
1022-
_ = dml_irm._calc_robustness_value(rho="1", theta=0.0, level=0.95, idx_treatment=0)
1022+
_ = dml_irm._calc_robustness_value(rho="1", null_hypothesis=0.0, level=0.95, idx_treatment=0)
10231023

10241024
msg = r'The absolute value of rho must be in \[0,1\]. 1.1 was passed.'
10251025
with pytest.raises(ValueError, match=msg):
10261026
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1)
10271027
with pytest.raises(ValueError, match=msg):
10281028
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1, level=0.95)
10291029
with pytest.raises(ValueError, match=msg):
1030-
_ = dml_irm._calc_robustness_value(rho=1.1, theta=0.0, level=0.95, idx_treatment=0)
1030+
_ = dml_irm._calc_robustness_value(rho=1.1, null_hypothesis=0.0, level=0.95, idx_treatment=0)
10311031

10321032
# test level
10331033
msg = "The confidence level must be of float type. 1 of type <class 'int'> was passed."
@@ -1036,50 +1036,55 @@ def test_doubleml_sensitivity_inputs():
10361036
with pytest.raises(TypeError, match=msg):
10371037
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1)
10381038
with pytest.raises(TypeError, match=msg):
1039-
_ = dml_irm._calc_robustness_value(rho=1.0, level=1, theta=0.0, idx_treatment=0)
1039+
_ = dml_irm._calc_robustness_value(rho=1.0, level=1, null_hypothesis=0.0, idx_treatment=0)
10401040

10411041
msg = r'The confidence level must be in \(0,1\). 1.0 was passed.'
10421042
with pytest.raises(ValueError, match=msg):
10431043
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
10441044
with pytest.raises(ValueError, match=msg):
10451045
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
10461046
with pytest.raises(ValueError, match=msg):
1047-
_ = dml_irm._calc_robustness_value(rho=1.0, level=1.0, theta=0.0, idx_treatment=0)
1047+
_ = dml_irm._calc_robustness_value(rho=1.0, level=1.0, null_hypothesis=0.0, idx_treatment=0)
10481048

10491049
msg = r'The confidence level must be in \(0,1\). 0.0 was passed.'
10501050
with pytest.raises(ValueError, match=msg):
10511051
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
10521052
with pytest.raises(ValueError, match=msg):
10531053
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
10541054
with pytest.raises(ValueError, match=msg):
1055-
_ = dml_irm._calc_robustness_value(rho=1.0, level=0.0, theta=0.0, idx_treatment=0)
1055+
_ = dml_irm._calc_robustness_value(rho=1.0, level=0.0, null_hypothesis=0.0, idx_treatment=0)
10561056

1057-
# test theta
1058-
msg = "theta must be of float type. 1 of type <class 'int'> was passed."
1057+
# test null_hypothesis
1058+
msg = "null_hypothesis has to be of type float or np.ndarry. 1 of type <class 'int'> was passed."
10591059
with pytest.raises(TypeError, match=msg):
1060-
_ = dml_irm.sensitivity_analysis(theta=1)
1060+
_ = dml_irm.sensitivity_analysis(null_hypothesis=1)
1061+
msg = r"null_hypothesis is numpy.ndarray but does not have the required shape \(1,\). Array of shape \(2,\) was passed."
1062+
with pytest.raises(ValueError, match=msg):
1063+
_ = dml_irm.sensitivity_analysis(null_hypothesis=np.array([1, 2]))
1064+
msg = "null_hypothesis must be of float type. 1 of type <class 'int'> was passed."
10611065
with pytest.raises(TypeError, match=msg):
1062-
_ = dml_irm._calc_robustness_value(theta=1, level=0.95, rho=1.0, idx_treatment=0)
1066+
_ = dml_irm._calc_robustness_value(null_hypothesis=1, level=0.95, rho=1.0, idx_treatment=0)
1067+
msg = r"null_hypothesis must be of float type. \[1\] of type <class 'numpy.ndarray'> was passed."
10631068
with pytest.raises(TypeError, match=msg):
1064-
dml_irm.sensitivity_analysis()
1065-
_ = dml_irm.sensitivity_plot(theta=1)
1069+
_ = dml_irm._calc_robustness_value(null_hypothesis=np.array([1]), level=0.95, rho=1.0, idx_treatment=0)
10661070

10671071
# test idx_treatment
1072+
dml_irm.sensitivity_analysis()
10681073
msg = "idx_treatment must be an integer. 0.0 of type <class 'float'> was passed."
10691074
with pytest.raises(TypeError, match=msg):
1070-
_ = dml_irm._calc_robustness_value(idx_treatment=0.0, theta=0.0, level=0.95, rho=1.0)
1075+
_ = dml_irm._calc_robustness_value(idx_treatment=0.0, null_hypothesis=0.0, level=0.95, rho=1.0)
10711076
with pytest.raises(TypeError, match=msg):
10721077
_ = dml_irm.sensitivity_plot(idx_treatment=0.0)
10731078

10741079
msg = "idx_treatment must be larger or equal to 0. -1 was passed."
10751080
with pytest.raises(ValueError, match=msg):
1076-
_ = dml_irm._calc_robustness_value(idx_treatment=-1, theta=0.0, level=0.95, rho=1.0)
1081+
_ = dml_irm._calc_robustness_value(idx_treatment=-1, null_hypothesis=0.0, level=0.95, rho=1.0)
10771082
with pytest.raises(ValueError, match=msg):
10781083
_ = dml_irm.sensitivity_plot(idx_treatment=-1)
10791084

10801085
msg = "idx_treatment must be smaller or equal to 0. 1 was passed."
10811086
with pytest.raises(ValueError, match=msg):
1082-
_ = dml_irm._calc_robustness_value(idx_treatment=1, theta=0.0, level=0.95, rho=1.0)
1087+
_ = dml_irm._calc_robustness_value(idx_treatment=1, null_hypothesis=0.0, level=0.95, rho=1.0)
10831088
with pytest.raises(ValueError, match=msg):
10841089
_ = dml_irm.sensitivity_plot(idx_treatment=1)
10851090

doubleml/tests/test_doubleml_model_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_sensitivity_defaults():
184184
'cf_d': 0.03,
185185
'rho': 1.0,
186186
'level': 0.95,
187-
'theta': 0.0}
187+
'null_hypothesis': np.array([0.])}
188188

189189
dml_plr.sensitivity_analysis()
190190
assert dml_plr._sensitivity_params['input'] == input_dict

doubleml/tests/test_doubleml_return_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,12 @@ def test_sensitivity():
351351
assert isinstance(plr_dml1.sensitivity_plot(), plotly.graph_objs._figure.Figure)
352352
assert isinstance(plr_dml1.sensitivity_plot(value='ci'), plotly.graph_objs._figure.Figure)
353353
assert isinstance(plr_dml1._calc_sensitivity_analysis(cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95), dict)
354-
assert isinstance(plr_dml1._calc_robustness_value(theta=0.0, level=0.95, rho=1.0, idx_treatment=0), tuple)
354+
assert isinstance(plr_dml1._calc_robustness_value(null_hypothesis=0.0, level=0.95, rho=1.0, idx_treatment=0), tuple)
355355

356356
assert isinstance(irm_dml1.sensitivity_summary, str)
357357
irm_dml1.sensitivity_analysis()
358358
assert isinstance(irm_dml1.sensitivity_summary, str)
359359
assert isinstance(irm_dml1.sensitivity_plot(), plotly.graph_objs._figure.Figure)
360360
assert isinstance(irm_dml1.sensitivity_plot(value='ci'), plotly.graph_objs._figure.Figure)
361361
assert isinstance(irm_dml1._calc_sensitivity_analysis(cf_y=0.03, cf_d=0.03, rho=1.0, level=0.95), dict)
362-
assert isinstance(irm_dml1._calc_robustness_value(theta=0.0, level=0.95, rho=1.0, idx_treatment=0), tuple)
362+
assert isinstance(irm_dml1._calc_robustness_value(null_hypothesis=0.0, level=0.95, rho=1.0, idx_treatment=0), tuple)

doubleml/tests/test_doubleml_sensitivity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dml_sensitivity_multitreat_fixture(generate_data_bivariate, dml_procedure, n
6666
dml_procedure=dml_procedure)
6767

6868
dml_plr_obj.fit()
69-
dml_plr_obj.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d, rho=rho, level=level, theta=0.0)
69+
dml_plr_obj.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d, rho=rho, level=level, null_hypothesis=0.0)
7070
res_manual = doubleml_sensitivity_manual(sensitivity_elements=dml_plr_obj.sensitivity_elements,
7171
all_coefs=dml_plr_obj.all_coef,
7272
psi=dml_plr_obj.psi,

doubleml/tests/test_doubleml_sensitivity_cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def dml_plr_multiway_cluster_sensitivity_rho0(dml_procedure, score):
6969

7070
dml_plr_obj.fit()
7171
dml_plr_obj.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d,
72-
rho=0.0, level=level, theta=0.0)
72+
rho=0.0, level=level, null_hypothesis=0.0)
7373

7474
res_dict = {'coef': dml_plr_obj.coef,
7575
'se': dml_plr_obj.se,
@@ -108,7 +108,7 @@ def dml_plr_multiway_cluster_sensitivity_rho0_se(dml_procedure):
108108

109109
dml_plr_obj.fit()
110110
dml_plr_obj.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d,
111-
rho=0.0, level=level, theta=0.0)
111+
rho=0.0, level=level, null_hypothesis=0.0)
112112

113113
res_dict = {'coef': dml_plr_obj.coef,
114114
'se': dml_plr_obj.se,

0 commit comments

Comments
 (0)