huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.43k stars 885 forks source link

accelerate autocast on mps device #1396

Open davidlight2018 opened 1 year ago

davidlight2018 commented 1 year ago

System Info

accelerate==0.18.0
system==M2 macos 13.3.1

Information

Tasks

Reproduction

Down below is the autocast implement. The default for fp16 is cuda, I wonder is it possible to adapt on macos mps device. Thanks very much.

    def autocast(self):
        """
        Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing
        different will happen otherwise.

        Example:

        ```python
        >>> from accelerate import Accelerator

        >>> accelerator = Accelerator(mixed_precision="fp16")
        >>> with accelerator.autocast():
        ...     train()
    """
    if self.native_amp:
        if self.mixed_precision == "fp16" and is_torch_version(">=", "1.10"):
            autocast_context = torch.cuda.amp.autocast(dtype=torch.float16)
        elif self.mixed_precision == "bf16":
            if self.distributed_type in [DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]:
                autocast_context = torch.autocast(dtype=torch.bfloat16, device_type=self.device.type)
        else:
            autocast_context = torch.cuda.amp.autocast()

        autocast_context.__enter__()
        yield
        autocast_context.__exit__(*sys.exc_info())
    else:
        yield


### Expected behavior

Omission
muellerzr commented 1 year ago

When PyTorch adds this feature, we'll make sure it's possible in Accelerate: https://github.com/pytorch/pytorch/issues/88415

davidlight2018 commented 1 year ago

Thanks for the reply, really looking forward to have it

Datamance commented 9 months ago

Commenting to follow. It'll be great to finally have mixed precision training on M1!

sagargulabani commented 4 months ago

I was trying to get this running,

What I did was

Create a fork of the latest pytorch repository and copy the changes manually from https://github.com/pytorch/pytorch/pull/99272 to my fork.

What I ended up with was.

https://github.com/sagargulabani/pytorch/commit/060ccae622a72711af0f8cce30b8617907ecd526

I built this pytorch locally from source to use in my conda environment.

However, now I am running into the following issue.

Traceback (most recent call last):
  File "/Users/sagargulabani/.cache/huggingface/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1964, in <module>
    main(args)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1676, in main
    model_pred = unet(
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1216, in forward
    sample, res_samples = downsample_block(
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 1279, in forward
    hidden_states = attn(
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 397, in forward
    hidden_states = block(
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention.py", line 366, in forward
    attn_output = self.attn2(
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention_processor.py", line 522, in forward
    return self.processor(
  File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention_processor.py", line 1279, in __call__
    hidden_states = F.scaled_dot_product_attention(
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.
Steps:   0%|                                                                                                                                                                                                      | 0/500 [00:01<?, ?it/s]
libc++abi: terminating due to uncaught exception of type std::__1::system_error: recursive_mutex lock failed: Invalid argument
Traceback (most recent call last):
  File "/opt/anaconda3/envs/hf/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1057, in launch_command
    simple_launcher(args)
  File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/launch.py", line 673, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)

The main error seems to be this

RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.

Wondering what could be the fix for this ?

I have an M3 Max machine.

Thanks.

janboeye commented 2 months ago

Any progress? Thanks