Skip to content

Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's #2018

Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's

Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's #2018