-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckPrecisionRecall.py
128 lines (113 loc) · 5.03 KB
/
checkPrecisionRecall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2023-2024 G. Fan, J. Wang, Y. Li, D. Zhang, and R. J. Miller
#
# This file is derived from Starmie hosted at https://github.com/megagonlabs/starmie
# and originally licensed under the BSD-3 Clause License. Modifications to the original
# source code have been made by I. Taha, M. Lissandrini, A. Simitsis, and Y. Ioannidis, and
# can be viewed at https://github.com/athenarc/table-search.
#
# This work is licensed under the GNU Affero General Public License v3.0,
# unless otherwise explicitly stated. See the https://github.com/athenarc/table-search/blob/main/LICENSE
# for more details.
#
# You may use, modify, and distribute this file in accordance with the terms of the
# GNU Affero General Public License v3.0.
# This file has been modified
import pickle
import mlflow
import numpy as np
import pandas as pd
import pickle5 as p
from matplotlib import *
from matplotlib import pyplot as plt
def loadDictionaryFromPickleFile(dictionaryPath):
''' Load the pickle file as a dictionary
Args:
dictionaryPath: path to the pickle file
Return: dictionary from the pickle file
'''
filePointer=open(dictionaryPath, 'rb')
dictionary = p.load(filePointer)
filePointer.close()
return dictionary
def saveDictionaryAsPickleFile(dictionary, dictionaryPath):
''' Save dictionary as a pickle file
Args:
dictionary to be saved
dictionaryPath: filepath to which the dictionary will be saved
'''
filePointer=open(dictionaryPath, 'wb')
pickle.dump(dictionary,filePointer, protocol=pickle.HIGHEST_PROTOCOL)
filePointer.close()
def calcMetrics(max_k, k_range, resultFile, gtPath=None, resPath=None, record=True):
''' Calculate and log the performance metrics: MAP, Precision@k, Recall@k
Args:
max_k: the maximum K value (e.g. for SANTOS benchmark, max_k = 10. For TUS benchmark, max_k = 60)
k_range: step size for the K's up to max_k
gtPath: file path to the groundtruth
resPath: file path to the raw results from the model
record (boolean): to log in MLFlow or not
Return: MAP, P@K, R@K
'''
groundtruth = loadDictionaryFromPickleFile(gtPath)
# resultFile = loadDictionaryFromPickleFile(resPath)
# =============================================================================
# Precision and recall
# =============================================================================
precision_array = []
recall_array = []
ideal_recall_arr = []
for k in range(1, max_k+1):
true_positive = 0
false_positive = 0
false_negative = 0
rec = 0
ideal_recall = []
for table in resultFile:
# t28 tables have less than 60 results. So, skipping them in the analysis.
if table.split("____",1)[0] != "t_28dc8f7610402ea7":
if table in groundtruth:
groundtruth_set = set(groundtruth[table])
groundtruth_set = {x.split(".")[0] for x in groundtruth_set}
result_set = resultFile[table][:k]
result_set = [x.split(".")[0] for x in result_set]
# find_intersection = true positives
find_intersection = set(result_set).intersection(groundtruth_set)
tp = len(find_intersection)
fp = k - tp
fn = len(groundtruth_set) - tp
if len(groundtruth_set)>=k:
true_positive += tp
false_positive += fp
false_negative += fn
rec += tp / (tp+fn)
ideal_recall.append(k/len(set(groundtruth[table])))
precision = true_positive / (true_positive + false_positive)
recall = rec/len(resultFile)
precision_array.append(precision)
recall_array.append(recall)
ideal_recall_arr.append(sum(ideal_recall)/len(ideal_recall))
#if k % 10 == 0:
# print(k, "IDEAL RECALL:", sum(ideal_recall)/len(ideal_recall))
#print("IDEAL RECALL")
#print(sum(ideal_recall)/len(ideal_recall))
used_k = [k_range]
if max_k >k_range:
for i in range(k_range * 2, max_k+1, k_range):
used_k.append(i)
#print("--------------------------")
#for k in used_k:
# print("Precision at k = ",k,"=", precision_array[k-1])
# print("Recall at k = ",k,"=", recall_array[k-1])
# print("--------------------------")
map_sum = 0
for k in range(0, max_k):
map_sum += precision_array[k]
mean_avg_pr = map_sum/max_k
#print("The mean average precision is:", mean_avg_pr)
# logging to mlflow
if record: # if the user would like to log to MLFlow
mlflow.log_metric("mean_avg_precision", mean_avg_pr)
mlflow.log_metric("prec_k", precision_array[max_k-1])
mlflow.log_metric("recall_k", recall_array[max_k-1])
#return mean_avg_pr, precision_array[max_k-1], recall_array[max_k-1]
return mean_avg_pr, precision_array, recall_array, ideal_recall_arr