-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathtrain.py
137 lines (120 loc) · 5.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import time
import logging
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
from apex.fp16_utils import *
from apex import amp, optimizers
from model import ft_net
from cross_entropy import ModelParallelCrossEntropy
from utils import get_class_split, get_sparse_onehot_label, compute_batch_acc
def get_data_loader(data_path, batch_size):
transform_train_list = [
transforms.Resize((384, 192), interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
data_transforms = transforms.Compose(transform_train_list)
image_dataset = datasets.ImageFolder(data_path, data_transforms)
dataloader = torch.utils.data.DataLoader(
image_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True
)
return len(image_dataset.classes), dataloader
def train_model(opt, data_loader, model, criterion, optimizer, class_split):
logging.info("Start training...")
for epoch in range(opt.num_epochs):
data_loader_iter = iter(data_loader)
for step in range(len(data_loader)):
start_time = time.time()
images, labels = data_loader_iter.next()
images = images.cuda(0)
labels = labels.cuda(0)
onehot_labels = get_sparse_onehot_label(labels, opt.num_gpus, opt.num_classes, opt.model_parallel, class_split)
# Forward
optimizer.zero_grad()
logits = model(images, labels=onehot_labels)
# Loss calculation
if opt.model_parallel:
compute_loss = step > 0 and step % 10 == 0
loss = criterion(compute_loss, opt.fp16, onehot_labels, *logits)
else:
loss = criterion(logits, labels)
# Backward
scale = 1.0
if opt.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scale = scaled_loss.item() / loss.item()
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
# Log training progress
if step > 0 and step % 10 == 0:
example_per_second = opt.batch_size / float(time.time() - start_time)
batch_acc = compute_batch_acc(logits, labels, opt.batch_size, opt.model_parallel, step)
logging.info(
"epoch [%.3d] iter = %d loss = %.3f scale = %.3f acc = %.4f example/sec = %.2f" %
(epoch+1, step, loss.item(), scale, batch_acc, example_per_second)
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="[%(asctime)s %(filename)s] %(message)s")
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpus', default='0,1,2,3', type=str, help='0,1,2,3')
parser.add_argument('--data_path', default='/your/data/path/Market-1501/train', type=str, help='training data path')
parser.add_argument('--num_epochs', default=15, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
parser.add_argument('--num_classes', default=0, type=int, help='number of classes')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--am', action="store_true", help='use am-softmax')
parser.add_argument('--model_parallel', action="store_true", help='use model parallel')
parser.add_argument('--fp16', action="store_true", help='use mixed-precision')
opt = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
cudnn.benchmark = True
gpu_ids = opt.gpus.split(",")
num_gpus = len(gpu_ids)
opt.num_gpus = num_gpus
num_classes, data_loader = get_data_loader(opt.data_path, opt.batch_size)
if opt.num_classes < num_classes:
opt.num_classes = num_classes
class_split = None
if opt.model_parallel:
# If using model parallel, split the number of classes
# accroding to the number of GPUs
class_split = get_class_split(opt.num_classes, num_gpus)
model = ft_net(
feature_dim=256,
num_classes=opt.num_classes,
num_gpus=num_gpus,
am=opt.am,
model_parallel=opt.model_parallel,
class_split=class_split
)
optimizer_ft = optim.SGD(
model.parameters(),
lr=opt.lr,
weight_decay=5e-4,
momentum=0.9,
nesterov=True
)
if opt.fp16:
model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level = "O1")
if opt.model_parallel:
# When using model parallel, we wrap all the model except classifier in DataParallel
model.backbone = nn.DataParallel(model.backbone).cuda()
model.features = nn.DataParallel(model.features).cuda()
criterion = ModelParallelCrossEntropy().cuda()
else:
# When not using model parallel, we use DataParallel directly
model = nn.DataParallel(model).cuda()
criterion = nn.CrossEntropyLoss().cuda()
train_model(opt, data_loader, model, criterion, optimizer_ft, class_split)