Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a new callback HugginfaceCheckpoint #109

Merged
merged 2 commits into from
Aug 22, 2024

Conversation

liu-jc
Copy link
Contributor

@liu-jc liu-jc commented Aug 19, 2024

Created a new callback to save the checkpoints as in huggingface format.
Basically, we inherited ModelCheckpoint from pytorch lightning and override the inherent functions _save_checkpoint and _remove_checkpoint.

@liu-jc liu-jc self-assigned this Aug 19, 2024
@liu-jc liu-jc requested a review from cuthalionn August 19, 2024 06:05
Copy link
Collaborator

@cuthalionn cuthalionn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm. I just mentioned a few minor issues that might require reconsidering.

Comment on lines 76 to 85
if hasattr(pretrain_module, "module"):
moirai_module = pretrain_module.module

# filepath in pytorch lightning usually ends with .ckpt
# To get the directory to save the model, remove the .ckpt
if filepath.endswith(".ckpt"):
save_dir = filepath.split(".ckpt")[0]
else:
save_dir = filepath
moirai_module.save_pretrained(save_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like if the check on line 76 evaluates to false, the moirai_module won't be initialized, which would prevent the code on line 85 from executing. Should we consider raising an error in this scenario to handle it explicitly? Alternatively, do we need to revisit whether the check itself is necessary in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for raising this point! That's a good point. If we only use moirai, then it will always have 'module' attribute as stated in moirai pretrain class:

class MoiraiPretrain(L.LightningModule):

In other classes like moiraiFintune and MoiraiForecast, they also have the moirai_module.
When I implemented this, I was thinking if we consider this as a framework, later we may have another model class without this module attribute?
But may overthink? We can make sure it always has that. I also agree that we can raise an error when we cannot get the moirai_module or a try except segment for getting the eventual moirai_module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think both solutions work, having a try-except block sounds more long-term friendly to me, but up to you!

Comment on lines +106 to +115
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original ModelCheckpoint, I believe only the specific checkpoint file is removed. However, in our implementation, we're removing the entire folder where the pretrained model is saved.

I'm unsure if this aligns with the intended use case, but I wanted to highlight the difference between the original behavior and our current approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In the original ModelCheckpoint, it removes the file. But in huggingface checkpoint format, a checkpoint is always a folder containing config.json and 'model.safetensors'. That's why I deleted the whole folder. Does it make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I got it. Sounds good.

@liu-jc
Copy link
Contributor Author

liu-jc commented Aug 22, 2024

Hi @cuthalionn,

I made some changes following your suggestions. Could you pls take a look again? If no problem, I will merge it soon.

Copy link
Collaborator

@cuthalionn cuthalionn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@liu-jc liu-jc merged commit a4c416a into SalesforceAIResearch:main Aug 22, 2024
4 checks passed
@liu-jc
Copy link
Contributor Author

liu-jc commented Aug 22, 2024

Directly related to #76.
May also related to #66, if we have the huggingface checkpoint locally. The issue might be solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants