machine-discovery / deer

Parallelizing non-linear sequential models over the sequence length
BSD 3-Clause "New" or "Revised" License
44 stars 2 forks source link

Difficulty replicating Figure 3 #16

Closed xaviergonzalez closed 5 months ago

xaviergonzalez commented 6 months ago

Hi! I am trying to replicate Figure 3 from your paper but have been running into difficulties.

If I run the below commands in a google colab instance (with both A100 and V100), I am for some reason getting 10^6 large magnitude of error in contrast with your Figure 3.

!git clone https://github.com/machine-discovery/deer.git
! cd deer; python setup.py install
! cd deer/experiments/02_compare_outputs; python main.py

image (18)

On the other hand, if I run your recommended installation code on either colab or the compute cluster I have access to

pip install --upgrade -e .
pip install --upgrade jax==0.4.11 jaxlib==0.4.11+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install nvidia-cudnn-cu11==8.6.0.163

I get cuda errors like : "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details."

I have not yet been able to resolve these cuda errors on my end.

Are you experiencing similar problems? what would you recommend running if someone wanted to replicate your results?

mfkasim1 commented 6 months ago

Hi, could you give more details on the cuda errors? It said "Look at the errors above for more details", so I guess I'm missing what error it refers to. From my experience, it's usually because there is a pre-existing pytorch installation that JAX is not happy with. For the ones in google colab, I'm not sure about that, but probably there is a versioning problem with it.

mavenlin commented 6 months ago

I'm not sure if this will help, in my own experience, the TensorRT is a frequent source of large numerical error even for very trivial einsums, and the numerical error would be different cross jax versions which I suspect is due to different XLA optimizations.

You may try setting NVIDIA_TF32_OVERRIDE=0.

mfkasim1 commented 5 months ago

I just played around with this and I can confirm that the error is too big and the result is different from what I had when writing this code. If you change nh into 7, the error is about 1e-7, but when nh = 8, the error suddenly increases. I suspect there is some change in the computation when the Jacobian matrix changes from 7x7 to 8x8. The jaxpr for both cases are the same, though.

mfkasim1 commented 5 months ago

I did a workaround on the latest commit to make it replicable. I still don't know what the source of the problem, but making the batch size = 1 would make the results different. If we're using larger batch size, we can replicate the results with using the for-loop.

mfkasim1 commented 5 months ago

@xaviergonzalez I found the culprit: jax.lax.slice is buggy when jitted (see https://github.com/google/jax/issues/21637). jax.lax.slice is used in jax.lax.slice_in_dim which is used in jax.lax.associative_scan. So what I did was just copy-paste the jax.lax.associative_scan code and change slice_in_dim into a direct indexing. Now it works even with batch_size = 1 in the example above.