Closed xaviergonzalez closed 5 months ago
Hi, could you give more details on the cuda errors? It said "Look at the errors above for more details", so I guess I'm missing what error it refers to. From my experience, it's usually because there is a pre-existing pytorch installation that JAX is not happy with. For the ones in google colab, I'm not sure about that, but probably there is a versioning problem with it.
I'm not sure if this will help, in my own experience, the TensorRT is a frequent source of large numerical error even for very trivial einsums, and the numerical error would be different cross jax versions which I suspect is due to different XLA optimizations.
You may try setting NVIDIA_TF32_OVERRIDE=0
.
I just played around with this and I can confirm that the error is too big and the result is different from what I had when writing this code. If you change nh
into 7
, the error is about 1e-7
, but when nh = 8
, the error suddenly increases. I suspect there is some change in the computation when the Jacobian matrix changes from 7x7
to 8x8
. The jaxpr for both cases are the same, though.
I did a workaround on the latest commit to make it replicable. I still don't know what the source of the problem, but making the batch size = 1 would make the results different. If we're using larger batch size, we can replicate the results with using the for-loop.
@xaviergonzalez I found the culprit: jax.lax.slice
is buggy when jitted (see https://github.com/google/jax/issues/21637). jax.lax.slice
is used in jax.lax.slice_in_dim
which is used in jax.lax.associative_scan
. So what I did was just copy-paste the jax.lax.associative_scan
code and change slice_in_dim
into a direct indexing. Now it works even with batch_size = 1
in the example above.
Hi! I am trying to replicate Figure 3 from your paper but have been running into difficulties.
If I run the below commands in a google colab instance (with both A100 and V100), I am for some reason getting 10^6 large magnitude of error in contrast with your Figure 3.
On the other hand, if I run your recommended installation code on either colab or the compute cluster I have access to
I get cuda errors like : "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details."
I have not yet been able to resolve these cuda errors on my end.
Are you experiencing similar problems? what would you recommend running if someone wanted to replicate your results?