Closed yongjer closed 1 year ago
cc @SunMarc
Hello @yongjer @LysandreJik @SunMarc
This seems a tricky bug. I would like to try to fix it, but maybe I will need some help on how to approach it.
The issue is:
When you use device_map = "auto"
, internally transformers
creates a context manager from accelerate
(https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/modeling_utils.py#L3081 and https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/modeling_utils.py#L3086). You can see that this context manager basically set the default device to be "meta" (https://github.com/huggingface/accelerate/blob/dab62832de44c84e80045e4db53e087b71d0fd85/src/accelerate/big_modeling.py#L51-L81).
During the instantiation of the DETR model, there is a step where we want frozen the batch norm (https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/models/detr/modeling_detr.py#L307-L327), but the backbone, which was created with timm, is using meta device, i.e., the weight are not materialized so we can't copy.
As a workaround we can try to guarantee that the backbone model will be created on a physical device, but it breaks a bit the idea of device_map.
Any thoughts on how to solve this issue?
If I'm not wrong (I usually am), we could solve it by not trying to load weights on the DetrFrozenBatchNorm2D if the device is meta
, something like:
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = DetrFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
And then add something like
self._no_split_modules = ["DetrModel", "DetrMLPPredictionHead", "nn.Linear"]
To the DetrForObjectDetection
constructor method.
This should be solved once the PR is merged !
System Info
transformers
version: 4.34.0Who can help?
@Narsil
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
here is my code below:
when set pipeline(device_map="auto") will raise an error:
Expected behavior
when set device=0 rather than device_map = "auto", it works