pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.57k stars 173 forks source link

FSDP 2 low bit optim broken on pytorch nightlies #652

Closed msaroufim closed 2 months ago

msaroufim commented 3 months ago

To repro: python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2

Logs

- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
  =========================== short test summary info ============================
  FAILED test/prototype/test_low_bit_optim.py::TestFSDP2::test_fsdp2 - RuntimeError: Process 0 exited with error code 10 and exception:
  Traceback (most recent call last):
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 664, in run_test
      getattr(self, test_name)()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 543, in wrapper
      fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2918, in wrapper
      method(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 184, in wrapper
      return func(*args, **kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 239, in test_fsdp2
      self.run_subtests(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_fsdp.py", line 1141, in run_subtests
      return run_subtests(self, *args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 882, in run_subtests
      test_fn(*test_args, **test_kwargs, **subtest_kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 284, in _test_fsdp2
      fsdp_optim.step()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 479, in wrapper
      out = func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 110, in step
      torch.compile(param_groups_adam, fullgraph=True)(param_groups)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
      return self._torchdynamo_orig_callable(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
      return _compile(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
      guarded_code = compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
      return _compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
      return StrobelightCompileTimeProfiler.profile_compile_time(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
      out_code = transform_code_object(code, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
      transformations(instructions, code_options)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
      tracer.run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
      super().run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
      while self.step():
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1048, in STORE_FAST
      self._store_fast(inst.argval)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1044, in _store_fast
      loaded_vt.set_name_hint(name)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
      return getattr(self.realize(), name)(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
      self._cache.realize()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
      self.vt = VariableBuilder(tx, self.source)(self.value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 337, in __call__
      vt = self._wrap(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 516, in _wrap
      return self.wrap_tensor(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1474, in wrap_tensor
      tensor_variable = wrap_fx_proxy(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1910, in wrap_fx_proxy
      return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2022, in wrap_fx_proxy_cls
      example_value = wrap_to_fake_tensor_and_record(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2589, in wrap_to_fake_tensor_and_record
      fake_e = wrap_fake_exception(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1459, in wrap_fake_exception
      return fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2590, in <lambda>
      lambda: tx.fake_mode.from_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2171, in from_tensor
      return self.fake_tensor_converter.from_real_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 374, in from_real_tensor
      out = self.meta_converter(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1642, in __call__
      r = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1543, in meta_tensor
      r.grad = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1434, in meta_tensor
      r = empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 847, in empty_create_subclass
      sub = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 833, in _empty_create_subclass
      new_empty_tensor = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 818, in _empty_create_subclass
      return self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1322, in meta_tensor
      base = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1429, in meta_tensor
      ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 752, in sym_sizes_strides_storage_offset
      return shape_env._create_symbolic_sizes_strides_storage_offset(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
      return retlog(fn(*args, **kwargs))
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3222, in _create_symbolic_sizes_strides_storage_offset
      assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
  AssertionError: 2 != 1

  from user code:
     File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 116, in param_groups_adam
      for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:

  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

  You can suppress this exception and fall back to eager by setting:
      import torch._dynamo
      torch._dynamo.config.suppress_errors = True

  To execute this test, run the following from the base repo dir:
      python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2

  This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

  Process 1 exited with error code 10 and exception:
  Traceback (most recent call last):
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 664, in run_test
      getattr(self, test_name)()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 543, in wrapper
      fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2918, in wrapper
      method(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 184, in wrapper
      return func(*args, **kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 239, in test_fsdp2
      self.run_subtests(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_fsdp.py", line 1141, in run_subtests
      return run_subtests(self, *args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 882, in run_subtests
      test_fn(*test_args, **test_kwargs, **subtest_kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 284, in _test_fsdp2
      fsdp_optim.step()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 479, in wrapper
      out = func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 110, in step
      torch.compile(param_groups_adam, fullgraph=True)(param_groups)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
      return self._torchdynamo_orig_callable(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
      return _compile(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
      guarded_code = compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
      return _compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
      return StrobelightCompileTimeProfiler.profile_compile_time(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
      out_code = transform_code_object(code, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
      transformations(instructions, code_options)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
      tracer.run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
      super().run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
      while self.step():
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1048, in STORE_FAST
      self._store_fast(inst.argval)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1044, in _store_fast
      loaded_vt.set_name_hint(name)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
      return getattr(self.realize(), name)(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
      self._cache.realize()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
      self.vt = VariableBuilder(tx, self.source)(self.value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 337, in __call__
      vt = self._wrap(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 516, in _wrap
      return self.wrap_tensor(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1474, in wrap_tensor
      tensor_variable = wrap_fx_proxy(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1910, in wrap_fx_proxy
      return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2022, in wrap_fx_proxy_cls
      example_value = wrap_to_fake_tensor_and_record(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2589, in wrap_to_fake_tensor_and_record
      fake_e = wrap_fake_exception(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1459, in wrap_fake_exception
      return fn()
  Traceback (most recent call last):
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 102, in <module>
      main()
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main
      run_cmd_or_die(f"docker exec -t {container_name} /exec")
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die
      raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}")
  RuntimeError: Command docker exec -t 6a284e024d9e5fa50319dc78d9852c462a2c247de64ee9d0a3d6d326ab401309 /exec failed with exit code 1
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2590, in <lambda>
      lambda: tx.fake_mode.from_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2171, in from_tensor
      return self.fake_tensor_converter.from_real_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 374, in from_real_tensor
      out = self.meta_converter(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1642, in __call__
      r = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1543, in meta_tensor
      r.grad = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1434, in meta_tensor
      r = empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 847, in empty_create_subclass
      sub = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 833, in _empty_create_subclass
      new_empty_tensor = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 818, in _empty_create_subclass
      return self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1322, in meta_tensor
      base = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1429, in meta_tensor
      ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 752, in sym_sizes_strides_storage_offset
      return shape_env._create_symbolic_sizes_strides_storage_offset(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
      return retlog(fn(*args, **kwargs))
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3222, in _create_symbolic_sizes_strides_storage_offset
      assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
  AssertionError: 2 != 1

  from user code:
     File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 116, in param_groups_adam
      for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:

  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

  You can suppress this exception and fall back to eager by setting:
      import torch._dynamo
      torch._dynamo.config.suppress_errors = True

  To execute this test, run the following from the base repo dir:
      python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2

  This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
  ==== 1 failed, 1135 passed, 267 skipped, 55 warnings in 1560.72s (0:26:00) =====
  Error: Process completed with exit code 1.
gau-nernst commented 3 months ago

Just realized TORCH_VERSION_AFTER_2_4 will return False in 2.4.0. Still got that old problem 🤣. So low bit optim FSDP2 test will not run in 2.4.0 CI.

https://github.com/pytorch/ao/blob/261d0a4feff1731163140d469ae69edb8a1da34b/test/prototype/test_low_bit_optim.py#L232

image
gau-nernst commented 3 months ago

The error message is very cryptic. AdamW8bit doesn't use dynamic shape though, so don't know why it pops up. And this error only happens to FSDP2 test, not the normal single-gpu test. Would you know who can take a look into this?

msaroufim commented 3 months ago

Yeah the version problem is kinda getting out of hand, I'll fix that asap

Regarding the error usually @awgu and @weifengpy are usually my gotos for fsdp2 issues

awgu commented 3 months ago

I think low-bit optimizer + FSDP2 is actually low-bit optimizer + DTensor + torch.compile, for which @bdhirsh is probably the best.

bdhirsh commented 3 months ago

(taking a look)

bdhirsh commented 3 months ago

The problem is that we have a pretty complicated input to the compiled region: our input is a DTensor, that has a local_tensor._base, and also has a populated .grad field that is also a DTensor, which has a _local_tensor._base with a different number of dims compared to the original ._base.

I have a min repro here https://github.com/pytorch/pytorch/issues/133274.

In the meantime, I also found that this tweak gets me past the error, although I'm not sure that we actually want to land it to eager FSDP2 (cc @awgu ):

diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index d739ffbcf96..c512ea7c37f 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -324,7 +324,7 @@ def foreach_reduce(
                 size=fsdp_param.sharded_size,
                 stride=fsdp_param.contiguous_sharded_stride,
                 storage_offset=flat_grad_offset,
-            )
+            ).detach()
             to_accumulate_grad = fsdp_param.sharded_param.grad is not None
             if fsdp_param.offload_to_cpu:
                 # Only overlap the D2H copy (copying to pinned memory) if not
gau-nernst commented 3 months ago

Thank you for the quick debug. May I ask

our input is a DTensor, that has a local_tensor._base, and also has a populated .grad field that is also a DTensor, which has a _local_tensor._base with a different number of dims compared to the original ._base

You mentioned the input being concerned has .grad field, indicating that it is a parameter. In the low bit optim test, only the optimizer states are tensor subclass, so they shouldn't have .grad field. I think something is not quite right here?

bdhirsh commented 3 months ago

I was looking at the values of param_groups, which are the inputs to your torch.compile region, here: https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adam.py#L110

And empirically, param_groups contains DTensor parameters with the above properties. Are you saying you don't expect the parameters themselves to be DTensors? Maybe @awgu would know better?

(Pdb) p type(param_groups[0][0][0][0])
<class 'torch.distributed._tensor.api.DTensor'>
(Pdb) p param_groups[0][0][0][0]._local_tensor._base.ndim
2
(Pdb) p param_groups[0][0][0][0].grad._local_tensor._base.ndim
1
(Pdb) p isinstance(param_groups[0][0][0][0], torch.nn.Parameter)
True
gau-nernst commented 3 months ago

@bdhirsh I see, thank you for the clarification. The subclass you were referring to is DTensor, not my custom subclass for quantized optimizer state. It makes sense.

But it also raises another question. How come other FSDP2 tests in torchao did not fail 😅. Then I rmb NF4 is not trainable, so it won't have .grad field. Not sure about other FSDP2 tests in torchao.

In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.

bdhirsh commented 3 months ago

In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.

yeah, I could definitely believe that this is true (I don't have bandwidth to add those tests, but if someone wants to try making a smaller repro that doesn't use your low bit optimizer they are welcome to 😄 )

then again, I think this is a pretty one-off bug that we just expected to be very rarely hit (we haven't had to excercise a lot of code in compile where our tensor inputs to the graph also have .grad fields that are subclasses), that should have a relatively straightforward fix.

gau-nernst commented 3 months ago

I can confirm normal optimizers have this bug too

import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
    n_layers=3,
    n_heads=4,
    dim=1024,
    vocab_size=vocab_size,
    max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()

for m in model.layers:
    fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)

# compile optimizer
optim.step = torch.compile(optim.step)

for iter_idx in range(5):
    inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    model(inp).mean().backward()
    optim.step()
    optim.zero_grad()

Run with torchrun --nnodes 1 --nproc_per_node 1 debug.py

Agree with your last point 😄! Hopefully the fix in PyTorch core is coming soon! Thank you for the help!

gau-nernst commented 2 months ago

@bdhirsh I noticed that the FSDP test for low-bit optim now passed with torch nightly. Was it fixed in core recently? I didn't see any updates in https://github.com/pytorch/pytorch/issues/133274

bdhirsh commented 2 months ago

hmm that's strange - i ran the non-subclass repro you put above locally and it still fails for me:

import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
    n_layers=3,
    n_heads=4,
    dim=1024,
    vocab_size=vocab_size,
    max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()

for m in model.layers:
    fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)

# compile optimizer
optim.step = torch.compile(optim.step)

for iter_idx in range(5):
    inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    model(inp).mean().backward()
    optim.step()
    optim.zero_grad()
gau-nernst commented 2 months ago

Hmm, I think it might be because I change the way I compile the optim step. Now I static-shape compile optim step for each param, instead of optim step for all params #812. In that case the issue in pytorch core is still there, but we can probably close this issue?

bdhirsh commented 2 months ago

ah yeah, great - this is definitely just a bug at the intersection of subclasses + dynamic shapes + optimizer/gradient, so if you're ok with static shapes only for now (which might be better for perf anyway), closing this issue sounds fine to me