b04901014 / FT-w2v2-ser

Official implementation for the paper Exploring Wav2vec 2.0 fine-tuning for improved speech emotion recognition
MIT License
137 stars 32 forks source link

Could you share the pre-trained wav2vec 2.0 models using TAPT and P-TAPT? #4

Closed youcaiSUN closed 2 years ago

youcaiSUN commented 2 years ago

Thanks very much!

b04901014 commented 2 years ago

What specific pre-trained model are you looking for? Is it already fine-tuned on the Emotion Classification objective or the one with only pre-training objective? (Say, running the further pre-training on IEMOCAP) And currently I have models trained on specific splits, for instance, training on Session 1~4 and test on Session 5, but I don't have a model that trained on the whole IEMOCAP. What is your usage there?

youcaiSUN commented 2 years ago

Sorry for the misunderstanding, just the one with only pre-training objective. I pre-trained one model using TPAT, however, the average uar is just~71 (~2% gap). So I want to use your pre-trained models in my downstream tasks. Besides, do I need to pre-train the model for each session and then fine-tune it on that session or just pre-train only once and finetune it on all sessions (since the pre-training is time-consuming)?

b04901014 commented 2 years ago

For the purpose of the research experiment, I have to run it for each session respectively to make sure I don't have access to the test set even during the pretraining phase. So the numbers you saw is by running TAPT (for each split) respectively instead of using a single model.

If you have your own dataset, it will be better to just run the continue-pretraining procedure (TAPT) on your own dataset instead of IEMOCAP. Or simply run TPAT with all the data including IEMOCAP...

The checkpoints I currently have in hand is the pretrained model of each session after TAPT. Are you requesting all of them (5 checkpoints)?

youcaiSUN commented 2 years ago

Great, I see!

Yeah, I request all of them. And could you please upload all your pre-trained models (using TAPT and P-TAPT) to online drive (may be Google, or Baidu) so that the community can use them in their own research.

Thanks very much!

b04901014 commented 2 years ago

Here you go. https://cmu.box.com/s/yitaa74udy23vqubm9ox2pyf2kb3sg5n

But I didn't double check checkpoints for you. (The experiments have been done for a while)

youcaiSUN commented 2 years ago

OK! I'll try it.

sundekai commented 2 years ago

thanks!!!

youcaiSUN commented 2 years ago

Here you go. https://cmu.box.com/s/yitaa74udy23vqubm9ox2pyf2kb3sg5n

But I didn't double check checkpoints for you. (The experiments have been done for a while)

Dear Liwei @b04901014,

I finetuned your pre-trained models using TAPT and PTAPT on IEMOCAP. Since my GPU has only 32GB memory which could not afford a batch size of 64, I slightly modified the trainer setting (i.e., I accumulate gradients every 2 batches with a batch size of 32) in _run_downstream_custom_multiple_fold.py_, which is shown as below:

trainer = Trainer(
    precision=args.precision,
    amp_backend='native',
    callbacks=[checkpoint_callback] if hasattr(model, 'valid_met') else None,
    checkpoint_callback=hasattr(model, 'valid_met'),
    resume_from_checkpoint=None,
    check_val_every_n_epoch=1,
    accumulate_grad_batches=64//args.batch_size,
    max_epochs=hparams.max_epochs,
    num_sanity_val_steps=2 if hasattr(model, 'valid_met') else 0,
    gpus=1,
    logger=False
)

Besides, I modified the checkpoint loading part in downstream/Custom/trainer.py to make sure that every fold the model initializes from corresponding checkpoint.

if self.hp.pretrained_path is not None:
    import glob
    ckpt_files = sorted(glob.glob(os.path.join(self.hp.pretrained_path, '*.ckpt')))
    if len(ckpt_files) == self.hp.nfolds:
        pretrained_path = ckpt_files[self.hp.ifold]
    else:
        pretrained_path = self.hp.pretrained_path
        assert pretrained_path.endswith('.ckpt')
    print(f"==> Fold {self.hp.ifold +1 } use pre-trained model '{pretrained_path}' ")
    self.model = PretrainedRNNHead.load_from_checkpoint(pretrained_path, strict=False,
                                                        n_classes=self.dataset.nemos,
                                                        backend=self.hp.model_type)
