-
Notifications
You must be signed in to change notification settings - Fork 436
[chronos-2] add support for SDPA #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[chronos-2] add support for SDPA #331
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kashif! Left some comments. I think Flash attention won't work with Chronos-2 because masking is very important for the model. SDPA should work though. That said, I actually experimented with SDPA and FlexAttention while training these models. However, in the end I still went with manual attention + torch compile because I ran into weird issues. See:
pytorch/pytorch#149857
pytorch/pytorch#149767
Did you benchmark SDPA vs Eager on your end?
lr_scheduler_type="linear", | ||
warmup_ratio=0.0, | ||
optim="adamw_torch_fused", | ||
logging_dir=str(output_dir / "logs"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging_dir has been removed from the training arguments and the different report_to backends handle the logging within the output dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this work fine for the older transformers
versions supported by this package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we can safely not set this argument and it will work for older transformer versions... (as well as newer)
@kashif with SDPA, the results stay exactly the same and there's a small improvement in the runtime. ![]() |
We have two options:
@shchur @lostella I can't decide between the two. What do you think? |
assert position_ids is not None, "position_ids must be provided when self.use_rope=True" | ||
|
||
# Force eager attention if output_attentions is True (only eager returns weights) | ||
attn_implementation = self.config._attn_implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to access the private _attn_implementation
attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems for now this is the convention see e.g. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L277-L278
happy to remove fa2 if requested |
I am also more inclined for keeping SDPA only and keeping the codebase simpler. We can revisit this were inference speed become a concern |
ok removing fa2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert not self.is_gated_act, "gated activation is not supported" | ||
|
||
# Attention implementation - default to "sdpa" if not specified | ||
attn_implementation = attn_implementation or "sdpa" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just set the default value in the function signature to "sdpa"
, or is there some case where the current logic is required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed we need this as the from_pretrained
will get the config.json
from the s3, and since it does not have it, the attn_implementation
will be None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the default is set to sdpa
as in my suggestion above, this should not be needed right?
Co-authored-by: Oleksandr Shchur <[email protected]>
assert not self.is_gated_act, "gated activation is not supported" | ||
|
||
# Attention implementation - default to "sdpa" if not specified | ||
attn_implementation = attn_implementation or "sdpa" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the default is set to sdpa
as in my suggestion above, this should not be needed right?
Co-authored-by: Abdul Fatir <[email protected]>
Co-authored-by: Abdul Fatir <[email protected]>
Thanks a lot for your contribution @kashif! |
This pull request introduces configurable attention backends to the Chronos-2 model, allowing users to select between eager, SDPA, and FlashAttention-2 implementations.