Open amyeroberts opened 7 months ago
I'm working on Resnet
edit: I'm running into a strange issue, where the tests would pass on one system and fail on another. I'm going to close to PR for now and investigate further.
BERT is not included in the above list of models. Does it mean that "device_map='auto'" is available for BERT models in any upcoming version of HF transformers? I still see the message BertForSequenceClassification does not support device_map='auto'
with transformers 4.39.3.
Hi @tnnandi, the list above is just for vision models that I got from a simple grep and filtering. device_map="auto"
isn't yet enabled for BERT, c.f. #25296. If you or anyone in the community would like to add it, we'd be happy to review a PR.
Hi @amyeroberts, hope you are well :) I'm not sure why, but it looks like the unit tests are passing even without defining _no_split_modules
. I'm testing on systems with two GPUs. Any idea why this is happening?
$ pytest tests/models/align/test_modeling_align.py -k "parallel or offload"
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 262 items / 241 deselected / 21 selected
tests/models/align/test_modeling_align.py ....................s [100%]
<warnings redacted>
================================================ 20 passed, 1 skipped, 241 deselected, 2 warnings in 5.33s ================================================
$ pytest tests/models/bert/test_modeling_bert.py -k "parallel or offload"
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 156 items / 148 deselected / 8 selected
tests/models/bert/test_modeling_bert.py ........ [100%]
<warnings redacted>
====================================================== 8 passed, 148 deselected, 2 warnings in 4.59s ======================================================
@jla524 It's because these tests are skipped if _no_split_modules
aren't defined (the model default) e.g. here for test_disk_offload_bin
. This is admittedly confusing, and should really be done with self.skipTest
Models updated so far:
Models remaining:
Hi! Would love to take the following models and give it a try: Videomae Vision_encoder_decoder Vision_text_dual_encoder Vit_mae X_clip Thanks!
Hi! I encountered an issue while running tests for some models, specifically Vision_text_dual_encoder
. Even though I set the _no_split_module
, the unit test still skips these tests. Does this mean that test cases for these models are not implemented?
Additionally, I want to know how to define certain models to be skipped in the test. For example, I have ViTMAEForPreTraining
, which should not be split across different GPUs IMO. However, this causes the test case to fail because the test expects the model to be split across different devices.
@WenheLI Ah, I should take the vision text dual encoder off the list, we can theoretically load any encoder and decoder there, so it's not possible to know the modules that can be split or not, same for vision encoder-decoder
Hey @amyeroberts, I was experimenting with defining _no_split_modules
for the Segformer model, and I encountered some unexpected results when running the tests.
When I set _no_split_modules = []
for the Segformer model, all tests failed because no weights were loaded to the GPUs. Here are the error messages I received:
FAILED tests/models/segformer/test_modeling_segformer.py::SegformerModelTest::test_cpu_offload - AssertionError: Items in the second set but not the first:
FAILED tests/models/segformer/test_modeling_segformer.py::SegformerModelTest::test_disk_offload_bin - ValueError: You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead.
FAILED tests/models/segformer/test_modeling_segformer.py::SegformerModelTest::test_disk_offload_safetensors - ValueError: You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead.
FAILED tests/models/segformer/test_modeling_segformer.py::SegformerModelTest::test_model_parallelism - AssertionError: Items in the second set but not the first:
To investigate further, I ran infer_auto_device_map
directly in a Jupyter notebook with GPU memory allocations used in those tests:
With 70% of the model size allocated to GPU:
compute_module_sizes(model)
total_size = compute_module_sizes(model)[""]
max_memory = {0: int(0.7 * total_size), "cpu": total_size * 2}
print(f"Total model size: {total_size}, max memory: {max_memory}")
print(
infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[]
)
)
Output:
Total model size: 1195632, max memory: {0: 836942, 'cpu': 2391264}
OrderedDict([('', 'cpu')])
With 90% of the model size allocated to GPU:
max_memory = {0: int(0.9 * total_size), "cpu": total_size * 2}
print(f"Total model size: {total_size}, max memory: {max_memory}")
print(
infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[]
)
)
Output:
Total model size: 1195632, max memory: {0: 1076068, 'cpu': 2391264}
OrderedDict([('segformer.encoder.patch_embeddings', 0),
('segformer.encoder.block.0.0.layer_norm_1', 0),
('segformer.encoder.block.0.0.attention.self.query', 0),
('segformer.encoder.block.0.0.attention.self.key', 0),
('segformer.encoder.block.0.0.attention.self.value', 0),
('segformer.encoder.block.0.0.attention.self.dropout', 0),
('segformer.encoder.block.0.0.attention.self.sr', 'cpu'),
('segformer.encoder.block.0.0.attention.self.layer_norm', 'cpu'),
('segformer.encoder.block.0.0.attention.output', 'cpu'),
('segformer.encoder.block.0.0.drop_path', 'cpu'),
('segformer.encoder.block.0.0.layer_norm_2', 'cpu'),
('segformer.encoder.block.0.0.mlp', 'cpu'),
('segformer.encoder.block.1', 'cpu'),
('segformer.encoder.block.2', 'cpu'),
('segformer.encoder.block.3', 'cpu'),
('segformer.encoder.layer_norm', 'cpu'),
('decode_head', 'cpu')])
The model can definitely be split into smaller modules as the 90% split case suggests. The problem with the 70% split case doesn't come from the smaller max_memory assigned for the GPU because the modules allocated to the GPU in the 90% case only account for 21,408 bytes of the total 1,195,632 bytes model size. This number (about 1.8% of the total model size) is significantly smaller than both the 70% (836,942 bytes) and 90% (1,076,068 bytes) max_memory defined for the GPU. Therefore, the problem is not the max_memory defined for the GPU, but rather some issues with the infer_auto_device_map
function's allocation strategy.
After looking into the infer_auto_device_map
function, I believe the logic might not be working as intended for models with highly imbalanced module sizes like Segformer:
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
while len(modules_to_treat) > 0:
name, module = modules_to_treat.pop(0)
module_size = module_sizes[name]
device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
current_memory_reserved = 0
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
current_memory_reserved = max_layer_size
This code reserves space for the largest layer on each main device. For Segformer, where the decode_head (1,107,984 bytes) is significantly larger than other layers, this approach may be too conservative, leaving little room for other layers on the GPU.
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# Split or not split?
modules_children = (
[]
if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
else list(module.named_children())
)
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# -> no split, we go to the next device
device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
else:
# -> split, we replace the module studied by its children + parameters
modules_children = list(module.named_parameters(recurse=False)) + modules_children
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
This part of the function decides whether to split a module or move to the next device. However, once it moves to the next device (i.e., CPU), it never goes back to the GPU, even if there's available space. This could explain why smaller modules aren't being allocated to the GPU after the decode_head is moved to the CPU.
@amyeroberts, should we wait until infer_auto_device_map
gets modified and works for models with uneven weight distributions before working on them? Or is it better to override those offload test functions from ModelTesterMixin in each model's test script so that we can enable device_map="auto"
for models with this problem sooner?
Hi @Nech-C, thanks for writing all of this up!
I don't know infer_auto_device_map
intimately, so not sure on the overall logic. Just from the code snippet, I'm assuming it's taking a greedy approach to the memory allocation, which won't be optimal in all cases (like this one). Rather than increment the devices like you said, we might want to keep a running count of available memory and go through the devices in decreasing priority, however this would be slower. If you'd like to open a PR to address I'd be happy to take a look, although I'm not a maintainer in accelerate so don't make decision on what should or could get added there :)
Regarding the order of things, having to update the tests I think is a sign that we should wait: infer_auto_device_map
is also what will be called when users do device_map="auto"
. If the model isn't being well allocated on the available devices then enabling this for segformer doesn't make sense.
Feature request
Feature Request
transformers
models can be easily loaded across multiple devices usingdevice_map="auto"
. This will automatically allocate weights across available devices e.g. GPUs and offload any weights onto CPU, then disk as necessary. This is useful when doing inference with large models.To enable this,
_no_split_modules
has to be defined in the model's pretrained model class e.g. like here for LLaMa. This defines layers which should not be split across devices, and should contain as few layers as possible.Steps to add
_no_split_modules
in the PreTrainedModel subclass. Try with_no_split_modules = []
firsttest_disk_offload_bin
,test_disk_offload_safetensors
,test_cpu_offload
,test_model_parallelism
,test_model_parallel_beam_search
pytest tests/models/{MODEL_NAME}/test_modeling_{MODEL_NAME}.py -vv -k "offload or parallelism"
Models
Motivation
Enable a powerful HF feature for all of our vision models
Your contribution
Ping me for review 🤗