Add a sharding rule for reduce_precision_p
and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn
's
#2018
Loading