rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.54k stars 179 forks source link

scatter or scatter_min fails when using torch.compile #440

Open gardiens opened 5 months ago

gardiens commented 5 months ago

Hello,

I can't compile any model that includes scatter or scatter min from torch_scatter. For example in this beautiful script

  import torch
import torch_geometric
from torch_scatter import scatter_min

print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)

def get_x(n_points=100):  
    import torch

    x_min = [0, 10]
    y_min = [0, 10]
    z_min = [0, 10]

    x = torch.rand((n_points, 3))
    x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
    x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
    x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]

    return x

device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))

model = scatter_min
compiled_model = torch.compile(model)

expected  `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)

The code fails with :

 torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
 line 65, in scatter_min
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)

My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,

rusty1s commented 4 months ago

This is currently expected, since the custom ops by torch-scatter are not supported in torch.compile. There exists two options:

rusty1s commented 4 months ago

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

gardiens commented 4 months ago

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?

rusty1s commented 4 months ago

Yes, if you want torch.compile support, then this is the recommended way.