facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

[data2vec] Loss goes down and up again #4177

Open treasan opened 2 years ago

treasan commented 2 years ago

I try to train data2vec on music data (the FMA dataset). I've made some modifications to the feature extractor ConvNet (I've made it a small ResNet essentially), and reduced the size of the transformer encoder (8 layers, 8 attention heads, 512 emb dim, 2048 emb_ffn_dim). Since I am new to training transformer models, I've tried to reproduce the paper's hyperparameters as closely as possible. My batch size is much smaller, so I've adjusted the learning rate accordingly etc..

The first few epochs seem promising. The loss goes down quite quickly and the embedding variance of the teacher and student models increase over time respectively (i.e. the model is not collapsing to a constant representation). However, after a while it abruptly changes and the loss goes up again and stays high over the remaining training (the variances stay high though). I've also experienced similar issues with another exponential moving average (EMA) model, DINO (the ResNet variant), a while ago.

When reducing the learning rate, the effect just happens later.

There indirectly seems to be a similar issue here.

Does anyone have an idea why this might be happening?

treasan commented 2 years ago

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs:

https://imgur.com/a/dynF7P9

The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances.

Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Ramlinbird commented 2 years ago

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs:

https://imgur.com/a/dynF7P9

The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances.

Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

treasan commented 2 years ago

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs: https://imgur.com/a/dynF7P9 The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances. Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

Hey! Unfortunately not. Your loss seems to be ok though? I am not sure how the "correct" loss function should look like. It is strange that my loss values are two orders of magnitude smaller. Would also be great to see a training plot from the authors. Could you tell me which hyperparameters you have used? (batch size, ...)

Ramlinbird commented 2 years ago

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs: https://imgur.com/a/dynF7P9 The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances. Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

Hey! Unfortunately not. Your loss seems to be ok though? I am not sure how the "correct" loss function should look like. It is strange that my loss values are two orders of magnitude smaller. Would also be great to see a training plot from the authors. Could you tell me which hyperparameters you have used? (batch size, ...)

Thanks. I didn't change any hyperparameters except max_tokens (since the memory error) in base_librispeech.yaml. @alexeib could you share your training plot with us, and help us to figure out this? Thanks so much.

alexeib commented 2 years ago

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss:

image

pred var:

image

target var:

image
Ramlinbird commented 2 years ago

@alexeib Thanks a lot for your quick reply! According to your loss plot, my training seems all right? And I also check my variance plot, image However, the pred var and target var is not flat in the end, they are still dropping, is this OK? What's the actual meaning of these values, and what relation should they have? (I read the source code, they are just standard variance of network's outputs.)

treasan commented 2 years ago

Was there also instance_norm applied to the targets at the reduced speech setup @alexeib ?

Orlllem commented 2 years ago

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss: image

pred var: image

target var: image

@alexeib thank you for the logs. This behavior of the loss (going down and up and down again) is related to the near-optimality in BYOL (Bootstrap Your Own Latent) ?. The BYOL demonstrate the importance of near-optimal predictor for preventing collapse.

a43992899 commented 2 years ago

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss: image

pred var: image

target var: image

@alexeib Hi, why does the NLP curve looks so different from the speech curve, and seems like it does not converge? Actually, my audio modality training curve is similar to your nlp curve. And my predict and target var does not seem to be collapsing. How do I check if my audio model converges?

image image image
saurabh-kataria commented 1 year ago

FWIW, I was getting target_var < 0.1 error while training data2vec2.0 on my data. I lowered the learning rate from (default) 0.00075 to 0.00050 and the error vanished.