Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different results in bf16/f32 mixed precision in conv_general_dilated in jax2tf #25873

Open
bondquant opened this issue Jan 14, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@bondquant
Copy link

bondquant commented Jan 14, 2025

Description

Hi Sirs.

I noticed Flax's Conv module was giving me different results between Python and C++ via TensorFlow Serving, which I could track and isolate to the following example. The problem is the most acute when I use bfloat16 for calculations (which happens to be the type I am using most). The code below calculates a simple convolution in 5 different ways, and only jax2tf with native_serialization=True stands out. It seems almost that using native_serialization=True actually does calculations in f32, because the results match to all other methods if I set dtype to f32.

I am not sure it is a bug in earnest, just trying to understand the discrepancy between research (Jax) and production (TF Serving). In real-life full-sized model the discrepancy is quite significant because it accumulates between layers.

It seems to be somehow related to #17151 and #17152.

import jax
from jax import lax, random as jrnd, numpy as jnp
from flax import linen as nn
from jax.experimental import jax2tf

dtype = jnp.bfloat16
# dtype = jnp.float16  # diff is smaller
# dtype = jnp.float32  # all methods agree

class Foo(nn.Module):
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, h):
        residual = h
        y = nn.Conv(2, kernel_size=(1,)
                     , strides=1
                     , padding=0
                     , feature_group_count=2
                     , use_bias=False
                     , dtype=self.dtype)(h) 
        return y + residual

layer = Foo(dtype=dtype)
model_vars = {'params': {'Conv_0': {'kernel': jnp.asarray([[[0.25436434, 0.8453059 ]]])}}}
x = jrnd.normal(jrnd.key(1729), (2,))

from jax.lax import ConvDimensionNumbers

def jaxpr_fn(kernel, x):
    kernel_bf16 = kernel.astype(dtype)
    x_bf16 = x.astype(dtype)
    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 2, 1)
                                           , rhs_spec=(2, 1, 0)
                                           , out_spec=(0, 2, 1))
    g = lax.conv_general_dilated(kernel_bf16, x_bf16
                                            , batch_group_count=1
                                            , dimension_numbers=dimension_numbers
                                            , feature_group_count=2
                                            , lhs_dilation=(1,)
                                            , padding=((0, 0),)
                                            , precision=None
                                            , preferred_element_type=None
                                            , rhs_dilation=(1,)
                                            , window_strides=(1,))
    return g + x

y_flax = layer.apply(model_vars, x.reshape(1, 1, -1))
print(y_flax.ravel())
# prints [-1.3705847 -2.27708  ]

y_native = jax2tf.convert(layer.apply
                          , native_serialization=True)(model_vars
                                                       , x.reshape(1, 1, -1))
y_native = jnp.asarray(y_native, dtype=jnp.float32)
print(y_native.ravel())
# prints [-1.3713225 -2.2845213], which is the same as other methods give with f32

y_nonnative = jax2tf.convert(layer.apply
                             , native_serialization=False)(model_vars, x.reshape(1, 1, -1))
y_nonnative = jnp.asarray(y_nonnative, dtype=jnp.float32)
print(y_nonnative.ravel())
# prints [-1.3705847 -2.27708  ]

y_direct = (model_vars['params']['Conv_0']['kernel'][0][0].astype(dtype) * x.astype(dtype)).astype(jnp.float32) + x
print(y_direct.ravel())
# prints [-1.3705847 -2.27708  ]

y_jaxpr = jaxpr_fn(model_vars['params']['Conv_0']['kernel'], x.reshape(1, 1, -1))
print(y_jaxpr.ravel())
# prints [-1.3705847 -2.27708  ]

diff = jnp.max(jnp.abs(y_nonnative - y_native))
print(diff)
# prints 0.0074412823

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  1.26.4
python: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='2cf1ec796e3e', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
@bondquant bondquant added the bug Something isn't working label Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants