fholstege / JSE

BSD 3-Clause "New" or "Revised" License
0 stars 0 forks source link

The meaning of "fine-tune" referred in the paper #1

Open LuoyaoChen opened 1 month ago

LuoyaoChen commented 1 month ago

Hi,

I am wondering about the meaning of "fine-tune" in the paper, page 41, Section I.2,

For CelebA, this means using a learning rate of 10−3
, a weight decay of 10−4
, a batch size of 128, and for 50epochs without early stopping. 
We use stochastic gradient descent (SGD) with a momentum parameter of 0.9. After this,
PCA is applied to the embeddings of the layer of the architecture, to reduce the dimensionality from 2048 to 300

Does this only apply to the code generate_result_sim.py? i.e. finetune refers to only the linear layer?

Specifically, when using the code create_embeddings.py, (which corresp. to results in Appendix, Table 1-5), even when we create the embedidngs, shall we fine-tuned image/language embeddings? If the answer is yes, can we have the fine-tune arguments? Or, rather, when the paper says"fine-tune" it means fine-tune the JSE's last linear layer only, not the embeddings?

thank you!

Best,

LuoyaoChen commented 1 month ago

Also, to replicate the number in Table2, should I run create_celebA_embeddings.py and then finetune_celebA.py or reverse?

Thank you!

fholstege commented 1 month ago

Hey Luoya,

Thanks for taking the time to check out our code!

In the paper, when referring to 'fine-tune', we mean that the entire neural network was finetuned for the task.

In order to replicate this for the Waterbirds task (the other datasets follow a similar structure), you can follow these steps;

  1. Run datasets/Waterbirds/finetune_Waterbirds.py. We use the learning rate, weight decay etc. in line with a paper from Kirichenko et al. (https://arxiv.org/abs/2204.02937) as referred to in the appendix of our paper (you can also find the hyper-parameters there).
  2. Once you have fine-tuned the model, you can use it to create embeddings with datasets/Waterbirds/create_waterbird_embeddings.py.
  3. Once you have embeddings, you can then run the different methods for comparison using the 'generateresult' files.

Let me know if you have any further questions!

LuoyaoChen commented 3 weeks ago

Thank you very much for your in-time and clear clarifications!

I do have one further thought though, which I would like to confirm, for the first step you listed, i.e.

  1. Run datasets/Waterbirds/finetune_Waterbirds.py. We use the learning rate, weight decay etc. in line with a paper from Kirichenko et al. (https://arxiv.org/abs/2204.02937) as referred to in the appendix of our paper (you can also find the hyper-parameters there).

I noticed in file datasets/Waterbirds/finetune_Waterbirds.py, line 97,

resnet50_features = embedding_creator(resnet50)
class embedding_creator(nn.Module):
  """
  Create version of the model, specifically for creating embeddings
  """

  def __init__(self, original_resnet, prevent_finetuning = True):
    super(embedding_creator, self).__init__()

    # set require_grad = False to all the original resnet layers
    if prevent_finetuning:
      for param in original_resnet.parameters(): # this ensures these layers are not trained subsequently
        param.requires_grad = False

    # select everything but the last layer
    self.list_feature_modules = list(original_resnet.children())[:-1]

    # define the feature extractor 
    self.feature_extractor =  nn.Sequential(
                    # stop at conv4
                    *self.list_feature_modules
                )

    self.flatten = nn.Flatten(start_dim =1, end_dim= -1)

  def forward(self, x):

    # get embedding
    embedding = self.feature_extractor(x)
    embedding_flat = self.flatten(embedding)

    return embedding_flat

Why I am asking: I am asking, as I noticed that in Line 120 of finetun_Waterbirds.py, we put both resnet50_features and W in the optimizer, here only W is optimized; however, when we torch.save, we only saved resnet50_features(), not W. On the other hand, I assume the purpose of finetuning is to update the parameters, so that when we later run create_embedding.py, images can be computed using finetuned-resnet or finetuned-W, or both.

My question is:

  1. Why resnet50 embedding was not updated...? i.e. param.requires_grad = False for original resnet()? Is this intended? Should't it be fine-tuned?
  2. Why do we not save W in statedict() at the end of fine-tune? I am assuming we should be using the updated W to create_embedding.py?
  3. Last, regarding performance reported in paper table 1, is that based on after fine-tuned resnet or W?

Thank you so much! I would love to hear if I didn't interpret your code or intention correctly.

Best,