Skip to content

Commit a6504e8

Browse files
committed
finished model trainer withouth tracking MLFlow
1 parent 0f54a49 commit a6504e8

File tree

8 files changed

+339
-5
lines changed

8 files changed

+339
-5
lines changed

main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from network_security.components.data_ingestion import DataIngestion
22
from network_security.components.data_validation import DataValidation
33
from network_security.components.data_transformation import DataTransformation
4+
from network_security.components.model_trainer import ModelTrainer
45
from network_security.exceptions.exception import NetworkSecurityException
56
from network_security.logging.logger import logging
67
from network_security.entity.config_entity import (
78
TrainingPipelineConfig,
89
DataIngestionConfig,
910
DataValidationConfig,
1011
DataTransformationConfig,
12+
ModelTrainerConfig
1113
)
1214
import sys
1315

@@ -53,6 +55,18 @@
5355
data_transformation_artifact = data_transformation.initiate_data_transformation()
5456
print(f"Data Transformation Artifact: \n{data_transformation_artifact} \n")
5557

58+
# model trainer configuration
59+
model_trainer_config = ModelTrainerConfig(training_pipeline_config=training_pipeline_config)
60+
model_trainer = ModelTrainer(
61+
data_transformation_artifact = data_transformation_artifact,
62+
model_trainer_config = model_trainer_config
63+
)
64+
65+
# initiating model trainer
66+
logging.info("Initiating model trainer")
67+
model_trainer_artifact = model_trainer.initiate_model_trainer()
68+
print(f"Model Trainer Artifact: \n{model_trainer_artifact} \n")
69+
5670

