-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
299 lines (235 loc) · 9.05 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""
A simple Python script to:
* Load the medBERT model
* Produce and store vectors for 15k public clinical trial article abstracts
* Standup a web service that accepts queries, converts to vector embeddings,
and returns the 20 most closely matching clinical trials
"""
from http import HTTPStatus
import socketserver
import http.server
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
import sqlite3 as db
import random
import torch
import math
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import codecs
import pickle
def loadTrialsTableFromTestData(dbConnection):
"""
Load the training database file into a proper database format
PubMed 200k RCT dataset
https://github.com/Franck-Dernoncourt/pubmed-rct
15ktrain.txt (from the 20k folder)
"""
# Open the file
filename = '15ktrain.txt'
file = open(filename, "r").read()
# Split it by the seperator defined by the published model
seperator = "###"
trials = file.split(seperator)
# Use pandas to push them into a SQL lite file
pd.DataFrame(trials, columns=['abstract_text']).to_sql(
name="clinical_trials", # Table name
con=dbConnection, # Open DB connection
if_exists="replace", # Fail if table is already present
index=True, # Write DF frame index as column
index_label="id", # Gives index column a columnname
chunksize=1000 # Write this many rows at a time to db
)
# Add the serialized vector column for later
cursor = dbConnection.cursor()
sql = """
ALTER TABLE clinical_trials
ADD COLUMN serialized_vectors text;
"""
cursor.execute(sql)
dbConnection.commit()
def generateTrialVectors(dbConnection, tokenizer, model):
"""
Generate embeddings for each of the trial db rows
"""
# Load DB to pandas dataframe
# trials = pd.read_sql_table(table_name='clinical_trials', con=dbConnection) # Actually this is SQLalchemy
trials = pd.read_sql_query(
"SELECT * from clinical_trials WHERE abstract_text!='' LIMIT 1000", dbConnection)
# Get the text column and transform it from a panda into a simple py list
trialsList = trials.abstract_text.values.tolist()
# Transform the text to vector embeddings using model
# Let's make this maneagable lists of 100 at a time
i = 0
allTrialVectors = []
for i in range(math.ceil(len(trialsList)/100)):
# Mark the loop
print("TRIAL VECTORIZE LOOP "+str(i))
# Chunk into 100s
start = i*100
end = start+100
if (end > len(trialsList)):
end = len(trialsList)-1
chunk = trialsList[start:end]
# Transform list of trials to list of trialVectors
trialVectorsChunk = transform(chunk, tokenizer, model)
# Add to output
allTrialVectors += trialVectorsChunk
return allTrialVectors
def pushVectorsToTrialTable(dbConnection, allTrialVectors):
"""
Push the trial vectors back to DB
"""
# Prep db cursor
cursor = dbConnection.cursor()
# Prep SQL statement to push SQL UPDATEs
sql = """
UPDATE clinical_trials
SET serialized_vectors = '{0}'
WHERE ID = {1}
"""
# Serialize every item in allTrialVectors and push update to db
for i in range(len(allTrialVectors)):
# Pickle and encode bytes to base64 string
pickled = codecs.encode(pickle.dumps(
allTrialVectors[i]), "base64").decode()
# ID of the row we're updating
id = i
# Format the SQL command with the id & vector and then execute it on a cursor, and finally commit the transaction
query = sql.format(pickled, id)
cursor.execute(query)
# Commit the transactions
dbConnection.commit()
def findClosestDotProduct(text, tokenizer, model, dbConnection, trialIDs, trialVectors, trialAbstracts):
"""
Calculcate the dot product of a query text against the
stored vector embeddings of clinical trials
to find the most similar
"""
# Transform text to vector embedding
vectorOutputToCompare = transform(text, tokenizer, model)
# Turn the list of tensors into a stacked tensor (required for torch.mm)
trialVectors = torch.stack(trialVectors)
# Compute dot score between query and all trial vectors
scores = torch.mm(vectorOutputToCompare, trialVectors.transpose(0, 1))[
0].cpu().tolist()
# Combine trialIDs & scores
IDscorePairs = list(zip(trialIDs, scores))
# Sort by decreasing score
IDscorePairs = sorted(IDscorePairs, key=lambda x: x[1], reverse=True)
# Output 20 best matching passages as a py list
results = []
i = 0
for ID, score in IDscorePairs:
# return ID
# return scores[]
results.append(str(score)+"||||"+trialAbstracts[ID])
i += 1
if i > 20:
break
return results
def meanPooling(output, attentionMask):
"""
Mean Pooling - Take average of all tokens
This math was taken from a framework.
"""
tokenVectors = output.last_hidden_state
expandedMask = attentionMask.unsqueeze(
-1).expand(tokenVectors.size()).float()
return (
torch.sum(tokenVectors * expandedMask, 1) /
torch.clamp(expandedMask.sum(1), min=1e-9)
)
def transform(text, tokenizer, model):
"""
Compute vector embeddings from text
"""
# Tokenize
# NOTE: BERT can only handle 512 tokens
tokenized = tokenizer(text, padding=True,
# BERT CAN ONLY ACCEPT <512 tokens including special [CLS] and [SEP]
max_length=510,
truncation='longest_first',
return_tensors='pt')
# Transform, no_grad for memory saving
with torch.no_grad():
output = model(**tokenized, return_dict=True)
# Pool
vectors = meanPooling(output, tokenized['attention_mask'])
# Normalize
vectors = F.normalize(vectors, p=2, dim=1)
# Return vectors (serialization later)
return vectors
def terminal_main():
"""
Load, calculate and insert back
"""
print("BEGIN")
# Get MedBERT model and tokenizer (from HuggingFace database)
# It also thankfully caches the .bin model file so no crazy traffic with nodemon hot reloading
tokenizer = AutoTokenizer.from_pretrained("Charangan/MedBERT")
model = AutoModel.from_pretrained("Charangan/MedBERT")
# Connect to DB
dbConnection = db.connect("15ktrain.sqlite")
# If the trials table doesn't exist...
cursor = dbConnection.cursor()
trialTableExists = cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
AND name='clinical_trials'
""").fetchall()
if not trialTableExists:
# Precompute the trial vector embeddings and store in the database (this only needs to be executed once...)
loadTrialsTableFromTestData(dbConnection)
allTrialVectors = generateTrialVectors(dbConnection, tokenizer, model)
pushVectorsToTrialTable(dbConnection, allTrialVectors)
else:
print("Table already precomputed")
query = input("Text to match to?: ")
result = findClosestDotProduct(query, tokenizer, model, dbConnection)
print(result)
print("END")
# Run program
# terminal_main()
"""
Below is the web server implementation
"""
hostName = "localhost"
serverPort = 80
# Do the basics outside the web server to prevent reloading every time
tokenizer = AutoTokenizer.from_pretrained("Charangan/MedBERT")
model = AutoModel.from_pretrained("Charangan/MedBERT")
dbConnection = db.connect("./15ktrain.sqlite")
# Load trialVectors from database
trialTable = pd.read_sql_query(
"SELECT * from clinical_trials WHERE abstract_text!='' LIMIT 1000", dbConnection)
# Panda=>list for trialIDs and serialized_vectors
trialIDs = trialTable.id.values.tolist()
trialVectors = trialTable.serialized_vectors.values.tolist()
trialAbstracts = trialTable.abstract_text.values.tolist()
# Deserialize (unpickle and base 64 decode) the vector embeddings in place
for i in range(len(trialVectors)):
trialVectors[i] = pickle.loads(codecs.decode(
trialVectors[i].encode(), "base64"))
class Handler(http.server.SimpleHTTPRequestHandler):
""""
Handle get requests with ?query=TEXT by returning the 20 most similar trials to $query
"""
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
i = self.path.index("?") + 1
params = dict([tuple(p.split("=")) for p in self.path[i:].split("&")])
query = params["query"]
# result = query
result = findClosestDotProduct(
query, tokenizer, model, dbConnection, trialIDs, trialVectors, trialAbstracts)
result = '####'.join(result)
self.wfile.write(bytes(result, "utf-8"))
httpd = socketserver.TCPServer(('', 80), Handler)
httpd.serve_forever()