Skip to content

Commit 9b94cef

Browse files
authored
Updates for NeurIPS submission (#16)
* Add bengali experiment * Updates * add hdf5 conversion * update * Updates * Updates * Updates * Updates * Updates * add image hdf5 * Add imagenet_hdf5 * Updates * Updates * Updates * Updates * Updates * Add macro recall * Updates * Add NLP stuff * Add conversion script * Updates * Add IMDB * Update CNN * Add yelp * fix * Add DPCNN * Updates * Add bengali evaluate script * update * update * update * Updates * Updates * Updates * Add yelp_5 * Updates
1 parent cdf8e69 commit 9b94cef

25 files changed

+1344
-96
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,4 @@ saved_models/
137137
logs/
138138
notebooks/*
139139
!notebooks/*.ipynb
140+
.vector_cache

datasets/bengali.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Adapted from https://www.kaggle.com/corochann/bengali-seresnext-training-with-pytorch
3+
"""
4+
from torch.utils.data import Dataset
5+
# import os
6+
# from torchvision.datasets.folder import default_loader
7+
import numpy as np
8+
import pandas as pd
9+
import gc
10+
11+
12+
def prepare_image(root, indices=[0, 1, 2, 3]):
13+
# assert data_type in ['train', 'test']
14+
# if submission:
15+
# image_df_list = [pd.read_parquet(datadir / f'{data_type}_image_data_{i}.parquet')
16+
# for i in indices]
17+
# else:
18+
image_df_list = [pd.read_feather(f'{root}/train_image_data_{i}.feather') for i in indices]
19+
20+
HEIGHT = 137
21+
WIDTH = 236
22+
images = [df.iloc[:, 1:].values.reshape(-1, HEIGHT, WIDTH) for df in image_df_list]
23+
del image_df_list
24+
gc.collect()
25+
images = np.concatenate(images, axis=0)
26+
return images
27+
28+
29+
class Bengali(Dataset):
30+
def __init__(self, root, targets, transform=None):
31+
self.transform = transform
32+
33+
if isinstance(targets, list):
34+
self.labels = list(pd.read_csv(f'{root}/train.csv')[targets].itertuples(index=False, name=None))
35+
else:
36+
self.labels = pd.read_csv(f'{root}/train.csv')[targets]
37+
self.images = prepare_image(root)
38+
39+
def __getitem__(self, index):
40+
image, label = self.images[index], self.labels[index]
41+
image = (255 - image).astype(np.float32) / 255.
42+
43+
if self.transform is not None:
44+
image = self.transform(image)
45+
46+
return image, label
47+
48+
def __len__(self) -> int:
49+
return len(self.labels)
50+
51+
52+
class BengaliGraphemeWhole(Bengali):
53+
def __init__(self, root, transform=None):
54+
super().__init__(root, ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'], transform=transform)
55+
56+
57+
class BengaliGraphemeRoot(Bengali):
58+
def __init__(self, root, transform=None):
59+
super().__init__(root, 'grapheme_root', transform=transform)
60+
61+
62+
class BengaliVowelDiacritic(Bengali):
63+
def __init__(self, root, transform=None):
64+
super().__init__(root, 'vowel_diacritic', transform=transform)
65+
66+
67+
class BengaliConsonantDiacritic(Bengali):
68+
def __init__(self, root, transform=None):
69+
super().__init__(root, 'consonant_diacritic', transform=transform)

datasets/datasets.py

+209-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import torch
2+
from torch.utils.data import Dataset
13
from torchvision.datasets import CIFAR10, CIFAR100, FashionMNIST, ImageFolder, ImageNet
24
from torchvision.transforms import Compose, Normalize, ToTensor, RandomCrop, RandomHorizontalFlip, RandomResizedCrop, \
35
Resize, CenterCrop
46
import os
57
from datasets.tiny_imagenet import TinyImageNet
8+
from datasets.imagenet_hdf5 import ImageNetHDF5
69
from utils import split, EqualSplitter, auto_augment, _fa_reduced_cifar10
710
from datasets.toxic import toxic_ds
11+
from datasets.toxic_bert import toxic_bert
12+
from datasets.bengali import BengaliConsonantDiacritic, BengaliGraphemeRoot, BengaliVowelDiacritic, BengaliGraphemeWhole
813

914

1015
@auto_augment(_fa_reduced_cifar10)
@@ -50,6 +55,13 @@ def imagenet_transforms(args):
5055
return transform_train, transform_test
5156

5257

58+
def imagenet_a_transforms(args):
59+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
60+
transform = Compose([Resize(256), CenterCrop(224), ToTensor(), normalize])
61+
62+
return transform
63+
64+
5365
def tinyimagenet_transforms(args):
5466
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
5567
base = [ToTensor(), normalize, ]
@@ -99,6 +111,32 @@ def modelnet_transforms(args):
99111
return transform, test_transform
100112

101113

114+
def bengali_transforms(args):
115+
import numpy as np
116+
from PIL import Image
117+
118+
def crop_char_image(image, threshold=5. / 255.):
119+
assert image.ndim == 2
120+
is_black = image > threshold
121+
122+
is_black_vertical = np.sum(is_black, axis=0) > 0
123+
is_black_horizontal = np.sum(is_black, axis=1) > 0
124+
left = np.argmax(is_black_horizontal)
125+
right = np.argmax(is_black_horizontal[::-1])
126+
top = np.argmax(is_black_vertical)
127+
bottom = np.argmax(is_black_vertical[::-1])
128+
height, width = image.shape
129+
cropped_image = image[left:height - right, top:width - bottom]
130+
return Image.fromarray(cropped_image)
131+
132+
return Compose([
133+
crop_char_image,
134+
Resize((64, 64)),
135+
ToTensor(),
136+
Normalize((0.0692,), (0.2051,))
137+
])
138+
139+
102140
dstransforms = {
103141
'cifar10': cifar_transforms,
104142
'cifar100': cifar_transforms,
@@ -107,8 +145,14 @@ def modelnet_transforms(args):
107145
'fashion': fashion_transforms,
108146
'tinyimagenet': tinyimagenet_transforms,
109147
'imagenet': imagenet_transforms,
148+
'imagenet_hdf5': imagenet_transforms,
149+
'imagenet_a': imagenet_a_transforms,
110150
'commands': commands_transforms,
111151
'modelnet': modelnet_transforms,
152+
'bengali_r': bengali_transforms,
153+
'bengali_c': bengali_transforms,
154+
'bengali_v': bengali_transforms,
155+
'bengali': bengali_transforms
112156
}
113157

114158

@@ -156,12 +200,34 @@ def imagenet(args):
156200

157201
root = '/ssd/ILSVRC2012' if args.dataset_path is None else args.dataset_path
158202

159-
trainset = data(root=root, transform=transform_train)
160-
testset = data(root=root, transform=transform_test)
203+
trainset = data(root=f'{root}/train', transform=transform_train)
204+
testset = data(root=f'{root}/val', transform=transform_test)
161205

162206
return trainset, testset
163207

164208

209+
@split
210+
def imagenet_hdf5(args):
211+
data = ImageNetHDF5
212+
transform_train, transform_test = dstransforms[args.dataset](args)
213+
214+
root = '/ssd/ILSVRC2012' if args.dataset_path is None else args.dataset_path
215+
216+
trainset = data(root=f'{root}/train', transform=transform_train)
217+
testset = data(root=f'{root}/val', transform=transform_test)
218+
219+
return trainset, testset
220+
221+
222+
def imagenet_a(args):
223+
data = ImageFolder
224+
transform = dstransforms[args.dataset](args)
225+
root = args.dataset_path
226+
testset = data(root=f'{root}/', transform=transform)
227+
228+
return None, None, testset
229+
230+
165231
@split
166232
def tinyimagenet(args):
167233
data = TinyImageNet
@@ -199,17 +265,146 @@ def modelnet(args):
199265
return trainset, valset
200266

201267

268+
@split
269+
def bengali_r(args):
270+
transform_train = dstransforms[args.dataset](args)
271+
272+
root = '/ssd/bengali' if args.dataset_path is None else args.dataset_path
273+
274+
trainset = BengaliGraphemeRoot(root=root, transform=transform_train)
275+
return trainset
276+
277+
278+
@split
279+
def bengali_c(args):
280+
transform_train = dstransforms[args.dataset](args)
281+
282+
root = '/ssd/bengali' if args.dataset_path is None else args.dataset_path
283+
284+
trainset = BengaliConsonantDiacritic(root=root, transform=transform_train)
285+
return trainset
286+
287+
288+
@split
289+
def bengali_v(args):
290+
transform_train = dstransforms[args.dataset](args)
291+
292+
root = '/ssd/bengali' if args.dataset_path is None else args.dataset_path
293+
294+
trainset = BengaliVowelDiacritic(root=root, transform=transform_train)
295+
return trainset
296+
297+
298+
@split
299+
def bengali(args):
300+
transform_train = dstransforms[args.dataset](args)
301+
302+
root = '/ssd/bengali' if args.dataset_path is None else args.dataset_path
303+
304+
trainset = BengaliGraphemeWhole(root=root, transform=transform_train)
305+
return trainset
306+
307+
308+
def imdb(args):
309+
from torchtext import data, datasets
310+
311+
TEXT = data.Field(tokenize='spacy', batch_first=True)
312+
LABEL = data.LabelField(dtype=torch.float)
313+
314+
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL, root=args.dataset_path)
315+
316+
TEXT.build_vocab(train_data, vectors='fasttext.simple.300d')
317+
LABEL.build_vocab(train_data)
318+
319+
train_iterator, test_iterator = data.BucketIterator.splits(
320+
(train_data, test_data),
321+
batch_size=args.batch_size,
322+
sort_within_batch=True,
323+
device=args.device)
324+
325+
train_iterator.vectors = TEXT.vocab.vectors.to(args.device)
326+
train_iterator.ntokens = len(TEXT.vocab)
327+
return train_iterator, None, test_iterator
328+
329+
330+
class ReverseOrder(Dataset):
331+
def __init__(self, dataset):
332+
self.dataset = dataset
333+
334+
def __getitem__(self, i):
335+
res = self.dataset[i]
336+
# print(res)
337+
return res[1], torch.tensor(res[0]).long()
338+
339+
def __len__(self):
340+
return len(self.dataset)
341+
342+
343+
def yelp_2(args):
344+
from torchtext import datasets
345+
from .toxic_bert import NoBatchBucketIterator
346+
347+
train_data, test_data = datasets.YelpReviewPolarity(root=args.dataset_path)
348+
train_data, test_data = ReverseOrder(train_data), ReverseOrder(test_data)
349+
350+
train_iterator = NoBatchBucketIterator(dataset=train_data, batch_size=args.batch_size,
351+
sort_key=lambda x: x[0].size(0),
352+
device=torch.device(args.device), sort_within_batch=True)
353+
test_iterator = NoBatchBucketIterator(dataset=test_data, batch_size=args.batch_size,
354+
sort_key=lambda x: x[0].size(0),
355+
device=torch.device(args.device), sort_within_batch=True)
356+
357+
vocab = train_data.dataset.get_vocab()
358+
vocab.load_vectors('fasttext.simple.300d')
359+
360+
train_iterator.vectors = vocab.vectors.to(args.device)
361+
train_iterator.ntokens = len(vocab)
362+
return train_iterator, None, test_iterator
363+
364+
365+
def yelp_5(args):
366+
from torchtext import datasets
367+
from .toxic_bert import NoBatchBucketIterator
368+
369+
train_data, test_data = datasets.YelpReviewFull(root=args.dataset_path)
370+
train_data, test_data = ReverseOrder(train_data), ReverseOrder(test_data)
371+
372+
train_iterator = NoBatchBucketIterator(dataset=train_data, batch_size=args.batch_size,
373+
sort_key=lambda x: x[0].size(0),
374+
device=torch.device(args.device), sort_within_batch=True)
375+
test_iterator = NoBatchBucketIterator(dataset=test_data, batch_size=args.batch_size,
376+
sort_key=lambda x: x[0].size(0),
377+
device=torch.device(args.device), sort_within_batch=True)
378+
379+
vocab = train_data.dataset.get_vocab()
380+
vocab.load_vectors('fasttext.simple.300d')
381+
382+
train_iterator.vectors = vocab.vectors.to(args.device)
383+
train_iterator.ntokens = len(vocab)
384+
return train_iterator, None, test_iterator
385+
386+
202387
ds = {
203388
'cifar10': cifar,
204389
'cifar100': cifar,
205390
'fashion': fashion,
206391
'fashion_old': fashion,
207392
'imagenet': imagenet,
393+
'imagenet_hdf5': imagenet_hdf5,
394+
'imagenet_a': imagenet_a,
208395
'commands': commands,
209396
'tinyimagenet': tinyimagenet,
210397
'reduced_cifar': reduced_cifar,
211398
'modelnet': modelnet,
212399
'toxic': toxic_ds,
400+
'toxic_bert': toxic_bert,
401+
'bengali_r': bengali_r,
402+
'bengali_c': bengali_c,
403+
'bengali_v': bengali_v,
404+
'bengali': bengali,
405+
'imdb': imdb,
406+
'yelp_2': yelp_2,
407+
'yelp_5': yelp_5
213408
}
214409

215410
dsmeta = {
@@ -218,10 +413,21 @@ def modelnet(args):
218413
'fashion': {'classes': 10, 'nc': 1, 'size': (28, 28)},
219414
'fashion_old': {'classes': 10, 'nc': 1, 'size': (28, 28)},
220415
'imagenet': {'classes': 1000, 'nc': 3, 'size': (224, 224)},
416+
'imagenet_hdf5': {'classes': 1000, 'nc': 3, 'size': (224, 224)},
417+
'imagenet_a': {'classes': 1000, 'nc': 3, 'size': (224, 224)},
221418
'commands': {'classes': 12, 'nc': 1, 'size': (32, 32)},
222419
'tinyimagenet': {'classes': 200, 'nc': 3, 'size': (64, 64)},
223420
'reduced_cifar': {'classes': 10, 'nc': 3, 'size': (32, 32)},
224421
'modelnet': {'classes': 10, 'nc': None, 'size': None},
225-
'toxic': {'classes': None, 'nc': None, 'size': (-1, 1)},
422+
'toxic': {'classes': 6, 'nc': 300, 'size': [-1]},
423+
'toxic_bert': {'classes': 6, 'nc': None, 'size': [-1]},
424+
'bengali_r': {'classes': 168, 'nc': 1, 'size': (64, 64)},
425+
'bengali_c': {'classes': 7, 'nc': 1, 'size': (64, 64)},
426+
'bengali_v': {'classes': 11, 'nc': 1, 'size': (64, 64)},
427+
'bengali': {'classes': (168, 11, 7), 'nc': 1, 'size': (64, 64)},
428+
'imdb': {'classes': 1, 'nc': 300, 'size': [-1]},
429+
'yelp_2': {'classes': 1, 'nc': 300, 'size': [-1]},
430+
'yelp_5': {'classes': 5, 'nc': 300, 'size': [-1]},
226431
}
227432

433+
nlp_data = ['toxic', 'toxic_bert', 'imdb', 'yelp_2', 'yelp_5']

0 commit comments

Comments
 (0)