facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.68k stars 254 forks source link

Training Loss Increasing After Initial Decrease with Custom Video Dataset #68

Closed DaBihy closed 2 months ago

DaBihy commented 5 months ago

Hello Everyone,

I've been working with the V-JEPA model for a self-supervised learning project using a custom video dataset. Initially, the training loss decreases as expected, but starts to increase significantly after reaching a minimum. This behavior persists across multiple training sessions with different hyperparameters.

jepa_loss_small_collapse

Configuration:

Data Setup

Data Augmentation

Loss Configuration

Mask Settings

Meta Configuration

Model Configuration

Optimization

Questions:

  1. Has anyone else encountered similar issues when training on custom datasets, particularly with video data?
  2. Are there recommended strategies for adjusting the training regimen or model configuration that might stabilize the loss?
  3. Could this be related to the specific characteristics of video data in the custom dataset that might require different handling or preprocessing?

Any insights or suggestions would be greatly appreciated. Thank you for your support!

Best regards,

@MidoAssran

icekang commented 5 months ago

Hi,

I have the same problem. Although, I resume from the lastest vjepa at epoch 300 (Plot of jepa-loss) image

However, looking at regression regularization loss, it seems to be continually optimized over-time. image

DaBihy commented 5 months ago

@icekang thank you for you comment, I can confirm that I have the same thing for reg loss: Screenshot 2024-06-17 at 15 43 01

The model is learning even though the JEPA loss is increasing. It's counterintuitive, but I think it's normal behavior for such frameworks as I observe the same thing when training BYOL.

icekang commented 5 months ago

Sorry, it was not regression loss, it is regularization loss regarding the variance of the predicted vector Anyway, I think I should all be decreasing, especially jepa loss which indicates that the predicted feature vector is close to the actual feature vector image

zetaSaahil commented 2 months ago

I had a very similar behavior with my custom dataset. However, I figured that with JEPA (while training with a small subset of the training data), the loss first increases, and then decreases after some time. I also realised that the learning rate should be high enough for the loss to overcome this local minima (for me, lr of 1e-3 for the small dataset, and 6e-4 for the big dataset worked the best).

jepaloss