divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

Error in running DimeNetPP #115

Closed LanceKnight closed 2 years ago

LanceKnight commented 2 years ago

Hello, I've write a simple script to test DimeNetPP. However I got an error like this: image

from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import torch
from torch_geometric.nn.acts import swish
from dig.threedgraph.method import DimeNetPP
from dig.threedgraph.evaluation import ThreeDEvaluator

dataset = QM9(root='/tmp/')

split_idx = {'train':list(range(1,1000)), 'valid':list(range(1000,1200)), 'test':list(range(2000,2200))}
train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]

model = DimeNetPP(energy_and_force=False, cutoff=5.0, num_layers=4, 
                  hidden_channels=128, out_channels=1, int_emb_size=64, 
                  basis_emb_size=8, out_emb_channels=256, num_spherical=7, 
                  num_radial=6, envelope_exponent=5, num_before_skip=1, 
                  num_after_skip=2, num_output_layers=3, act=swish, 
                  output_init='GlorotOrthogonal')

loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

device= 'cuda:0'
loader = DataLoader(dataset)
print(f'size of loader:{len(loader)}')
for i, data_batch in enumerate(loader):
    model.to(device)
    data_batch.to(device)
    pred = model(data_batch)
    if i%500 ==0:
        print(data_batch.x.device)
        print(pred.shape)

I used QM9 from pyg on deliberately because my real data input is not .csv and does not have attributes such as 'z'. So I would like to see if this model works for a simpler data input.

Also, the pyg-nightly has updated

act to from torch_geometric.nn.resolver import activation_resolver

limei0307 commented 2 years ago

Hi @LanceKnight,

This issue is due to the version of pytorch. Please see #77 for detail. I have updated the code here. Please help to check.

Thanks, Limei

LanceKnight commented 2 years ago

Hi, @limei0307 , Thanks for the answer. But how do I update this code from the git? I used pip to upgrade to the latest codes, but the latest codes do not have this update. Do I need to manually change the library codes?

limei0307 commented 2 years ago

Hi @LanceKnight, you can try to install from source with

git clone https://github.com/divelab/DIG.git
cd DIG
pip install .

, or just update code manually. Thanks.

LanceKnight commented 2 years ago

I see. Thanks!