@@ -46,16 +46,20 @@ def evaluate_models(
46
46
Returns:
47
47
HTML visualization of the evaluation results
48
48
"""
49
- if model_results is None or len ( model_results .columns ) == 0 :
50
- raise ValueError ("At least one model is required for evaluation" )
49
+ if model_results is None or model_results .is_empty () :
50
+ raise ValueError ("Model results are required for evaluation" )
51
51
52
52
if ground_truth_df is None or ground_truth_df .is_empty ():
53
53
raise ValueError ("Ground truth data is required for evaluation" )
54
54
55
55
gt_df = ground_truth_df
56
56
57
+ # --- 1. Extract unique model names from the flat DataFrame structure ---
58
+ model_keys = model_results ["model_name" ].unique ().to_list ()
59
+ if not model_keys :
60
+ raise ValueError ("No model names found in model_results" )
61
+
57
62
# --- 2. Build model info for evaluation models ---
58
- model_keys = list (model_results .columns )
59
63
model_info = {}
60
64
model_displays = []
61
65
model_prefixes = {}
@@ -65,11 +69,11 @@ def evaluate_models(
65
69
model_displays .append (display )
66
70
model_prefixes [display ] = prefix
67
71
68
- # --- 3. Convert DataFrame rows to dictionaries ---
72
+ # --- 3. Split model results by model ---
69
73
model_results_dict = {}
70
74
for model_name in model_keys :
71
- model_data = model_results [ model_name ]. to_dicts ( )
72
- model_results_dict [model_name ] = pl . DataFrame ( model_data )
75
+ model_data = model_results . filter ( pl . col ( "model_name" ) == model_name )
76
+ model_results_dict [model_name ] = model_data
73
77
74
78
# --- 4. Merge evaluation models' results ---
75
79
base_model = model_keys [0 ]
@@ -113,7 +117,7 @@ def evaluate_models(
113
117
114
118
# Check if we have ground truth data in our joined dataset
115
119
if gt_text_col not in merged_results .columns and "raw_text_gt" in merged_results .columns :
116
- gt_text_col = "raw_text_gt" # Fall back to legacy ground truth model format
120
+ gt_text_col = "raw_text_gt"
117
121
118
122
for row in merged_results .iter_rows (named = True ):
119
123
if gt_text_col not in row :
0 commit comments