huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.55k stars 25.72k forks source link

Add TensorFlow Wav2Vec2 for sequence classification #21778

Open sanchit-gandhi opened 1 year ago

sanchit-gandhi commented 1 year ago

Feature request

Wav2Vec2 is one of the most popular speech recognition models, used over 2 million times monthly. In the PyTorch modelling code, we have Wav2Vec2 for speech recognition and Wav2Vec2 for audio classification. However, in TensorFlow, we only have Wav2Vec2 for speech recognition. It would be great to add Wav2Vec2 for audio classification to the TensorFlow modelling code for cross-framework equivalence!

Motivation

The audio classification class for PyTorch Wav2Vec2 lives under Wav2Vec2ForSequenceClassification: https://github.com/huggingface/transformers/blob/13489248fa8f2cda7503628204f8f43b108797a2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1745

For this feature request, we'll need to port this PyTorch code into TensorFlow to create an equivalent TensorFlow class, TFWav2Vec2ForSequenceClassification.

This means adding a projection layer and classification layer on top of the base TFWav2Vec2Model. See the PyTorch code for reference: https://github.com/huggingface/transformers/blob/13489248fa8f2cda7503628204f8f43b108797a2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1753-L1758

To check our that our implementation is correct, we can do one forward pass of the PyTorch model and a forward pass of the TensorFlow model with the same inputs. If the output logits are to within 1e-5, we know that our TensorFlow model is correct ✅. We can then enable PT-TF cross tests in the modelling file such that these checks are performed by the CI.

Your contribution

Opening this one up to the community! If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!

sanchit-gandhi commented 1 year ago

This feature request is closely related to #21777! Once we have the TF Wav2Vec2 model for sequence classification added, we can copy across the projection layers and classification layers to Whisper in order to add TFWhisperForAudioClassifcation. Two birds with one stone ⚡️

nandwalritik commented 1 year ago

Hi @sanchit-gandhi I would love to take this up.

sanchit-gandhi commented 1 year ago

Very cool @nandwalritik! The first thing to do would be to add the equivalent TensorFlow code for the projection layer and classification layer on top of the base TFWav2Vec2Model. Do you want to have a go at adding this in a new PR? Happy to help with any questions / guidance! There's a bit of info as to where the PyTorch code lives in the original post ^

nandwalritik commented 1 year ago

Hi @sanchit-gandhi I have added some initial changes in #22073 PR, but while initializing it with pytorch weights model_tf = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks",from_pt=True) like this it gives Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification: can you guide me with this?

nandwalritik commented 1 year ago

hi @sanchit-gandhi can you guide me for above error, so that I can make all the required changes and close the PR.

vimarshc commented 1 year ago

Hey, Can you share the complete stack trace?

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification:

The important part of the error is Some. Most likely the classification head is not being loaded correctly.

Questions:

  1. Is it a warning? or is it an error?
  2. Did you try running the model after this?
  3. Tried using the same model for PyTorch and see if you get the same error.

cc: @nandwalritik

nandwalritik commented 1 year ago

Hey, Can you share the complete stack trace?

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification:

The important part of the error is Some. Most likely the classification head is not being loaded correctly.

Questions:

  1. Is it a warning? or is it an error?
  2. Did you try running the model after this?
  3. Tried using the same model for PyTorch and see if you get the same error.

cc: @nandwalritik

