pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.12k stars 3.63k forks source link

torch_geometric.transforms.RandomLinkSplit is not interoperable with torch_geometric.loader.DataLoader #8872

Closed aaronwtr closed 8 months ago

aaronwtr commented 8 months ago

🐛 Describe the bug

I am trying to build a GAT model with PyG and PyTorch Lightning. The problem I am trying to solve is a link prediction task and to that end, I need to split my edges into a train and val set. Since we have an independent held-out test graph, we don't need to get that with RandomLinkSplit. After my preprocessing, my graph object looks as follows: Data(x=[507, 4], edge_index=[2, 39607], edge_label=[39607, 1])

Now, when I perform the train and val splitting, the two resulting graphs look as follows: train: Data(x=[507, 4], edge_index=[2, 33668], edge_label=[16834, 1], edge_label_index=[2, 16834]) val: Data(x=[507, 4], edge_index=[2, 33668], edge_label=[5940, 1], edge_label_index=[2, 5940])

My model is as follows:

import torch

from torch_geometric.nn import GATv2Conv, Linear, HGTConv, to_hetero
from torch_geometric.loader import DataLoader

import lightning as pl
import torch.nn.functional as F
import torch_geometric.transforms as T

class HGSLNetDataModule(pl.LightningDataModule):
    def __init__(self, graph, batch_size=1):
        super().__init__()
        self.graph = graph
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_graph, self.val_graph, _ = self.link_split_transform()

    def train_dataloader(self):
        return DataLoader(self.train_graph, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_graph, batch_size=self.batch_size, shuffle=False)

    def link_split_transform(self):
        transform = T.RandomLinkSplit(
            num_val=0.15,
            num_test=0,
            is_undirected=True,
            add_negative_train_samples=False
        )
        return transform(self.graph)

class GATEncoder(torch.nn.Module):
    def __init__(self, num_layers, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(GATv2Conv(-1, hidden_channels, heads=1, dropout=dropout))
            print('break')
        self.conv_out = GATv2Conv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        z = self.conv_out(x, edge_index)
        return z

class GATDecoder(torch.nn.Module):
    def __init__(self, num_layers, hidden_channels):
        super().__init__()
        self.lin_in = Linear(-1, hidden_channels)

        self.lins_hidden = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins_hidden.append(Linear(hidden_channels, hidden_channels))

        self.lin_out = Linear(hidden_channels, 1)

    def forward(self, z, edge_index):
        z = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1)
        z = self.lin_in(z).relu()
        for lin in self.lins_hidden:
            z = lin(z).relu()
        z = self.lin_out(z)
        return z.view(-1)

class HGSLNet(pl.LightningModule):
    def __init__(self, num_layers, hidden_channels):
        super().__init__()
        self.encoder = GATEncoder(num_layers, hidden_channels, hidden_channels)
        self.decoder = GATDecoder(num_layers, hidden_channels)

    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        logit = self.decoder(z, edge_index).reshape(-1, 1)

        proba = torch.sigmoid(logit)
        y = torch.where(proba > 0.5, torch.tensor(1), torch.tensor(0)).long()

        return logit, proba, y

    def training_step(self, batch, batch_idx):
        x, edge_index, y = batch.x, batch.edge_index, batch.edge_label
        y = y.float()
        logit, _, _ = self(x, edge_index)
        loss = F.binary_cross_entropy_with_logits(logit, y)
        self.log('train_loss', loss)
        print(f"Training loss: {loss}")
        return loss

    def validation_step(self, batch, batch_idx):
        x, edge_index, y = batch.x, batch.edge_index, batch.edge_label
        # get size of edge_label
        edge_label_size = batch.edge_label.size(0)
        logit, _, _ = self(x, edge_index)
        y = y.float()
        loss = F.binary_cross_entropy_with_logits(logit, y)
        self.log('val_loss', loss)
        print(f"Validation loss: {loss}")
        return loss

    # def validation_epoch_end(self, outputs):
    #     avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    #     self.log('avg_val_loss', avg_loss)

    def test_step(self, batch, batch_idx):
        x, edge_index, y = batch.x, batch.edge_index, batch.y
        y_hat, _, _ = self(x, edge_index)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        self.log('test_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

Then, I proceed to construct my datamodule as follows:

from lightning import Trainer
from torch_geometric.data.lightning import LightningLinkData

from src.preprocessing import CellLineGraphData, MultiOmicsLoader
from src.model import HGSLNet, HGSLNetDataModule

def homotrain(graph):
    datamodule = HGSLNetDataModule(graph)

    # Initialize your model
    model = HGSLNet(num_layers=4, hidden_channels=128)

    # Initialize the trainer
    trainer = Trainer(max_epochs=100, accelerator='cpu', enable_progress_bar=False)

    # Train the model
    trainer.fit(model, datamodule)

Running this leads to the following error:

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:452: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.

  | Name    | Type       | Params
---------------------------------------
0 | encoder | GATEncoder | 35.3 K
1 | decoder | GATDecoder | 66.3 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Traceback (most recent call last):
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 42, in <module>
    main()
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 37, in main
    homotrain(graph)
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 20, in homotrain
    trainer.fit(model, datamodule)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self._run_sanity_check()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 127, in run
    batch, batch_idx, dataloader_idx = next(data_fetcher)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 127, in __next__
    batch = super().__next__()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 56, in __next__
    batch = next(self.iterator)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 326, in __next__
    out = next(self._iterator)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 132, in __next__
    out = next(self.iterators[0])
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 53, in fetch
    data = self.dataset[possibly_batched_index]
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/data/data.py", line 457, in __getitem__
    return self._store[key]
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/data/storage.py", line 104, in __getitem__
    return self._mapping[key]
KeyError: 0

It suggests that somehow PyG can not index over my graph, i.e. Data object from torch_geometric. I reached out to folks over at PyTorch Lightning, and Lightning developer Justin Goheen was so kind to step through my code with me. We traced back the problems to the torch_geometric.loader.DataLoader, which if we feed it any of our {train, val}_graphs after RandomLinkSplit, returns the exact error seen above. Note, that we get the same error if we use the standard torch DataLoader from torch.utils.DataLoader

Versions

Collecting environment information... PyTorch version: 2.0.1 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.1.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.0.40.1) CMake version: Could not collect Libc version: N/A

