Skip to content

Commit e8cd022

Browse files
committed
add eszsl
1 parent 9685207 commit e8cd022

File tree

2 files changed

+142
-1
lines changed

2 files changed

+142
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
__pycache__/
1+
__pycache__/
2+
.ipynb_checkpoints/

ESZSL/eszsl.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import numpy as np
2+
import argparse
3+
from scipy import io
4+
from sklearn.metrics import confusion_matrix
5+
6+
parser = argparse.ArgumentParser(description="ESZSL")
7+
8+
parser.add_argument('-data', '--dataset', help='choose between APY, AWA2, CUB, SUN', default='AWA2', type=str)
9+
parser.add_argument('-mode', '--mode', help='train/test, if test set alpha, gamma to best values below', default='train', type=str)
10+
parser.add_argument('-alpha', '--alpha', default=0, type=int)
11+
parser.add_argument('-gamma', '--gamma', default=0, type=int)
12+
13+
"""
14+
15+
Best Values of (Alpha, Gamma) found by validation & corr. test accuracies:
16+
17+
AWA2 -> (3, 0) -> Test Acc : 0.5482
18+
CUB -> (3, -1) -> Test Acc : 0.5394
19+
SUN -> (3, 2) -> Test Acc : 0.5569
20+
APY -> (3, -1) -> Test Acc : 0.3856
21+
22+
"""
23+
24+
class ESZSL():
25+
26+
def __init__(self, args):
27+
28+
self.args = args
29+
30+
data_folder = '../datasets/'+args.dataset+'/'
31+
res101 = io.loadmat(data_folder+'res101.mat')
32+
att_splits=io.loadmat(data_folder+'att_splits.mat')
33+
34+
train_loc = 'train_loc'
35+
val_loc = 'val_loc'
36+
test_loc = 'test_unseen_loc'
37+
38+
feat = res101['features']
39+
self.X_train = feat[:,np.squeeze(att_splits[train_loc]-1)]
40+
self.X_val = feat[:,np.squeeze(att_splits[val_loc]-1)]
41+
self.X_trainval = np.concatenate((self.X_train, self.X_val), axis=1)
42+
self.X_test = feat[:,np.squeeze(att_splits[test_loc]-1)]
43+
44+
labels = res101['labels']
45+
labels_train = labels[np.squeeze(att_splits[train_loc]-1)]
46+
self.labels_val = labels[np.squeeze(att_splits[val_loc]-1)]
47+
labels_trainval = np.concatenate((labels_train, self.labels_val), axis=0)
48+
self.labels_test = labels[np.squeeze(att_splits[test_loc]-1)]
49+
50+
train_labels_seen = np.unique(labels_train)
51+
val_labels_unseen = np.unique(self.labels_val)
52+
trainval_labels_seen = np.unique(labels_trainval)
53+
test_labels_unseen = np.unique(self.labels_test)
54+
55+
i=0
56+
for labels in train_labels_seen:
57+
labels_train[labels_train == labels] = i
58+
i+=1
59+
60+
j=0
61+
for labels in val_labels_unseen:
62+
self.labels_val[self.labels_val == labels] = j
63+
j+=1
64+
65+
k=0
66+
for labels in trainval_labels_seen:
67+
labels_trainval[labels_trainval == labels] = k
68+
k+=1
69+
70+
l=0
71+
for labels in test_labels_unseen:
72+
self.labels_test[self.labels_test == labels] = l
73+
l+=1
74+
75+
self.gt_train = np.zeros((labels_train.shape[0], len(train_labels_seen)))
76+
self.gt_train[np.arange(labels_train.shape[0]), np.squeeze(labels_train)] = 1
77+
78+
self.gt_trainval = np.zeros((labels_trainval.shape[0], len(trainval_labels_seen)))
79+
self.gt_trainval[np.arange(labels_trainval.shape[0]), np.squeeze(labels_trainval)] = 1
80+
81+
sig = att_splits['att']
82+
self.train_sig = sig[:, train_labels_seen-1]
83+
self.val_sig = sig[:, val_labels_unseen-1]
84+
self.trainval_sig = sig[:, trainval_labels_seen-1]
85+
self.test_sig = sig[:, test_labels_unseen-1]
86+
87+
def find_W(self, X, y, sig, alpha, gamma):
88+
89+
part_0 = np.linalg.pinv(np.matmul(X, X.transpose()) + (10**alpha)*np.eye(X.shape[0]))
90+
part_1 = np.matmul(np.matmul(X, y), sig.transpose())
91+
part_2 = np.linalg.pinv(np.matmul(sig, sig.transpose()) + (10**gamma)*np.eye(sig.shape[0]))
92+
93+
W = np.matmul(np.matmul(part_0, part_1), part_2)
94+
95+
return W
96+
97+
def find_hyperparamters(self):
98+
99+
print('Training...\n')
100+
101+
best_acc = 0.0
102+
103+
for alph in range(-3, 4):
104+
for gamm in range(-3, 4):
105+
W = self.find_W(self.X_train, self.gt_train, self.train_sig, alph, gamm)
106+
acc = self.zsl_acc(self.X_val, W, self.labels_val, self.val_sig)
107+
print('Val Acc:{}; Alpha:{}; Gamma:{}\n'.format(acc, alph, gamm))
108+
if acc>best_acc:
109+
best_acc = acc
110+
alpha = alph
111+
gamma = gamm
112+
113+
print('\nBest Val Acc:{} with Alpha:{} & Gamma:{}\n'.format(best_acc, alpha, gamma))
114+
115+
return alpha, gamma
116+
117+
def zsl_acc(self, X, W, y_true, sig): # Class Averaged Top-1 Accuarcy
118+
119+
class_scores = np.matmul(np.matmul(X.transpose(), W), sig)
120+
predicted_classes = np.array([np.argmax(output) for output in class_scores])
121+
cm = confusion_matrix(y_true, predicted_classes)
122+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
123+
acc = sum(cm.diagonal())/sig.shape[1]
124+
125+
return acc
126+
127+
def evaluate(self):
128+
129+
if self.args.mode=='train': alpha, gamma = self.find_hyperparamters()
130+
else: alpha, gamma = self.args.alpha, self.args.gamma
131+
132+
best_W = self.find_W(self.X_trainval, self.gt_trainval, self.trainval_sig, alpha, gamma) # combine train and val
133+
134+
test_acc = self.zsl_acc(self.X_test, best_W, self.labels_test, self.test_sig)
135+
136+
print('Test Acc:{}'.format(test_acc))
137+
138+
args = parser.parse_args()
139+
model = ESZSL(args)
140+
model.evaluate()

0 commit comments

Comments
 (0)