TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

split_data_target issue #140

Open jfulem opened 1 year ago

jfulem commented 1 year ago

Hi, I'm using a Dataset class (extending torch.utils.data.Dataset) which contains this method:

def __getitem__(self, idx) -> tuple[list[torch.Tensor], torch.Tensor]

Unfortunately, the split_data_target in torchensemble.utils.io can't handle that.


   if len(element) == 2:
        # Dataloader with one input and one target
        data, target = element[0], element[1]
        return [data.to(device)], target.to(device)  

maybe this modification would be useful:

    if len(element) == 2:
        # Dataloader with one input and one target
        data, target = element[0], element[1]
        if isinstance(data, list) or isinstance(data, tuple):
            return [d.to(device) for d in data], target.to(device)
        return [data.to(device)], target.to(device)# tensor -> list
xuyxu commented 1 year ago

def getitem(self, idx) -> tuple[list[torch.Tensor], torch.Tensor]

Hi @jfulem, is this specification of dataset generally used? Will appreciate your suggestions very much.

jfulem commented 1 year ago

Yes, getting this tuple[list[torch.Tensor], torch.Tensor] as the dataset item is very common.

xuyxu commented 1 year ago

Marked as a new feature, are you willing to work on this @jfulem ?

jfulem commented 1 year ago

@xuyxu, sure I can make the PR.