Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PaliGemma Training Pipeline #44

Closed
wants to merge 10 commits into from
Next Next commit
first draft
SangbumChoi committed Sep 15, 2024
commit 8ab77b5180940be9e5bbf6158ecbb4e1ce09c1b5
561 changes: 561 additions & 0 deletions cookbooks/maestro_paligemma_object_detection.ipynb

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions maestro/trainer/common/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Tuple, Optional, Literal, Union, Iterator

LoraInitLiteral = Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]

def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)
30 changes: 1 addition & 29 deletions maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
save_metric_plots,
MeanAveragePrecisionMetric
)
from maestro.trainer.common.utils.peft import prepare_peft_model, LoraInitLiteral
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.checkpoints import (
CheckpointManager,
@@ -32,7 +33,6 @@
postprocess_florence2_output_for_mean_average_precision,
run_predictions,
)
from maestro.trainer.models.paligemma.training import LoraInitLiteral


@dataclass(frozen=True)
@@ -159,34 +159,6 @@ def train(config: TrainingConfiguration) -> None:
print(f"Best checkpoint saved at: {checkpoint_manager.best_checkpoint_dir}")


def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)


def run_training_loop(
processor: AutoProcessor,
model: PeftModel,
1 change: 1 addition & 0 deletions maestro/trainer/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from maestro.trainer.models.paligemma.core import TrainingConfiguration, train
119 changes: 119 additions & 0 deletions maestro/trainer/models/paligemma/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
from typing import Optional, Tuple

import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, \
DEFAULT_CUDA_DEVICE

DEFAULT_PALIGEMMA_MODEL_ID = "google/paligemma-3b-pt-224"
DEFAULT_PALIGEMMA_MODEL_REVISION = "float16"
DEVICE = torch.device("cpu") \
if not torch.cuda.is_available() \
else os.getenv(CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE)


class CheckpointManager:
"""Manages checkpoints for model training.
This class handles saving and retrieving model checkpoints during training.
Attributes:
training_dir (str): Directory where checkpoints will be saved.
best_val_loss (float): Best validation loss achieved so far.
latest_checkpoint_dir (str): Directory for the latest checkpoint.
best_checkpoint_dir (str): Directory for the best checkpoint.
"""

def __init__(self, training_dir: str):
"""Initializes the CheckpointManager.
Args:
training_dir (str): Directory where checkpoints will be saved.
"""
self.training_dir = training_dir
self.best_val_loss = float('inf')
self.latest_checkpoint_dir = os.path.join(training_dir, "checkpoints", "latest")
self.best_checkpoint_dir = os.path.join(training_dir, "checkpoints", "best")

def save_latest(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration):
"""Saves the latest model checkpoint.
Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
save_model(self.latest_checkpoint_dir, processor, model)

def save_best(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration, val_loss: float):
"""Saves the best model checkpoint if the validation loss improves.
Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
val_loss (float): The current validation loss.
"""
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
save_model(self.best_checkpoint_dir, processor, model)
print(f"New best model saved with validation loss: {self.best_val_loss}")

def get_best_model_path(self):
"""Returns the path to the best model checkpoint.
Returns:
str: Path to the best model checkpoint.
"""
return self.best_checkpoint_dir


def save_model(
target_dir: str,
processor: AutoProcessor,
model: PaliGemmaForConditionalGeneration,
) -> None:
"""Saves the model and processor to the specified directory.
Args:
target_dir (str): Directory where the model and processor will be saved.
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
os.makedirs(target_dir, exist_ok=True)
processor.save_pretrained(target_dir)
model.save_pretrained(target_dir)


def load_model(
model_id_or_path: str = DEFAULT_PALIGEMMA_MODEL_ID,
revision: str = DEFAULT_PALIGEMMA_MODEL_REVISION,
device: torch.device = DEVICE,
cache_dir: Optional[str] = None,
) -> Tuple[AutoProcessor, PaliGemmaForConditionalGeneration]:
"""Loads a PaliGemma model and its associated processor.
Args:
model_id_or_path: The identifier or path of the model to load.
revision: The specific model revision to use.
device: The device to load the model onto.
cache_dir: Directory to cache the downloaded model files.
Returns:
A tuple containing the loaded processor and model.
Raises:
ValueError: If the model or processor cannot be loaded.
"""
processor = AutoProcessor.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
cache_dir=cache_dir,
).to(device)
return processor, model
395 changes: 395 additions & 0 deletions maestro/trainer/models/paligemma/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,395 @@
from maestro.trainer.common.utils.file_system import create_new_run_directory

