getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Passing LazyTensor to MLP #270

Closed aabbas90 closed 1 year ago

aabbas90 commented 1 year ago

Hi, I would like to pass LazyTensor to a Pytorch MLP to compute 'distances' between elements. I have created a small example:

import torch
from torch import nn

M, N, D = 1000000, 2000000, 3
x = torch.randn(M, D, requires_grad=True).cuda()
y = torch.randn(N, D).cuda()

from pykeops.torch import LazyTensor
x_i = LazyTensor(x.view(M, 1, D))
y_j = LazyTensor(y.view(1, N, D))

# Computing distance w.r.t a MLP:
comparison_mlp = nn.Sequential(nn.Linear(D, D), nn.ReLU(), nn.Linear(D, 1))

#>>>>>> Error here:
K_ij = comparison_mlp((x_i - y_j).abs())
a_i = K_ij.sum(dim=1)

Although, I guess this is too much to ask from KeOps. How do you recommend one can implement such operations? Thanks for creating the library!

joanglaunes commented 1 year ago

Hello @aabbas90 ,

Sorry for being very late to reply. It is indeed not possible to define MLP using the PyTorch nn module and pass LazyTensors to it. However you can very well implement by hand the MLP using LazyTensor operations. Here is how you can do it with your example :

import torch

M, N, D = 1000000, 2000000, 3
x = torch.randn(M, D, requires_grad=True).cuda()
y = torch.randn(N, D).cuda()

from pykeops.torch import LazyTensor
x_i = LazyTensor(x.view(M, 1, D))
y_j = LazyTensor(y.view(1, N, D))

w1 = torch.rand(D**2, requires_grad=True).cuda()
w2 = torch.rand(D, requires_grad=True).cuda()
comparison_mlp = lambda x : x.vecmatmult(w1).relu().vecmatmult(w2)

K_ij = comparison_mlp((x_i - y_j).abs())
a_i = K_ij.sum(dim=1)
L-Reichardt commented 1 year ago

Thanks, this answer helped me a lot.

A small addition, based on my limited knowledge of PyTorch gradient-related stuff (correct me if I am wrong). The weight should be wrapped in nn.Parameter so PyTorch optimizer can see it. nn.Parameter itself is not compatible with KeOps, so the tensor should be called with .data.

w1 = nn.Parameter(w1, requires_grad = True)
w2 = nn.Parameter(...)
...
comparison_mlp = lambda x : x.vecmatmult(w1.data).relu().vecmatmult(w2.data)
L-Reichardt commented 1 year ago

@joanglaunes Is there no need to call torch.autograd.grad() anywhere?