-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.lua
125 lines (104 loc) · 3.48 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
--[[
Useful utility functions for managing networks.
]]
local ffi=require 'ffi'
------------------------------------
local function MSRinit(model)
for k,v in pairs(model:findModules('nn.SpatialConvolution')) do
local n = v.kW*v.kH*v.nOutputPlane
v.weight:normal(0,math.sqrt(2/n))
if v.bias then v.bias:zero() end
end
end
------------------------------------
local function FCinit(model)
for k,v in pairs(model:findModules'nn.Linear') do
v.bias:zero()
end
end
------------------------------------
local function DisableBias(model)
for i,v in ipairs(model:findModules'nn.SpatialConvolution') do
v.bias = nil
v.gradBias = nil
end
end
------------------------------------
local function makeDataParallelTable(model, nGPU)
if nGPU > 1 then
local gpus = torch.range(1, nGPU):totable()
local fastest, benchmark
if pcall(require 'cudnn') then
fastest, benchmark = cudnn.fastest, cudnn.benchmark
end
-- should have nccl installed for faster data transfer rates between GPUs
-- (see https://github.com/torch/cunn/blob/master/DataParallelTable.lua#L12)
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'nngraph'
if pcall(require,'cudnn') then
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
------------------------------------
local function cleanDPT(module)
-- This assumes this DPT was created by the function above: all the
-- module.modules are clones of the same network on different GPUs
-- hence we only need to keep one when saving the model to the disk.
local newDPT = nn.DataParallelTable(1)
cutorch.setDevice(opt.GPU)
newDPT:add(module:get(1), opt.GPU)
return newDPT
end
------------------------------------
local function saveDataParallel(filename, model)
if torch.type(model) == 'nn.DataParallelTable' then
torch.save(filename, model.modules[1])
elseif torch.type(model) == 'nn.Sequential' then
local temp_model = nn.Sequential()
for i, module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' then
temp_model:add(module.modules[1])
else
temp_model:add(module)
end
end
torch.save(filename, temp_model)
else
error('This saving function only works with Sequential or DataParallelTable modules.')
end
end
------------------------------------
local function loadDataParallel(filename, nGPU)
if opt.backend == 'cudnn' then
require 'cudnn'
end
local model = torch.load(filename)
if torch.type(model) == 'nn.DataParallelTable' then
return makeDataParallel(model:get(1):float(), nGPU)
elseif torch.type(model) == 'nn.Sequential' or torch.type(model) == 'nn.gModule' then
for i,module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' then
model.modules[i] = makeDataParallel(module:get(1):float(), nGPU)
end
end
return model
else
error('The loaded model is not a Sequential or DataParallelTable module.')
end
end
------------------------------------
return {
MSRinit = MSRinit,
FCinit = FCinit,
DisableBias = DisableBias,
makeDataParallelTable = makeDataParallelTable,
saveDataParallel = saveDataParallel
}