Skip to content
This repository was archived by the owner on Aug 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion main_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def get_args_parser():
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")

# logging with aim
parser.add_argument("--use_aim", default=True, type=bool, help="whether to use aim for logging.")
parser.add_argument("--aim_repo", default=None, type=str, help="path to Aim repository.")
parser.add_argument("--aim_run_hash", default=None, type=str,
help="Aim run hash. Create a new run if not specified.")

return parser


Expand Down Expand Up @@ -301,7 +308,7 @@ def train_dino(args):
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
fp16_scaler, args):
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger = utils.MetricLogger(args, delimiter=" ")
header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
# update weight decay and learning rate according to their schedule
Expand Down
23 changes: 22 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,39 @@ def reduce_dict(input_dict, average=True):
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

try:
import functools

from aim import Run

@functools.lru_cache()
def get_aim_run(repo, run_hash):
from aim import Run
return Run(run_hash=run_hash, repo=repo)

except ImportError:
print("Warning: Aim is not installed. Install aim to use metric logging.")
get_aim_run = None


class MetricLogger(object):
def __init__(self, delimiter="\t"):
def __init__(self, args, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
self.aim_run = None
if args.use_aim and get_aim_run:
self.aim_run = get_aim_run(args.aim_repo, args.aim_run_hash)
for key, value in vars(args).items():
self.aim_run.set(('cli_args', key), value, strict=False)

def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
if self.aim_run:
self.aim_run.track(v, name=k)

def __getattr__(self, attr):
if attr in self.meters:
Expand Down