Open sanchit-gandhi opened 1 year ago
This feature request is closely related to #21777! Once we have the TF Wav2Vec2 model for sequence classification added, we can copy across the projection layers and classification layers to Whisper in order to add TFWhisperForAudioClassifcation
. Two birds with one stone ⚡️
Hi @sanchit-gandhi I would love to take this up.
Very cool @nandwalritik! The first thing to do would be to add the equivalent TensorFlow code for the projection layer and classification layer on top of the base TFWav2Vec2Model
. Do you want to have a go at adding this in a new PR? Happy to help with any questions / guidance! There's a bit of info as to where the PyTorch code lives in the original post ^
Hi @sanchit-gandhi I have added some initial changes in #22073 PR, but while initializing it with pytorch weights
model_tf = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks",from_pt=True)
like this it gives Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification:
can you guide me with this?
hidden_states
and pooled_output
in pytorch and tf implementation they both are matching.hi @sanchit-gandhi can you guide me for above error, so that I can make all the required changes and close the PR.
Hey, Can you share the complete stack trace?
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification:
The important part of the error is Some. Most likely the classification head is not being loaded correctly.
Questions:
cc: @nandwalritik
Hey, Can you share the complete stack trace?
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification:
The important part of the error is Some. Most likely the classification head is not being loaded correctly.
Questions:
- Is it a warning? or is it an error?
- Did you try running the model after this?
- Tried using the same model for PyTorch and see if you get the same error.
cc: @nandwalritik
Stacktrace
>>> tf_model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks",from_pt=True)
/home/nandwalritik/nandwalritik/transformers/src/transformers/configuration_utils.py:379: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
warnings.warn(
TFWav2Vec2ForSequenceClassification has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tine this model, you need a GPU or a TPU
TFWav2Vec2Model has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tine this model, you need a GPU or a TPU
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2ForSequenceClassification: ['wav2vec2.encoder.layers.10.attention.q_proj.weight', 'wav2vec2.encoder.layers.1.attention.k_proj.bias', 'wav2vec2.encoder.layers.1.attention.q_proj.bias', 'wav2vec2.encoder.layers.0.attention.v_proj.bias', 'wav2vec2.encoder.layers.6.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.attention.v_proj.weight', 'wav2vec2.encoder.layers.1.attention.out_proj.bias', 'wav2vec2.encoder.layers.0.layer_norm.weight', 'wav2vec2.encoder.layers.3.layer_norm.weight', 'wav2vec2.encoder.layers.10.attention.out_proj.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.encoder.layers.8.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.layer_norm.weight', 'wav2vec2.encoder.layers.0.attention.k_proj.bias', 'wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.v_proj.weight', 'wav2vec2.encoder.layers.5.attention.k_proj.weight', 'wav2vec2.encoder.layers.6.final_layer_norm.weight', 'wav2vec2.encoder.layers.9.feed_forward.output_dense.weight', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.layers.6.attention.q_proj.weight', 'wav2vec2.encoder.layers.4.attention.v_proj.bias', 'wav2vec2.encoder.layers.11.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.6.attention.q_proj.bias', 'wav2vec2.encoder.layers.0.attention.q_proj.bias', 'wav2vec2.encoder.layers.4.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.attention.k_proj.bias', 'wav2vec2.encoder.layers.7.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.3.attention.k_proj.bias', 'wav2vec2.encoder.layers.8.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.attention.out_proj.weight', 'wav2vec2.encoder.layers.7.attention.out_proj.bias', 'wav2vec2.encoder.layers.8.attention.q_proj.bias', 'wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.encoder.layers.11.feed_forward.output_dense.weight', 'wav2vec2.encoder.pos_conv_embed.conv.bias', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.feed_forward.output_dense.bias', 'wav2vec2.feature_projection.projection.weight', 'wav2vec2.encoder.layers.5.attention.v_proj.weight', 'wav2vec2.encoder.layers.10.attention.out_proj.bias', 'wav2vec2.encoder.layers.4.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.attention.k_proj.weight', 'wav2vec2.encoder.layers.7.layer_norm.bias', 'wav2vec2.encoder.layers.1.attention.q_proj.weight', 'wav2vec2.encoder.layers.7.layer_norm.weight', 'wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.encoder.layers.8.attention.v_proj.bias', 'projector.bias', 'wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.attention.q_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.10.attention.k_proj.bias', 'wav2vec2.encoder.layers.4.attention.out_proj.bias', 'wav2vec2.encoder.layers.6.final_layer_norm.bias', 'layer_weights', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.k_proj.bias', 'wav2vec2.encoder.layers.7.attention.v_proj.weight', 'wav2vec2.encoder.layers.2.attention.out_proj.bias', 'wav2vec2.encoder.layers.4.attention.out_proj.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.q_proj.weight', 'wav2vec2.encoder.layers.3.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.feed_forward.output_dense.weight', 'wav2vec2.feature_projection.layer_norm.bias', 'wav2vec2.encoder.layers.6.attention.k_proj.weight', 'wav2vec2.encoder.layers.7.attention.v_proj.bias', 'wav2vec2.encoder.layers.4.attention.k_proj.bias', 'wav2vec2.encoder.layers.4.layer_norm.weight', 'wav2vec2.encoder.layers.9.attention.q_proj.bias', 'wav2vec2.encoder.layers.4.attention.q_proj.bias', 'wav2vec2.encoder.layers.8.layer_norm.weight', 'wav2vec2.encoder.layers.2.final_layer_norm.weight', 'wav2vec2.feature_projection.projection.bias', 'wav2vec2.encoder.layers.3.final_layer_norm.bias', 'wav2vec2.encoder.layers.8.layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.k_proj.bias', 'wav2vec2.encoder.layers.5.layer_norm.weight', 'wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.6.attention.v_proj.bias', 'wav2vec2.encoder.layers.8.attention.v_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.5.attention.out_proj.bias', 'wav2vec2.encoder.layers.10.layer_norm.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.9.attention.q_proj.weight', 'wav2vec2.encoder.layers.5.attention.v_proj.bias', 'wav2vec2.encoder.layers.6.attention.out_proj.weight', 'wav2vec2.encoder.layers.3.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.attention.q_proj.bias', 'wav2vec2.feature_projection.layer_norm.weight', 'wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.feature_extractor.conv_layers.6.conv.weight', 'wav2vec2.encoder.layers.7.attention.q_proj.bias', 'wav2vec2.encoder.layers.9.attention.k_proj.bias', 'wav2vec2.encoder.layers.3.attention.q_proj.weight', 'wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.3.final_layer_norm.weight', 'wav2vec2.encoder.layers.2.attention.v_proj.weight', 'wav2vec2.encoder.layers.0.attention.out_proj.bias', 'wav2vec2.encoder.layers.3.layer_norm.bias', 'wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.0.attention.out_proj.weight', 'wav2vec2.encoder.layers.4.layer_norm.bias', 'wav2vec2.encoder.layers.5.attention.q_proj.bias', 'wav2vec2.encoder.layers.5.attention.q_proj.weight', 'wav2vec2.encoder.layers.9.final_layer_norm.bias', 'wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.4.attention.q_proj.weight', 'wav2vec2.encoder.layers.2.attention.out_proj.weight', 'wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.encoder.layers.5.final_layer_norm.weight', 'wav2vec2.encoder.layers.2.attention.q_proj.bias', 'wav2vec2.encoder.layer_norm.weight', 'wav2vec2.encoder.layers.3.attention.v_proj.bias', 'wav2vec2.encoder.layers.7.final_layer_norm.weight', 'wav2vec2.encoder.layers.6.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.attention.k_proj.weight', 'wav2vec2.encoder.layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.out_proj.weight', 'wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.weight', 'classifier.weight', 'wav2vec2.encoder.layers.1.attention.v_proj.bias', 'wav2vec2.encoder.layers.1.attention.out_proj.weight', 'wav2vec2.encoder.layers.2.attention.q_proj.weight', 'wav2vec2.encoder.layers.11.attention.k_proj.weight', 'wav2vec2.encoder.layers.4.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.7.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.final_layer_norm.weight', 'wav2vec2.encoder.layers.11.attention.out_proj.weight', 'wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.layers.10.final_layer_norm.bias', 'projector.weight', 'wav2vec2.encoder.layers.0.attention.q_proj.weight', 'wav2vec2.encoder.layers.6.attention.v_proj.weight', 'wav2vec2.encoder.layers.11.attention.v_proj.bias', 'wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.encoder.layers.10.attention.k_proj.weight', 'wav2vec2.encoder.layers.10.feed_forward.output_dense.bias', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.encoder.layers.2.attention.v_proj.bias', 'wav2vec2.encoder.layers.1.layer_norm.weight', 'wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.1.final_layer_norm.weight', 'wav2vec2.encoder.layers.3.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.4.attention.k_proj.weight', 'wav2vec2.encoder.layers.0.layer_norm.bias', 'wav2vec2.encoder.layers.11.final_layer_norm.bias', 'wav2vec2.encoder.layers.9.attention.out_proj.bias', 'wav2vec2.encoder.layers.8.final_layer_norm.bias', 'wav2vec2.encoder.layers.10.final_layer_norm.weight', 'wav2vec2.encoder.layers.1.final_layer_norm.bias', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.attention.v_proj.bias', 'wav2vec2.encoder.layers.3.attention.out_proj.weight', 'wav2vec2.encoder.layers.3.attention.out_proj.bias', 'wav2vec2.encoder.layers.9.attention.v_proj.bias', 'wav2vec2.encoder.layers.4.attention.v_proj.weight', 'wav2vec2.encoder.layers.1.attention.v_proj.weight', 'wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.11.attention.out_proj.bias', 'wav2vec2.encoder.layers.5.final_layer_norm.bias', 'wav2vec2.encoder.layers.5.attention.out_proj.weight', 'wav2vec2.encoder.layers.10.attention.q_proj.bias', 'wav2vec2.encoder.layers.6.layer_norm.bias', 'wav2vec2.encoder.layers.7.final_layer_norm.bias', 'classifier.bias', 'wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.6.attention.k_proj.bias', 'wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.2.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.2.attention.k_proj.weight', 'wav2vec2.encoder.layers.2.layer_norm.weight', 'wav2vec2.encoder.layers.3.attention.v_proj.weight', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.layer_norm.bias', 'wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.9.attention.v_proj.weight', 'wav2vec2.encoder.layers.9.final_layer_norm.weight', 'wav2vec2.encoder.layers.11.layer_norm.weight', 'wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.1.attention.k_proj.weight', 'wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.encoder.layers.2.layer_norm.bias', 'wav2vec2.encoder.layers.2.final_layer_norm.bias', 'wav2vec2.encoder.layers.2.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.3.attention.q_proj.bias', 'wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.encoder.layers.0.attention.v_proj.weight', 'wav2vec2.encoder.layers.2.attention.k_proj.bias', 'wav2vec2.encoder.layers.9.layer_norm.bias', 'wav2vec2.encoder.layers.8.attention.k_proj.bias', 'wav2vec2.encoder.layers.11.attention.q_proj.weight', 'wav2vec2.encoder.layers.4.final_layer_norm.bias', 'wav2vec2.encoder.layers.6.layer_norm.weight', 'wav2vec2.encoder.layers.8.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.layer_norm.bias', 'wav2vec2.encoder.layers.9.attention.out_proj.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.weight', 'wav2vec2.encoder.layers.5.layer_norm.bias', 'wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.9.feed_forward.output_dense.bias']
- This IS expected if you are initializing TFWav2Vec2ForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFWav2Vec2ForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFWav2Vec2ForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['tf_wav2_vec2_model_1.wav2vec2.masked_spec_embed', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_extractor.conv_layers.6.conv.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.projection.weight', 'tf_wav2_vec2_model_1.wav2vec2.feature_projection.projection.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'tf_wav2_vec2_model_1.wav2vec2.encoder.pos_conv_embed.conv.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.0.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.1.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.2.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.3.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.4.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.5.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.6.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.7.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.8.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.9.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.10.final_layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.k_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.k_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.q_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.q_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.v_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.v_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.out_proj.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.attention.out_proj.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.layer_norm.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.output_dense.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.feed_forward.output_dense.bias', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.final_layer_norm.weight', 'tf_wav2_vec2_model_1.wav2vec2.encoder.layers.11.final_layer_norm.bias', 'dense_2.weight', 'dense_2.bias', 'dense_3.weight', 'dense_3.bias', 'Variable']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
>>> inputs_tf = feature_extractor(dataset[0]["audio"]["array"],sampling_rate=sampling_rate,return_tensors="tf")
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
...
>>> logits = tf_model(**inputs_tf).logits
>>> inputs_tf = feature_extractor(dataset[0]["audio"]["array"],sampling_rate=sampling_rate,return_tensors="tf")
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
...
>>> logits_tf = tf_model(**inputs_tf).logits
>>> logits
tensor([[-0.0732, -0.5845, -3.5185, -1.4014, -0.1823, -2.9616, -3.1919, -1.3804,
-1.1895, 0.4006, 6.4601, -6.2880]])
>>> logits_tf
<tf.Tensor: shape=(1, 12), dtype=float32, numpy=
array([[-1.310684 , 0.13441604, 0.6363504 , -0.5188892 , 0.46565807,
-0.25152174, -0.45716044, -0.14784068, 0.176272 , 1.4507922 ,
-1.9966551 , -0.5963241 ]], dtype=float32)>
>>> equal = torch.allclose(logits,torch.tensor(logits_tf.numpy()), rtol=1e-5)
>>> equal
False
Ok. You have enough to go on here. The output is not equal because you're not using all the weights in the pretrained model.
So, print the dimensions of all the layers of both models and verify layer by layer if everything matches perfectly. cc: @nandwalritik
Thanks for helping out here @vimarshc! Your tips were spot on ✅ @nandwalritik has the PR nearly finished and equality with the PyTorch model
Feature request
Wav2Vec2 is one of the most popular speech recognition models, used over 2 million times monthly. In the PyTorch modelling code, we have Wav2Vec2 for speech recognition and Wav2Vec2 for audio classification. However, in TensorFlow, we only have Wav2Vec2 for speech recognition. It would be great to add Wav2Vec2 for audio classification to the TensorFlow modelling code for cross-framework equivalence!
Motivation
The audio classification class for PyTorch Wav2Vec2 lives under
Wav2Vec2ForSequenceClassification
: https://github.com/huggingface/transformers/blob/13489248fa8f2cda7503628204f8f43b108797a2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1745For this feature request, we'll need to port this PyTorch code into TensorFlow to create an equivalent TensorFlow class,
TFWav2Vec2ForSequenceClassification
.This means adding a projection layer and classification layer on top of the base
TFWav2Vec2Model
. See the PyTorch code for reference: https://github.com/huggingface/transformers/blob/13489248fa8f2cda7503628204f8f43b108797a2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1753-L1758To check our that our implementation is correct, we can do one forward pass of the PyTorch model and a forward pass of the TensorFlow model with the same inputs. If the output logits are to within 1e-5, we know that our TensorFlow model is correct ✅. We can then enable PT-TF cross tests in the modelling file such that these checks are performed by the CI.
Your contribution
Opening this one up to the community! If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!