Skip to content

Commit

Permalink
Added sharding parameter: fix for JAX 0.4.31 (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored Aug 1, 2024
1 parent dcbadb6 commit ef28ceb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions quax/examples/zero/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def _(value: Zero, *, broadcast_dimensions, shape) -> Zero:


@quax.register(lax.convert_element_type_p)
def _(value: Zero, *, new_dtype, weak_type) -> Zero:
del weak_type
def _(value: Zero, *, new_dtype, weak_type, sharding=None) -> Zero:
# sharding was added around JAX 0.4.31, it seems.
del weak_type, sharding
return Zero(value.shape, new_dtype)


Expand Down

0 comments on commit ef28ceb

Please sign in to comment.