TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

DataLoader with multiple inputs caused errors #75

Closed Runda-Xu closed 3 years ago

Runda-Xu commented 3 years ago

I have a train_loader with multiple inputs:

train_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(input1, input2, input3, label) ,batch_size=batch_size, shuffle=True)

These inputs are arrays with different shapes, so it's hard to concatenate them into a single tensor. They work well with normal pytorch code. However, Ensemble-Pytorch can not deal with it. Hope Ensemble-Pytorch can support DataLoader with multiple inputs in the future. Thank you.

xuyxu commented 3 years ago

Hi @Runda-Xu, thanks for your suggestions! I will take a look at this feature request, and get back to you soon.

xuyxu commented 3 years ago

I am wondering that whether the code snippet below meets your requirement, where we have created a dataloader with three input tensors input_1, input_2, input_3, and it is passed into a model whose forward method takes three inputs accordingly. To make the entire workflow run as expected, we can pass multiple inputs into the model in the form of non-keyword arguments (i.e., *data). There should be no problem as long as the order of these inputs are the same between arguments in forward and arguments when creating the dataloader.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

class TOY(nn.Module):
    def __init__(self):
        super(TOY, self).__init__()
        return

    def forward(self, input_1, input_2, input_3):
        return input_1

nb_samples = 100
input_1 = torch.randn(nb_samples, 10)
input_2 = torch.randn(nb_samples, 5)
input_3 = torch.randn(nb_samples, 7)
target = torch.empty(nb_samples, dtype=torch.long).random_(10)

dataset = TensorDataset(input_1, input_2, input_3, target)
loader = DataLoader(dataset, batch_size=2)

model = TOY()

for batch_idx, elem in enumerate(loader):
    data, target = elem[:-1], elem[-1]
    print(model(*data))