pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.52k stars 3.57k forks source link

torch.script.jit doesn't work with pytorch_geometric custom ModuleDict #9401

Closed ItamarKanter closed 2 weeks ago

ItamarKanter commented 3 weeks ago

🐛 Describe the bug

Trying to convert pyg code int torch script failed

import torch
import torch.nn as nn
import torch_geometric.nn as tgnn

class Net(nn.Module):
    def __init__(self, kernel_size):
        super(Net, self).__init__()
        modules = {}
        modules["layer1"] = nn.Conv2d(3, 16, 
            kernel_size=kernel_size, stride=1, padding=2)
        # self.layers = nn.ModuleDict(modules) #work
        self.layers = tgnn.module_dict.ModuleDict(modules) #doesn't work

    def forward(self, x):
        x = self.layers["layer1"](x)
        return x

torch.jit.script(Net(3))
TypeError: 
'set' object in attribute 'ModuleDict.CLASS_ATTRS' is not a valid constant.
Valid constants are:
1. a nn.ModuleList
2. a value of type {bool, float, int, str, NoneType, torch.device, torch.layout, torch.dtype}
3. a list or tuple of (2)

in pyg 2.3 somehow it's work (maybe related to #8363)

Versions

Versions of relevant libraries: [pip3] numpy==1.26.3 [pip3] pytorch-lightning==1.9.1 [pip3] torch==2.0.0 [pip3] torch_geometric==2.5.3 [pip3] torcheval==0.0.7 [pip3] torchinfo==1.8.0 [pip3] torchmetrics==0.11.0 [pip3] torchvision==0.15.2a0 [conda] cudatoolkit 11.8.0 h4ba93d1_13 defaults [conda] mkl 2023.2.0 h84fe81f_50496 defaults [conda] mkl-devel 2023.2.0 ha770c72_50496 defaults [conda] mkl-include 2023.2.0 h84fe81f_50496 defaults [conda] nomkl 3.0 0 anaconda [conda] numpy 1.26.3 py310hb13e2d6_0 defaults [conda] pytorch 2.0.0 cpu_generic_py310h7ffd2bf_1 defaults [conda] pytorch-cuda 11.7 h778d358_5 pytorch [conda] pytorch-lightning 1.9.1 pyhd8ed1ab_0 defaults [conda] pytorch-mutex 1.0 cuda pytorch [conda] pytorch_geometric 2.5.3 pyhd8ed1ab_0 defaults [conda] torcheval 0.0.7 pypi_0 pypi [conda] torchinfo 1.8.0 pyhd8ed1ab_0 defaults [conda] torchmetrics 0.11.0 pyhd8ed1ab_0 defaults [conda] torchvision 0.15.2 cuda118py310h196c800_0 anaconda

rusty1s commented 2 weeks ago

Thanks. Will be fixed with https://github.com/pyg-team/pytorch_geometric/pull/9424