-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path06_run_best_config.py
144 lines (125 loc) · 4.83 KB
/
06_run_best_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Get best val and eval on test run script
Finds the parameters with the best result on the validation data
Solves and evaluates with those on the test data
"""
from __future__ import absolute_import
import argparse
import logging
import os
import sys
import time
import attr
import pandas as pd
import toml
from linajea.config import (SolveParametersConfig,
TrackingConfig)
import linajea.evaluation
from linajea.process_blockwise import solve_blockwise
from linajea.utils import (print_time,
getNextInferenceData)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
help='path to config file')
parser.add_argument('--checkpoint', type=int, default=-1,
help='checkpoint to process')
parser.add_argument('--swap_val_test', action="store_true",
help='swap validation and test data?')
parser.add_argument('--sort_by', type=str, default="sum_errors",
help=('Which metric to use to select best '
'parameters/weights'))
args = parser.parse_args()
config = TrackingConfig.from_file(args.config)
logging.basicConfig(
level=config.general.logging,
handlers=[
logging.FileHandler('run.log', mode='a'),
logging.StreamHandler(sys.stdout),
],
format='%(asctime)s %(name)s %(levelname)-8s %(message)s')
score_columns = ['fn_edges', 'identity_switches',
'fp_divisions', 'fn_divisions']
if not config.general.sparse:
score_columns = ['fp_edges'] + score_columns
results = {}
samples = set()
args.validation = not args.swap_val_test
for sample_idx, inf_config in enumerate(getNextInferenceData(
args, is_evaluate=True)):
sample = inf_config.inference_data.data_source.datafile.filename
checkpoint = inf_config.inference_data.checkpoint
cell_score_threshold = inf_config.inference_data.cell_score_threshold
samples.add(sample)
logger.debug(
"getting results for:", sample, checkpoint, cell_score_threshold)
res = linajea.evaluation.get_results_sorted(
inf_config,
filter_params={"val": True},
score_columns=score_columns,
sort_by=args.sort_by)
res = res.assign(
checkpoint=checkpoint).assign(
cell_score_threshold=cell_score_threshold)
results[(
os.path.basename(sample), checkpoint, cell_score_threshold)] = \
res.reset_index()
args.validation = not args.validation
results = pd.concat(list(results.values())).reset_index()
del results['_id']
del results['param_id']
solve_params = attr.fields_dict(SolveParametersConfig)
by = [
"matching_threshold",
"weight_node_score",
"selection_constant",
"track_cost",
"weight_division",
"division_constant",
"weight_child",
"weight_continuation",
"weight_edge_score",
"checkpoint",
"cell_score_threshold"
]
if "cell_state_key" in results:
by.append("cell_state_key")
results = results.groupby(by, dropna=False, as_index=False).agg(
lambda x: -1
if len(x) != len(samples)
else sum(x)
if (not isinstance(x.iloc[0], list) and
not isinstance(x.iloc[0], dict) and
not isinstance(x.iloc[0], str)
)
else x.iloc[0])
results = results[results.sum_errors != -1]
results.sort_values(args.sort_by, ascending=True, inplace=True)
for k in solve_params.keys():
if k == "tag" and k not in results.iloc[0]:
solve_params[k] = None
continue
if k == "cell_state_key" and k not in results.iloc[0]:
solve_params[k] = None
continue
solve_params[k] = results.iloc[0][k]
solve_params['val'] = False
config.path = os.path.join("tmp_configs", "config_{}.toml".format(
time.time()))
config_dict = attr.asdict(config)
config_dict['solve']['parameters'] = [solve_params]
config_dict['solve']['grid_search'] = False
config_dict['solve']['random_search'] = False
with open(config.path, 'w') as f:
toml.dump(config_dict, f, encoder=toml.TomlNumpyEncoder())
args.config = config.path
start_time = time.time()
for inf_config in getNextInferenceData(args, is_solve=True):
solve_blockwise(inf_config)
end_time = time.time()
print_time(end_time - start_time)
start_time = time.time()
for inf_config in getNextInferenceData(args, is_evaluate=True):
linajea.evaluation.evaluate_setup(inf_config)
end_time = time.time()
print_time(end_time - start_time)