from maestro.trainer.models.paligemma.checkpoints import (
CheckpointManager,
load_model,
DEFAULT_PALIGEMMA_MODEL_ID,
DEFAULT_PALIGEMMA_MODEL_REVISION,
DEVICE
)

@dataclass(frozen=True)
class TrainingConfiguration:
"""Configuration for training a PaliGemma model.
This class encapsulates all the parameters needed for training a PaliGemma model,
including dataset paths, model specifications, training hyperparameters, and output
settings.
Attributes:
dataset (str): Path to the dataset used for training.
model_id (str): Identifier for the PaliGemma model.
revision (str): Revision of the model to use.
device (torch.device): Device to use for training.
cache_dir (Optional[str]): Directory to cache the model.
epochs (int): Number of training epochs.
optimizer (Literal["sgd", "adamw", "adam"]): Optimizer to use for training.
lr (float): Learning rate for the optimizer.
lr_scheduler (Literal["linear", "cosine", "polynomial"]): Learning rate
scheduler.
batch_size (int): Batch size for training.
val_batch_size (Optional[int]): Batch size for validation.
num_workers (int): Number of workers for data loading.
val_num_workers (Optional[int]): Number of workers for validation data loading.
lora_r (int): Rank of the LoRA update matrices.
lora_alpha (int): Scaling factor for the LoRA update.
lora_dropout (float): Dropout probability for LoRA layers.
bias (Literal["none", "all", "lora_only"]): Which bias to train.
use_rslora (bool): Whether to use RSLoRA.
init_lora_weights (Union[bool, LoraInitLiteral]): How to initialize LoRA
weights.
output_dir (str): Directory to save output files.
metrics (List[BaseMetric]): List of metrics to track during training.
"""
dataset: str
model_id: str = DEFAULT_PALIGEMMA_MODEL_ID
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION
device: torch.device = DEVICE
cache_dir: Optional[str] = None
epochs: int = 10
optimizer: Literal["sgd", "adamw", "adam"] = "adamw"
lr: float = 1e-5
lr_scheduler: Literal["linear", "cosine", "polynomial"] = "linear"
batch_size: int = 4
val_batch_size: Optional[int] = None
num_workers: int = 0
val_num_workers: Optional[int] = None
lora_r: int = 8
lora_alpha: int = 8
lora_dropout: float = 0.05
bias: Literal["none", "all", "lora_only"] = "none"
use_rslora: bool = True
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian"
output_dir: str = "./training/florence-2"
metrics: List[BaseMetric] = field(default_factory=list)

def train(config: TrainingConfiguration) -> None:
make_it_reproducible(avoid_non_deterministic_algorithms=False)
run_dir = create_new_run_directory(
base_output_dir=config.output_dir,
)
config = replace(
config,
output_dir=run_dir,
)
checkpoint_manager = CheckpointManager(run_dir)

processor, model = load_model(
model_id_or_path=config.model_id,
revision=config.revision,
device=config.device,
cache_dir=config.cache_dir,
)
train_loader, val_loader, test_loader = prepare_data_loaders(
dataset_location=config.dataset,
train_batch_size=config.batch_size,
processor=processor,
device=config.device,
num_workers=config.num_workers,
test_loaders_workers=config.val_num_workers,
)
peft_model = prepare_peft_model(
model=model,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias=config.bias,
use_rslora=config.use_rslora,
init_lora_weights=config.init_lora_weights,
revision=config.revision,
)
training_metrics_tracker = MetricsTracker.init(metrics=["loss"])
metrics = ["loss"]
for metric in config.metrics:
metrics += metric.describe()
validation_metrics_tracker = MetricsTracker.init(metrics=metrics)

