openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
24.78k stars 3.21k forks source link

CLIP Training Code #83

Open vinson2233 opened 3 years ago

vinson2233 commented 3 years ago

Not really an issue, I just want to share my training code since some people still have some difficulties to write the training code. Just modify the code to suit your usage. Feel free to ask or point out any mistakes in my code.

# Latest Update : 18 July 2022, 09:55 GMT+7

# TO ADD :
# Gradient Checkpointing
# Filter out bias from weight decay
# Decaying learning rate with cosine schedule
# Half-precision Adam statistics
# Half-precision stochastically rounded text encoder weights were used

#BATCH_SIZE must larger than 1

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training

class image_title_dataset(Dataset):
    def __init__(self, list_image_path,list_txt):

        self.image_path = list_image_path
        self.title  = clip.tokenize(list_txt) #you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx])) # Image from PIL module
        title = self.title[idx]
        return image,title

# use your own data
list_image_path = ['folder/image1.jpg','folder2/image2.jpg'] 
list_txt = ['description for image1.jpg' , 'description for image2.jpg']
dataset = image_title_dataset(list_image_path,list_txt)
train_dataloader = DataLoader(dataset,batch_size = BATCH_SIZE) #Define your own dataloader

#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

if device == "cpu":
  model.float()
else :
  clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

# add your own code to track the training progress.
for epoch in range(EPOCH):
  for batch in train_dataloader :
      optimizer.zero_grad()

      images,texts = batch 

      images= images.to(device)
      texts = texts.to(device)

      logits_per_image, logits_per_text = model(images, texts)

      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

      total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
      total_loss.backward()
      if device == "cpu":
         optimizer.step()
      else : 
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)

Code to save the model :

torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"model_checkpoint/model_10.pt") #just change to your preferred folder/filename

Code to load the saved model :

model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
checkpoint = torch.load("model_checkpoint/model_10.pt")

# Use these 3 lines if you use default model setting(not training setting) of the clip. For example, if you set context_length to 100 since your string is very long during training, then assign 100 to checkpoint['model_state_dict']["context_length"] 
checkpoint['model_state_dict']["input_resolution"] = model.input_resolution #default is 224
checkpoint['model_state_dict']["context_length"] = model.context_length # default is 77
checkpoint['model_state_dict']["vocab_size"] = model.vocab_size 

model.load_state_dict(checkpoint['model_state_dict'])

Alternative training code :

lonngxiang commented 3 years ago

what is it that doesn't work? does it raise any error? I set jit=False when loading the model for the clip.load Again, it didn't work out well image

lonngxiang commented 3 years ago

what is it that doesn't work? does it raise any error? I set jit=False when loading the model for the clip.load

There were no mistakes, but the results were bad image

wilderrodrigues commented 3 years ago

Hi @vinson2233 ,

First of all, thanks for the contribution.

I'm implementing a package using part of the code you published here. The add-ons I have are more related to creating the custom dataset, adding unit tests, dockerising the whole thing and also offering a service so other people can have quickly have inference running for their own + some plots.

However, when trying to train on my RTX 2080, I'm getting this:

Traceback (most recent call last):
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/wilderrodrigues/clip-mania/clip_mania/application/train.py", line 37, in <module>
    app.run(main)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/wilderrodrigues/clip-mania/clip_mania/application/train.py", line 28, in main
    model, preprocess = executor.train(dataset_path, epochs=epochs)
  File "/home/wilderrodrigues/clip-mania/clip_mania/core/executor.py", line 67, in train
    logits_per_image, logits_per_text = clip_model(images, prompts)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/clip/model.py", line 355, in forward
    image_features = self.encode_image(image)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/clip/model.py", line 337, in encode_image
    return self.visual(image.type(self.dtype))
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/clip/model.py", line 220, in forward
    x = self.conv1(x)  # shape = [*, width, grid, grid]
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/home/wilderrodrigues/.conda/envs/clip-mania/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 419, in _conv_forward
    return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _thnn_conv2d_forward