Stacktrace
>>> tf_model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks",from_pt=True)
/home/nandwalritik/nandwalritik/transformers/src/transformers/configuration_utils.py:379: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
  warnings.warn(

TFWav2Vec2ForSequenceClassification has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tine this model, you need a GPU or a TPU

TFWav2Vec2Model has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tine this model, you need a GPU or a TPU
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification: ['wav2vec2.encoder.layers.10.attention.q_proj.weight', 'wav2vec2.encoder.layers.1.attention.k_proj.bias', 'wav2vec2.encoder.layers.1.attention.q_proj.bias', 'wav2vec2.encoder.layers.0.attention.v_proj.bias', 'wav2vec2.encoder.layers.6.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.attention.v_proj.weight', 'wav2vec2.encoder.layers.1.attention.out_proj.bias', 'wav2vec2.encoder.layers.0.layer_norm.weight', 'wav2vec2.encoder.layers.3.layer_norm.weight', 'wav2vec2.encoder.layers.10.attention.out_proj.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.encoder.layers.8.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.layer_norm.weight', 'wav2vec2.encoder.layers.0.attention.k_proj.bias', 'wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.v_proj.weight', 'wav2vec2.encoder.layers.5.attention.k_proj.weight', 'wav2vec2.encoder.layers.6.final_layer_norm.weight', 'wav2vec2.encoder.layers.9.feed_forward.output_dense.weight', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.layers.6.attention.q_proj.weight', 'wav2vec2.encoder.layers.4.attention.v_proj.bias', 'wav2vec2.encoder.layers.11.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.6.attention.q_proj.bias', 'wav2vec2.encoder.layers.0.attention.q_proj.bias', 'wav2vec2.encoder.layers.4.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.attention.k_proj.bias', 'wav2vec2.encoder.layers.7.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.3.attention.k_proj.bias', 'wav2vec2.encoder.layers.8.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.attention.out_proj.weight', 'wav2vec2.encoder.layers.7.attention.out_proj.bias', 'wav2vec2.encoder.layers.8.attention.q_proj.bias', 'wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.encoder.layers.11.feed_forward.output_dense.weight', 'wav2vec2.encoder.pos_conv_embed.conv.bias', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.feed_forward.output_dense.bias', 'wav2vec2.feature_projection.projection.weight', 'wav2vec2.encoder.layers.5.attention.v_proj.weight', 'wav2vec2.encoder.layers.10.attention.out_proj.bias', 'wav2vec2.encoder.layers.4.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.attention.k_proj.weight', 'wav2vec2.encoder.layers.7.layer_norm.bias', 'wav2vec2.encoder.layers.1.attention.q_proj.weight', 'wav2vec2.encoder.layers.7.layer_norm.weight', 'wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.encoder.layers.8.attention.v_proj.bias', 'projector.bias', 'wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.attention.q_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.10.attention.k_proj.bias', 'wav2vec2.encoder.layers.4.attention.out_proj.bias', 'wav2vec2.encoder.layers.6.final_layer_norm.bias', 'layer_weights', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.k_proj.bias', 'wav2vec2.encoder.layers.7.attention.v_proj.weight', 'wav2vec2.encoder.layers.2.attention.out_proj.bias', 'wav2vec2.encoder.layers.4.attention.out_proj.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.q_proj.weight', 'wav2vec2.encoder.layers.3.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.feed_forward.output_dense.weight', 'wav2vec2.feature_projection.layer_norm.bias', 'wav2vec2.encoder.layers.6.attention.k_proj.weight', 'wav2vec2.encoder.layers.7.attention.v_proj.bias', 'wav2vec2.encoder.layers.4.attention.k_proj.bias', 'wav2vec2.encoder.layers.4.layer_norm.weight', 'wav2vec2.encoder.layers.9.attention.q_proj.bias', 'wav2vec2.encoder.layers.4.attention.q_proj.bias', 'wav2vec2.encoder.layers.8.layer_norm.weight', 'wav2vec2.encoder.layers.2.final_layer_norm.weight', 'wav2vec2.feature_projection.projection.bias', 'wav2vec2.encoder.layers.3.final_layer_norm.bias', 'wav2vec2.encoder.layers.8.layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.k_proj.bias', 'wav2vec2.encoder.layers.5.layer_norm.weight', 'wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.6.attention.v_proj.bias', 'wav2vec2.encoder.layers.8.attention.v_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.5.attention.out_proj.bias', 'wav2vec2.encoder.layers.10.layer_norm.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.9.attention.q_proj.weight', 'wav2vec2.encoder.layers.5.attention.v_proj.bias', 'wav2vec2.encoder.layers.6.attention.out_proj.weight', 'wav2vec2.encoder.layers.3.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.attention.q_proj.bias', 'wav2vec2.feature_projection.layer_norm.weight', 'wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.feature_extractor.conv_layers.6.conv.weight', 'wav2vec2.encoder.layers.7.attention.q_proj.bias', 'wav2vec2.encoder.layers.9.attention.k_proj.bias', 'wav2vec2.encoder.layers.3.attention.q_proj.weight', 'wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.3.final_layer_norm.weight', 'wav2vec2.encoder.layers.2.attention.v_proj.weight', 'wav2vec2.encoder.layers.0.attention.out_proj.bias', 'wav2vec2.encoder.layers.3.layer_norm.bias', 'wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.0.attention.out_proj.weight', 'wav2vec2.encoder.layers.4.layer_norm.bias', 'wav2vec2.encoder.layers.5.attention.q_proj.bias', 'wav2vec2.encoder.layers.5.attention.q_proj.weight', 'wav2vec2.encoder.layers.9.final_layer_norm.bias', 'wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.4.attention.q_proj.weight', 'wav2vec2.encoder.layers.2.attention.out_proj.weight', 'wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.encoder.layers.5.final_layer_norm.weight', 'wav2vec2.encoder.layers.2.attention.q_proj.bias', 'wav2vec2.encoder.layer_norm.weight', 'wav2vec2.encoder.layers.3.attention.v_proj.bias', 'wav2vec2.encoder.layers.7.final_layer_norm.weight', 'wav2vec2.encoder.layers.6.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.attention.k_proj.weight', 'wav2vec2.encoder.layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.out_proj.weight', 'wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.weight', 'classifier.weight', 'wav2vec2.encoder.layers.1.attention.v_proj.bias', 'wav2vec2.encoder.layers.1.attention.out_proj.weight', 'wav2vec2.encoder.layers.2.attention.q_proj.weight', 'wav2vec2.encoder.layers.11.attention.k_proj.weight', 'wav2vec2.encoder.layers.4.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.7.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.final_layer_norm.weight', 'wav2vec2.encoder.layers.11.attention.out_proj.weight', 'wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.layers.10.final_layer_norm.bias', 'projector.weight', 'wav2vec2.encoder.layers.0.attention.q_proj.weight', 'wav2vec2.encoder.layers.6.attention.v_proj.weight', 'wav2vec2.encoder.layers.11.attention.v_proj.bias', 'wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.encoder.layers.10.attention.k_proj.weight', 'wav2vec2.encoder.layers.10.feed_forward.output_dense.bias', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.encoder.layers.2.attention.v_proj.bias', 'wav2vec2.encoder.layers.1.layer_norm.weight', 'wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.1.final_layer_norm.weight', 'wav2vec2.encoder.layers.3.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.4.attention.k_proj.weight', 'wav2vec2.encoder.layers.0.layer_norm.bias', 'wav2vec2.encoder.layers.11.final_layer_norm.bias', 'wav2vec2.encoder.layers.9.attention.out_proj.bias', 'wav2vec2.encoder.layers.8.final_layer_norm.bias', 'wav2vec2.encoder.layers.10.final_layer_norm.weight', 'wav2vec2.encoder.layers.1.final_layer_norm.bias', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.attention.v_proj.bias', 'wav2vec2.encoder.layers.3.attention.out_proj.weight', 'wav2vec2.encoder.layers.3.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.attention.v_proj.bias', 'wav2vec2.encoder.layers.4.attention.v_proj.weight', 'wav2vec2.encoder.layers.1.attention.v_proj.weight', 'wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.out_proj.bias', 'wav2vec2.encoder.layers.5.final_layer_norm.bias', 'wav2vec2.encoder.layers.5.attention.out_proj.weight', 'wav2vec2.encoder.layers.10.attention.q_proj.bias', 'wav2vec2.encoder.layers.6.layer_norm.bias', 'wav2vec2.encoder.layers.7.final_layer_norm.bias', 'classifier.bias', 'wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.6.attention.k_proj.bias', 'wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.2.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.2.attention.k_proj.weight', 'wav2vec2.encoder.layers.2.layer_norm.weight', 'wav2vec2.encoder.layers.3.attention.v_proj.weight', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.layer_norm.bias', 'wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.9.attention.v_proj.weight', 'wav2vec2.encoder.layers.9.final_layer_norm.weight', 'wav2vec2.encoder.layers.11.layer_norm.weight', 'wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.1.attention.k_proj.weight', 'wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.encoder.layers.2.layer_norm.bias', 'wav2vec2.encoder.layers.2.final_layer_norm.bias', 'wav2vec2.encoder.layers.2.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.3.attention.q_proj.bias', 'wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.encoder.layers.0.attention.v_proj.weight', 'wav2vec2.encoder.layers.2.attention.k_proj.bias', 'wav2vec2.encoder.layers.9.layer_norm.bias', 'wav2vec2.encoder.layers.8.attention.k_proj.bias', 'wav2vec2.encoder.layers.11.attention.q_proj.weight', 'wav2vec2.encoder.layers.4.final_layer_norm.bias', 'wav2vec2.encoder.layers.6.layer_norm.weight', 'wav2vec2.encoder.layers.8.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.layer_norm.bias', 'wav2vec2.encoder.layers.9.attention.out_proj.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.layer_norm.bias', 'wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.9.feed_forward.output_dense.bias']
- This IS expected if you are initializing TFWav2Vec2ForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFWav2Vec2ForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFWav2Vec2ForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['tf_wav2_vec2_model_1.wav2vec2.masked_spec_embed', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.6.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.projection.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.projection.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.final_layer_norm.bias', 'dense_2.weight', 'dense_2.bias', 'dense_3.weight', 'dense_3.bias', 'Variable']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

  1. It's a warning.
  2. I tried running on sample_inputs same as here
    >>> inputs_tf = feature_extractor(dataset[0]["audio"]["array"],sampling_rate=sampling_rate,return_tensors="tf")
    >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
    >>> with torch.no_grad():
    ...     logits = model(**inputs).logits
    ... 
    >>> logits = tf_model(**inputs_tf).logits
    >>> inputs_tf = feature_extractor(dataset[0]["audio"]["array"],sampling_rate=sampling_rate,return_tensors="tf")
    >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
    >>> with torch.no_grad():
    ...     logits = model(**inputs).logits
    ... 
    >>> logits_tf = tf_model(**inputs_tf).logits
    >>> logits
    tensor([[-0.0732, -0.5845, -3.5185, -1.4014, -0.1823, -2.9616, -3.1919, -1.3804,
         -1.1895,  0.4006,  6.4601, -6.2880]])
    >>> logits_tf
    <tf.Tensor: shape=(1, 12), dtype=float32, numpy=
    array([[-1.310684  ,  0.13441604,  0.6363504 , -0.5188892 ,  0.46565807,
        -0.25152174, -0.45716044, -0.14784068,  0.176272  ,  1.4507922 ,
        -1.9966551 , -0.5963241 ]], dtype=float32)>
    >>> equal = torch.allclose(logits,torch.tensor(logits_tf.numpy()), rtol=1e-5)
    >>> equal
    False    
  3. Pytorch model doesn't gives any error/warning like that.
vimarshc commented 1 year ago

Ok. You have enough to go on here. The output is not equal because you're not using all the weights in the pretrained model.

  1. The warning states that for some reason some layers were initialized with the pretrained weights and some weren't.
  2. This usually happens if the model doesn't match perfectly.
  3. If the model has N layers and only the first M match exactly then only the first M will be loaded from the pretrained model.

So, print the dimensions of all the layers of both models and verify layer by layer if everything matches perfectly. cc: @nandwalritik

sanchit-gandhi commented 1 year ago

Thanks for helping out here @vimarshc! Your tips were spot on ✅ @nandwalritik has the PR nearly finished and equality with the PyTorch model