Different results in bf16/f32 mixed precision in conv_general_dilated
in jax2tf
#25873
Labels
bug
Something isn't working
conv_general_dilated
in jax2tf
#25873
Description
Hi Sirs.
I noticed Flax's Conv module was giving me different results between Python and C++ via TensorFlow Serving, which I could track and isolate to the following example. The problem is the most acute when I use
bfloat16
for calculations (which happens to be the type I am using most). The code below calculates a simple convolution in 5 different ways, and onlyjax2tf
withnative_serialization=True
stands out. It seems almost that usingnative_serialization=True
actually does calculations in f32, because the results match to all other methods if I setdtype
tof32
.I am not sure it is a bug in earnest, just trying to understand the discrepancy between research (Jax) and production (TF Serving). In real-life full-sized model the discrepancy is quite significant because it accumulates between layers.
It seems to be somehow related to #17151 and #17152.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: