Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.
MIT License
653 stars 78 forks source link

model checkpointing #10

Closed sour4bh closed 3 years ago

sour4bh commented 3 years ago

Hey, Thank you for the lightning implementation, just what I needed at the moment! However, I'm a little confused about model checkpointing. I would assume it automatically saves the checkpoint to lightning_logs/checkpoints/, however after a full training run I didn't find anything saved in the checkpoints folder. I'm taking a deeper look into the repo and from first glance, I can see you didn't override that hook. I'm guessing the default checkpointing hook would not work since this is self-distillation (I'm using train_finetune.py btw) Let me know in case this is not expected behaviour.

Zasder3 commented 3 years ago

This is odd behavior. In my training runs, it saved weights at the end of every epoch into the directory lightning_logs/version_N/checkpoints. Could you detail the command you used to start the training run and training duration used?

sour4bh commented 3 years ago

Yes, certainly it was an odd behaviour and wanted to get your thoughts on it.

I used the following command to invoke train_finetune.py: python train_finetune.py --folder dataset --batch_size 256 --gpu 1 --num_workers 4

Extra info : I'm running this on a google colab. Following are the series of commands I execute after cloning your repo to setup my training environment:

!pip install ftfy regex
!pip install transformers
!pip install git+https://github.com/openai/CLIP.git

!pip install torch==1.8.1 pytorch-lightning

import pytorch_lightning as pl
print(pl.__version__) ## 1.3.5

!pip install torchtext==0.9.1

The above dependencies version choices were made in order to get the pl library to work in colab!

Zasder3 commented 3 years ago

I'm following your setup and was unable to replicate this bug. Does this issue continue to persist?

Slightly unrelated, I notice in your fork that you use a BERT-based model. I updated the library to support those types of models more naturally (doesn't average word embeddings to get sentence embedding).