-
Notifications
You must be signed in to change notification settings - Fork 450
[chronos-2] add support for SDPA #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7bf335a
1cf40ac
5bab24d
362513a
06a1aad
661679c
8fca666
5b4a90e
23ebe35
70e1b10
63a8a3c
f5540d2
011181a
ff2515c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| # Authors: Abdul Fatir Ansari <[email protected]> | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List | ||
| from typing import List, Literal | ||
|
|
||
| from transformers.configuration_utils import PretrainedConfig | ||
|
|
||
|
|
@@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig): | |
| Token ID for padding/missing value token, by default 0 | ||
| rope_theta | ||
| The base theta for rotary position embedding (RoPE), by default 10000.0 | ||
| attn_implementation | ||
| The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa") | ||
| """ | ||
|
|
||
| model_type = "t5" | ||
|
|
@@ -63,6 +65,7 @@ def __init__( | |
| vocab_size: int = 2, | ||
| pad_token_id: int = 0, | ||
| rope_theta: float = 10000.0, | ||
| attn_implementation: Literal["eager", "sdpa"] | None = None, | ||
| **kwargs, | ||
| ): | ||
| self.vocab_size = vocab_size | ||
|
|
@@ -83,11 +86,17 @@ def __init__( | |
|
|
||
| assert not self.is_gated_act, "gated activation is not supported" | ||
|
|
||
| # Attention implementation - default to "sdpa" if not specified | ||
| attn_implementation = attn_implementation or "sdpa" | ||
| assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported" | ||
|
|
||
| # unused | ||
| kwargs.pop("is_encoder_decoder", None) | ||
| kwargs.pop("eos_token_id", None) | ||
|
|
||
| super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs) | ||
| super().__init__( | ||
| pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,6 +155,7 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): | |
| self.n_heads: int = config.num_heads | ||
| self.dropout: float = config.dropout_rate | ||
| self.inner_dim: int = self.n_heads * self.kv_proj_dim | ||
| self.config = config | ||
|
|
||
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
| self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
|
|
@@ -165,6 +166,64 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): | |
| if use_rope: | ||
| self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) | ||
|
|
||
| def _eager_attention( | ||
| self, | ||
| query_states: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Eager attention implementation using manual matmul. | ||
|
|
||
| Args: | ||
| query_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| key_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| value_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| mask: [batch, n_heads, q_len, kv_len] | ||
|
|
||
| Returns: | ||
| attn_output: [batch, n_heads, seq_len, kv_proj_dim] | ||
| attn_weights: [batch, n_heads, q_len, kv_len] | ||
| """ | ||
| # Compute attention weights (no scaling - this is the original Chronos-2 implementation) | ||
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" | ||
| scores += mask | ||
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
| def _sdpa_attention( | ||
| self, | ||
| query_states: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| """SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention. | ||
|
|
||
| Args: | ||
| query_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| key_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| value_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid) | ||
|
|
||
| Returns: | ||
| attn_output: [batch, n_heads, seq_len, kv_proj_dim] | ||
| attn_weights: None (SDPA doesn't return weights) | ||
| """ | ||
| attn_output = nn.functional.scaled_dot_product_attention( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask=mask, | ||
| dropout_p=self.dropout if self.training else 0.0, | ||
| scale=1.0, # Match eager implementation (no scaling) | ||
| ) | ||
|
|
||
| return attn_output, None | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -190,6 +249,11 @@ def forward( | |
| if self.use_rope: | ||
| assert position_ids is not None, "position_ids must be provided when self.use_rope=True" | ||
|
|
||
| # Force eager attention if output_attentions is True (only eager returns weights) | ||
| attn_implementation = self.config._attn_implementation | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to access the private There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems for now this is the convention see e.g. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L277-L278 |
||
| if output_attentions: | ||
| attn_implementation = "eager" | ||
|
|
||
| seq_length = hidden_states.shape[1] | ||
|
|
||
| def shape(states: torch.Tensor) -> torch.Tensor: | ||
|
|
@@ -215,12 +279,10 @@ def unshape(states: torch.Tensor) -> torch.Tensor: | |
| cos, sin = self.rope_embed(value_states, position_ids) | ||
| query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| # Compute attention weights | ||
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" | ||
| scores += mask | ||
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
| if attn_implementation == "sdpa": | ||
| attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) | ||
| else: # eager | ||
kashif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) | ||
|
|
||
| # Project attention output | ||
| attn_output = unshape(attn_output) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,7 +211,6 @@ def fit( | |
| lr_scheduler_type="linear", | ||
| warmup_ratio=0.0, | ||
| optim="adamw_torch_fused", | ||
| logging_dir=str(output_dir / "logs"), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logging_dir has been removed from the training arguments and the different report_to backends handle the logging within the output dir There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this work fine for the older There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we can safely not set this argument and it will work for older transformer versions... (as well as newer) |
||
| logging_strategy="steps", | ||
| logging_steps=100, | ||
| disable_tqdm=False, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just set the default value in the function signature to
"sdpa", or is there some case where the current logic is required?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed we need this as the
from_pretrainedwill get theconfig.jsonfrom the s3, and since it does not have it, theattn_implementationwill beNoneThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the default is set to
sdpaas in my suggestion above, this should not be needed right?