Skip to content

Commit

Permalink
Merge pull request #754 from mindsdb/dielectron_problems
Browse files Browse the repository at this point in the history
Dielectron problems
  • Loading branch information
paxcema authored Nov 17, 2021
2 parents 671cd59 + 5298804 commit c9bbf68
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion lightwood/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def from_dict(obj: Dict):
target_weights = obj.get('target_weights', None)
positive_domain = obj.get('positive_domain', False)
timeseries_settings = TimeseriesSettings.from_dict(obj.get('timeseries_settings', {}))
anomaly_detection = obj.get('anomaly_detection', True)
anomaly_detection = obj.get('anomaly_detection', False)
ignore_features = obj.get('ignore_features', [])
fit_on_all = obj.get('fit_on_all', True)
strict_mode = obj.get('strict_mode', True)
Expand Down
2 changes: 1 addition & 1 deletion lightwood/data/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _remove_columns(data: pd.DataFrame, identifiers: Dict[str, object], target:
if mode == "predict":
if (
target in data.columns
and not timeseries_settings.use_previous_target
and (not timeseries_settings.is_timeseries or not timeseries_settings.use_previous_target)
and not anomaly_detection
):
data = data.drop(columns=[target])
Expand Down
3 changes: 2 additions & 1 deletion lightwood/data/encoded_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:

encoded_tensor = self.encoders[col].encode(data, **kwargs)[0]
if torch.isnan(encoded_tensor).any() or torch.isinf(encoded_tensor).any():
raise Exception(f'Encoded tensor: {encoded_tensor} contains nan or inf values')
raise Exception(f'Encoded tensor: {encoded_tensor} contains nan or inf values, this tensor is \
the encoding of column {col} using {self.encoders[col].__class__}')
if col != self.target:
X = torch.cat([X, encoded_tensor])
else:
Expand Down
4 changes: 4 additions & 0 deletions lightwood/data/timeseries_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def transform_timeseries(
if tss.use_previous_target and target not in data.columns:
raise Exception(f"Cannot transform. Missing historical values for target column {target} (`use_previous_target` is set to True).") # noqa

for hcol in tss.historical_columns:
if hcol not in data.columns or data[hcol].isna().any():
raise Exception(f"Cannot transform. Missing values in historical column {hcol}.")

if '__mdb_make_predictions' in original_df.columns:
index = original_df[original_df['__mdb_make_predictions'].map(
{'True': True, 'False': False, True: True, False: False}).isin([True])]
Expand Down
14 changes: 14 additions & 0 deletions lightwood/helpers/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sklearn.metrics import r2_score as sk_r2_score


def r2_score(y_true, y_pred) -> float:
""" Wrapper for sklearn R2 score, lower capped between 0 and 1"""
acc = sk_r2_score(y_true, y_pred)
# Cap at 0
if acc < 0:
acc = 0
# Guard against overflow (> 1 means overflow of negative score)
if acc > 1:
acc = 0

return acc
6 changes: 5 additions & 1 deletion lightwood/helpers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def evaluate_accuracy(data: pd.DataFrame,
ts_analysis=ts_analysis)
else:
true_values = data[target].tolist()
accuracy_function = getattr(importlib.import_module('sklearn.metrics'), accuracy_function_str)
if hasattr(importlib.import_module('lightwood.helpers.accuracy'), accuracy_function_str):
accuracy_function = getattr(importlib.import_module('lightwood.helpers.accuracy'),
accuracy_function_str)
else:
accuracy_function = getattr(importlib.import_module('sklearn.metrics'), accuracy_function_str)
score_dict[accuracy_function_str] = accuracy_function(list(true_values), list(predictions))

return score_dict
Expand Down
12 changes: 10 additions & 2 deletions tests/integration/advanced/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_3_time_series_sktime_mixer(self):
df = pd.DataFrame(columns=['Time', target])
df['Time'] = t
df[target] = ts
df[f'{target}_2x'] = 2 * ts

train = df[:int(len(df) * 0.8)]
test = df[int(len(df) * 0.8):]
Expand All @@ -193,7 +194,8 @@ def test_3_time_series_sktime_mixer(self):
'timeseries_settings': {
'order_by': ['Time'],
'window': 5,
'nr_predictions': 20
'nr_predictions': 20,
'historical_columns': [f'{target}_2x']
}})

json_ai = json_ai_from_problem(df, problem_definition=pdef)
Expand All @@ -209,5 +211,11 @@ def test_3_time_series_sktime_mixer(self):

predictor.learn(train)
ps = predictor.predict(test)

assert r2_score(ps['truth'].values, ps['prediction'].iloc[0]) >= 0.95

# test historical columns asserts
test[f'{target}_2x'].iloc[0] = np.nan
self.assertRaises(Exception, predictor.predict, test)

test.pop(f'{target}_2x')
self.assertRaises(Exception, predictor.predict, test)

0 comments on commit c9bbf68

Please sign in to comment.