huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.01k stars 26.54k forks source link

Finetuning doesn't initialize microsoft/resnet classifier weights with _fast_init #31841

Closed williford closed 2 months ago

williford commented 3 months ago

System Info

It seems that the changes with https://github.com/huggingface/transformers/pull/11471 broke fine-tuning of ResNet (when the number of classes is being changed).

It seems like most models handle this by adding Linear to the following: https://github.com/huggingface/transformers/blob/ae9dd02ee1a8627d26be32202202b8081e9855a4/src/transformers/models/resnet/modeling_resnet.py#L274

However, it seems like it would be better to handle it when the mismatch size is detected in modeling_utils.py: https://github.com/huggingface/transformers/blob/ae9dd02ee1a8627d26be32202202b8081e9855a4/src/transformers/modeling_utils.py#L4282

Who can help?

@amyeroberts

Information

Tasks

Reproduction

E.g.

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10, ignore_mismatched_sizes=True).classifier[1].
weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(1.7014e+38, grad_fn=<MaxBackward1>)

# Sometimes the same command gives NaN:
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10, ignore_mismatched_sizes=True).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(nan, grad_fn=<MaxBackward1>)

# no change in the number of labels
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=1000, ignore_mismatched_sizes=True).classifier[1
].weight.absolute().max()
tensor(4.7245, grad_fn=<MaxBackward1>)

# change weights
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=1001, ignore_mismatched_sizes=True ).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1001]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([1001, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(1.8520e-40, grad_fn=<MaxBackward1>)

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10000, ignore_mismatched_sizes=True).classifier[
1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10000]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10000, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(0., grad_fn=<MaxBackward1>)

Disabling the _fast_init fixes the issue:

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10000, ignore_mismatched_sizes=True, _fast_init=False).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10000]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10000, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(0.0221, grad_fn=<MaxBackward1>)

Expected behavior

The statistics of the initialized weights should be similar with and without the _fast_init - importantly, it shouldn't contain NaN's and the maximum absolute values shouldn't be 0 or really large (e.g. > 1e20).

NielsRogge commented 3 months ago

cc @ydshieh who worked on a similar issue which was fixed by https://github.com/huggingface/transformers/pull/28122

ydshieh commented 3 months ago

Hi @williford

Could you share your system info with us? You can run the command transformers-cli env and copy-paste its output below.

williford commented 3 months ago

For the reproduction I installed transformers with pip install git+https://github.com/huggingface/transformers:

williford commented 3 months ago

@ydshieh If I'm understanding the code correctly, your change makes sure the model._initialize_weights is called. ResNetForImageClassification inherits from ResNetPreTrainedModel, which overloads _init_weights. However, ResNetPreTrainedModel doesn't do anything when the module is a torch.nn.module.linear.Linear.

When fast_init is not set, then the Linear module initializes the weights via the "reset_parameters" method.

ydshieh commented 2 months ago

@williford Thank you for diving into this issue. Yes, you are correct! I opened a PR to fix it and it works now.