ArmanMaesumi / torchrbf

GPU-Accelerated Radial Basis Function (RBF) Interpolation in PyTorch
Other
20 stars 1 forks source link

Numerical precision issues without TF32 #2

Open prutschman-iv opened 1 month ago

prutschman-iv commented 1 month ago

I was seeing what seemed like poor numerical performance (without TF32) compared to the scipy implementation of RBFInterpolator, so I made the following test case. It creates a regularly spaced grid of points, displaces them randomly but not so much as to create "pinches", then compares the interpolated values at the regular grid points. With infinite precision a thin plate spline interpolator would return identical values.

import numpy as np

import sys
import scipy
import torch
import torchrbf

assert not torch.backends.cuda.matmul.allow_tf32
torch.set_default_device('cuda')

print('python ', sys.version)
print('torch ', torch.__version__)

for k in range(1,8):
    print(f"{(2**k)**2} control points")
    pts = np.indices((2**k,2**k)).reshape(2,-1).T.astype(np.float32)
    pts_offset = (pts + np.random.uniform(-0.1,0.1,size=pts.shape)).astype(np.float32)

    tpts = torch.tensor(pts)
    tpts_offset = torch.tensor(pts_offset)

    rbf = scipy.interpolate.RBFInterpolator(pts, pts_offset)
    errs = rbf(pts)-pts_offset
    error_mags = np.hypot(*errs.T)
    print('  scipy   \t', error_mags.max())

    trbf = torchrbf.RBFInterpolator(tpts, tpts_offset, device='cuda')
    errs = trbf(tpts)-tpts_offset
    error_mags = torch.hypot(*errs.T)
    print('  torchrbf\t', error_mags.max().cpu().numpy())

When I run this on my Windows machine with a GTX 4090, I get the following:

python  3.11.9 (tags/v3.11.9:de54cf5, Apr  2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
torch  2.3.1+cu121
4 control points
  scipy      0.0
  torchrbf   2.3964506e-07
16 control points
  scipy      8.95090418262362e-16
  torchrbf   1.0662403e-06
64 control points
  scipy      2.6615730177631208e-14
  torchrbf   1.5168344e-05
256 control points
  scipy      6.629405610737136e-13
  torchrbf   0.00014111983
1024 control points
  scipy      1.3932322615057043e-11
  torchrbf   0.007374375
4096 control points
  scipy      2.662584573260454e-10
  torchrbf   0.15812844
16384 control points
  scipy      3.6122964516631894e-09
  torchrbf   179.04927

Do you have any suggestions or advice on improving the situation? 16k control points is a bit on the excessive side, but I have a practical application that could easily use on the order of 4k points, and torchrbf errors at this level are on the order of the grid displacement.

ArmanMaesumi commented 1 month ago

After quite a bit of debugging, it dawned on me that numpy's default dtype is float64, whereas pytorch's default is of course float32. Admittedly this is quite a big oversight on my part!

To confirm, I manually converted all (relevant) internal tensors to torch.float64, and indeed the precision is now in a reasonable margin of scipy's.

Perhaps I should include an optional argument that forces torchrbf to use higher precision internally -- in most cases it is probably not necessary, so the default can be set to the current behavior (float32). And of course GPU performance tends to fall drastically when using higher precision so you'll want to avoid it when possible.

If you want to quickly hack together a local fix, you can CTRL + F ".float()" in torchrbf/RBFInterpolator.py and change those all to ".double()". You'll also want to make sure your input tensors (data coordinates and data values) are converted.

I'm currently a bit occupied so I won't have time to push a patch for now.

prutschman-iv commented 1 month ago

Thank you, I will give your suggestion a try!