Skip to content

Commit 2e01008

Browse files
committed
Further refactoring
1 parent 04c3550 commit 2e01008

File tree

1 file changed

+60
-56
lines changed

1 file changed

+60
-56
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import glob
33
import os
44
import re
5+
from typing import Dict, List, Set, Tuple
56

67
import pandas as pd
78
from openpyxl import Workbook
@@ -10,18 +11,15 @@
1011

1112
from .config import get_mt_exp_dir
1213

13-
chap_num = 0
14-
1514

1615
def read_group_results(
1716
file_path: str,
1817
target_book: str,
19-
all_books: list[str],
20-
metrics: list[str],
18+
all_books: List[str],
19+
metrics: List[str],
2120
key_word: str,
22-
) -> tuple[dict[str, dict[int, list[str]]], set[int]]:
23-
global chap_num
24-
21+
chap_num: int,
22+
) -> Tuple[Dict[str, Dict[int, List[float]]], Set[int], int]:
2523
data = {}
2624
chapter_groups = set()
2725
for lang_pair in os.listdir(file_path):
@@ -38,37 +36,38 @@ def read_group_results(
3836
folder_path = os.path.join(file_path, lang_pair, os.path.basename(groups))
3937
diff_pred_file = glob.glob(os.path.join(folder_path, "diff_predictions*"))
4038
if diff_pred_file:
41-
r = extract_diff_pred_data(diff_pred_file[0], metrics, target_book)
39+
r, chap_num = extract_diff_pred_data(diff_pred_file[0], metrics, target_book, chap_num)
4240
data[lang_pair][int(m.group(1))] = r
4341
else:
4442
data[lang_pair][int(m.group(1))] = {}
4543
print(folder_path + " has no diff_predictions file.")
4644
chapter_groups.add(int(m.group(1)))
4745
chap_num = max(chap_num, int(m.group(1)))
48-
return data, chapter_groups
49-
46+
return data, chapter_groups, chap_num
5047

51-
def extract_diff_pred_data(filename: str, metrics: list[str], target_book: str, header_row=5) -> dict[int, list[str]]:
52-
global chap_num
5348

49+
def extract_diff_pred_data(
50+
filename: str, metrics: List[str], target_book: str, chap_num: int, header_row=5
51+
) -> Tuple[Dict[int, List[float]], int]:
5452
metrics = [m.lower() for m in metrics]
5553
try:
5654
df = pd.read_excel(filename, header=header_row)
5755
except ValueError as e:
5856
print(f"An error occurs in {filename}")
5957
print(e)
60-
return {}
58+
return {}, chap_num
6159

6260
df.columns = [col.strip().lower() for col in df.columns]
6361

6462
result = {}
6563
metric_warning = False
64+
uncalculated_metric = set()
6665
for _, row in df.iterrows():
6766
vref = row["vref"]
6867
m = re.match(r"(\d?[A-Z]{2,3}) (\d+)", str(vref))
6968
if not m:
7069
print(f"Invalid VREF format: {str(vref)}")
71-
return {}
70+
continue
7271

7372
book_name, chap = m.groups()
7473
if book_name != target_book:
@@ -78,22 +77,21 @@ def extract_diff_pred_data(filename: str, metrics: list[str], target_book: str,
7877
values = []
7978
for metric in metrics:
8079
if metric in row:
81-
values.append(row[metric])
80+
values.append(float(row[metric]))
8281
else:
8382
metric_warning = True
83+
uncalculated_metric.add(metric)
8484
values.append(None)
8585

8686
result[int(chap)] = values
8787

8888
if metric_warning:
89-
print("Warning: {metric} was not calculated in {filename}")
90-
91-
return result
89+
print(f"Warning: {uncalculated_metric} was not calculated in {filename}")
9290

91+
return result, chap_num
9392

94-
def flatten_dict(data: dict, chapter_groups: list[int], metrics: list[str], baseline={}) -> list[str]:
95-
global chap_num
9693

94+
def flatten_dict(data: Dict, chapter_groups: List[int], metrics: List[str], chap_num: int, baseline={}) -> List[str]:
9795
rows = []
9896
if len(data) > 0:
9997
for lang_pair in data:
@@ -127,9 +125,9 @@ def flatten_dict(data: dict, chapter_groups: list[int], metrics: list[str], base
127125
return rows
128126

129127

130-
def create_xlsx(rows: list[str], chapter_groups: list[str], output_path: str, metrics: list[str]) -> None:
131-
global chap_num
132-
128+
def create_xlsx(
129+
rows: List[str], chapter_groups: List[str], output_path: str, metrics: List[str], chap_num: int
130+
) -> None:
133131
wb = Workbook()
134132
ws = wb.active
135133

@@ -238,71 +236,71 @@ def create_xlsx(rows: list[str], chapter_groups: list[str], output_path: str, me
238236
# python -m silnlp.nmt.exp_summary Catapult_Reloaded_Confidences
239237
# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK
240238
def main() -> None:
241-
global chap_num
242-
243239
parser = argparse.ArgumentParser(
244-
description="Pulling results from a single experiment and/or multiple experiment groups."
245-
"A valid experiment should have the following format:"
246-
"baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment"
240+
description="Pulling results from a single experiment and/or multiple experiment groups. "
241+
"A valid experiment should have the following format: "
242+
"baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment "
247243
"or "
248-
"exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups"
249-
"More information in --exp and --baseline."
250-
"Use --exp for multiple experiment groups and --baseline for a single experiment."
251-
"At least one --exp or --baseline needs to be specified."
244+
"exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups "
245+
"More information in --exp and --baseline. "
246+
"Use --exp for multiple experiment groups and --baseline for a single experiment. "
247+
"At least one --exp or --baseline needs to be specified. "
252248
)
253249
parser.add_argument(
254250
"--exp",
255251
type=str,
256252
help="Experiment folder with progression results. "
257-
"A valid experiment groups should have the following format:"
258-
"exp/lang_pair/exp_groups/diff_predictions"
259-
"where there should be at least one exp_groups that naming in the following format:"
260-
"*book*+*book*_*key-word*_order_*number*_ch"
261-
"where *book*+*book*... are the combination of all --trained-books with the last one being --target-book."
262-
"More information in --key-word.",
253+
"A valid experiment groups should have the following format: "
254+
"exp/lang_pair/exp_groups/diff_predictions "
255+
"where there should be at least one exp_groups that naming in the following format: "
256+
"*book*+*book*_*key-word*_order_*number*_ch "
257+
"where *book*+*book*... are the combination of all --trained-books with the last one being --target-book. "
258+
"More information in --key-word. ",
263259
)
264260
parser.add_argument(
265-
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp"
261+
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp "
266262
)
267-
parser.add_argument("--target-book", required=True, type=str.upper, help="Book that is going to be analyzed")
263+
parser.add_argument("--target-book", required=True, type=str.upper, help="Book that is going to be analyzed ")
268264
parser.add_argument(
269265
"--metrics",
270266
nargs="*",
271267
metavar="metrics",
272268
default=["chrf3", "confidence"],
273269
type=str.lower,
274-
help="Metrics that will be analyzed with",
270+
help="Metrics that will be analyzed with ",
275271
)
276272
parser.add_argument(
277273
"--key-word",
278274
type=str,
279275
default="conf",
280-
help="Key word in the filename for the exp group to distinguish between the experiment purpose."
281-
"For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf."
282-
"Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard.",
276+
help="Key word in the filename for the exp group to distinguish between the experiment purpose. "
277+
"For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf. "
278+
"Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard. ",
283279
)
284280
parser.add_argument(
285281
"--baseline",
286282
type=str,
287-
help="A non-progression folder for a single experiment."
288-
"A valid single experiment should have the following format:"
289-
"baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format:"
290-
"*book*+*book*... as the combination of all --trained-books."
291-
"or"
283+
help="A non-progression folder for a single experiment. "
284+
"A valid single experiment should have the following format: "
285+
"baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format: "
286+
"*book*+*book*... as the combination of all --trained-books. "
287+
"or "
292288
"baseline/lang_pair/diff_predictions "
293-
"where the information of --trained-books should have already been indicated in the baseline name.",
289+
"where the information of --trained-books should have already been indicated in the baseline name. ",
294290
)
295291
args = parser.parse_args()
296292

297293
if not (args.exp or args.baseline):
298-
parser.error("At least one --exp or --baseline needs to be specified.")
294+
parser.error("At least one --exp or --baseline needs to be specified. ")
299295

300296
trained_books = args.trained_books
301297
target_book = args.target_book
302298
all_books = trained_books + [target_book]
303299
metrics = args.metrics
304300
key_word = args.key_word
305301

302+
chap_num = 0
303+
306304
multi_group_exp_name = args.exp
307305
multi_group_exp_dir = get_mt_exp_dir(multi_group_exp_name) if multi_group_exp_name else None
308306

@@ -317,7 +315,9 @@ def main() -> None:
317315
data = {}
318316
chapter_groups = set()
319317
if multi_group_exp_dir:
320-
data, chapter_groups = read_group_results(multi_group_exp_dir, target_book, all_books, metrics, key_word)
318+
data, chapter_groups, chap_num = read_group_results(
319+
multi_group_exp_dir, target_book, all_books, metrics, key_word, chap_num
320+
)
321321
chapter_groups = sorted(chapter_groups)
322322

323323
baseline_data = {}
@@ -330,19 +330,23 @@ def main() -> None:
330330
baseline_path = os.path.join(single_group_exp_dir, lang_pair)
331331
baseline_diff_pred = glob.glob(os.path.join(baseline_path, "diff_predictions*"))
332332
if baseline_diff_pred:
333-
baseline_data[lang_pair] = extract_diff_pred_data(baseline_diff_pred[0], metrics, target_book)
333+
baseline_data[lang_pair], chap_num = extract_diff_pred_data(
334+
baseline_diff_pred[0], metrics, target_book, chap_num
335+
)
334336
else:
335337
print(f"Checking experiments under {baseline_path}...")
336338
sub_baseline_path = os.path.join(baseline_path, "+".join(trained_books))
337339
baseline_diff_pred = glob.glob(os.path.join(sub_baseline_path, "diff_predictions*"))
338340
if baseline_diff_pred:
339-
baseline_data[lang_pair] = extract_diff_pred_data(baseline_diff_pred[0], metrics, target_book)
341+
baseline_data[lang_pair], chap_num = extract_diff_pred_data(
342+
baseline_diff_pred[0], metrics, target_book, chap_num
343+
)
340344
else:
341345
print(f"Baseline experiment has no diff_predictions file in {sub_baseline_path}")
342346

343347
print("Writing data...")
344-
rows = flatten_dict(data, chapter_groups, metrics, baseline=baseline_data)
345-
create_xlsx(rows, chapter_groups, output_path, metrics)
348+
rows = flatten_dict(data, chapter_groups, metrics, chap_num, baseline=baseline_data)
349+
create_xlsx(rows, chapter_groups, output_path, metrics, chap_num)
346350
print(f"Result is in {output_path}")
347351

348352

0 commit comments

Comments
 (0)