Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PaliGemma Training Pipeline #44

Closed
wants to merge 10 commits into from
Next Next commit
first draft
SangbumChoi committed Sep 15, 2024

Verified

This commit was signed with the committer’s verified signature.
suejung-sentry Suejung Shin
commit 8ab77b5180940be9e5bbf6158ecbb4e1ce09c1b5
561 changes: 561 additions & 0 deletions cookbooks/maestro_paligemma_object_detection.ipynb

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions maestro/trainer/common/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Tuple, Optional, Literal, Union, Iterator

LoraInitLiteral = Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]

def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)
30 changes: 1 addition & 29 deletions maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
save_metric_plots,
MeanAveragePrecisionMetric
)
from maestro.trainer.common.utils.peft import prepare_peft_model, LoraInitLiteral
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.checkpoints import (
CheckpointManager,
@@ -32,7 +33,6 @@
postprocess_florence2_output_for_mean_average_precision,
run_predictions,
)
from maestro.trainer.models.paligemma.training import LoraInitLiteral


@dataclass(frozen=True)
@@ -159,34 +159,6 @@ def train(config: TrainingConfiguration) -> None:
print(f"Best checkpoint saved at: {checkpoint_manager.best_checkpoint_dir}")


def prepare_peft_model(
model: AutoModelForCausalLM,
r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.05,
bias: Literal["none", "all", "lora_only"] = "none",
inference_mode: bool = False,
use_rslora: bool = True,
init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian",
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
) -> PeftModel:
config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
lora_dropout=lora_dropout,
bias=bias,
inference_mode=inference_mode,
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
revision=revision,
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
return peft_model.to(model.device)


def run_training_loop(
processor: AutoProcessor,
model: PeftModel,
1 change: 1 addition & 0 deletions maestro/trainer/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from maestro.trainer.models.paligemma.core import TrainingConfiguration, train
119 changes: 119 additions & 0 deletions maestro/trainer/models/paligemma/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
from typing import Optional, Tuple

import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, \
DEFAULT_CUDA_DEVICE

DEFAULT_PALIGEMMA_MODEL_ID = "google/paligemma-3b-pt-224"
DEFAULT_PALIGEMMA_MODEL_REVISION = "float16"
DEVICE = torch.device("cpu") \
if not torch.cuda.is_available() \
else os.getenv(CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE)


class CheckpointManager:
"""Manages checkpoints for model training.
This class handles saving and retrieving model checkpoints during training.
Attributes:
training_dir (str): Directory where checkpoints will be saved.
best_val_loss (float): Best validation loss achieved so far.
latest_checkpoint_dir (str): Directory for the latest checkpoint.
best_checkpoint_dir (str): Directory for the best checkpoint.
"""

def __init__(self, training_dir: str):
"""Initializes the CheckpointManager.
Args:
training_dir (str): Directory where checkpoints will be saved.
"""
self.training_dir = training_dir
self.best_val_loss = float('inf')
self.latest_checkpoint_dir = os.path.join(training_dir, "checkpoints", "latest")
self.best_checkpoint_dir = os.path.join(training_dir, "checkpoints", "best")

def save_latest(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration):
"""Saves the latest model checkpoint.
Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
save_model(self.latest_checkpoint_dir, processor, model)

def save_best(self, processor: AutoProcessor, model: PaliGemmaForConditionalGeneration, val_loss: float):
"""Saves the best model checkpoint if the validation loss improves.
Args:
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
val_loss (float): The current validation loss.
"""
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
save_model(self.best_checkpoint_dir, processor, model)
print(f"New best model saved with validation loss: {self.best_val_loss}")

def get_best_model_path(self):
"""Returns the path to the best model checkpoint.
Returns:
str: Path to the best model checkpoint.
"""
return self.best_checkpoint_dir


def save_model(
target_dir: str,
processor: AutoProcessor,
model: PaliGemmaForConditionalGeneration,
) -> None:
"""Saves the model and processor to the specified directory.
Args:
target_dir (str): Directory where the model and processor will be saved.
processor (AutoProcessor): The processor to save.
model (PaliGemmaForConditionalGeneration): The model to save.
"""
os.makedirs(target_dir, exist_ok=True)
processor.save_pretrained(target_dir)
model.save_pretrained(target_dir)


def load_model(
model_id_or_path: str = DEFAULT_PALIGEMMA_MODEL_ID,
revision: str = DEFAULT_PALIGEMMA_MODEL_REVISION,
device: torch.device = DEVICE,
cache_dir: Optional[str] = None,
) -> Tuple[AutoProcessor, PaliGemmaForConditionalGeneration]:
"""Loads a PaliGemma model and its associated processor.
Args:
model_id_or_path: The identifier or path of the model to load.
revision: The specific model revision to use.
device: The device to load the model onto.
cache_dir: Directory to cache the downloaded model files.
Returns:
A tuple containing the loaded processor and model.
Raises:
ValueError: If the model or processor cannot be loaded.
"""
processor = AutoProcessor.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id_or_path,
trust_remote_code=True,
revision=revision,
cache_dir=cache_dir,
).to(device)
return processor, model
Loading