run_training_loop(
processor=processor,
model=peft_model,
data_loaders=(train_loader, val_loader),
config=config,
training_metrics_tracker=training_metrics_tracker,
validation_metrics_tracker=validation_metrics_tracker,
checkpoint_manager=checkpoint_manager
)

save_metric_plots(
training_tracker=training_metrics_tracker,
validation_tracker=validation_metrics_tracker,
output_dir=os.path.join(config.output_dir, "metrics"),
)
training_metrics_tracker.as_json(
output_dir=os.path.join(config.output_dir, "metrics"),
filename="training.json")
validation_metrics_tracker.as_json(
output_dir=os.path.join(config.output_dir, "metrics"),
filename="validation.json")

# Log out paths for latest and best checkpoints
print(f"Latest checkpoint saved at: {checkpoint_manager.latest_checkpoint_dir}")
print(f"Best checkpoint saved at: {checkpoint_manager.best_checkpoint_dir}")


def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)


def run_training_loop(
processor: AutoProcessor,
model: PeftModel,
data_loaders: Tuple[DataLoader, Optional[DataLoader]],
config: TrainingConfiguration,
training_metrics_tracker: MetricsTracker,
validation_metrics_tracker: MetricsTracker,
checkpoint_manager: CheckpointManager,
) -> None:
train_loader, val_loader = data_loaders
optimizer = get_optimizer(model=model, config=config)
total_steps = config.epochs * len(train_loader)
lr_scheduler = get_scheduler(
name=config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=total_steps,
)
for epoch in range(config.epochs):
run_training_epoch(
processor=processor,
model=model,
train_loader=train_loader,
val_loader=val_loader,
epoch=epoch + 1,
config=config,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
training_metrics_tracker=training_metrics_tracker,
validation_metrics_tracker=validation_metrics_tracker,
checkpoint_manager=checkpoint_manager
)


def run_training_epoch(
processor: AutoProcessor,
model: PeftModel,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
epoch: int,
config: TrainingConfiguration,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
training_metrics_tracker: MetricsTracker,
validation_metrics_tracker: MetricsTracker,
checkpoint_manager: CheckpointManager,
) -> None:
model.train()
training_losses: List[float] = []

with tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{config.epochs}", unit="batch") as pbar:
for step_id, (inputs, answers) in enumerate(train_loader):
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
labels = processor.tokenizer(
text=answers,
return_tensors="pt",
padding=True,
return_token_type_ids=False
).input_ids.to(config.device)
outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
loss = loss.item()
training_metrics_tracker.register(
metric="loss",
epoch=epoch,
step=step_id + 1,
value=loss,
)
training_losses.append(loss)

# Update progress bar
last_100_losses = training_losses[-100:]
loss_moving_average = sum(last_100_losses) / len(last_100_losses) if last_100_losses else 0.0
pbar.set_postfix({"Loss": f"{loss_moving_average:.4f}"})
pbar.update(1)

# Save checkpoints based on training loss if no validation loader
if val_loader is None or len(val_loader) == 0:
train_loss = sum(training_losses) / len(training_losses)
checkpoint_manager.save_latest(processor, model)
checkpoint_manager.save_best(processor, model, train_loss)
return

run_validation_epoch(
processor=processor,
model=model,
loader=val_loader,
epoch_number=epoch,
config=config,
metrics_tracker=validation_metrics_tracker,
)

val_loss = validation_metrics_tracker.get_metric_values("loss")[-1][2]
checkpoint_manager.save_latest(processor, model)
checkpoint_manager.save_best(processor, model, val_loss)