This is very weird! The code is basically the same. I was using the latest PyTorch, but just now decided to downgrade it and use the same version as the CLIP repo. It still fails. I also tried this:

clip_model, preprocess = clip.load(name=model_name, device=device, jit=False)
clip_model = clip_model.cuda()

No effect at all.

And yes, the GPU is available. I can see it when I try this:

>>> import torch
>>> torch.cuda.is_available()
True

Thanks in advance!

wilderrodrigues commented 3 years ago

Just another point:

I checked the parameters of both clip.visual.transformer and clip.transformer blocks and all the tensors are already in the right device.

...
-7.6884e-01,  2.5686e-01, -6.0447e-01], device='cuda:0',
       requires_grad=True)
...

I'm also trying to change the Transformer conv1, having it allocated to the CUDA device. But not helping yet.

Forward pass in the model.py of clip is not happy:

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]

Will keep digging.

wilderrodrigues commented 3 years ago

I found the problem, finally! :) For example, this:

      images= torch.stack([preprocess(Image.fromarray(img)) for img in list_image],dim=0)
      texts = clip.tokenize(list_txt)

Should be like this:

      images= torch.stack([preprocess(Image.fromarray(img)) for img in list_image],dim=0).cuda()
      texts = clip.tokenize(list_txt).cuda()

Or, of course, if one prefers, one can have .to(device) instead.

wilderrodrigues commented 3 years ago

Just another quick note: the way the ground_truth is being taken as a range from the BATCH_SIZE is not right. If one has 10 classes and uses 2 batches, it will always give classes (ground_truth) of 0 and 1. So, it bias the training saying that all the images are 0s and 1s.

I changed it and I'm now mapping my prompts (the sentences) to classes in a dictionary: prompt being the key and an integer being the class. It works as expected and with BATCH_SIZE=1 on my MacBook and my GPU. I got 3 novel pictures of airplanes - not in the original dataset - and ran the training for 2 epochs to get this as result:

test_executor.py::TestModelExecutor::test_instance 
test_executor.py::TestModelExecutor::test_train 

============================== 2 passed in 19.44s ==============================

Process finished with exit code 0
PASSED                [ 50%]PASSED                   [100%]
Expected 'an airplane' and  got 'an airplane'
Probability for the expected prompt was '0.9883'
Highest probability was '0.9883'
100%|██████████| 2/2 [00:13<00:00,  6.93s/it]

Before my changes were applied, the loss was always the same and the classifier was not working.

DRSY commented 3 years ago

Have anyone tried fine-tune CLIP on MS COCO image-text retrieval task? How is the performance compared with other state-of-the-art models?

lonngxiang commented 3 years ago

Have anyone tried fine-tune CLIP on MS COCO image-text retrieval task? How is the performance compared with other state-of-the-art models?

everything goes worse when I fine-tune the CLIP

DRSY commented 3 years ago

Have anyone tried fine-tune CLIP on MS COCO image-text retrieval task? How is the performance compared with other state-of-the-art models?

everything goes worse when I fine-tune the CLIP

From my experiment, the zero-shot image retrieval performance is R@1 25.4, R@5 48.7 and R@10 59.9 on the MS COCO 5k test set. After fine-tuning, it slightly improves to R@1 33.6, R@5 62.2 and R@10 73.8. Still lags behind SOTA non-transformer-based models(e.g., VSRN).

lonngxiang commented 3 years ago

Have anyone tried fine-tune CLIP on MS COCO image-text retrieval task? How is the performance compared with other state-of-the-art models?

everything goes worse when I fine-tune the CLIP

From my experiment, the zero-shot image retrieval performance is R@1 25.4, R@5 48.7 and R@10 59.9 on the MS COCO 5k test set. After fine-tuning, it slightly improves to R@1 33.6, R@5 62.2 and R@10 73.8. Still lags behind SOTA non-transformer-based models(e.g., VSRN).

are you use this issues code to finetune? https://github.com/openai/CLIP/issues/83;anythin are you changes ?

vinson2233 commented 3 years ago

@wilderrodrigues I forgot to include to(device) in my code, thanks for the catch. Also regarding the ground truth, this ground_truth is designed for image-title embedding to utilize the concept of n-pair-loss, not for image-title classification. I have mentioned this somewhere in this long threads about modifying the CLIP for classification task (https://github.com/openai/CLIP/issues/83#issuecomment-826262221)

DRSY commented 3 years ago

Have anyone tried fine-tune CLIP on MS COCO image-text retrieval task? How is the performance compared with other state-of-the-art models?

everything goes worse when I fine-tune the CLIP

From my experiment, the zero-shot image retrieval performance is R@1 25.4, R@5 48.7 and R@10 59.9 on the MS COCO 5k test set. After fine-tuning, it slightly improves to R@1 33.6, R@5 62.2 and R@10 73.8. Still lags behind SOTA non-transformer-based models(e.g., VSRN).

are you use this issues code to finetune? https://github.com/openai/CLIP/issues/83;anythin are you changes ?

I used the PyTorch lightning code, but modify the learning rate for ViT-B/32 from 5e-4 to 5e-5. If I use 5e-4 learning rate, after the first several steps, the model degraded to near-zero accuracy. But if I use 5e-5 learning rate, the model will steadily get improved to the results I posted.

yangydeng commented 3 years ago

Not really an issue, I just want to share my training code since some people still have some difficulties to write the training code. Just modify the code to suit your usage. Feel free to ask or point out any mistakes in my code.

# Latest Update : 04 June 2021, 09:59 GMT+7

# TO ADD :
# Gradient Checkpointing
# Filter out bias from weight decay
# Decaying learning rate with cosine schedule
# Half-precision Adam statistics
# Half-precision stochastically rounded text encoder weights were used

#BATCH_SIZE must larger than 1
train_dataloader = DataLoader(...,batch_size = BATCH_SIZE) #Define your own dataloader

#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
if device == "cpu":
  model.float()
else :
  clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

for epoch in range(EPOCH):
  for batch in train_dataloader :
      optimizer.zero_grad()

      list_image,list_txt = batch #list_images is list of image in numpy array(np.uint8), or list of PIL images

      images= torch.stack([preprocess(Image.fromarray(img)) for img in list_image],dim=0).to(device) # omit the Image.fromarray if the images already in PIL format, change this line to images=list_image if using preprocess inside the dataset class
      texts = clip.tokenize(list_txt).to(device)

      logits_per_image, logits_per_text = model(images, texts)
      if device == "cpu":
        ground_truth = torch.arange(BATCH_SIZE).long().to(device)
      else:
        ground_truth = torch.arange(BATCH_SIZE).half().to(device)

      total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
      total_loss.backward()
      if device == "cpu":
         optimizer.step()
      else : 
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)
  • NOTE :

  • that for inference purpose, the conversion step from fp16 to fp32 is not needed, just use the model in full fp16

  • For multi-GPU training, see my comment on #111 (comment)

Code to save the model :

torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"model_checkpoint/model_10.pt") #just change to your preferred folder/filename

Code to load the saved model :

model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
checkpoint = torch.load("model_checkpoint/model_10.pt")

# Use these 3 lines if you use default model setting(not training setting) of the clip. For example, if you set context_length to 100 since your string is very long during training, then assign 100 to checkpoint['model_state_dict']["context_length"] 
checkpoint['model_state_dict']["input_resolution"] = model.input_resolution #default is 224
checkpoint['model_state_dict']["context_length"] = model.context_length # default is 77
checkpoint['model_state_dict']["vocab_size"] = model.vocab_size 

model.load_state_dict(checkpoint['model_state_dict'])

Notes : @Zasder3 have created a PyTorch lighting version to train the CLIP https://github.com/Zasder3/train-CLIP

Hi, thanks for your sharing. I was wondering what's the difference between your operation (like cast to fp16 and set back to fp32) and automatic-mix-precision(eg. torch.cuda.amp). Because in my code, I adopt "torch.cuda.amp" but it sometimes occurs NAN during training.

vinson2233 commented 3 years ago

@yangydeng good question. Your question was discussed in the https://github.com/openai/CLIP/issues/57#issuecomment-794864178. The Author manually specify what kind of operation done in FP16 and FP32, this can be seen from these part of code https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L371-L392

That's why torch.cuda.amp give different result since amp convert everything(I think), while the author filter out which operation done in FP16

yangydeng commented 3 years ago

@yangydeng good question. Your question was discussed in the #57 (comment). The Author manually specify what kind of operation done in FP16 and FP32, this can be seen from these part of code https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L371-L392

That's why torch.cuda.amp give different result since amp convert everything(I think), while the author filter out which operation done in FP16

Thanks for your reply, I'll try that later. And I found you adpot quite a large weight-decay (=0.2) as default. Have you try any other number?

vinson2233 commented 3 years ago

@yangydeng That's the value that used by the author (see the paper Appendix F, table 18 on page 48). I haven't try weight-decay lower than that. But I have tried AdamW with weight decay=0.1 and it gives a worse result.

yangydeng commented 3 years ago

@yangydeng That's the value that used by the author (see the paper Appendix F, table 18 on page 48). I haven't try weight-decay lower than that. But I have tried AdamW with weight decay=0.1 and it gives a worse result.

Cool, many details can be found on this table.

ThakurRajAnand commented 3 years ago

@vinson2233 Thanks for sharing the code. I am getting following error while using the code. Have you seen this during your training?

RuntimeError: Expected object of scalar type Long but got scalar type Half for argument #2 'target' in call to _thnn_nll_loss_forward

vinson2233 commented 3 years ago

@ThakurRajAnand Is the error raise on the loss_img(logits_per_image,ground_truth) part? Are you working on CPU or GPU ? The error you are showing indicates that the target is Half(then your device is not "CPU"), but the logit are Long type, which indicate that your device is indeed CPU. It might be more helpful if you could print the data type at any given step of your training.

uahsan3 commented 3 years ago

@ThakurRajAnand Is the error raise on the loss_img(logits_per_image,ground_truth) part? Are you working on CPU or GPU ? The error you are showing indicates that the target is Half(then your device is not "CPU"), but the logit are Long type, which indicate that your device is indeed CPU. It might be more helpful if you could print the data type at any given step of your training.

Hi, I am facing the same issue. I verified that my device = 'cuda:0' - I am not sure where I am going wrong. Can you please point out why its giving the error:

RuntimeError: Expected object of scalar type Long but got scalar type Half for argument #2 'target' in call to _thnn_nll_loss_forward

vinson2233 commented 3 years ago

@ThakurRajAnand @uahsan3 Thanks for the catch. The mistake is from my code. The ground truth should be always long regardless of the device. I've already edited the code. This is the new definition of the ground truth.

ground_truth = torch.arange(BATCH_SIZE,dtype=torch.long,device=device)
ThakurRajAnand commented 3 years ago

@vinson2233 Thanks. I got busy and didn't get a chance to reply that I was able to fix it same day by making same change in your code.

I had another question. Can we use the model to predict text from the model I trained on my own data? I tried but it was asking for image and text both.

vinson2233 commented 3 years ago

@ThakurRajAnand the forward method that defined on CLIP model is basically composition of model.encode_text(text), model.encode_image(img) and cosine calculation between 2 embeddings. If you want to use just the text, then the encode_text is the method you are looking for. See this part of code for details https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L354-L356

ThakurRajAnand commented 3 years ago

@vinson2233 I think I didn't explain clearly. I am interested in predicting text by providing new image only. Is that possible? Text I trained on is long e.g. this is a geometry question and we have circles in it

vinson2233 commented 3 years ago

@ThakurRajAnand Ohh i see. So regarding text generation, there are 2 types of techniques. The first one is you already have a set of predefined text and try to fetch which text is the most suitable one(text-retrieval style). The second one is to generate new text from the image(generative). I think the CLIP cannot do the latter one, but you can use CLIP for the text-retrieval style. I think what your are looking for is something like this https://github.com/pzzhang/VinVL or https://github.com/microsoft/Oscar

lonngxiang commented 3 years ago

if I only want to save encode_text piece of the model, anyone knows how to do it?

vinson2233 commented 3 years ago

@lonngxiang someone have post this question in this threads https://github.com/openai/CLIP/issues/113. I also want to know the answer.

lonngxiang commented 3 years ago

@lonngxiang someone have post this question in this threads #113. I also want to know the answer.

tks

lr19960813 commented 3 years ago

@vinson2233 Hey vinson, Thank you for sharing, it was a great help. But I encountered a problem during training.

Traceback (most recent call last): File "/CLIP-main/fine_tune.py", line 88, in logits_per_image, logits_per_text = model(images, texts) File "/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/CLIP-main/clip/model.py", line 357, in forward image_features = self.encode_image(image) File "/CLIP-main/clip/model.py", line 339, in encode_image return self.visual(image.type(self.dtype)) File "/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/CLIP-main/clip/model.py", line 226, in forward x = x + self.positional_embedding.to(x.dtype) RuntimeError: The size of tensor a (1711) must match the size of tensor b (50) at non-singleton dimension 1`**

I further tested and finally found that the error occurred in model.py line226: `

