TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.07k stars 95 forks source link

Ensemble for PyTorch Geometric #105

Open ParasKoundal opened 2 years ago

ParasKoundal commented 2 years ago

Hi, I want to use Ensemble-PyTorch with PyTorch-Geometric. However, it doesn't recognize the dataloaders.

Is this under development or a bug.

xuyxu commented 2 years ago

Hi @ParasKoundal, could you provide the code snippet on using dataloaders with graph data, so that we can take a closer look.

ParasKoundal commented 2 years ago

@xuyxu It is simple.

.....
from torch_geometric.loader import DataLoader
.......

train_loader = DataLoader(train_dataset, batch_size= batch_s, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=2,drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=2,drop_last=True)
.......

I have created a custom class to preprocess dataset before loading into dataloader.

After that I was trying as given in https://ensemble-pytorch.readthedocs.io/en/latest/quick_start.html. For regression I tried initially with VotingRegressor, doesn't work (error given in the initial issue raised). Similar with others too.

xuyxu commented 2 years ago

Could you further provide the full exception traceback, thanks!

ParasKoundal commented 2 years ago

@xuyxu

Here's that

test_loader=test_loader
  File "/cr/data02/koundal/applications/gpu-project/lib/python3.7/site-packages/torchensemble/bagging.py", line 329, in fit
    self.n_outputs = self._decide_n_outputs(train_loader)
  File "/cr/data02/koundal/applications/gpu-project/lib/python3.7/site-packages/torchensemble/_base.py", line 267, in _decide_n_outputs
    _, target = split_data_target(elem, self.device)
  File "/cr/data02/koundal/applications/gpu-project/lib/python3.7/site-packages/torchensemble/utils/io.py", line 84, in split_data_target
    raise ValueError(msg)
ValueError: Invalid dataloader, please check if the input dataloder is valid.
xuyxu commented 2 years ago

This could possibly be the side-effect of the commit from issue #75. Will see if this could be fixed in a few days, thanks for reporting @ParasKoundal !

ParasKoundal commented 2 years ago

@xuyxu Any update on this?

xuyxu commented 2 years ago

Hi @ParasKoundal, sorry, I am kind of busy these days, and will take a look during the next weekend.

xuyxu commented 2 years ago

In torchensemble, at each iteration the input loader is expected to return a list in the following forms:

The first kind of form is the most widely-used form of the dataloader (i.e., for batch_idx, (data, target) in enumerate(loader)), while the second one comes from the feature request from #75 to support multiple input tensors.

However, the dataloder in pytorch geometric conforms to neither of them:

which does not contain a target tensor since the label is simply the index of the batch in the tuple returned.

Here is a simple solution, please let me know if it solves your problem on using torchensmeble models in pytorch geometric. The general idea is to override the _sample method. Taking metapath2vec as an example, we could declare a new class like:

from torch_geometric.nn import MetaPath2Vec

class CustomMetaPath2Vec(MetaPath2Vec):

    def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]:
        if not isinstance(batch, Tensor):
            batch = torch.tensor(batch, dtype=torch.long)

        pos_sample = self._pos_sample(batch)
        neg_sample = self._neg_sample(batch)

        data = torch.cat((pos_sample, neg_sample), dim=0)
        target = torch.cat(
            (torch.ones(pos_sample.size(0)), torch.zeros(neg_sample.size(0))),
            dim=0
        )

        return [data, target]

Using this new class, the positive_batch and negative_batch will be concatenated as one tensor data, and you can identify them via the target tensor.

In addition, some extra steps are required in the forward function of downstream base estimators.

Looking forward to your kind reply @ParasKoundal