huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.63k stars 927 forks source link

Feature Request: Context manager for downcasting on TPUs #587

Open muellerzr opened 2 years ago

muellerzr commented 2 years ago

Since downcasting on TPUs is better for calculating metrics and logging values, it might be nice to have an API that lets you toggle downcasting via a context manager.

Under the hood this should use patch_environment to patch the right environmental variable.

Proposed API design:

with accelerator.tpu_downcast():
  # Run through evaluation loop
muellerzr commented 2 years ago

On a TBD until such a method is supported by torch_xla