Skip to content

Commit c033a52

Browse files
Merge pull request #14 from chcwww/sep_ovr_thread
Use python-level parallel OVR training instead of multi-core LIBLINEAR
2 parents c473d48 + faa7293 commit c033a52

File tree

1 file changed

+125
-21
lines changed

1 file changed

+125
-21
lines changed

libmultilabel/linear/linear.py

Lines changed: 125 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22

33
import logging
44
import os
5+
import psutil
6+
import threading
7+
import queue
8+
import re
59

610
import numpy as np
711
import scipy.sparse as sparse
812
from liblinear.liblinearutil import train, problem, parameter, solver_names
913
from tqdm import tqdm
1014

15+
from ctypes import c_double
16+
1117
__all__ = [
1218
"train_1vsrest",
1319
"train_thresholding",
@@ -86,14 +92,114 @@ def _to_dense_array(self, matrix: np.matrix | sparse.csr_matrix) -> np.ndarray:
8692
return np.asarray(matrix)
8793

8894

95+
class ParallelOVRTrainer(threading.Thread):
96+
"""A trainer for parallel 1vsrest training."""
97+
98+
y: sparse.csc_matrix
99+
x: sparse.csr_matrix
100+
bias: float
101+
prob: problem
102+
param: parameter
103+
weights: np.ndarray
104+
pbar: tqdm
105+
queue: queue.SimpleQueue
106+
107+
def __init__(self):
108+
threading.Thread.__init__(self)
109+
110+
@classmethod
111+
def init_trainer(
112+
cls,
113+
y: sparse.csr_matrix,
114+
x: sparse.csr_matrix,
115+
options: str,
116+
verbose: bool,
117+
):
118+
"""Initialize the parallel trainer by setting y, x, parameter and threading related
119+
variables as class variables of ParallelOVRTrainer.
120+
121+
Args:
122+
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
123+
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
124+
options (str): The option string passed to liblinear.
125+
verbose (bool): Output extra progress information.
126+
"""
127+
x, options, bias = _prepare_options(x, options)
128+
cls.y = y.tocsc()
129+
cls.x = x
130+
cls.bias = bias
131+
num_instances, num_classes = cls.y.shape
132+
num_features = cls.x.shape[1]
133+
cls.prob = problem(np.ones((num_instances,)), cls.x)
134+
135+
# remove "-m nr_thread" from options to prevent nested multi-threading
136+
cls.param = parameter(re.sub(r"-m\s+\d+", "", options))
137+
if cls.param.solver_type in [solver_names.L2R_L1LOSS_SVC_DUAL, solver_names.L2R_L2LOSS_SVC_DUAL]:
138+
cls.param.w_recalc = True # only works for solving L1/L2-SVM dual
139+
cls.weights = np.zeros((num_features, num_classes), order="F")
140+
cls.queue = queue.SimpleQueue()
141+
142+
if verbose:
143+
logging.info(f"Training a one-vs-rest model on {num_classes} labels")
144+
for i in range(num_classes):
145+
cls.queue.put(i)
146+
cls.pbar = tqdm(total=num_classes, disable=not verbose)
147+
148+
@classmethod
149+
def del_trainer(cls):
150+
cls.pbar.close()
151+
for key in list(cls.__annotations__):
152+
delattr(cls, key)
153+
154+
def _do_parallel_train(self, y: np.ndarray) -> np.matrix:
155+
"""Wrap around liblinear.liblinearutil.train.
156+
157+
Args:
158+
y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
159+
160+
Returns:
161+
np.matrix: The weights.
162+
"""
163+
if y.shape[0] == 0:
164+
return np.matrix(np.zeros((self.prob.n, 1)))
165+
166+
prob = self.prob.copy()
167+
prob.y = (c_double * prob.l)(*y)
168+
model = train(prob, self.param)
169+
170+
w = np.ctypeslib.as_array(model.w, (self.prob.n, 1))
171+
w = np.asmatrix(w)
172+
# When all labels are -1, we must flip the sign of the weights
173+
# because LIBLINEAR treats the first label as positive, which
174+
# is -1 in this case. But for our usage we need them to be negative.
175+
# For data with both +1 and -1 for labels, LIBLINEAR guarantees
176+
# that +1 is always the first label.
177+
if model.get_labels()[0] == -1:
178+
return -w
179+
else:
180+
# The memory is freed on model deletion so we make a copy.
181+
return w.copy()
182+
183+
def run(self):
184+
while True:
185+
try:
186+
label_idx = self.queue.get_nowait()
187+
except queue.Empty:
188+
break
189+
yi = self.y[:, label_idx].toarray().reshape(-1)
190+
self.weights[:, label_idx] = self._do_parallel_train(2 * yi - 1).ravel()
191+
192+
self.pbar.update()
193+
194+
89195
def train_1vsrest(
90196
y: sparse.csr_matrix,
91197
x: sparse.csr_matrix,
92198
multiclass: bool = False,
93199
options: str = "",
94200
verbose: bool = True,
95201
) -> FlatModel:
96-
"""Train a linear model for multi-label data using a one-vs-rest strategy.
202+
"""Train a linear model parallel on labels for multi-label data using a one-vs-rest strategy.
97203
98204
Args:
99205
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
@@ -106,18 +212,15 @@ def train_1vsrest(
106212
A model which can be used in predict_values.
107213
"""
108214
# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
109-
x, options, bias = _prepare_options(x, options)
110-
111-
y = y.tocsc()
112-
num_class = y.shape[1]
113-
num_feature = x.shape[1]
114-
weights = np.zeros((num_feature, num_class), order="F")
115-
116-
if verbose:
117-
logging.info(f"Training one-vs-rest model on {num_class} labels")
118-
for i in tqdm(range(num_class), disable=not verbose):
119-
yi = y[:, i].toarray().reshape(-1)
120-
weights[:, i] = _do_train(2 * yi - 1, x, options).ravel()
215+
ParallelOVRTrainer.init_trainer(y, x, options, verbose)
216+
num_threads = psutil.cpu_count(logical=False)
217+
trainers = [ParallelOVRTrainer() for _ in range(num_threads)]
218+
for trainer in trainers:
219+
trainer.start()
220+
for trainer in trainers:
221+
trainer.join()
222+
weights, bias = ParallelOVRTrainer.weights, ParallelOVRTrainer.bias
223+
ParallelOVRTrainer.del_trainer()
121224

122225
return FlatModel(
123226
name="1vsrest",
@@ -170,7 +273,7 @@ def _prepare_options(x: sparse.csr_matrix, options: str) -> tuple[sparse.csr_mat
170273
if not "-q" in options_split:
171274
options_split.append("-q")
172275
if not "-m" in options:
173-
options_split.append(f"-m {int(os.cpu_count() / 2)}")
276+
options_split.append(f"-m {psutil.cpu_count(logical=False)}")
174277

175278
options = " ".join(options_split)
176279
return x, options, bias
@@ -212,7 +315,7 @@ def train_thresholding(
212315
thresholds = np.zeros(num_class)
213316

214317
if verbose:
215-
logging.info("Training thresholding model on %s labels", num_class)
318+
logging.info("Training a thresholding model on %s labels", num_class)
216319

217320
num_positives = np.sum(y, 2)
218321
label_order = np.flip(np.argsort(num_positives)).flat
@@ -356,10 +459,11 @@ def _do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix:
356459

357460
w = np.ctypeslib.as_array(model.w, (x.shape[1], 1))
358461
w = np.asmatrix(w)
359-
# Liblinear flips +1/-1 labels so +1 is always the first label,
360-
# but not if all labels are -1.
361-
# For our usage, we need +1 to always be the first label,
362-
# so the check is necessary.
462+
# When all labels are -1, we must flip the sign of the weights
463+
# because LIBLINEAR treats the first label as positive, which
464+
# is -1 in this case. But for our usage we need them to be negative.
465+
# For data with both +1 and -1, LIBLINEAR guarantees that +1
466+
# is always the first label.
363467
if model.get_labels()[0] == -1:
364468
return -w
365469
else:
@@ -440,7 +544,7 @@ def train_cost_sensitive(
440544
weights = np.zeros((num_feature, num_class), order="F")
441545

442546
if verbose:
443-
logging.info(f"Training cost-sensitive model for Macro-F1 on {num_class} labels")
547+
logging.info(f"Training a cost-sensitive model for Macro-F1 on {num_class} labels")
444548
for i in tqdm(range(num_class), disable=not verbose):
445549
yi = y[:, i].toarray().reshape(-1)
446550
w = _cost_sensitive_one_label(2 * yi - 1, x, options)
@@ -549,7 +653,7 @@ def train_cost_sensitive_micro(
549653
bestScore = -np.inf
550654

551655
if verbose:
552-
logging.info(f"Training cost-sensitive model for Micro-F1 on {num_class} labels")
656+
logging.info(f"Training a cost-sensitive model for Micro-F1 on {num_class} labels")
553657
for a in param_space:
554658
tp = fn = fp = 0
555659
for i in tqdm(range(num_class), disable=not verbose):

0 commit comments

Comments
 (0)