-
Notifications
You must be signed in to change notification settings - Fork 1
/
fast_bm25.py
130 lines (114 loc) · 3.92 KB
/
fast_bm25.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import collections
import heapq
import math
import pickle
import sys
PARAM_K1 = 1.5
PARAM_B = 0.75
IDF_CUTOFF = 4
class BM25:
"""Fast Implementation of Best Matching 25 ranking function.
Attributes
----------
t2d : <token: <doc, freq>>
Dictionary with terms frequencies for each document in `corpus`.
idf: <token, idf score>
Pre computed IDF score for every term.
doc_len : list of int
List of document lengths.
avgdl : float
Average length of document in `corpus`.
"""
def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF):
"""
Parameters
----------
corpus : list of list of str
Given corpus.
k1 : float
Constant used for influencing the term frequency saturation. After saturation is reached, additional
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
the type of documents or queries.
b : float
Constant used for influencing the effects of different document lengths relative to average document length.
When b is bigger, lengthier documents (compared to average) have more impact on its effect. According to
[1]_, experiments suggest that 0.5 < b < 0.8 yields reasonably good results, although the optimal value
depends on factors such as the type of documents or queries.
alpha: float
IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy
of BM25 but increase performance
"""
self.k1 = k1
self.b = b
self.alpha = alpha
self.avgdl = 0
self.t2d = {}
self.idf = {}
self.doc_len = []
if corpus:
self._initialize(corpus)
@property
def corpus_size(self):
return len(self.doc_len)
def _initialize(self, corpus):
"""Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
for i, document in enumerate(corpus):
self.doc_len.append(len(document))
for word in document:
if word not in self.t2d:
self.t2d[word] = {}
if i not in self.t2d[word]:
self.t2d[word][i] = 0
self.t2d[word][i] += 1
self.avgdl = sum(self.doc_len)/len(self.doc_len)
to_delete = []
for word, docs in self.t2d.items():
idf = math.log(self.corpus_size - len(docs) + 0.5) - math.log(len(docs) + 0.5)
# only store the idf score if it's above the threshold
if idf > self.alpha:
self.idf[word] = idf
else:
to_delete.append(word)
print(f"Dropping {len(to_delete)} terms")
for word in to_delete:
del self.t2d[word]
self.average_idf = sum(self.idf.values())/len(self.idf)
if self.average_idf < 0:
print(
f'Average inverse document frequency is less than zero. Your corpus of {self.corpus_size} documents'
' is either too small or it does not originate from natural text. BM25 may produce'
' unintuitive results.',
file=sys.stderr
)
def get_top_n(self, query, documents, n=5):
"""
Retrieve the top n documents for the query.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
list
The top n documents
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst)
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
def save(self, filename):
with open(f"{filename}.pkl", "wb") as fsave:
pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def load(filename):
with open(f"{filename}.pkl", "rb") as fsave:
return pickle.load(fsave)