-
Notifications
You must be signed in to change notification settings - Fork 88
[FEATURE] Add Donut & Flava model #1271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
cm = silence_mindspore_logger() if is_sharded else nullcontext() | ||
with cm: | ||
ms.load_param_into_net(model_to_load, state_dict, strict_load=True) | ||
model_to_load.load_state_dict(state_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change strict to False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of tied weights, there may be extra or missing parameters in the Hugging Face transformer checkpoint. Using strict=True
will raise an error, so we follow the same design as the Transformers repo by setting strict=False
.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the Donut and Flava models, a significant and valuable addition to the library. The change to use load_state_dict
for handling tied weights is also a crucial bug fix. The implementation of the new models appears to be a faithful port from the Hugging Face transformers
library. However, I've identified several issues, including two critical bugs in the FlavaForPreTraining
model that affect the model's output when return_dict=False
. Additionally, there are a few minor typos in docstrings and log messages that should be corrected to improve clarity and usability. Addressing these points will greatly enhance the quality of this contribution.
Relies on Mbart #1195
What does this PR do?
Fixes # (issue)
In MS2.6/2.7, when there is a tied weights scenario, the
load_param_into_net
API may produce unexpected results during the weight loading stage. To address this, we switched to using theload_state_dict
API, which more closely aligns with PyTorch's behavior.Here is an example that demonstrates the buggy result caused by using the
load_param_into_net
API with tied weights.Adds # (feature)
Add Donut & Flava model
Donut
Flava
Before submitting
What's New
. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx