Skip to content

Commit 0f9911b

Browse files
committed
fix_gridsearch
1 parent 7106704 commit 0f9911b

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

libmultilabel/linear/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,31 @@ class GridSearchCV(sklearn.model_selection.GridSearchCV):
109109
The usage is similar to sklearn's, except that the parameter ``scoring`` is unavailable. Instead, specify ``scoring_metric`` in ``MultiLabelEstimator`` in the Pipeline.
110110
111111
Args:
112-
pipeline (sklearn.pipeline.Pipeline): A sklearn Pipeline for grid search.
112+
estimator (estimator object): A estimator for grid search.
113113
param_grid (dict): Search space for a grid search containing a dictionary of
114114
parameters and their corresponding list of candidate values.
115115
n_jobs (int, optional): Number of CPU cores run in parallel. Defaults to None.
116116
"""
117117

118-
_required_parameters = ["pipeline", "param_grid"]
118+
_required_parameters = ["estimator", "param_grid"]
119119

120-
def __init__(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict, n_jobs=None, **kwargs):
121-
assert isinstance(pipeline, sklearn.pipeline.Pipeline)
120+
def __init__(self, estimator, param_grid: dict, n_jobs=None, **kwargs):
122121
if n_jobs is not None and n_jobs > 1:
123-
param_grid = self._set_singlecore_options(pipeline, param_grid)
122+
param_grid = self._set_singlecore_options(estimator, param_grid)
124123
if "scoring" in kwargs.keys():
125124
raise ValueError(
126125
"Please specify the validation metric with `MultiLabelEstimator.scoring_metric` in the Pipeline instead of using the parameter `scoring`."
127126
)
128127

129-
super().__init__(estimator=pipeline, n_jobs=n_jobs, param_grid=param_grid, **kwargs)
128+
super().__init__(estimator=estimator, n_jobs=n_jobs, param_grid=param_grid, **kwargs)
130129

131-
def _set_singlecore_options(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict):
130+
def _set_singlecore_options(self, estimator, param_grid: dict):
132131
"""Set liblinear options to `-m 1`. The grid search option `n_jobs`
133132
runs multiple processes in parallel. Using multithreaded liblinear
134133
in conjunction with grid search oversubscribes the CPU and deteriorates
135134
the performance significantly.
136135
"""
137-
params = pipeline.get_params()
136+
params = estimator.get_params()
138137
for name, transform in params.items():
139138
if isinstance(transform, MultiLabelEstimator):
140139
regex = r"-m \d+"

0 commit comments

Comments
 (0)