huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.98k stars 975 forks source link

[Utils] `has_offloaded_params` #3188

Closed kylesayrs closed 1 month ago

kylesayrs commented 1 month ago

Purpose

Changes

Testing

test_has_offloaded_params.py ```python3 import torch from accelerate.utils import has_offloaded_params from accelerate.hooks import attach_align_device_hook from accelerate.big_modeling import cpu_offload_with_hook class ModelForTest(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(3, 4) self.batchnorm = torch.nn.BatchNorm1d(4) self.linear2 = torch.nn.Linear(4, 5) def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) model = ModelForTest() assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() model, hook = cpu_offload_with_hook(model, execution_device="cuda:0") assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() attach_align_device_hook(model, offload=False) assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() attach_align_device_hook(model, offload=True) assert has_offloaded_params(model.linear1) assert has_offloaded_params(model.batchnorm) assert has_offloaded_params(model.batchnorm) ```

Who can review?

@SunMarc @LysandreJik @mgoin

HuggingFaceDocBuilderDev commented 1 month ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

muellerzr commented 1 month ago

For the quality check, pip install -e .[quality]; make style; make quality should fix it right up!

kylesayrs commented 1 month ago

@SunMarc @muellerzr Thanks for the reviews! I've made some final changes, this is good to go from my end

SunMarc commented 1 month ago

I saw that you added some test in the PR description. Could you also add them in accelerate ? Thanks a lot !

kylesayrs commented 1 month ago

@SunMarc Added!