diff --git a/requirements.txt b/requirements.txt index cbf30d9..7895d3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ seaborn matplotlib joblib numpy +xgboost diff --git a/src/model.py b/src/model.py index a0c986e..b5967b4 100644 --- a/src/model.py +++ b/src/model.py @@ -13,6 +13,7 @@ from sklearn.metrics import precision_recall_curve from sklearn.metrics import roc_curve from sklearn.metrics import roc_auc_score +from xgboost import XGBClassifier import matplotlib.pyplot as plt @@ -29,20 +30,8 @@ def train(data, num_estimators, isDataFrame=False): X, y, test_size=0.3, random_state=0 ) - pipe = Pipeline( - [ - ("scaler", StandardScaler()), - ( - "RFC", - RandomForestClassifier( - criterion="gini", - max_depth=10, - max_features="auto", - n_estimators=num_estimators, - ), - ), - ] - ) + pipe = Pipeline([('scaler', StandardScaler()), ('XGB', XGBClassifier())]) + training_logs = pipe.fit(X_train, y_train)