NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.23k stars 2.08k forks source link

Fix the bug where the optimizer doesn't actually use multi_tensor_applier under float16. #846

Closed Gstdioh closed 1 month ago

Gstdioh commented 1 month ago

Fix the bug where the optimizer doesn't actually use multi_tensor_applier under float16, because overflow_buf is always False.

Specifically, overflow_buf = self._dummy_overflow_buf, and self._dummy_overflow_buf is initialized as torch.tensor([0], dtype=torch.int, device='cuda') under float16.

However, bool(torch.tensor([0], dtype=torch.int, device='cuda')) is False, meaning overflow_buf is always False.

The original code is similar to the following:

if self.config.bf16:
    self._dummy_overflow_buf = None
else:
    self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')

...

_multi_tensor_copy_this_to_that(
    this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
)

...

def _multi_tensor_copy_this_to_that(
    this: List[torch.Tensor], that: List[torch.Tensor], overflow_buf: Optional[torch.Tensor] = None
):
    """
    Use multi-tensor-applier to copy values from one list to another.
    We don't have a bfloat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16.
    """
    if overflow_buf:
        overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
    else:
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

Therefore, the optimizer doesn't actually use multi_tensor_applier under float16. Here, it can be changed to directly determine whether overflow_buf is None.