Skip to content

Commit

Permalink
feat: pass precedence to plum dispatcher (#24)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Aug 1, 2024
1 parent ef28ceb commit 4aea470
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_rules: dict[core.Primitive, plum.Function] = {}


def register(primitive: core.Primitive) -> Callable[[CT], CT]:
def register(primitive: core.Primitive, *, precedence: int = 0) -> Callable[[CT], CT]:
"""Registers a multiple dispatch implementation for this JAX primitive.
!!! Example
Expand All @@ -50,6 +50,9 @@ def _(x: SomeValue, y: SomeValue):
- `primitive`: The `jax.core.Primitive` to provide a multiple dispatch
implementation for.
- `precedence`: The precedence of this rule.
See `plum.Dispatcher.dispatch` for details.
**Returns:**
A decorator for registering a multiple dispatch rule with the specified primitive.
Expand All @@ -68,7 +71,7 @@ def existing_rule():
existing_rule = plum.Dispatcher().abstract(existing_rule)

_rules[primitive] = existing_rule
existing_rule.dispatch(rule)
existing_rule.dispatch(rule, precedence=precedence)
return rule

return _register
Expand Down

0 comments on commit 4aea470

Please sign in to comment.