facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.53k stars 242 forks source link

Crashes after first Epoch #26

Closed thomasf1 closed 4 months ago

thomasf1 commented 4 months ago

I´m trying to get jepa to work on Colab, but for some reason it does End/Crash after completing the first Epoch. The output folder is basically empty (one empty csv file)

The pretained model used is vitl16.pth.tar. (https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar)

The dataset used is a bunch of mp4 videos (no class_labels / set to 0)

Could you give me some pointers on how to possibly debug this?

Environment: Colab Pro, tried it with the A100 and V100.

Start of the training with:

!python -m evals.main --fname (my modified vith16_k400_16x8x3.yaml with a small dataset) --devices cuda:0

Output:

INFO:root:called-params /content/jepa/xxx-mini.yaml
INFO:root:loaded params...
{   'data': {   'dataset_train': '/content/jepa/xxx-train-mini.csv',
                'dataset_type': 'VideoDataset',
                'dataset_val': '/content/jepa/xxx-val-mini.csv',
                'frame_step': 4,
                'frames_per_clip': 16,
                'num_classes': 100,
                'num_segments': 8,
                'num_views_per_segment': 3},
    'eval_name': 'video_classification_frozen',
    'nodes': 1,
    'optimization': {   'attend_across_segments': True,
                        'batch_size': 4,
                        'final_lr': 0.0,
                        'lr': 0.001,
                        'num_epochs': 20,
                        'resolution': 224,
                        'start_lr': 0.001,
                        'use_bfloat16': True,
                        'warmup': 0.0,
                        'weight_decay': 0.01},
    'pretrain': {   'checkpoint': 'vitl16.pth.tar',
                    'checkpoint_key': 'target_encoder',
                    'clip_duration': None,
                    'folder': '/content/jepa/',
                    'frames_per_clip': 16,
                    'model_name': 'vit_large',
                    'patch_size': 16,
                    'tight_silu': False,
                    'tubelet_size': 2,
                    'uniform_power': True,
                    'use_sdpa': True,
                    'use_silu': False,
                    'write_tag': 'jepa'},
    'resume_checkpoint': False,
    'tag': 'xxx2',
    'tasks_per_node': 8}
INFO:root:Running... (rank: 0/1)
INFO:root:Running evaluation: video_classification_frozen
INFO:root:Initialized (rank/world-size) 0/1
INFO:root:Loading pretrained model from /content/jepa/vitl16.pth.tar
VisionTransformer(
  (patch_embed): PatchEmbed3D(
    (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  )
  (blocks): ModuleList(
    (0-23): 24 x Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
INFO:root:loaded pretrained model with msg: <All keys matched successfully>
INFO:root:loaded pretrained encoder from epoch: 300
 path: /content/jepa/vitl16.pth.tar
INFO:root:VideoDataset dataset created
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
INFO:root:VideoDataset unsupervised data loader created
Making EvalVideoTransform, multi-view
INFO:root:VideoDataset dataset created
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
INFO:root:VideoDataset unsupervised data loader created
INFO:root:Dataloader created... iterations per epoch: 1076
INFO:root:Using AdamW
INFO:root:Epoch 1
INFO:root:[    0] 0.000% (loss: 4.641) [mem: 3.13e+03]
INFO:root:[   20] 90.476% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   40] 95.122% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   60] 96.721% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   80] 97.531% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  100] 98.020% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  120] 98.347% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  140] 98.582% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  160] 98.758% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  180] 98.895% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  200] 99.005% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  220] 99.095% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  240] 99.170% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  260] 99.234% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  280] 99.288% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  300] 99.336% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  320] 99.377% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  340] 99.413% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  360] 99.446% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  380] 99.475% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  400] 99.501% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  420] 99.525% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  440] 99.546% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  460] 99.566% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  480] 99.584% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  500] 99.601% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  520] 99.616% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  540] 99.630% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  560] 99.643% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  580] 99.656% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  600] 99.667% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  620] 99.678% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  640] 99.688% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  660] 99.697% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  680] 99.706% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  700] 99.715% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  720] 99.723% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  740] 99.730% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  760] 99.737% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  780] 99.744% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  800] 99.750% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  820] 99.756% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  840] 99.762% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  860] 99.768% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  880] 99.773% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  900] 99.778% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  920] 99.783% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  940] 99.787% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  960] 99.792% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  980] 99.796% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1000] 99.800% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1020] 99.804% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1040] 99.808% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1060] 99.811% (loss: 0.000) [mem: 3.24e+03]
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 88 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
MidoAssran commented 4 months ago

Hi @thomasf1 I just tried the code locally and didn't have any issues. Could you add some additional logging to help debug? For example, does the code succesfuly return from this first function call?

thomasf1 commented 4 months ago

@MidoAssran Thanks for the pointer. I´ve added a bunch of debugging and got it to work with a very small dataset. Doing some more work to see where it gets stuck...

What´s the recommendation for the validation dataset in terms of split. Also, is the validation part unsupervised, too - or does it require class ids in the dataset?

thomasf1 commented 4 months ago

One other observation: It seems that the Data loading is done for each Epoch again, leaving the GPU not utilised for quite a long time. This might be a area that can be improved considerably.

Did you guys graph the GPU usage - this might be mitigated in the multi machine training code?

thomasf1 commented 4 months ago
Screenshot 2024-02-22 at 23 48 26
MidoAssran commented 4 months ago

Hi @thomasf1 yes since you are running the evaluation code (training an attentive probe on top of the frozen encoder), the validation part does need a class_id in the dataset index file, as this is a supervised learning problem.

As for the efficiency, yes it's true the data loading is done in each epoch, however, since the evals run reasonably quickly compared to the pretraining since they only involve training a small probe, we didn't try optimizing this further. If you want to speed it up, one option would be to compute the embeddings of the videos in your dataset, and then just train a probe on top of those pre-extracted features.

Since you seemed to have already gotten the eval code working, i'm going to close this task for now, but feel free to comment if you have any other questions!

jackhawa commented 3 months ago

Hello @thomasf1 , I got the same error. Can you mention how did you resolve this? Thanks.

tomarvimal commented 3 months ago

@thomasf1 i have the same error where the model crashes after first epoch while fine tuning the attentive probe over my custom dataset..i am running the task on a single GPU machine , also the RAM utilisation of the model exceeds 30GB's . is there a potential solution to this problem/