huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.97k stars 26.79k forks source link

[Data2Vec] Incompatibility with the original implementation #17010

Closed arxyzan closed 2 years ago

arxyzan commented 2 years ago

Hello dear HuggingFace team! According to the original paper, data2vec is not an actual model but more of a self-distilling training strategy. It takes an encoder model as backbone (RoBERTa for text, BEiT for vision, wav2vec for audio as mentioned in the paper) and pre-trains the encoder (student) to predict representations extracted from the EMA instance of the encoder (teacher), meaning the encoder can be any Transformer-based encoder model. After pretraining, in order to finetune or get predictions, the encoder itself is what matters and data2vec is of no use! (as seen here) I reviewed data2vec implementation in HF transformers and noticed that you decided to use static encoders (BERT for text, BEiT for vision and wav2vec2 for audio) so for example, using Data2VecVisionModel for any task would be the same as using BEiTModel. Also I noticed that the encoders used for HF Data2Vec are not exactly the same models I mentioned above and there are some minor differences. The reason I'm wondering this, is because I was trying to copy the weights from your models to apply them to my own models in my own repo and found out that I can't due to those incompatibilities. So my question is, what was the purpose behind all this? and did you train all those models or copied the weights from the original checkpoints in fairseq?

Regards, Aryan

LysandreJik commented 2 years ago

cc @patrickvonplaten @NielsRogge

NielsRogge commented 2 years ago

Also I noticed that the encoders used for HF Data2Vec are not exactly the same models I mentioned above and there are some minor differences. The reason I'm wondering this, is because I was trying to copy the weights from your models to apply them to my own models in my own repo and found out that I can't due to those incompatibilities.

Can you elaborate on this? We converted the weights from the original repo, so they should be equivalent to the original implementation.

arxyzan commented 2 years ago

Hello @NielsRogge, sorry for the delayed response. Seems like I made a mistake regarding mismatch between architectures! Perhaps I loaded incorrect models using AutoModel. Today I reviewed all three models thoroughly and found no mismatch. But how about my first question? What was your intent behind reimplementing 3 models for data2vec while they're exactly the same as RoBERTa, BEiT and Wav2Vec2 which are already present in the transformers package?

Thanks, Aryan

arxyzan commented 2 years ago

Regarding the fact that some minor differences exist in model architectures, what I attempted to do is that I tried to load weights directly from data2vec checkpoints to existing encoder models as below:

  1. Loaded state dict from facebook/data2vec-text-base checkpoint into roberta-base and all keys matched successfully.

  2. Loaded state dict from facebook/data2vec-vision-base checkpoint into microsoft/beit-base-patch16-224 and got IncompatibleKeys warning: _IncompatibleKeys(missing_keys=['encoder.relative_position_bias.relative_position_bias_table', 'encoder.relative_position_bias.relative_position_index', 'layernorm.weight', 'layernorm.bias'], unexpected_keys=['pooler.layernorm.weight', 'pooler.layernorm.bias', 'encoder.layer.0.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.0.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.1.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.1.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.2.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.2.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.3.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.3.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.4.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.4.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.5.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.5.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.6.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.6.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.7.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.7.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.8.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.8.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.9.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.9.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.10.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.10.attention.attention.relative_position_bias.relative_position_index', 'encoder.layer.11.attention.attention.relative_position_bias.relative_position_bias_table', 'encoder.layer.11.attention.attention.relative_position_bias.relative_position_index'])

  3. Loaded state dict from facebook/data2vec-audio-base checkpoint into facebook/wav2vec2-base and got IncompatibleKeys warning: _IncompatibleKeys(missing_keys=['encoder.pos_conv_embed.conv.bias', 'encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v'], unexpected_keys=['feature_extractor.conv_layers.1.layer_norm.weight', 'feature_extractor.conv_layers.1.layer_norm.bias', 'feature_extractor.conv_layers.2.layer_norm.weight', 'feature_extractor.conv_layers.2.layer_norm.bias', 'feature_extractor.conv_layers.3.layer_norm.weight', 'feature_extractor.conv_layers.3.layer_norm.bias', 'feature_extractor.conv_layers.4.layer_norm.weight', 'feature_extractor.conv_layers.4.layer_norm.bias', 'feature_extractor.conv_layers.5.layer_norm.weight', 'feature_extractor.conv_layers.5.layer_norm.bias', 'feature_extractor.conv_layers.6.layer_norm.weight', 'feature_extractor.conv_layers.6.layer_norm.bias', 'encoder.pos_conv_embed.layers.0.conv.weight', 'encoder.pos_conv_embed.layers.0.conv.bias', 'encoder.pos_conv_embed.layers.1.conv.weight', 'encoder.pos_conv_embed.layers.1.conv.bias', 'encoder.pos_conv_embed.layers.2.conv.weight', 'encoder.pos_conv_embed.layers.2.conv.bias', 'encoder.pos_conv_embed.layers.3.conv.weight', 'encoder.pos_conv_embed.layers.3.conv.bias', 'encoder.pos_conv_embed.layers.4.conv.weight', 'encoder.pos_conv_embed.layers.4.conv.bias']) @NielsRogge

arxyzan commented 2 years ago

For BEiT, the problem was that there are some differences in the config; In order to load weights with no errors these values must be set in config:

...
beit_config = BeitConfig(use_relative_position_bias=False,
                         use_mean_pooling=False, 
                         use_shared_relative_position_bias=True)

So in terms of architecutre, transformers.models.BEiTModel and transformers.models.Data2VecVisionModel are the same, but for Wav2Vec2Modelvs Data2VecAudioModel it's not the same case! they're actually different in terms of design so I'd have to use another technique to transfer weights from Data2VecAudio to Wav2Vec2. I know that the reason is that the same case exists in fairseq too. There are some design differences between data2vec-audio and wav2vec2, so in order to transfer weights from there you had to make those changes to the Data2VecAudioModel codes.

NielsRogge commented 2 years ago

But how about my first question? What was your intent behind reimplementing 3 models for data2vec while they're exactly the same as RoBERTa, BEiT and Wav2Vec2 which are already present in the transformers package?

We're planning to add Data2VecAudioForPretraining etc, which is why the implementations were duplicated.

arxyzan commented 2 years ago

Cool! looking forward to that. Thanks for putting your time replying. I'm closing this issue.