def run_validation_epoch(
processor: AutoProcessor,
model: Union[PeftModel, AutoModelForCausalLM],
loader: DataLoader,
config: TrainingConfiguration,
metrics_tracker: MetricsTracker,
epoch_number: int
) -> None:
val_loss = 0.0
with torch.no_grad():
for inputs, targets in loader:
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
labels = processor.tokenizer(
text=targets,
return_tensors="pt",
padding=True,
return_token_type_ids=False
).input_ids.to(config.device)
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
labels=labels
)
loss = outputs.loss
val_loss += loss.item()
avg_val_loss = val_loss / len(loader)
metrics_tracker.register(
metric="loss",
epoch=epoch_number,
step=1,
value=avg_val_loss,
)
# Run inference once for all metrics
prompts, expected_responses, generated_texts, images = run_predictions(
dataset=loader.dataset,
processor=processor,
model=model,
device=config.device,
)

metrics_results = {"loss": avg_val_loss}

for metric in config.metrics:
if isinstance(metric, MeanAveragePrecisionMetric):
classes = extract_unique_detection_dataset_classes(loader.dataset)
targets, predictions = postprocess_florence2_output_for_mean_average_precision(
expected_responses=expected_responses,
generated_texts=generated_texts,
images=images,
classes=classes,
processor=processor
)
result = metric.compute(targets=targets, predictions=predictions)
for key, value in result.items():
metrics_tracker.register(
metric=key,
epoch=epoch_number,
step=1,
value=value,
)
metrics_results[key] = value

print("Validation Metrics:", ", ".join([f"{k}: {v:.4f}" for k, v in metrics_results.items()]))

# Display inference results in IPython environments
display_results(prompts, expected_responses, generated_texts, images)


def get_optimizer(model: PeftModel, config: TrainingConfiguration) -> Optimizer:
optimizer_type = config.optimizer.lower()
if optimizer_type == "adamw":
return AdamW(model.parameters(), lr=config.lr)
if optimizer_type == "adam":
return Adam(model.parameters(), lr=config.lr)
if optimizer_type == "sgd":
return SGD(model.parameters(), lr=config.lr)
raise ValueError(f"Unsupported optimizer: {config.optimizer}")


def evaluate(config: TrainingConfiguration) -> None:
processor, model = load_model(
model_id_or_path=config.model_id,
revision=config.revision,
device=config.device,
cache_dir=config.cache_dir,
)
train_loader, val_loader, test_loader = prepare_data_loaders(
dataset_location=config.dataset,
train_batch_size=config.batch_size,
processor=processor,
device=config.device,
num_workers=config.num_workers,
test_loaders_workers=config.val_num_workers,
)
evaluation_loader = test_loader if test_loader is not None else val_loader

metrics = []
for metric in config.metrics:
metrics += metric.describe()
evaluation_metrics_tracker = MetricsTracker.init(metrics=metrics)

# Run inference once for all metrics
_, expected_responses, generated_texts, images = run_predictions(
dataset=evaluation_loader.dataset,
processor=processor,
model=model,
device=config.device,
)

for metric in config.metrics:
if isinstance(metric, MeanAveragePrecisionMetric):
classes = extract_unique_detection_dataset_classes(train_loader.dataset)
targets, predictions = postprocess_florence2_output_for_mean_average_precision(
expected_responses=expected_responses,
generated_texts=generated_texts,
images=images,
classes=classes,
processor=processor
)
result = metric.compute(targets=targets, predictions=predictions)
for key, value in result.items():
evaluation_metrics_tracker.register(
metric=key,
epoch=1,
step=1,
value=value,
)

evaluation_metrics_tracker.as_json(
output_dir=os.path.join(config.output_dir, "metrics"),
filename="evaluation.json")
110 changes: 110 additions & 0 deletions maestro/trainer/models/paligemma/data_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple, List

import torch
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor

from maestro.trainer.common.data_loaders.datasets import DetectionDataset


def prepare_data_loaders(
dataset_location: str,
train_batch_size: int,
processor: AutoProcessor,
device: torch.device,
num_workers: int = 0,
test_batch_size: Optional[int] = None,
test_loaders_workers: Optional[int] = None,
) -> Tuple[
DataLoader,
Optional[DataLoader],
Optional[DataLoader],
]:
test_batch_size = test_batch_size or train_batch_size
test_loaders_workers = test_loaders_workers or num_workers
train_data_loader = prepare_detection_data_loader(
dataset_location=dataset_location,
split_name="train",
batch_size=train_batch_size,
processor=processor,
device=device,
num_workers=num_workers,
shuffle=True,
)
if train_data_loader is None:
raise RuntimeError("Could not initialise train data loader")
valid_data_loader = prepare_detection_data_loader(
dataset_location=dataset_location,
split_name="valid",
batch_size=test_batch_size,
processor=processor,
device=device,
num_workers=test_loaders_workers,
shuffle=False,
)
test_data_loader = prepare_detection_data_loader(
dataset_location=dataset_location,
split_name="test",
batch_size=test_batch_size,
processor=processor,
device=device,
num_workers=test_loaders_workers,
shuffle=False,
)
return train_data_loader, valid_data_loader, test_data_loader


def prepare_detection_data_loader(
dataset_location: str,
split_name: str,
batch_size: int,
processor: AutoProcessor,
device: torch.device,
num_workers: int = 0,
shuffle: bool = True,
) -> Optional[DataLoader]:
dataset = prepare_detection_dataset(
dataset_location=dataset_location,
split_name=split_name,
)
if dataset is None:
return None
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=partial(collate_fn, processor=processor, device=device),
num_workers=num_workers,
shuffle=shuffle,
)


def prepare_detection_dataset(
dataset_location: str,
split_name: str,
) -> Optional[DetectionDataset]:
image_directory_path = os.path.join(dataset_location, split_name)
jsonl_file_path = os.path.join(dataset_location, split_name, "annotations.jsonl")
if not os.path.exists(image_directory_path):
logging.warning(f"Could not data directory: {image_directory_path}")
return None
if not os.path.exists(jsonl_file_path):
logging.warning(f"Could not find JSONL file: {jsonl_file_path}")
return None
return DetectionDataset(
jsonl_file_path=jsonl_file_path,
image_directory_path=image_directory_path,
)


