pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Index out of range in SchNet on a modification of QM9 dataset. #9299

Open CalmScout opened 1 month ago

CalmScout commented 1 month ago

🐛 Describe the bug

Hi!

The idea of the code below is to run a custom version of SchNet on SMILES representations of molecules. Code:

print("Importing packages...")
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
from torch_geometric.nn import SchNet
from tqdm import tqdm
import pickle
import os

print("Defining functions...")
# Define a function to convert SMILES to PyG data objects
def smiles_to_pyg_graph(smiles):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from torch_geometric.data import Data

    try:
        mol = Chem.MolFromSmiles(smiles)
    except:
        return None

    if mol is None:
        return None

    # Add Hydrogens to the molecule
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)

    # Convert the molecule to a graph
    node_features = []
    for atom in mol.GetAtoms():
        node_features.append(atom_feature(atom))
    # node_features = torch.tensor(node_features, dtype=torch.float)
    node_features = torch.tensor(node_features, dtype=torch.long)

    edge_indices = []
    edge_features = []

    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_indices.append((start, end))
        edge_indices.append((end, start))
        edge_features.append(bond_feature(bond))
        edge_features.append(bond_feature(bond))

    edge_indices = torch.tensor(edge_indices).t().to(torch.long)
    # edge_features = torch.tensor(edge_features, dtype=torch.float)
    edge_features = torch.tensor(edge_features, dtype=torch.long)

    return Data(x=node_features, edge_index=edge_indices, edge_attr=edge_features)

# Helper functions for node and edge features
def atom_feature(atom):
    return [atom.GetAtomicNum(), atom.GetFormalCharge()]

def bond_feature(bond):
    return [int(bond.GetBondTypeAsDouble())]

# Load dataset and convert SMILES to PyG data objects
print("Creating dataset...")
# if we have cached data, load it
if os.path.exists('data/qm9_pyg_data.pkl'):
    print("Loading data from cache...")
    with open('data/qm9_pyg_data.pkl', 'rb') as f:
        data_list = pickle.load(f)
else:
    print("Creating dataset from scratch...")
    dataset = QM9(root='data')
    data_list = []
    # for i in tqdm(range(len(dataset))):
    for i in tqdm(range(1000)):
        smiles = dataset[i]['smiles']
        data = smiles_to_pyg_graph(smiles)
        if data is not None:
            data_list.append(data)
    # Save data_list to a pickle file
    with open('data/qm9_pyg_data.pkl', 'wb') as f:
        pickle.dump(data_list, f)

print(f"Example data entry in the data_list: {data_list[0]}")

# Define a SchNet model
class MySchNet(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_targets):
        super(MySchNet, self).__init__()
        self.schnet = SchNet(hidden_channels, num_features)
        self.lin = torch.nn.Linear(hidden_channels, num_targets)

    def forward(self, data):
        print(f'pirnt from forward: data.x.shape: {data.x.shape}')
        print(f'pirnt from forward: data.edge_index.shape: {data.edge_index.shape}')
        print(f'pirnt from forward: data.edge_attr.shape: {data.edge_attr.shape}')
        out = self.schnet(data.x, data.edge_index, data.edge_attr)
        out = self.lin(out)
        return out

