getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 63 forks source link

batch grad shape inconsistent #145

Closed hbgtjxzbbx closed 3 years ago

hbgtjxzbbx commented 3 years ago

Hi Thanks for the great package! I met an shape inconsistency issue that may caused by Lazytensor autograd.
Can you help to take a look? Thanks!

Here is an example code. If use the same kernel written by torch, there is no error. if use the batchsize=1, there is no error. The error only comes when use keops and batch_size>1


_Traceback (most recent call last):
  File "/playpen-raid1/zyshen/proj/shapmagn/shapmagn/grad_debug.py", line 68, in <module>
    loss.backward()
  File "/playpen-raid1/zyshen/anaconda3/envs/pr/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/playpen-raid1/zyshen/anaconda3/envs/pr/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/playpen-raid1/zyshen/anaconda3/envs/pr/lib/python3.7/site-packages/torch/autograd/function.py", line 89, in apply
    return self._forward_cls.backward(self, *args)  # type: ignore
  File "/playpen-raid1/zyshen/anaconda3/envs/pr/lib/python3.7/site-packages/pykeops/torch/generic/generic_red.py", line 161, in backward
    grad = grad.reshape(arg_ind.shape)  # The gradient should have the same shape as the input!
RuntimeError: shape '[2, 2000, 3]' is invalid for input of size 6000_

import torch
from pykeops.torch import LazyTensor
from torch.autograd import grad

def torch_kernel(sigma=0.1):
    """
    :param sigma: scalar
    :return:
    """
    def conv(x, y, b):
        """

        :param x: torch.Tensor, BxNxD,  input position1
        :param y: torch.Tensor, BxKxD input position2
        :param b: torch.Tensor, BxKxd, input val
        :return: torch.Tensor, BxNxd, output
        """
        x = x[:, :, None]
        y = y[:, None]
        b = b[:, None]  # Bx1xKxd
        dist2 = ((x/sigma - y/sigma) ** 2).sum(-1, keepdim=True)  # BxNxKx1
        kernel = (-dist2).exp()
        return (kernel * b).sum(axis=2)

    return conv

def keops_kernel(sigma=0.1):
    def conv(x, y, b):
        """

        :param x: torch.Tensor, BxNxD,  input position1
        :param y: torch.Tensor, BxKxD input position2
        :param b: torch.Tensor, BxKxd, input val
        :return:torch.Tensor, BxNxd, output
        """
        x = LazyTensor(x[:, :, None] / sigma)  # BxNx1xD
        y = LazyTensor(y[:, None] / sigma)  # Bx1xKxD
        b = LazyTensor(b[:, None])  # Bx1xKxd
        dist2 = x.sqdist(y)
        kernel = (-dist2).exp()  # BxNxK
        return (kernel * b).sum_reduction(axis=2)
    return conv

def hamiltonian_evolve(mom, control_points):
    record_is_grad_enabled = torch.is_grad_enabled()
    torch.set_grad_enabled(True)
    control_points = control_points.clone().requires_grad_()
    mom = mom.clone().requires_grad_()
    kernel = keops_kernel(sigma=0.1)
    #kernel = torch_kernel(sigma=0.1)
    energy =  (mom * kernel(control_points, control_points, mom)).sum()*0.5
    grad_mom, grad_control = grad(energy, (mom, control_points), create_graph=True)
    torch.set_grad_enabled(record_is_grad_enabled)
    return -grad_control, grad_mom

B = 2
N = 2000
D = 3
device = torch.device("cpu")  # cuda:0, cpu
control_points = torch.rand(B, N, D, device=device)
momentum = torch.rand(B, N, D, requires_grad=True, device=device)
_, dcontrol_points = hamiltonian_evolve(momentum, control_points)
loss = dcontrol_points.mean()
loss.backward()
hbgtjxzbbx commented 3 years ago

sorry just found duplicate with issue #103