Closed msaroufim closed 2 months ago
Just realized TORCH_VERSION_AFTER_2_4
will return False in 2.4.0. Still got that old problem 🤣. So low bit optim FSDP2 test will not run in 2.4.0 CI.
The error message is very cryptic. AdamW8bit
doesn't use dynamic shape though, so don't know why it pops up. And this error only happens to FSDP2 test, not the normal single-gpu test. Would you know who can take a look into this?
Yeah the version problem is kinda getting out of hand, I'll fix that asap
Regarding the error usually @awgu and @weifengpy are usually my gotos for fsdp2 issues
I think low-bit optimizer + FSDP2 is actually low-bit optimizer + DTensor + torch.compile
, for which @bdhirsh is probably the best.
(taking a look)
The problem is that we have a pretty complicated input to the compiled region: our input is a DTensor
, that has a local_tensor._base
, and also has a populated .grad
field that is also a DTensor
, which has a _local_tensor._base
with a different number of dims compared to the original ._base
.
I have a min repro here https://github.com/pytorch/pytorch/issues/133274.
In the meantime, I also found that this tweak gets me past the error, although I'm not sure that we actually want to land it to eager FSDP2 (cc @awgu ):
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index d739ffbcf96..c512ea7c37f 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -324,7 +324,7 @@ def foreach_reduce(
size=fsdp_param.sharded_size,
stride=fsdp_param.contiguous_sharded_stride,
storage_offset=flat_grad_offset,
- )
+ ).detach()
to_accumulate_grad = fsdp_param.sharded_param.grad is not None
if fsdp_param.offload_to_cpu:
# Only overlap the D2H copy (copying to pinned memory) if not
Thank you for the quick debug. May I ask
our input is a DTensor, that has a local_tensor._base, and also has a populated .grad field that is also a DTensor, which has a _local_tensor._base with a different number of dims compared to the original ._base
You mentioned the input being concerned has .grad
field, indicating that it is a parameter. In the low bit optim test, only the optimizer states are tensor subclass, so they shouldn't have .grad
field. I think something is not quite right here?
I was looking at the values of param_groups
, which are the inputs to your torch.compile region, here: https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adam.py#L110
And empirically, param_groups
contains DTensor parameters with the above properties. Are you saying you don't expect the parameters themselves to be DTensors? Maybe @awgu would know better?
(Pdb) p type(param_groups[0][0][0][0])
<class 'torch.distributed._tensor.api.DTensor'>
(Pdb) p param_groups[0][0][0][0]._local_tensor._base.ndim
2
(Pdb) p param_groups[0][0][0][0].grad._local_tensor._base.ndim
1
(Pdb) p isinstance(param_groups[0][0][0][0], torch.nn.Parameter)
True
@bdhirsh I see, thank you for the clarification. The subclass you were referring to is DTensor, not my custom subclass for quantized optimizer state. It makes sense.
But it also raises another question. How come other FSDP2 tests in torchao did not fail 😅. Then I rmb NF4 is not trainable, so it won't have .grad
field. Not sure about other FSDP2 tests in torchao.
In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.
In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.
yeah, I could definitely believe that this is true (I don't have bandwidth to add those tests, but if someone wants to try making a smaller repro that doesn't use your low bit optimizer they are welcome to 😄 )
then again, I think this is a pretty one-off bug that we just expected to be very rarely hit (we haven't had to excercise a lot of code in compile where our tensor inputs to the graph also have .grad
fields that are subclasses), that should have a relatively straightforward fix.
I can confirm normal optimizers have this bug too
import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()
for m in model.layers:
fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)
# compile optimizer
optim.step = torch.compile(optim.step)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
model(inp).mean().backward()
optim.step()
optim.zero_grad()
Run with torchrun --nnodes 1 --nproc_per_node 1 debug.py
Agree with your last point 😄! Hopefully the fix in PyTorch core is coming soon! Thank you for the help!
@bdhirsh I noticed that the FSDP test for low-bit optim now passed with torch nightly. Was it fixed in core recently? I didn't see any updates in https://github.com/pytorch/pytorch/issues/133274
hmm that's strange - i ran the non-subclass repro you put above locally and it still fails for me:
import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()
for m in model.layers:
fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)
# compile optimizer
optim.step = torch.compile(optim.step)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
model(inp).mean().backward()
optim.step()
optim.zero_grad()
Hmm, I think it might be because I change the way I compile the optim step. Now I static-shape compile optim step for each param, instead of optim step for all params #812. In that case the issue in pytorch core is still there, but we can probably close this issue?
ah yeah, great - this is definitely just a bug at the intersection of subclasses + dynamic shapes + optimizer/gradient, so if you're ok with static shapes only for now (which might be better for perf anyway), closing this issue sounds fine to me
To repro:
python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2
Logs