else:
    self.model = PretrainedRNNHead(n_classes=self.dataset.nemos,
                                   backend=self.hp.model_type)

The average results (five experiments for each session) are as follows:

Method Stat Func UAR WAR macroF1 microF1
PTAPT mean 71.6207 70.5469 70.9655 72.0825
PTAPT std 0.8163 0.9130 0.9821 0.6793
TAPT mean 71.4809 70.0057 70.3757 71.8254
TAPT std 0.5025 0.4757 0.3672 0.1165
TAPT* mean 71.4191 70.1220 70.4931 71.7437
TAPT* std 1.1267 1.1062 1.2942 0.8730

*: using my own pre-trained model under the default TAPT setting in README.md except that training_step=16000 with a batch size of 32 (about 1000 epochs on IEMOCAP Session 1 train set), and in the finetuning stage it is used for all sessions.

From the above table, the three kinds of models have similar performance. Since the results have a gap with those reported in the paper, could you help me to figure out what's wrong with my finetuning process?

Best, Licai

b04901014 commented 2 years ago

Hmm. I'm a little bit confused. Are you running n-folds with different session as testing set? Or are the performance you reporting on a specific split? (Say all are test on session 1)

I mean, if you are reporting only the session 1 performance (run session 1 five times), you should be only using the session 1 checkpoint? If you only run 1 run with the 5 splits (test on session 1 ~ 5), the std should be much larger, since the performance on different splits differs very much.

youcaiSUN commented 2 years ago

Sorry, the results I reported are the mean/std of five sessions (and each session I ran five times). I use the following command (using TAPT as an example) for finetuning:

CUDA_VISIBLE_DEVICES=0 python run_downstream_custom_multiple_fold.py --precision 16 \
                                                                     --batch_size 32 \
                                                                     --num_exps 5 \
                                                                     --saving_path ./saved/downstream/TAPT/author \
                                                                     --pretrained_path ./saved/pretrain/TAPT/author
Besides, the std in the above table as calculated as "np.std(np.mean(metrics, 2), 1)" instead of the default fold std "np.mean(np.std(metric, 1)". The results of fold std are as below (much larger than std as you said): Method Stat Func UAR WAR macroF1 microF1
PTAPT mean 71.6207 70.5469 70.9655 72.0825
PTAPT std 0.8163 0.9130 0.9821 0.6793
PTAPT fold std 3.0007 2.2742 2.1910 2.2853
TAPT mean 71.4809 70.0057 70.3757 71.8254
TAPT std 0.5025 0.4757 0.3672 0.1165
TAPT fold std 3.9209 3.7091 3.7024 3.2703
TAPT* mean 71.4191 70.1220 70.4931 71.7437
TAPT* std 1.1267 1.1062 1.2942 0.8730
TAPT* fold std 3.6798 3.2672 3.2352 2.9478
b04901014 commented 2 years ago

Could you check if the five label file in your labeldir correspond to the 5 splits? It looks like you modified the code to hardcode the labeldir, originally it is passed as a command line argument.

Also, if you use accumulate_grad_batches to accumulate gradients every two step, then you should double the max_epoch, otherwise you are using only half the training steps. BTW, you can have more memory if you set https://github.com/b04901014/FT-w2v2-ser/blob/main/modules/FeatureFuser.py#L103 to True.

I don't think the performance should differ this much though.

youcaiSUN commented 2 years ago

The std is the std (mean(across five sessions) across five experiments).

The five label files in labeldir are generated using your preprocessing code, which can be downloaded from labels_sess.zip.

I checked the label files. The test split for each fold only include the samples in the correponding session.

The oringinal argument for labeldir in downstream finetuning is the path to a specific checkpoint file. As I said above, to make sure that each session uses the correponding pre-trained checkpoint which do not 'see' the test set of that session, I modified the checkpoint loading part .

b04901014 commented 2 years ago

Hmm. Then could you set the gradient_checkpointing in https://github.com/b04901014/FT-w2v2-ser/blob/main/modules/FeatureFuser.py#L103 to True then re-run with batch_size = 64 without accumulating gradients ? With accumulate gradients every 2 steps, you are essentially running only 7.5 epochs instead of 15.

