Skip to content

Commit 04c3550

Browse files
committed
Fixed bugs. Refactored code. Added more descriptions.
1 parent 78ce2fc commit 04c3550

File tree

1 file changed

+85
-56
lines changed

1 file changed

+85
-56
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 85 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@
1111
from .config import get_mt_exp_dir
1212

1313
chap_num = 0
14-
trained_books = []
15-
target_book = ""
16-
all_books = []
17-
metrics = []
18-
key_word = ""
1914

2015

21-
def read_data(file_path: str, data: dict, chapters: set) -> None:
16+
def read_group_results(
17+
file_path: str,
18+
target_book: str,
19+
all_books: list[str],
20+
metrics: list[str],
21+
key_word: str,
22+
) -> tuple[dict[str, dict[int, list[str]]], set[int]]:
2223
global chap_num
23-
global all_books
24-
global key_word
2524

25+
data = {}
26+
chapter_groups = set()
2627
for lang_pair in os.listdir(file_path):
2728
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
2829
if not lang_pattern.match(lang_pair):
@@ -33,24 +34,22 @@ def read_data(file_path: str, data: dict, chapters: set) -> None:
3334
pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$")
3435

3536
for groups in os.listdir(os.path.join(file_path, lang_pair)):
36-
m = pattern.match(os.path.basename(groups))
37-
if m:
37+
if m := pattern.match(os.path.basename(groups)):
3838
folder_path = os.path.join(file_path, lang_pair, os.path.basename(groups))
3939
diff_pred_file = glob.glob(os.path.join(folder_path, "diff_predictions*"))
4040
if diff_pred_file:
41-
r = extract_data(diff_pred_file[0])
41+
r = extract_diff_pred_data(diff_pred_file[0], metrics, target_book)
4242
data[lang_pair][int(m.group(1))] = r
43-
chapters.add(int(m.group(1)))
44-
if int(m.group(1)) > chap_num:
45-
chap_num = int(m.group(1))
4643
else:
44+
data[lang_pair][int(m.group(1))] = {}
4745
print(folder_path + " has no diff_predictions file.")
46+
chapter_groups.add(int(m.group(1)))
47+
chap_num = max(chap_num, int(m.group(1)))
48+
return data, chapter_groups
4849

4950

50-
def extract_data(filename: str, header_row=5) -> dict:
51+
def extract_diff_pred_data(filename: str, metrics: list[str], target_book: str, header_row=5) -> dict[int, list[str]]:
5152
global chap_num
52-
global metrics
53-
global target_book
5453

5554
metrics = [m.lower() for m in metrics]
5655
try:
@@ -67,47 +66,49 @@ def extract_data(filename: str, header_row=5) -> dict:
6766
for _, row in df.iterrows():
6867
vref = row["vref"]
6968
m = re.match(r"(\d?[A-Z]{2,3}) (\d+)", str(vref))
69+
if not m:
70+
print(f"Invalid VREF format: {str(vref)}")
71+
return {}
7072

7173
book_name, chap = m.groups()
7274
if book_name != target_book:
7375
continue
7476

75-
if int(chap) > chap_num:
76-
chap_num = int(chap)
77-
77+
chap_num = max(chap_num, int(chap))
7878
values = []
7979
for metric in metrics:
8080
if metric in row:
8181
values.append(row[metric])
8282
else:
83-
metric = True
83+
metric_warning = True
8484
values.append(None)
8585

8686
result[int(chap)] = values
8787

8888
if metric_warning:
89-
print("Warning: {metric} is not calculated in {filename}")
89+
print("Warning: {metric} was not calculated in {filename}")
9090

9191
return result
9292

9393

94-
def flatten_dict(data: dict, chapters: list, baseline={}) -> list:
94+
def flatten_dict(data: dict, chapter_groups: list[int], metrics: list[str], baseline={}) -> list[str]:
9595
global chap_num
96-
global metrics
9796

9897
rows = []
9998
if len(data) > 0:
10099
for lang_pair in data:
101100
for chap in range(1, chap_num + 1):
102101
row = [lang_pair, chap]
103102
row.extend([None, None, None] * len(metrics) * len(data[lang_pair]))
104-
row.extend([None] * len(chapters))
103+
row.extend([None] * len(chapter_groups))
105104
row.extend([None] * (1 + len(metrics)))
106105

107106
for res_chap in data[lang_pair]:
108107
if chap in data[lang_pair][res_chap]:
109108
for m in range(len(metrics)):
110-
index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
109+
index_m = (
110+
3 + 1 + len(metrics) + chapter_groups.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
111+
)
111112
row[index_m] = data[lang_pair][res_chap][chap][m]
112113
if len(baseline) > 0:
113114
for m in range(len(metrics)):
@@ -126,16 +127,15 @@ def flatten_dict(data: dict, chapters: list, baseline={}) -> list:
126127
return rows
127128

128129

129-
def create_xlsx(rows: list, chapters: list, output_path: str) -> None:
130+
def create_xlsx(rows: list[str], chapter_groups: list[str], output_path: str, metrics: list[str]) -> None:
130131
global chap_num
131-
global metrics
132132

133133
wb = Workbook()
134134
ws = wb.active
135135

136136
num_col = len(metrics) * 3 + 1
137137
groups = [("language pair", 1), ("Chapter", 1), ("Baseline", (1 + len(metrics)))]
138-
for chap in chapters:
138+
for chap in chapter_groups:
139139
groups.append((chap, num_col))
140140

141141
col = 1
@@ -239,16 +239,28 @@ def create_xlsx(rows: list, chapters: list, output_path: str) -> None:
239239
# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK
240240
def main() -> None:
241241
global chap_num
242-
global trained_books
243-
global target_book
244-
global all_books
245-
global metrics
246-
global key_word
247242

248243
parser = argparse.ArgumentParser(
249-
description="Pull results. At least one --exp or --baseline needs to be specified."
244+
description="Pulling results from a single experiment and/or multiple experiment groups."
245+
"A valid experiment should have the following format:"
246+
"baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment"
247+
"or "
248+
"exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups"
249+
"More information in --exp and --baseline."
250+
"Use --exp for multiple experiment groups and --baseline for a single experiment."
251+
"At least one --exp or --baseline needs to be specified."
252+
)
253+
parser.add_argument(
254+
"--exp",
255+
type=str,
256+
help="Experiment folder with progression results. "
257+
"A valid experiment groups should have the following format:"
258+
"exp/lang_pair/exp_groups/diff_predictions"
259+
"where there should be at least one exp_groups that naming in the following format:"
260+
"*book*+*book*_*key-word*_order_*number*_ch"
261+
"where *book*+*book*... are the combination of all --trained-books with the last one being --target-book."
262+
"More information in --key-word.",
250263
)
251-
parser.add_argument("--exp", type=str, help="Experiment folder with progression results")
252264
parser.add_argument(
253265
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp"
254266
)
@@ -261,8 +273,25 @@ def main() -> None:
261273
type=str.lower,
262274
help="Metrics that will be analyzed with",
263275
)
264-
parser.add_argument("--key-word", type=str, default="conf", help="Key word in the filename for the exp group")
265-
parser.add_argument("--baseline", type=str, help="Baseline or non-progression result for the exp group")
276+
parser.add_argument(
277+
"--key-word",
278+
type=str,
279+
default="conf",
280+
help="Key word in the filename for the exp group to distinguish between the experiment purpose."
281+
"For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf."
282+
"Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard.",
283+
)
284+
parser.add_argument(
285+
"--baseline",
286+
type=str,
287+
help="A non-progression folder for a single experiment."
288+
"A valid single experiment should have the following format:"
289+
"baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format:"
290+
"*book*+*book*... as the combination of all --trained-books."
291+
"or"
292+
"baseline/lang_pair/diff_predictions "
293+
"where the information of --trained-books should have already been indicated in the baseline name.",
294+
)
266295
args = parser.parse_args()
267296

268297
if not (args.exp or args.baseline):
@@ -274,46 +303,46 @@ def main() -> None:
274303
metrics = args.metrics
275304
key_word = args.key_word
276305

277-
exp1_name = args.exp
278-
exp1_dir = get_mt_exp_dir(exp1_name) if exp1_name else None
306+
multi_group_exp_name = args.exp
307+
multi_group_exp_dir = get_mt_exp_dir(multi_group_exp_name) if multi_group_exp_name else None
279308

280-
exp2_name = args.baseline
281-
exp2_dir = get_mt_exp_dir(exp2_name) if exp2_name else None
309+
single_group_exp_name = args.baseline
310+
single_group_exp_dir = get_mt_exp_dir(single_group_exp_name) if single_group_exp_name else None
282311

283-
folder_name = "+".join(all_books)
284-
result_dir = exp1_dir if exp1_dir else exp2_dir
312+
result_file_name = "+".join(all_books)
313+
result_dir = multi_group_exp_dir if multi_group_exp_dir else single_group_exp_dir
285314
os.makedirs(os.path.join(result_dir, "a_result_folder"), exist_ok=True)
286-
output_path = os.path.join(result_dir, "a_result_folder", f"{folder_name}.xlsx")
315+
output_path = os.path.join(result_dir, "a_result_folder", f"{result_file_name}.xlsx")
287316

288317
data = {}
289-
chapters = set()
290-
if exp1_dir:
291-
read_data(exp1_dir, data, chapters)
292-
chapters = sorted(chapters)
318+
chapter_groups = set()
319+
if multi_group_exp_dir:
320+
data, chapter_groups = read_group_results(multi_group_exp_dir, target_book, all_books, metrics, key_word)
321+
chapter_groups = sorted(chapter_groups)
293322

294323
baseline_data = {}
295-
if exp2_dir:
296-
for lang_pair in os.listdir(exp2_dir):
324+
if single_group_exp_dir:
325+
for lang_pair in os.listdir(single_group_exp_dir):
297326
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
298327
if not lang_pattern.match(lang_pair):
299328
continue
300329

301-
baseline_path = os.path.join(exp2_dir, lang_pair)
330+
baseline_path = os.path.join(single_group_exp_dir, lang_pair)
302331
baseline_diff_pred = glob.glob(os.path.join(baseline_path, "diff_predictions*"))
303332
if baseline_diff_pred:
304-
baseline_data[lang_pair] = extract_data(baseline_diff_pred[0])
333+
baseline_data[lang_pair] = extract_diff_pred_data(baseline_diff_pred[0], metrics, target_book)
305334
else:
306335
print(f"Checking experiments under {baseline_path}...")
307336
sub_baseline_path = os.path.join(baseline_path, "+".join(trained_books))
308337
baseline_diff_pred = glob.glob(os.path.join(sub_baseline_path, "diff_predictions*"))
309338
if baseline_diff_pred:
310-
baseline_data[lang_pair] = extract_data(baseline_diff_pred[0])
339+
baseline_data[lang_pair] = extract_diff_pred_data(baseline_diff_pred[0], metrics, target_book)
311340
else:
312341
print(f"Baseline experiment has no diff_predictions file in {sub_baseline_path}")
313342

314343
print("Writing data...")
315-
rows = flatten_dict(data, chapters, baseline=baseline_data)
316-
create_xlsx(rows, chapters, output_path)
344+
rows = flatten_dict(data, chapter_groups, metrics, baseline=baseline_data)
345+
create_xlsx(rows, chapter_groups, output_path, metrics)
317346
print(f"Result is in {output_path}")
318347

319348

0 commit comments

Comments
 (0)