Closed nathan-az closed 7 months ago
Awesome, thanks @lewtun ! I like your suggestion. It's a bit cleaner, and more general as it should work with FSDP (as far as I know neither FSDP nor deepspeed's model parallelism are compatible with BNB quant). Should be implemented now.
I'll make a separate PR with instructions on using ZeRO3. I'm using a separate deepspeed config file, but to avoid committing new files to your repo (complicating things) I'll just confirm I can get the CPU offloading working nicely directly in the accelerate config, and if I can I'll add some instructions :)
Let me know if any other changes are needed. Style and format checks passed :)
Thanks for iterating @nathan-az - this looks sweet as!
I'll just confirm I can get the CPU offloading working nicely directly in the accelerate config, and if I can I'll add some instructions
I know one can override the accelerate
config args by passing flags like --num_processes=2
, but I've never checked if this is also possible for the deepspeed subset, which is a nested dict. If you find it's not easy to override at the command line, feel free to add a dedicated config like deepspeed_zero3_offload.yaml
:)
I'm using the (awesome) script setups and data utilities from this repo for work. However I am not in an HPC environment (so launching multi-node without Slurm), and am using custom configs (both accelerate and deepspeed).
I noticed that when using ZeRO 3 with CPU offloading, I am getting CUDA OOMs when setting batch sizes to the same that I usually would with the native
deepspeed
multi-node launcher and a custom script. I believe this is because currently the model parameters are always being put on the GPUs if available, consuming significant VRAM.My understanding is that
from_pretrained
is made to cater to a deepspeed setup, and handles device placement and model parallelism in the offloading case. This PR adds a quick check confirming:offload_param_device
been set (I don't explicitly check forcpu
sincenvme
is also valid though less common)If it is set, device map is kept None in the init kwargs. From testing on my use case I've been able to increase the batch sizes back to what I expect to work with success, and VRAM usage seems to be spread evenly across any nodes.
Let me know if there are problems with the PR. I ran
style
andquality
from the Makefile before pushing.