diff --git a/evaluation/baseline/common_finetune_sft.py b/evaluation/baseline/common_finetune_sft.py new file mode 100644 index 00000000..74189bd5 --- /dev/null +++ b/evaluation/baseline/common_finetune_sft.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +""" +Common SFT (LoRA) fine-tuning helper for HF Image-Text-to-Text models. +Exposes run_sft(train_examples, **kwargs) where each example is a dict with a +"messages" field (chat template) and optional images embedded in the messages. + +Dependencies: + transformers, datasets, peft, trl, accelerate (and bitsandbytes if you want QLoRA) +""" +from __future__ import annotations +import io +import os +import sys +from typing import List + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +import wfdb +from datasets import Dataset +from transformers import AutoModelForImageTextToText, AutoProcessor +import torch +from peft import LoraConfig +from trl import SFTTrainer +from trl.trainer.sft_config import SFTConfig +from PIL import Image + +# Ensure project src/ is on sys.path so we can import time_series_datasets +PROJECT_SRC = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "src")) +if PROJECT_SRC not in sys.path: + sys.path.insert(0, PROJECT_SRC) + +from time_series_datasets.ecg_qa.plot_example import get_ptbxl_ecg_path + +LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] + + +def _downsample_to_100hz(ecg_data: np.ndarray, original_freq: int) -> np.ndarray: + """Downsample ECG data to 100Hz""" + if original_freq == 100: + return ecg_data + + # Calculate downsampling factor + downsample_factor = original_freq // 100 + + # Downsample by taking every nth sample + downsampled_data = ecg_data[::downsample_factor] + + return downsampled_data + + +def _load_ecg_data(ecg_id: int) -> np.ndarray: + """Load ECG data for a given ECG ID using wfdb.""" + ecg_path = get_ptbxl_ecg_path(ecg_id) + + if not os.path.exists(ecg_path + ".dat"): + raise FileNotFoundError(f"ECG file not found: {ecg_path}.dat") + + # Read ECG data using wfdb - returns (samples, leads) shape + ecg_data, meta = wfdb.rdsamp(ecg_path) + + return ecg_data + + +def _ecg_to_image(ecg_data: np.ndarray) -> Image.Image: + """Render ECG data to PIL Image (for lazy loading in collate_fn).""" + + # Downsample to 100Hz if needed + if ecg_data.shape[0] > 1000: # Likely 500Hz data + ecg_data = _downsample_to_100hz(ecg_data, 500) + + n = min(ecg_data.shape[1], 12) # Up to 12 leads + fig, axes = plt.subplots(n, 1, figsize=(10.5, 1.5 * n), dpi=80) + if n == 1: + axes = [axes] + + # Create time array for 100Hz sampling (10 seconds) + time_points = np.arange(0, 10, 0.01) # 100Hz for 10 seconds + + for i in range(n): + ax = axes[i] + lead_name = LEAD_NAMES[i] if i < len(LEAD_NAMES) else f"Lead {i+1}" + + # Plot the ECG signal for this lead - ecg_data is (samples, leads) + ax.plot(time_points, ecg_data[:, i], linewidth=2, color="k", alpha=1.0) + + # Add grid lines (millimeter paper style) + # Major grid lines (every 0.2s and 0.5mV) + ax.vlines( + np.arange(0, 10, 0.2), -2.5, 2.5, colors="r", alpha=0.3, linewidth=0.5 + ) + ax.hlines( + np.arange(-2.5, 2.5, 0.5), 0, 10, colors="r", alpha=0.3, linewidth=0.5 + ) + + # Minor grid lines (every 0.04s and 0.1mV) + ax.vlines( + np.arange(0, 10, 0.04), -2.5, 2.5, colors="r", alpha=0.1, linewidth=0.3 + ) + ax.hlines( + np.arange(-2.5, 2.5, 0.1), 0, 10, colors="r", alpha=0.1, linewidth=0.3 + ) + + ax.set_xticks(np.arange(0, 11, 1.0)) + ax.set_ylabel(f"Lead {lead_name} (mV)", fontweight="bold") + ax.margins(0.0) + ax.set_ylim(-2.5, 2.5) + ax.set_title(f"Lead {lead_name}", fontweight="bold", pad=10) + + plt.tight_layout() + + canvas = FigureCanvas(fig) + buf = io.BytesIO() + canvas.print_png(buf) + plt.close(fig) + + # Return PIL Image directly + buf.seek(0) + return Image.open(buf).convert("RGB") + + +def process_vision_info(messages: list[dict]) -> list[Image.Image]: + """Extract PIL images from chat messages. Handles lazy ecg_id, bytes, and PIL Images.""" + image_inputs = [] + for msg in messages: + content = msg.get("content", []) + if not isinstance(content, list): + content = [content] + + for element in content: + if not isinstance(element, dict): + continue + + # Handle lazy ecg_id reference (render on-demand) + ecg_id = element.get("ecg_id") + if ecg_id is not None: + ecg_data = _load_ecg_data(ecg_id) + image = _ecg_to_image(ecg_data) + image_inputs.append(image) + continue + + # Handle pre-rendered bytes (backwards compatibility) + image = element.get("image") + if image is None: + continue # Text elements get "image": None from Dataset serialization + + # Handle bytes (PNG data) - convert to PIL Image + if isinstance(image, bytes): + image = Image.open(io.BytesIO(image)) + + image_inputs.append(image.convert("RGB")) + return image_inputs + + +def run_sft( + train_examples: List[dict], + *, + output_dir: str, + llm_id: str = "google/gemma-3-4b-pt", + epochs: int = 1, + learning_rate: float = 2e-4, + per_device_train_batch_size: int = 1, + gradient_accumulation_steps: int = 4, + max_seq_len: int = 4096, + logging_steps: int = 10, + save_steps: int = 500, # Save less frequently to save disk space + bf16: bool = True, +) -> None: + """Run LoRA SFT on chat-style examples (with images) and save adapters. + + Args: + train_examples: List of dicts, each containing a "messages" list compatible + with the processor's chat template. Image elements should be PIL Images + placed as dicts with {"type": "image", "image": PIL.Image}. + output_dir: Where to save adapters and processor + llm_id: HF model id (e.g., google/gemma-3-4b-pt) + epochs, learning_rate, per_device_train_batch_size, gradient_accumulation_steps, + max_seq_len: Usual training hyperparameters + logging_steps, save_steps, bf16: Trainer settings + """ + if not train_examples: + raise ValueError("train_examples is empty; provide at least one training example") + + os.makedirs(output_dir, exist_ok=True) + + ds = Dataset.from_list(train_examples) + + processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") + + model = AutoModelForImageTextToText.from_pretrained( + llm_id, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", + low_cpu_mem_usage=True, + ) + + lora_cfg = LoraConfig( + lora_alpha=8, + lora_dropout=0.05, + r=8, + bias="none", + target_modules="all-linear", + task_type="CAUSAL_LM", + modules_to_save=["lm_head", "embed_tokens"], + ) + + training_args = SFTConfig( + output_dir=output_dir, + num_train_epochs=epochs, + per_device_train_batch_size=per_device_train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=learning_rate, + logging_steps=logging_steps, + save_strategy="steps", + save_steps=save_steps, + bf16=bf16, + report_to=[], + dataset_text_field="", + dataset_kwargs={"skip_prepare_dataset": True}, + max_seq_length=max_seq_len, + packing=False, + remove_unused_columns=False, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + optim="adamw_8bit", + max_grad_norm=0.3, + ) + + def collate_fn(examples: List[dict]): + """Collate chat examples into a batch with masked labels.""" + texts = [] + images = [] + for ex in examples: + msgs = ex["messages"] + text = processor.apply_chat_template( + msgs, add_generation_prompt=False, tokenize=False + ) + texts.append(text.strip()) + images.append(process_vision_info(msgs)) + + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + labels = batch["input_ids"].clone() + + pad_token_id = processor.tokenizer.pad_token_id + if pad_token_id is not None: + labels[labels == pad_token_id] = -100 + + special_map = processor.tokenizer.special_tokens_map + boi_id = None + if isinstance(special_map, dict) and "boi_token" in special_map: + boi_id = processor.tokenizer.convert_tokens_to_ids(special_map["boi_token"]) + if boi_id is not None: + labels[labels == boi_id] = -100 + labels[labels == 262144] = -100 + + batch["labels"] = labels + return batch + + trainer = SFTTrainer( + model=model, + peft_config=lora_cfg, + processing_class=processor, + train_dataset=ds, + args=training_args, + data_collator=collate_fn, + ) + + resume_from_checkpoint = None + import glob + import os as os_module + checkpoints = glob.glob(f"{output_dir}/checkpoint-*") + if checkpoints: + latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1])) + if os_module.path.exists(os_module.path.join(latest_checkpoint, "trainer_state.json")): + resume_from_checkpoint = latest_checkpoint + print(f"Resuming from checkpoint: {resume_from_checkpoint}") + else: + print(f"Checkpoint {latest_checkpoint} is incomplete, starting from scratch") + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + trainer.model.save_pretrained(output_dir) + processor.save_pretrained(output_dir) + print(f"Saved LoRA adapters and processor to: {output_dir}") + + del model + del trainer + torch.cuda.empty_cache() + diff --git a/evaluation/baseline/evaluate_ecg_qa.py b/evaluation/baseline/evaluate_ecg_qa.py index 5c657d96..32c06e9b 100644 --- a/evaluation/baseline/evaluate_ecg_qa.py +++ b/evaluation/baseline/evaluate_ecg_qa.py @@ -6,13 +6,56 @@ # SPDX-License-Identifier: MIT # +import ast +import os import re import sys -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Optional, Tuple +import pandas as pd -from common_evaluator import CommonEvaluator -from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset +# Add src directory to path for time_series_datasets imports +_SRC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "src") +sys.path.insert(0, _SRC_DIR) + +# Template answers cache - loaded from CSV on first access +_template_answers_cache: Optional[Dict[int, List[str]]] = None + + +def _load_template_answers_cache() -> Dict[int, List[str]]: + """Load template answers from CSV file.""" + global _template_answers_cache + if _template_answers_cache is None: + # Path relative to this file: ../../data/ecg_qa/ecgqa/mimic-iv-ecg/answers_for_each_template.csv + base_dir = os.path.dirname(os.path.abspath(__file__)) + template_answers_path = os.path.join( + base_dir, "..", "..", "data", "ecg_qa", "ecgqa", "mimic-iv-ecg", "answers_for_each_template.csv" + ) + + if not os.path.exists(template_answers_path): + raise FileNotFoundError( + f"Template answers file not found at {template_answers_path}. " + "Please ensure the ECG-QA dataset is downloaded." + ) + + template_df = pd.read_csv(template_answers_path) + _template_answers_cache = {} + for _, row in template_df.iterrows(): + template_id = int(row['template_id']) + answers_str = row['classes'] + try: + _template_answers_cache[template_id] = ast.literal_eval(answers_str) + except Exception as e: + print(f"Warning: Failed to parse answers for template {template_id}: {e}") + _template_answers_cache[template_id] = [] + + return _template_answers_cache + + +def get_possible_answers_for_template(template_id: int) -> List[str]: + """Get possible answers for a specific template ID.""" + cache = _load_template_answers_cache() + return cache.get(template_id, []) def extract_answer(text: str) -> str: @@ -40,7 +83,7 @@ def normalize_label(label: str) -> str: def evaluate_ecg_metrics( - ground_truth: str, prediction: str, sample: Dict[str, Any] | None = None + ground_truth: str, prediction: str, sample: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Evaluate ECG-QA CoT predictions using per-template answers from CSV. @@ -68,11 +111,22 @@ def evaluate_ecg_metrics( print(f"DEBUG: Sample content: {sample}") raise ValueError("Missing 'template_id' in sample for ECG-QA evaluation") - possible_answers = ECGQACoTQADataset.get_possible_answers_for_template( - int(template_id) - ) + possible_answers = get_possible_answers_for_template(int(template_id)) if not possible_answers: - raise ValueError(f"No possible answers found for template_id={template_id}") + # Template not found in answers file - return metrics indicating this + return { + "accuracy": 0, + "f1_score": 0.0, + "precision": 0.0, + "recall": 0.0, + "prediction_normalized": pred_norm, + "ground_truth_normalized": gt_norm, + "prediction_supported": False, + "ground_truth_supported": False, + "template_id": template_id, + "possible_answers": [], + "template_missing": True, + } possible_answers_lower = [a.lower().strip() for a in possible_answers] @@ -121,13 +175,16 @@ def _calculate_template_f1_stats(data_points: List[Dict[str, Any]]) -> Dict[str, total_correct = 0 total_f1_sum = 0.0 + skipped_templates = [] for template_id, points in template_groups.items(): if not points: continue possible_answers = points[0].get("possible_answers", []) if not possible_answers: - raise ValueError(f"No possible answers found for template {template_id}") + # Template missing from answers file - skip but track for warning + skipped_templates.append((template_id, len(points))) + continue # Initialize per-class counts class_predictions: Dict[str, Dict[str, int]] = {} @@ -205,6 +262,9 @@ def _calculate_template_f1_stats(data_points: List[Dict[str, Any]]) -> Dict[str, overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0 overall_avg_f1 = total_f1_sum / total_samples if total_samples > 0 else 0.0 + # Report skipped templates + skipped_samples = sum(count for _, count in skipped_templates) + return { "overall": { "total_samples": total_samples, @@ -212,8 +272,11 @@ def _calculate_template_f1_stats(data_points: List[Dict[str, Any]]) -> Dict[str, "accuracy": overall_accuracy, "average_f1": overall_avg_f1, "macro_f1": overall_macro_f1, + "skipped_templates": len(skipped_templates), + "skipped_samples": skipped_samples, }, "per_template": template_stats, + "skipped_template_details": skipped_templates, } @@ -248,6 +311,9 @@ def _build_data_points_from_results( def main(): """Main function to run ECG-QA CoT evaluation with parser-matching F1 aggregation.""" + from common_evaluator import CommonEvaluator + from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset + if len(sys.argv) != 2: print("Usage: python evaluate_ecg_qa.py ") print("Example: python evaluate_ecg_qa.py meta-llama/Llama-3.2-1B") diff --git a/evaluation/baseline/evaluate_ecg_qa_results.py b/evaluation/baseline/evaluate_ecg_qa_results.py new file mode 100644 index 00000000..978179e3 --- /dev/null +++ b/evaluation/baseline/evaluate_ecg_qa_results.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +Evaluate ECG-QA inference results from a CSV file. +Reuses evaluation logic from evaluate_ecg_qa.py. +""" +import argparse +import pandas as pd + +from evaluate_ecg_qa import ( + evaluate_ecg_metrics, + _calculate_template_f1_stats, +) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate ECG-QA inference results from CSV") + parser.add_argument("csv_path", type=str, help="Path to inference results CSV") + parser.add_argument("--verbose", "-v", action="store_true", help="Show per-sample details") + args = parser.parse_args() + + # Load CSV + df = pd.read_csv(args.csv_path) + print(f"Loaded {len(df)} samples from {args.csv_path}\n") + + # Evaluate each sample using imported function + data_points = [] + for _, row in df.iterrows(): + sample = {"template_id": row["template_id"]} + metrics = evaluate_ecg_metrics( + row["target_answer"], + row["generated_answer"], + sample + ) + data_points.append(metrics) + + # Aggregate using imported function + f1_stats = _calculate_template_f1_stats(data_points) + + # Print results + overall = f1_stats.get("overall", {}) + print("=" * 80) + print("EVALUATION RESULTS") + print("=" * 80) + print(f"Total samples evaluated: {overall.get('total_samples', 0)}") + print(f"Accuracy: {overall.get('accuracy', 0):.4f}") + print(f"Average F1 Score: {overall.get('average_f1', 0):.4f}") + print(f"Macro-F1 Score: {overall.get('macro_f1', 0):.4f}") + + # Report skipped templates + skipped_templates = overall.get('skipped_templates', 0) + skipped_samples = overall.get('skipped_samples', 0) + if skipped_templates > 0: + print(f"\nWarning: Skipped {skipped_templates} templates ({skipped_samples} samples) due to missing answers") + skipped_details = f1_stats.get("skipped_template_details", []) + for template_id, count in skipped_details: + print(f" Template {template_id}: {count} samples skipped") + + # Per-template stats + per_template = f1_stats.get("per_template", {}) + if per_template: + print(f"\nPer-Template Statistics:") + for template_id, stats in sorted(per_template.items()): + print(f" Template {template_id}:") + print(f" Samples: {stats['num_samples']}") + print(f" Accuracy: {stats['accuracy']:.4f}") + print(f" Macro-F1: {stats['macro_f1']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/baseline/evaluate_har_plot.py b/evaluation/baseline/evaluate_har_plot.py index 05cecd13..e2dd5bf0 100644 --- a/evaluation/baseline/evaluate_har_plot.py +++ b/evaluation/baseline/evaluate_har_plot.py @@ -8,7 +8,6 @@ import re import sys -import argparse import io import base64 from typing import Dict, Any @@ -16,46 +15,29 @@ import matplotlib.pyplot as plt from common_evaluator_plot import CommonEvaluatorPlot -from time_series_datasets.pamap2.PAMAP2AccQADataset import PAMAP2AccQADataset from time_series_datasets.har_cot.HARAccQADataset import HARAccQADataset def extract_label_from_prediction(prediction: str) -> str: - """ - Extract the label from the model's prediction. - - If 'Answer:' is present, take everything after the last 'Answer:' - - Otherwise, take the last word - - Strips whitespace and punctuation - """ + """Extract the label from the model's prediction.""" pred = prediction.strip() - # Find the last occurrence of 'Answer:' (case-insensitive) match = list(re.finditer(r'answer:\s*', pred, re.IGNORECASE)) if match: - # Take everything after the last 'Answer:' - start = match[-1].end() - label = pred[start:].strip() + label = pred[match[-1].end():].strip() else: - # Take the last word label = pred.split()[-1] if pred.split() else '' - # Remove trailing punctuation (e.g., period, comma) label = re.sub(r'[\.,;:!?]+$', '', label) return label.lower() def evaluate_har_acc(ground_truth: str, prediction: str) -> Dict[str, Any]: - """ - Evaluate HARAccQADataset predictions against ground truth. - Extracts the label from the end of the model's output and compares to ground truth. - """ + """Evaluate HARAccQADataset predictions against ground truth.""" gt_clean = ground_truth.lower().strip() pred_label = extract_label_from_prediction(prediction) - accuracy = int(gt_clean == pred_label) - return {"accuracy": accuracy} + return {"accuracy": int(gt_clean == pred_label)} def generate_time_series_plot(time_series) -> str: - """ - Create a base64 PNG plot from a list/tuple of 1D numpy arrays (e.g., [x, y, z]). - """ + """Create a base64 PNG plot from accelerometer data [x, y, z].""" if time_series is None: return None ts_list = list(time_series) @@ -69,7 +51,9 @@ def generate_time_series_plot(time_series) -> str: for i, series in enumerate(ts_list): axes[i].plot(series, marker='o', linestyle='-', markersize=0) axes[i].grid(True, alpha=0.3) - axes[i].set_title(f"Accelerometer - {axis_names.get(i, f'Axis {i+1}')}" ) + axes[i].set_title(f"Accelerometer - {axis_names.get(i, f'Axis {i+1}')}") + axes[i].set_ylabel("Acceleration (g)") + axes[-1].set_xlabel("Time (samples)") plt.tight_layout() @@ -103,7 +87,7 @@ def main(): dataset_classes=dataset_classes, evaluation_functions=evaluation_functions, plot_functions=plot_functions, - max_samples=None, # Limit for faster testing, set to None for full evaluation, + max_samples=None, max_new_tokens=400, ) print("\n" + "="*80) diff --git a/evaluation/baseline/evaluate_sleep_plot.py b/evaluation/baseline/evaluate_sleep_plot.py index ca71f0c8..8751598b 100644 --- a/evaluation/baseline/evaluate_sleep_plot.py +++ b/evaluation/baseline/evaluate_sleep_plot.py @@ -13,67 +13,49 @@ from typing import Dict, Any import matplotlib.pyplot as plt -import numpy as np from common_evaluator_plot import CommonEvaluatorPlot from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset def extract_label_from_text(text: str) -> str: - """ - Extract the label from a free-form rationale or prediction text. - - If 'Answer:' is present (case-insensitive), take everything after the last 'Answer:' - - Otherwise, take the last word - - Strip whitespace and trailing punctuation - - Lowercase for comparison - """ + """Extract the label from a prediction or rationale text.""" if text is None: return "" pred = text.strip() matches = list(re.finditer(r"answer:\s*", pred, re.IGNORECASE)) if matches: - start = matches[-1].end() - label = pred[start:].strip() + label = pred[matches[-1].end():].strip() else: label = pred.split()[-1] if pred.split() else "" label = re.sub(r"[\.,;:!?]+$", "", label) return label.lower() -def evaluate_sleep_stage( - ground_truth_text: str, prediction_text: str -) -> Dict[str, Any]: - """ - Evaluate SleepEDFCoTQADataset predictions against ground truth. - For SleepEDF, the dataset's "answer" is a rationale ending with 'Answer: