Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Weighted regression #1210

Merged
merged 19 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lightwood/api/json_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def lookup_encoder(
"positive_domain"
] = "$statistical_analysis.positive_domain"

if problem_defintion.target_weights is not None:
encoder_dict["args"][
"target_weights"
] = problem_defintion.target_weights

# Time-series representations require more advanced flags
if tss.is_timeseries:
gby = tss.group_by if tss.group_by is not None else []
Expand Down
2 changes: 1 addition & 1 deletion lightwood/data/encoded_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, encoders: Dict[str, BaseEncoder], data_frame: pd.DataFrame, t

Note: normal behavior is to cache encoded representations to avoid duplicated computations. If you want an option to disable, this please open an issue.

:param encoders: list of Lightwood encoders used to encode the data per each column.
:param encoders: dictionary of Lightwood encoders used to encode the data per each column.
:param data_frame: original dataframe.
:param target: name of the target column to predict.
""" # noqa
Expand Down
44 changes: 38 additions & 6 deletions lightwood/encoder/numeric/numeric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Union
from typing import Union, Dict
from copy import deepcopy as dc

import torch
import numpy as np
Expand All @@ -20,11 +21,15 @@ class NumericEncoder(BaseEncoder):
The ``absolute_mean`` is computed in the ``prepare`` method and is just the mean of the absolute values of all numbers feed to prepare (which are not none)

``none`` stands for any number that is an actual python ``None`` value or any sort of non-numeric value (a string, nan, inf)
""" # noqa
""" # noqa

def __init__(self, data_type: dtype = None, is_target: bool = False, positive_domain: bool = False):
def __init__(self, data_type: dtype = None,
target_weights: Dict[float, float] = None,
is_target: bool = False,
positive_domain: bool = False):
"""
:param data_type: The data type of the number (integer, float, quantity)
:param target_weights: a dictionary of weights to use on the examples.
:param is_target: Indicates whether the encoder refers to a target column or feature column (True==target)
:param positive_domain: Forces the encoder to always output positive values
"""
Expand All @@ -34,12 +39,19 @@ def __init__(self, data_type: dtype = None, is_target: bool = False, positive_do
self.decode_log = False
self.output_size = 4 if not self.is_target else 3

# Weight-balance info if encoder represents target
self.target_weights = None
self.index_weights = None
if self.is_target and target_weights is not None:
self.target_weights = dc(target_weights)
self.index_weights = torch.tensor(list(self.target_weights.values()))

def prepare(self, priming_data: pd.Series):
"""
"NumericalEncoder" uses a rule-based form to prepare results on training (priming) data. The averages etc. are taken from this distribution.

:param priming_data: an iterable data structure containing numbers numbers which will be used to compute the values used for normalizing the encoded representations
""" # noqa
""" # noqa
if self.is_prepared:
raise Exception('You can only call "prepare" once for a given encoder.')

Expand All @@ -57,7 +69,8 @@ def encode(self, data: Union[np.ndarray, pd.Series]):
if isinstance(data, pd.Series):
data = data.values

inp_data = np.nan_to_num(data.astype(float), nan=0, posinf=np.finfo(np.float32).max, neginf=np.finfo(np.float32).min) # noqa
inp_data = np.nan_to_num(data.astype(float), nan=0, posinf=np.finfo(np.float32).max,
neginf=np.finfo(np.float32).min) # noqa
if not self.positive_domain:
sign = np.vectorize(self._sign_fn, otypes=[float])(inp_data)
else:
Expand Down Expand Up @@ -97,7 +110,7 @@ def decode(self, encoded_values: torch.Tensor, decode_log: bool = None) -> list:
:param decode_log: Whether to decode the ``log`` or ``linear`` part of the representation, since the encoded vector contains both a log and a linear part

:returns: The decoded array
""" # noqa
""" # noqa

if not self.is_prepared:
raise Exception('You need to call "prepare" before calling "encode" or "decode".')
Expand Down Expand Up @@ -145,3 +158,22 @@ def decode(self, encoded_values: torch.Tensor, decode_log: bool = None) -> list:
ret[mask_none] = None

return ret.tolist() # TODO: update signature on BaseEncoder and replace all encs to return ndarrays

