@@ -218,12 +218,26 @@ def _initialize_weights(self, weights):
218218 assert isinstance (weights , dict )
219219 self ._weights = weights
220220
221- def _get_weights (self ):
222- weights = self ._weights ['weights' ]
223- if 'weights_bar' not in self ._weights .keys ():
224- weights_bar = self ._weights ['weights' ]
221+ def _get_weights (self , m_hat = None ):
222+ # standard case for ATE
223+ if self .score == 'ATE' :
224+ weights = self ._weights ['weights' ]
225+ if 'weights_bar' not in self ._weights .keys ():
226+ weights_bar = self ._weights ['weights' ]
227+ else :
228+ weights_bar = self ._weights ['weights_bar' ][:, self ._i_rep ]
225229 else :
226- weights_bar = self ._weights ['weights_bar' ][:, self ._i_rep ]
230+ # special case for ATTE
231+ assert self .score == 'ATTE'
232+ assert m_hat is not None
233+ subgroup = self ._weights ['weights' ] * self ._dml_data .d
234+ subgroup_probability = np .mean (subgroup )
235+ weights = np .divide (subgroup , subgroup_probability )
236+
237+ weights_bar = np .divide (
238+ np .multiply (m_hat , self ._weights ['weights' ]),
239+ subgroup_probability )
240+
227241 return weights , weights_bar
228242
229243 def _check_data (self , obj_dml_data ):
@@ -277,8 +291,13 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
277291 f'predictions obtained with the ml_g learner { str (self ._learner ["ml_g" ])} are also '
278292 'observed to be binary with values 0 and 1. Make sure that for classifiers '
279293 'probabilities and not labels are predicted.' )
294+ if self .score == 'ATTE' :
295+ # skip g_hat1 estimation
296+ g_hat1 = {'preds' : None ,
297+ 'targets' : None ,
298+ 'models' : None }
280299
281- if g1_external :
300+ elif g1_external :
282301 # use external predictions
283302 g_hat1 = {'preds' : external_predictions ['ml_g1' ],
284303 'targets' : None ,
@@ -291,7 +310,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
291310 # adjust target values to consider only compatible subsamples
292311 g_hat1 ['targets' ] = _cond_targets (g_hat1 ['targets' ], cond_sample = (d == 1 ))
293312
294- if self ._dml_data .binary_outcome :
313+ if self ._dml_data .binary_outcome & ( self . score != 'ATTE' ) :
295314 binary_preds = (type_of_target (g_hat1 ['preds' ]) == 'binary' )
296315 zero_one_preds = np .all ((np .power (g_hat1 ['preds' ], 2 ) - g_hat1 ['preds' ]) == 0 )
297316 if binary_preds & zero_one_preds :
@@ -334,11 +353,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
334353
335354 def _score_elements (self , y , d , g_hat0 , g_hat1 , m_hat , smpls ):
336355
337- # fraction of treated for ATTE
338- p_hat = None
339- if self .score == 'ATTE' :
340- p_hat = np .mean (d )
341-
342356 m_hat_adj = np .full_like (m_hat , np .nan , dtype = 'float64' )
343357 if self .normalize_ipw :
344358 if self .dml_procedure == 'dml1' :
@@ -351,24 +365,21 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
351365
352366 # compute residuals
353367 u_hat0 = y - g_hat0
354- u_hat1 = None
355- if self .score == 'ATE' :
356- u_hat1 = y - g_hat1
357-
358- if isinstance (self .score , str ):
368+ if self .score == 'ATTE' :
369+ g_hat1 = y
370+ u_hat1 = y - g_hat1
371+
372+ if (self .score == 'ATE' ) or (self .score == 'ATTE' ):
373+ weights , weights_bar = self ._get_weights (m_hat = m_hat_adj )
374+ psi_b = weights * (g_hat1 - g_hat0 ) \
375+ + weights_bar * (
376+ np .divide (np .multiply (d , u_hat1 ), m_hat_adj )
377+ - np .divide (np .multiply (1.0 - d , u_hat0 ), 1.0 - m_hat_adj ))
359378 if self .score == 'ATE' :
360- weights , weights_bar = self ._get_weights ()
361- psi_b = weights * (g_hat1 - g_hat0 ) \
362- + weights_bar * (
363- np .divide (np .multiply (d , u_hat1 ), m_hat_adj )
364- - np .divide (np .multiply (1.0 - d , u_hat0 ), 1.0 - m_hat_adj ))
365379 psi_a = np .full_like (m_hat_adj , - 1.0 )
366380 else :
367381 assert self .score == 'ATTE'
368- psi_b = np .divide (np .multiply (d , u_hat0 ), p_hat ) \
369- - np .divide (np .multiply (m_hat_adj , np .multiply (1.0 - d , u_hat0 )),
370- np .multiply (p_hat , (1.0 - m_hat_adj )))
371- psi_a = - np .divide (d , p_hat )
382+ psi_a = - 1.0 * weights
372383 else :
373384 assert callable (self .score )
374385 psi_a , psi_b = self .score (y = y , d = d ,
@@ -384,15 +395,14 @@ def _sensitivity_element_est(self, preds):
384395
385396 m_hat = preds ['predictions' ]['ml_m' ]
386397 g_hat0 = preds ['predictions' ]['ml_g0' ]
387- g_hat1 = preds ['predictions' ]['ml_g1' ]
388-
389- # use weights make this extendable
390398 if self .score == 'ATE' :
391- weights , weights_bar = self . _get_weights ()
399+ g_hat1 = preds [ 'predictions' ][ 'ml_g1' ]
392400 else :
393401 assert self .score == 'ATTE'
394- weights = np .divide (d , np .mean (d ))
395- weights_bar = np .divide (m_hat , np .mean (d ))
402+ g_hat1 = y
403+
404+ # use weights make this extendable
405+ weights , weights_bar = self ._get_weights (m_hat = m_hat )
396406
397407 sigma2_score_element = np .square (y - np .multiply (d , g_hat1 ) - np .multiply (1.0 - d , g_hat0 ))
398408 sigma2 = np .mean (sigma2_score_element )
0 commit comments