Fix the bug where the optimizer doesn't actually use multi_tensor_applier under float16, because overflow_buf is always False.
Specifically, overflow_buf = self._dummy_overflow_buf, and self._dummy_overflow_buf is initialized as torch.tensor([0], dtype=torch.int, device='cuda') under float16.
However, bool(torch.tensor([0], dtype=torch.int, device='cuda')) is False, meaning overflow_buf is always False.
The original code is similar to the following:
if self.config.bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
...
_multi_tensor_copy_this_to_that(
this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
)
...
def _multi_tensor_copy_this_to_that(
this: List[torch.Tensor], that: List[torch.Tensor], overflow_buf: Optional[torch.Tensor] = None
):
"""
Use multi-tensor-applier to copy values from one list to another.
We don't have a bfloat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16.
"""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
Therefore, the optimizer doesn't actually use multi_tensor_applier under float16. Here, it can be changed to directly determine whether overflow_buf is None.
Fix the bug where the optimizer doesn't actually use multi_tensor_applier under float16, because overflow_buf is always False.
Specifically,
overflow_buf = self._dummy_overflow_buf
, andself._dummy_overflow_buf
is initialized astorch.tensor([0], dtype=torch.int, device='cuda')
under float16.However,
bool(torch.tensor([0], dtype=torch.int, device='cuda'))
is False, meaning overflow_buf is always False.The original code is similar to the following:
Therefore, the optimizer doesn't actually use multi_tensor_applier under float16. Here, it can be changed to directly determine whether overflow_buf is None.