huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

Code Logical Bug: Using Init Handler Kwargs for Grad Scaler In FP8 Training (accelerate/accelerator.py) #3233

Open immortalCO opened 1 week ago

immortalCO commented 1 week ago

System Info

Any system - the bug is a logical bug in code.

Information

Tasks

Reproduction

This is to report a bug in the code near https://github.com/huggingface/accelerate/blob/main/src/accelerate/accelerator.py#L474C9-L511C60.

        self.scaler = None
        self.native_amp = False
        if (
            self.state.mixed_precision == "fp16"
            and self.device.type != "cpu"
            and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
        ):
            self.native_amp = True
            if self.device.type not in ("xpu", "cuda", "npu", "xla", "mlu", "musa") or is_torch_xla_available(
                check_is_tpu=True
            ):
                raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
            kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
            self.scaler = get_grad_scaler(self.distributed_type, **kwargs)

        elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
            DistributedType.DEEPSPEED,
            DistributedType.MEGATRON_LM,
        ):
            if self.device.type in ["cpu", "xpu"]:
                self.native_amp = True
            else:
                self.native_amp = is_bf16_available(True)
            if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
                raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")

        elif self.state.mixed_precision == "fp8":
            # We always enable `native_amp` for FP8
            self.native_amp = True
            if self.fp8_backend == "MSAMP":
                if self.distributed_type == DistributedType.FSDP:
                    raise NotImplementedError(
                        "`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
                        "Please consider using deepspeed, which is supported."
                    )
                elif self.distributed_type != DistributedType.DEEPSPEED:
                    # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
### BUG: Where does this kwargs come from
                    self.scaler = get_grad_scaler(**kwargs)

The kwargs in FP8 non-DeepSpeed was not initialized. And it actually uses InitProcessGroupKwargs, which is at L414.

        kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}

Considering a similar line at L486~L487, the L486 should be inserted before this line:

                elif self.distributed_type != DistributedType.DEEPSPEED:
                    kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} # insert L486
                    self.scaler = get_grad_scaler(**kwargs)

Expected behavior

Please fix this code logical bug.