pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.68k stars 212 forks source link

meta device issue with float8 delayed scale #654

Open weifengpy opened 1 month ago

weifengpy commented 1 month ago

repro:

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.scaling_type_weight "delayed" --metrics.log_freq 1 --training.steps 3 --checkpoint.enable_checkpoint --checkpoint.interval 2
  traceback : Traceback (most recent call last):
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper   
      return f(*args, **kwargs)
    File "/data/users/weif/torchtitan/train.py", line 301, in main
      pred = model(input_ids)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
      return inner()
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/weif/torchtitan/torchtitan/models/llama/model.py", line 439, in forward
      h = layer(h, self.freqs_cis)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
      return inner()
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1769, in inner
      args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 67, in fsdp_hook_wrapper
      return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 629, in _fn
      return fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 234, in _pre_forward
      args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 314, in pre_forward
      self.unshard(self.unshard_async_op)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 243, in unshard
      self._all_gather_result = foreach_all_gather(
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 139, in foreach_all_gather
      param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)    
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 217, in _get_param_all_gather_inputs
      param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 702, in all_gather_inputs
      ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
    File "/data/users/weif/ao/torchao/float8/fsdp_utils.py", line 408, in fsdp_pre_all_gather
      float8_tensor = hp_tensor_to_float8_delayed(
    File "/data/users/weif/ao/torchao/float8/float8_scaling_utils.py", line 105, in hp_tensor_to_float8_delayed
      return hp_tensor_and_scale_to_float8(
    File "/data/users/weif/ao/torchao/float8/float8_tensor.py", line 254, in hp_tensor_and_scale_to_float8
      return _ToFloat8ConstrFunc.apply(
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/weif/ao/torchao/float8/float8_tensor.py", line 170, in forward
      tensor_scaled = tensor.to(torch.float32) * scale
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 289, in _fn
      result = fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
      result = fn(**bound.arguments)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1064, in _ref
      output = prim(a, b)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1671, in mul
      return prims.mul(a, b)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
      return self._op(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 95, in meta_kernel
      return fake_impl_holder.kernel(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/utils.py", line 20, in __call__
      return self.func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/library.py", line 1190, in inner
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 614, in fake_impl
      return self._abstract_fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims/__init__.py", line 402, in _prim_elementwise_meta
      utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
      raise RuntimeError(msg)
  RuntimeError: Tensor on device meta is not on the expected device cuda:1!
weifengpy commented 1 month ago

cc @vkuzo

vkuzo commented 1 month ago

without debugging, my guess would be something like:

  1. model is created on device meta
  2. checkpoint is loaded with device cuda, but it does not have the extra buffers for delayed scaling

I can take a look next week, unless someone gets to it faster

weifengpy commented 1 month ago

checkpoint is loaded with device cuda, but it does not have the extra buffers for delayed scaling

if running the repo for the 1st time, torchtitan/output/checkpoint folder will be empty. the model won't load checkponits but the error is still there. We do meta init and call init_weights to move model from meta to cuda. buffers for delayed scaling might need some treatment

I can take a look next week, unless someone gets to it faster

thanks!

vkuzo commented 1 month ago

I see, then this line is relevant: https://github.com/pytorch/ao/blob/e85c1a318b06bbdb3b8c7f92f3257999864446b0/torchao/float8/float8_linear.py#L648

We'll have to think if we can figure out to do this automatically without introducing one more API. If not, we'll have to design such as API.

weifengpy commented 1 month ago

I see, then this line is relevant: https://github.com/pytorch/ao/blob/e85c1a318b06bbdb3b8c7f92f3257999864446b0/torchao/float8/float8_linear.py#L648

We'll have to think if we can figure out to do this automatically without introducing one more API. If not, we'll have to design such as API.

I see. it sounds plausible

vkuzo commented 1 week ago

I really don't love this solution, but we could do something like this: https://github.com/pytorch/ao/pull/1292. Thoughts?

weifengpy commented 1 week ago

I really don't love this solution, but we could do something like this: pytorch/ao#1292. Thoughts?

thanks for the fix!

vkuzo commented 1 week ago

opening as the fix isn't landed yet :)