@@ -221,12 +221,26 @@ def _initialize_weights(self, weights):
221221 assert isinstance (weights , dict )
222222 self ._weights = weights
223223
224- def _get_weights (self ):
225- weights = self ._weights ['weights' ]
226- if 'weights_bar' not in self ._weights .keys ():
227- weights_bar = self ._weights ['weights' ]
224+ def _get_weights (self , m_hat = None ):
225+ # standard case for ATE
226+ if self .score == 'ATE' :
227+ weights = self ._weights ['weights' ]
228+ if 'weights_bar' not in self ._weights .keys ():
229+ weights_bar = self ._weights ['weights' ]
230+ else :
231+ weights_bar = self ._weights ['weights_bar' ][:, self ._i_rep ]
228232 else :
229- weights_bar = self ._weights ['weights_bar' ][:, self ._i_rep ]
233+ # special case for ATTE
234+ assert self .score == 'ATTE'
235+ assert m_hat is not None
236+ subgroup = self ._weights ['weights' ] * self ._dml_data .d
237+ subgroup_probability = np .mean (subgroup )
238+ weights = np .divide (subgroup , subgroup_probability )
239+
240+ weights_bar = np .divide (
241+ np .multiply (m_hat , self ._weights ['weights' ]),
242+ subgroup_probability )
243+
230244 return weights , weights_bar
231245
232246 def _check_data (self , obj_dml_data ):
@@ -280,8 +294,13 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
280294 f'predictions obtained with the ml_g learner { str (self ._learner ["ml_g" ])} are also '
281295 'observed to be binary with values 0 and 1. Make sure that for classifiers '
282296 'probabilities and not labels are predicted.' )
297+ if self .score == 'ATTE' :
298+ # skip g_hat1 estimation
299+ g_hat1 = {'preds' : None ,
300+ 'targets' : None ,
301+ 'models' : None }
283302
284- if g1_external :
303+ elif g1_external :
285304 # use external predictions
286305 g_hat1 = {'preds' : external_predictions ['ml_g1' ],
287306 'targets' : None ,
@@ -294,7 +313,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
294313 # adjust target values to consider only compatible subsamples
295314 g_hat1 ['targets' ] = _cond_targets (g_hat1 ['targets' ], cond_sample = (d == 1 ))
296315
297- if self ._dml_data .binary_outcome :
316+ if self ._dml_data .binary_outcome & ( self . score != 'ATTE' ) :
298317 binary_preds = (type_of_target (g_hat1 ['preds' ]) == 'binary' )
299318 zero_one_preds = np .all ((np .power (g_hat1 ['preds' ], 2 ) - g_hat1 ['preds' ]) == 0 )
300319 if binary_preds & zero_one_preds :
@@ -338,11 +357,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
338357
339358 def _score_elements (self , y , d , g_hat0 , g_hat1 , m_hat , smpls ):
340359
341- # fraction of treated for ATTE
342- p_hat = None
343- if self .score == 'ATTE' :
344- p_hat = np .mean (d )
345-
346360 m_hat_adj = np .full_like (m_hat , np .nan , dtype = 'float64' )
347361 if self .normalize_ipw :
348362 if self .dml_procedure == 'dml1' :
@@ -355,24 +369,21 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
355369
356370 # compute residuals
357371 u_hat0 = y - g_hat0
358- u_hat1 = None
359- if self .score == 'ATE' :
360- u_hat1 = y - g_hat1
361-
362- if isinstance (self .score , str ):
372+ if self .score == 'ATTE' :
373+ g_hat1 = y
374+ u_hat1 = y - g_hat1
375+
376+ if (self .score == 'ATE' ) or (self .score == 'ATTE' ):
377+ weights , weights_bar = self ._get_weights (m_hat = m_hat_adj )
378+ psi_b = weights * (g_hat1 - g_hat0 ) \
379+ + weights_bar * (
380+ np .divide (np .multiply (d , u_hat1 ), m_hat_adj )
381+ - np .divide (np .multiply (1.0 - d , u_hat0 ), 1.0 - m_hat_adj ))
363382 if self .score == 'ATE' :
364- weights , weights_bar = self ._get_weights ()
365- psi_b = weights * (g_hat1 - g_hat0 ) \
366- + weights_bar * (
367- np .divide (np .multiply (d , u_hat1 ), m_hat_adj )
368- - np .divide (np .multiply (1.0 - d , u_hat0 ), 1.0 - m_hat_adj ))
369383 psi_a = np .full_like (m_hat_adj , - 1.0 )
370384 else :
371385 assert self .score == 'ATTE'
372- psi_b = np .divide (np .multiply (d , u_hat0 ), p_hat ) \
373- - np .divide (np .multiply (m_hat_adj , np .multiply (1.0 - d , u_hat0 )),
374- np .multiply (p_hat , (1.0 - m_hat_adj )))
375- psi_a = - np .divide (d , p_hat )
386+ psi_a = - 1.0 * weights
376387 else :
377388 assert callable (self .score )
378389 psi_a , psi_b = self .score (y = y , d = d ,
@@ -388,15 +399,14 @@ def _sensitivity_element_est(self, preds):
388399
389400 m_hat = preds ['predictions' ]['ml_m' ]
390401 g_hat0 = preds ['predictions' ]['ml_g0' ]
391- g_hat1 = preds ['predictions' ]['ml_g1' ]
392-
393- # use weights make this extendable
394402 if self .score == 'ATE' :
395- weights , weights_bar = self . _get_weights ()
403+ g_hat1 = preds [ 'predictions' ][ 'ml_g1' ]
396404 else :
397405 assert self .score == 'ATTE'
398- weights = np .divide (d , np .mean (d ))
399- weights_bar = np .divide (m_hat , np .mean (d ))
406+ g_hat1 = y
407+
408+ # use weights make this extendable
409+ weights , weights_bar = self ._get_weights (m_hat = m_hat )
400410
401411 sigma2_score_element = np .square (y - np .multiply (d , g_hat1 ) - np .multiply (1.0 - d , g_hat0 ))
402412 sigma2 = np .mean (sigma2_score_element )
0 commit comments