def collate_fn(
batch: Tuple[List[str], List[str], List[Image.Image]],
processor: AutoProcessor,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
questions, answers, images = zip(*batch)
inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
return inputs, answers
216 changes: 211 additions & 5 deletions maestro/trainer/models/paligemma/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,219 @@
import dataclasses
from typing import Optional, Annotated, List, Dict, Type

import rich
import torch
import typer

from maestro.trainer.models.paligemma.checkpoints import DEFAULT_PALIGEMMA_MODEL_ID, \
DEFAULT_PALIGEMMA_MODEL_REVISION, DEVICE
from maestro.trainer.models.paligemma.core import TrainingConfiguration
from maestro.trainer.models.paligemma.core import train as paligemma_train
from maestro.trainer.models.paligemma.core import evaluate as paligemma_evaluate
from maestro.trainer.common.utils.metrics import BaseMetric, MeanAveragePrecisionMetric

paligemma_app = typer.Typer(help="Fine-tune and evaluate PaliGemma model")


@paligemma_app.command(help="Train PaliGemma model")
def train() -> None:
typer.echo("🚧 Just a placeholder - to be implemented 🚧")
METRIC_CLASSES: Dict[str, Type[BaseMetric]] = {
"mean_average_precision": MeanAveragePrecisionMetric,
}


def parse_metrics(metrics: List[str]) -> List[BaseMetric]:
metric_objects = []
for metric_name in metrics:
metric_class = METRIC_CLASSES.get(metric_name.lower())
if metric_class:
metric_objects.append(metric_class())
else:
raise ValueError(f"Unsupported metric: {metric_name}")
return metric_objects


@paligemma_app.command(
help="Train PaliGemma model",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
)
def train(
dataset: Annotated[
str,
typer.Option("--dataset", help="Path to the dataset used for training"),
],
model_id: Annotated[
str,
typer.Option("--model_id", help="Identifier for the PaliGemma model"),
] = DEFAULT_PALIGEMMA_MODEL_ID,
revision: Annotated[
str,
typer.Option("--revision", help="Revision of the model to use"),
] = DEFAULT_PALIGEMMA_MODEL_REVISION,
device: Annotated[
str,
typer.Option("--device", help="Device to use for training"),
] = DEVICE,
cache_dir: Annotated[
Optional[str],
typer.Option("--cache_dir", help="Directory to cache the model"),
] = None,
epochs: Annotated[
int,
typer.Option("--epochs", help="Number of training epochs"),
] = 10,
optimizer: Annotated[
str,
typer.Option("--optimizer", help="Optimizer to use for training"),
] = "adamw",
lr: Annotated[
float,
typer.Option("--lr", help="Learning rate for the optimizer"),
] = 1e-5,
lr_scheduler: Annotated[
str,
typer.Option("--lr_scheduler", help="Learning rate scheduler"),
] = "linear",
batch_size: Annotated[
int,
typer.Option("--batch_size", help="Batch size for training"),
] = 4,
val_batch_size: Annotated[
Optional[int],
typer.Option("--val_batch_size", help="Batch size for validation"),
] = None,
num_workers: Annotated[
int,
typer.Option("--num_workers", help="Number of workers for data loading"),
] = 0,
val_num_workers: Annotated[
Optional[int],
typer.Option("--val_num_workers", help="Number of workers for validation data loading"),
] = None,
lora_r: Annotated[
int,
typer.Option("--lora_r", help="Rank of the LoRA update matrices"),
] = 8,
lora_alpha: Annotated[
int,
typer.Option("--lora_alpha", help="Scaling factor for the LoRA update"),
] = 8,
lora_dropout: Annotated[
float,
typer.Option("--lora_dropout", help="Dropout probability for LoRA layers"),
] = 0.05,
bias: Annotated[
str,
typer.Option("--bias", help="Which bias to train"),
] = "none",
use_rslora: Annotated[
bool,
typer.Option("--use_rslora/--no_use_rslora", help="Whether to use RSLoRA"),
] = True,
init_lora_weights: Annotated[
str,
typer.Option("--init_lora_weights", help="How to initialize LoRA weights"),
] = "gaussian",
output_dir: Annotated[
str,
typer.Option("--output_dir", help="Directory to save output files"),
] = "./training/paligemma",
metrics: Annotated[
List[str],
typer.Option("--metrics", help="List of metrics to track during training"),
] = [],
) -> None:
metric_objects = parse_metrics(metrics)
config = TrainingConfiguration(
dataset=dataset,
model_id=model_id,
revision=revision,
device=torch.device(device),
cache_dir=cache_dir,
epochs=epochs,
optimizer=optimizer,
lr=lr,
lr_scheduler=lr_scheduler,
batch_size=batch_size,
val_batch_size=val_batch_size,
num_workers=num_workers,
val_num_workers=val_num_workers,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias=bias,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
output_dir=output_dir,
metrics=metric_objects
)
typer.echo(typer.style(
text="Training configuration",
fg=typer.colors.BRIGHT_GREEN,
bold=True
))
rich.print(dataclasses.asdict(config))
paligemma_train(config=config)


@paligemma_app.command(help="Evaluate PaliGemma model")
def evaluate() -> None:
typer.echo("🚧 Just a placeholder - to be implemented 🚧")
def evaluate(
dataset: Annotated[
str,
typer.Option("--dataset", help="Path to the dataset used for evaluation"),
],
model_id: Annotated[
str,
typer.Option("--model_id", help="Identifier for the PaliGemma model"),
] = DEFAULT_PALIGEMMA_MODEL_ID,
revision: Annotated[
str,
typer.Option("--revision", help="Revision of the model to use"),
] = DEFAULT_PALIGEMMA_MODEL_REVISION,
device: Annotated[
str,
typer.Option("--device", help="Device to use for evaluation"),
] = DEVICE,
cache_dir: Annotated[
Optional[str],
typer.Option("--cache_dir", help="Directory to cache the model"),
] = None,
batch_size: Annotated[
int,
typer.Option("--batch_size", help="Batch size for evaluation"),
] = 4,
num_workers: Annotated[
int,
typer.Option("--num_workers", help="Number of workers for data loading"),
] = 0,
val_num_workers: Annotated[
Optional[int],
typer.Option("--val_num_workers", help="Number of workers for validation data loading"),
] = None,
output_dir: Annotated[
str,
typer.Option("--output_dir", help="Directory to save output files"),
] = "./evaluation/paligemma",
metrics: Annotated[
List[str],
typer.Option("--metrics", help="List of metrics to track during evaluation"),
] = [],
) -> None:
metric_objects = parse_metrics(metrics)
config = TrainingConfiguration(
dataset=dataset,
model_id=model_id,
revision=revision,
device=torch.device(device),
cache_dir=cache_dir,
batch_size=batch_size,
num_workers=num_workers,
val_num_workers=val_num_workers,
output_dir=output_dir,
metrics=metric_objects
)
typer.echo(typer.style(
text="Evaluation configuration",
fg=typer.colors.BRIGHT_GREEN,
bold=True
))
rich.print(dataclasses.asdict(config))
paligemma_evaluate(config=config)
87 changes: 87 additions & 0 deletions maestro/trainer/models/paligemma/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import re
from typing import List
from typing import Tuple

import numpy as np
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM

from maestro.trainer.common.data_loaders.datasets import DetectionDataset

DETECTION_CLASS_PATTERN = r"([a-zA-Z0-9 -]+)<loc_\d+>"


def postprocess_florence2_output_for_mean_average_precision(
expected_responses: List[str],
generated_texts: List[str],
images: List[Image.Image],
classes: List[str],
processor: AutoProcessor
) -> Tuple[List[sv.Detections], List[sv.Detections]]:
targets = []
predictions = []

for image, suffix, generated_text in zip(images, expected_responses, generated_texts):
# Postprocess prediction for mean average precision calculation
prediction = processor.post_process_generation(generated_text, task="<OD>", image_size=image.size)
prediction = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, prediction, resolution_wh=image.size)
if len(prediction) == 0:
prediction["class_name"] = []
prediction = prediction[np.isin(prediction["class_name"], classes)]
prediction.class_id = np.array([classes.index(class_name) for class_name in prediction["class_name"]])
# Set confidence for mean average precision calculation
prediction.confidence = np.ones(len(prediction))

