Skip to content

Commit ef83471

Browse files
author
usr
committed
Initial commit
0 parents  commit ef83471

18 files changed

+41563
-0
lines changed

README.md

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Provably Improving Expert Predictions with Conformal Prediction
2+
3+
## Install dependences
4+
5+
Experiments ran on python 3.7.3, with GPU. To install the required libraries, set the torch version according to the available device (CPU or GPU) in `requirements.txt` and run:
6+
7+
`pip install -r requirements.txt`
8+
9+
## Running experiments
10+
11+
For synthetic experiments run:
12+
13+
`python3 ./run_conf_synthetic.py --n_labels `<*n*\> `--cal_split` <*split*\> `--runs` <*runs*\>
14+
15+
where:
16+
* <*n*\> is the the number of labels, i.e. $n$.
17+
* <*split*\> is the calibration and estimation split, i.e. $\frac{m}{data\_to\_split}$.
18+
* <*runs*\> is the number of times that each experiment will run with different random splits of the above specified size.
19+
20+
**Note:** the above runs experiments for $\mathbb{P}[\hat Y = Y | \mathcal{Y}]\in\{0.3,0.5,0.7,0.9\}$ and classifiers' accuracies also $\in\{0.3,0.5,0.7,0.9\}$.
21+
22+
For real data experiments run:
23+
24+
`python3 ./run_conf_real.py --cal_split` <*split*\> `--runs` <*runs*\>
25+
26+
where <*split*\> and <*runs*\> are the same as above.
27+
28+
**Note:** the above runs experiments for all classifiers in the paper.
29+
30+
## Results
31+
32+
* All plots are produced in `plots.ipynb`.
33+
* For the Tables we used the functions `print_accuracy_synthetic()` and `print_accuracy_tables_real()` for synthetic and real data experiments results respectively, in `plot/plot.py`.
34+
* For results regarding the relative gain in success probability $\mathbb{P}[\hat{Y}= Y| \mathcal{C}_{\hat{\alpha}}(X)]$ with respect to $\mathbb{P}[\hat{Y}= Y| \mathcal{Y}]$ we used `get_mn()` for synthetic experiments and `get_m_real()` for real data experiments in `plot/plot.py`.

config.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import torch
3+
import os
4+
import argparse
5+
6+
class Config:
7+
def __init__(self) -> None:
8+
pass
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--n_labels", type=int, default=10)
12+
parser.add_argument("--cal_split", type=float)
13+
parser.add_argument("--runs", type=int, default=5)
14+
args = parser.parse_args()
15+
conf = Config()
16+
17+
conf.ROOT_DIR = os.path.dirname(__file__)
18+
19+
if torch.cuda.is_available():
20+
conf.device = torch.cuda.current_device()
21+
else:
22+
conf.device = 'cpu'
23+
24+
25+
conf.seed = 12345678
26+
27+
conf.torch_rng = torch.Generator(device=conf.device)
28+
conf.torch_rng.manual_seed(conf.seed)
29+
conf.rng = np.random.default_rng(seed=conf.seed)
30+
conf.data_size = 10000
31+
# parameter to control difficulty of the dataset in synthetic
32+
conf.class_sep = {10:{0.3:0.46, 0.5:1.09, 0.7:1.72, 0.9: 2.75},
33+
50:{0.3:1.31, 0.5:2.16, 0.7:3.19, 0.9: 5.27},
34+
100:{0.3:1.75, 0.5:2.8, 0.7:4.4, 0.9: 7.7} }
35+
36+
conf.accuracies = np.arange(3,10, 2)/10.
37+
conf.is_oblivious= False
38+
39+
conf.n_labels = args.n_labels
40+
conf.cal_split = args.cal_split
41+
42+
conf.test_split = 0.2 # synthetic test split
43+
conf.n_runs_per_split = args.runs
44+
conf.delta = 0.1
45+
# synthetic
46+
distr = conf.rng.dirichlet(np.ones(conf.n_labels),size=1)
47+
sum_distr = distr.sum()
48+
if sum_distr < 1.:
49+
distr += (1 - sum_distr)/conf.n_labels
50+
conf.class_probabilities = distr
51+
52+
conf.model_names = ['densenet-bc-L190-k40' ,'preresnet-110','resnet-110']

conformal_prediction.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
from config import conf
2+
import numpy as np
3+
import torch
4+
import torch.nn.functional as F
5+
6+
class ConformalPrediction:
7+
def __init__(self, X_cal, y_cal, X_est, y_est,model, delta) -> None:
8+
self.model = model
9+
self.X_cal = X_cal
10+
self.y_cal = y_cal
11+
self.X_est = X_est
12+
self.y_est = y_est
13+
self.calibration_size = len(y_cal)
14+
self.delta = delta
15+
16+
def find_all_alpha_values():
17+
pass
18+
19+
def prediction_sets():
20+
pass
21+
22+
def find_a_star():
23+
pass
24+
25+
class StandardCPgpu(ConformalPrediction):
26+
# TODO change name of class not to be misleading
27+
# It works with or without GPU
28+
# Implements system with both standard and modified conformal prediction
29+
"""Implementation of the functions of our system"""
30+
def __init__(self, X_cal, y_cal, X_est, y_est,model, delta) -> None:
31+
super().__init__(X_cal, y_cal,X_est, y_est, model, delta)
32+
33+
def epsilon_fn(self, k_a, delta_n_alphas):
34+
"""Estimation error"""
35+
delta_n_alphas_t = torch.tensor(delta_n_alphas)
36+
epsilon = torch.sqrt((torch.log(delta_n_alphas_t))/(2*k_a))
37+
return epsilon
38+
39+
def find_all_alpha_values(self):
40+
"""Retruns all 0<alpha<1 values that can be considered"""
41+
# conformal scores of true labels in calibration set
42+
model_out = self.model.predict_prob(self.X_cal)
43+
one_hot = np.eye(conf.n_labels)[self.y_cal]
44+
true_label_logits = model_out*one_hot
45+
46+
conf_scores = sorted(1 - true_label_logits[true_label_logits >0 ])
47+
self.conf_scores_t = torch.tensor(conf_scores, device=conf.device)
48+
# scores of all predicted labels for each sample in calibration set
49+
logits = self.model.predict_prob(self.X_cal)
50+
logits_scores = (1 - logits)
51+
52+
# all coverages that result in different sets for each sample in calibration set
53+
one_minus_alphas = np.searchsorted(conf_scores, logits_scores, side='left')/self.calibration_size
54+
55+
alphas = 1 - one_minus_alphas[one_minus_alphas < 1]
56+
alphas = alphas[(1-alphas) > 1/self.calibration_size]
57+
self.alphas = alphas
58+
self.n_alphas = conf.n_labels*self.calibration_size
59+
return alphas
60+
61+
62+
63+
def find_a_star(self, w_matrix, a1_star_idx=None, all_a1_a2=False):
64+
# TODO change title a_star means ^alpha
65+
"""Return ^alpha"""
66+
a_star_idx =-1
67+
curr_criterion = 0
68+
alphas1 = self.alphas.flatten()
69+
qhat_a1 = torch.zeros((1,1), device=conf.device)
70+
# alphas and qunatiles for shifted quantile method given a_1
71+
if a1_star_idx is not None:
72+
quant_a1= (np.ceil((1 - alphas1[a1_star_idx])*(self.calibration_size+1))/self.calibration_size)
73+
qhat_a1 = torch.quantile(self.conf_scores_t, quant_a1)
74+
alphas = self.alphas[self.alphas > self.alphas[a1_star_idx]]
75+
else:
76+
alphas = self.alphas
77+
78+
# all qunatiles for each alpha value
79+
quant_unique = (np.ceil((1 - alphas)*(self.calibration_size+1))/self.calibration_size).flatten()
80+
self.epsilon = np.zeros(quant_unique.shape)
81+
82+
83+
# output scores for each sample in estimation set
84+
output_scores = 1 - self.model.predict_prob(self.X_est)
85+
86+
# move data to gpu if available
87+
quants_t = torch.tensor(quant_unique, device=conf.device)
88+
89+
qhats_t = torch.quantile(self.conf_scores_t, quants_t, keepdim=True)
90+
qhats_t = qhats_t.unsqueeze(1)
91+
y_cal_t = torch.tensor(self.y_est, device=conf.device, dtype=torch.int64)
92+
fill_value_t = torch.tensor(0, dtype=torch.double,device=conf.device)
93+
output_scores_t = torch.tensor(output_scores,device=conf.device)
94+
ws_t = torch.tensor(w_matrix[self.y_est], device=conf.device)
95+
96+
for i,q in enumerate(qhats_t):
97+
qhats = q.expand(self.calibration_size,conf.n_labels )
98+
99+
# sets[sample][label] is 1 for the labels in the prediction set for each sample
100+
# sets for shifted quantile method given a_1
101+
if a1_star_idx is not None:
102+
qhats_a1 = qhat_a1.expand(self.calibration_size,conf.n_labels )
103+
sets_upper = torch.where( output_scores_t <= qhats_a1, 1,0)
104+
sets_lower = torch.where(qhats <= output_scores_t, 1,0)
105+
sets = sets_upper* sets_lower
106+
else:
107+
sets = torch.where(output_scores_t<= qhats, 1,0)
108+
sets_exp_ws = sets * torch.exp(ws_t)
109+
110+
# denominators for all P[\hat Y = Y | C_alpha(X), Y \in C_alpha(X), Y=y]
111+
denominators = torch.sum(sets_exp_ws, axis=1)
112+
one_hot_ycal = F.one_hot(y_cal_t)
113+
# mask for prediction sets that include the true label
114+
mask = sets * one_hot_ycal
115+
true_label_in_sets_idx = torch.sum(mask, axis=1)
116+
117+
# nominators for all P[\hat Y = Y | C_alpha(X), Y \in C_alpha(X), Y=y]
118+
nominators = torch.sum(sets_exp_ws*one_hot_ycal, axis=1)
119+
120+
# apply mask so that Y \in C_alpha(X) is satisfied
121+
masked_prob = torch.where(true_label_in_sets_idx==1, nominators/denominators, fill_value_t)
122+
# number of sets that Y \in C_alpha(X) is satisfied
123+
k_a = true_label_in_sets_idx.sum()
124+
# non empty sets and alpha_star > 0
125+
if k_a > 0 :
126+
expected_correct_prob = masked_prob.sum()/k_a
127+
delta_n_alphas = (alphas.shape[0] /self.delta) if not all_a1_a2 else ((self.calibration_size**2)*(conf.n_labels**2))/self.delta
128+
epsilon = self.epsilon_fn(k_a, delta_n_alphas)
129+
self.epsilon[i] = epsilon
130+
131+
coverage = 1 - alphas[i] if not a1_star_idx else (alphas[i] - alphas1[a1_star_idx] - (1/(self.calibration_size + 1)))
132+
criterion = coverage*(expected_correct_prob - epsilon)
133+
if criterion > curr_criterion:
134+
a_star_idx = i
135+
curr_criterion = criterion
136+
if all_a1_a2:
137+
return a_star_idx, curr_criterion
138+
139+
return a_star_idx
140+
141+
142+
def error_given_test_set_per_a(self, X_test, y_test, w_matrix, alphas,a_star_idx=None, a2_star_idx=None):
143+
"""Misprediction probability for each value of alpha or alpha_2 given alpha_1"""
144+
test_size = len(X_test)
145+
output_scores = 1 - self.model.predict_prob(X_test)
146+
147+
# alphas and qunatiles for shifted quantile method
148+
if a_star_idx is not None:
149+
quant_a1= (np.ceil((1 - self.alphas[a_star_idx])*(self.calibration_size+1))/self.calibration_size)
150+
qhat_a1 = torch.quantile(self.conf_scores_t, quant_a1)
151+
alphas = np.array([alphas]) if a2_star_idx is not None else self.alphas[self.alphas > self.alphas[a_star_idx]]
152+
153+
qhats_unique = (np.ceil((1 - alphas)*(self.calibration_size+1))/self.calibration_size)
154+
error_rate_per_a = torch.zeros((len(qhats_unique),), device=conf.device)
155+
156+
# move data to gpu if available
157+
qhats_t = torch.tensor(qhats_unique, device=conf.device).unsqueeze(1)
158+
y_test_t = torch.tensor(y_test, device=conf.device, dtype=torch.int64)
159+
output_scores_t = torch.tensor(output_scores,device=conf.device)
160+
ws_t = torch.tensor(w_matrix[y_test], device=conf.device)
161+
a_empty_sets = 0
162+
fill_value_t = torch.exp(ws_t)/(torch.exp(ws_t).sum(axis=1).unsqueeze(1).expand(-1,conf.n_labels))
163+
164+
for i,q in enumerate(qhats_t):
165+
166+
qhats = q.expand(test_size,conf.n_labels )
167+
# sets[sample][label] is 1 for the labels in the prediction set for each sample
168+
# sets for shifted quantile method given alpha_1
169+
if a_star_idx is not None:
170+
qhats_a1 = qhat_a1.expand(test_size,conf.n_labels )
171+
sets_upper = torch.where( output_scores_t <= qhats_a1, 1,0)
172+
sets_lower = torch.where(qhats <= output_scores_t, 1,0)
173+
sets = sets_upper* sets_lower
174+
else:
175+
sets = torch.where(output_scores_t<=qhats, 1,0)
176+
non_empty_sets = sets.sum(axis=1).count_nonzero()
177+
178+
if non_empty_sets ==0 :
179+
a_empty_sets+=1
180+
181+
# Denominators for P[\hat Y = y | C_alpha(X), y \in C_alpha(X)]
182+
sets_exp_ws = sets * torch.exp(ws_t)
183+
denominators_col = torch.sum(sets_exp_ws, axis=1)
184+
denominators = denominators_col.unsqueeze(1).expand(-1,conf.n_labels)
185+
186+
187+
# Nomiators for P[\hat Y = y | C_alpha(X), y \in C_alpha(X)]
188+
nominators = sets_exp_ws
189+
190+
# confusion matrix for each C
191+
cm = torch.where(denominators>0, nominators/denominators, fill_value_t)
192+
193+
# human prediction from prediction sets
194+
y_h = cm.multinomial(num_samples=1, replacement=True, generator=conf.torch_rng).squeeze()
195+
196+
# error for empty sets
197+
198+
y_hats = torch.where(denominators_col>0, y_h , -1)
199+
errors = (y_hats!=y_test_t).count_nonzero().double()
200+
error_rate_per_a[i] = errors/test_size
201+
202+
203+
204+
return error_rate_per_a
205+
206+
207+
208+
def size_given_test_set_per_a(self, X_test, y_test, w_matrix, alphas,a_star_idx=None, a2_star_idx=None):
209+
"""Average set size for each value of alpha or alpha_2 given alpha_1"""
210+
test_size = len(X_test)
211+
output_scores = 1 - self.model.predict_prob(X_test)
212+
213+
# alphas and qunatiles for shifted quantile method
214+
if a_star_idx is not None:
215+
quant_a1= (np.ceil((1 - self.alphas[a_star_idx])*(self.calibration_size+1))/self.calibration_size)
216+
qhat_a1 = torch.quantile(self.conf_scores_t, quant_a1)
217+
alphas = np.array([alphas]) if a2_star_idx is not None else self.alphas[self.alphas > self.alphas[a_star_idx]]
218+
219+
qhats_unique = (np.ceil((1 - alphas)*(self.calibration_size+1))/self.calibration_size)
220+
set_size_per_a = torch.zeros((len(qhats_unique),), device=conf.device)
221+
222+
# move data to gpu if available
223+
qhats_t = torch.tensor(qhats_unique, device=conf.device).unsqueeze(1)
224+
output_scores_t = torch.tensor(output_scores,device=conf.device)
225+
ws_t = torch.tensor(w_matrix[y_test], device=conf.device)
226+
227+
for i,q in enumerate(qhats_t):
228+
229+
qhats = q.expand(test_size,conf.n_labels )
230+
# sets[sample][label] is 1 for the labels in the prediction set for each sample
231+
# sets for shifted quantile method
232+
if a_star_idx is not None:
233+
qhats_a1 = qhat_a1.expand(test_size,conf.n_labels )
234+
sets_upper = torch.where( output_scores_t <= qhats_a1, 1,0)
235+
sets_lower = torch.where(qhats <= output_scores_t, 1,0)
236+
sets = sets_upper* sets_lower
237+
else:
238+
sets = torch.where(output_scores_t<=qhats, 1,0)
239+
size_per_set = sets.sum(axis=1)
240+
set_size_per_a[i] = size_per_set.sum()/size_per_set.numel()
241+
242+
243+
return set_size_per_a
244+
245+
246+
247+

0 commit comments

Comments
 (0)