-
Notifications
You must be signed in to change notification settings - Fork 15
/
multiscorer.py
134 lines (96 loc) · 4.46 KB
/
multiscorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class MultiScorer():
'''
Use this class to encapsulate and/or aggregate multiple scoring functions so that it can be passed as an argument for scoring in scikit's cross_val_score function.
Instances of this class are also callables, with signature as needed by `cross_val_score`.
'''
def __init__(self, metrics):
'''
Create a new instance of MultiScorer.
Parameters
----------
metrics: dict
The metrics to be used by the scorer.
The dictionary must have as key a name (str) for the metric and as value a tuple containing the metric function itself and a dict literal of the additional named arguments to be passed to the function.
The metric function should be one of the `sklearn.metrics` function or any other callable with the same signature: `metric(y_real, y, **kwargs)`.
'''
self.metrics = metrics
self.results = {}
self._called = False
self.n_folds = 0
for metric in metrics.keys():
self.results[metric] = []
def __call__(self, estimator, X, y):
'''
To be called by for evaluation from sklearn's GridSearchCV or cross_val_score.
Parameters are as they are defined in the respective documentation.
Returns
-------
A dummy value of 0.5 just for compatibility reasons.
'''
self.n_folds += 1
yPred = estimator.predict(X)
for key in self.metrics.keys():
metric, kwargs = self.metrics[key]
self.results[key].append(metric(y, yPred, **kwargs))
self._called = True
return 0.5
def get_metric_names(self):
'''
Get all the metric names as given when initialized
Returns
-------
A list containing the given names (str) of the metrics
'''
return self.metrics.keys()
def get_results(self, metric=None, fold='all'):
'''
Get the results of a specific or all the metrics.
This method should be called after the object itself has been called so that the metrics are applied.
Parameters
----------
metric: str or None (default)
The given name of a metric to return its result(s). If omitted the results of all metrics will be returned.
fold: int in range [1, number_of_folds] or 'all' (Default)
Get the metric(s) results for the specific fold.
The number of folds corresponds to the number of times the instance is called.
If its value is a number, either the score of a single metric for that fold or a dictionary of the (single) scores for that fold will be returned, depending on the value of `metric` parameter.
If its value is 'all', either a list of a single metric or a dictionary containing the lists of scores for all folds will be returned, depending on the value of `metric` parameter.
Returns
-------
metric_result_for_one_fold
The result of the designated metric function for the specific fold, if `metric` parameter was not omitted and an integer value was given to `fold` parameter.
If the value of `metric` does not correspond to a metric name, `None` will be returned.
all_metric_results_for_one_fold: dict
A dict having as keys the names of the metrics and as values their results for the specific fold.
This will be returned only if `metric` parameter was omitted and an integer value was given to `fold` parameter.
metric_results_for_all_folds: list
A list of length number_of_folds containing the results of all folds for the specific metric, if `metric` parameter was not omitted and value 'all' was given to `fold`.
If the value of `metric` does not correspond to a metric name, `None` will be returned.
all_metric_results_for_all_folds: dict of lists
A dict having as keys the names of the metrics and as values lists (of length number_of_folds) of their results for all folds.
This will be returned only if `metric` parameter was omitted and 'all' value was given to `fold` parameter.
Raises
------
UserWarning
If this method is called before the instance is called for evaluation.
ValueError
If the value for `fold` parameter is not appropriate.
'''
if not self._called:
raise UserWarning('Evaluation has not been performed yet.')
if isinstance(fold, str) and fold == 'all':
if metric is None:
return self.results
else:
return self.results[metric]
elif isinstance(fold, int):
if fold not in range(1, self.n_folds+1): raise ValueError('Invalid fold index: '+str(fold))
if metric is None:
res = dict()
for key in self.results.keys():
res[key] = self.results[key][fold-1]
return res
else:
return self.results[metric][fold-1]
else:
raise ValueError('Unexpected fold value: %s' %(str(fold)))