Skip to content

Commit b729a95

Browse files
committed
unify dataframe structure for single/multi model runs
1 parent 102b6c8 commit b729a95

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

omni-reader/steps/evaluate_models.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,20 @@ def evaluate_models(
4646
Returns:
4747
HTML visualization of the evaluation results
4848
"""
49-
if model_results is None or len(model_results.columns) == 0:
50-
raise ValueError("At least one model is required for evaluation")
49+
if model_results is None or model_results.is_empty():
50+
raise ValueError("Model results are required for evaluation")
5151

5252
if ground_truth_df is None or ground_truth_df.is_empty():
5353
raise ValueError("Ground truth data is required for evaluation")
5454

5555
gt_df = ground_truth_df
5656

57+
# --- 1. Extract unique model names from the flat DataFrame structure ---
58+
model_keys = model_results["model_name"].unique().to_list()
59+
if not model_keys:
60+
raise ValueError("No model names found in model_results")
61+
5762
# --- 2. Build model info for evaluation models ---
58-
model_keys = list(model_results.columns)
5963
model_info = {}
6064
model_displays = []
6165
model_prefixes = {}
@@ -65,11 +69,11 @@ def evaluate_models(
6569
model_displays.append(display)
6670
model_prefixes[display] = prefix
6771

68-
# --- 3. Convert DataFrame rows to dictionaries ---
72+
# --- 3. Split model results by model ---
6973
model_results_dict = {}
7074
for model_name in model_keys:
71-
model_data = model_results[model_name].to_dicts()
72-
model_results_dict[model_name] = pl.DataFrame(model_data)
75+
model_data = model_results.filter(pl.col("model_name") == model_name)
76+
model_results_dict[model_name] = model_data
7377

7478
# --- 4. Merge evaluation models' results ---
7579
base_model = model_keys[0]
@@ -113,7 +117,7 @@ def evaluate_models(
113117

114118
# Check if we have ground truth data in our joined dataset
115119
if gt_text_col not in merged_results.columns and "raw_text_gt" in merged_results.columns:
116-
gt_text_col = "raw_text_gt" # Fall back to legacy ground truth model format
120+
gt_text_col = "raw_text_gt"
117121

118122
for row in merged_results.iter_rows(named=True):
119123
if gt_text_col not in row:

omni-reader/utils/visualizations.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,16 @@ def create_summary_visualization(
385385

386386

387387
def create_ocr_batch_visualization(df: pl.DataFrame) -> HTMLString:
388-
"""Create an HTML visualization of batch OCR processing results."""
389-
# Extract metrics
388+
"""Create an HTML visualization of batch OCR processing results.
389+
390+
Args:
391+
df: DataFrame containing OCR results (flattened for single/multi-model runs)
392+
393+
Returns:
394+
HTMLString: HTML visualization of batch OCR processing results
395+
"""
396+
# Calculate overall metrics
390397
total_results = len(df)
391-
# Ensure all raw_text values are strings
392398
raw_texts = []
393399
for txt in df["raw_text"].to_list():
394400
if isinstance(txt, list):
@@ -400,14 +406,13 @@ def create_ocr_batch_visualization(df: pl.DataFrame) -> HTMLString:
400406
total_proc_time = df["processing_time"].sum() if "processing_time" in df.columns else 0
401407
avg_proc_time = df["processing_time"].mean() if "processing_time" in df.columns else 0
402408

403-
# Get model-specific metrics
409+
# Calculate model-specific metrics
404410
model_metrics = {}
405411
model_displays = []
406412

407413
if "model_name" in df.columns:
408414
for model in df["model_name"].unique().to_list():
409415
mdf = df.filter(pl.col("model_name") == model)
410-
# Ensure all model-specific raw_text values are strings
411416
m_raw_texts = []
412417
for txt in mdf["raw_text"].to_list():
413418
if isinstance(txt, list):
@@ -523,8 +528,6 @@ def create_ocr_batch_visualization(df: pl.DataFrame) -> HTMLString:
523528
if isinstance(raw_text, list):
524529
raw_text = "\n".join(raw_text)
525530
text_preview = str(raw_text)[:100] + ("..." if len(str(raw_text)) > 100 else "")
526-
527-
# Calculate the length properly
528531
text_length = len(str(raw_text)) if raw_text is not None else 0
529532

530533
table_rows += f"""

0 commit comments

Comments
 (0)