pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

delayed scaling safety logic currently doesn't work with activation checkpointing #267

Closed vkuzo closed 2 months ago

vkuzo commented 4 months ago

Our current delayed scaling API askd the user to call the sync_float8_amax_and_scale_history after each backward and before the optimizer step. This does not work on the first iteration if activation checkpointing is on, because the first backward calls the first forward, and an exception is thrown.

For now can workaround with the config override, but we need a better API design since we need to support AC.

example trace:

Root Cause (first observed failure):
[0]:
  time      : 2024-05-28_10:24:53
  host      : devgpu003.cco3.facebook.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3789795)
  error_file: /tmp/torchelastic_9ib8xc3f/none_wmccvyr7/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/data/users/vasiliy/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/train.py", line 315, in main
      loss.backward()
    File "/data/users/vasiliy/pytorch/torch/_tensor.py", line 523, in backward
      torch.autograd.backward(
    File "/data/users/vasiliy/pytorch/torch/autograd/__init__.py", line 284, in backward
      _engine_run_backward(
    File "/data/users/vasiliy/pytorch/torch/autograd/graph.py", line 767, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/autograd/function.py", line 302, in apply
      return user_fn(self, *args)
             ^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 122, in backward
      fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY = ctx.saved_tensors
                                                                ^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1115, in unpack_hook
      frame.recompute_fn(*args)
    File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1399, in recompute_fn
      fn(*args, **kwargs)
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 317, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 186, in forward
      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
                   ^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 383, in forward
      self.float8_pre_forward(x)
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 361, in float8_pre_forward
      raise AssertionError(
  AssertionError: amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward
vkuzo commented 2 months ago

https://github.com/pytorch/ao/issues/570