Skip to content

Commit

Permalink
Document mask convention
Browse files Browse the repository at this point in the history
  • Loading branch information
colehaus authored and patrick-kidger committed Aug 28, 2023
1 parent e771ec8 commit 6fe42a5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def __call__(
`(kv_seq_length, value_size)`.
- `mask`: Optional mask preventing attention to certain positions. Should either
be a JAX array of shape `(query_seq_length, kv_seq_length)`, or (for custom
per-head masking) `(num_heads, query_seq_length, kv_seq_length)`.
per-head masking) `(num_heads, query_seq_length, kv_seq_length)`. A value of
`False` at a position indicates that position should be ignored.
- `key`: A `jax.random.PRNGKey` used for dropout. Unused if `dropout = 0`.
(Keyword only argument.)
- `inference`: As [`equinox.nn.Dropout.__call__`][]. (Keyword only
Expand Down

0 comments on commit 6fe42a5

Please sign in to comment.