TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.05k stars 95 forks source link

Question of the implementation of sampling with replacement for bagging #119

Closed SunHaozhe closed 1 year ago

SunHaozhe commented 1 year ago

I have a doubt of the current implementation of sampling with replacement for bagging, is this implementation grounded or justified?

The implementation of sampling with replacement for bagging is performed on each (training) batch, as shown in https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/torchensemble/bagging.py#L48 (from Line 48 to 66 in https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/torchensemble/bagging.py):

for batch_idx, elem in enumerate(train_loader):

        data, target = io.split_data_target(elem, device)
        batch_size = data[0].size(0)

        # Sampling with replacement
        sampling_mask = torch.randint(
            high=batch_size, size=(int(batch_size),), dtype=torch.int64
        )
        sampling_mask = torch.unique(sampling_mask)  # remove duplicates
        subsample_size = sampling_mask.size(0)
        sampling_data = [tensor[sampling_mask] for tensor in data]
        sampling_target = target[sampling_mask]

        optimizer.zero_grad()
        sampling_output = estimator(*sampling_data)
        loss = criterion(sampling_output, sampling_target)
        loss.backward()
        optimizer.step()

Why do we need to remove duplicates? What would be the problem if we remove the line sampling_mask = torch.unique(sampling_mask) # remove duplicates?

Furthermore, the problems I see with this implementation are:

  1. The effect of this is that each training batch is subsampled and duplicates are removed, so each actual training batch will have a different (random) batch size. As we know that batch size has an effect on the neural network performance, the random batch size may become a confounding factor of the learning algorithm.
  2. As the sampling with replacement is performed on each training batch, if we train the model more than 1 epoch, then each base model will not see the same coverage of subsample and will eventually see all the examples after several epochs.
  3. The data out of PyTorch dataloader has already gone through data augmentation stage.

An alternative implementation I am thinking is to do the sampling with replacement only at the very beginning of the fit method, which is to use the sampling with replacement to create N dataloaders/datasets (assume that there are N base models to learn), each of the N dataloaders/datasets can have duplicates. Then in the function _parallel_fit_per_epoch, the data batch are used in the classic way without further subsampling.

What do you think of this alternative implementation? Please let me know if I miss anything.

xuyxu commented 1 year ago

Hi @SunHaozhe, thanks for your insightful problem and detailed explaination. I think the best way to figure out this is to look at how bagging is implemented in other ML libraries, such as scikit-learn (code here). It turns out that sklearn uses numpy.random.randint to generate sampling indices, which could have duplicated samples.

Therefore, the implementation in torchensemble is slightly wrong and needs to be fixed. However, I dont think it's a good idea to maintain N independent dataloaders during the training stage (a huge waste on memory space).

Will appreciate it very much if you could share some ideas on how to solve the memory problem. :-)

SunHaozhe commented 1 year ago

Hi @xuyxu , thanks for your reply.

I provide the new implementation of Bagging in this pull request.

Indeed we do not need to store N copies of dataloaders/datasets in memory/disk, we only store N list of indices, thus we do not waste memory space. The details and implementation can be found in that pull request. Please let me know what you think :)

xuyxu commented 1 year ago

Thanks for your PR, will take a loot when I get a moment.