Closed kylesayrs closed 1 month ago
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
For the quality check, pip install -e .[quality]; make style; make quality
should fix it right up!
@SunMarc @muellerzr Thanks for the reviews! I've made some final changes, this is good to go from my end
I saw that you added some test in the PR description. Could you also add them in accelerate ? Thanks a lot !
@SunMarc Added!
Purpose
has_offloaded_params
utility function which returnsTrue
iff there is anAlignDevicesHook
with offloading enabledChanges
accelerate.utils.modeling
, exposed throughaccelerate.utils
accelerate.utils.modeling
andaccelerate.accelerator
with newly added functionTesting
test_has_offloaded_params.py
```python3 import torch from accelerate.utils import has_offloaded_params from accelerate.hooks import attach_align_device_hook from accelerate.big_modeling import cpu_offload_with_hook class ModelForTest(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(3, 4) self.batchnorm = torch.nn.BatchNorm1d(4) self.linear2 = torch.nn.Linear(4, 5) def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) model = ModelForTest() assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() model, hook = cpu_offload_with_hook(model, execution_device="cuda:0") assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() attach_align_device_hook(model, offload=False) assert not has_offloaded_params(model.linear1) assert not has_offloaded_params(model.batchnorm) assert not has_offloaded_params(model.batchnorm) model = ModelForTest() attach_align_device_hook(model, offload=True) assert has_offloaded_params(model.linear1) assert has_offloaded_params(model.batchnorm) assert has_offloaded_params(model.batchnorm) ```Who can review?
@SunMarc @LysandreJik @mgoin