huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.49k stars 27.12k forks source link

(Willing to PR if this is acceptable) Mimic `adamw_torch_4bit` and have `adamw_torch_8bit` #34893

Open fzyzcjy opened 2 days ago

fzyzcjy commented 2 days ago

Feature request

Hi thanks for the lib! Currently there is adamw_torch_4bit, but I hope to mimic it to have a adamw_torch_8bit that uses 8bit torchao adamw.

The reason is that, I would like to use deepspeed cpu offload for the optimizer, and also use 8bit adamw. However, the 8bit one in current hf transformers does not support cpu, so I need to use the torchao one.

Motivation

-

Your contribution

yes, willing to PR

Rocketknight1 commented 21 hours ago

cc @muellerzr for deepspeed/accelerate!

muellerzr commented 20 hours ago

A PR for this would be great 🤗 cc @SunMarc

fzyzcjy commented 20 hours ago

Thanks! I will do that later.

SunMarc commented 18 hours ago

Feel free to add it ! Let me know if you need any help

fzyzcjy commented 10 hours ago

Thanks! I will firstly mimic the 4bit one and see whether it works.