TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

no speedup seen from using TorchEnsemble's implementation #137

Closed lorenzozanisi closed 1 year ago

lorenzozanisi commented 1 year ago

Hi

I want to train an ensemble of NNs on a single GPU in parallel. At the moment I am doing simply:

for model in ensemble.models:
  model.to('cuda')
  train_model(model, ...)

However this does not work in parallel as there are CPU overheads that prevent the spawning of multiple kernels.

TorchEnsemble should deal with these overheads by using joblib's Parallel and delayed - that is, I should be able to start one kernel for each NN training and thus it is possible to parallelise the ensemble on the same GPU.

However this is not what I'm seeing. The code below is a slight re-implementation of your BaggingRegressor, and I am seeing the same training times of my naive implementation above.

The dataset is just a standard pytorch Dataset object. Each model is quite small, and same for the each batch of data, so I can definitely fit all the ensemble on the GPU.

Do you have any insight as to why I cannot parallelise my ensemble efficiently with the code below? Many thanks!

def _parallel_fit_per_epoch(
    train_loader,
    estimator,
    optimizer,
    criterion,
    device,
):
    for batch_idx, elem in enumerate(train_loader):
        data, target = elem[0].to(device), elem[1].to(device)
        optimizer.zero_grad()
        output = estimator.forward(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    return estimator, optimizer, loss

class Regressor(nn.Module):
    # ...
    def forward(self,...):
      # operations
class Ensemble(nn.Module):

    def __init__(self, num_models, ...):
        self.num_models = num_models
        self.loss = nn.MSE()
        self.device = 'cuda'
        self.models = [Regressor(...).to(self.device) for _ in self.num_models]
       # etc 

    def fit(self, train_dataset,valid_dataset,epochs,batch_size, learning_rate):
        optimizers = []
        train_loaders = []
        for model in self.models:
            opt = torch.optim.Adam(model.parameters(), lr=learning_rate)            
            optimizers.append(opt)
            train_loaders.append(train_dataset, batch_size=batch_size, shuffle=True))

        with Parallel(n_jobs=self.num_models) as parallel:
            # Training loop
            for epoch in range(epochs):
                rets = parallel(
                    delayed(_parallel_fit_per_epoch)(
                        dataloader,
                        estimator,
                        optimizer,
                        self.loss,
                        self.device,
                    )
                    for idx, (estimator, optimizer, dataloader) in enumerate(
                        zip(self.models, optimizers, train_loaders)
                    )
                )

                estimators, optimizers, losses = [], [], []
                for estimator, optimizer, loss in rets:
                    estimators.append(estimator)
                    optimizers.append(optimizer)
                    losses.append(loss)
xuyxu commented 1 year ago

Hi @lorenzozanisi, did you set the n_jobs parameters to an integer larger than 1?

lorenzozanisi commented 1 year ago

Thanks for the fast reply @xuyxu, yes as you can see I call with Parallel(n_jobs=self.num_models) as parallel, where self.num_models=5

xuyxu commented 1 year ago

Kind of strange since the paralleism feature is well tested before. Could you provide your package version of joblib and torch, I will take a closer look.

lorenzozanisi commented 1 year ago

joblib = 1.1.1 torch = 1.8.1+cu111 python = 3.7.5

Note, given how my environment is set up I need that version of python. With this, joblib will throw an error as it calls pickle internally with protocol=5 which is supported only in python>=3.8. I substituted import pickle with import pickle5 as pickle in all the relevant places in joblib and it runs without errors. I don't think this is enough for the parallelisation to break though.

xuyxu commented 1 year ago

Here is my result when training the VotingClassifier in examples/classification_cifar10_cnn with joblib = 1.1.0, torch = 1.13.0, and python=3.9:

The speedup should be acceptable considering the large cost on pickling model and copying data.

Could you further provide the following information:

lorenzozanisi commented 1 year ago

Hi @xuyxu I rebuilt my env for python 3.9 and now it works