Python version: 3.8.17 | packaged by conda-forge | (default, Jun 16 2023, 07:11:32) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.1.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M2

Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] torch==2.0.1 [pip3] torch-geometric==2.3.1 [pip3] torchmetrics==1.3.0.post0 [conda] numpy 1.24.4 pypi_0 pypi [conda] torch 2.0.1 pypi_0 pypi [conda] torch-geometric 2.3.1 pypi_0 pypi [conda] torchmetrics 1.3.0.post0 pypi_0 pypi

rusty1s commented 8 months ago

I think what you want to use here is LinkNeighborLoader instead of DataLoader.

aaronwtr commented 8 months ago

@rusty1s trying to install pyg-lib according to docs to get LinkNeighborLoader to run with pip install pyg-lib -f https://data.pyg.org/whl/torch-2.0.0+cpu.html results in:

Looking in links: https://data.pyg.org/whl/torch-2.0.0+cpu.html
ERROR: Could not find a version that satisfies the requirement pyg-lib (from versions: none)
ERROR: No matching distribution found for pyg-lib

Am I somehow not installing it correctly?

rusty1s commented 8 months ago

Why OS are you on?

aaronwtr commented 8 months ago

@rusty1s MacOS 14.1.1

rusty1s commented 8 months ago

I am assuming you are on M1/M2 then? We don't provide custom wheels for that, so you would need to install pyg-lib from source for this. Sorry :(

aaronwtr commented 8 months ago

@rusty1s thank you for your help with this. Was able to install pyg-lib from the git main branch. However, I am running into other weird/unexpected behaviour. When I run the LinkNeighbourLoader as follows:

class HGSLNetDataModule(pl.LightningDataModule):
    def __init__(self, graph, batch_size=32):
        super().__init__()
        self.graph = graph
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_graph, self.val_graph, _ = self.link_split_transform()

    def train_dataloader(self):
        return LinkNeighborLoader(
            self.train_graph,
            batch_size=self.batch_size,
            num_neighbors=[10],
            shuffle=True
        )

    def val_dataloader(self):
        return LinkNeighborLoader(
            self.val_graph,
            batch_size=self.batch_size,
            num_neighbors=[10],
            shuffle=False
        )

    def link_split_transform(self):
        transform = T.RandomLinkSplit(
            num_val=0.15,
            num_test=0,
            is_undirected=True,
            add_negative_train_samples=False
        )
        return transform(self.graph)

I get the following error:

Traceback (most recent call last):
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 41, in <module>
    main()
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 36, in main
    homotrain(graph)
  File "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/synthetic-lethality-prediction/synthetic-lethality-prediction/src/main.py", line 19, in homotrain
    trainer.fit(model, datamodule)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self._run_sanity_check()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 127, in run
    batch, batch_idx, dataloader_idx = next(data_fetcher)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 127, in __next__
    batch = super().__next__()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 56, in __next__
    batch = next(self.iterator)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 326, in __next__
    out = next(self._iterator)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 132, in __next__
    out = next(self.iterators[0])
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/loader/base.py", line 36, in __next__
    return self.transform_fn(next(self.iterator))
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/loader/link_loader.py", line 182, in collate_fn
    out = self.link_sampler.sample_from_edges(
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 182, in sample_from_edges
    return edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 550, in edge_sample
    out = sample_fn(seed, seed_time)
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 282, in _sample
    out = torch.ops.pyg.neighbor_sample(
  File "/Users/aaronw/miniconda3/envs/sl-prediction/lib/python3.8/site-packages/torch/_ops.py", line 502, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: pyg::neighbor_sample() Expected a value of type 'Optional[Tensor]' for argument 'seed_time' but instead found type 'bool'.
Position: 6
Value: True
Declaration: pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] num_neighbors, Tensor? node_time=None, Tensor? edge_time=None, Tensor? seed_time=None, Tensor? edge_weight=None, bool csc=False, bool replace=False, bool directed=True, bool disjoint=False, str temporal_strategy="uniform", bool return_edge_id=True) -> (Tensor, Tensor, Tensor, Tensor?, int[], int[])
Cast error details: Unable to cast True to Tensor

I have not set any seed_time argument anywhere so I can't quite follow where this arg gets assigned a boolean.

akihironitta commented 8 months ago

@aaronwtr The error is likely due to version mismatch between torch_geometric and pyg_lib. Which version of each have you installed?

aaronwtr commented 8 months ago

@akihironitta I did

pip install ninja wheel
pip install git+https://github.com/pyg-team/pyg-lib.git

to install. This installed version 0.4.0 of pyg-lib

akihironitta commented 8 months ago

If you have pyg_lib 0.4.0, you need torch_geometric master afaik.

aaronwtr commented 8 months ago

@rusty1s @akihironitta thank you for your help! Indeed had to use LinkNeighborLoader + pyg_lib 0.4.0 + torch_geometric from main.