10
10
from sklearn import metrics
11
11
import math
12
12
13
+ from naslib .predictors .zerocost import ZeroCost
13
14
from naslib .search_spaces .core .query_metrics import Metric
14
15
from naslib .utils import generate_kfold , cross_validation
15
16
17
+ from naslib import utils
18
+
16
19
logger = logging .getLogger (__name__ )
17
20
18
21
@@ -47,6 +50,9 @@ def __init__(self, predictor, config=None):
47
50
self .num_arches_to_mutate = 5
48
51
self .max_mutation_rate = 3
49
52
53
+ # For ZeroCost proxies
54
+ self .dataloader = None
55
+
50
56
def adapt_search_space (
51
57
self , search_space , load_labeled , scope = None , dataset_api = None
52
58
):
@@ -70,6 +76,9 @@ def adapt_search_space(
70
76
"This search space is not yet implemented in PredictorEvaluator."
71
77
)
72
78
79
+ if isinstance (self .predictor , ZeroCost ):
80
+ self .dataloader , _ , _ , _ , _ = utils .get_train_val_loaders (self .config )
81
+
73
82
def get_full_arch_info (self , arch ):
74
83
"""
75
84
Given an arch, return the accuracy, train_time,
@@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}):
139
148
arch .load_labeled_architecture (dataset_api = self .dataset_api )
140
149
141
150
arch_hash = arch .get_hash ()
142
- if False : # removing this for consistency, for now
143
- continue
144
- else :
145
- arch_hash_map [arch_hash ] = True
151
+
152
+ arch_hash_map [arch_hash ] = True
146
153
147
154
accuracy , train_time , info_dict = self .get_full_arch_info (arch )
148
155
xdata .append (arch )
@@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity):
295
302
hyperparams = self .predictor .get_hyperparams ()
296
303
297
304
fit_time_end = time .time ()
298
- test_pred = self .predictor .query (xtest , test_info )
305
+ if isinstance (self .predictor , ZeroCost ):
306
+ [g .parse () for g in xtest ] # parse the graphs because they will be used
307
+ test_pred = self .predictor .query_batch (xtest , self .dataloader )
308
+ else :
309
+ test_pred = self .predictor .query (xtest , test_info )
299
310
query_time_end = time .time ()
300
311
301
312
# If the predictor is an ensemble, take the mean
0 commit comments