getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 63 forks source link

Round operation on LazyTensors #139

Closed haguettaz closed 3 years ago

haguettaz commented 3 years ago

Adds rounding operation to given decimal places on LazyTensors

Test plan:

import pykeops
import torch
import math
from pykeops.torch import LazyTensor

device = 'cpu' 

# rounds to the nearest integer (0 decimal)
x = torch.FloatTensor(1000, 1).uniform_(-10, 10)
y = x.data.clone()
x = x.to(device)
y = y.to(device)
x.requires_grad = True
y.requires_grad = True

x_i = LazyTensor(x[:, None])
s1 = x_i.round(0).sum(0)
s2 = torch.sum(torch.round(y))
print("s1 - s2", torch.abs(s1 - s2).item())
assert torch.abs(s1 - s2) < 1e-3, torch.abs(s1 - s2)

s1.backward()
s2.backward()

print("grad_s1 - grad_s2", torch.max(torch.abs(x.grad - y.grad)).item())
assert torch.max(torch.abs(x.grad - y.grad)) < 1e-3

# rounds to 3 decimal places
x = torch.FloatTensor(1000, 1).uniform_(-1, 1)
y = x.data.clone()
x = x.to(device)
y = y.to(device)
x.requires_grad = True
y.requires_grad = True

x_i = LazyTensor(x[:, None])
s1 = x_i.round(3).sum(0)
s2 = torch.sum(torch.round(y * 1e3)*1e-3)
print("s1 - s2", torch.abs(s1 - s2).item())
assert torch.abs(s1 - s2) < 1e-3, torch.abs(s1 - s2)

s1.backward()
s2.backward()

print("grad_s1 - grad_s2", torch.max(torch.abs(x.grad - y.grad)).item())
assert torch.max(torch.abs(x.grad - y.grad)) < 1e-3
joanglaunes commented 3 years ago

I have now merged this operation ; thank you again!