Skip to content

Commit

Permalink
Rules can be printted now. However it is necessary postprocess the co…
Browse files Browse the repository at this point in the history
…nditions to simplify the rules. Maybe in a next commit
  • Loading branch information
vauxgomes committed Apr 1, 2021
1 parent d427565 commit f1d6ce7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
10 changes: 8 additions & 2 deletions examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=1)

# Clasisfier
clf = LADClassifier(mode='lazy')
clf = LADClassifier(mode='eager')
clf.fit(X_train, y_train)

y_hat = clf.predict(X_test)

print(classification_report(y_test, y_hat))
print(classification_report(y_test, y_hat))

print(clf)

# scores = cross_validate(LADClassifier(mode='eager'), X, y, scoring=['accuracy'])

# print(np.mean(scores['test_accuracy']))
7 changes: 5 additions & 2 deletions lad/lad.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fit(self, X, y):

elif self.mode == 'lazy':
self.model = LazyPatterns(cpb, gsc)

self.model.fit(Xbin, y)

return self # `fit` should always return `self`
Expand All @@ -89,4 +89,7 @@ def predict_proba(self, X):
X = check_array(X, accept_sparse=True)
check_is_fitted(self, 'is_fitted_')

return self.model.predict_proba(X)
return self.model.predict_proba(X)

def __str__(self):
return self.model.__str__()
29 changes: 25 additions & 4 deletions lad/rulegenerator/eager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np


class MaxPatterns():

def __init__(self, binarizer, selector, purity):
Expand All @@ -10,9 +9,6 @@ def __init__(self, binarizer, selector, purity):
self.__cutpoints = binarizer.get_cutpoints()
self.__selected = selector.get_selected()

def get_rules(self):
return self.__rules

def predict(self, X):
weights = {}

Expand Down Expand Up @@ -146,6 +142,8 @@ def __adjust(self):
for c in __cutpoints:
r['attributes'].append(c[0])
r['values'].append(c[1])

self.__rules.sort(key=lambda x: x['label'])

def __get_stats(self, Xbin, y, instance, attributes):
covered = np.where(
Expand Down Expand Up @@ -174,3 +172,26 @@ def __get_stats(self, Xbin, y, instance, attributes):
max(1.0, distance_other)/max(1.0, len(uncovered_other)))

return len(covered), counts[argmax], purity, label, discrepancy

def __str__(self):
s = f'MaxPatterns Set of Rules [{len(self.__rules)}]:\n'

for r in self.__rules:
label = r['label']
# weight = r['weight']
conditions = []

for i, condition in enumerate(r['conditions']):
att = r['attributes'][i]
val = r['values'][i]

if (condition):
conditions.append(f'att{att} <= {val:.4}')
else:
conditions.append(f'att{att} > {val:.4}')

# Label -> CONDITION_1 AND CONDITION_2 AND CONDITION_n
s += f'{label} \u2192 {" AND ".join(conditions)}\n'

return s

3 changes: 3 additions & 0 deletions lad/rulegenerator/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ def __get_stats(self, instance, attributes, label):
lift = (counts[argmax]/len(covered[0]))/(self.__labels[unique[argmax]]/self.__y.shape[0])

return label, confidence, support, lift

def __str__(self):
print(f'LazyPatterns Set of Rules [None]:')

0 comments on commit f1d6ce7

Please sign in to comment.