Skip to content

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Oct 21, 2025

This pull request introduces configurable attention backends to the Chronos-2 model, allowing users to select between eager, SDPA, and FlashAttention-2 implementations.

Copy link
Contributor

@abdulfatir abdulfatir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kashif! Left some comments. I think Flash attention won't work with Chronos-2 because masking is very important for the model. SDPA should work though. That said, I actually experimented with SDPA and FlexAttention while training these models. However, in the end I still went with manual attention + torch compile because I ran into weird issues. See:

pytorch/pytorch#149857
pytorch/pytorch#149767

Did you benchmark SDPA vs Eager on your end?

lr_scheduler_type="linear",
warmup_ratio=0.0,
optim="adamw_torch_fused",
logging_dir=str(output_dir / "logs"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging_dir has been removed from the training arguments and the different report_to backends handle the logging within the output dir

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work fine for the older transformers versions supported by this package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can safely not set this argument and it will work for older transformer versions... (as well as newer)

@abdulfatir abdulfatir added the run-eval Run evaluation CI workflow label Oct 21, 2025
@abdulfatir
Copy link
Contributor

abdulfatir commented Oct 21, 2025

@kashif with SDPA, the results stay exactly the same and there's a small improvement in the runtime.

image

@abdulfatir
Copy link
Contributor

abdulfatir commented Oct 22, 2025

Ran full eval with FA2. It leads to a very slight worsening of the scores but is clearly faster overall.

image

@abdulfatir
Copy link
Contributor

We have two options:

  • Keep the FA2 in this PR in hope that someone may benefit from it.
Pros Cons
Faster inference for specific cases Non-transparent attn mode
- Probably won't be tested a lot (dead code)
  • Only keep SDPA. This goes with the minimalist spirit of the package to reduce future maintenance burden.

@shchur @lostella I can't decide between the two. What do you think?

@abdulfatir abdulfatir changed the title [chronos-2] add support for spda and flash attention [chronos-2] add support for SDPA and flash attention 2 Oct 22, 2025
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"

# Force eager attention if output_attentions is True (only eager returns weights)
attn_implementation = self.config._attn_implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to access the private _attn_implementation attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif
Copy link
Contributor Author

kashif commented Oct 22, 2025

happy to remove fa2 if requested

@lostella
Copy link
Contributor

I am also more inclined for keeping SDPA only and keeping the codebase simpler. We can revisit this were inference speed become a concern

@kashif
Copy link
Contributor Author

kashif commented Oct 22, 2025

ok removing fa2

Copy link
Contributor

@abdulfatir abdulfatir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great PR @kashif! Let's wait for @shchur's approval before merging.

@abdulfatir abdulfatir changed the title [chronos-2] add support for SDPA and flash attention 2 [chronos-2] add support for SDPA Oct 22, 2025
assert not self.is_gated_act, "gated activation is not supported"

# Attention implementation - default to "sdpa" if not specified
attn_implementation = attn_implementation or "sdpa"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set the default value in the function signature to "sdpa", or is there some case where the current logic is required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed we need this as the from_pretrained will get the config.json from the s3, and since it does not have it, the attn_implementation will be None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the default is set to sdpa as in my suggestion above, this should not be needed right?

assert not self.is_gated_act, "gated activation is not supported"

# Attention implementation - default to "sdpa" if not specified
attn_implementation = attn_implementation or "sdpa"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the default is set to sdpa as in my suggestion above, this should not be needed right?

@abdulfatir abdulfatir merged commit ca9c327 into amazon-science:main Oct 22, 2025
6 checks passed
@kashif kashif deleted the add-attention-implementations branch October 22, 2025 12:10
@shchur
Copy link
Contributor

shchur commented Oct 22, 2025

Thanks a lot for your contribution @kashif!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-eval Run evaluation CI workflow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants