Skip to content

Commit b8e31b5

Browse files
committed
Implement a new callback HugginfaceCheckpoint
1 parent 572445e commit b8e31b5

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2024, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import re
19+
import shutil
20+
import time
21+
import warnings
22+
from copy import deepcopy
23+
from datetime import timedelta
24+
from pathlib import Path
25+
from typing import Any, Dict, Literal, Optional, Set, Union
26+
from weakref import proxy
27+
28+
import lightning.pytorch as pl
29+
from lightning.pytorch.callbacks import ModelCheckpoint
30+
from lightning.pytorch.utilities.rank_zero import (
31+
WarningCache,
32+
rank_zero_info,
33+
rank_zero_warn,
34+
)
35+
36+
log = logging.getLogger(__name__)
37+
warning_cache = WarningCache()
38+
39+
40+
_PATH = Union[str, Path]
41+
42+
43+
class HuggingFaceCheckpoint(ModelCheckpoint):
44+
r"""
45+
Save the model to the Hugging Face format.
46+
Inherit from `pytorch_lightning.callbacks.ModelCheckpoint` and override the `_save_checkpoint` method.
47+
"""
48+
49+
def __init__(
50+
self,
51+
dirpath: Optional[_PATH] = None,
52+
filename: Optional[str] = None,
53+
monitor: Optional[str] = None,
54+
save_top_k: int = 1,
55+
mode: str = "min",
56+
every_n_epochs: Optional[int] = None,
57+
):
58+
super().__init__(
59+
dirpath=dirpath,
60+
filename=filename,
61+
monitor=monitor,
62+
mode=mode,
63+
save_top_k=save_top_k,
64+
every_n_epochs=every_n_epochs,
65+
)
66+
67+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
68+
# Only save the checkpoint if it is in the main process
69+
if not trainer.is_global_zero:
70+
return
71+
72+
# Extract the model from the Lightning module
73+
pl_module = trainer.model
74+
pretrain_module = pl_module.module
75+
76+
if hasattr(pretrain_module, "module"):
77+
moirai_module = pretrain_module.module
78+
79+
# filepath in pytorch lightning usually ends with .ckpt
80+
# To get the directory to save the model, remove the .ckpt
81+
if filepath.endswith(".ckpt"):
82+
save_dir = filepath.split(".ckpt")[0]
83+
else:
84+
save_dir = filepath
85+
moirai_module.save_pretrained(save_dir)
86+
87+
self._last_global_step_saved = trainer.global_step
88+
self._last_checkpoint_saved = save_dir
89+
90+
# notify loggers
91+
if trainer.is_global_zero:
92+
for logger in trainer.loggers:
93+
logger.after_save_checkpoint(proxy(self))
94+
95+
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
96+
# Only remove the checkpoint if it is in the main process
97+
if not trainer.is_global_zero:
98+
return
99+
100+
# filepath in pytorch lightning usually ends with .ckpt
101+
# To get the directory to save the model, remove the .ckpt
102+
if filepath.endswith(".ckpt"):
103+
save_dir = filepath.split(".ckpt")[0]
104+
else:
105+
save_dir = filepath
106+
if os.path.exists(save_dir):
107+
shutil.rmtree(save_dir)
108+
else:
109+
warnings.warn(f"Checkpoint not found: {save_dir}")

src/uni2ts/callbacks/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2024, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

0 commit comments

Comments
 (0)