Open Nech-C opened 2 weeks ago
I'm sorry if my original issue was unclear. This is my first bug report. I have modified it; hopefully, it's easier to understand now. I will explain the problem in a few sentences.
When calling infer_auto_device_map
, if the max_memory assigns the cuda device/cpu a max memory that is less than the largest layer size of the model, no layer will be allocated to it. The amount of memory used for a main device is <= defined max_memory - largest layer size. While this may not be an issue for most cases, it causes offload/parallelism tests in test_modeling_common.py from the Transformers library to fail since they use the model's size as a reference.
Can you guys tell me whether it's the expected behavior so that we can start enabling device_map="auto"
for more models? Thank you so much!
Thanks for reporting this issue. I agree that this looks like there is room for improvement in order to allow as many modules as possible to be loaded on the fastest device. To me, this looks like a Knapsack problem, so finding an optimal solution could become quite interesting. But I'll let @muellerzr and @SunMarc comment on this, who have more background knowledge.
Hey @Nech-C, thanks for the detailed report ! You have a very good understanding of the situation. This could indeed be improved as infer_auto_device_map
doesn't perform well for unbalanced models. This is something i'm aware of.
As you said, there are two points that can be improved and here are my thoughts:
1) The current device needs to have more memory than the size of module plus the size of the largest layer just for module to be allocated on it.
This is required in case we perform cpu/disk offloading as we need to bring the largest offloaded layer to the gpu. One way to solve that is to create the device_map without the hypothesis that we will have offloaded layers. If we end up with offloaded layers, we redo the calculation with that hypothesis. Another solution would be check if the memory of the model < memory of the gpus. If that's not the case, we do the calculation without the hypothesis. However, we might still face issues with unbalanced models.
2) However, once it moves to the next device (i.e., CPU), it never goes back to the GPU, even if there's available space.
We can improve that part indeed. Since most transformers have balanced modules, It was working fine. For example, we can still consider coming back to the previous device if it has at least 10% of space that is available. The reasoning behind moving to the next device was to limit movement across devices as this will make inference slower. 1->2->3 and not 1->2->1->2->3.
If you are up to the challenge, feel free to open a PR to fix those two points ! I can have a look later !
@SunMarc Thank you so much for your detailed response! I appreciate your insights into the problem. I'd love to take on this challenge. It might take me a little time to get up to speed with the library, but I'm excited to give it a try. Can I reach out with any questions as I work on this?
Nice, thanks for helping ! Yes, feel free to ask any questions !
Hi @SunMarc,
I've been digging into the code, and this is more complicated than I first thought. I agree that conditionally calculating the device_map
may not fully solve the problem. So, I think we can address this in two separate PRs.
PR no. 1 (quick fix):
max_memory
.fallback_allocation
. When set to True, it will attempt an alternative assignment if max_memory
is sufficient for some (non-splittable module) + (largest layer) but insufficient for the default assignment attempt. This makes sure at least one module is assigned to the potential execution device and likely won't break other code.PR no. 2 (optimization): Work on your idea about utilizing space on main devices more efficiently. We'd add a new parameter so users can choose to maximize main device use.
Do you think this approach sounds good to you? If you agree, I can start working on the first pr soon. Let me know if you have any suggestions or concerns about this plan!
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
Steps to Reproduce
infer_auto_device_map
with `no_split_module_classes=[] and the defined max_memory.Code example
config = SegformerConfig( image_size=64, num_channels=3, num_encoder_blocks=4, depths=[1, 1, 1, 1], sr_ratios=[8, 4, 2, 1], hidden_sizes=[8, 8, 16, 16], num_attention_heads=[1, 1, 2, 2], hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, )
Segformer has a huge single layer that cannot be split
model = SegformerForSemanticSegmentation(config) model_sizes = compute_module_sizes(model) model_size = model_sizes[""]
0.7 is one of the split ratios defined for device offload tests in the Transformers library
split_ratio = 0.7 max_memory = {0: int(split_ratio model_size), "cpu": model_size 2}
print(f"model size: {model_size}, max memory: {max_memory}") print( infer_auto_device_map( model, max_memory=max_memory, no_split_module_classes=[] ) )
model size: 1195632, max memory: {0: 836942, 'cpu': 2391264} OrderedDict([('', 'cpu')])
: 1195632 segformer: 87648 segformer.encoder: 87648 decode_head: 1107984 decode_head.linear_c: 53248 decode_head.linear_fuse: 1048576 decode_head.batch_norm: 4104 decode_head.classifier: 2056
Toal model size: 1195632, max memory: {0: 1076068, 'cpu': 2391264} OrderedDict([('segformer.encoder.patch_embeddings', 0),
('segformer.encoder.block.0.0.layer_norm_1', 0),
('segformer.encoder.block.0.0.attention.self.query', 0),
('segformer.encoder.block.0.0.attention.self.key', 0),
('segformer.encoder.block.0.0.attention.self.value', 0),
('segformer.encoder.block.0.0.attention.self.dropout', 0),
('segformer.encoder.block.0.0.attention.self.sr', 'cpu'), ('segformer.encoder.block.0.0.attention.self.layer_norm', 'cpu'), ('segformer.encoder.block.0.0.attention.output', 'cpu'), ('segformer.encoder.block.0.0.drop_path', 'cpu'), ('segformer.encoder.block.0.0.layer_norm_2', 'cpu'), ('segformer.encoder.block.0.0.mlp', 'cpu'), ('segformer.encoder.block.1', 'cpu'), ('segformer.encoder.block.2', 'cpu'), ('segformer.encoder.block.3', 'cpu'), ('segformer.encoder.layer_norm', 'cpu'), ('decode_head', 'cpu')])
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
while len(modules_to_treat) > 0: name, module = modules_to_treat.pop(0) module_size = module_sizes[name]
if current_max_size is not None and current_memory_used + module_size > current_max_size:
Split or not split?