-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_green_score.py
65 lines (55 loc) · 2.3 KB
/
generate_green_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import pandas as pd
from tqdm import tqdm
from green_score import GREEN
import json
class GenerateGreenScore:
def __init__(self, csv_path, cache_dir=None, save_every=10, organs=["chest"]):
# ["abdomen", "chest", "pelvis"]
self.csv_path = csv_path
self.save_every = save_every
self.organs = organs
self.model = GREEN(
model_id_or_path="StanfordAIMI/GREEN-radllama2-7b",
do_sample=False, # should be always False
batch_size=16,
return_0_if_no_green_score=True,
cuda=True,
cache_dir=cache_dir,
max_len=400
)
self.df = pd.read_csv(self.csv_path)
self.df.fillna('', inplace=True)
if "green_abdomen" not in self.df.columns.to_list():
for organ in self.organs:
self.df["green_" + organ] = [-1] * len(self.df)
self.df["explanation_" + organ] = [""] * len(self.df)
def run(self):
self.generate_scores()
self.save_df()
self.save_summary()
return self.df
def save_df(self):
self.df.to_csv(self.csv_path, index=False)
def save_summary(self):
summary = {}
for organ in self.organs:
o_ = [n for n in self.df[f'green_{organ}'].to_list() if n != -1]
summary[f"{organ}"] = sum(o_) / len(o_)
path = f"{os.sep}".join(self.csv_path.split(os.sep)[:-1])
with open(path + os.sep + "summary.json", 'w') as json_file:
json.dump(summary, json_file, indent=4)
def generate_scores(self):
for indx in tqdm(self.df.index):
row = self.df.iloc[indx]
for organ in self.organs:
if row[f"green_{organ}"] != -1:
continue
if row[f"gt-{organ}"] and row[f"gt-{organ}"]:
_, green, explination = \
self.model(refs=[row[f"gt-{organ}"]], hyps=[row[f"generated-{organ}"]])
self.df[f"green_{organ}"].iloc[indx] = green[0].item()
self.df[f"explanation_{organ}"].iloc[indx] = explination[0]
if indx % self.save_every == 0:
print(f"Saving at indx {indx}")
self.save_df()