Skip to content

Commit f9002d5

Browse files
committed
code works well in pytorch1.6.0 and add code with Automatic Mixed Training
1 parent ac055f1 commit f9002d5

File tree

4 files changed

+117
-10
lines changed

4 files changed

+117
-10
lines changed

README.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
- If there is something wrong in my code, please contact me, thanks!
1212

1313
## Environment
14-
- python 3.7.3
15-
- pytorch 1.2.0
16-
- opencv 3.4.2
14+
- python 3.7.7
15+
- pytorch 1.4.0 (>=1.2.0, 1.6.0 works too)
16+
- opencv 4.2.0.34 (others work too)
1717

1818
## Visualization
1919
1. In the **first** Non-local Layer.
@@ -32,17 +32,25 @@
3232
from lib.non_local_gaussian import NONLocalBlock2D
3333
from lib.non_local_embedded_gaussian import NONLocalBlock2D
3434
from lib.non_local_dot_product import NONLocalBlock2D
35+
```
36+
3537
2. Run **demo_MNIST_train.py** with one GPU or multi GPU to train the Network. Then the weights will be save in **weights/**.
3638
```
37-
CUDA_VISIBLE_DEVICES=0,1 python demo_MNIST.py
38-
39+
CUDA_VISIBLE_DEVICES=0,1 python demo_MNIST_train.py
40+
41+
# Or train with Automatic Mixed Precision based on pytorch 1.6.0
42+
CUDA_VISIBLE_DEVICES=0 python demo_MNIST_AMP_train_with_single_gpu.py
43+
```
44+
3945
3. Run **nl_map_save.py** to save NL_MAP of one test sample in **nl_map_vis**.
4046
```
4147
CUDA_VISIBLE_DEVICES=0,1 python nl_map_save.py
42-
48+
```
49+
4350
4. Come into **nl_map_vis/** and run **nl_map_vis.py** to visualize the NL_MAP. (tips: if the Non-local type you select is **non_local_concatenation** or **non_local_dot_product** (without Softmax operation), you may need to normalize NL_MAP in the visualize code)
4451
```
4552
python nl_map_save.py
53+
```
4654

4755
## Update Records
4856
1. Figure out how to implement the **concatenation** type, and add the code to **lib/**.
@@ -66,6 +74,9 @@ to **Non-Local_pytorch_0.3.1/**.
6674

6775
8. In order to visualize NL_MAP, some code have been slightly modified. The code **nl_map_save.py** is added to save NL_MAP (two Non-local Layer) of one test sample. The code **Non-local_pytorch/nl_map_vis.py** is added to visualize NL_MAP. Besieds, the code is support pytorch 1.2.0.
6876

77+
9. The code also works well in **pytorch 1.4.0**.
78+
79+
10. The code also works well in **pytorch 1.6.0**. Add **demo_MNIST_AMP_train_with_single_gpu.py** with Automatic Mixed Precision Training (FP16), supported by **pytorch 1.6.0**. It can reduce GPU memory during training. What's more, if you use GPU 2080Ti (tensor cores), training speed can be increased. More details (such as how to train with multiple GPUs) can be found in [here](https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training)
6980

7081
## Todo
7182
- Experiments on Charades dataset.
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
import torch.utils.data as Data
3+
import torchvision
4+
from lib.network import Network
5+
from torch import nn
6+
from torch.cuda import amp
7+
import time
8+
9+
10+
train_data = torchvision.datasets.MNIST(root='./mnist', train=True,
11+
transform=torchvision.transforms.ToTensor(),
12+
download=True)
13+
test_data = torchvision.datasets.MNIST(root='./mnist/',
14+
transform=torchvision.transforms.ToTensor(),
15+
train=False)
16+
17+
train_loader = Data.DataLoader(dataset=train_data, batch_size=128 * 50, shuffle=True)
18+
test_loader = Data.DataLoader(dataset=test_data, batch_size=128 * 50, shuffle=False)
19+
20+
train_batch_num = len(train_loader)
21+
test_batch_num = len(test_loader)
22+
23+
net = Network()
24+
if torch.cuda.is_available():
25+
# net = nn.DataParallel(net)
26+
net.cuda()
27+
28+
# +++++++++++++++++++++++++++++++
29+
scaler = amp.GradScaler()
30+
# +++++++++++++++++++++++++++++++
31+
32+
opt = torch.optim.Adam(net.parameters(), lr=0.001)
33+
loss_func = nn.CrossEntropyLoss()
34+
35+
for epoch_index in range(10):
36+
st = time.time()
37+
38+
torch.set_grad_enabled(True)
39+
net.train()
40+
for train_batch_index, (img_batch, label_batch) in enumerate(train_loader):
41+
if torch.cuda.is_available():
42+
img_batch = img_batch.cuda()
43+
label_batch = label_batch.cuda()
44+
45+
# ++++++++++++++++++++++++++++++++++++++++++++++
46+
# predict = net(img_batch)
47+
# loss = loss_func(predict, label_batch)
48+
with amp.autocast():
49+
predict = net(img_batch)
50+
loss = loss_func(predict, label_batch)
51+
# ++++++++++++++++++++++++++++++++++++++++++++++
52+
53+
net.zero_grad()
54+
# ++++++++++++++++++++++++++++++++++++++++++++++
55+
# loss.backward()
56+
# opt.step()
57+
scaler.scale(loss).backward()
58+
scaler.step(opt)
59+
scaler.update()
60+
# ++++++++++++++++++++++++++++++++++++++++++++++
61+
62+
print('(LR:%f) Time of a epoch:%.4fs' % (opt.param_groups[0]['lr'], time.time()-st))
63+
64+
torch.set_grad_enabled(False)
65+
net.eval()
66+
total_loss = []
67+
total_acc = 0
68+
total_sample = 0
69+
70+
for test_batch_index, (img_batch, label_batch) in enumerate(test_loader):
71+
if torch.cuda.is_available():
72+
img_batch = img_batch.cuda()
73+
label_batch = label_batch.cuda()
74+
75+
predict = net(img_batch)
76+
loss = loss_func(predict, label_batch)
77+
78+
predict = predict.argmax(dim=1)
79+
acc = (predict == label_batch).sum()
80+
81+
total_loss.append(loss)
82+
total_acc += acc
83+
total_sample += img_batch.size(0)
84+
85+
net.train()
86+
87+
mean_acc = total_acc.item() * 1.0 / total_sample
88+
mean_loss = sum(total_loss) / total_loss.__len__()
89+
90+
print('[Test] epoch[%d/%d] acc:%.4f%% loss:%.4f\n'
91+
% (epoch_index, 10, mean_acc * 100, mean_loss.item()))
92+
93+
# weight_path = 'weights/net.pth'
94+
# print('Save Net weights to', weight_path)
95+
# net.cpu()
96+
# torch.save(net.state_dict(), weight_path)

lib/non_local_embedded_gaussian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=T
130130
print(out.size())
131131

132132
img = torch.zeros(2, 3, 20, 20)
133-
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True)
133+
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
134134
out = net(img)
135135
print(out.size())
136136

137137
img = torch.randn(2, 3, 8, 20, 20)
138-
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True)
138+
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139139
out = net(img)
140140
print(out.size())
141141

lib/non_local_gaussian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def forward(self, x, return_nl_map=False):
7878
f = torch.matmul(theta_x, phi_x)
7979
f_div_C = F.softmax(f, dim=-1)
8080

81-
if self.store_last_batch_nl_map:
82-
self.nl_map = f_div_C
81+
# if self.store_last_batch_nl_map:
82+
# self.nl_map = f_div_C
8383

8484
y = torch.matmul(f_div_C, g_x)
8585
y = y.permute(0, 2, 1).contiguous()

0 commit comments

Comments
 (0)