Otherwise I could not find anything weird about your experiment currently.

youcaiSUN commented 2 years ago

Could you check if the five label file in your labeldir correspond to the 5 splits? It looks like you modified the code to hardcode the labeldir, originally it is passed as a command line argument.

Also, if you use accumulate_grad_batches to accumulate gradients every two step, then you should double the max_epoch, otherwise you are using only half the training steps. BTW, you can have more memory if you set https://github.com/b04901014/FT-w2v2-ser/blob/main/modules/FeatureFuser.py#L103 to True.

I don't think the performance should differ this much though.

Thanks very much!

I believe that the training steps using a batch size of 32 and accumulate_grad_batches=2 is equal to that using a batch size of 64. Nevertheless, I will rerun the code with a batch size of 64 using distributed parallel training to examine the influence of differerent batch sizes.

b04901014 commented 2 years ago

NVM, I think you are right. They should be equivalent. Then I actually don't know where should be the issue then. Do you mind just uploading your whole code as a zip file without the checkpoints? Maybe I could look into it.

youcaiSUN commented 2 years ago

My pleasure! The code is here FT-w2v2-ser.zip.

I use the shell in scripts folder to pretrain and finetune models on IEMOCAP.

The results are saved in saved folder (detailed results for finetuning are also included in the subfolders).

Any questions is welcome.

b04901014 commented 2 years ago

I have seen several numbers below 76% for the session 2 (Fold 2) for individual runs on TAPT and PTAPT. Which I almost haven't seen one throughout my experiments. There is definitely something wrong, will inspect in it.

youcaiSUN commented 2 years ago

Thanks very much! I'm looking forward to your message.

b04901014 commented 2 years ago

I just ran your code on session 2 only (one fold), here is what I got:

image

There is a bad 76 run there, but the most of them are around 78~81%. Compared to what you got in your run:

image

I think the difference is too huge that they are unlikely to be drawn from the same distribution. The only code I changed is the path in the script into:

CUDA_VISIBLE_DEVICES=3 python run_downstream_custom_multiple_fold.py --precision 16 \
                                                                      --batch_size 32 \
                                                                      --num_exps 5 \
                                                                      --saving_path /tmp \
                                                                      --pretrained_path /usr1/liweiche/exps_new/IEMOCAP/checkpoints/session2.ckpt \
                                                                      --datadir ../FT-w2v2-ser/Dataset/IEMOCAP/Audio_16k/ \
                                                                      --labeldir /usr1/liweiche/exps_new/IEMOCAP/session2/labels/

And on the "/share/home/sunlicai/transformers/wav2vec2-base" to "facebook/wav2vec2-base". I'm not sure how you get your local pretrained model there. But unless the config.json in it is the same as https://huggingface.co/facebook/wav2vec2-base/blob/main/config.json . (The dropout ratio, and some other hyper-parameters depends on that) It should be fine.

I did not change anything else including the gradient accumulation and seed. If the pretrained model config is not a problem, then the only thing I could think of is the environment (are you using the same version of packages as in the Dockerfile? especially the huggingface version), or there may be bugs during the transition of folds...

youcaiSUN commented 2 years ago

Much thanks to your helpful debugging.

I re-ran the code with batch_size=64 by following your suggestion to set gradient_checkpointing to True.

The results of PTAPT are as follows: Exp Fold UAR WAR macroF1 microF1
1 1 69.8167 64.3318 64.7215 69.2565
1 2 77.6739 72.9228 74.4119 76.8706
1 3 69.3841 68.5491 69.1264 69.2956
1 4 73.4489 73.9088 73.5648 73.8406
1 5 75.2707 73.1668 73.5307 73.9309
2 1 73.7616 70.4147 71.3115 72.6207
2 2 80.7606 78.0059 78.9644 79.5461
2 3 69.4304 69.9392 70.2692 71.0953
2 4 69.5929 69.7381 69.5926 71.0474
2 5 73.5148 71.8775 72.7360 73.4933
3 1 72.2615 69.3088 70.2676 71.2854
3 2 74.9120 73.0205 74.0394 75.0084
3 3 70.9914 70.4605 70.7650 70.9058
3 4 73.8345 75.3637 74.4753 74.9603
3 5 72.3662 71.9581 72.1026 72.1454
4 1 70.7526 66.4516 67.6090 70.4535
4 2 75.6213 72.8250 74.1808 76.2389
4 3 67.9802 67.8540 68.6463 69.5036
4 4 72.2317 71.2900 71.3094 72.1324
4 5 72.0067 70.5077 70.3071 71.1857
5 1 71.8092 70.2304 71.0657 71.4374
5 2 79.5007 76.8328 77.2310 77.9573
5 3 71.6541 71.5899 71.9562 72.1773
5 4 74.7210 75.1697 75.1347 75.2602
5 5 73.9711 73.3280 74.1058 74.2059
Stat Func mean 73.0907 71.5618 72.0570 73.0342
Stat Func std 0.8454 1.2663 1.2180 0.7900
Stat Func fold_std 2.7925 2.7066 2.6447 2.5047
The results of TAPT are as follows: Exp Fold UAR WAR macroF1 microF1
1 1 72.4140 69.2166 69.9198 70.8831
1 2 78.8307 77.1261 77.6667 78.5128
1 3 67.4721 67.0721 67.1881 67.2617
1 4 76.4685 76.3337 76.1041 76.2746
1 5 73.2559 71.3940 71.6633 72.2239
2 1 70.3762 64.3318 65.6699 70.2334
2 2 77.3431 75.8553 77.0228 77.0488
2 3 65.5209 64.5526 64.9570 65.2109
2 4 71.1131 70.9990 71.3480 72.9595
2 5 74.7123 72.2804 73.2582 74.5627
3 1 68.2955 62.0277 62.8583 68.3240
3 2 76.2815 74.6823 75.9064 76.6184
3 3 64.3476 64.3788 64.0973 65.9882
3 4 73.1064 71.8720 71.5928 72.7076
3 5 68.1130 69.1378 69.4952 70.0967
4 1 72.1599 68.6636 69.4446 70.8491
4 2 74.9441 73.4115 74.3779 76.2800
4 3 68.5871 68.2016 68.4900 68.6162
4 4 76.4899 76.6246 76.5327 76.5754
4 5 72.2048 69.7824 70.3533 71.4640
5 1 70.0527 64.8848 65.8348 69.6342
5 2 76.0400 71.0655 71.6846 74.8651
5 3 67.6256 67.6803 67.8072 67.9602
5 4 72.0541 73.9088 73.3376 74.2588
5 5 73.3085 72.2804 72.3810 72.6330
Stat Func mean 72.0447 70.3105 70.7597 72.0817
Stat Func std 1.2300 1.3362 1.3040 0.7987
Stat Func fold_std 3.5281 3.9150 3.8503 3.5022

The results of PTAPT and TAPT are better than previous ones, especially PTAPT. It seems that using a smaller batches with gradient accmulation here is not fully equivalent to using larger batches without it. However, these is still ~1.5% performance gap.

I manually installed most packages (except fairseq, which is not used currently) with the versions specified in your dockerfile. You can check the env.txt.

My local pre-trained wav2vec2.0 base model is downloaded from official website about seven months ago. After I checked my local config.json, I found that it was not same to the lastest one (updated six months ago). One noticable difference is the architectures (old 'Wav2Vec2Model' vs. new 'Wav2Vec2ForPreTraining'). I also notice that warning (' Found keys that are not in the model state dict but in the checkpoint') occurred when loads pretrained checkpoint.

