-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathlstm_classifier_predict.py
49 lines (35 loc) · 1.45 KB
/
lstm_classifier_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import print_function
from sklearn import metrics
import pandas as pd
from sklearn.model_selection import train_test_split
from keras_fake_news_detector.library.utility.plot_utils import plot_confusion_matrix
from keras_fake_news_detector.library.classifiers.recurrent_networks import LstmClassifier
import numpy as np
def main():
np.random.seed(42)
data_dir_path = './data'
model_dir_path = './models'
config_file_path = model_dir_path + '/' + LstmClassifier.model_name + '-config.npy'
weight_file_path = model_dir_path + '/' + LstmClassifier.model_name + '-weights.h5'
print('loading csv file ...')
# Import `fake_or_real_news.csv`
df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv")
# Set `y`
Y = [1 if label == 'REAL' else 0 for label in df.label]
# Drop the `label` column
df.drop("label", axis=1)
X = df['text']
config = np.load(config_file_path).item()
classifier = LstmClassifier(config)
classifier.load_weights(weight_file_path)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42)
print('testing size: ', len(Xtest))
print('start predicting ...')
pred = classifier.predict(Xtest)
print(pred)
score = metrics.accuracy_score(Ytest, pred)
print("accuracy: %0.3f" % score)
cm = metrics.confusion_matrix(Ytest, pred, labels=[0, 1])
plot_confusion_matrix(cm, classes=[0, 1])
if __name__ == '__main__':
main()