lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.15k stars 1.09k forks source link

Add static graph param #226

Closed rom1504 closed 2 years ago

rom1504 commented 2 years ago

depends on https://github.com/huggingface/accelerate/pull/637

This is needed for gradient checkpointing to work

rom1504 commented 2 years ago

Accelerate merged the PR but we need to wait for a release of accelerate before merging here

nousr commented 2 years ago

https://github.com/huggingface/accelerate/releases/tag/v0.13.0

looks like your PR finally got released by huggingface @rom1504

rom1504 commented 2 years ago

When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteraton to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well.

https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

seems to have only benefits

let's merge. We can set to False by default is anyone complains

It unlocks gradient checkpointing

lucidrains commented 2 years ago

@rom1504 thanks Romain!