Skip to content

Commit 9e6b819

Browse files
IvyZXFlax Authors
authored and
Flax Authors
committed
Make force_fp32_for_softmax arg in MultiHeadDotProductAttention useful.
Fixes #4008 PiperOrigin-RevId: 646679331
1 parent 3b21870 commit 9e6b819

File tree

2 files changed

+22
-29
lines changed

2 files changed

+22
-29
lines changed

flax/linen/attention.py

+21-25
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import functools
19+
import inspect
1920
import warnings
2021
from typing import Any, Callable, Optional, Union, overload
2122

@@ -574,33 +575,28 @@ def __call__(
574575
m_deterministic = True
575576

576577
# apply attention
578+
attn_args = (query, key, value)
579+
# This kwargs list match the default nn.dot_product_attention.
580+
# For custom `attention_fn`s, invalid kwargs will be filtered.
581+
attn_kwargs = dict(
582+
mask=mask,
583+
dropout_rng=dropout_rng,
584+
dropout_rate=self.dropout_rate,
585+
broadcast_dropout=self.broadcast_dropout,
586+
deterministic=m_deterministic,
587+
dtype=self.dtype,
588+
precision=self.precision,
589+
force_fp32_for_softmax=self.force_fp32_for_softmax,
590+
)
591+
attn_kwargs = {
592+
k: v
593+
for k, v in attn_kwargs.items()
594+
if k in inspect.signature(self.attention_fn).parameters
595+
}
577596
if sow_weights:
578-
x = self.attention_fn(
579-
query,
580-
key,
581-
value,
582-
mask=mask,
583-
dropout_rng=dropout_rng,
584-
dropout_rate=self.dropout_rate,
585-
broadcast_dropout=self.broadcast_dropout,
586-
deterministic=m_deterministic,
587-
dtype=self.dtype,
588-
precision=self.precision,
589-
module=self,
590-
) # pytype: disable=wrong-keyword-args
597+
x = self.attention_fn(*attn_args, **attn_kwargs, module=self)
591598
else:
592-
x = self.attention_fn(
593-
query,
594-
key,
595-
value,
596-
mask=mask,
597-
dropout_rng=dropout_rng,
598-
dropout_rate=self.dropout_rate,
599-
broadcast_dropout=self.broadcast_dropout,
600-
deterministic=m_deterministic,
601-
dtype=self.dtype,
602-
precision=self.precision,
603-
)
599+
x = self.attention_fn(*attn_args, **attn_kwargs)
604600
# back to the original inputs dimensions
605601
out = DenseGeneral(
606602
features=features,

tests/linen/linen_attention_test.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Tests for flax.linen.attention."""
1616

17-
import functools
1817
from absl.testing import absltest, parameterized
1918
from flax import errors, jax_utils
2019
from flax import linen as nn
@@ -565,9 +564,7 @@ def test_mixed_precision_multihead_attention(
565564
qkv_features=4,
566565
kernel_init=initializers.lecun_normal(),
567566
bias_init=initializers.uniform(),
568-
attention_fn=functools.partial(
569-
nn.dot_product_attention, force_fp32_for_softmax=force_fp32
570-
),
567+
force_fp32_for_softmax=force_fp32,
571568
deterministic=False,
572569
dtype=jnp.bfloat16,
573570
)

0 commit comments

Comments
 (0)