getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.02k stars 65 forks source link

Float vs. int issue with computations? #370

Open pratikrathore8 opened 2 months ago

pratikrathore8 commented 2 months ago

I've been trying to use pykeops to generate rows of a kernel matrix. However, when I change the kernel hyperparameter sigma from an int to a float (5 to 5.), the outputs become nan. Do the developers have an idea for why this is happening?

I'm happy to share the details of my Python environment if needed. Thanks!

from pykeops.torch import LazyTensor
import torch

def check_kernel_computations(x, idx, sigma):
    x_idx = x[idx]
    x_idx_lazy = LazyTensor(x_idx)
    x_lazy = LazyTensor(x[None, :, :])

    D = (x_idx_lazy - x_lazy).abs().sum(dim=2)

    K = (-D / sigma).exp()

    print("sigma = ", sigma)
    print(-D.sum(dim=0))
    print((-D / sigma).sum(dim=0))
    print(K.sum(dim=0))
    print()

seed = 0
device = 'cuda:2'
idx = 0

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Generate random normal data x
x = torch.randn(100000, 1000).to(device)

check_kernel_computations(x, idx, 5)

check_kernel_computations(x, idx, 5.)
sigma =  5
tensor([[   -0.0000],
        [-1147.4860],
        [-1084.8994],
        ...,
        [-1111.5668],
        [-1137.5164],
        [-1145.2468]], device='cuda:2')
tensor([[   0.0000],
        [-229.4972],
        [-216.9799],
        ...,
        [-222.3134],
        [-227.5033],
        [-229.0494]], device='cuda:2')
tensor([[1.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]], device='cuda:2')

sigma =  5.0
tensor([[   -0.0000],
        [-1147.4860],
        [-1084.8994],
        ...,
        [-1111.5668],
        [-1137.5164],
        [-1145.2468]], device='cuda:2')
tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], device='cuda:2')
tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], device='cuda:2')
jeanfeydy commented 2 months ago

Hi @pratikrathore8 ,

Thanks for your report! The bug comes from the fact that you use a single query point x[idx], and that our engine is faulty when dealing with degenerate LazyTensors that have a single row or column. By chance, the issues does not pop up when sigma is an integer, i.e. a constant that is inlined at compile time... But when sigma is a float, it is handled as an extra variable in the formula and this visibly triggers the error.

We are currently rewriting the section of our code that deals with input dimensions, so the problem should hopefully be fixed soon. Until then, a simple fix is to use two instead of one query points idx, e.g. with:

from pykeops.torch import LazyTensor
import torch

def check_kernel_computations(x, idx, sigma):
    x_idx = x[idx]
    x_idx = x_idx.view(1, 1, -1)
    # Create an artificial duplicate
    x_idx = torch.cat((x_idx, x_idx), dim=0)
    assert x_idx.shape == (2, 1, x.shape[-1])
    x_idx_lazy = LazyTensor(x_idx)
    x_lazy = LazyTensor(x[None, :, :])

    D = (x_idx_lazy - x_lazy).abs().sum(dim=2)

    K = (-D / sigma).exp()

    print("sigma = ", sigma)
    print(-D.sum(dim=0))
    print((-D / sigma).sum(dim=0))
    print(K.sum(dim=0))
    print()

seed = 0
device = 'cuda'
idx = 0

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Generate random normal data x
x = torch.randn(100000, 1000).to(device)

check_kernel_computations(x, idx, 5)

check_kernel_computations(x, idx, 5.)

Best regards, Jean

pratikrathore8 commented 2 months ago

@jeanfeydy Thanks for clarifying the source of this error and providing the workaround!