f0uriest / interpax

Interpolation and function approximation with JAX
MIT License
137 stars 11 forks source link

Different results on different GPUs #28

Closed jaeminoh closed 8 months ago

jaeminoh commented 8 months ago

Hi f0uriest,

I encountered an issue that interpolation results vary along different machines.

I used a 1d interpolator with the monotonic method, allowing extrap=True.

test machines: [CPU, RTX Titan, RTX 4090]. reference machine: CPU with double precision (x64).

below table presents relative $L^1$ error: abs(a - b).sum() / abs(b).sum()

precision CPU RTX Titan RTX 4090
x32 5.87719e-08 5.89367e-08 1.78212e-04
x64 reference 4.16375e-17 4.16375e-17

Since I used the same (xq, xp, yp), the errors of each row must coincide, respectively.

However, as you can see, interpolation on RTX 4090 with single precision produced quite an inaccurate result.

Do you have any ideas on this?

f0uriest commented 8 months ago

Do you know if this is specific to interpax? It's likely it's a more general JAX issue (or really a CUDA/XLA issue) that things get compiled differently for different hardware, see https://github.com/google/jax/issues/20371 and https://github.com/google/jax/discussions/10674#discussioncomment-7214817

Also, is the error uniformly bad for all points being interpolated, or is it localized in some way?

jaeminoh commented 8 months ago

Hi! Thank you for the reply.

I believe it's related to a general JAX-related issue since I could not observe machine-specific implementation in interpax. But I don't know where to start to fix it 😅

Here I attached two images, which present relative pointwise error abs(a - b) / abs(b).

4090_x32

This is for 4090 with single precision,

4090_x64

and this is for 4090 with double precision. Numbers on the axes are just indices.

For the left vertical edge of the figures, xq is monotonically increasing (from 0 to 1) So I would say that the error is uniformly bad.

In fact, my query points xq were loaded from Excel files using pandas.read_excel, so this could've been a cause. So I switched my query points xq to numpy.linspace(0, 1, 1000), and observed the same issue again. On Titan, $\approx 10^{-6}$, however on 4090, $\approx 10^{-2}$ for the relative $L^1$ error, where the baseline is Cpu results with x64 arithemetic.

f0uriest commented 8 months ago

Can you share some code/data that seems to reproduce the issue? I don't have access to either of those GPUs but I can try some others and see if its a more general issue.

jaeminoh commented 8 months ago

Hi, I think I found the cause.

I ran the test with NVIDIA_TF32_OVERRIDE=0, I got the correct result:

overpotential_rtx_4090

It might be related to this issue (default TF32 overriding of JAX): https://github.com/google/jax/issues/7010#issue-924358639 https://github.com/patrick-kidger/diffrax/issues/213#issuecomment-1382731229