|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
18 | 18 | import functools
|
| 19 | +import inspect |
19 | 20 | import warnings
|
20 | 21 | from typing import Any, Callable, Optional, Union, overload
|
21 | 22 |
|
@@ -574,33 +575,28 @@ def __call__(
|
574 | 575 | m_deterministic = True
|
575 | 576 |
|
576 | 577 | # apply attention
|
| 578 | + attn_args = (query, key, value) |
| 579 | + # This kwargs list match the default nn.dot_product_attention. |
| 580 | + # For custom `attention_fn`s, invalid kwargs will be filtered. |
| 581 | + attn_kwargs = dict( |
| 582 | + mask=mask, |
| 583 | + dropout_rng=dropout_rng, |
| 584 | + dropout_rate=self.dropout_rate, |
| 585 | + broadcast_dropout=self.broadcast_dropout, |
| 586 | + deterministic=m_deterministic, |
| 587 | + dtype=self.dtype, |
| 588 | + precision=self.precision, |
| 589 | + force_fp32_for_softmax=self.force_fp32_for_softmax, |
| 590 | + ) |
| 591 | + attn_kwargs = { |
| 592 | + k: v |
| 593 | + for k, v in attn_kwargs.items() |
| 594 | + if k in inspect.signature(self.attention_fn).parameters |
| 595 | + } |
577 | 596 | if sow_weights:
|
578 |
| - x = self.attention_fn( |
579 |
| - query, |
580 |
| - key, |
581 |
| - value, |
582 |
| - mask=mask, |
583 |
| - dropout_rng=dropout_rng, |
584 |
| - dropout_rate=self.dropout_rate, |
585 |
| - broadcast_dropout=self.broadcast_dropout, |
586 |
| - deterministic=m_deterministic, |
587 |
| - dtype=self.dtype, |
588 |
| - precision=self.precision, |
589 |
| - module=self, |
590 |
| - ) # pytype: disable=wrong-keyword-args |
| 597 | + x = self.attention_fn(*attn_args, **attn_kwargs, module=self) |
591 | 598 | else:
|
592 |
| - x = self.attention_fn( |
593 |
| - query, |
594 |
| - key, |
595 |
| - value, |
596 |
| - mask=mask, |
597 |
| - dropout_rng=dropout_rng, |
598 |
| - dropout_rate=self.dropout_rate, |
599 |
| - broadcast_dropout=self.broadcast_dropout, |
600 |
| - deterministic=m_deterministic, |
601 |
| - dtype=self.dtype, |
602 |
| - precision=self.precision, |
603 |
| - ) |
| 599 | + x = self.attention_fn(*attn_args, **attn_kwargs) |
604 | 600 | # back to the original inputs dimensions
|
605 | 601 | out = DenseGeneral(
|
606 | 602 | features=features,
|
|
0 commit comments