1
+ import torch
2
+ from torch .utils .data import Dataset
1
3
from torchvision .datasets import CIFAR10 , CIFAR100 , FashionMNIST , ImageFolder , ImageNet
2
4
from torchvision .transforms import Compose , Normalize , ToTensor , RandomCrop , RandomHorizontalFlip , RandomResizedCrop , \
3
5
Resize , CenterCrop
4
6
import os
5
7
from datasets .tiny_imagenet import TinyImageNet
8
+ from datasets .imagenet_hdf5 import ImageNetHDF5
6
9
from utils import split , EqualSplitter , auto_augment , _fa_reduced_cifar10
7
10
from datasets .toxic import toxic_ds
11
+ from datasets .toxic_bert import toxic_bert
12
+ from datasets .bengali import BengaliConsonantDiacritic , BengaliGraphemeRoot , BengaliVowelDiacritic , BengaliGraphemeWhole
8
13
9
14
10
15
@auto_augment (_fa_reduced_cifar10 )
@@ -50,6 +55,13 @@ def imagenet_transforms(args):
50
55
return transform_train , transform_test
51
56
52
57
58
+ def imagenet_a_transforms (args ):
59
+ normalize = Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
60
+ transform = Compose ([Resize (256 ), CenterCrop (224 ), ToTensor (), normalize ])
61
+
62
+ return transform
63
+
64
+
53
65
def tinyimagenet_transforms (args ):
54
66
normalize = Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
55
67
base = [ToTensor (), normalize , ]
@@ -99,6 +111,32 @@ def modelnet_transforms(args):
99
111
return transform , test_transform
100
112
101
113
114
+ def bengali_transforms (args ):
115
+ import numpy as np
116
+ from PIL import Image
117
+
118
+ def crop_char_image (image , threshold = 5. / 255. ):
119
+ assert image .ndim == 2
120
+ is_black = image > threshold
121
+
122
+ is_black_vertical = np .sum (is_black , axis = 0 ) > 0
123
+ is_black_horizontal = np .sum (is_black , axis = 1 ) > 0
124
+ left = np .argmax (is_black_horizontal )
125
+ right = np .argmax (is_black_horizontal [::- 1 ])
126
+ top = np .argmax (is_black_vertical )
127
+ bottom = np .argmax (is_black_vertical [::- 1 ])
128
+ height , width = image .shape
129
+ cropped_image = image [left :height - right , top :width - bottom ]
130
+ return Image .fromarray (cropped_image )
131
+
132
+ return Compose ([
133
+ crop_char_image ,
134
+ Resize ((64 , 64 )),
135
+ ToTensor (),
136
+ Normalize ((0.0692 ,), (0.2051 ,))
137
+ ])
138
+
139
+
102
140
dstransforms = {
103
141
'cifar10' : cifar_transforms ,
104
142
'cifar100' : cifar_transforms ,
@@ -107,8 +145,14 @@ def modelnet_transforms(args):
107
145
'fashion' : fashion_transforms ,
108
146
'tinyimagenet' : tinyimagenet_transforms ,
109
147
'imagenet' : imagenet_transforms ,
148
+ 'imagenet_hdf5' : imagenet_transforms ,
149
+ 'imagenet_a' : imagenet_a_transforms ,
110
150
'commands' : commands_transforms ,
111
151
'modelnet' : modelnet_transforms ,
152
+ 'bengali_r' : bengali_transforms ,
153
+ 'bengali_c' : bengali_transforms ,
154
+ 'bengali_v' : bengali_transforms ,
155
+ 'bengali' : bengali_transforms
112
156
}
113
157
114
158
@@ -156,12 +200,34 @@ def imagenet(args):
156
200
157
201
root = '/ssd/ILSVRC2012' if args .dataset_path is None else args .dataset_path
158
202
159
- trainset = data (root = root , transform = transform_train )
160
- testset = data (root = root , transform = transform_test )
203
+ trainset = data (root = f' { root } /train' , transform = transform_train )
204
+ testset = data (root = f' { root } /val' , transform = transform_test )
161
205
162
206
return trainset , testset
163
207
164
208
209
+ @split
210
+ def imagenet_hdf5 (args ):
211
+ data = ImageNetHDF5
212
+ transform_train , transform_test = dstransforms [args .dataset ](args )
213
+
214
+ root = '/ssd/ILSVRC2012' if args .dataset_path is None else args .dataset_path
215
+
216
+ trainset = data (root = f'{ root } /train' , transform = transform_train )
217
+ testset = data (root = f'{ root } /val' , transform = transform_test )
218
+
219
+ return trainset , testset
220
+
221
+
222
+ def imagenet_a (args ):
223
+ data = ImageFolder
224
+ transform = dstransforms [args .dataset ](args )
225
+ root = args .dataset_path
226
+ testset = data (root = f'{ root } /' , transform = transform )
227
+
228
+ return None , None , testset
229
+
230
+
165
231
@split
166
232
def tinyimagenet (args ):
167
233
data = TinyImageNet
@@ -199,17 +265,146 @@ def modelnet(args):
199
265
return trainset , valset
200
266
201
267
268
+ @split
269
+ def bengali_r (args ):
270
+ transform_train = dstransforms [args .dataset ](args )
271
+
272
+ root = '/ssd/bengali' if args .dataset_path is None else args .dataset_path
273
+
274
+ trainset = BengaliGraphemeRoot (root = root , transform = transform_train )
275
+ return trainset
276
+
277
+
278
+ @split
279
+ def bengali_c (args ):
280
+ transform_train = dstransforms [args .dataset ](args )
281
+
282
+ root = '/ssd/bengali' if args .dataset_path is None else args .dataset_path
283
+
284
+ trainset = BengaliConsonantDiacritic (root = root , transform = transform_train )
285
+ return trainset
286
+
287
+
288
+ @split
289
+ def bengali_v (args ):
290
+ transform_train = dstransforms [args .dataset ](args )
291
+
292
+ root = '/ssd/bengali' if args .dataset_path is None else args .dataset_path
293
+
294
+ trainset = BengaliVowelDiacritic (root = root , transform = transform_train )
295
+ return trainset
296
+
297
+
298
+ @split
299
+ def bengali (args ):
300
+ transform_train = dstransforms [args .dataset ](args )
301
+
302
+ root = '/ssd/bengali' if args .dataset_path is None else args .dataset_path
303
+
304
+ trainset = BengaliGraphemeWhole (root = root , transform = transform_train )
305
+ return trainset
306
+
307
+
308
+ def imdb (args ):
309
+ from torchtext import data , datasets
310
+
311
+ TEXT = data .Field (tokenize = 'spacy' , batch_first = True )
312
+ LABEL = data .LabelField (dtype = torch .float )
313
+
314
+ train_data , test_data = datasets .IMDB .splits (TEXT , LABEL , root = args .dataset_path )
315
+
316
+ TEXT .build_vocab (train_data , vectors = 'fasttext.simple.300d' )
317
+ LABEL .build_vocab (train_data )
318
+
319
+ train_iterator , test_iterator = data .BucketIterator .splits (
320
+ (train_data , test_data ),
321
+ batch_size = args .batch_size ,
322
+ sort_within_batch = True ,
323
+ device = args .device )
324
+
325
+ train_iterator .vectors = TEXT .vocab .vectors .to (args .device )
326
+ train_iterator .ntokens = len (TEXT .vocab )
327
+ return train_iterator , None , test_iterator
328
+
329
+
330
+ class ReverseOrder (Dataset ):
331
+ def __init__ (self , dataset ):
332
+ self .dataset = dataset
333
+
334
+ def __getitem__ (self , i ):
335
+ res = self .dataset [i ]
336
+ # print(res)
337
+ return res [1 ], torch .tensor (res [0 ]).long ()
338
+
339
+ def __len__ (self ):
340
+ return len (self .dataset )
341
+
342
+
343
+ def yelp_2 (args ):
344
+ from torchtext import datasets
345
+ from .toxic_bert import NoBatchBucketIterator
346
+
347
+ train_data , test_data = datasets .YelpReviewPolarity (root = args .dataset_path )
348
+ train_data , test_data = ReverseOrder (train_data ), ReverseOrder (test_data )
349
+
350
+ train_iterator = NoBatchBucketIterator (dataset = train_data , batch_size = args .batch_size ,
351
+ sort_key = lambda x : x [0 ].size (0 ),
352
+ device = torch .device (args .device ), sort_within_batch = True )
353
+ test_iterator = NoBatchBucketIterator (dataset = test_data , batch_size = args .batch_size ,
354
+ sort_key = lambda x : x [0 ].size (0 ),
355
+ device = torch .device (args .device ), sort_within_batch = True )
356
+
357
+ vocab = train_data .dataset .get_vocab ()
358
+ vocab .load_vectors ('fasttext.simple.300d' )
359
+
360
+ train_iterator .vectors = vocab .vectors .to (args .device )
361
+ train_iterator .ntokens = len (vocab )
362
+ return train_iterator , None , test_iterator
363
+
364
+
365
+ def yelp_5 (args ):
366
+ from torchtext import datasets
367
+ from .toxic_bert import NoBatchBucketIterator
368
+
369
+ train_data , test_data = datasets .YelpReviewFull (root = args .dataset_path )
370
+ train_data , test_data = ReverseOrder (train_data ), ReverseOrder (test_data )
371
+
372
+ train_iterator = NoBatchBucketIterator (dataset = train_data , batch_size = args .batch_size ,
373
+ sort_key = lambda x : x [0 ].size (0 ),
374
+ device = torch .device (args .device ), sort_within_batch = True )
375
+ test_iterator = NoBatchBucketIterator (dataset = test_data , batch_size = args .batch_size ,
376
+ sort_key = lambda x : x [0 ].size (0 ),
377
+ device = torch .device (args .device ), sort_within_batch = True )
378
+
379
+ vocab = train_data .dataset .get_vocab ()
380
+ vocab .load_vectors ('fasttext.simple.300d' )
381
+
382
+ train_iterator .vectors = vocab .vectors .to (args .device )
383
+ train_iterator .ntokens = len (vocab )
384
+ return train_iterator , None , test_iterator
385
+
386
+
202
387
ds = {
203
388
'cifar10' : cifar ,
204
389
'cifar100' : cifar ,
205
390
'fashion' : fashion ,
206
391
'fashion_old' : fashion ,
207
392
'imagenet' : imagenet ,
393
+ 'imagenet_hdf5' : imagenet_hdf5 ,
394
+ 'imagenet_a' : imagenet_a ,
208
395
'commands' : commands ,
209
396
'tinyimagenet' : tinyimagenet ,
210
397
'reduced_cifar' : reduced_cifar ,
211
398
'modelnet' : modelnet ,
212
399
'toxic' : toxic_ds ,
400
+ 'toxic_bert' : toxic_bert ,
401
+ 'bengali_r' : bengali_r ,
402
+ 'bengali_c' : bengali_c ,
403
+ 'bengali_v' : bengali_v ,
404
+ 'bengali' : bengali ,
405
+ 'imdb' : imdb ,
406
+ 'yelp_2' : yelp_2 ,
407
+ 'yelp_5' : yelp_5
213
408
}
214
409
215
410
dsmeta = {
@@ -218,10 +413,21 @@ def modelnet(args):
218
413
'fashion' : {'classes' : 10 , 'nc' : 1 , 'size' : (28 , 28 )},
219
414
'fashion_old' : {'classes' : 10 , 'nc' : 1 , 'size' : (28 , 28 )},
220
415
'imagenet' : {'classes' : 1000 , 'nc' : 3 , 'size' : (224 , 224 )},
416
+ 'imagenet_hdf5' : {'classes' : 1000 , 'nc' : 3 , 'size' : (224 , 224 )},
417
+ 'imagenet_a' : {'classes' : 1000 , 'nc' : 3 , 'size' : (224 , 224 )},
221
418
'commands' : {'classes' : 12 , 'nc' : 1 , 'size' : (32 , 32 )},
222
419
'tinyimagenet' : {'classes' : 200 , 'nc' : 3 , 'size' : (64 , 64 )},
223
420
'reduced_cifar' : {'classes' : 10 , 'nc' : 3 , 'size' : (32 , 32 )},
224
421
'modelnet' : {'classes' : 10 , 'nc' : None , 'size' : None },
225
- 'toxic' : {'classes' : None , 'nc' : None , 'size' : (- 1 , 1 )},
422
+ 'toxic' : {'classes' : 6 , 'nc' : 300 , 'size' : [- 1 ]},
423
+ 'toxic_bert' : {'classes' : 6 , 'nc' : None , 'size' : [- 1 ]},
424
+ 'bengali_r' : {'classes' : 168 , 'nc' : 1 , 'size' : (64 , 64 )},
425
+ 'bengali_c' : {'classes' : 7 , 'nc' : 1 , 'size' : (64 , 64 )},
426
+ 'bengali_v' : {'classes' : 11 , 'nc' : 1 , 'size' : (64 , 64 )},
427
+ 'bengali' : {'classes' : (168 , 11 , 7 ), 'nc' : 1 , 'size' : (64 , 64 )},
428
+ 'imdb' : {'classes' : 1 , 'nc' : 300 , 'size' : [- 1 ]},
429
+ 'yelp_2' : {'classes' : 1 , 'nc' : 300 , 'size' : [- 1 ]},
430
+ 'yelp_5' : {'classes' : 5 , 'nc' : 300 , 'size' : [- 1 ]},
226
431
}
227
432
433
+ nlp_data = ['toxic' , 'toxic_bert' , 'imdb' , 'yelp_2' , 'yelp_5' ]
0 commit comments