def get_weights(self, label_data):
# get a sorted list of intervals to assign weights. Keys are the interval edges.
target_weight_keys = np.array(list(self.target_weights.keys()))
target_weight_values = np.array(list(self.target_weights.values()))
sorted_indices = np.argsort(target_weight_keys)

# get sorted arrays for vector numpy operations
target_weight_keys = target_weight_keys[sorted_indices]
target_weight_values = target_weight_values[sorted_indices]

# find the indices of the bins according to the keys. clip to the length of the weight values (search sorted
# returns indices from 0 to N with N = len(target_weight_keys).
assigned_target_weight_indices = np.clip(a=np.searchsorted(target_weight_keys, label_data),
a_min=0,
a_max=len(target_weight_keys) - 1).astype(np.int32)

return target_weight_values[assigned_target_weight_indices]

30 changes: 20 additions & 10 deletions lightwood/mixer/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs


optuna.logging.set_verbosity(optuna.logging.CRITICAL)


Expand Down Expand Up @@ -95,7 +94,8 @@ def __init__(
if not gpu_works:
self.device = torch.device('cpu')
self.device_str = 'cpu'
log.warning('LightGBM running on CPU, this somewhat slower than the GPU version, consider using a GPU instead') # noqa
log.warning(
'LightGBM running on CPU, this somewhat slower than the GPU version, consider using a GPU instead') # noqa
else:
self.device = torch.device('cuda')
self.device_str = 'gpu'
Expand Down Expand Up @@ -137,10 +137,17 @@ def _to_dataset(self, data: Dict[str, Dict], output_dtype: str):
if weight_map is not None:
data[subset_name]['weights'] = [weight_map[x] for x in label_data]
label_data = self.ordinal_encoder.transform(np.array(label_data).reshape(-1, 1)).flatten()
elif output_dtype == dtype.integer:
label_data = label_data.clip(-pow(2, 63), pow(2, 63)).astype(int)
elif output_dtype in self.float_dtypes:
label_data = label_data.astype(float)
elif output_dtype in self.num_dtypes:
if weight_map is not None:
target_encoder = data[subset_name]['ds'].encoders[self.target]

# get the weights from the numeric target encoder
data[subset_name]['weights'] = target_encoder.get_weights(label_data)

if output_dtype in self.float_dtypes:
label_data = label_data.astype(float)
elif output_dtype == dtype.integer:
label_data = label_data.clip(-pow(2, 63), pow(2, 63)).astype(int)

data[subset_name]['label_data'] = label_data

Expand Down Expand Up @@ -206,12 +213,15 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
Only happens sometimes and I can find no pattern as to when, happens for multiple input and target types.

Why does the following crash happen and what does it mean? No idea, closest relationships I can find is /w optuna modifying parameters after the dataset is create: https://github.com/microsoft/LightGBM/issues/4019 | But why this would apply here makes no sense. Could have to do with the `train` process of lightgbm itself setting a "set only once" property on a dataset when it starts. Dunno, if you find out replace this comment with the real reason.
''' # noqa
''' # noqa
kwargs = {}
if 'verbose_eval' in inspect.getfullargspec(lightgbm.train).args:
kwargs['verbose_eval'] = False
self.model = lightgbm.train(self.params, lightgbm.Dataset(data['train']['data'], label=data['train']
['label_data'], weight=data['train']['weights']), **kwargs)
self.model = lightgbm.train(self.params,
lightgbm.Dataset(data['train']['data'],
label=data['train']['label_data'],
weight=data['train']['weights']),
**kwargs)
end = time.time()
seconds_for_one_iteration = max(0.1, end - start)

Expand All @@ -232,7 +242,7 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:

# Train the models
log.info(
f'Training GBM ({model_generator}) with {self.num_iterations} iterations given {self.stop_after} seconds constraint') # noqa
f'Training GBM ({model_generator}) with {self.num_iterations} iterations given {self.stop_after} seconds constraint') # noqa
if self.num_iterations < 1:
self.num_iterations = 1
self.params['num_iterations'] = int(self.num_iterations)
Expand Down
45 changes: 35 additions & 10 deletions lightwood/mixer/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ def _to_dataset(self, ds: EncodedDs, output_dtype: str, mode='train'):
data = data.cpu().numpy()

