Skip to content

Commit f57c576

Browse files
committed
support multi-node training
1 parent 68c190b commit f57c576

File tree

5 files changed

+56
-30
lines changed

5 files changed

+56
-30
lines changed

GETTING_STARTED.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,26 @@ If you want to train model with 4 GPUs, you can run:
3232
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
3333
```
3434

35+
If you want to train model with multiple machines, you can run:
36+
37+
```
38+
# machine 1
39+
export GLOO_SOCKET_IFNAME=eth0
40+
export NCCL_SOCKET_IFNAME=eth0
41+
42+
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
43+
--num-gpus 4 --num-machines 2 --machine-rank 0 --dist-url tcp://ip:port
44+
45+
# machine 2
46+
export GLOO_SOCKET_IFNAME=eth0
47+
export NCCL_SOCKET_IFNAME=eth0
48+
49+
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
50+
--num-gpus 4 --num-machines 2 --machine-rank 1 --dist-url tcp://ip:port
51+
```
52+
53+
Make sure the dataset path and code are the same in different machines, and machines can communicate with each other.
54+
3555
To evaluate a model's performance, use
3656

3757
```bash

fastreid/engine/defaults.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,18 @@ def test(cls, cfg, model):
467467
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
468468
results[dataset_name] = results_i
469469

470-
if comm.is_main_process():
471-
assert isinstance(
472-
results, dict
473-
), "Evaluator must return a dict on the main process. Got {} instead.".format(
474-
results
475-
)
476-
print_csv_format(results)
470+
if comm.is_main_process():
471+
assert isinstance(
472+
results, dict
473+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
474+
results
475+
)
476+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
477+
results_i['dataset'] = dataset_name
478+
print_csv_format(results_i)
477479

478-
if len(results) == 1: results = list(results.values())[0]
480+
if len(results) == 1:
481+
results = list(results.values())[0]
479482

480483
return results
481484

fastreid/engine/hooks.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,19 +360,20 @@ def _do_eval(self):
360360
)
361361
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
362362

363-
# Remove extra memory cache of main process due to evaluation
364-
torch.cuda.empty_cache()
363+
# Evaluation may take different time among workers.
364+
# A barrier make them start the next iteration together.
365+
comm.synchronize()
365366

366367
def after_epoch(self):
367368
next_epoch = self.trainer.epoch + 1
368-
is_final = next_epoch == self.trainer.max_epoch
369-
if is_final or (self._period > 0 and next_epoch % self._period == 0):
369+
if self._period > 0 and next_epoch % self._period == 0:
370370
self._do_eval()
371-
# Evaluation may take different time among workers.
372-
# A barrier make them start the next iteration together.
373-
comm.synchronize()
374371

375372
def after_train(self):
373+
next_epoch = self.trainer.epoch + 1
374+
# This condition is to prevent the eval from running after a failed training
375+
if next_epoch % self._period != 0 and next_epoch >= self.trainer.max_epoch:
376+
self._do_eval()
376377
# func is likely a closure that holds reference to the trainer
377378
# therefore we clean it to avoid circular reference in the end
378379
del self._func

fastreid/evaluation/evaluator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88

9+
from fastreid.utils import comm
910
from fastreid.utils.logger import log_every_n_seconds
1011

1112

@@ -96,6 +97,7 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
9697
Returns:
9798
The return value of `evaluator.evaluate()`
9899
"""
100+
num_devices = comm.get_world_size()
99101
logger = logging.getLogger(__name__)
100102
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
101103

@@ -118,10 +120,11 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
118120
inputs["images"] = inputs["images"].flip(dims=[3])
119121
flip_outputs = model(inputs)
120122
outputs = (outputs + flip_outputs) / 2
123+
if torch.cuda.is_available():
124+
torch.cuda.synchronize()
121125
total_compute_time += time.perf_counter() - start_compute_time
122126
evaluator.process(inputs, outputs)
123127

124-
idx += 1
125128
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
126129
seconds_per_batch = total_compute_time / iters_after_start
127130
if idx >= num_warmup * 2 or seconds_per_batch > 30:
@@ -140,17 +143,18 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
140143
total_time_str = str(datetime.timedelta(seconds=total_time))
141144
# NOTE this format is parsed by grep
142145
logger.info(
143-
"Total inference time: {} ({:.6f} s / batch per device)".format(
144-
total_time_str, total_time / (total - num_warmup)
146+
"Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format(
147+
total_time_str, total_time / (total - num_warmup), num_devices
145148
)
146149
)
147150
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
148151
logger.info(
149-
"Total inference pure compute time: {} ({:.6f} s / batch per device)".format(
150-
total_compute_time_str, total_compute_time / (total - num_warmup)
152+
"Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format(
153+
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
151154
)
152155
)
153156
results = evaluator.evaluate()
157+
154158
# An evaluator may return None when not in main process.
155159
# Replace it by an empty dict instead to make it easier for downstream code to handle
156160
if results is None:

fastreid/evaluation/testing.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,21 @@
88
from tabulate import tabulate
99
from termcolor import colored
1010

11-
logger = logging.getLogger(__name__)
12-
1311

1412
def print_csv_format(results):
1513
"""
16-
Print main metrics in a format similar to Detectron,
14+
Print main metrics in a format similar to Detectron2,
1715
so that they are easy to copypaste into a spreadsheet.
1816
Args:
19-
results (OrderedDict[dict]): task_name -> {metric -> score}
17+
results (OrderedDict): {metric -> score}
2018
"""
21-
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
22-
task = list(results.keys())[0]
23-
metrics = ["Datasets"] + [k for k in results[task]]
19+
# unordered results cannot be properly printed
20+
assert isinstance(results, OrderedDict) or not len(results), results
21+
logger = logging.getLogger(__name__)
2422

25-
csv_results = []
26-
for task, res in results.items():
27-
csv_results.append((task, *list(res.values())))
23+
dataset_name = results.pop('dataset')
24+
metrics = ["Dataset"] + [k for k in results]
25+
csv_results = [(dataset_name, *list(results.values()))]
2826

2927
# tabulate it
3028
table = tabulate(

0 commit comments

Comments
 (0)