-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogger.lua
69 lines (51 loc) · 2.02 KB
/
logger.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
--[[
Logger functions for easier logging data management.
]]
local utils = paths.dofile('utils.lua')
------------------------------------
local function storeModel(model, optimState, epoch, flag)
if flag then
print('Saving model snapshot to: ' .. paths.concat(opt.save,'model_' .. epoch ..'.t7'))
utils.saveDataParallel(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model:clearState())
torch.save(paths.concat(opt.save,'last_epoch.t7'), epoch)
torch.save(paths.concat(opt.save,'optim_' .. epoch ..'.t7'), optimstate)
torch.save(paths.concat(opt.save,'meters_' .. epoch ..'.t7'), meters)
else
print('Saving model snapshot to: ' .. paths.concat(opt.save,'model.t7'))
utils.saveDataParallel(paths.concat(opt.save, 'model.t7'), model:clearState())
torch.save(paths.concat(opt.save,'last_epoch.t7'), epoch)
torch.save(paths.concat(opt.save,'optim.t7'), optimstate)
torch.save(paths.concat(opt.save,'meters.t7'),meters)
end
end
------------------------------------
local function logging(model, optimState, meters, loggers, epoch)
require 'image'
meters.conf.normalized = true
image.save(paths.concat(opt.save,'confusion_' .. epoch ..'.jpg'),
image.scale(meters.conf:value():float(),1000,1000,'simple'))
meters.conf.normalized = false
loggers.test:style{'+-','+-','+-'}
loggers.test:plot()
loggers.train:style{'+-','+-'}
loggers.train:plot()
loggers.full_train:style{'+-'}
loggers.full_train:plot()
-- store model snapshot
if opt.snapshot > 0 then
if epoch%opt.snapshot == 0 then
storeModel(model.modules[1], optimState, epoch, true)
end
elseif opt.snapshot < 0 then
if epoch%math.abs(opt.snapshot) == 0 then
storeModel(model.modules[1], optimState, epoch, false)
end
else
-- save only at the last epoch
if epoch == opt.nEpochs then
storeModel(model.modules[1], optimState, epoch, false)
end
end
end
------------------------------------
return logging