1
+ from config import conf
2
+ import numpy as np
3
+ import torch
4
+ import torch .nn .functional as F
5
+
6
+ class ConformalPrediction :
7
+ def __init__ (self , X_cal , y_cal , X_est , y_est ,model , delta ) -> None :
8
+ self .model = model
9
+ self .X_cal = X_cal
10
+ self .y_cal = y_cal
11
+ self .X_est = X_est
12
+ self .y_est = y_est
13
+ self .calibration_size = len (y_cal )
14
+ self .delta = delta
15
+
16
+ def find_all_alpha_values ():
17
+ pass
18
+
19
+ def prediction_sets ():
20
+ pass
21
+
22
+ def find_a_star ():
23
+ pass
24
+
25
+ class StandardCPgpu (ConformalPrediction ):
26
+ # TODO change name of class not to be misleading
27
+ # It works with or without GPU
28
+ # Implements system with both standard and modified conformal prediction
29
+ """Implementation of the functions of our system"""
30
+ def __init__ (self , X_cal , y_cal , X_est , y_est ,model , delta ) -> None :
31
+ super ().__init__ (X_cal , y_cal ,X_est , y_est , model , delta )
32
+
33
+ def epsilon_fn (self , k_a , delta_n_alphas ):
34
+ """Estimation error"""
35
+ delta_n_alphas_t = torch .tensor (delta_n_alphas )
36
+ epsilon = torch .sqrt ((torch .log (delta_n_alphas_t ))/ (2 * k_a ))
37
+ return epsilon
38
+
39
+ def find_all_alpha_values (self ):
40
+ """Retruns all 0<alpha<1 values that can be considered"""
41
+ # conformal scores of true labels in calibration set
42
+ model_out = self .model .predict_prob (self .X_cal )
43
+ one_hot = np .eye (conf .n_labels )[self .y_cal ]
44
+ true_label_logits = model_out * one_hot
45
+
46
+ conf_scores = sorted (1 - true_label_logits [true_label_logits > 0 ])
47
+ self .conf_scores_t = torch .tensor (conf_scores , device = conf .device )
48
+ # scores of all predicted labels for each sample in calibration set
49
+ logits = self .model .predict_prob (self .X_cal )
50
+ logits_scores = (1 - logits )
51
+
52
+ # all coverages that result in different sets for each sample in calibration set
53
+ one_minus_alphas = np .searchsorted (conf_scores , logits_scores , side = 'left' )/ self .calibration_size
54
+
55
+ alphas = 1 - one_minus_alphas [one_minus_alphas < 1 ]
56
+ alphas = alphas [(1 - alphas ) > 1 / self .calibration_size ]
57
+ self .alphas = alphas
58
+ self .n_alphas = conf .n_labels * self .calibration_size
59
+ return alphas
60
+
61
+
62
+
63
+ def find_a_star (self , w_matrix , a1_star_idx = None , all_a1_a2 = False ):
64
+ # TODO change title a_star means ^alpha
65
+ """Return ^alpha"""
66
+ a_star_idx = - 1
67
+ curr_criterion = 0
68
+ alphas1 = self .alphas .flatten ()
69
+ qhat_a1 = torch .zeros ((1 ,1 ), device = conf .device )
70
+ # alphas and qunatiles for shifted quantile method given a_1
71
+ if a1_star_idx is not None :
72
+ quant_a1 = (np .ceil ((1 - alphas1 [a1_star_idx ])* (self .calibration_size + 1 ))/ self .calibration_size )
73
+ qhat_a1 = torch .quantile (self .conf_scores_t , quant_a1 )
74
+ alphas = self .alphas [self .alphas > self .alphas [a1_star_idx ]]
75
+ else :
76
+ alphas = self .alphas
77
+
78
+ # all qunatiles for each alpha value
79
+ quant_unique = (np .ceil ((1 - alphas )* (self .calibration_size + 1 ))/ self .calibration_size ).flatten ()
80
+ self .epsilon = np .zeros (quant_unique .shape )
81
+
82
+
83
+ # output scores for each sample in estimation set
84
+ output_scores = 1 - self .model .predict_prob (self .X_est )
85
+
86
+ # move data to gpu if available
87
+ quants_t = torch .tensor (quant_unique , device = conf .device )
88
+
89
+ qhats_t = torch .quantile (self .conf_scores_t , quants_t , keepdim = True )
90
+ qhats_t = qhats_t .unsqueeze (1 )
91
+ y_cal_t = torch .tensor (self .y_est , device = conf .device , dtype = torch .int64 )
92
+ fill_value_t = torch .tensor (0 , dtype = torch .double ,device = conf .device )
93
+ output_scores_t = torch .tensor (output_scores ,device = conf .device )
94
+ ws_t = torch .tensor (w_matrix [self .y_est ], device = conf .device )
95
+
96
+ for i ,q in enumerate (qhats_t ):
97
+ qhats = q .expand (self .calibration_size ,conf .n_labels )
98
+
99
+ # sets[sample][label] is 1 for the labels in the prediction set for each sample
100
+ # sets for shifted quantile method given a_1
101
+ if a1_star_idx is not None :
102
+ qhats_a1 = qhat_a1 .expand (self .calibration_size ,conf .n_labels )
103
+ sets_upper = torch .where ( output_scores_t <= qhats_a1 , 1 ,0 )
104
+ sets_lower = torch .where (qhats <= output_scores_t , 1 ,0 )
105
+ sets = sets_upper * sets_lower
106
+ else :
107
+ sets = torch .where (output_scores_t <= qhats , 1 ,0 )
108
+ sets_exp_ws = sets * torch .exp (ws_t )
109
+
110
+ # denominators for all P[\hat Y = Y | C_alpha(X), Y \in C_alpha(X), Y=y]
111
+ denominators = torch .sum (sets_exp_ws , axis = 1 )
112
+ one_hot_ycal = F .one_hot (y_cal_t )
113
+ # mask for prediction sets that include the true label
114
+ mask = sets * one_hot_ycal
115
+ true_label_in_sets_idx = torch .sum (mask , axis = 1 )
116
+
117
+ # nominators for all P[\hat Y = Y | C_alpha(X), Y \in C_alpha(X), Y=y]
118
+ nominators = torch .sum (sets_exp_ws * one_hot_ycal , axis = 1 )
119
+
120
+ # apply mask so that Y \in C_alpha(X) is satisfied
121
+ masked_prob = torch .where (true_label_in_sets_idx == 1 , nominators / denominators , fill_value_t )
122
+ # number of sets that Y \in C_alpha(X) is satisfied
123
+ k_a = true_label_in_sets_idx .sum ()
124
+ # non empty sets and alpha_star > 0
125
+ if k_a > 0 :
126
+ expected_correct_prob = masked_prob .sum ()/ k_a
127
+ delta_n_alphas = (alphas .shape [0 ] / self .delta ) if not all_a1_a2 else ((self .calibration_size ** 2 )* (conf .n_labels ** 2 ))/ self .delta
128
+ epsilon = self .epsilon_fn (k_a , delta_n_alphas )
129
+ self .epsilon [i ] = epsilon
130
+
131
+ coverage = 1 - alphas [i ] if not a1_star_idx else (alphas [i ] - alphas1 [a1_star_idx ] - (1 / (self .calibration_size + 1 )))
132
+ criterion = coverage * (expected_correct_prob - epsilon )
133
+ if criterion > curr_criterion :
134
+ a_star_idx = i
135
+ curr_criterion = criterion
136
+ if all_a1_a2 :
137
+ return a_star_idx , curr_criterion
138
+
139
+ return a_star_idx
140
+
141
+
142
+ def error_given_test_set_per_a (self , X_test , y_test , w_matrix , alphas ,a_star_idx = None , a2_star_idx = None ):
143
+ """Misprediction probability for each value of alpha or alpha_2 given alpha_1"""
144
+ test_size = len (X_test )
145
+ output_scores = 1 - self .model .predict_prob (X_test )
146
+
147
+ # alphas and qunatiles for shifted quantile method
148
+ if a_star_idx is not None :
149
+ quant_a1 = (np .ceil ((1 - self .alphas [a_star_idx ])* (self .calibration_size + 1 ))/ self .calibration_size )
150
+ qhat_a1 = torch .quantile (self .conf_scores_t , quant_a1 )
151
+ alphas = np .array ([alphas ]) if a2_star_idx is not None else self .alphas [self .alphas > self .alphas [a_star_idx ]]
152
+
153
+ qhats_unique = (np .ceil ((1 - alphas )* (self .calibration_size + 1 ))/ self .calibration_size )
154
+ error_rate_per_a = torch .zeros ((len (qhats_unique ),), device = conf .device )
155
+
156
+ # move data to gpu if available
157
+ qhats_t = torch .tensor (qhats_unique , device = conf .device ).unsqueeze (1 )
158
+ y_test_t = torch .tensor (y_test , device = conf .device , dtype = torch .int64 )
159
+ output_scores_t = torch .tensor (output_scores ,device = conf .device )
160
+ ws_t = torch .tensor (w_matrix [y_test ], device = conf .device )
161
+ a_empty_sets = 0
162
+ fill_value_t = torch .exp (ws_t )/ (torch .exp (ws_t ).sum (axis = 1 ).unsqueeze (1 ).expand (- 1 ,conf .n_labels ))
163
+
164
+ for i ,q in enumerate (qhats_t ):
165
+
166
+ qhats = q .expand (test_size ,conf .n_labels )
167
+ # sets[sample][label] is 1 for the labels in the prediction set for each sample
168
+ # sets for shifted quantile method given alpha_1
169
+ if a_star_idx is not None :
170
+ qhats_a1 = qhat_a1 .expand (test_size ,conf .n_labels )
171
+ sets_upper = torch .where ( output_scores_t <= qhats_a1 , 1 ,0 )
172
+ sets_lower = torch .where (qhats <= output_scores_t , 1 ,0 )
173
+ sets = sets_upper * sets_lower
174
+ else :
175
+ sets = torch .where (output_scores_t <= qhats , 1 ,0 )
176
+ non_empty_sets = sets .sum (axis = 1 ).count_nonzero ()
177
+
178
+ if non_empty_sets == 0 :
179
+ a_empty_sets += 1
180
+
181
+ # Denominators for P[\hat Y = y | C_alpha(X), y \in C_alpha(X)]
182
+ sets_exp_ws = sets * torch .exp (ws_t )
183
+ denominators_col = torch .sum (sets_exp_ws , axis = 1 )
184
+ denominators = denominators_col .unsqueeze (1 ).expand (- 1 ,conf .n_labels )
185
+
186
+
187
+ # Nomiators for P[\hat Y = y | C_alpha(X), y \in C_alpha(X)]
188
+ nominators = sets_exp_ws
189
+
190
+ # confusion matrix for each C
191
+ cm = torch .where (denominators > 0 , nominators / denominators , fill_value_t )
192
+
193
+ # human prediction from prediction sets
194
+ y_h = cm .multinomial (num_samples = 1 , replacement = True , generator = conf .torch_rng ).squeeze ()
195
+
196
+ # error for empty sets
197
+
198
+ y_hats = torch .where (denominators_col > 0 , y_h , - 1 )
199
+ errors = (y_hats != y_test_t ).count_nonzero ().double ()
200
+ error_rate_per_a [i ] = errors / test_size
201
+
202
+
203
+
204
+ return error_rate_per_a
205
+
206
+
207
+
208
+ def size_given_test_set_per_a (self , X_test , y_test , w_matrix , alphas ,a_star_idx = None , a2_star_idx = None ):
209
+ """Average set size for each value of alpha or alpha_2 given alpha_1"""
210
+ test_size = len (X_test )
211
+ output_scores = 1 - self .model .predict_prob (X_test )
212
+
213
+ # alphas and qunatiles for shifted quantile method
214
+ if a_star_idx is not None :
215
+ quant_a1 = (np .ceil ((1 - self .alphas [a_star_idx ])* (self .calibration_size + 1 ))/ self .calibration_size )
216
+ qhat_a1 = torch .quantile (self .conf_scores_t , quant_a1 )
217
+ alphas = np .array ([alphas ]) if a2_star_idx is not None else self .alphas [self .alphas > self .alphas [a_star_idx ]]
218
+
219
+ qhats_unique = (np .ceil ((1 - alphas )* (self .calibration_size + 1 ))/ self .calibration_size )
220
+ set_size_per_a = torch .zeros ((len (qhats_unique ),), device = conf .device )
221
+
222
+ # move data to gpu if available
223
+ qhats_t = torch .tensor (qhats_unique , device = conf .device ).unsqueeze (1 )
224
+ output_scores_t = torch .tensor (output_scores ,device = conf .device )
225
+ ws_t = torch .tensor (w_matrix [y_test ], device = conf .device )
226
+
227
+ for i ,q in enumerate (qhats_t ):
228
+
229
+ qhats = q .expand (test_size ,conf .n_labels )
230
+ # sets[sample][label] is 1 for the labels in the prediction set for each sample
231
+ # sets for shifted quantile method
232
+ if a_star_idx is not None :
233
+ qhats_a1 = qhat_a1 .expand (test_size ,conf .n_labels )
234
+ sets_upper = torch .where ( output_scores_t <= qhats_a1 , 1 ,0 )
235
+ sets_lower = torch .where (qhats <= output_scores_t , 1 ,0 )
236
+ sets = sets_upper * sets_lower
237
+ else :
238
+ sets = torch .where (output_scores_t <= qhats , 1 ,0 )
239
+ size_per_set = sets .sum (axis = 1 )
240
+ set_size_per_a [i ] = size_per_set .sum ()/ size_per_set .numel ()
241
+
242
+
243
+ return set_size_per_a
244
+
245
+
246
+
247
+
0 commit comments