# Instantiate the model and define other training parameters
print("Defining model...")
model = MySchNet(num_features=2, hidden_channels=64, num_targets=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

The correspondign output before the Exception:

Training...
Batch size: 32
type(batch.x): <class 'torch.Tensor'>
batch.x.dtype: torch.int64
Batch edge_index shape: torch.Size([2, 834])
Batch edge_index dtype: torch.int64
Batch edge_attr shape: torch.Size([834, 1])
Batch edge_attr dtype: torch.int64
pirnt from forward: data.x.shape: torch.Size([419, 2])
pirnt from forward: data.edge_index.shape: torch.Size([2, 834])
pirnt from forward: data.edge_attr.shape: torch.Size([834, 1])

And an Exception message:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[5], [line 17](vscode-notebook-cell:?execution_count=5&line=17)
     [15](vscode-notebook-cell:?execution_count=5&line=15) print(f'Batch edge_attr dtype: {batch.edge_attr.dtype}')
     [16](vscode-notebook-cell:?execution_count=5&line=16) optimizer.zero_grad()
---> [17](vscode-notebook-cell:?execution_count=5&line=17) output = model(batch)
     [18](vscode-notebook-cell:?execution_count=5&line=18) loss = criterion(output, batch.y.view(-1, 1))  # Assuming targets are stored in batch.y
     [19](vscode-notebook-cell:?execution_count=5&line=19) loss.backward()

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

Cell In[4], [line 14](vscode-notebook-cell:?execution_count=4&line=14)
     [12](vscode-notebook-cell:?execution_count=4&line=12) print(f'pirnt from forward: data.edge_index.shape: {data.edge_index.shape}')
     [13](vscode-notebook-cell:?execution_count=4&line=13) print(f'pirnt from forward: data.edge_attr.shape: {data.edge_attr.shape}')
---> [14](vscode-notebook-cell:?execution_count=4&line=14) out = self.schnet(data.x, data.edge_index, data.edge_attr)
     [15](vscode-notebook-cell:?execution_count=4&line=15) out = self.lin(out)
     [16](vscode-notebook-cell:?execution_count=4&line=16) return out

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:284, in SchNet.forward(self, z, pos, batch)
    [271](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:271) r"""Forward pass.
    [272](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:272) 
    [273](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:273) Args:
   (...)
    [280](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:280)         (default: :obj:`None`)
    [281](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:281) """
    [282](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:282) batch = torch.zeros_like(z) if batch is None else batch
--> [284](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:284) h = self.embedding(z)
    [285](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:285) edge_index, edge_weight = self.interaction_graph(pos, batch)
    [286](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:286) edge_attr = self.distance_expansion(edge_weight)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
    [162](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:162) def forward(self, input: Tensor) -> Tensor:
--> [163](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:163)     return F.embedding(
    [164](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:164)         input, self.weight, self.padding_idx, self.max_norm,
    [165](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:165)         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   [2258](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2258)     # Note [embedding_renorm set_grad_enabled]
   [2259](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2259)     # XXX: equivalent to
   [2260](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2260)     # with torch.no_grad():
   [2261](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2261)     #   torch.embedding_renorm_
   [2262](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2262)     # remove once script supports set_grad_enabled
   [2263](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2263)     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> [2264](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2264) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

Thanks for reading! I appreciate any feedback regarding the issue.

Best regards, Anton.

Versions

Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.31

Python version: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.15.0-1058-aws-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 10.1.243 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB Nvidia driver version: 535.171.04 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 8 On-line CPU(s) list: 0-7 Thread(s) per core: 2 Core(s) per socket: 4 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 79 Model name: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz Stepping: 1 CPU MHz: 3000.000 CPU max MHz: 3000.0000 CPU min MHz: 1200.0000 BogoMIPS: 4600.02 Hypervisor vendor: Xen Virtualization type: full L1d cache: 128 KiB L1i cache: 128 KiB L2 cache: 1 MiB L3 cache: 45 MiB NUMA node0 CPU(s): 0-7 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] pytorch-lightning==2.2.3 [pip3] torch==2.3.0 [pip3] torch_cluster==1.6.3+pt22cu121 [pip3] torch-ema==0.3 [pip3] torch_geometric==2.5.3 [pip3] torch_scatter==2.1.2+pt22cu121 [pip3] torch_sparse==0.6.18+pt22cu121 [pip3] torch_spline_conv==1.2.2+pt22cu121 [pip3] torchaudio==2.3.0 [pip3] torchmetrics==1.0.1 [pip3] torchvision==0.18.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] pytorch-lightning 2.2.3 pypi_0 pypi [conda] torch 2.3.0 pypi_0 pypi [conda] torch-cluster 1.6.3+pt22cu121 pypi_0 pypi [conda] torch-ema 0.3 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-scatter 2.1.2+pt22cu121 pypi_0 pypi [conda] torch-sparse 0.6.18+pt22cu121 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt22cu121 pypi_0 pypi [conda] torchaudio 2.3.0 pypi_0 pypi [conda] torchmetrics 1.0.1 pypi_0 pypi [conda] torchvision 0.18.0 pypi_0 pypi

rusty1s commented 1 month ago

Currently, PyG's SchNet expects an input feature vector of shape [num_atoms], while it looks that your input is two-dimensional.