-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
159 lines (129 loc) · 4.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import sys
import logging
import collections
from copy import deepcopy
from torch import distributed as dist
def is_dist() -> bool:
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def merge(dict1, dict2):
''' Return a new dictionary by merging
two dictionaries recursively.
'''
result = deepcopy(dict1)
for key, value in dict2.items():
if isinstance(value, collections.abc.Mapping):
result[key] = merge(result.get(key, {}), value)
else:
result[key] = deepcopy(dict2[key])
return result
def fill_config(config):
#config = copy.deepcopy(config)
base_cfg = config.pop('base', {})
for sub, sub_cfg in config.items():
if isinstance(sub_cfg, dict):
config[sub] = merge(base_cfg, sub_cfg)
elif isinstance(sub_cfg, list):
config[sub] = [merge(base_cfg, c) for c in sub_cfg]
return config
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
class LoggerBuffer():
def __init__(self, name, path, headers, screen_intvl=1):
self.logger = self.get_logger(name, path)
self.history = []
self.headers = headers
self.screen_intvl = screen_intvl
def get_logger(self, name, path):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# set log level
msg_fmt = '[%(levelname)s] %(asctime)s, %(message)s'
time_fmt = '%Y-%m-%d_%H-%M-%S'
formatter = logging.Formatter(msg_fmt, time_fmt)
# define file handler and set formatter
file_handler = logging.FileHandler(path, 'w')
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
stream_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
# to avoid duplicated logging info in PyTorch >1.9
if len(logger.root.handlers) == 0:
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.root.addHandler(stream_handler)
# to avoid duplicated logging info in PyTorch >1.8
for handler in logger.root.handlers:
handler.setLevel(logging.WARNING)
return logger
def clean(self):
self.history = {}
def update(self, msg):
# get the iteration
n = msg.pop('Iter')
self.history.append(msg)
# header expansion
novel_heads = [k for k in msg if k not in self.headers]
if len(novel_heads) > 0:
self.logger.warning(
'Items {} are not defined.'.format(novel_heads))
# missing items
missing_heads = [k for k in self.headers if k not in msg]
if len(missing_heads) > 0:
self.logger.warning(
'Items {} are missing.'.format(missing_heads))
if self.screen_intvl != 1:
doc_msg = ['Iter: {:5d}'.format(n)]
for k, fmt in self.headers.items():
v = self.history[-1][k]
doc_msg.append(('{}: {'+fmt+'}').format(k, v))
doc_msg = ', '.join(doc_msg)
self.logger.debug(doc_msg)
'''
construct message to show on screen every `self.screen_intvl` iters
'''
if n % self.screen_intvl == 0:
screen_msg = ['Iter: {:5d}'.format(n)]
for k, fmt in self.headers.items():
vals = [msg[k] for msg in self.history[-self.screen_intvl:]
if k in msg]
v = sum(vals) / len(vals)
screen_msg.append(('{}: {'+fmt+'}').format(k, v))
screen_msg = ', '.join(screen_msg)
self.logger.info(screen_msg)