-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRAM_translated.py
107 lines (90 loc) · 4.11 KB
/
RAM_translated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from torch.distributions.normal import Normal
from RAM import MODEL, LOSS, adjust_learning_rate
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if device.type=='cuda' else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])),
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])),
batch_size=batch_size, shuffle=True, **kwargs)
T = 4
lr = 0.0001
std = 0.25
scale = 3
decay = 0.975
im_sz = 60
glimpse_width = 12
model = MODEL(im_sz=im_sz, channel=1, glimps_width=glimpse_width, scale=scale, std = std).to(device)
loss_fn = LOSS(T=T, gamma=1, device=device).to(device)
optimizer = optim.Adam(list(model.parameters())+list(loss_fn.parameters()), lr=lr)
def translate_img(x, to_sz):
B,C,H,W = x.size()
x_t = -torch.ones(B,C,to_sz,to_sz).to(device) # background of MNIST is mapped to -1
for i in range(B):
loch = np.random.randint(0,33)
locw = np.random.randint(0,33)
x_t[i,:,loch:loch+H,locw:locw+W] = x[i]
return x_t
for epoch in range(1,201):
'''
Training
'''
adjust_learning_rate(optimizer, epoch, lr, decay)
model.train()
train_aloss, train_lloss, train_bloss, train_reward = 0, 0, 0, 0
for batch_idx, (data, label) in enumerate(train_loader):
data = translate_img(data.to(device), im_sz)
label = label.to(device)
optimizer.zero_grad()
model.initialize(data.size(0), device)
loss_fn.initialize(data.size(0))
for _ in range(T):
logpi, action = model(data)
aloss, lloss, bloss, reward = loss_fn(action, label, logpi) # loss_fn stores logpi during intermediate time-stamps and returns loss in the last time-stamp
loss = aloss+lloss+bloss
loss.backward()
optimizer.step()
train_aloss += aloss.item()
train_lloss += lloss.item()
train_bloss += bloss.item()
train_reward += reward.item()
print('====> Epoch: {} Average loss: a {:.4f} l {:.4f} b {:.4f} Reward: {:.1f}'.format(
epoch, train_aloss / len(train_loader.dataset),
train_lloss / len(train_loader.dataset),
train_bloss / len(train_loader.dataset),
train_reward *100/ len(train_loader.dataset)))
# uncomment below line to save the model
# torch.save([model.state_dict(), loss_fn.state_dict(), optimizer.state_dict()],'results/final'+str(epoch)+'.pth')
'''
Evaluation
'''
model.eval()
test_aloss, test_lloss, test_bloss, test_reward = 0, 0, 0, 0
for batch_idx, (data, label) in enumerate(test_loader):
data = translate_img(data.to(device), im_sz)
label = label.to(device)
model.initialize(data.size(0), device)
loss_fn.initialize(data.size(0))
for _ in range(T):
logpi, action = model(data)
aloss, lloss, bloss, reward = loss_fn(action, label, logpi)
loss = aloss+lloss+bloss
test_aloss += aloss.item()
test_lloss += lloss.item()
test_bloss += bloss.item()
test_reward += reward.item()
print('====> Epoch: {} Average loss: a {:.4f} l {:.4f} b {:.4f} Reward: {:.1f}'.format(
epoch, test_aloss / len(test_loader.dataset),
test_lloss / len(test_loader.dataset),
test_bloss / len(test_loader.dataset),
test_reward *100/ len(test_loader.dataset)))