# Postprocess target for mean average precision calculation
target = processor.post_process_generation(suffix, task="<OD>", image_size=image.size)
if len(target) == 0:
target["class_name"] = []
target = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, target, resolution_wh=image.size)
target.class_id = np.array([classes.index(class_name) for class_name in target["class_name"]])

targets.append(target)
predictions.append(prediction)

return targets, predictions


def run_predictions(
dataset: DetectionDataset,
processor: AutoProcessor,
model: AutoModelForCausalLM,
device: torch.device,
) -> Tuple[List[str], List[str], List[str], List[Image.Image]]:
prompts = []
expected_responses = []
generated_texts = []
images = []

for idx in tqdm(list(range(len(dataset))), desc="Generating predictions..."):
image, data = dataset.dataset[idx]
prefix = data["prefix"]
suffix = data["suffix"]

inputs = processor(text=prefix, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

prompts.append(prefix)
expected_responses.append(suffix)
generated_texts.append(generated_text)
images.append(image)

return prompts, expected_responses, generated_texts, images


def extract_unique_detection_dataset_classes(dataset: DetectionDataset) -> List[str]:
class_set = set()
for i in range(len(dataset)):
_, suffix, _ = dataset[i]
classes = re.findall(DETECTION_CLASS_PATTERN, suffix)
class_set.update(classes)
return sorted(class_set)