5771
except Exception as e:
5872
raise NetworkSecurityException(e, sys)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from network_security.exceptions.exception import NetworkSecurityException
2+
from network_security.logging.logger import logging
3+
from network_security.utils.ml_utils.model.estimator import NetworkModel
4+
from network_security.utils.ml_utils.metric.classification_metric import get_classification_score
5+
import network_security.constants.training_pipeline as tp
6+
from network_security.entity.config_entity import (
7+
DataTransformationConfig,
8+
ModelTrainerConfig
9+
)
10+
from network_security.entity.artifact_entity import (
11+
DataTransformationArtifact,
12+
ModelTrainerArtifact
13+
)
14+
from network_security.utils.main_utils.utils import (
15+
save_object,
16+
load_object,
17+
load_numpy_array,
18+
evaluate_models
19+
)
20+
21+
from sklearn.linear_model import LogisticRegression
22+
from sklearn.neighbors import KNeighborsClassifier
23+
from sklearn.tree import DecisionTreeClassifier
24+
from sklearn.ensemble import (
25+
AdaBoostClassifier,
26+
GradientBoostingClassifier,
27+
RandomForestClassifier,
28+
)
29+
from xgboost import XGBClassifier
30+
from sklearn.metrics import r2_score
31+
32+
import pandas as pd
33+
import numpy as np
34+
import os, sys
35+
36+
37+
class ModelTrainer:
38+
def __init__(self, model_trainer_config: ModelTrainerConfig,
39+
data_transformation_artifact: DataTransformationArtifact):
40+
try:
41+
self.model_trainer_config = model_trainer_config
42+
self.data_transformation_artifact = data_transformation_artifact
43+
except Exception as e:
44+
raise NetworkSecurityException(e, sys)
45+
46+
47+
48+
def train_model(self, X_train, y_train, X_test, y_test) -> object:
49+
try:
50+
logging.info("Training the model")
51+
52+
# intializing models
53+
models = {
54+
"LogisticRegression": LogisticRegression(max_iter=1000),
55+
"KNeighborsClassifier": KNeighborsClassifier(),
56+
"DecisionTreeClassifier": DecisionTreeClassifier(),
57+
"RandomForestClassifier": RandomForestClassifier(),
58+
"AdaBoostClassifier": AdaBoostClassifier(),
59+
"GradientBoostingClassifier": GradientBoostingClassifier(),
60+
"XGBClassifier": XGBClassifier()
61+
}
62+
63+
# defining parameters for hyperparameter tuning
64+
params = {
65+
"DecisionTreeClassifier": {
66+
"criterion": ['gini', 'entropy'],
67+
# "splitter": ['best', 'random'],
68+
# "max_features": ['sqrt', 'log2', None],
69+
# "max_depth": [3, 5, 10, 15, 20, None]
70+
},
71+
"RandomForestClassifier": {
72+
# "criterion": ['gini', 'entropy', "log_loss"],
73+
# "max_features": ['sqrt', 'log2', None],
74+
"n_estimators": [50, 100, 200],
75+
"max_depth": [3, 5, 10, 15, 20, None]
76+
},
77+
"GradientBoostingClassifier": {
78+
"loss": ['log_loss', 'exponential'],
79+
# "learning_rate": [0.1, 0.01, 0.001, 0.05],
80+
# "subsample": [0.6, 0.7, 0.75, 0.8, 0.85, 0.9],
81+
# "criterion": ['friedman_mse', 'squared_error'],
82+
# "max_features": ['sqrt', 'log2', None],
83+
"n_estimators": [50, 100, 200],
84+
# "max_depth": [3, 5, 10]
85+
},
86+
"LogisticRegression": {
87+
# "penalty": ['l1', 'l2', 'elasticnet', None],
88+
# "C": [0.01, 0.1, 1.0, 10.0, 100.0],
89+
# "solver": ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'],
90+
"max_iter": [100, 200, 500]
91+
},
92+
"KNeighborsClassifier": {
93+
"n_neighbors": [3, 5, 7, 9, 11],
94+
# "weights": ['uniform', 'distance'],
95+
# "algorithm": ['auto', 'ball_tree', 'kd_tree', 'brute'],
96+
# "p": [1, 2],
97+
# "leaf_size": [10, 20, 30, 40, 50]
98+
},
99+
"AdaBoostClassifier": {
100+
"n_estimators": [50, 100, 200],
101+
"learning_rate": [0.1, 0.01, 0.001, 0.05, 1.0],
102+
# "algorithm": ['SAMME', 'SAMME.R']
103+
},
104+
"XGBClassifier": {
105+
"n_estimators": [50, 100, 200],
106+
# "learning_rate": [0.1, 0.01, 0.001, 0.05],
107+
# "max_depth": [3, 5, 7, 9],
108+
# "gamma": [0, 0.1, 0.2],
109+
# "subsample": [0.6, 0.7, 0.8, 0.9]
110+
}
111+
}
112+
113+
model_report: dict = evaluate_models(
114+
X_train= X_train,
115+
y_train= y_train,
116+
X_test = X_test,
117+
y_test = y_test,
118+
models = models,
119+
params = params
120+
)
121+
122+
# getting the best model score from the report
123+
best_model_name = max(model_report, key=model_report.get)
124+
best_model_score = max(sorted(model_report.values()))
125+
best_model = models[best_model_name]
126+
127+
y_train_pred = best_model.predict(X_train)
128+
train_classification_metric = get_classification_score(y_true=y_train, y_pred=y_train_pred)
129+
130+
# tracking the MLFlow
131+
132+
133+
# getting the test classification metrics
134+
y_test_pred = best_model.predict(X_test)
135+
test_classification_metric = get_classification_score(y_true=y_test, y_pred=y_test_pred)
136+
137+
# loading the object, saving it
138+
preprocessor = load_object(file_path=self.data_transformation_artifact.transformation_object_path)
139+
model_dir_path = os.path.dirname(self.model_trainer_config.trained_model_file_path)
140+
os.makedirs(model_dir_path, exist_ok=True)
141+
142+
# saving the object
143+
network_model = NetworkModel(preprocessor=preprocessor, model=best_model)
144+
save_object(
145+
file_path=self.model_trainer_config.trained_model_file_path,
146+
obj=network_model
147+
)
148+
149+
# saving the model trainer artifact
150+
model_trainer_artifact = ModelTrainerArtifact(
151+
trained_model_file_path=self.model_trainer_config.trained_model_file_path,
152+
train_metric_artifact=train_classification_metric,
153+
test_metric_artifact=test_classification_metric
154+
)
155+
return model_trainer_artifact
156+
except Exception as e:
157+
raise NetworkSecurityException(e, sys)
158+
159+
160+
161+
162+
163+
def initiate_model_trainer(self) -> ModelTrainerArtifact:
164+
try:
165+
logging.info("Initiating model trainer")
166+
train_file_path = self.data_transformation_artifact.transformed_train_file_path
167+
test_file_path = self.data_transformation_artifact.transformed_test_file_path
168+
169+
# loading the training and testing arrays
170+
training_array = load_numpy_array(file_path=train_file_path)
171+
testing_array = load_numpy_array(file_path=test_file_path)
172+
173+
# splitting the training and testing arrays into input and target feature arrays
174+
X_train, y_train = training_array[:, :-1], training_array[:, -1]
175+
X_test, y_test = testing_array[:, :-1], testing_array[:, -1]
176+
177+
# creating model
178+
model = self.train_model(X_train, y_train, X_test, y_test)
179+
180+
logging.info("Model training completed")
181+
return model
182+
except Exception as e:
183+
raise NetworkSecurityException(e, sys)