if mode in ('train', 'dev'):
weights = []
label_data = ds.get_column_original_data(self.target)
if output_dtype in self.cls_dtypes:
if mode == 'train': # TODO weight maps?
if mode == 'train':
self.ordinal_encoder = OrdinalEncoder()
self.label_set = list(set(label_data))
self.ordinal_encoder.fit(np.array(list(self.label_set)).reshape(-1, 1))
Expand All @@ -131,14 +132,26 @@ def _to_dataset(self, ds: EncodedDs, output_dtype: str, mode='train'):
if x in self.label_set:
filtered_label_data.append(x)

weight_map = getattr(self.target_encoder, 'target_weights', None)
if weight_map is not None:
weights = [weight_map[x] for x in label_data]

label_data = self.ordinal_encoder.transform(np.array(filtered_label_data).reshape(-1, 1)).flatten()

elif output_dtype == dtype.integer:
label_data = label_data.clip(-pow(2, 63), pow(2, 63)).astype(int)
elif output_dtype in self.float_dtypes:
label_data = label_data.astype(float)
elif output_dtype in self.num_dtypes:
weight_map = getattr(self.target_encoder, 'target_weights', None)
if weight_map is not None:
target_encoder = ds.encoders[self.target]

# get the weights from the numeric target encoder
weights = target_encoder.get_weights(label_data)

if output_dtype in self.float_dtypes:
label_data = label_data.astype(float)
elif output_dtype == dtype.integer:
label_data = label_data.clip(-pow(2, 63), pow(2, 63)).astype(int)

return data, label_data
return data, label_data, weights

else:
return data
Expand Down Expand Up @@ -175,8 +188,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
}

# Prepare the data
train_dataset, train_labels = self._to_dataset(train_data, output_dtype, mode='train')
dev_dataset, dev_labels = self._to_dataset(dev_data, output_dtype, mode='dev')
train_dataset, train_labels, train_weights = self._to_dataset(train_data, output_dtype, mode='train')
dev_dataset, dev_labels, dev_weights = self._to_dataset(dev_data, output_dtype, mode='dev')

if output_dtype not in self.num_dtypes:
self.all_classes = self.ordinal_encoder.categories_[0]
Expand All @@ -191,7 +204,13 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:

with xgb.config_context(verbosity=0):
self.model = model_class(**self.params)
self.model.fit(train_dataset, train_labels, eval_set=[(dev_dataset, dev_labels)])
if train_weights is not None and dev_weights is not None:
self.model.fit(train_dataset, train_labels, sample_weight=train_weights,
eval_set=[(dev_dataset, dev_labels)],
sample_weight_eval_set=[dev_weights])
else:
self.model.fit(train_dataset, train_labels,
eval_set=[(dev_dataset, dev_labels)])

end = time.time()
seconds_for_one_iteration = max(0.1, end - start)
Expand Down Expand Up @@ -224,7 +243,13 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:

with xgb.config_context(verbosity=0):
self.model = model_class(**self.params)
self.model.fit(train_dataset, train_labels, eval_set=[(dev_dataset, dev_labels)])
if train_weights is not None and dev_weights is not None:
self.model.fit(train_dataset, train_labels, sample_weight=train_weights,
eval_set=[(dev_dataset, dev_labels)],
sample_weight_eval_set=[dev_weights])
else:
self.model.fit(train_dataset, train_labels,
eval_set=[(dev_dataset, dev_labels)])

if self.fit_on_dev:
self.partial_fit(dev_data, train_data)
Expand Down
26 changes: 25 additions & 1 deletion tests/unit_tests/encoder/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_encode_and_decode(self):
def test_positive_domain(self):
data = pd.Series([-1, -2, -100, 5, 10, 15])
for encoder in [NumericEncoder(), TsNumericEncoder()]:
encoder.is_target = True # only affects target values
encoder.is_target = True # only affects target values
encoder.positive_domain = True
encoder.prepare(data)
decoded_vals = encoder.decode(encoder.encode(data))
Expand Down Expand Up @@ -110,3 +110,27 @@ def test_nan_encoding(self):
assert is_none(dec)
else:
assert not is_none(x) or x != 0.0