==> Fold 1 use pre-trained model './saved/pretrain/TAPT/author/session1.ckpt' 
Some weights of Wav2Vec2ForPreTraining were not initialized from the model checkpoint at /share/home/sunlicai/transformers/wav2vec2-base and are newly initialized: ['project_hid.bias', 'project_q.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_hid.weight', 'project_q.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/share/home/sunlicai/anaconda3/envs/pytorch1.9/lib/python3.8/site-packages/pytorch_lightning/core/saving.py:205: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['rnn_head.weight_ih_l0', 'rnn_head.weight_hh_l0', 'rnn_head.bias_ih_l0', 'rnn_head.bias_hh_l0', 'rnn_head.weight_ih_l0_reverse', 'rnn_head.weight_hh_l0_reverse', 'rnn_head.bias_ih_l0_reverse', 'rnn_head.bias_hh_l0_reverse', 'linear_head.1.weight', 'linear_head.1.bias']
  rank_zero_warn(
/share/home/sunlicai/anaconda3/envs/pytorch1.9/lib/python3.8/site-packages/pytorch_lightning/core/saving.py:209: UserWarning: Found keys that are not in the model state dict but in the checkpoint: ['wav2vec2.wav2vec2PT.wav2vec2.masked_spec_embed', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_extractor.conv_layers.6.conv.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_projection.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_projection.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.feature_projection.projection.weight', 'wav2vec2.wav2vec2PT.wav2vec2.feature_projection.projection.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.pos_conv_embed.conv.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.0.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.1.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.2.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.3.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.4.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.5.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.6.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.7.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.8.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.9.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.10.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.k_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.k_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.v_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.v_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.q_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.q_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.out_proj.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.attention.out_proj.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.layer_norm.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.feed_forward.output_dense.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.feed_forward.output_dense.bias', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.final_layer_norm.weight', 'wav2vec2.wav2vec2PT.wav2vec2.encoder.layers.11.final_layer_norm.bias', 'wav2vec2.wav2vec2PT.quantizer.codevectors', 'wav2vec2.wav2vec2PT.quantizer.weight_proj.weight', 'wav2vec2.wav2vec2PT.quantizer.weight_proj.bias', 'wav2vec2.wav2vec2PT.project_q.weight', 'wav2vec2.wav2vec2PT.project_q.bias', 'wav2vec2.wav2vec2PT.project_hid.weight', 'wav2vec2.wav2vec2PT.project_hid.bias']
  rank_zero_warn(
Weigh losses by prior distribution of each class: tensor([0.2906, 0.2043, 0.2069, 0.2983]).
Using native 16bit precision.

Maybe it's the devil. I'll fix the config issue by using the default facebook/wav2vec2-base instead of my local one and re-run the code. Hope it helps!

b04901014 commented 2 years ago

Glad that you are getting better results on reproducing, while I'm not sure why there is a difference between batch_size=64 and the accumulate gradient stuff you did earlier... On my machine they ran about the same results.

youcaiSUN commented 2 years ago

Really wierd! image

Glad that you are getting better results on reproducing, while I'm not sure why there is a difference between batch_size=64 and the accumulate gradient stuff you did earlier... On my machine they ran about the same results.

youcaiSUN commented 2 years ago

I re-ran the code using facebook/wav2vec2-base with batch_size=64. However, the results are even worse than those when using my local config. I am totally lost!!!

Method Stat Func UAR WAR macroF1 microF1
PTAPT mean 72.2857 70.5583 71.0144 72.3176
PTAPT std 0.6220 0.4712 0.4265 0.4557
PTAPT fold_std 2.6346 2.2990 2.3078 2.1766
TAPT mean 71.6164 69.7006 70.0368 71.5589
TAPT std 0.8966 0.7048 0.7180 0.3754
TAPT old_std 4.2064 3.7524 3.9618 3.4980
b04901014 commented 2 years ago

Hmm. It's really hard for me to debug for this... I think you fixed your own seed and I don't think there is a difference between these three models. One can claim there's a variance merely across the average of 5 runs, but I don't think there's that much difference. And this could not explain why in every run there the steady gap between PTAPT and TAPT still exists, and I did not observe such a high variance in my local runs; all seem to indicate there do have some factor that changed across your three experiments. Additionally, I would say as I didn't fix the seed, having above 73% may be somewhat reasonable enough for PTAPT.

Thank you for bringing this up, but I'm afraid that I couldn't help more on this currently. One thing is that I don't really know what's happening there, another is that I'm occupied with another research project recently.

b04901014 commented 2 years ago

This issue is getting too lengthy. I'll close this issue. Feel free to open another one if you have questions.