Skip to content

Commit 46870a7

Browse files
committed
rename dataset_root to data_root
1 parent 0ce9a43 commit 46870a7

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

make_hdf5.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def prepare_parser():
2525
help='Which Dataset to train on, out of I128, I256, C10, C100;'
2626
'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)')
2727
parser.add_argument(
28-
'--dataset_root', type=str, default='data',
28+
'--data_root', type=str, default='data',
2929
help='Default location where data is stored (default: %(default)s)')
3030
parser.add_argument(
3131
'--batch_size', type=int, default=256,
@@ -58,7 +58,7 @@ def run(config):
5858
train_loader = utils.get_data_loaders(dataset=config['dataset'],
5959
batch_size=config['batch_size'],
6060
shuffle=False,
61-
dataset_root=config['dataset_root'],
61+
data_root=config['data_root'],
6262
use_multiepoch_sampler=False,
6363
**kwargs)[0]
6464

@@ -81,7 +81,7 @@ def run(config):
8181
y = y.numpy()
8282
# If we're on the first batch, prepare the hdf5
8383
if i==0:
84-
with h5.File(config['dataset_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
84+
with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
8585
print('Producing dataset of len %d' % len(train_loader.dataset))
8686
imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']),
8787
chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression'])
@@ -92,7 +92,7 @@ def run(config):
9292
labels_dset[...] = y
9393
# Else append to the hdf5
9494
else:
95-
with h5.File(config['dataset_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
95+
with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
9696
f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
9797
f['imgs'][-x.shape[0]:] = x
9898
f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)

utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def prepare_parser():
248248
help='Default location to store all weights, samples, data, and logs '
249249
' (default: %(default)s)')
250250
parser.add_argument(
251-
'--dataset_root', type=str, default='data',
251+
'--data_root', type=str, default='data',
252252
help='Default location where data is stored (default: %(default)s)')
253253
parser.add_argument(
254254
'--weights_root', type=str, default='weights',
@@ -521,15 +521,15 @@ def __len__(self):
521521

522522

523523
# Convenience function to centralize all data loaders
524-
def get_data_loaders(dataset, dataset_root=None, augment=False, batch_size=64,
524+
def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64,
525525
num_workers=8, shuffle=True, load_in_mem=False, hdf5=False,
526526
pin_memory=True, drop_last=True, start_itr=0,
527527
num_epochs=500, use_multiepoch_sampler=False,
528528
**kwargs):
529529

530530
# Append /FILENAME.hdf5 to root if using hdf5
531-
dataset_root += '/%s' % root_dict[dataset]
532-
print('Using dataset root location %s' % dataset_root)
531+
data_root += '/%s' % root_dict[dataset]
532+
print('Using dataset root location %s' % data_root)
533533

534534
which_dataset = dset_dict[dataset]
535535
norm_mean = [0.5,0.5,0.5]
@@ -562,7 +562,7 @@ def get_data_loaders(dataset, dataset_root=None, augment=False, batch_size=64,
562562
train_transform = transforms.Compose(train_transform + [
563563
transforms.ToTensor(),
564564
transforms.Normalize(norm_mean, norm_std)])
565-
train_set = which_dataset(root=dataset_root, transform=train_transform,
565+
train_set = which_dataset(root=data_root, transform=train_transform,
566566
load_in_mem=load_in_mem, **dataset_kwargs)
567567

568568
# Prepare loader; the loaders list is for forward compatibility with

0 commit comments

Comments
 (0)