Skip to content

Commit

Permalink
fix print stmts
Browse files Browse the repository at this point in the history
  • Loading branch information
crwhite14 committed Jul 12, 2020
1 parent e624ad1 commit 2cd286a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions post_hoc_celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main(config):

_, best_thresh = val_model(net, valloader, get_best_accuracy, protected_index, prediction_index)

print('val_results')
print('val_results, thresh', best_thresh.item())
print_objective_results(valloader, net, best_thresh, protected_index, prediction_index)
print()
print('test_results')
Expand All @@ -340,7 +340,7 @@ def main(config):
if 'random' in config['models']:

best_obj, best_thresh = val_model(net, valloader, get_best_objective, protected_index, prediction_index)
print('val best thresh results')
print('val best thresh results, thresh', best_thresh.item())
print_objective_results(valloader, net, best_thresh, protected_index, prediction_index)
print()
print('test best thresh results')
Expand All @@ -357,7 +357,7 @@ def main(config):

rand_model.eval()
best_obj, best_thresh = val_model(rand_model, valloader, get_best_objective, protected_index, prediction_index)
print('iteration', iteration, 'obj', best_obj.item(), 'thresh', best_thresh)
print('iteration', iteration, 'obj', best_obj.item(), 'thresh', best_thresh.item())

if best_obj < rand_result[0]:
print('found new best')
Expand All @@ -376,8 +376,10 @@ def main(config):

print('val_results')
print_objective_results(valloader, best_model, best_thresh, protected_index, prediction_index)
print()
print('test_results')
result_dict = print_objective_results(testloader, best_model, best_thresh, protected_index, prediction_index)
print()
results['random'] = result_dict

torch.save(best_model.state_dict(), config['random']['checkpoint'])
Expand Down Expand Up @@ -465,8 +467,10 @@ def main(config):

print('val_results')
print_objective_results(valloader, actor, best_thresh, protected_index, prediction_index)
print()
print('test_results')
result_dict = print_objective_results(testloader, actor, best_thresh, protected_index, prediction_index)
print()
results['adversarial'] = result_dict

torch.save(actor.state_dict(), config['adversarial']['checkpoint'])
Expand Down

0 comments on commit 2cd286a

Please sign in to comment.