google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.11k stars 645 forks source link

Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU #1755

Closed mohamad-amin closed 2 years ago

mohamad-amin commented 2 years ago

Hey, thanks for the great work!

I'm using BatchNorm in my network, but have set the use_running_average parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network's behaviour doesn't change among different batches (Ideally, I guess, but it should be true).

I've provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:

psd_data.pkl is the pickled version of a dict containing three things:

The problem that I have is:

ys = []
for i in range(10):
  ys.append(apply_fn(params, X_train[i:i+1]))
ys = jnp.stack(ys).squeeze()
vs = apply_fn(params, X_train[:10])
np.allclose(ys, vs)
# Outputs False!

which shows that the network's behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?

https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing

mohamad-amin commented 2 years ago

Update:

This problem happens even in Flax's ImageNet example. In Inference section, add this code after computing the logits:

# Evaluate using model trained on imagenet.
logits = model.apply({'params': state.params, 'batch_stats': state.batch_stats}, batch['image'][:128], train=False)

import numpy as onp
# Evaluate using model trained on imagenet.
logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
  logit = model.apply({'params': state.params, 'batch_stats': state.batch_stats}, batch['image'][i:i+1], train=False)
  logits_loop[i] = logit

np.allclose(logits_loop, logits)
# Outputs False!

Moreover, the difference is rather big IMO:

(logits_loop - logits).mean()
# Outputs -3.496e-06!
mohamad-amin commented 2 years ago

Another update:

It's probably not a BatchNorm issue! I tried the same thing with the Flax's MNIST example which doesn't have BatchNorm and the issue still persists!

logits = train.CNN().apply({'params': state.params}, test_ds['image'][:100])

import numpy as onp
logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
  logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
  logits_loop[i] = logit

np.allclose(logits_loop, logits)
# Outputs False!

This time though, the difference is smaller, which is probably due to smaller number of parameters I think:

(logits - logits_loop).mean()
5.602e-09

I believe that this issue is not related to GPU/TPU inconsistencies, as both the loop-based and batch-based forward passes are done on CPU (no jit, pmap or xmap involved, and I set the colab's runtime to None so that it doesn't use any GPU/TPU). So, I'm not exactly sure why this difference exists!

Note: Although it's true that these small fluctuations are probably not significant in basic {Initialize --> Train --> Test} pipelines, they're very important when computing gradients, jacobians or stuff like these (my use case: https://github.com/google/neural-tangents) and they introduce bigger problems. As an example, in my use case, I'm computing the jacobian of the network with respect to its parameters on a fixed set of data, and I'm doing so in batches as I'm using a big chunk of data which results in the aforementioned jacobians not fitting in GPU in a single pass. Thus, The calculated jacobians of different batches are not consistent and result in wrong calculations (if you you're interested in knowing more, the resulting NTK will have a minimum eigenvalue of -1.5, in some cases, whereas if everything was computed correctly, it should result in a positive definite matrix with eigenvalues bigger than 0 -- my personal guess is that they would at least be bigger than 1).

Moreover, if it was due to some numerical issues, then it would be expected that if we run the same batch or loop twice, we ended up in slightly different results, but it's not the case:

logits_2 = train.CNN().apply({'params': state.params}, test_ds['image'][:100])
logits = train.CNN().apply({'params': state.params}, test_ds['image'][:100])

logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
  logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
  logits_loop[i] = logit

logits_loop_2 = onp.zeros_like(logits)
for i in range(len(logits_loop)):
  logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
  logits_loop_2[i] = logit

print(np.allclose(logits_2, logits))
print(np.allclose(logits_loop, logits_loop_2))
# Outputs:
True
True
marcvanzee commented 2 years ago

Thanks for your analysis!

This difference originates from XLA, and it is probably due to the fact that the batch implementation and single instance implementations of some of the JAX primitives used in the CNN Module are different. Below is a minimal example that only uses conv_general_dilated from JAX but still has the discrepancy you observed.

Can you please file an issue in the JAX repo containing the example below? Perhaps it is expected, but it would be good to know why this happens. Please let me know when you filed the issue, then I can close this one.

import numpy as np
from jax import lax, random

rng1, rng2 = random.split(random.PRNGKey(0))

lhs = random.normal(rng1, shape=((2, 1, 28, 28)))
rhs = random.normal(rng2, shape=((32, 1, 3, 3)))

logits = lax.conv_general_dilated(
    lhs=lhs,
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

logit = lax.conv_general_dilated(
    lhs=lhs[0:1],
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

print(np.allclose(logits, logit))  # outputs False
jheek commented 2 years ago

My guess is this happens in conv/dense layers indeed. Passing in a different shape will cause the additions to be done in a different order, leading to pretty big discrepancies. To make matters worse it behaves slightly differently on each device

TPUs tend to operate in higher precision for the batch size 1 case, GPUs tend to do stochastic reductions so even running the same forward pass twice gives different results. CPU is slightly better behaved but still suffers from re-ordering.

mohamad-amin commented 2 years ago

I also thought about different order of sums, but I checked PyTorch's ResNet18 and the difference when I was evaluating a 128 batch at once and in loop was from order of 2.5e-11. There still is some difference, but it's 1e5 times smaller, which is a huge number in my opinion. Is there any difference between Jax's CUDA usage and Torch's CUDA usage? And is there any work-around regarding this use case?

marcvanzee commented 2 years ago

Thanks for filing the issue in the JAX repo! Closing this one since there doesn't seem to be anything Flax related (but feel free to re-open if you think I am wrong).

marcvanzee commented 2 years ago

@mohamad-amin I recommend copying your last message in the JAX issue.

mohamad-amin commented 2 years ago

@marcvanzee

I think there is a small problem with your example. logits and logit don't have the correct shape. Did you mean:

import numpy as np
from jax import lax, random

rng1, rng2 = random.split(random.PRNGKey(0))

lhs = random.normal(rng1, shape=((128, 1, 28, 28)))
rhs = random.normal(rng2, shape=((32, 1, 3, 3)))

logits = lax.conv_general_dilated(
    lhs=lhs,
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

logits_loop = np.zeros_like(logits)
for i in range(128):
  logits_loop[i] = lax.conv_general_dilated(
    lhs=lhs[i:i+1],
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

print(np.allclose(logits_loop, logits))

(I'm not familiar with this operator, so I'm not sure if they should result in the same thing. TMI but I'm also super sick rn so I prefer to not rely on my brain at the moment.)

marcvanzee commented 2 years ago

Oh yes, I'm sorry! Indeed, your code is correct.

mohamad-amin commented 2 years ago

@marcvanzee I just ran that code and it outputs True! 😿

marcvanzee commented 2 years ago

Hmm I guess one conv is not sufficient to create the difference. If you chain them after each other it does work. I tested it on a v3-8 TPU and a public Colab with a single GPU, both output False:

import numpy as np
from jax import lax, random

rng1, rng2 = random.split(random.PRNGKey(0))

lhs = random.normal(rng1, shape=((128, 1, 28, 28)))
rhs1 = random.normal(rng2, shape=((32, 1, 3, 3)))
rhs2 = random.normal(rng2, shape=((32, 32, 3, 3)))

def conv_general_dilated(lhs, rhs):
  return lax.conv_general_dilated(
      lhs=lhs,
      rhs=rhs,
      window_strides=(1, 1),
      padding="SAME")

def conv(x):
  x = conv_general_dilated(x, rhs1)
  for _ in range(30):
    x = conv_general_dilated(x, rhs2)
  return x

logits = conv(lhs)

logits_loop = np.zeros_like(logits)
for i in range(128):
  logits_loop[i] = conv(lhs[i:i+1])

print(np.allclose(logits_loop, logits))  # outputs False
mohamad-amin commented 2 years ago

Are you sure everything is alright with this code? It gives me NaNs, but I'm not sure if it's supposed to do so.

np.isnan(logits_loop).sum()
30057

I'm not sure about what I expect from np.allclose when its inputs contain NaN.

marcvanzee commented 2 years ago

Argh you are right, it is also quite late here. I will go to bed now and investigate it in more detail tomorrow. I've re-opened this issue.

jheek commented 2 years ago

Here is a slightly simpler example that I think reproduces your issue:

import numpy as np
from jax import lax, random, jit
from jax import nn
import jax

init_fn = nn.initializers.lecun_normal()
lhs = random.normal(random.PRNGKey(0), (128, 256))

def conv(x):
  for i in range(10):
    rhs = init_fn(random.PRNGKey(i), (256, 256))
    x = x @ rhs
  return jax.device_get(x)

def np_conv(x):
  x = jax.device_get(x)
  for i in range(10):
    rhs = jax.device_get(init_fn(random.PRNGKey(i), (256, 256)))
    x = x @ rhs
  return x

logits = conv(lhs)
logits_np = np_conv(lhs)

logits_loop = np.zeros_like(logits)
for i in range(128):
  logits_loop[i] = conv(lhs[i:i+1])

print(np.allclose(logits_loop, logits))  # outputs False

The previous example uses stdev=1 weights for the kernels which gives you infinities/NaNs if you stack a bunch of them. In this example you get a relative error of approximately 10^-6 stacking more will make the errors larger.

I added a check against numpy as well which gives RMS errors of roughly 10^-6 errors for all pairs in logits, logits_np, and logits_loop):

image

mohamad-amin commented 2 years ago

Thanks! I think this is now reportable to the JAX team, right? @jheek

jheek commented 2 years ago

Yeah this is using vanilla JAX. But I am not sure what there is to report to JAX? From this I would conclude that this is just "implementation variance".

My guess is that your use case might be very sensitive to the numerical precision (not uncommon for eigenvalue decomposition). Perhaps you could try to run your code in float64 to check if the lack of precision is the issue? I think that the most likely explanation at this point.

mohamad-amin commented 2 years ago

I actually have been using float64 as my default config but it doesn't help. I understand that there might be some numerical error, but 1e-6 to me looks way bigger than numerical error. It might be that I'm not that experienced with numerical computations, but np.nextafter(0, 1) gives me 5e-324, which is, let's say somewhat smaller than 1e-6. Also, the same computation in PyTorch results in ~1e-11 numerical error, which sounds quite acceptable, considering the fact that it doesn't make eigen-values of the kernel that I want to compute negative. However, PyTorch doesn't provide such tools for easy jacobian, and parallelism integrations. I'm grateful for the effort on JAX, Flax and etc, but I think it's definitely worth to look into why this happens and why is there such a huge discrepancy between PyTorch and JAX, as they use the same hardware.

jheek commented 2 years ago

he same computation in PyTorch results in ~1e-11 numerical error, which sounds quite acceptable

Is that comparing to NumPy or comparing PyTorch batched vs unbatched?

I actually have been using float64 as my default config but it doesn't help.

I tried the repro above again with float64 but in that case I get an error of 1e-15. Are you sure code is really running in float64? Note that you need to pass a flag to JAX to enable it and just using jnp.float64 is not enough

mohamad-amin commented 2 years ago

Is that comparing to NumPy or comparing PyTorch batched vs unbatched?

Just PyTorch batched vs unbatched.

I tried the repro above again with float64 but in that case I get an error of 1e-15. Are you sure code is really running in float64? Note that you need to pass a flag to JAX to enable it and just using jnp.float64 is not enough

Hmm, interesting. If you look at this notebook, I've used from jax import config; config.update('jax_enable_x64', True) as the first line of the code to run, but seems like the results of apply_fn or model.apply have float32 dtype. I just checked my codebase in which I encountered this problem too, and even though I run my stuff with an environment flag of JAX_ENABLE_X64=True, the result of flax's apply_fn = model.apply still has float32 dtype. Do you know why this is happening?

More info: Screen Shot 2022-01-07 at 11 30 35 AM

Update: Screen Shot 2022-01-07 at 12 19 13 PM

Another update: Even the model.init function constructs some float32 parameters, which is a bit weird to me, as I have explicitly asked jax to use x64 all the time.

Is there anything that I'm missing here?

mohamad-amin commented 2 years ago

Well now I see that at least BatchNorm and Conv have 2 to 3 parameters for controlling accuracy (dtype, param_dtype and precision), but is this working as expected? And is there no way to set these parameters globally for all flax modules? I don't see anything regarding setting these parameters in the get-going documents, so this could be a bit misleading I guess?

jheek commented 2 years ago

Unlike NumPy both JAX and Flax are very conservative about using double precision by default. We also don't allow settting these defaults globally because we believe it's bad practise (it doesn't work when libraries/user code implicilty depends on the global being set to a certain value). We are changing the default dtype argument to take into account input argument dtypes just like NumPy

jheek commented 2 years ago

Closing because dtype behavior is now consistent since dropping the default float32 dtype