Skip to content

Commit cf43343

Browse files
committed
Update benchmark and validate scripts to output results in JSON with a fixed delimiter for use in multi-process launcher
1 parent 1331c14 commit cf43343

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

benchmark.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,20 +473,21 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
473473
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
474474
batch_size = initial_batch_size
475475
results = dict()
476+
error_str = 'Unknown'
476477
while batch_size >= 1:
477478
torch.cuda.empty_cache()
478479
try:
479480
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
480481
results = bench.run()
481482
return results
482483
except RuntimeError as e:
483-
e_str = str(e)
484-
print(e_str)
485-
if 'channels_last' in e_str:
486-
print(f'Error: {model_name} not supported in channels_last, skipping.')
484+
error_str = str(e)
485+
if 'channels_last' in error_str:
486+
_logger.error(f'{model_name} not supported in channels_last, skipping.')
487487
break
488-
print(f'Error: "{e_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
488+
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
489489
batch_size = decay_batch_exp(batch_size)
490+
results['error'] = error_str
490491
return results
491492

492493

@@ -528,13 +529,14 @@ def benchmark(args):
528529
model_results = OrderedDict(model=model)
529530
for prefix, bench_fn in zip(prefixes, bench_fns):
530531
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
531-
if prefix:
532+
if prefix and 'error' not in run_results:
532533
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
533534
model_results.update(run_results)
534-
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
535-
model_results.setdefault('param_count', param_count)
536-
model_results.pop('train_param_count', 0)
537-
return model_results if model_results['param_count'] else dict()
535+
if 'error' not in model_results:
536+
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
537+
model_results.setdefault('param_count', param_count)
538+
model_results.pop('train_param_count', 0)
539+
return model_results
538540

539541

540542
def main():
@@ -578,13 +580,15 @@ def main():
578580
sort_key = 'train_samples_per_sec'
579581
elif 'profile' in args.bench:
580582
sort_key = 'infer_gmacs'
583+
results = filter(lambda x: sort_key in x, results)
581584
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
582585
if len(results):
583586
write_results(results_file, results)
584587
else:
585588
results = benchmark(args)
586-
json_str = json.dumps(results, indent=4)
587-
print(json_str)
589+
590+
# output results in JSON to stdout w/ delimiter for runner script
591+
print(f'--result\n{json.dumps(results, indent=4)}')
588592

589593

590594
def write_results(results_file, results):

validate.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import csv
1313
import glob
14+
import json
1415
import time
1516
import logging
1617
import torch
@@ -263,6 +264,7 @@ def validate(args):
263264
else:
264265
top1a, top5a = top1.avg, top5.avg
265266
results = OrderedDict(
267+
model=args.model,
266268
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
267269
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
268270
param_count=round(param_count / 1e6, 2),
@@ -276,6 +278,27 @@ def validate(args):
276278
return results
277279

278280

281+
def _try_run(args, initial_batch_size):
282+
batch_size = initial_batch_size
283+
results = OrderedDict()
284+
error_str = 'Unknown'
285+
while batch_size >= 1:
286+
args.batch_size = batch_size
287+
torch.cuda.empty_cache()
288+
try:
289+
results = validate(args)
290+
return results
291+
except RuntimeError as e:
292+
error_str = str(e)
293+
if 'channels_last' in error_str:
294+
break
295+
_logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.')
296+
batch_size = batch_size // 2
297+
results['error'] = error_str
298+
_logger.error(f'{args.model} failed to validate ({error_str}).')
299+
return results
300+
301+
279302
def main():
280303
setup_default_logging()
281304
args = parser.parse_args()
@@ -308,36 +331,25 @@ def main():
308331
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
309332
results = []
310333
try:
311-
start_batch_size = args.batch_size
334+
initial_batch_size = args.batch_size
312335
for m, c in model_cfgs:
313-
batch_size = start_batch_size
314336
args.model = m
315337
args.checkpoint = c
316-
result = OrderedDict(model=args.model)
317-
r = {}
318-
while not r and batch_size >= args.num_gpu:
319-
torch.cuda.empty_cache()
320-
try:
321-
args.batch_size = batch_size
322-
print('Validating with batch size: %d' % args.batch_size)
323-
r = validate(args)
324-
except RuntimeError as e:
325-
if batch_size <= args.num_gpu:
326-
print("Validation failed with no ability to reduce batch size. Exiting.")
327-
raise e
328-
batch_size = max(batch_size // 2, args.num_gpu)
329-
print("Validation failed, reducing batch size by 50%")
330-
result.update(r)
338+
r = _try_run(args, initial_batch_size)
339+
if 'error' in r:
340+
continue
331341
if args.checkpoint:
332-
result['checkpoint'] = args.checkpoint
333-
results.append(result)
342+
r['checkpoint'] = args.checkpoint
343+
results.append(r)
334344
except KeyboardInterrupt as e:
335345
pass
336346
results = sorted(results, key=lambda x: x['top1'], reverse=True)
337347
if len(results):
338348
write_results(results_file, results)
339349
else:
340-
validate(args)
350+
results = validate(args)
351+
# output results in JSON to stdout w/ delimiter for runner script
352+
print(f'--result\n{json.dumps(results, indent=4)}')
341353

342354

343355
def write_results(results_file, results):

0 commit comments

Comments
 (0)