Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PaliGemma Training Pipeline #44

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
9,279 changes: 9,279 additions & 0 deletions cookbooks/maestro_paligemma_object_detection.ipynb

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions maestro/trainer/common/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Literal, Union

from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, PaliGemmaForConditionalGeneration

LoraInitLiteral = Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
# TO DO
# Make revision as RevisionLiteral?


def prepare_peft_model(
model: Union[AutoModelForCausalLM, PaliGemmaForConditionalGeneration],
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = "",
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)
32 changes: 2 additions & 30 deletions maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import Literal, Optional, Union

import torch
from peft import LoraConfig, PeftModel, get_peft_model
from peft import PeftModel
from torch.optim import SGD, Adam, AdamW, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler

from maestro.trainer.common.peft import LoraInitLiteral, prepare_peft_model
from maestro.trainer.common.utils.file_system import create_new_run_directory
from maestro.trainer.common.utils.metrics import (
BaseMetric,
Expand All @@ -33,7 +34,6 @@
process_output_for_detection_metric,
process_output_for_text_metric,
)
from maestro.trainer.models.paligemma.training import LoraInitLiteral


@dataclass(frozen=True)
Expand Down Expand Up @@ -174,34 +174,6 @@ def train(config: Configuration) -> None:
print(f"Best checkpoint saved at: {checkpoint_manager.best_checkpoint_dir}")


def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)


def run_training_loop(
processor: AutoProcessor,
model: PeftModel,
Expand Down
1 change: 1 addition & 0 deletions maestro/trainer/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from maestro.trainer.models.paligemma.core import Configuration, train
116 changes: 116 additions & 0 deletions maestro/trainer/models/paligemma/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from typing import Optional

import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE

DEFAULT_PALIGEMMA_MODEL_ID = "google/paligemma-3b-pt-224"
DEFAULT_PALIGEMMA_MODEL_REVISION = "float16"
DEVICE = torch.device("cpu") if not torch.cuda.is_available() else os.getenv(CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE)


class CheckpointManager:
"""Manages checkpoints for model training.

This class handles saving and retrieving model checkpoints during training.

Attributes:
training_dir (str): Directory where checkpoints will be saved.
best_val_loss (float): Best validation loss achieved so far.
latest_checkpoint_dir (str): Directory for the latest checkpoint.
best_checkpoint_dir (str): Directory for the best checkpoint.
"""

def __init__(self, training_dir: str):
"""Initializes the CheckpointManager.

Args:
training_dir (str): Directory where checkpoints will be saved.
"""
self.training_dir = training_dir
self.best_val_loss = float("inf")
self.latest_checkpoint_dir = os.path.join(training_dir, "checkpoints", "latest")
self.best_checkpoint_dir = os.path.join(training_dir, "checkpoints", "best")

def save_latest(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration):
"""Saves the latest model checkpoint.

Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
save_model(self.latest_checkpoint_dir, processor, model)

def save_best(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration, val_loss: float):
"""Saves the best model checkpoint if the validation loss improves.

Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
val_loss (float): The current validation loss.
"""
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
save_model(self.best_checkpoint_dir, processor, model)
print(f"New best model saved with validation loss: {self.best_val_loss}")

def get_best_model_path(self):
"""Returns the path to the best model checkpoint.

Returns:
str: Path to the best model checkpoint.
"""
return self.best_checkpoint_dir


def save_model(
target_dir: str,
processor: AutoProcessor,
model: PaliGemmaForConditionalGeneration,
) -> None:
"""Saves the model and processor to the specified directory.

Args:
target_dir (str): Directory where the model and processor will be saved.
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
os.makedirs(target_dir, exist_ok=True)
processor.save_pretrained(target_dir)
model.save_pretrained(target_dir)


def load_model(
model_id_or_path: str = DEFAULT_PALIGEMMA_MODEL_ID,
revision: str = DEFAULT_PALIGEMMA_MODEL_REVISION,
device: torch.device = DEVICE,
cache_dir: Optional[str] = None,
) -> tuple[AutoProcessor, PaliGemmaForConditionalGeneration]:
"""Loads a PaliGemma model and its associated processor.

Args:
model_id_or_path: The identifier or path of the model to load.
revision: The specific model revision to use.
device: The device to load the model onto.
cache_dir: Directory to cache the downloaded model files.

Returns:
A tuple containing the loaded processor and model.

Raises:
ValueError: If the model or processor cannot be loaded.
"""
processor = AutoProcessor.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
cache_dir=cache_dir,
).to(device)
return processor, model
Loading