diff --git a/lad/rulegenerator/eager.py b/lad/rulegenerator/eager.py index 2072718..81d856e 100644 --- a/lad/rulegenerator/eager.py +++ b/lad/rulegenerator/eager.py @@ -134,15 +134,28 @@ def fit(self, Xbin, y): def __adjust(self): for r in self.__rules: - __cutpoints = [self.__cutpoints[i] for i in self.__selected[r['attributes']]] + conditions = {} + cutpoints = [self.__cutpoints[i] for i in self.__selected[r['attributes']]] + + for i, (att, value) in enumerate(cutpoints): + condition = conditions.get(att, {}) + symbol = r['conditions'][i] # True: <=, False: > + + if symbol: condition[symbol] = min(value, condition.get(symbol, value)) + else: condition[symbol] = max(value, condition.get(symbol, value)) + + conditions[att] = condition r['attributes'].clear() + r['conditions'].clear() r['values'] = [] + + for att in conditions: + for condition in conditions[att]: + r['attributes'].append(att) + r['conditions'].append(condition == '<=') + r['values'].append(conditions[att][condition]) - for c in __cutpoints: - r['attributes'].append(c[0]) - r['values'].append(c[1]) - self.__rules.sort(key=lambda x: x['label']) def __get_stats(self, Xbin, y, instance, attributes):