network_security/constants/training_pipeline/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,13 @@
5555
"missing_values": np.nan,
5656
"n_neighbors": 3,
5757
"weights": "uniform",
58-
}
58+
}
59+
60+
""""
61+
Defining constants for model trainer
62+
"""
63+
MODEL_TRAINER_DIR_NAME: str = "model_trainer"
64+
MODEL_TRAINER_TRAINED_MODEL_DIR: str = "trained_model"
65+
MODEL_TRAINER_TRAINED_MODEL_FILE_NAME: str = "model.pkl"
66+
MODEL_TRAINER_EXPECTED_ACCURACY: float = 0.7
67+
MODEL_TRAINER_OVERFITTING_UNDERFITTING_THRESHOLD: float = 0.1

network_security/entity/artifact_entity.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from dataclasses import dataclass
22

3-
import os
4-
53
@dataclass
64
class DataIngestionArtifact:
75
train_file_path: str
@@ -20,4 +18,16 @@ class DataValidationArtifact:
2018
class DataTransformationArtifact:
2119
transformed_train_file_path: str
2220
transformed_test_file_path: str
23-
transformation_object_path: str
21+
transformation_object_path: str
22+
23+
@dataclass
24+
class ClassificationMetricArtifact:
25+
f1_score: float
26+
precision_score: float
27+
recall_score: float
28+
29+
@dataclass
30+
class ModelTrainerArtifact:
31+
trained_model_file_path: str
32+
train_metric_artifact: ClassificationMetricArtifact
33+
test_metric_artifact: ClassificationMetricArtifact

network_security/entity/config_entity.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,18 @@ def __init__(self, training_pipeline_config:TrainingPipelineConfig):
9696
self.data_transformation_dir,
9797
training_pipeline.DATA_TRANSFORMATION_TRANSFORMED_OBJECT_DIR,
9898
training_pipeline.DATA_TRANSFORMATION_TRANSFORMED_OBJECT_FILE_NAME,
99-
)
99+
)
100+
101+
102+
class ModelTrainerConfig:
103+
def __init__(self, training_pipeline_config: TrainingPipelineConfig):
104+
self.model_trainer_dir: str = os.path.join(
105+
training_pipeline_config.artifact_dir, training_pipeline.MODEL_TRAINER_DIR_NAME
106+
)
107+
self.trained_model_file_path: str = os.path.join(
108+
self.model_trainer_dir,
109+
training_pipeline.MODEL_TRAINER_TRAINED_MODEL_DIR,
110+
training_pipeline.MODEL_TRAINER_TRAINED_MODEL_FILE_NAME
111+
)
112+
self.expected_accuracy: float = training_pipeline.MODEL_TRAINER_EXPECTED_ACCURACY
113+
self.overfitting_underfitting_threshold: float = training_pipeline.MODEL_TRAINER_OVERFITTING_UNDERFITTING_THRESHOLD

