Closed SunHaozhe closed 2 years ago
Hi @SunHaozhe, thanks for your insightful problem and detailed explaination. I think the best way to figure out this is to look at how bagging is implemented in other ML libraries, such as scikit-learn (code here). It turns out that sklearn uses numpy.random.randint
to generate sampling indices, which could have duplicated samples.
Therefore, the implementation in torchensemble is slightly wrong and needs to be fixed. However, I dont think it's a good idea to maintain N
independent dataloaders during the training stage (a huge waste on memory space).
Will appreciate it very much if you could share some ideas on how to solve the memory problem. :-)
Hi @xuyxu , thanks for your reply.
I provide the new implementation of Bagging in this pull request.
Indeed we do not need to store N
copies of dataloaders/datasets in memory/disk, we only store N
list of indices, thus we do not waste memory space. The details and implementation can be found in that pull request. Please let me know what you think :)
Thanks for your PR, will take a loot when I get a moment.
I have a doubt of the current implementation of sampling with replacement for bagging, is this implementation grounded or justified?
The implementation of sampling with replacement for bagging is performed on each (training) batch, as shown in https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/torchensemble/bagging.py#L48 (from Line 48 to 66 in https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/torchensemble/bagging.py):
Why do we need to remove duplicates? What would be the problem if we remove the line
sampling_mask = torch.unique(sampling_mask) # remove duplicates
?Furthermore, the problems I see with this implementation are:
An alternative implementation I am thinking is to do the sampling with replacement only at the very beginning of the
fit
method, which is to use the sampling with replacement to createN
dataloaders/datasets (assume that there areN
base models to learn), each of theN
dataloaders/datasets can have duplicates. Then in the function_parallel_fit_per_epoch
, the data batch are used in the classic way without further subsampling.What do you think of this alternative implementation? Please let me know if I miss anything.