facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.53k stars 242 forks source link

Training Loss Increasing After Initial Decrease with Custom Video Dataset #68

Open DaBihy opened 3 weeks ago

DaBihy commented 3 weeks ago

Hello Everyone,

I've been working with the V-JEPA model for a self-supervised learning project using a custom video dataset. Initially, the training loss decreases as expected, but starts to increase significantly after reaching a minimum. This behavior persists across multiple training sessions with different hyperparameters.

jepa_loss_small_collapse

Configuration:

Data Setup

Data Augmentation

Loss Configuration

Mask Settings

Meta Configuration

Model Configuration

Optimization

Questions:

  1. Has anyone else encountered similar issues when training on custom datasets, particularly with video data?
  2. Are there recommended strategies for adjusting the training regimen or model configuration that might stabilize the loss?
  3. Could this be related to the specific characteristics of video data in the custom dataset that might require different handling or preprocessing?

Any insights or suggestions would be greatly appreciated. Thank you for your support!

Best regards,

@MidoAssran

icekang commented 2 weeks ago

Hi,

I have the same problem. Although, I resume from the lastest vjepa at epoch 300 (Plot of jepa-loss) image

However, looking at regression regularization loss, it seems to be continually optimized over-time. image

DaBihy commented 2 weeks ago

@icekang thank you for you comment, I can confirm that I have the same thing for reg loss: Screenshot 2024-06-17 at 15 43 01

The model is learning even though the JEPA loss is increasing. It's counterintuitive, but I think it's normal behavior for such frameworks as I observe the same thing when training BYOL.

icekang commented 2 weeks ago

Sorry, it was not regression loss, it is regularization loss regarding the variance of the predicted vector Anyway, I think I should all be decreasing, especially jepa loss which indicates that the predicted feature vector is close to the actual feature vector image