kdexd / virtex

[CVPR 2021] VirTex: Learning Visual Representations from Textual Annotations
http://kdexd.xyz/virtex
MIT License
557 stars 61 forks source link

Training loss acts strangely after resuming #18

Closed BaohaoLiao closed 3 years ago

BaohaoLiao commented 3 years ago

Hi,

I want to reproduce your pre-training result. There was a accident that caused the interruption of my training. I restored it by the flag "--resume-from" and it acts weirdly. The training and validation loss jumped dramatically at the beginning and then decreased, which seems there is a problem about the restoring. Could you help me about this?

kdexd commented 3 years ago

Hi @BaohaoLiao,

This is a known issue. I have not been able to exactly pinpoint it (I save states for model, optimizer, and scheduler). My best guess is the mismatch between epochs and iterations. I prefer saving checkpoints over fixed iterations instead of epochs (it gives more control as iterations are more fine-grained scale than epochs). I am saving checkpoints every 2000 iterations by default, while each epoch takes 462 iterations (for batch size 256). So resuming training may restart an epoch abruptly. If you train for long enough after resuming (~100K iterations or more), then the final performance converges to within 0.1% of the reported results. As a quick hack, could you try (also) saving checkpoints every epoch:

if iteration % 462 == 0:
    checkpoint_manager.step(iteration)

... or simply changing --checkpoint-every argument to 462.

This should not cause an abrupt behavior. Please let me know if it does, I should be investigating more into this issue. In case you are curious, please be assured of the numbers reported in the paper — I ran all experiments of the paper (500K iterations) without any resuming. So if there is really a bug here, it would not have affected the final results :-)

Similar Experience with this Issue

I have received this issue via email a few weeks ago, and I had suggested this fix. They tried it and responded:

I am running the VirTex on the other dataset, and I used checkpoint-every argument to int(num_videos/batch_size). Then, I restarted from a certain checkpoint. If the checkpoint is too early, I still see some discrepancies, but the difference seems smaller than before! Also, I see the model recovers quickly from the checkpoint.

kdexd commented 3 years ago

In case you find any obvious bug, please do let me know, I will be more than happy to accept your fix!

BaohaoLiao commented 3 years ago

Hi @BaohaoLiao,

This is a known issue. I have not been able to exactly pinpoint it (I save states for model, optimizer, and scheduler). My best guess is the mismatch between epochs and iterations. I prefer saving checkpoints over fixed iterations instead of epochs (it gives more control as iterations are more fine-grained scale than epochs). I am saving checkpoints every 2000 iterations by default, while each epoch takes 462 iterations (for batch size 256). So resuming training may restart an epoch abruptly. If you train for long enough after resuming (~100K iterations or more), then the final performance converges to within 0.1% of the reported results. As a quick hack, could you try (also) saving checkpoints every epoch:

if iteration % 462 == 0:
    checkpoint_manager.step(iteration)

... or simply changing --checkpoint-every argument to 462.

This should not cause an abrupt behavior. Please let me know if it does, I should be investigating more into this issue. In case you are curious, please be assured of the numbers reported in the paper — I ran all experiments of the paper (500K iterations) without any resuming. So if there is really a bug here, it would not have affected the final results :-)

Similar Experience with this Issue

I have received this issue via email a few weeks ago, and I had suggested this fix. They tried it and responded:

I am running the VirTex on the other dataset, and I used checkpoint-every argument to int(num_videos/batch_size). Then, I restarted from a certain checkpoint. If the checkpoint is too early, I still see some discrepancies, but the difference seems smaller than before! Also, I see the model recovers quickly from the checkpoint.

Thank you for your answer. I think this problem is also caused by the mixed precision training. When I resume the training, there is always warning about "overflow gradient " caused by amp. And amp rescales the loss. According to https://nvidia.github.io/apex/amp.html, you might need to save the state of amp also.

kdexd commented 3 years ago

I think that might be it, thank you for spotting this! I will work on a fix in a few days. If you wish to make a PR with this change, I would be happy to merge it!

Jeff-LiangF commented 3 years ago

Hi @kdexd,

Thank you for your great codebase. All the community will benefit from your awesome code!

I also encountered the resume issue: training loss jumped dramatically at the beginning. I would like to know whether this phenomenon would affect the final results. As you posted,

If you train for long enough after resuming (~100K iterations or more), then the final performance converges to within 0.1% of the reported results.

May I assume it is just a noise jitter and does not have much effect on the final validation?

Thanks!

kdexd commented 3 years ago

Hi @Jeff-LiangF – that's true. And while you're at it, I would recommend you to try plugging in the "StatefulDistributedSampler", which takes care of resuming data loading from a partial epoch.

Check out this doc: https://vissl.readthedocs.io/en/v0.1.5/large_scale/stateful_sampler.html And code here: https://vissl.readthedocs.io/en/master/_modules/vissl/data/data_helper.html

If it works well for you, do consider opening a Pull Request, I will be more than happy to merge it in the codebase!

Jeff-LiangF commented 3 years ago

Hi @kdexd,

Many thanks for your prompt help! I will try StatefulDistributedSampler recently. If it works well, I will put a PR. :-)

4m4n5 commented 3 years ago

Hey @kdexd,

I have also been using this repo's framework for another implementation. I tried plugging in a "StatefulDistributedSampler" and the loss is still unstable on resuming. And this is with a completely different model architecture and tasks. I also tried checkpointing at the ends of epochs and continuing from as you mentioned above.

if iteration % 462 == 0:
    checkpoint_manager.step(iteration)

... or simply changing --checkpoint-every argument to 462.

This should not cause an abrupt behavior. Please let me know if it does, I should be investigating more into this issue. In case you are curious, please be assured of the numbers reported in the paper — I ran all experiments of the paper (500K iterations) without any resuming. So if there is really a bug here, it would not have affected the final results :-)

But that does not solve the problem which leads me to believe that maybe the problem is not epoch-iteration mismatch related. However, I have not been able to solve the issue but hope that these observations help.

kdexd commented 3 years ago

Thanks for this information @4m4n5, it helped me a lot to narrow down the potential causes of this issue — the real culprit is the state dict of Lookahead Optimizer not being restored properly. https://github.com/kdexd/virtex/commit/1c91943260498b44b095055550bbda5565fa649e fixes it. I tried to resume training on arbitrary iterations, and the spiky loss values do not occur anymore! Please pull from master and let me know if you face any issues.

4m4n5 commented 3 years ago

That completely solves the issue at my end. You can go ahead and close this issue now. This helps a lot :)

kdexd commented 3 years ago

That's great to hear, closing this issue now!