def forward(self, x: torch.Tensor):
    x = self.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + self.positional_embedding.to(x.dtype) **#The size of tensor x is 1711 but another is 50.**
    x = self.ln_pre(x)`

I’m not sure what went wrong. I used pandas to load my csv data set, and some texts are very long. Maybe the reason is the text length is too long to be preprocessed but I am not sure.

vinson2233 commented 3 years ago

@lr19960813 first, your error occured during the image_features = self.encode_image(image) step, so the error caused by the image. Have you make sure to use preprocess module to your images before feeding them to the network? you can put it inside the dataloader when fetching the images.

second, regarding the long text, yes, Clip text only accept token length 77 (1 token represent 1-3 characracter). It's better to trim your text first.

lr19960813 commented 3 years ago

@lr19960813 first, your error occured during the image_features = self.encode_image(image) step, so the error caused by the image. Have you make sure to use preprocess module to your images before feeding them to the network? you can put it inside the dataloader when fetching the images.

second, regarding the long text, yes, Clip text only accept token length 77 (1 token represent 1-3 characracter). It's better to trim your text first.

Thank you for your reply! It is very helpful! I have finiished that experiment. By the way, how I can train clip model with my own dataset from scratch. It seems that this code justs fine tune clip model by using its default parameters.

vinson2233 commented 3 years ago

@lr19960813 Then you need to load the model from scratch without loading the trained weight. Here's the code, you can also set the context_length to a greater number if you have a really long text.

from clip.model import CLIP
model = CLIP(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
    )

If you confused about how to fill the value (for example : embed_dim), you can refer to existing model by inspecting the state_dict of existing model.

model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Best model use ViT-B/32
model_state_dict = model.state_dict()
embed_dim = model_state_dict['text_projection'].shape[1]

To see other params, you can refer to this code : https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L395-L424

gunwooYong commented 3 years ago

Hello, this is Yong, a researcher from South Korea.

I appreciate your work and it is very helpful for me as I have tried conducting fine-tuning task. As I studied, CLIP includes a potential risk such as overfiting during fine-tuning with the small number of data.

Thus, I wonder about how much data is necessary at least when applying fine-tuning? In my case, I have 4 classes and each class includes at most 10 pictures. In such scenario, can I carry out fine-tuning task? or just use the traditional CLIP?

I am looking forward to your answer and I appreciate your work again. Thank you.

lr19960813 commented 3 years ago

Dear Yong

This is Li Rui a doctor from japan. I think you are right. I meet the same problem with you. I try another fine tune task by using 80000+ datas and it can not converge. I am trying small learning rate and use bigger batch size. It preforms better than before but still not very good. I think you can try it.

Li Rui

gunwooYong @.***> 于2021年7月22日周四 下午1:54写道:

Hello, this is Yong, a researcher from South Korea.

I appreciate your work and it is very helpful for me as I have tried conducting fine-tuning task. As I studied, CLIP includes a potential risk such as overfiting during fine-tuning with the small number of data.

Thus, I wonder about how much data is necessary at least when applying fine-tuning? In my case, I have 4 classes and each class includes at most 10 pictures. In such scenario, can I carry out fine-tuning task? or just use the traditional CLIP?

I am looking forward to your answer and I appreciate your work again. Thank you.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/openai/CLIP/issues/83#issuecomment-884655027, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJQWBGTJXBFN6TPRPGURLWDTY6P73ANCNFSM42SGLVGA .

gunwooYong commented 3 years ago

Dear Rui,

I really appreciate your help. As described in CLIP-Art(https://github.com/KeremTurgutlu/clip_art), they used 140,000 images for fine-tuning. So, I think you may solve your problem thorough data augmentation.

However, I cannot acquire more data as it is related with a specific domain. For this reason, I am trying to conduct fine-tuning with the tiny number of images.

I hope you to overcome this issue.

Best regard,

Dear Yong This is Li Rui a doctor from japan. I think you are right. I meet the same problem with you. I try another fine tune task by using 80000+ datas and it can not converge. I am trying small learning rate and use bigger batch size. It preforms better than before but still not very good. I think you can try it. Li Rui gunwooYong @.***> 于2021年7月22日周四 下午1:54写道: Hello, this is Yong, a researcher from South Korea. I appreciate your work and it is very helpful for me as I have tried conducting fine-tuning task. As I studied, CLIP includes a potential risk such as overfiting during fine-tuning with the small number of data. Thus, I wonder about how much data is necessary at least when applying fine-tuning? In my case, I have 4 classes and each class includes at most 10 pictures. In such scenario, can I carry out fine-tuning task? or just use the traditional CLIP? I am looking forward to your answer and I appreciate your work again. Thank you. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#83 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJQWBGTJXBFN6TPRPGURLWDTY6P73ANCNFSM42SGLVGA .

vinson2233 commented 3 years ago

@gunwooYong I might be not the person you are looking for since I'm not the author nor have any affiliation with openai. I think this question is better to discuss in new issue on github page so the real author have more chance to see the question.

But regardless, I did train CLIP as image classification model, where I have 19 categories with around 100k images, with modification where the loss only use the first logits(dim : number_image x num_class) since 1 image only have 1 category but 1 category can be owned by multiple images in the batch. The performance is still way behind a CNN Classification model.

mitchellnw commented 3 years ago

In case it is helpful, posting a link to our implementation of CLIP-training code https://github.com/mlfoundations/open_clip

vinson2233 commented 3 years ago

@mitchellnw Nice, thanks for sharing, I'll put it on https://github.com/openai/CLIP/issues/83#issue-853114174 so it is easier for people to reach your work.

sarahESL commented 3 years ago

Thank for providing this @vinson2233 . I was wondering how come the loss does not include the cosine similarity of the encoded image and text? Following the "Figure 3" from the original clip paper. Or it is included and I am missing something?! :thinking:

vinson2233 commented 3 years ago

@sarahESL, Actually, the code already include that, but it is not explicit enough on showing that.

logits_per_image, logits_per_text = model(images, texts)

This line of code is producing the cosine similarity of the encoded image and text, not the embedding. I'm using the terms 'logits' just to follow the paper and keep it consistent with the example shown in the readme (Actually I didn't like the word logits).

So if we have 10 text and 5 images. logits_per_image dimension is 5 x 10, and the logits_per_text is 10 x 5. Where each entries is cosine similarity times the temperature parameter (you can divide it by 100 or the temperature value itself to get the usual cosine similarity tough).

DRSY commented 3 years ago

Hi, just wanna share my work on deploying CLIP on iOS for cross-modal text-to-image retrieval/search: https://github.com/DRSY/MTIS. I exported the CLIP image encoder and text encoder as TorchScript format, which can then be loaded via torch's c++ libtorch frontend.

NingYuanxiang commented 3 years ago

Hi, I want to revise the structure of clip's model. how to load some special layer and keep it constant in train on new dataset.If you reply me earlier, I will be greatly thankful for you.

vinson2233 commented 3 years ago

I'm not sure what you mean by revise. But have you seen this link? https://discuss.pytorch.org/t/how-to-replace-a-layer-or-module-in-a-pretrained-network/60068 So basically you just access specific part of CLIP, and replace it.

To make it constant, I'm assuming what you mean is not update that specific layer during training. Maybe this forum will match your interest https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088. So basically just set require grad = False if the layers name match with your custom layer

NingYuanxiang commented 3 years ago

I'm not sure what you mean by revise. But have you seen this link? https://discuss.pytorch.org/t/how-to-replace-a-layer-or-module-in-a-pretrained-network/60068 So basically you just access specific part of CLIP, and replace it.

To make it constant, I'm assuming what you mean is not update that specific layer during training. Maybe this forum will match your interest https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088. So basically just set require grad = False if the layers name match with your custom layer

Thank for providing this. I have some confuse, how to verify the revised net have improvement. Which training set is best to train and verify.

vinson2233 commented 3 years ago

@NingYuanxiang If I in your position, I will create new image-title dataset, which have not been trained by the CLIP itself (So don't use COCO). This dataset should be totally new and never seen by the CLIP before. After you have such dataset, try to finetune the original CLIP for the new dataset. Also finetune the revised CLIP for the new dataset, and then compare the performance. Of course this approach is only based on empirical result, which mean you need to do extensive experiment for validating your modification of CLIP.

4sunshine commented 3 years ago

@NingYuanxiang If I in your position, I will create new image-title dataset, which have not been trained by the CLIP itself (So don't use COCO). This dataset should be totally new and never seen by the CLIP before. After you have such dataset, try to finetune the original CLIP for the new dataset. Also finetune the revised CLIP for the new dataset, and then compare the performance. Of course this approach is only based on empirical result, which mean you need to do extensive experiment for validating your modification of CLIP.

Dear @vinson2233 am I correct that as least train split of MS-COCO (COCO Captions) was used during training of provided checkpoints in this repo?

vinson2233 commented 3 years ago

@4sunshine When reading the paper, I'm assume they use every dataset that mentioned in paper in the checkpoint model in this repo. For example MS-COCO (Lin et al., 2014), Visual Genome (Krishna et al., 2017), YFCC100M (Thomee et al., 2016). I don't know if the COCO part they are using is only the train part, or maybe they just doing evaluation on COCO. We need to clarify it with the author (And again, I'm not the author).

4sunshine commented 3 years ago

@4sunshine When reading the paper, I'm assume they use every dataset that mentioned in paper in the checkpoint model in this repo. For example MS-COCO (Lin et al., 2014), Visual Genome (Krishna et al., 2017), YFCC100M (Thomee et al., 2016). I don't know if the COCO part they are using is only the train part, or maybe they just doing evaluation on COCO. We need to clarify it with the author (And again, I'm not the author).

Thank you @vinson2233 for quick reply. I want to know the answer because of evaluation setup used in paper https://arxiv.org/abs/2012.04329. This paper is about Scene-Text aware Text-to-Image retrieval and its test samples partly belong to MS-COCO train.

minmummax commented 3 years ago

is this demo code used for finetune on small dataset?

liuhui0401 commented 3 years ago

I want to know what is the general loss of your code during training? My loss will drop from greater than 2 to greater than 1. Is this normal?