From 7bf335a19aebeb234153d96d4842c88351aa8db9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 13:55:09 +0200 Subject: [PATCH 01/13] add support for spda and flash attention --- scripts/training/train.py | 1 - src/chronos/chronos2/config.py | 6 ++ src/chronos/chronos2/layers.py | 136 +++++++++++++++++++++++++++++-- src/chronos/chronos2/model.py | 2 + src/chronos/chronos2/pipeline.py | 1 - 5 files changed, 138 insertions(+), 8 deletions(-) diff --git a/scripts/training/train.py b/scripts/training/train.py index c16092e5..09d5d8ef 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -663,7 +663,6 @@ def main( lr_scheduler_type=lr_scheduler_type, warmup_ratio=warmup_ratio, optim=optim, - logging_dir=str(output_dir / "logs"), logging_strategy="steps", logging_steps=log_steps, save_strategy="steps", diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index f73fda5f..0ce511ea 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig): Token ID for padding/missing value token, by default 0 rope_theta The base theta for rotary position embedding (RoPE), by default 10000.0 + attn_implementation + The attention implementation to use. Options: "eager", "sdpa", "flash_attention_2", by default None (uses "sdpa") """ model_type = "t5" @@ -63,6 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, + attn_implementation: str | None = None, **kwargs, ): self.vocab_size = vocab_size @@ -83,6 +86,9 @@ def __init__( assert not self.is_gated_act, "gated activation is not supported" + # Attention implementation - default to "sdpa" if not specified + self._attn_implementation = attn_implementation or "sdpa" + # unused kwargs.pop("is_encoder_decoder", None) kwargs.pop("eos_token_id", None) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 2c4e6b33..04408ea9 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -155,6 +155,7 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): self.n_heads: int = config.num_heads self.dropout: float = config.dropout_rate self.inner_dim: int = self.n_heads * self.kv_proj_dim + self.config = config self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -165,6 +166,123 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): if use_rope: self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) + def _eager_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Eager attention implementation using manual matmul. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: [batch, n_heads, q_len, kv_len] + """ + # Compute attention weights (no scaling - this is the original Chronos-2 implementation) + scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" + scores += mask + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, attn_weights + + def _sdpa_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, None]: + """SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid) + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: None (SDPA doesn't return weights) + """ + attn_output = nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, # Match eager implementation (no scaling) + ) + + return attn_output, None + + def _flash_attention_2( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, None]: + """FlashAttention-2 implementation. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: None (FlashAttention doesn't return weights) + """ + try: + from flash_attn import flash_attn_func + except ImportError: + raise ImportError( + "FlashAttention-2 is not installed. Please install it with: " + "pip install flash-attn --no-build-isolation" + ) + + # FlashAttention expects inputs in shape [batch, seq_len, n_heads, head_dim] + # We have [batch, n_heads, seq_len, head_dim], so we need to transpose + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # FlashAttention only supports fp16 and bf16 + input_dtype = query_states.dtype + if input_dtype not in [torch.float16, torch.bfloat16]: + target_dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.dropout if self.training else 0.0, + softmax_scale=1.0, # Match eager implementation (no scaling) + causal=False, # Chronos uses bidirectional attention by default + ) + + # Convert back to original dtype if needed + if input_dtype not in [torch.float16, torch.bfloat16]: + attn_output = attn_output.to(input_dtype) + + # Transpose back to [batch, n_heads, seq_len, head_dim] + attn_output = attn_output.transpose(1, 2) + + return attn_output, None + def forward( self, hidden_states: torch.Tensor, @@ -190,6 +308,11 @@ def forward( if self.use_rope: assert position_ids is not None, "position_ids must be provided when self.use_rope=True" + # Force eager attention if output_attentions is True (only eager returns weights) + attn_implementation = self.config._attn_implementation + if output_attentions and attn_implementation != "eager": + attn_implementation = "eager" + seq_length = hidden_states.shape[1] def shape(states: torch.Tensor) -> torch.Tensor: @@ -215,12 +338,13 @@ def unshape(states: torch.Tensor) -> torch.Tensor: cos, sin = self.rope_embed(value_states, position_ids) query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) - # Compute attention weights - scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" - scores += mask - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # Dispatch to appropriate attention implementation + if attn_implementation == "sdpa": + attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) + elif attn_implementation == "flash_attention_2": + attn_output, attn_weights = self._flash_attention_2(query_states, key_states, value_states, mask) + else: # eager or default + attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) # Project attention output attn_output = unshape(attn_output) diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 06eb7081..ff5a5c01 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -199,6 +199,8 @@ class Chronos2Model(PreTrainedModel): config_class = Chronos2CoreConfig # type: ignore[assignment] _supports_long_horizon: bool = True _supports_future_covariates: bool = True + _supports_sdpa: bool = True + _supports_flash_attn_2: bool = True def __init__(self, config: Chronos2CoreConfig): assert hasattr(config, "chronos_config"), "Not a valid Chronos config" diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index e00d3362..2250fb25 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -211,7 +211,6 @@ def fit( lr_scheduler_type="linear", warmup_ratio=0.0, optim="adamw_torch_fused", - logging_dir=str(output_dir / "logs"), logging_strategy="steps", logging_steps=100, disable_tqdm=False, From 1cf40ac6d240783afbd67ccd0640d460b874aba6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 13:58:18 +0200 Subject: [PATCH 02/13] added tests --- test/test_chronos2.py | 191 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 753a9443..df7df1b8 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -936,3 +936,194 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future # Check predictions from the fine-tuned model are different from the original predictions assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy()) + + +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) +def test_pipeline_works_with_different_attention_implementations(attn_implementation): + """Test that the pipeline works with different attention implementations.""" + from chronos.chronos2.config import Chronos2CoreConfig + + # Load the dummy model + model_path = Path(__file__).parent / "dummy-chronos2-model" + + # Load with specified attention implementation + pipeline = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation=attn_implementation + ) + + # Verify the config has the correct attention implementation + assert pipeline.model.config._attn_implementation == attn_implementation + + # Test prediction with simple input + inputs = torch.rand(2, 1, 16) + prediction_length = 7 + + outputs = pipeline.predict(inputs, prediction_length=prediction_length) + + # Check outputs are valid + assert isinstance(outputs, list) and len(outputs) == 2 + for out in outputs: + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32) + + +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) +@pytest.mark.parametrize("output_attentions", [False, True]) +def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions): + """Test that attention implementations handle output_attentions correctly.""" + from chronos.chronos2.config import Chronos2CoreConfig + from chronos.chronos2.layers import MHA + + # Create config with specified attention implementation + config = Chronos2CoreConfig( + d_model=128, + d_kv=32, + num_heads=4, + dropout_rate=0.1, + attn_implementation=attn_implementation, + ) + + # Create MHA layer + mha = MHA(config, use_rope=True) + mha.eval() + + # Create dummy inputs + batch_size = 2 + seq_len = 10 + hidden_states = torch.randn(batch_size, seq_len, config.d_model) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) + + # Test forward pass + output = mha( + hidden_states=hidden_states, + mask=mask, + position_ids=position_ids, + output_attentions=output_attentions, + ) + + # Check output shape + assert output.hidden_states.shape == (batch_size, seq_len, config.d_model) + + # Check attention weights - should only be returned when output_attentions=True + if output_attentions: + assert output.attn_weights is not None + assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len) + else: + # SDPA and flash_attention_2 don't return weights + if attn_implementation in ["sdpa", "flash_attention_2"]: + assert output.attn_weights is None + + +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) +def test_attention_implementations_produce_consistent_outputs(attn_implementation): + """Test that different attention implementations produce similar outputs.""" + from chronos.chronos2.config import Chronos2CoreConfig + from chronos.chronos2.layers import MHA + + # Create config with specified attention implementation + config = Chronos2CoreConfig( + d_model=128, + d_kv=32, + num_heads=4, + dropout_rate=0.0, # Disable dropout for deterministic comparison + attn_implementation=attn_implementation, + ) + + # Create MHA layer + mha = MHA(config, use_rope=True) + mha.eval() + + # Create dummy inputs + batch_size = 2 + seq_len = 10 + torch.manual_seed(42) + hidden_states = torch.randn(batch_size, seq_len, config.d_model) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) + + # Run forward pass + with torch.no_grad(): + output = mha( + hidden_states=hidden_states, + mask=mask, + position_ids=position_ids, + output_attentions=False, + ) + + # Check output is valid (not NaN or Inf) + assert not torch.isnan(output.hidden_states).any() + assert not torch.isinf(output.hidden_states).any() + + +def test_flash_attention_2_implementation(): + """Test FlashAttention2 implementation if available.""" + pytest.importorskip("flash_attn", reason="flash_attn package not installed") + + from chronos.chronos2.config import Chronos2CoreConfig + from chronos.chronos2.layers import MHA + + # Create config with flash_attention_2 + config = Chronos2CoreConfig( + d_model=128, + d_kv=32, + num_heads=4, + dropout_rate=0.0, + attn_implementation="flash_attention_2", + ) + + # Create MHA layer + mha = MHA(config, use_rope=True) + mha.eval() + + # Create dummy inputs + batch_size = 2 + seq_len = 10 + hidden_states = torch.randn(batch_size, seq_len, config.d_model) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) + + # Test forward pass + with torch.no_grad(): + output = mha( + hidden_states=hidden_states, + mask=mask, + position_ids=position_ids, + output_attentions=False, + ) + + # Check output shape and validity + assert output.hidden_states.shape == (batch_size, seq_len, config.d_model) + assert not torch.isnan(output.hidden_states).any() + assert not torch.isinf(output.hidden_states).any() + # FlashAttention doesn't return weights + assert output.attn_weights is None + + +def test_eager_and_sdpa_produce_identical_outputs(pipeline): + """Test that eager and SDPA implementations produce identical outputs on full pipeline.""" + # Reload pipeline with SDPA + model_path = Path(__file__).parent / "dummy-chronos2-model" + pipeline_sdpa = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32 + ) + + # Note: the original pipeline fixture uses default attn_implementation which should be sdpa + # Force eager for comparison + pipeline_eager = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32 + ) + + # Test with random input + torch.manual_seed(42) + inputs = torch.rand(2, 1, 16) + prediction_length = 7 + + with torch.no_grad(): + outputs_eager = pipeline_eager.predict(inputs, prediction_length=prediction_length) + outputs_sdpa = pipeline_sdpa.predict(inputs, prediction_length=prediction_length) + + # Verify outputs match exactly + assert len(outputs_eager) == len(outputs_sdpa) + for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa): + # Should match exactly or very close (numerical precision) + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) From 5bab24d1809011adb6e8e6b3a712bf52818122a8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 14:03:14 +0200 Subject: [PATCH 03/13] fix imports --- test/test_chronos2.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/test_chronos2.py b/test/test_chronos2.py index df7df1b8..a1def519 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -14,6 +14,9 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input +from chronos.chronos2.config import Chronos2CoreConfig +from chronos.chronos2.layers import MHA + from test.util import validate_tensor DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model" @@ -941,8 +944,6 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future @pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) def test_pipeline_works_with_different_attention_implementations(attn_implementation): """Test that the pipeline works with different attention implementations.""" - from chronos.chronos2.config import Chronos2CoreConfig - # Load the dummy model model_path = Path(__file__).parent / "dummy-chronos2-model" @@ -970,9 +971,6 @@ def test_pipeline_works_with_different_attention_implementations(attn_implementa @pytest.mark.parametrize("output_attentions", [False, True]) def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions): """Test that attention implementations handle output_attentions correctly.""" - from chronos.chronos2.config import Chronos2CoreConfig - from chronos.chronos2.layers import MHA - # Create config with specified attention implementation config = Chronos2CoreConfig( d_model=128, @@ -1017,9 +1015,6 @@ def test_attention_implementations_with_output_attentions(attn_implementation, o @pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) def test_attention_implementations_produce_consistent_outputs(attn_implementation): """Test that different attention implementations produce similar outputs.""" - from chronos.chronos2.config import Chronos2CoreConfig - from chronos.chronos2.layers import MHA - # Create config with specified attention implementation config = Chronos2CoreConfig( d_model=128, @@ -1059,9 +1054,6 @@ def test_flash_attention_2_implementation(): """Test FlashAttention2 implementation if available.""" pytest.importorskip("flash_attn", reason="flash_attn package not installed") - from chronos.chronos2.config import Chronos2CoreConfig - from chronos.chronos2.layers import MHA - # Create config with flash_attention_2 config = Chronos2CoreConfig( d_model=128, From 362513a4b8fb5df02f6906cd3c9a40166cfd027c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 15:19:41 +0000 Subject: [PATCH 04/13] added group test --- test/test_chronos2.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/test/test_chronos2.py b/test/test_chronos2.py index a1def519..cccd28ae 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1105,17 +1105,45 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32 ) - # Test with random input + # Test 1: Simple univariate input torch.manual_seed(42) - inputs = torch.rand(2, 1, 16) + inputs_simple = torch.rand(2, 1, 16) prediction_length = 7 with torch.no_grad(): - outputs_eager = pipeline_eager.predict(inputs, prediction_length=prediction_length) - outputs_sdpa = pipeline_sdpa.predict(inputs, prediction_length=prediction_length) + outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length) + outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length) # Verify outputs match exactly assert len(outputs_eager) == len(outputs_sdpa) for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + # Test 2: Multivariate inputs with covariates to test group attention + np.random.seed(42) + torch.manual_seed(42) + inputs_grouped = [ + { + "target": np.random.randn(2, 36), + "past_covariates": { + "temperature": np.random.randn(36), + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36), + }, + "future_covariates": { + "temperature": np.random.randn(prediction_length), + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length), + }, + } + for _ in range(5) + ] + + with torch.no_grad(): + outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length) + outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length) + + # Verify outputs match for grouped inputs + assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped) + for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): + # Should match exactly or very close (numerical precision) + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) From 06a1aad1be864953f267c6c8947204b6c65ead9b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 22:15:07 +0200 Subject: [PATCH 05/13] flash attention doesnt work with 2d mask --- src/chronos/chronos2/config.py | 8 +++++--- src/chronos/chronos2/layers.py | 24 +++++++++++++++++++++--- test/test_chronos2.py | 12 ++++++++---- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index 0ce511ea..0abdeb9a 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -40,7 +40,7 @@ class Chronos2CoreConfig(PretrainedConfig): rope_theta The base theta for rotary position embedding (RoPE), by default 10000.0 attn_implementation - The attention implementation to use. Options: "eager", "sdpa", "flash_attention_2", by default None (uses "sdpa") + The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa") """ model_type = "t5" @@ -87,13 +87,15 @@ def __init__( assert not self.is_gated_act, "gated activation is not supported" # Attention implementation - default to "sdpa" if not specified - self._attn_implementation = attn_implementation or "sdpa" + attn_implementation = attn_implementation or "sdpa" # unused kwargs.pop("is_encoder_decoder", None) kwargs.pop("eos_token_id", None) - super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs) + super().__init__( + pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs + ) @dataclass diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 04408ea9..259dd60d 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -283,6 +283,22 @@ def _flash_attention_2( return attn_output, None + @staticmethod + def _mask_supports_flash_attention(mask: torch.Tensor | None) -> bool: + """Return True when the additive mask is compatible with FlashAttention-2.""" + if mask is None: + return True + if not torch.is_tensor(mask): + return mask == 0 + # FlashAttention-2 does not support low-rank/2D masks and only handles trivial additive masks. + if mask.ndim <= 2: + return False + if mask.numel() == 0: + return True + if mask.dtype == torch.bool: + return not mask.any() + return not torch.any(mask) + def forward( self, hidden_states: torch.Tensor, @@ -338,11 +354,13 @@ def unshape(states: torch.Tensor) -> torch.Tensor: cos, sin = self.rope_embed(value_states, position_ids) query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) + use_flash_attn = attn_implementation == "flash_attention_2" and self._mask_supports_flash_attention(mask) + # Dispatch to appropriate attention implementation - if attn_implementation == "sdpa": - attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) - elif attn_implementation == "flash_attention_2": + if use_flash_attn: attn_output, attn_weights = self._flash_attention_2(query_states, key_states, value_states, mask) + elif attn_implementation == "sdpa" or attn_implementation == "flash_attention_2": + attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) else: # eager or default attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) diff --git a/test/test_chronos2.py b/test/test_chronos2.py index cccd28ae..61a5e366 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1050,10 +1050,14 @@ def test_attention_implementations_produce_consistent_outputs(attn_implementatio assert not torch.isinf(output.hidden_states).any() +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention-2 requires a CUDA-capable device") def test_flash_attention_2_implementation(): """Test FlashAttention2 implementation if available.""" pytest.importorskip("flash_attn", reason="flash_attn package not installed") + device = torch.device("cuda") + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + # Create config with flash_attention_2 config = Chronos2CoreConfig( d_model=128, @@ -1064,15 +1068,15 @@ def test_flash_attention_2_implementation(): ) # Create MHA layer - mha = MHA(config, use_rope=True) + mha = MHA(config, use_rope=True).to(device=device, dtype=dtype) mha.eval() # Create dummy inputs batch_size = 2 seq_len = 10 - hidden_states = torch.randn(batch_size, seq_len, config.d_model) - position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) - mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) + hidden_states = torch.randn(batch_size, seq_len, config.d_model, device=device, dtype=dtype) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len, device=device, dtype=torch.float32) # Test forward pass with torch.no_grad(): From 8fca666e638905937c46710c6f539f2cf30831cc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 11:51:15 +0200 Subject: [PATCH 06/13] remove flash attention implementation --- src/chronos/chronos2/layers.py | 83 +---------------------------- src/chronos/chronos2/model.py | 1 - test/test_chronos2.py | 96 ++-------------------------------- 3 files changed, 7 insertions(+), 173 deletions(-) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 259dd60d..6613e551 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -224,80 +224,6 @@ def _sdpa_attention( return attn_output, None - def _flash_attention_2( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - mask: torch.Tensor, - ) -> tuple[torch.Tensor, None]: - """FlashAttention-2 implementation. - - Args: - query_states: [batch, n_heads, seq_len, kv_proj_dim] - key_states: [batch, n_heads, seq_len, kv_proj_dim] - value_states: [batch, n_heads, seq_len, kv_proj_dim] - mask: [batch, n_heads, q_len, kv_len] - - Returns: - attn_output: [batch, n_heads, seq_len, kv_proj_dim] - attn_weights: None (FlashAttention doesn't return weights) - """ - try: - from flash_attn import flash_attn_func - except ImportError: - raise ImportError( - "FlashAttention-2 is not installed. Please install it with: " - "pip install flash-attn --no-build-isolation" - ) - - # FlashAttention expects inputs in shape [batch, seq_len, n_heads, head_dim] - # We have [batch, n_heads, seq_len, head_dim], so we need to transpose - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # FlashAttention only supports fp16 and bf16 - input_dtype = query_states.dtype - if input_dtype not in [torch.float16, torch.bfloat16]: - target_dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout_p=self.dropout if self.training else 0.0, - softmax_scale=1.0, # Match eager implementation (no scaling) - causal=False, # Chronos uses bidirectional attention by default - ) - - # Convert back to original dtype if needed - if input_dtype not in [torch.float16, torch.bfloat16]: - attn_output = attn_output.to(input_dtype) - - # Transpose back to [batch, n_heads, seq_len, head_dim] - attn_output = attn_output.transpose(1, 2) - - return attn_output, None - - @staticmethod - def _mask_supports_flash_attention(mask: torch.Tensor | None) -> bool: - """Return True when the additive mask is compatible with FlashAttention-2.""" - if mask is None: - return True - if not torch.is_tensor(mask): - return mask == 0 - # FlashAttention-2 does not support low-rank/2D masks and only handles trivial additive masks. - if mask.ndim <= 2: - return False - if mask.numel() == 0: - return True - if mask.dtype == torch.bool: - return not mask.any() - return not torch.any(mask) def forward( self, @@ -354,14 +280,9 @@ def unshape(states: torch.Tensor) -> torch.Tensor: cos, sin = self.rope_embed(value_states, position_ids) query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_flash_attn = attn_implementation == "flash_attention_2" and self._mask_supports_flash_attention(mask) - - # Dispatch to appropriate attention implementation - if use_flash_attn: - attn_output, attn_weights = self._flash_attention_2(query_states, key_states, value_states, mask) - elif attn_implementation == "sdpa" or attn_implementation == "flash_attention_2": + if attn_implementation == "sdpa": attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) - else: # eager or default + else: # eager attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) # Project attention output diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index ff5a5c01..9b72f61a 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -200,7 +200,6 @@ class Chronos2Model(PreTrainedModel): _supports_long_horizon: bool = True _supports_future_covariates: bool = True _supports_sdpa: bool = True - _supports_flash_attn_2: bool = True def __init__(self, config: Chronos2CoreConfig): assert hasattr(config, "chronos_config"), "Not a valid Chronos config" diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 61a5e366..ebd4d9ed 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -320,13 +320,13 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs, _ = pipeline.predict(inputs, prediction_length=10) -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) def test_pipeline_predict_can_handle_different_model_and_input_dtypes( - torch_dtype: torch.dtype, input_dtype: torch.dtype + dtype: torch.dtype, input_dtype: torch.dtype ): pipeline = BaseChronosPipeline.from_pretrained( - Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype + Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype ) context = 10 * torch.rand(size=(4, 3, 16)) + 10 context = context.to(dtype=input_dtype) @@ -1007,94 +1007,11 @@ def test_attention_implementations_with_output_attentions(attn_implementation, o assert output.attn_weights is not None assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len) else: - # SDPA and flash_attention_2 don't return weights - if attn_implementation in ["sdpa", "flash_attention_2"]: + # SDPA doesn't return weights + if attn_implementation == "sdpa": assert output.attn_weights is None -@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) -def test_attention_implementations_produce_consistent_outputs(attn_implementation): - """Test that different attention implementations produce similar outputs.""" - # Create config with specified attention implementation - config = Chronos2CoreConfig( - d_model=128, - d_kv=32, - num_heads=4, - dropout_rate=0.0, # Disable dropout for deterministic comparison - attn_implementation=attn_implementation, - ) - - # Create MHA layer - mha = MHA(config, use_rope=True) - mha.eval() - - # Create dummy inputs - batch_size = 2 - seq_len = 10 - torch.manual_seed(42) - hidden_states = torch.randn(batch_size, seq_len, config.d_model) - position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) - mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) - - # Run forward pass - with torch.no_grad(): - output = mha( - hidden_states=hidden_states, - mask=mask, - position_ids=position_ids, - output_attentions=False, - ) - - # Check output is valid (not NaN or Inf) - assert not torch.isnan(output.hidden_states).any() - assert not torch.isinf(output.hidden_states).any() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention-2 requires a CUDA-capable device") -def test_flash_attention_2_implementation(): - """Test FlashAttention2 implementation if available.""" - pytest.importorskip("flash_attn", reason="flash_attn package not installed") - - device = torch.device("cuda") - dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - - # Create config with flash_attention_2 - config = Chronos2CoreConfig( - d_model=128, - d_kv=32, - num_heads=4, - dropout_rate=0.0, - attn_implementation="flash_attention_2", - ) - - # Create MHA layer - mha = MHA(config, use_rope=True).to(device=device, dtype=dtype) - mha.eval() - - # Create dummy inputs - batch_size = 2 - seq_len = 10 - hidden_states = torch.randn(batch_size, seq_len, config.d_model, device=device, dtype=dtype) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) - mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len, device=device, dtype=torch.float32) - - # Test forward pass - with torch.no_grad(): - output = mha( - hidden_states=hidden_states, - mask=mask, - position_ids=position_ids, - output_attentions=False, - ) - - # Check output shape and validity - assert output.hidden_states.shape == (batch_size, seq_len, config.d_model) - assert not torch.isnan(output.hidden_states).any() - assert not torch.isinf(output.hidden_states).any() - # FlashAttention doesn't return weights - assert output.attn_weights is None - - def test_eager_and_sdpa_produce_identical_outputs(pipeline): """Test that eager and SDPA implementations produce identical outputs on full pipeline.""" # Reload pipeline with SDPA @@ -1110,7 +1027,6 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): ) # Test 1: Simple univariate input - torch.manual_seed(42) inputs_simple = torch.rand(2, 1, 16) prediction_length = 7 @@ -1125,8 +1041,6 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) # Test 2: Multivariate inputs with covariates to test group attention - np.random.seed(42) - torch.manual_seed(42) inputs_grouped = [ { "target": np.random.randn(2, 36), From 5b4a90ee52963beb0090b5254aafb62d717bfd6a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 12:03:47 +0200 Subject: [PATCH 07/13] formatting --- src/chronos/chronos2/layers.py | 1 - test/test_chronos2.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 6613e551..3a396ca4 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -224,7 +224,6 @@ def _sdpa_attention( return attn_output, None - def forward( self, hidden_states: torch.Tensor, diff --git a/test/test_chronos2.py b/test/test_chronos2.py index ebd4d9ed..3c4281cb 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -322,9 +322,7 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs, @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) -def test_pipeline_predict_can_handle_different_model_and_input_dtypes( - dtype: torch.dtype, input_dtype: torch.dtype -): +def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype): pipeline = BaseChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype ) From 23ebe353dc0a190a5925d8a62d5b6d5d94824bc4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 12:52:56 +0200 Subject: [PATCH 08/13] Update src/chronos/chronos2/config.py Co-authored-by: Oleksandr Shchur --- src/chronos/chronos2/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index 0abdeb9a..2c5728b7 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -65,7 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, - attn_implementation: str | None = None, + attn_implementation: Literal["eager", "sdpa"] = "sdpa", **kwargs, ): self.vocab_size = vocab_size From 70e1b10bd0ac3e717bef2f05bce00ed130c0e401 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 13:07:04 +0200 Subject: [PATCH 09/13] make sure its eager or sdpa --- src/chronos/chronos2/config.py | 5 +++-- src/chronos/chronos2/layers.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index 2c5728b7..4a1baac3 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -4,7 +4,7 @@ # Authors: Abdul Fatir Ansari from dataclasses import dataclass -from typing import List +from typing import List, Literal, Optional from transformers.configuration_utils import PretrainedConfig @@ -65,7 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, - attn_implementation: Literal["eager", "sdpa"] = "sdpa", + attn_implementation: Optional[Literal["eager", "sdpa"]] = "sdpa", **kwargs, ): self.vocab_size = vocab_size @@ -88,6 +88,7 @@ def __init__( # Attention implementation - default to "sdpa" if not specified attn_implementation = attn_implementation or "sdpa" + assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported" # unused kwargs.pop("is_encoder_decoder", None) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 3a396ca4..b00e8a8c 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -251,7 +251,7 @@ def forward( # Force eager attention if output_attentions is True (only eager returns weights) attn_implementation = self.config._attn_implementation - if output_attentions and attn_implementation != "eager": + if output_attentions: attn_implementation = "eager" seq_length = hidden_states.shape[1] From 63a8a3ce13de852003508e1424766c09bb63845d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 13:10:27 +0200 Subject: [PATCH 10/13] set default to None --- src/chronos/chronos2/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index 4a1baac3..5cf6999a 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -65,7 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, - attn_implementation: Optional[Literal["eager", "sdpa"]] = "sdpa", + attn_implementation: Optional[Literal["eager", "sdpa"]] = None, **kwargs, ): self.vocab_size = vocab_size From f5540d2722013b411d6fce35f428757b5f0f3fd7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 13:41:42 +0200 Subject: [PATCH 11/13] Update src/chronos/chronos2/config.py Co-authored-by: Abdul Fatir --- src/chronos/chronos2/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index 5cf6999a..c851fd80 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -4,7 +4,7 @@ # Authors: Abdul Fatir Ansari from dataclasses import dataclass -from typing import List, Literal, Optional +from typing import List, Literal from transformers.configuration_utils import PretrainedConfig From 011181a5d2a877b4ce886038199dcba7373d07f4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 13:41:51 +0200 Subject: [PATCH 12/13] Update src/chronos/chronos2/config.py Co-authored-by: Abdul Fatir --- src/chronos/chronos2/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index c851fd80..a7333b49 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -65,7 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, - attn_implementation: Optional[Literal["eager", "sdpa"]] = None, + attn_implementation: Literal["eager", "sdpa"] = "sdpa", **kwargs, ): self.vocab_size = vocab_size From ff2515cd46919c551af8f208abc4c6df81edeb93 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Wed, 22 Oct 2025 11:44:58 +0000 Subject: [PATCH 13/13] Remove `Optional` and use `| None` --- src/chronos/chronos2/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index a7333b49..c6e011cb 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -65,7 +65,7 @@ def __init__( vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, - attn_implementation: Literal["eager", "sdpa"] = "sdpa", + attn_implementation: Literal["eager", "sdpa"] | None = None, **kwargs, ): self.vocab_size = vocab_size