Skip to content

Implement a new callback HugginfaceCheckpoint #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cli/conf/pretrain/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ trainer:
mode: max
save_top_k: -1
every_n_epochs: ${floordiv:${trainer.max_epochs},10}
- _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint
dirpath: ${hydra:runtime.output_dir}/HF_checkpoints
filename: last
monitor: epoch
mode: max
save_top_k: 1
every_n_epochs: 1
# epoch-based training provides averaged metrics
# cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch
max_epochs: 1_000
Expand Down
117 changes: 117 additions & 0 deletions src/uni2ts/callbacks/HuggingFaceCheckpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2024, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import re
import shutil
import time
import warnings
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from weakref import proxy

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.utilities.rank_zero import (
WarningCache,
rank_zero_info,
rank_zero_warn,
)

log = logging.getLogger(__name__)
warning_cache = WarningCache()

_PATH = Union[str, Path]


class HuggingFaceCheckpoint(ModelCheckpoint):
r"""
Save the model to the Hugging Face format.
Inherit from `pytorch_lightning.callbacks.ModelCheckpoint` and override the `_save_checkpoint` method.
"""

def __init__(
self,
dirpath: Optional[_PATH] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
save_top_k: int = 1,
mode: str = "min",
every_n_epochs: Optional[int] = None,
):
super().__init__(
dirpath=dirpath,
filename=filename,
monitor=monitor,
mode=mode,
save_top_k=save_top_k,
every_n_epochs=every_n_epochs,
)

def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
# Only save the checkpoint if it is in the main process
if not trainer.is_global_zero:
return

# Extract the model from the Lightning module
pl_module = trainer.model
pretrain_module = pl_module.module

try:
moirai_module = pretrain_module.module
except AttributeError:
moirai_module = pretrain_module
warnings.warn(
"Warning: no module attribute found in the model. Saving the model directly."
)

# filepath in pytorch lightning usually ends with .ckpt
# To get the directory to save the model, remove the .ckpt
if filepath.endswith(".ckpt"):
save_dir = filepath.split(".ckpt")[0]
else:
save_dir = filepath

try:
moirai_module.save_pretrained(save_dir)
except Exception as e:
warnings.warn(f"An error occurred during model saving: {e}")

self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = save_dir

# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
# Only remove the checkpoint if it is in the main process
if not trainer.is_global_zero:
return

# filepath in pytorch lightning usually ends with .ckpt
# To get the directory to save the model, remove the .ckpt
if filepath.endswith(".ckpt"):
save_dir = filepath.split(".ckpt")[0]
else:
save_dir = filepath
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
Comment on lines +114 to +115
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original ModelCheckpoint, I believe only the specific checkpoint file is removed. However, in our implementation, we're removing the entire folder where the pretrained model is saved.

I'm unsure if this aligns with the intended use case, but I wanted to highlight the difference between the original behavior and our current approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In the original ModelCheckpoint, it removes the file. But in huggingface checkpoint format, a checkpoint is always a folder containing config.json and 'model.safetensors'. That's why I deleted the whole folder. Does it make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I got it. Sounds good.

else:
warnings.warn(f"Checkpoint not found: {save_dir}")
14 changes: 14 additions & 0 deletions src/uni2ts/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2024, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading