forked from ombhojane/explainableai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
80 lines (67 loc) · 2.75 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# main.py
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.neural_network import MLPClassifier
from explainableai import XAIWrapper
import argparse
def main(file_path, target_column):
# Import the dataset
print("Importing dataset...")
df = pd.read_csv(file_path)
# Perform EDA
XAIWrapper.perform_eda(df)
X = df.drop(columns=[target_column])
y = df[target_column]
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create models
models = {
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
'Logistic Regression': LogisticRegression(max_iter=1000),
'XGBoost': XGBClassifier(n_estimators=100, random_state=42),
'Neural Network': MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=42)
}
# Create XAIWrapper instance
xai = XAIWrapper()
# Fit the models and run XAI analysis
xai.fit(models, X_train, y_train)
results = xai.analyze()
print("\nLLM Explanation of Results:")
print(results['llm_explanation'])
# Generate the report
try:
xai.generate_report()
except Exception as e:
print(f"An error occurred while generating the report: {str(e)}")
# Example of using the trained model for new predictions
while True:
print("\nEnter values for prediction (or 'q' to quit):")
user_input = {}
for feature in X.columns:
value = input(f"{feature}: ")
if value.lower() == 'q':
return
# Try to convert to float if possible, otherwise keep as string
try:
user_input[feature] = float(value)
except ValueError:
user_input[feature] = value
try:
prediction, probabilities, explanation = xai.explain_prediction(user_input)
print("\nPrediction Results:")
print(f"Prediction: {prediction}")
print(f"Probabilities: {probabilities}")
print("\nLLM Explanation of Prediction:")
print(explanation)
except Exception as e:
print(f"An error occurred during prediction: {str(e)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run XAI analysis on a dataset")
parser.add_argument("file_path", help="Path to the CSV file containing the dataset")
parser.add_argument("target_column", help="Name of the target column in the dataset")
args = parser.parse_args()
main(args.file_path, args.target_column)