network_security/utils/main_utils/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from network_security.exceptions.exception import NetworkSecurityException
22
from network_security.logging.logger import logging
3+
4+
from sklearn.model_selection import GridSearchCV
5+
from sklearn.metrics import r2_score
6+
37
import pandas as pd
48
import numpy as np
59
import os, sys
@@ -39,6 +43,16 @@ def save_numpy_array(file_path: str, array: np.array) -> None:
3943
np.save(file, array)
4044
except Exception as e:
4145
raise NetworkSecurityException(e, sys)
46+
47+
48+
def load_numpy_array(file_path: str) -> np.array:
49+
try:
50+
if not os.path.exists(file_path):
51+
raise Exception(f"The file: {file_path} does not exist")
52+
with open(file_path, "rb") as file:
53+
return np.load(file)
54+
except Exception as e:
55+
raise NetworkSecurityException(e, sys)
4256

4357

4458
def save_object(file_path: str, obj: object) -> None:
@@ -48,3 +62,51 @@ def save_object(file_path: str, obj: object) -> None:
4862
pickle.dump(obj, file)
4963
except Exception as e:
5064
raise NetworkSecurityException(e, sys)
65+
66+
67+
def load_object(file_path: str) -> object:
68+
try:
69+
if not os.path.exists(file_path):
70+
raise Exception(f"The file: {file_path} does not exist")
71+
with open(file_path, "rb") as file:
72+
return pickle.load(file)
73+
except Exception as e:
74+
raise NetworkSecurityException(e, sys)
75+
76+
77+
def evaluate_models(X_train, y_train,
78+
X_test, y_test,
79+
models: dict,
80+
params: dict
81+
) -> dict:
82+
try:
83+
report = {}
84+
85+
for model_name, model in models.items():
86+
# Get parameters for this model
87+
param = params[model_name]
88+
89+
# Perform GridSearch
90+
gs = GridSearchCV(model, param, cv=3)
91+
gs.fit(X_train, y_train)
92+
93+
# Set best parameters and retrain
94+
model.set_params(**gs.best_params_)
95+
model.fit(X_train, y_train)
96+
97+
# Make predictions
98+
y_train_pred = model.predict(X_train)
99+
y_test_pred = model.predict(X_test)
100+
101+
# Calculate scores
102+
train_model_score = r2_score(y_train, y_train_pred)
103+
test_model_score = r2_score(y_test, y_test_pred)
104+
105+
# Store test score in report
106+
report[model_name] = test_model_score
107+
108+
return report
109+
110+
except Exception as e:
111+
raise NetworkSecurityException(e, sys)
112+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from network_security.entity.artifact_entity import ClassificationMetricArtifact
2+
from network_security.exceptions.exception import NetworkSecurityException
3+
from sklearn.metrics import f1_score, precision_score, recall_score
4+
import numpy as np
5+
import sys
6+
7+
8+
def get_classification_score(y_true: np.array, y_pred: np.array) -> ClassificationMetricArtifact:
9+
try:
10+
model_f1_score = f1_score(y_true, y_pred)
11+
model_precision_score = precision_score(y_true, y_pred)
12+
model_recall_score = recall_score(y_true, y_pred)
13+
14+
classification_metric = ClassificationMetricArtifact(
15+
f1_score=model_f1_score,
16+
precision_score=model_precision_score,
17+
recall_score=model_recall_score
18+
)
19+
return classification_metric
20+
except Exception as e:
21+
raise NetworkSecurityException(e, sys)

0 commit comments

Comments
 (0)