SungFeng-Huang / SSL-pretraining-separation

Official repository of our paper: https://arxiv.org/abs/2010.15366
56 stars 8 forks source link

--load_path parameter not being used: How exactly fine-tuning is carried out? #3

Open jvel07 opened 1 year ago

jvel07 commented 1 year ago

Hey there, @SungFeng-Huang! Hoping you are all good. :)

We were trying your repo for fine-tuning ConvTasNet, which can be achieved by setting the --strategy = 'pretrained'. Looking at the train_general.py, there is the following snippet of code:

if known_args.strategy == "pretrained":
    parser.add_argument("--load_path", default=None, required=True, help="Checkpoint path to load for fine-tuning.")
  1. Here, --load_path is not being used anywhere else in the code. Hence, the question is: how exactly fine-tuning is carried out? Ultimately, we couldn't tell if there is any difference between from_scratch and pretrained from the train_general.py code.

  2. We also tried to find out whether you perform a freeze of some layers of a pre-trained ConvTasNet model before fine-tuning, but there is no such thing. Based on the paper, it seems that the ConvTasNet-trained model is loaded, and then --all-- these weights are learned again during fine-tuning. Is this correct?

Hope you can clarify our doubts, please! Thanks for your attention.

SungFeng-Huang commented 1 year ago

Hi,

  1. The --load_path option is one of the keyword arguments of pytorch_lightning.Trainer class for resume training. For some of the newer version of pytorch_lightning, this is sometimes the keyword argument of Trainer.fit() function. Since this repository is pretty old, you should probably check the PytorchLightning documentation of the package version you've installed and fix the code accordingly.
  2. Yes, all weights are fine-tuned, no parameters are frozen.

Hope these answers help! Feel free to contact me if you have more questions.

jvel07 commented 1 year ago

Thanks for the reply, @SungFeng-Huang!

So, in this case for the repo, fine-tuning would mean just resuming the training (of a model that was trained on train-100) now with train-360?

Are there any other configurations/parameters that were changed for fine-tuning besides the train set? (We can't seem to find these details in the paper).

SungFeng-Huang commented 1 year ago

Thanks for the reply, @SungFeng-Huang!

So, in this case for the repo, fine-tuning would mean just resuming the training (of a model that was trained on train-100) now with train-360?

Are there any other configurations/parameters that were changed for fine-tuning besides the train set? (We can't seem to find these details in the paper).

Ummm, this work is done pretty long time ago, so I can't fully remember the details, but I think "fine-tuning" means to resume (continue) to train the separation task from an SSL trained model checkpoint. For example, we can "fine-tune" by training a separation task on Libri2Mix, whose parameter is initialized from a "pre-trained" model that was trained with a speech denoising task trained on Libri1Mix. In this case, comparing the "pre-training" stage and "fine-tuning" stage, only the dataset and loss function are changed (Libri1Mix -> Libri2Mix / denoising loss -> separation loss), while the others are unchanged.

jvel07 commented 1 year ago

@SungFeng-Huang, yes. The fact is that we don't see the loss being changed at the moment of pretraining and/or when fine-tuning. For both cases, the loss is PITLossWrapper. And, It only changes to MultiTaskLossWrapper at the moment of choosing "multitask", see below: https://github.com/SungFeng-Huang/SSL-pretraining-separation/blob/d7ec4cf6a99f33f38f50b09619b838f51ac456da/train_general.py#L100

What are the corresponding loss functions to both pre-training and fine-tuning?

SungFeng-Huang commented 1 year ago

I finally recovered part of my memory. The pre-training part, specifically with speech enhancement task, we use the checkpoint found on huggingface, something like this (not sure whether this is the exact one used at that time). Since both those uploaded checkpoint and my scripts are modified from the asteroid example code, I'm pretty sure that the settings are basically the same, such as sampling rate, STFT configs, learning rate, and other hyperparameters.

As for the loss function of fine-tuning is separation loss, which is an sisdr loss wrapped with a PITLossWrapper: https://github.com/SungFeng-Huang/SSL-pretraining-separation/blob/d7ec4cf6a99f33f38f50b09619b838f51ac456da/train_general.py#L100-L101

Now back to the pre-training loss, since speech enhancement, or also called denoising, achieves the best result, so let's first talk about the enhancement case. The loss can be found in MultiTaskLossWrapper, which determines to use enhancement loss or separation loss depending on the target's shape, and here's the enhancement part of the loss: https://github.com/SungFeng-Huang/SSL-pretraining-separation/blob/d7ec4cf6a99f33f38f50b09619b838f51ac456da/src/losses/multi_task_wrapper.py#L12-L15 So you could directly use the MultiTaskLossWrapper if you want to train an enhancememt model for pre-training. If you want a "pure" enhancement loss, asteroid repo might be a better place to find it. You can find example codes of training a speech enhancement model there.

For the other self-supervised loss, they are basically borrowed from the code of s3prl package. I believe I didn't upload those related code simply because I wrote the experiment code elsewhere and those SSL pre-training did not outperform speech enhancement pre-training.

SungFeng-Huang commented 1 year ago

To correct, we used the checkpoint downloaded from Zenodo, which is the main space for asteroid to store checkpoints at that time.