-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathglobaleval.py
More file actions
235 lines (193 loc) · 9.88 KB
/
globaleval.py
File metadata and controls
235 lines (193 loc) · 9.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import pandas as pd
import os
from sklearn.metrics import f1_score, recall_score, precision_score, roc_auc_score
import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Evaluate detection results (output from the detection script)')
parser.add_argument('--dir', default='example/metadata', type=str, help='Path to bumblebee dataset folder')
parser.add_argument('--pred', default='detection', type=str, help='Path to predictions folder')
args = parser.parse_args()
## list the sites from the pred folder
siteslist = []
for site in os.listdir(os.path.join(args.dir,args.pred)):
if site.startswith('indices_') and site.endswith('.csv'):
sitename = site.replace('indices_','').replace('.csv','')
siteslist.append(sitename)
print(f"Name of folder (model): {args.dir}")
print(f"Sites found: {siteslist}")
for currentsite in siteslist:
# Load the predictions (output from process.py)
predictions = os.path.join(args.dir,args.pred,f"indices_{currentsite}.csv")
Df= pd.read_csv(predictions)
## add a column with the site name
Df['site'] = currentsite
# Load all ground truth csv files
gt_files = []
for root, dirs, files in os.walk(os.path.join(args.dir,currentsite)):
for file in files:
if os.path.splitext(file)[1] == ".txt":
gt_files.append(os.path.join(root, file))
# Check if any ground truth files were found
if not gt_files:
print("No ground truth files found.")
continue
# Load all ground truth dataframes
dfs = []
df_preds_gt = []
for gt_file in gt_files:
df = pd.read_csv(gt_file, sep="\t", header=None)
# add a column with the name of the file
df['file'] = os.path.basename(gt_file)
## find the predictions corresponding to the annotations
## filter the predictions Dataframe to only include the ones that are in the ground truth file
gtbasename = os.path.basename(gt_file)
gtbasename = os.path.splitext(gtbasename)[0]
## remove '_labels.txt' and replace with '.wav'
gtbasename = gtbasename.replace('_labels', '')
gtbasename = gtbasename + '.wav'
# get the predictions that correspond to the gtbasename
df_pred = Df[Df.loc[:,'name'] == gtbasename]
## if this is empty, skip
if df_pred.empty:
print(f"No corresping wave file found for {gtbasename} ground truth, skipping.")
continue
df_pred['buzzlabel'] = 0
# for each row in the gt file, check if the start and end times are in the predictions
df['pred'] = 0.
df['flacfile'] = 'null'
for index, row in df.iterrows():
start = row[0]
end = row[1]
pred = df_pred[(df_pred.loc[:,'start'] >= start)]
pred2 = df_pred[(df_pred.loc[:,'start'] >= start) & (df_pred.loc[:,'start'] + 5 <= end)]
#if pred is both empty, skip
if pred.empty:
print(f"No prediction found for ground truth event starting at {start} and ending at {end} in file {gtbasename}, skipping.")
continue
df_pred.loc[(df_pred.loc[:,'start'] >= start) & (df_pred.loc[:,'start'] + 5 <= end),'buzzlabel'] = 1
df.at[index, 'flacfile'] = pred.iloc[0]['flacfile'] # we just take the first flacfile that matches the start time
if len(pred2) > 0:
# if there are multiple predictions (event longer than 5 seconds), take the mean
buzz = np.mean(pred2.loc[:,'buzz'])
df.at[index, 'pred'] = buzz
else:
# if the ground truth event is shorter than 5 seconds, take the first chunk of 5 seconds
buzz = pred.iloc[0]['buzz']
df.at[index, 'pred'] = buzz
df['site'] = currentsite
df_preds_gt.append(df_pred)
dfs.append(df)
# Concatenate all dataframes into a single dataframe
df_gt = pd.concat(dfs, ignore_index=True)
df_final_preds_gt = pd.concat(df_preds_gt, ignore_index=True)
## rename the columns
df_gt.columns = ['start', 'end', 'event','file', 'pred','flacfile','site']
### aggregate the results in the global list
if currentsite == siteslist[0]:
df_gt_global = df_gt
df_final_preds_gt_global = df_final_preds_gt
else:
df_gt_global = pd.concat([df_gt_global, df_gt], ignore_index=True)
df_final_preds_gt_global = pd.concat([df_final_preds_gt_global, df_final_preds_gt], ignore_index=True)
## end of loop over sites
## here starts the global evaluation
# Calculate the AUC
aucscore = roc_auc_score(df_final_preds_gt_global['buzzlabel'], df_final_preds_gt_global['buzz'])
print(f"AUC: {aucscore:.4f} for site global")
# compute the roc curve and find the threshold that gives the best f1 score
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(df_final_preds_gt_global['buzzlabel'], df_final_preds_gt_global['buzz'])
## plot the curve
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='blue', label='ROC curve (area = %0.2f)' % aucscore)
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.savefig(os.path.join(args.dir,args.pred,f"global_roc_curve.png"))
plt.close()
f1_scores = []
for threshold in thresholds:
y_pred = (df_final_preds_gt_global['buzz'] >= threshold).astype(int)
f1 = f1_score(df_final_preds_gt_global['buzzlabel'], y_pred)
f1_scores.append(f1)
best_threshold = thresholds[np.argmax(f1_scores)]
tpr_best = tpr[np.argmax(f1_scores)]
fpr_best = fpr[np.argmax(f1_scores)]
print(f"Best F1 score: {max(f1_scores):.4f} for site global")
print(f"Best TPR: {tpr_best:.4f} for site global")
print(f"Best FPR: {fpr_best:.4f} for site global")
print(f"Best threshold: {best_threshold:.4f} for site global")
# Calculate the precision and recall
precision = precision_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= best_threshold).astype(int))
recall = recall_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= best_threshold).astype(int))
print(f"Precision: {precision:.4f} for site global")
print(f"Recall: {recall:.4f} for site global")
# Calculate the f1 score
f1 = f1_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= best_threshold).astype(int))
print(f"F1 score: {f1:.4f} for site global")
# create a new column in the predictions dataframe by thresholding the buzz column accordingly with the best threshold
df_final_preds_gt_global['pred_best'] = (df_final_preds_gt_global['buzz'] >= best_threshold).astype(int)
## Calculate the scores for TPR of 0.9 and 0.95
tpr_values = [0.9, 0.95, tpr_best]
results = []
for tpr_target in tpr_values:
fpr_target = fpr[np.argmax(tpr >= tpr_target)]
print(f"FPR at TPR {tpr_target}: {fpr_target:.4f} for site global")
# Corresponding threshold
threshold_target = thresholds[np.argmax(tpr >= tpr_target)]
print(f"Threshold at TPR {tpr_target}: {threshold_target:.4f} for site global")
# Calculate precision and recall at TPR target
precision_target = precision_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= threshold_target).astype(int))
recall_target = recall_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= threshold_target).astype(int))
print(f"Precision at TPR {tpr_target}: {precision_target:.4f} for site global")
print(f"Recall at TPR {tpr_target}: {recall_target:.4f} for site global")
# Calculate F1 score at TPR target
f1_target = f1_score(df_final_preds_gt_global['buzzlabel'], (df_final_preds_gt_global['buzz'] >= threshold_target).astype(int))
print(f"F1 score at TPR {tpr_target}: {f1_target:.4f} for site global")
# create a new column in the predictions dataframe by thresholding the buzz column accordingly, name this column pred_{thr}
df_final_preds_gt_global[f'pred_{tpr_target}'] = (df_final_preds_gt_global['buzz'] >= threshold_target).astype(int)
results.append({
'TPR': tpr_target,
'FPR': fpr_target,
'Threshold': threshold_target,
'Precision': precision_target,
'Recall': recall_target,
'F1 score': f1_target
})
# Add the results for best, TPR 0.9 and 0.95 to the dataframe
df_results = pd.DataFrame({
'AUC': [aucscore],
'Best threshold': [best_threshold],
'Precision': [precision],
'Recall': [recall],
'F1 score': [f1],
'FPR at TPR 0.9': [results[0]['FPR']],
'Threshold at TPR 0.9': [results[0]['Threshold']],
'Precision at TPR 0.9': [results[0]['Precision']],
'Recall at TPR 0.9': [results[0]['Recall']],
'F1 score at TPR 0.9': [results[0]['F1 score']],
'FPR at TPR 0.95': [results[1]['FPR']],
'Threshold at TPR 0.95': [results[1]['Threshold']],
'Precision at TPR 0.95': [results[1]['Precision']],
'Recall at TPR 0.95': [results[1]['Recall']],
'F1 score at best threshold' : [results[2]['F1 score']],
'Precision at best threshold': [results[2]['Precision']],
'Recall at best threshold': [results[2]['Recall']],
'F1 score at best threshold': [results[2]['F1 score']],
'FPR at best threshold': [results[2]['FPR']]
})
df_results.to_csv(os.path.join(args.dir,args.pred,f"global_metrics.csv"), index=False)
print(df_results)
df_gt_global.to_csv(os.path.join(args.dir,args.pred,f"global_gtonly.csv"), index=False)
## change column order to have buzz domesticanimals dB buzzlabel pred_0.9 pred_0.95
#cols = df_final_preds_gt_global.columns.tolist()
#cols = cols[-3:] + cols[:-4]
#df_final_preds_gt_global = df_final_preds_gt_global[cols]
df_final_preds_gt_global.to_csv(os.path.join(args.dir,args.pred,f"global_gt_preds.csv"), index=False)