-
Couldn't load subscription status.
- Fork 88
feat(diffusers): add examples/diffusers/tests #1372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(diffusers): add examples/diffusers/tests #1372
Conversation
Summary of ChangesHello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the testing infrastructure for diffusers examples by adding dedicated test files for ControlNet, DreamBooth, Text-to-Image, Textual Inversion, and Unconditional Image Generation. It also refines the checkpointing mechanism in training scripts by adopting the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive test suite for various diffusers examples, enhancing the project's robustness. It also refactors the checkpointing mechanism in several training scripts to consistently use the .safetensors format and the save_pretrained method, which is a good improvement for model serialization. My review identifies a few critical issues in the new test files and one of the updated training scripts that need to be addressed.
| with ms.load_checkpoint(state_dict_file, format="safetensors") as f: | ||
| metadata = f.metadata() or {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of ms.load_checkpoint as a context manager is incorrect, as it returns a dictionary and does not support the with statement. To read metadata from a .safetensors file, safetensors.safe_open should be used. You will need to add from safetensors import safe_open to the file imports.
| with ms.load_checkpoint(state_dict_file, format="safetensors") as f: | |
| metadata = f.metadata() or {} | |
| from safetensors import safe_open | |
| with safe_open(state_dict_file, framework="np") as f: | |
| metadata = f.metadata() or {} |
| # TODO: load optimizer & grad scaler etc. like accelerator.load_state | ||
| input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt") | ||
| input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors") | ||
| ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ms.load_checkpoint function is being called to load a .safetensors file, but the format argument is missing. Without format="safetensors", MindSpore will try to load it as a standard checkpoint file, which will fail.
| ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True) | |
| ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True) |
|
|
||
|
|
||
| class ControlNetSD35(ExamplesTests): | ||
| def test_controlnet_sd3(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test method test_controlnet_sd3 in class ControlNetSD35 has a duplicated name with the test method in class ControlNetSD3. Test method names within a test class and its inherited test cases should be unique. This will cause one of the tests to be skipped by the test runner. Please rename it to reflect the test case, for example, test_controlnet_sd35.
| def test_controlnet_sd3(self): | |
| def test_controlnet_sd35(self): |
3b81766 to
ada9cd1
Compare
fb61a3d to
58c81ef
Compare
58c81ef to
e8afa2a
Compare
845cf22 to
18234b8
Compare
18234b8 to
494debb
Compare
What does this PR do?
Fixes # (issue)
Adds # (feature)
Before submitting
What's New. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx