Closed mohamad-amin closed 2 years ago
Update:
This problem happens even in Flax's ImageNet example. In Inference section, add this code after computing the logits:
# Evaluate using model trained on imagenet.
logits = model.apply({'params': state.params, 'batch_stats': state.batch_stats}, batch['image'][:128], train=False)
import numpy as onp
# Evaluate using model trained on imagenet.
logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
logit = model.apply({'params': state.params, 'batch_stats': state.batch_stats}, batch['image'][i:i+1], train=False)
logits_loop[i] = logit
np.allclose(logits_loop, logits)
# Outputs False!
Moreover, the difference is rather big IMO:
(logits_loop - logits).mean()
# Outputs -3.496e-06!
Another update:
It's probably not a BatchNorm issue! I tried the same thing with the Flax's MNIST example which doesn't have BatchNorm and the issue still persists!
logits = train.CNN().apply({'params': state.params}, test_ds['image'][:100])
import numpy as onp
logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
logits_loop[i] = logit
np.allclose(logits_loop, logits)
# Outputs False!
This time though, the difference is smaller, which is probably due to smaller number of parameters I think:
(logits - logits_loop).mean()
5.602e-09
I believe that this issue is not related to GPU/TPU inconsistencies, as both the loop-based and batch-based forward passes are done on CPU (no jit, pmap or xmap involved, and I set the colab's runtime to None so that it doesn't use any GPU/TPU). So, I'm not exactly sure why this difference exists!
Note: Although it's true that these small fluctuations are probably not significant in basic {Initialize --> Train --> Test} pipelines, they're very important when computing gradients, jacobians or stuff like these (my use case: https://github.com/google/neural-tangents) and they introduce bigger problems. As an example, in my use case, I'm computing the jacobian of the network with respect to its parameters on a fixed set of data, and I'm doing so in batches as I'm using a big chunk of data which results in the aforementioned jacobians not fitting in GPU in a single pass. Thus, The calculated jacobians of different batches are not consistent and result in wrong calculations (if you you're interested in knowing more, the resulting NTK will have a minimum eigenvalue of -1.5, in some cases, whereas if everything was computed correctly, it should result in a positive definite matrix with eigenvalues bigger than 0 -- my personal guess is that they would at least be bigger than 1).
Moreover, if it was due to some numerical issues, then it would be expected that if we run the same batch or loop twice, we ended up in slightly different results, but it's not the case:
logits_2 = train.CNN().apply({'params': state.params}, test_ds['image'][:100])
logits = train.CNN().apply({'params': state.params}, test_ds['image'][:100])
logits_loop = onp.zeros_like(logits)
for i in range(len(logits_loop)):
logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
logits_loop[i] = logit
logits_loop_2 = onp.zeros_like(logits)
for i in range(len(logits_loop)):
logit = train.CNN().apply({'params': state.params}, test_ds['image'][i:i+1])
logits_loop_2[i] = logit
print(np.allclose(logits_2, logits))
print(np.allclose(logits_loop, logits_loop_2))
# Outputs:
True
True
Thanks for your analysis!
This difference originates from XLA, and it is probably due to the fact that the batch implementation and single instance implementations of some of the JAX primitives used in the CNN Module are different. Below is a minimal example that only uses conv_general_dilated
from JAX but still has the discrepancy you observed.
Can you please file an issue in the JAX repo containing the example below? Perhaps it is expected, but it would be good to know why this happens. Please let me know when you filed the issue, then I can close this one.
import numpy as np
from jax import lax, random
rng1, rng2 = random.split(random.PRNGKey(0))
lhs = random.normal(rng1, shape=((2, 1, 28, 28)))
rhs = random.normal(rng2, shape=((32, 1, 3, 3)))
logits = lax.conv_general_dilated(
lhs=lhs,
rhs=rhs,
window_strides=(1, 1),
padding="SAME")
logit = lax.conv_general_dilated(
lhs=lhs[0:1],
rhs=rhs,
window_strides=(1, 1),
padding="SAME")
print(np.allclose(logits, logit)) # outputs False
My guess is this happens in conv/dense layers indeed. Passing in a different shape will cause the additions to be done in a different order, leading to pretty big discrepancies. To make matters worse it behaves slightly differently on each device
TPUs tend to operate in higher precision for the batch size 1 case, GPUs tend to do stochastic reductions so even running the same forward pass twice gives different results. CPU is slightly better behaved but still suffers from re-ordering.
I also thought about different order of sums, but I checked PyTorch's ResNet18 and the difference when I was evaluating a 128 batch at once and in loop was from order of 2.5e-11. There still is some difference, but it's 1e5 times smaller, which is a huge number in my opinion. Is there any difference between Jax's CUDA usage and Torch's CUDA usage? And is there any work-around regarding this use case?
Thanks for filing the issue in the JAX repo! Closing this one since there doesn't seem to be anything Flax related (but feel free to re-open if you think I am wrong).
@mohamad-amin I recommend copying your last message in the JAX issue.
@marcvanzee
I think there is a small problem with your example. logits
and logit
don't have the correct shape. Did you mean:
import numpy as np
from jax import lax, random
rng1, rng2 = random.split(random.PRNGKey(0))
lhs = random.normal(rng1, shape=((128, 1, 28, 28)))
rhs = random.normal(rng2, shape=((32, 1, 3, 3)))
logits = lax.conv_general_dilated(
lhs=lhs,
rhs=rhs,
window_strides=(1, 1),
padding="SAME")
logits_loop = np.zeros_like(logits)
for i in range(128):
logits_loop[i] = lax.conv_general_dilated(
lhs=lhs[i:i+1],
rhs=rhs,
window_strides=(1, 1),
padding="SAME")
print(np.allclose(logits_loop, logits))
(I'm not familiar with this operator, so I'm not sure if they should result in the same thing. TMI but I'm also super sick rn so I prefer to not rely on my brain at the moment.)
Oh yes, I'm sorry! Indeed, your code is correct.
@marcvanzee I just ran that code and it outputs True! 😿
Hmm I guess one conv is not sufficient to create the difference. If you chain them after each other it does work. I tested it on a v3-8 TPU and a public Colab with a single GPU, both output False
:
import numpy as np
from jax import lax, random
rng1, rng2 = random.split(random.PRNGKey(0))
lhs = random.normal(rng1, shape=((128, 1, 28, 28)))
rhs1 = random.normal(rng2, shape=((32, 1, 3, 3)))
rhs2 = random.normal(rng2, shape=((32, 32, 3, 3)))
def conv_general_dilated(lhs, rhs):
return lax.conv_general_dilated(
lhs=lhs,
rhs=rhs,
window_strides=(1, 1),
padding="SAME")
def conv(x):
x = conv_general_dilated(x, rhs1)
for _ in range(30):
x = conv_general_dilated(x, rhs2)
return x
logits = conv(lhs)
logits_loop = np.zeros_like(logits)
for i in range(128):
logits_loop[i] = conv(lhs[i:i+1])
print(np.allclose(logits_loop, logits)) # outputs False
Are you sure everything is alright with this code? It gives me NaNs, but I'm not sure if it's supposed to do so.
np.isnan(logits_loop).sum()
30057
I'm not sure about what I expect from np.allclose
when its inputs contain NaN.
Argh you are right, it is also quite late here. I will go to bed now and investigate it in more detail tomorrow. I've re-opened this issue.
Here is a slightly simpler example that I think reproduces your issue:
import numpy as np
from jax import lax, random, jit
from jax import nn
import jax
init_fn = nn.initializers.lecun_normal()
lhs = random.normal(random.PRNGKey(0), (128, 256))
def conv(x):
for i in range(10):
rhs = init_fn(random.PRNGKey(i), (256, 256))
x = x @ rhs
return jax.device_get(x)
def np_conv(x):
x = jax.device_get(x)
for i in range(10):
rhs = jax.device_get(init_fn(random.PRNGKey(i), (256, 256)))
x = x @ rhs
return x
logits = conv(lhs)
logits_np = np_conv(lhs)
logits_loop = np.zeros_like(logits)
for i in range(128):
logits_loop[i] = conv(lhs[i:i+1])
print(np.allclose(logits_loop, logits)) # outputs False
The previous example uses stdev=1 weights for the kernels which gives you infinities/NaNs if you stack a bunch of them. In this example you get a relative error of approximately 10^-6 stacking more will make the errors larger.
I added a check against numpy as well which gives RMS errors of roughly 10^-6 errors for all pairs in logits, logits_np, and logits_loop):
Thanks! I think this is now reportable to the JAX team, right? @jheek
Yeah this is using vanilla JAX. But I am not sure what there is to report to JAX? From this I would conclude that this is just "implementation variance".
My guess is that your use case might be very sensitive to the numerical precision (not uncommon for eigenvalue decomposition). Perhaps you could try to run your code in float64 to check if the lack of precision is the issue? I think that the most likely explanation at this point.
I actually have been using float64 as my default config but it doesn't help. I understand that there might be some numerical error, but 1e-6 to me looks way bigger than numerical error. It might be that I'm not that experienced with numerical computations, but np.nextafter(0, 1)
gives me 5e-324
, which is, let's say somewhat smaller than 1e-6
. Also, the same computation in PyTorch results in ~1e-11
numerical error, which sounds quite acceptable, considering the fact that it doesn't make eigen-values of the kernel that I want to compute negative. However, PyTorch doesn't provide such tools for easy jacobian, and parallelism integrations. I'm grateful for the effort on JAX, Flax and etc, but I think it's definitely worth to look into why this happens and why is there such a huge discrepancy between PyTorch and JAX, as they use the same hardware.
he same computation in PyTorch results in ~1e-11 numerical error, which sounds quite acceptable
Is that comparing to NumPy or comparing PyTorch batched vs unbatched?
I actually have been using float64 as my default config but it doesn't help.
I tried the repro above again with float64 but in that case I get an error of 1e-15. Are you sure code is really running in float64? Note that you need to pass a flag to JAX to enable it and just using jnp.float64 is not enough
Is that comparing to NumPy or comparing PyTorch batched vs unbatched?
Just PyTorch batched vs unbatched.
I tried the repro above again with float64 but in that case I get an error of 1e-15. Are you sure code is really running in float64? Note that you need to pass a flag to JAX to enable it and just using jnp.float64 is not enough
Hmm, interesting. If you look at this notebook, I've used from jax import config; config.update('jax_enable_x64', True)
as the first line of the code to run, but seems like the results of apply_fn
or model.apply
have float32
dtype. I just checked my codebase in which I encountered this problem too, and even though I run my stuff with an environment flag of JAX_ENABLE_X64=True
, the result of flax's apply_fn = model.apply
still has float32
dtype. Do you know why this is happening?
More info:
Update:
Another update:
Even the model.init
function constructs some float32
parameters, which is a bit weird to me, as I have explicitly asked jax to use x64 all the time.
Is there anything that I'm missing here?
Well now I see that at least BatchNorm and Conv have 2 to 3 parameters for controlling accuracy (dtype, param_dtype and precision), but is this working as expected? And is there no way to set these parameters globally for all flax modules? I don't see anything regarding setting these parameters in the get-going documents, so this could be a bit misleading I guess?
Unlike NumPy both JAX and Flax are very conservative about using double precision by default. We also don't allow settting these defaults globally because we believe it's bad practise (it doesn't work when libraries/user code implicilty depends on the global being set to a certain value). We are changing the default dtype argument to take into account input argument dtypes just like NumPy
Closing because dtype behavior is now consistent since dropping the default float32 dtype
Hey, thanks for the great work!
I'm using BatchNorm in my network, but have set the
use_running_average
parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network's behaviour doesn't change among different batches (Ideally, I guess, but it should be true).I've provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:
wide_resnet_jax.py
: The python file containing the shallow WideResNet module implemented using Flax. You can download it from this gist: https://gist.github.com/mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9adpsd_data.pkl
, which can be downloaded from: https://drive.google.com/file/d/18eb93M34vaWjFyNzIq-vnOfll0T6HCjT/view?usp=sharingpsd_data.pkl
is the pickled version of a dict containing three things:data
: The train and test data used for training the model.params
: The trained parameters of the WideResNet module that we're using, such that it will achieve 1.0 train accuracy and 0.89 test accuracy.labels
: The labels of the datapoints in data, to double check the accuracies.The problem that I have is:
which shows that the network's behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?
https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing