1111from .config import get_mt_exp_dir
1212
1313chap_num = 0
14- trained_books = []
15- target_book = ""
16- all_books = []
17- metrics = []
18- key_word = ""
1914
2015
21- def read_data (file_path : str , data : dict , chapters : set ) -> None :
16+ def read_group_results (
17+ file_path : str ,
18+ target_book : str ,
19+ all_books : list [str ],
20+ metrics : list [str ],
21+ key_word : str ,
22+ ) -> tuple [dict [str , dict [int , list [str ]]], set [int ]]:
2223 global chap_num
23- global all_books
24- global key_word
2524
25+ data = {}
26+ chapter_groups = set ()
2627 for lang_pair in os .listdir (file_path ):
2728 lang_pattern = re .compile (r"([\w-]+)\-([\w-]+)" )
2829 if not lang_pattern .match (lang_pair ):
@@ -33,24 +34,22 @@ def read_data(file_path: str, data: dict, chapters: set) -> None:
3334 pattern = re .compile (rf"^{ re .escape (prefix )} _{ key_word } _order_(\d+)_ch$" )
3435
3536 for groups in os .listdir (os .path .join (file_path , lang_pair )):
36- m = pattern .match (os .path .basename (groups ))
37- if m :
37+ if m := pattern .match (os .path .basename (groups )):
3838 folder_path = os .path .join (file_path , lang_pair , os .path .basename (groups ))
3939 diff_pred_file = glob .glob (os .path .join (folder_path , "diff_predictions*" ))
4040 if diff_pred_file :
41- r = extract_data (diff_pred_file [0 ])
41+ r = extract_diff_pred_data (diff_pred_file [0 ], metrics , target_book )
4242 data [lang_pair ][int (m .group (1 ))] = r
43- chapters .add (int (m .group (1 )))
44- if int (m .group (1 )) > chap_num :
45- chap_num = int (m .group (1 ))
4643 else :
44+ data [lang_pair ][int (m .group (1 ))] = {}
4745 print (folder_path + " has no diff_predictions file." )
46+ chapter_groups .add (int (m .group (1 )))
47+ chap_num = max (chap_num , int (m .group (1 )))
48+ return data , chapter_groups
4849
4950
50- def extract_data (filename : str , header_row = 5 ) -> dict :
51+ def extract_diff_pred_data (filename : str , metrics : list [ str ], target_book : str , header_row = 5 ) -> dict [ int , list [ str ]] :
5152 global chap_num
52- global metrics
53- global target_book
5453
5554 metrics = [m .lower () for m in metrics ]
5655 try :
@@ -67,47 +66,49 @@ def extract_data(filename: str, header_row=5) -> dict:
6766 for _ , row in df .iterrows ():
6867 vref = row ["vref" ]
6968 m = re .match (r"(\d?[A-Z]{2,3}) (\d+)" , str (vref ))
69+ if not m :
70+ print (f"Invalid VREF format: { str (vref )} " )
71+ return {}
7072
7173 book_name , chap = m .groups ()
7274 if book_name != target_book :
7375 continue
7476
75- if int (chap ) > chap_num :
76- chap_num = int (chap )
77-
77+ chap_num = max (chap_num , int (chap ))
7878 values = []
7979 for metric in metrics :
8080 if metric in row :
8181 values .append (row [metric ])
8282 else :
83- metric = True
83+ metric_warning = True
8484 values .append (None )
8585
8686 result [int (chap )] = values
8787
8888 if metric_warning :
89- print ("Warning: {metric} is not calculated in {filename}" )
89+ print ("Warning: {metric} was not calculated in {filename}" )
9090
9191 return result
9292
9393
94- def flatten_dict (data : dict , chapters : list , baseline = {}) -> list :
94+ def flatten_dict (data : dict , chapter_groups : list [ int ], metrics : list [ str ], baseline = {}) -> list [ str ] :
9595 global chap_num
96- global metrics
9796
9897 rows = []
9998 if len (data ) > 0 :
10099 for lang_pair in data :
101100 for chap in range (1 , chap_num + 1 ):
102101 row = [lang_pair , chap ]
103102 row .extend ([None , None , None ] * len (metrics ) * len (data [lang_pair ]))
104- row .extend ([None ] * len (chapters ))
103+ row .extend ([None ] * len (chapter_groups ))
105104 row .extend ([None ] * (1 + len (metrics )))
106105
107106 for res_chap in data [lang_pair ]:
108107 if chap in data [lang_pair ][res_chap ]:
109108 for m in range (len (metrics )):
110- index_m = 3 + 1 + len (metrics ) + chapters .index (res_chap ) * (len (metrics ) * 3 + 1 ) + m * 3
109+ index_m = (
110+ 3 + 1 + len (metrics ) + chapter_groups .index (res_chap ) * (len (metrics ) * 3 + 1 ) + m * 3
111+ )
111112 row [index_m ] = data [lang_pair ][res_chap ][chap ][m ]
112113 if len (baseline ) > 0 :
113114 for m in range (len (metrics )):
@@ -126,16 +127,15 @@ def flatten_dict(data: dict, chapters: list, baseline={}) -> list:
126127 return rows
127128
128129
129- def create_xlsx (rows : list , chapters : list , output_path : str ) -> None :
130+ def create_xlsx (rows : list [ str ], chapter_groups : list [ str ] , output_path : str , metrics : list [ str ] ) -> None :
130131 global chap_num
131- global metrics
132132
133133 wb = Workbook ()
134134 ws = wb .active
135135
136136 num_col = len (metrics ) * 3 + 1
137137 groups = [("language pair" , 1 ), ("Chapter" , 1 ), ("Baseline" , (1 + len (metrics )))]
138- for chap in chapters :
138+ for chap in chapter_groups :
139139 groups .append ((chap , num_col ))
140140
141141 col = 1
@@ -239,16 +239,28 @@ def create_xlsx(rows: list, chapters: list, output_path: str) -> None:
239239# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK
240240def main () -> None :
241241 global chap_num
242- global trained_books
243- global target_book
244- global all_books
245- global metrics
246- global key_word
247242
248243 parser = argparse .ArgumentParser (
249- description = "Pull results. At least one --exp or --baseline needs to be specified."
244+ description = "Pulling results from a single experiment and/or multiple experiment groups."
245+ "A valid experiment should have the following format:"
246+ "baseline/lang_pair/exp_group/diff_predictions or baseline/lang_pair/diff_predictions for a single experiment"
247+ "or "
248+ "exp/lang_pair/exp_groups/diff_predictions for multiple experiment groups"
249+ "More information in --exp and --baseline."
250+ "Use --exp for multiple experiment groups and --baseline for a single experiment."
251+ "At least one --exp or --baseline needs to be specified."
252+ )
253+ parser .add_argument (
254+ "--exp" ,
255+ type = str ,
256+ help = "Experiment folder with progression results. "
257+ "A valid experiment groups should have the following format:"
258+ "exp/lang_pair/exp_groups/diff_predictions"
259+ "where there should be at least one exp_groups that naming in the following format:"
260+ "*book*+*book*_*key-word*_order_*number*_ch"
261+ "where *book*+*book*... are the combination of all --trained-books with the last one being --target-book."
262+ "More information in --key-word." ,
250263 )
251- parser .add_argument ("--exp" , type = str , help = "Experiment folder with progression results" )
252264 parser .add_argument (
253265 "--trained-books" , nargs = "*" , required = True , type = str .upper , help = "Books that are trained in the exp"
254266 )
@@ -261,8 +273,25 @@ def main() -> None:
261273 type = str .lower ,
262274 help = "Metrics that will be analyzed with" ,
263275 )
264- parser .add_argument ("--key-word" , type = str , default = "conf" , help = "Key word in the filename for the exp group" )
265- parser .add_argument ("--baseline" , type = str , help = "Baseline or non-progression result for the exp group" )
276+ parser .add_argument (
277+ "--key-word" ,
278+ type = str ,
279+ default = "conf" ,
280+ help = "Key word in the filename for the exp group to distinguish between the experiment purpose."
281+ "For example, in LUK+ACT_conf_order_12_ch, the key-word should be conf."
282+ "Another example, in LUK+ACT_standard_order_12_ch, the key-word should be standard." ,
283+ )
284+ parser .add_argument (
285+ "--baseline" ,
286+ type = str ,
287+ help = "A non-progression folder for a single experiment."
288+ "A valid single experiment should have the following format:"
289+ "baseline/lang_pair/exp_group/diff_predictions where exp_group will be in the following format:"
290+ "*book*+*book*... as the combination of all --trained-books."
291+ "or"
292+ "baseline/lang_pair/diff_predictions "
293+ "where the information of --trained-books should have already been indicated in the baseline name." ,
294+ )
266295 args = parser .parse_args ()
267296
268297 if not (args .exp or args .baseline ):
@@ -274,46 +303,46 @@ def main() -> None:
274303 metrics = args .metrics
275304 key_word = args .key_word
276305
277- exp1_name = args .exp
278- exp1_dir = get_mt_exp_dir (exp1_name ) if exp1_name else None
306+ multi_group_exp_name = args .exp
307+ multi_group_exp_dir = get_mt_exp_dir (multi_group_exp_name ) if multi_group_exp_name else None
279308
280- exp2_name = args .baseline
281- exp2_dir = get_mt_exp_dir (exp2_name ) if exp2_name else None
309+ single_group_exp_name = args .baseline
310+ single_group_exp_dir = get_mt_exp_dir (single_group_exp_name ) if single_group_exp_name else None
282311
283- folder_name = "+" .join (all_books )
284- result_dir = exp1_dir if exp1_dir else exp2_dir
312+ result_file_name = "+" .join (all_books )
313+ result_dir = multi_group_exp_dir if multi_group_exp_dir else single_group_exp_dir
285314 os .makedirs (os .path .join (result_dir , "a_result_folder" ), exist_ok = True )
286- output_path = os .path .join (result_dir , "a_result_folder" , f"{ folder_name } .xlsx" )
315+ output_path = os .path .join (result_dir , "a_result_folder" , f"{ result_file_name } .xlsx" )
287316
288317 data = {}
289- chapters = set ()
290- if exp1_dir :
291- read_data ( exp1_dir , data , chapters )
292- chapters = sorted (chapters )
318+ chapter_groups = set ()
319+ if multi_group_exp_dir :
320+ data , chapter_groups = read_group_results ( multi_group_exp_dir , target_book , all_books , metrics , key_word )
321+ chapter_groups = sorted (chapter_groups )
293322
294323 baseline_data = {}
295- if exp2_dir :
296- for lang_pair in os .listdir (exp2_dir ):
324+ if single_group_exp_dir :
325+ for lang_pair in os .listdir (single_group_exp_dir ):
297326 lang_pattern = re .compile (r"([\w-]+)\-([\w-]+)" )
298327 if not lang_pattern .match (lang_pair ):
299328 continue
300329
301- baseline_path = os .path .join (exp2_dir , lang_pair )
330+ baseline_path = os .path .join (single_group_exp_dir , lang_pair )
302331 baseline_diff_pred = glob .glob (os .path .join (baseline_path , "diff_predictions*" ))
303332 if baseline_diff_pred :
304- baseline_data [lang_pair ] = extract_data (baseline_diff_pred [0 ])
333+ baseline_data [lang_pair ] = extract_diff_pred_data (baseline_diff_pred [0 ], metrics , target_book )
305334 else :
306335 print (f"Checking experiments under { baseline_path } ..." )
307336 sub_baseline_path = os .path .join (baseline_path , "+" .join (trained_books ))
308337 baseline_diff_pred = glob .glob (os .path .join (sub_baseline_path , "diff_predictions*" ))
309338 if baseline_diff_pred :
310- baseline_data [lang_pair ] = extract_data (baseline_diff_pred [0 ])
339+ baseline_data [lang_pair ] = extract_diff_pred_data (baseline_diff_pred [0 ], metrics , target_book )
311340 else :
312341 print (f"Baseline experiment has no diff_predictions file in { sub_baseline_path } " )
313342
314343 print ("Writing data..." )
315- rows = flatten_dict (data , chapters , baseline = baseline_data )
316- create_xlsx (rows , chapters , output_path )
344+ rows = flatten_dict (data , chapter_groups , metrics , baseline = baseline_data )
345+ create_xlsx (rows , chapter_groups , output_path , metrics )
317346 print (f"Result is in { output_path } " )
318347
319348
0 commit comments