facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.17k stars 279 forks source link

FSDP unnecessarily clones buffers in state_dict()? #966

Open rohan-varma opened 2 years ago

rohan-varma commented 2 years ago

My understanding is that FSDP does not shard the model buffers, and as a result, unlike parameters which would be fred and go back to their sharded state after state_dict()/summon_full_params(), this would not happen with buffers. Although it seems that buffers are still cloned, which may be unnecessary and a small optimization could be made: https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2516

anj-s commented 2 years ago

Are you talking about buffers which are separate from model parameters? Or just the state_dict itself which we end up calling clone on.

Yes, I do think we need to optimize this bit where we call clone since it runs into OOM errors.

Do you have any suggestions for what we can do to improve this?