callsys / GenPromp

[ICCV 2023] Generative Prompt Model for Weakly Supervised Object Localization
Apache License 2.0
55 stars 2 forks source link

When do `train_unet`, Why don't you use pretrained weight of tokens? #4

Closed seilk closed 1 year ago

seilk commented 1 year ago

First and foremost, I'd like to express my profound gratitude for the outstanding paper and the code implementation. I have one point of curiosity.

In attempting train_unet, isn't it the case that the initial weights of each category token pretrained in train_token are not used?

In the code, train_unet is executed with split="train". Given this, just like when running train_token, wouldn't the initial weights of the concept_token in the text_encoder be initialized identically to the initial weights of the meta_token?

Since all parameters of the text_encoder are frozen during train_unet, wouldn't this mean that the unet is fine-tuned with the initial weights of both the meta_token and concept_token being the same?

In the Loss formula (5) mentioned in the paper, it is depicted as in the linked image. This Loss seems to utilize f* (pretrained initial weight), hence my query. image

Thank you always for your hard work.

callsys commented 1 year ago

Thanks for your correction, we have introduced this bug when reimplementing the code. the initial weights of both the meta_token and concept_token are needed in training. To fix this bug, we change line 199 in base.py from

if self.test_mode and self.load_token_path is not None:
    text_encoder = self.load_embeddings(text_encoder)

into

if self.test_mode or self.load_token_path is not None:
    text_encoder = self.load_embeddings(text_encoder)

to make sure self.self.load_embeddings is called when load_token_path is not None.