Closed jaeminoh closed 8 months ago
Do you know if this is specific to interpax? It's likely it's a more general JAX issue (or really a CUDA/XLA issue) that things get compiled differently for different hardware, see https://github.com/google/jax/issues/20371 and https://github.com/google/jax/discussions/10674#discussioncomment-7214817
Also, is the error uniformly bad for all points being interpolated, or is it localized in some way?
Hi! Thank you for the reply.
I believe it's related to a general JAX-related issue since I could not observe machine-specific implementation in interpax. But I don't know where to start to fix it 😅
Here I attached two images, which present relative pointwise error abs(a - b) / abs(b)
.
This is for 4090 with single precision,
and this is for 4090 with double precision. Numbers on the axes are just indices.
For the left vertical edge of the figures, xq
is monotonically increasing (from 0 to 1)
So I would say that the error is uniformly bad.
In fact, my query points xq
were loaded from Excel files using pandas.read_excel
, so this could've been a cause.
So I switched my query points xq
to numpy.linspace(0, 1, 1000)
, and observed the same issue again.
On Titan, $\approx 10^{-6}$, however on 4090, $\approx 10^{-2}$ for the relative $L^1$ error, where the baseline is Cpu results with x64 arithemetic.
Can you share some code/data that seems to reproduce the issue? I don't have access to either of those GPUs but I can try some others and see if its a more general issue.
Hi, I think I found the cause.
I ran the test with
NVIDIA_TF32_OVERRIDE=0
,
I got the correct result:
It might be related to this issue (default TF32 overriding of JAX): https://github.com/google/jax/issues/7010#issue-924358639 https://github.com/patrick-kidger/diffrax/issues/213#issuecomment-1382731229
Hi f0uriest,
I encountered an issue that interpolation results vary along different machines.
I used a 1d interpolator with the monotonic method, allowing
extrap=True
.test machines: [CPU, RTX Titan, RTX 4090]. reference machine: CPU with double precision (x64).
below table presents relative $L^1$ error:
abs(a - b).sum() / abs(b).sum()
Since I used the same
(xq, xp, yp)
, the errors of each row must coincide, respectively.However, as you can see, interpolation on RTX 4090 with single precision produced quite an inaccurate result.
Do you have any ideas on this?