pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.37k stars 446 forks source link

Gradient clipping doesn't work with FSDP CPU offloading #1977

Open acisseJZhong opened 2 weeks ago

acisseJZhong commented 2 weeks ago

I am running the full finetune distributed recipe, when setting clip_grad_norm: 1.0 and fsdp_cpu_offload: True, it raises error RuntimeError: No backend type associated with device type cpu

Full error stack trace:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 847, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 842, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 740, in train
[rank2]:     grad_norm = torch.nn.utils.clip_grad_norm_(
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
[rank2]:     clip_coef = max_norm / (total_norm + 1e-6)
[rank2]:                 ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 39, in wrapped
[rank2]:     return f(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 1075, in __rdiv__
[rank2]:     return self.reciprocal() * other
[rank2]:            ^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
[rank2]:     self.redistribute_local_args(
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
[rank2]:     resharded_local_tensor = redistribute_local_tensor(
[rank2]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
[rank2]:     new_local_tensor = partial_spec._reduce_value(
[rank2]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
[rank2]:     reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
[rank2]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
[rank2]:     return funcol.all_reduce(
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 176, in all_reduce
[rank2]:     tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: No backend type associated with device type cpu

Wondering how should we fix this error?

felipemello1 commented 2 weeks ago

@ebsmothers, do you think it would make sense to ping someone from FSDP?

RdoubleA commented 2 weeks ago

Could you try modifying the init_process_group call to use the gloo backend for cpu? Perhaps it should initialize both nccl for gpu and gloo for cpu?

https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L903

ebsmothers commented 2 weeks ago

I don’t think we want to modify init_process_group here. To me that error indicates that we are trying to call some comms primitive on a tensor that’s already on CPU, which we shouldn’t be doing. Initializing process group on CPU would only be helpful if we actually want distributed training on CPU, which we don’t. Let’s debug a bit more and then we can loop in distributed folks if needed.

gau-nernst commented 2 weeks ago

I believe when CPU offload is used in FSDP, gradients will be transferred to CPU during the backward pass (to free up gradients memory, similar to optim in backward) to perform optimizer step on CPU. That's probably why you see cpu device there, because the gradients are on CPU now. They are DTensor, hence when you run gradient clipping, which calls .sum() or some sort, it will try to do all-reduce, hence the error.

It's probably faster to check with the distributed folks if FSDP w/ CPU offload support gradient clipping in general. Even if it is technically possible (e.g. do clipping on CPU), I think it would be too slow + possibly require changes in internal FSDP code.

vancoykendall commented 1 week ago

Looks like torchtitan repo ran into the same issue and someone created a quick workaround in a special branch: https://github.com/pytorch/torchtitan/pull/622/files