From 9cfcf56df717b26e263a3b57d67191b30c49d542 Mon Sep 17 00:00:00 2001 From: Vaux Gomes Date: Wed, 12 Oct 2022 17:59:19 -0300 Subject: [PATCH] Removing unnecessary conditions from rules --- lad/rulegenerator/eager.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/lad/rulegenerator/eager.py b/lad/rulegenerator/eager.py index 2072718..81d856e 100644 --- a/lad/rulegenerator/eager.py +++ b/lad/rulegenerator/eager.py @@ -134,15 +134,28 @@ def fit(self, Xbin, y): def __adjust(self): for r in self.__rules: - __cutpoints = [self.__cutpoints[i] for i in self.__selected[r['attributes']]] + conditions = {} + cutpoints = [self.__cutpoints[i] for i in self.__selected[r['attributes']]] + + for i, (att, value) in enumerate(cutpoints): + condition = conditions.get(att, {}) + symbol = r['conditions'][i] # True: <=, False: > + + if symbol: condition[symbol] = min(value, condition.get(symbol, value)) + else: condition[symbol] = max(value, condition.get(symbol, value)) + + conditions[att] = condition r['attributes'].clear() + r['conditions'].clear() r['values'] = [] + + for att in conditions: + for condition in conditions[att]: + r['attributes'].append(att) + r['conditions'].append(condition == '<=') + r['values'].append(conditions[att][condition]) - for c in __cutpoints: - r['attributes'].append(c[0]) - r['values'].append(c[1]) - self.__rules.sort(key=lambda x: x['label']) def __get_stats(self, Xbin, y, instance, attributes):