InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

Distributed Data Parallel support given that DataParallel may be deprecated in the next release of PyTorch #104

Open yoshitomo-matsubara opened 2 years ago

yoshitomo-matsubara commented 2 years ago

Hi @jbegaint @fracape I'm still waiting for CompressAI's DDP support mentioned here. Could you please reconsider this option again?

I think it's a great timing to consider it given that PyTorch team is thinking of DataParallel deprecation with their upcoming v1.11 release of PyTorch (See the following issue) and many projects

https://github.com/pytorch/pytorch/issues/65936

Feature

Distributed Data Parallel support for faster model training

Motivation

Thank you!

YodaEmbedding commented 1 year ago

Isn't this already possible by wrapping the model via:

model = DistributedDataParallel(model)

If accessing .compress/etc is an issue, we can probably cheat and just forward those queries to the model.module instance:

class DistributedDataParallelCompressionModel(DistributedDataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

model = DistributedDataParallelCompressionModel(model)

EDIT: We released CompressAI-Trainer, which should by default use all available GPUs. (This can be restricted to fewer GPU devices, e.g. only devices 0 and 1 via export CUDA_VISIBLE_DEVICES=0,1.) Give it a try! See Installation and Walkthrough.

yoshitomo-matsubara commented 1 year ago

@YodaEmbedding Did you confirm it worked in distributed training mode? (I do not have a resource to test it out now)

It is a similar approach used in this repository to DataParallel, and I tried to use DistributedDataParallel in that way when I opened this issue. But it didn't work at that time, and the problem was not that simple.

danishnazir commented 1 year ago

Hi @YodaEmbedding @yoshitomo-matsubara ,

As suggested by @YodaEmbedding , I tried to train bmshj2018-hyperprior in DDP setting from scratch. I used two V100-16GB GPU's with batch size of 16 each (total in 32). I kept every other setting as default but ofcourse adapted examples/train.py to accomadate DDP training (pretty straightforward to do so). I was able to train and my takeaways/results are as follows.


- For a fair comparison, I try to compare this with pre-trained `CompressAI` models. I noticed that in order to reach this `BPP`, I had to increase the `quality` parameter by one (used `quality=3` in DDP mode and `quality=4` in `CompressAI ` pre-trained models). 

Using trained model bmshj2018-hyperprior-mse-4-ans { "name": "bmshj2018-hyperprior-mse", "description": "Inference (ans)", "results": { "psnr-rgb": [ 32.826677878697716 ], "ms-ssim-rgb": [ 0.9747917304436365 ], "bpp": [ 0.47835625542534715 ], "encoding_time": [ 0.6658469438552856 ], "decoding_time": [ 0.915264755487442 ] } }



- As you can see, DDP results are a bit better than pre-trained `CompressAI` model. Is it a bug or maybe using a bigger batch size helped or something else? 

Please share your ideas, I am happy to share DDP training code as well if you would like to see it. Thanks  
yoshitomo-matsubara commented 1 year ago

Hi @danishnazir

Thank you for testing it out. Could you explain how you executed your script in distributed training mode? It should be like torchrun --nproc_per_node=2 ... or python3 -m torch.distributed.launch --nproc_per_node=2 ...

This is the 2nd issue I made for DDP support, and in the 1st issue there were multiple users waiting for DDP support, and at that time we could not resolve it by a simple wrapper for DistributedDataParallel like suggested above (though I forgot to share the exact errors there)

Generally, I have seen that the AUX loss remains very high in DDP mode. For instance, in Non-DDP training mode it was in the range of [20,30], whereas in the DDP mode it was in the range of [100,120]. Is it because of the batch size? since in Non-DDP mode, batch size was ofcourse less than the DDP mode. However, I am reporting the AUX loss per GPU (only local rank) in case of DDP so maybe not sure if this is the issue. Is it because of some bug in the code or considered normal?

Probably, the default reduction for MSE is mean https://github.com/InterDigitalInc/CompressAI/blob/53275cf5e03f83b0ab8ab01372849bfdc9ef5f1c/compressai/losses/rate_distortion.py#L47 , but aux_loss is taking sum. That may be why only AUX loss turned out to be relatively high due to a large batch size. https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/models/base.py#L117-L146

Since I was using a very large batch size, network converged around 200ish epochs (Vime90K dataset) and training was very fast, took me only 4 days (can be made faster as well i guess).

It is a little bit surprising to me that it still takes 4 days to train a model even in distributed training mode. Did it take more than 4 days with DP instead of DDP?

danishnazir commented 1 year ago

Hi @yoshitomo-matsubara

Thank you for your response.

Could you explain how you executed your script in distributed training mode?

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --node_rank=0 examples/train.py -d /path/to/dataset/ --epochs 300 -lr 1e-4 --batch-size 16 --cuda --save

but aux_loss is taking sum. That may be why only AUX loss turned out to be relatively high due to a large batch size.

Yeah, make sense. But I dont understand that since its the sum per GPU in the DDP mode (atleast how I understand it) , how this is very high? or i am missing something here? Moreover, does this affect the performance in general?

It is a little bit surprising to me that it still takes 4 days to train a model even in distributed training mode. Did it take more than 4 days with DP instead of DDP?

You are right, its a bit surprising for me as well that it took this long. I dont remember exactly how much time it took in the DP mode. Maybe i need to test that or if maybe someone from CompressAI can confirm the time. Furthermore, I would also argue that it didnt exactly take 4 days , it was more or less the estimate. Moreover, I trained for 270 epochs but looking at the logs it was already converging around 200 epochs.

YodaEmbedding commented 1 year ago

We released CompressAI-Trainer, which should by default use all available GPUs. (This can be restricted to fewer GPU devices, e.g. only devices 0 and 1 via export CUDA_VISIBLE_DEVICES=0,1.) Give it a try! See Installation and Walkthrough.

https://interdigitalinc.github.io/CompressAI-Trainer/tutorials/full.html#single-gpu-and-multi-gpu-training

yoshitomo-matsubara commented 1 year ago

Hi @danishnazir

Yeah, make sense. But I dont understand that since its the sum per GPU in the DDP mode (atleast how I understand it) , how this is very high? or i am missing something here? Moreover, does this affect the performance in general?

Probably, we need to see the actual code for better understanding it as the command you provided looks like the right way to use distributed training mode. If the code requires minimal changes to use DDP, perhaps you want to submit a PR (mentioning this issue) and request code review by someone from InterDitialInc (I am not).

yoshitomo-matsubara commented 1 year ago

Hi @YodaEmbedding

We released CompressAI-Trainer, which should by default use all available GPUs. (This can be restricted to fewer GPU devices, e.g. only devices 0 and 1 via export CUDA_VISIBLE_DEVICES=0,1.) Give it a try! See Installation and Walkthrough.

Thank you for sharing that. From this line, I assume that the trainer supports DDP. It would be awesome if example/train.py in this repo can support DDP as well if @danishnazir can submit the PR

danishnazir commented 1 year ago

Hi @yoshitomo-matsubara , Yes I will submit the PR as soon as possible. Thanks.