def test_weights(self):
num_bins = 10
data = np.random.normal(loc=0.0, scale=1.0, size=1000)
hist, bin_edges = np.histogram(data, bins=num_bins, density=False)

# constrict bins so that final histograms align, throw out minimum bin as the np.searchsorted is left justified
# and this leads always to a singleton bin that contains the lowest value.
bin_edges = bin_edges[1:]

# construct target weight mapping. This mapping will round each entry to the lower bin edge.
target_weights = {bin_edge: bin_edge for bin_edge in bin_edges}
self.assertTrue(type(target_weights) is dict)

# apply weight mapping
encoder = NumericEncoder(is_target=True, target_weights=target_weights)
generated_weights = encoder.get_weights(label_data=data)

self.assertTrue(type(generated_weights) is np.ndarray)

# distributions should match
gen_hist, _ = np.histogram(generated_weights, bins=num_bins, density=False)

self.assertTrue(np.all(np.equal(hist, gen_hist)))
87 changes: 87 additions & 0 deletions tests/unit_tests/mixer/test_lgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
import numpy as np
import pandas as pd
from lightwood.api.types import ProblemDefinition
from lightwood.api.high_level import json_ai_from_problem, code_from_json_ai, predictor_from_code
import importlib

np.random.seed(42)


@unittest.skipIf(importlib.util.find_spec('lightgbm') is None, "LightGBM is not available, skipping LightGBM tests.")
class TestBasic(unittest.TestCase):
QuantumPlumber marked this conversation as resolved.
Show resolved Hide resolved

def get_submodels(self):
submodels = [
{
'module': 'LightGBM',
'args': {
'stop_after': '$problem_definition.seconds_per_mixer',
'fit_on_dev': True,
'target': '$target',
'dtype_dict': '$dtype_dict',
'target_encoder': '$encoders[self.target]',
'use_optuna': True
}
},
]
return submodels

def test_0_regression(self):
"""
This test mocks a dataset intended to demonstrate the efficacy of weighting. The operation does not successfully
test if the weighting procedure works as intended, but does test the code for bugs.
"""

# generate data that mocks an observational skew by adding a linear selection to data
data_size = 100000
loc = 100.0
scale = 10.0
eps = .1
target_data = np.random.normal(loc=loc, scale=scale, size=data_size)
epsilon = np.random.normal(loc=0.0, scale=loc * eps, size=len(target_data))
feature_data = target_data + epsilon
df = pd.DataFrame({'feature': feature_data, 'target': target_data})

hist, bin_edges = np.histogram(target_data, bins=10, density=False)
fracs = np.linspace(1, 100, len(hist))
fracs = fracs / fracs.sum()
target_size = 10000
skewed_arr_list = []
for i in range(len(hist)):
frac = fracs[i]
low_edge = bin_edges[i]
high_edge = bin_edges[i + 1]

bin_array = target_data[target_data <= high_edge]
bin_array = bin_array[bin_array >= low_edge]

# select only a fraction fo the elements in this bin
bin_array = bin_array[:int(target_size * frac)]

skewed_arr_list.append(bin_array)

skewed_arr = np.concatenate(skewed_arr_list)
epsilon = np.random.normal(loc=0.0, scale=loc * eps, size=len(skewed_arr))
skewed_feat = skewed_arr + epsilon
skew_df = pd.DataFrame({'feature': skewed_feat, 'target': skewed_arr})

# generate data set weights to remove bias.
hist, bin_edges = np.histogram(skew_df['target'].to_numpy(), bins=10, density=False)
hist = 1 - hist / hist.sum()
target_weights = {bin_edge: bin_frac for bin_edge, bin_frac in zip(bin_edges, hist)}

pdef = ProblemDefinition.from_dict({'target': 'target', 'target_weights': target_weights, 'time_aim': 80})
jai = json_ai_from_problem(skew_df, pdef)

jai.model['args']['submodels'] = self.get_submodels()
code = code_from_json_ai(jai)
predictor = predictor_from_code(code)

predictor.learn(skew_df)
output_df = predictor.predict(df)

output_mean = output_df['prediction'].mean()

self.assertTrue(np.all(np.isclose(output_mean, loc, atol=0., rtol=.03)),
msg=f"the output mean {output_mean} is not close to {loc}")
Loading
Loading