🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
Previous PR #3204 introduce an unintended behavior change where a module being aligned would also attempt to align parameters belonging to its submodules. This is a problem for functions like get_state_dict_offloaded_model which calls align_module_device on non-leaf modules.
tests/test_modeling_utils.py:808
state_dict = get_state_dict_offloaded_model(model)
src/accelerate/utils/modeling.py:1532: in get_state_dict_offloaded_model
with align_module_device(module, "cpu"):
/usr/lib/python3.10/contextlib.py:135: in __enter__
return next(self.gen)
src/accelerate/utils/modeling.py:1929: in align_module_device
set_module_tensor_to_device(module, name, execution_device)
ValueError: weight is on the meta device, we need a `value` to put in on cpu.
Purpose
Fix align_module_device bug where the function attempts to align meta tensors belonging to submodule parameters
Fix get_state_dict_offloaded_model(model) behavior to match model.state_dict()
Introduce tests for get_state_dict_offloaded_model
Changes
align_module_device now only aligns parameters directly attached to the parent
Move all tensors in module state dict to cpu before returning, including both parameters and buffers
Add usage tests for get_state_dict_offloaded_model
Testing
Added tests fail without changes, but pass with changes
Use below script to test end-to-end
test_e2e.py
```python3
from transformers import AutoModelForCausalLM
from accelerate import cpu_offload
from accelerate.utils.modeling import get_state_dict_offloaded_model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
cpu_offload(model)
state_dict = get_state_dict_offloaded_model(model)
```
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Background
get_state_dict_offloaded_model
which callsalign_module_device
on non-leaf modules.Purpose
align_module_device
bug where the function attempts to align meta tensors belonging to submodule parametersget_state_dict_offloaded_model(model)
behavior to matchmodel.state_dict()
get_state_dict_offloaded_model
Changes
align_module_device
now only aligns parameters directly attached to the parentget_state_dict_offloaded_model
Testing
test_e2e.py
```python3 from transformers import AutoModelForCausalLM from accelerate import cpu_offload from accelerate.utils.modeling import get_state_dict_offloaded_model model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") cpu_offload(model) state_dict = get_state_dict_offloaded_model(model) ```