-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.lua
108 lines (79 loc) · 3.46 KB
/
data.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
--[[
Functions to load data from various datasets (cifar/mnist/ilsvrc2012).
]]
require 'image'
local string_ascii = require 'dbcollection.utils.string_ascii'
local ascii2str = string_ascii.convert_ascii_to_str
local function loader_mnist(dbloader, num_classes, mode)
local function fetch_data()
-- get random class index
local classID = torch.random(1, num_classes)
local imgs_class = dbloader:get(mode, 'list_images_per_class', classID)
-- remove negative indexes
imgs_class = imgs_class[imgs_class:ge(0)]
-- get random image index
local random_idx = torch.random(1, imgs_class:size(1))
local img_idx = imgs_class[random_idx] + 1 -- lua is 1-indexed
-- fetch image filename
local img = dbloader:get(mode, 'images', img_idx):float():div(255):repeatTensor(3,1,1)
-- output data
return {img, torch.LongTensor{classID}}
end
return fetch_data, dbloader:size(mode)[1]
end
local function loader_cifar(dbloader, num_classes, mode)
local function fetch_data()
-- get random class index
local classID = torch.random(1, num_classes)
local imgs_class = dbloader:get(mode, 'list_images_per_class', classID)
-- remove negative indexes
imgs_class = imgs_class[imgs_class:ge(0)]
-- get random image index
local random_idx = torch.random(1, imgs_class:size(1))
local img_idx = imgs_class[random_idx] + 1 -- lua is 1-indexed
-- fetch image filename
local img = dbloader:get(mode, 'images', img_idx):transpose(3,4):transpose(2,3):float():div(255):squeeze()
-- output data
return {img, torch.LongTensor{classID}}
end
return fetch_data, dbloader:size(mode)[1]
end
local function loader_ilsvrc2012(dbloader, num_classes, mode)
local mode = mode
if mode == 'test' then
mode = 'val' -- only has the 'val' set in the ilsvrc2012
end
local function fetch_data()
-- get random class index
local classID = torch.random(1, num_classes)
local imgs_class = dbloader:get(mode, 'list_image_filenames_per_class', classID)
-- remove negative indexes
imgs_class = imgs_class[imgs_class:ge(0)]
-- get random image index
local random_idx = torch.random(1, imgs_class:size(1))
local img_idx = imgs_class[random_idx] + 1 -- lua is 1-indexed
-- fetch image filename
local filename = ascii2str(dbloader:get(mode, 'image_filenames', img_idx))
filename = paths.concat(dbloader.data_dir, filename) -- merge data dirpath to filename
-- output data
return {image.load(filename, 3, 'float'), torch.LongTensor{classID}}
end
return fetch_data, dbloader:size(mode)[1]
end
local function data_loader(name, mode)
local dbc = require 'dbcollection'
local dbdataset = dbc.load(name)
local num_classes = dbdataset:size('train', 'classes')[1]
if name == 'ilsvrc2012' then
return loader_ilsvrc2012(dbdataset, num_classes, mode)
elseif name == 'cifar10' then
return loader_cifar(dbdataset, num_classes, mode)
elseif name == 'cifar100' then
return loader_cifar(dbdataset, num_classes, mode)
elseif name == 'mnist' then
return loader_mnist(dbdataset, num_classes, mode)
else
error('Undefined dataset for this example code. Please choose one dataset from the following: ilsvrc2012 | cifar10 | cifar 100 | mnist')
end
end
return data_loader