Skip to content

Commit bb475ea

Browse files
added pytorch training
moved tensorflow to subdirectory
1 parent c505153 commit bb475ea

File tree

13 files changed

+378
-97
lines changed

13 files changed

+378
-97
lines changed

rt_bene_model_training/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ For pip users: `pip install tensorflow-gpu numpy tqdm opencv-python scikit-learn
3131
This code was used to train the blink estimator for RT-BENE. The labels for the RT-BENE blink dataset are contained in the [rt_bene_dataset](../rt_bene_dataset) directory. The images corresponding to the labels can be downloaded from the RT-GENE dataset (labels are only available for the "noglasses" part): [download](https://zenodo.org/record/2529036) [(alternative link)](https://goo.gl/tfUaDm). Please run `python train_blink_model.py --help` to see the required arguments to train the model.
3232

3333
## Model testing code
34-
Evaluation code for a 3-fold evaluation is provided in the [evaluate_blink_model.py](./evaluate_blink_model.py) file. An example to train and evaluate an ensemble of models can be found in [train_and_evaluate.py](./train_and_evaluate.py). Please run `python train_and_evaluate.py --help` to see the required arguments.
34+
Evaluation code for a 3-fold evaluation is provided in the [evaluate_blink_model.py](tensorflow/evaluate_blink_model.py) file. An example to train and evaluate an ensemble of models can be found in [train_and_evaluate.py](tensorflow/train_and_evaluate.py). Please run `python train_and_evaluate.py --help` to see the required arguments.
3535

3636
![Results](../assets/rt_bene_precision_recall.png)
3737

rt_bene_model_training/__init__.py

Whitespace-only changes.

rt_bene_model_training/pytorch/__init__.py

Whitespace-only changes.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import os
2+
3+
import numpy as np
4+
from PIL import Image
5+
from torch.utils import data
6+
from torchvision import transforms
7+
from tqdm import tqdm
8+
9+
10+
class RTBENEH5Dataset(data.Dataset):
11+
12+
def __init__(self, h5_file, subject_list=None, transform=None, loader_desc="train"):
13+
self._h5_file = h5_file
14+
self._transform = transform
15+
self._subject_labels = []
16+
self._positive_labels = 0
17+
self._total_labels = 0
18+
19+
assert subject_list is not None, "Must pass a list of subjects to load the data for"
20+
21+
if self._transform is None:
22+
self._transform = transforms.Compose([transforms.Resize((224, 224), Image.BICUBIC),
23+
transforms.ToTensor(),
24+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
25+
std=[0.229, 0.224, 0.225])])
26+
27+
_wanted_subjects = ["s{:03d}".format(_i) for _i in subject_list]
28+
29+
for grp_s_n in tqdm(_wanted_subjects, desc="Loading ({}) subject metadata...".format(loader_desc)): # subjects
30+
for grp_i_n, grp_i in h5_file[grp_s_n].items(): # images
31+
if "image" in grp_i.keys() and "label" in grp_i.keys():
32+
image_dataset = grp_i["image"]
33+
label = grp_i["label"]
34+
if label == 1.0:
35+
self._positive_labels = self._positive_labels + 1
36+
self._total_labels = self._total_labels + 1
37+
38+
for _i in range(len(image_dataset)):
39+
self._subject_labels.append(["/" + grp_s_n + "/" + grp_i_n, _i])
40+
41+
@staticmethod
42+
def get_class_weights(h5_file, subject_list):
43+
positive = 0
44+
total = 0
45+
_wanted_subjects = ["s{:03d}".format(_i) for _i in subject_list]
46+
47+
for grp_s_n in tqdm(_wanted_subjects, desc="Loading class weights..."):
48+
for grp_i_n, grp_i in h5_file[grp_s_n].items(): # images
49+
if "image" in grp_i.keys() and "label" in grp_i.keys():
50+
label = grp_i["label"][()][0]
51+
if label == 1.0:
52+
positive = positive + 1
53+
total = total + 1
54+
55+
negative = total - positive
56+
weight_for_0 = (negative + positive) / negative
57+
weight_for_1 = (negative + positive) / positive
58+
return {0: weight_for_0, 1: weight_for_1}
59+
60+
def __len__(self):
61+
return len(self._subject_labels)
62+
63+
def __getitem__(self, index):
64+
_sample = self._subject_labels[index]
65+
assert type(_sample[0]) == str, "Sample not found at index {}".format(index)
66+
_img = self._h5_file[_sample[0] + "/image"][_sample[1]][()]
67+
label_data = self._h5_file[_sample[0] + "/label"][()].astype(np.float32)
68+
69+
# Load data and get label
70+
_transformed_img = self._transform(Image.fromarray(_img, 'RGB'))
71+
72+
return _transformed_img, label_data
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import print_function, division, absolute_import
2+
3+
import argparse
4+
import os
5+
6+
import h5py
7+
import numpy as np
8+
from PIL import Image, ImageFilter, ImageOps
9+
from torchvision import transforms
10+
from tqdm import tqdm
11+
12+
script_path = os.path.dirname(os.path.realpath(__file__))
13+
14+
# Augmentations following `prepare_dataset.m`: randomly crop and resize the image 10 times,
15+
# along side two blurring stages, grayscaling and histogram normalisation
16+
_required_size = (224, 224)
17+
_transforms_list = [transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), # equivalent to random 5px from each edge
18+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
19+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
20+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
21+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
22+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
23+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
24+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
25+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
26+
transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)),
27+
transforms.Grayscale(num_output_channels=3),
28+
lambda x: x.filter(ImageFilter.GaussianBlur(radius=1)),
29+
lambda x: x.filter(ImageFilter.GaussianBlur(radius=3)),
30+
lambda x: ImageOps.equalize(x)] # histogram equalisation
31+
32+
33+
def load_and_augment(file_path, augment=False):
34+
image = Image.open(file_path).resize(_required_size)
35+
augmented_images = [np.array(trans(image)) for trans in _transforms_list if augment is True]
36+
augmented_images.append(np.array(image))
37+
38+
return np.array(augmented_images, dtype=np.uint8)
39+
40+
41+
if __name__ == "__main__":
42+
parser = argparse.ArgumentParser(description='Estimate gaze from images')
43+
parser.add_argument('--rt_bene_root', type=str, required=True, nargs='?', help='Path to the base directory of RT_GENE')
44+
parser.add_argument('--augment_dataset', type=bool, required=False, default=False, help="Whether to augment the dataset with predefined transforms")
45+
parser.add_argument('--compress', action='store_true', dest="compress")
46+
parser.add_argument('--no-compress', action='store_false', dest="compress")
47+
parser.set_defaults(compress=False)
48+
args = parser.parse_args()
49+
50+
_compression = "lzf" if args.compress is True else None
51+
52+
subject_path = [os.path.join(args.rt_bene_root, "s{:03d}_noglasses/".format(_i)) for _i in range(0, 17)]
53+
54+
hdf_file = h5py.File(os.path.abspath(os.path.join(args.rt_bene_root, "rtbene_dataset.hdf5")), mode='w')
55+
for subject_id, subject_data in enumerate(subject_path):
56+
subject_id = str("s{:03d}".format(subject_id))
57+
subject_grp = hdf_file.create_group(subject_id)
58+
with open(os.path.join(args.rt_bene_root, "{}_blink_labels.csv".format(subject_id)), "r") as f:
59+
_lines = f.readlines()
60+
61+
for line in tqdm(_lines, desc="Subject {}".format(subject_id)):
62+
63+
split = line.split(",")
64+
image_name = split[0]
65+
image_grp = subject_grp.create_group(image_name)
66+
image_path = os.path.join(subject_data, "natural/left/", "{}".format(split[0]))
67+
if os.path.exists(image_path):
68+
label = float(split[1].strip("\n"))
69+
if label != 0.5: # paper removed 0.5s
70+
image_data = load_and_augment(image_path, augment=args.augment_dataset)
71+
image_grp.create_dataset("image", data=image_data, compression=_compression)
72+
image_grp.create_dataset("label", data=[label])
73+
74+
hdf_file.flush()
75+
hdf_file.close()

rt_bene_model_training/pytorch/util/__init__.py

Whitespace-only changes.

rt_bene_model_training/tensorflow/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,96 @@
1-
#!/usr/bin/env python
2-
3-
import gc
4-
5-
import tensorflow as tf
6-
from tensorflow.keras.models import load_model
7-
8-
from sklearn.metrics import confusion_matrix, roc_curve, auc, average_precision_score
9-
10-
import numpy as np
11-
12-
tf.compat.v1.disable_eager_execution()
13-
14-
config = tf.compat.v1.ConfigProto()
15-
config.gpu_options.allow_growth = True
16-
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
17-
18-
19-
fold_infos = {
20-
'fold1': [2],
21-
'fold2': [1],
22-
'fold3': [0],
23-
'all': [2, 1, 0]
24-
}
25-
26-
model_metrics = [tf.keras.metrics.BinaryAccuracy()]
27-
28-
29-
def estimate_metrics(testing_fold, model_instance):
30-
threshold = 0.5
31-
p = model_instance.predict(x=testing_fold['x'], verbose=0)
32-
p = p >= threshold
33-
matrix = confusion_matrix(testing_fold['y'], p)
34-
ap = average_precision_score(testing_fold['y'], p)
35-
fpr, tpr, thresholds = roc_curve(testing_fold['y'], p)
36-
roc = auc(fpr, tpr)
37-
return matrix, ap, roc
38-
39-
40-
def get_metrics_from_matrix(matrix):
41-
tp, tn, fp, fn = matrix[1, 1], matrix[0, 0], matrix[0, 1], matrix[1, 0]
42-
precision = tp / (tp + fp)
43-
recall = tp / (tp + fn)
44-
f1score = 2. * (precision * recall) / (precision + recall)
45-
return precision, recall, f1score
46-
47-
48-
def threefold_evaluation(dataset, model_paths_fold1, model_paths_fold2, model_paths_fold3, input_size):
49-
folds = ['fold1', 'fold2', 'fold3']
50-
aps = []
51-
rocs = []
52-
recalls = []
53-
precisions = []
54-
f1scores = []
55-
models = []
56-
57-
for fold_to_eval_on, model_paths in zip(folds, [model_paths_fold1, model_paths_fold2, model_paths_fold3]):
58-
if len(model_paths_fold1) > 1:
59-
models = [load_model(model_path, compile=False) for model_path in model_paths]
60-
img_input_l = tf.keras.Input(shape=input_size, name='img_input_L')
61-
img_input_r = tf.keras.Input(shape=input_size, name='img_input_R')
62-
tensors = [model([img_input_r, img_input_l]) for model in models]
63-
output_layer = tf.keras.layers.average(tensors)
64-
model_instance = tf.keras.Model(inputs=[img_input_r, img_input_l], outputs=output_layer)
65-
else:
66-
model_instance = load_model(model_paths[0])
67-
model_instance.compile()
68-
69-
testing_fold = dataset.get_training_data(fold_infos[fold_to_eval_on]) # get the testing fold subjects
70-
71-
matrix, ap, roc = estimate_metrics(testing_fold, model_instance)
72-
aps.append(ap)
73-
rocs.append(roc)
74-
precision, recall, f1score = get_metrics_from_matrix(matrix)
75-
recalls.append(recall)
76-
precisions.append(precision)
77-
f1scores.append(f1score)
78-
79-
del model_instance, testing_fold
80-
# noinspection PyUnusedLocal
81-
for model in models:
82-
del model
83-
gc.collect()
84-
85-
evaluation = {'AP': {}, 'ROC': {}, 'precision': {}, 'recall': {}, 'f1score': {}}
86-
evaluation['AP']['avg'] = np.mean(np.array(aps))
87-
evaluation['AP']['std'] = np.std(np.array(aps))
88-
evaluation['ROC']['avg'] = np.mean(np.array(rocs))
89-
evaluation['ROC']['std'] = np.std(np.array(rocs))
90-
evaluation['precision']['avg'] = np.mean(np.array(precisions))
91-
evaluation['precision']['std'] = np.std(np.array(precisions))
92-
evaluation['recall']['avg'] = np.mean(np.array(recalls))
93-
evaluation['recall']['std'] = np.std(np.array(recalls))
94-
evaluation['f1score']['avg'] = np.mean(np.array(f1scores))
95-
evaluation['f1score']['std'] = np.std(np.array(f1scores))
96-
return evaluation
1+
#!/usr/bin/env python
2+
3+
import gc
4+
5+
import tensorflow as tf
6+
from tensorflow.keras.models import load_model
7+
8+
from sklearn.metrics import confusion_matrix, roc_curve, auc, average_precision_score
9+
10+
import numpy as np
11+
12+
tf.compat.v1.disable_eager_execution()
13+
14+
config = tf.compat.v1.ConfigProto()
15+
config.gpu_options.allow_growth = True
16+
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
17+
18+
19+
fold_infos = {
20+
'fold1': [2],
21+
'fold2': [1],
22+
'fold3': [0],
23+
'all': [2, 1, 0]
24+
}
25+
26+
model_metrics = [tf.keras.metrics.BinaryAccuracy()]
27+
28+
29+
def estimate_metrics(testing_fold, model_instance):
30+
threshold = 0.5
31+
p = model_instance.predict(x=testing_fold['x'], verbose=0)
32+
p = p >= threshold
33+
matrix = confusion_matrix(testing_fold['y'], p)
34+
ap = average_precision_score(testing_fold['y'], p)
35+
fpr, tpr, thresholds = roc_curve(testing_fold['y'], p)
36+
roc = auc(fpr, tpr)
37+
return matrix, ap, roc
38+
39+
40+
def get_metrics_from_matrix(matrix):
41+
tp, tn, fp, fn = matrix[1, 1], matrix[0, 0], matrix[0, 1], matrix[1, 0]
42+
precision = tp / (tp + fp)
43+
recall = tp / (tp + fn)
44+
f1score = 2. * (precision * recall) / (precision + recall)
45+
return precision, recall, f1score
46+
47+
48+
def threefold_evaluation(dataset, model_paths_fold1, model_paths_fold2, model_paths_fold3, input_size):
49+
folds = ['fold1', 'fold2', 'fold3']
50+
aps = []
51+
rocs = []
52+
recalls = []
53+
precisions = []
54+
f1scores = []
55+
models = []
56+
57+
for fold_to_eval_on, model_paths in zip(folds, [model_paths_fold1, model_paths_fold2, model_paths_fold3]):
58+
if len(model_paths_fold1) > 1:
59+
models = [load_model(model_path, compile=False) for model_path in model_paths]
60+
img_input_l = tf.keras.Input(shape=input_size, name='img_input_L')
61+
img_input_r = tf.keras.Input(shape=input_size, name='img_input_R')
62+
tensors = [model([img_input_r, img_input_l]) for model in models]
63+
output_layer = tf.keras.layers.average(tensors)
64+
model_instance = tf.keras.Model(inputs=[img_input_r, img_input_l], outputs=output_layer)
65+
else:
66+
model_instance = load_model(model_paths[0])
67+
model_instance.compile()
68+
69+
testing_fold = dataset.get_training_data(fold_infos[fold_to_eval_on]) # get the testing fold subjects
70+
71+
matrix, ap, roc = estimate_metrics(testing_fold, model_instance)
72+
aps.append(ap)
73+
rocs.append(roc)
74+
precision, recall, f1score = get_metrics_from_matrix(matrix)
75+
recalls.append(recall)
76+
precisions.append(precision)
77+
f1scores.append(f1score)
78+
79+
del model_instance, testing_fold
80+
# noinspection PyUnusedLocal
81+
for model in models:
82+
del model
83+
gc.collect()
84+
85+
evaluation = {'AP': {}, 'ROC': {}, 'precision': {}, 'recall': {}, 'f1score': {}}
86+
evaluation['AP']['avg'] = np.mean(np.array(aps))
87+
evaluation['AP']['std'] = np.std(np.array(aps))
88+
evaluation['ROC']['avg'] = np.mean(np.array(rocs))
89+
evaluation['ROC']['std'] = np.std(np.array(rocs))
90+
evaluation['precision']['avg'] = np.mean(np.array(precisions))
91+
evaluation['precision']['std'] = np.std(np.array(precisions))
92+
evaluation['recall']['avg'] = np.mean(np.array(recalls))
93+
evaluation['recall']['std'] = np.std(np.array(recalls))
94+
evaluation['f1score']['avg'] = np.mean(np.array(f1scores))
95+
evaluation['f1score']['std'] = np.std(np.array(f1scores))
96+
return evaluation

0 commit comments

Comments
 (0)