-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a new callback HugginfaceCheckpoint #109
Conversation
b8e31b5
to
778080d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall lgtm. I just mentioned a few minor issues that might require reconsidering.
if hasattr(pretrain_module, "module"): | ||
moirai_module = pretrain_module.module | ||
|
||
# filepath in pytorch lightning usually ends with .ckpt | ||
# To get the directory to save the model, remove the .ckpt | ||
if filepath.endswith(".ckpt"): | ||
save_dir = filepath.split(".ckpt")[0] | ||
else: | ||
save_dir = filepath | ||
moirai_module.save_pretrained(save_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like if the check on line 76 evaluates to false, the moirai_module won't be initialized, which would prevent the code on line 85 from executing. Should we consider raising an error in this scenario to handle it explicitly? Alternatively, do we need to revisit whether the check itself is necessary in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for raising this point! That's a good point. If we only use moirai, then it will always have 'module' attribute as stated in moirai pretrain class:
class MoiraiPretrain(L.LightningModule): |
In other classes like moiraiFintune and MoiraiForecast, they also have the moirai_module.
When I implemented this, I was thinking if we consider this as a framework, later we may have another model class without this module attribute?
But may overthink? We can make sure it always has that. I also agree that we can raise an error when we cannot get the moirai_module or a try except segment for getting the eventual moirai_module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think both solutions work, having a try-except block sounds more long-term friendly to me, but up to you!
if os.path.exists(save_dir): | ||
shutil.rmtree(save_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original ModelCheckpoint, I believe only the specific checkpoint file is removed. However, in our implementation, we're removing the entire folder where the pretrained model is saved.
I'm unsure if this aligns with the intended use case, but I wanted to highlight the difference between the original behavior and our current approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. In the original ModelCheckpoint, it removes the file. But in huggingface checkpoint format, a checkpoint is always a folder containing config.json
and 'model.safetensors'. That's why I deleted the whole folder. Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I got it. Sounds good.
5b953de
to
d38b388
Compare
Hi @cuthalionn, I made some changes following your suggestions. Could you pls take a look again? If no problem, I will merge it soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
d38b388
to
5d11e59
Compare
Created a new callback to save the checkpoints as in huggingface format.
Basically, we inherited ModelCheckpoint from pytorch lightning and override the inherent functions
_save_checkpoint
and_remove_checkpoint
.