Skip to content

Commit 8cb5d2b

Browse files
authored
Merge pull request #186 from automl/predictor_evaluator_fix
2 parents 0b51bc1 + f60a56a commit 8cb5d2b

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

naslib/defaults/predictor_evaluator.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
from sklearn import metrics
1111
import math
1212

13+
from naslib.predictors.zerocost import ZeroCost
1314
from naslib.search_spaces.core.query_metrics import Metric
1415
from naslib.utils import generate_kfold, cross_validation
1516

17+
from naslib import utils
18+
1619
logger = logging.getLogger(__name__)
1720

1821

@@ -47,6 +50,9 @@ def __init__(self, predictor, config=None):
4750
self.num_arches_to_mutate = 5
4851
self.max_mutation_rate = 3
4952

53+
# For ZeroCost proxies
54+
self.dataloader = None
55+
5056
def adapt_search_space(
5157
self, search_space, load_labeled, scope=None, dataset_api=None
5258
):
@@ -70,6 +76,9 @@ def adapt_search_space(
7076
"This search space is not yet implemented in PredictorEvaluator."
7177
)
7278

79+
if isinstance(self.predictor, ZeroCost):
80+
self.dataloader, _, _, _, _ = utils.get_train_val_loaders(self.config)
81+
7382
def get_full_arch_info(self, arch):
7483
"""
7584
Given an arch, return the accuracy, train_time,
@@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}):
139148
arch.load_labeled_architecture(dataset_api=self.dataset_api)
140149

141150
arch_hash = arch.get_hash()
142-
if False: # removing this for consistency, for now
143-
continue
144-
else:
145-
arch_hash_map[arch_hash] = True
151+
152+
arch_hash_map[arch_hash] = True
146153

147154
accuracy, train_time, info_dict = self.get_full_arch_info(arch)
148155
xdata.append(arch)
@@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity):
295302
hyperparams = self.predictor.get_hyperparams()
296303

297304
fit_time_end = time.time()
298-
test_pred = self.predictor.query(xtest, test_info)
305+
if isinstance(self.predictor, ZeroCost):
306+
[g.parse() for g in xtest] # parse the graphs because they will be used
307+
test_pred = self.predictor.query_batch(xtest, self.dataloader)
308+
else:
309+
test_pred = self.predictor.query(xtest, test_info)
299310
query_time_end = time.time()
300311

301312
# If the predictor is an ensemble, take the mean

naslib/predictors/zerocost.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
based on https://github.com/mohsaied/zero-cost-nas
55
"""
66
import torch
7+
import numpy as np
78
import logging
89
import math
910

@@ -24,12 +25,21 @@ def __init__(self, method_type="jacov"):
2425
self.num_imgs_or_batches = 1
2526
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2627

27-
def query(self, graph, dataloader=None, info=None):
28+
def query_batch(self, graphs, dataloader):
29+
scores = []
30+
31+
for graph in graphs:
32+
score = self.query(graph, dataloader)
33+
scores.append(score)
34+
35+
return np.array(scores)
36+
37+
def query(self, graph, dataloader):
2838
loss_fn = graph.get_loss_fn()
2939

3040
n_classes = graph.num_classes
3141
score = predictive.find_measures(
32-
net_orig=graph,
42+
net_orig=graph.to(self.device),
3343
dataloader=dataloader,
3444
dataload_info=(self.dataload, self.num_imgs_or_batches, n_classes),
3545
device=self.device,

0 commit comments

Comments
 (0)