kohjingyu / fromage

🧀 Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".
https://jykoh.com/fromage
Apache License 2.0
474 stars 35 forks source link

Question about the frozen language model #6

Closed sijeh closed 1 year ago

sijeh commented 1 year ago

https://github.com/kohjingyu/fromage/blob/964bd157ba3ac24824d525d7f90037eef42aa7c2/fromage/models.py#L61

Hello kohjingyu, thanks for your great work! I'm a little confused about the frozen LLM model. It seems all the parameters in LLM are frozen. Should the input_embedding.weight in the [RET] position be learnable? I could not find such code as self.input_embedding.requires_grad_(True) or self.input_embedding.weight.requires_grad = True. On the other hand, I see the gradient of input_embedding is adjusted in param.grad[mask, :] = 0 Please point out if I neglect some important information.

Best regards.

kohjingyu commented 1 year ago

Thanks for bringing this up, as it is a subtle point. When the token embeddings are resized (to include the extra [RET] token): https://github.com/kohjingyu/fromage/blob/964bd157ba3ac24824d525d7f90037eef42aa7c2/fromage/models.py#L71

requires_grad is automatically set to true for them (you can verify this during training when it prints out the params and whether they are trainable). This is why we have to zero out the gradients of the non-[RET] tokens in the training loop, so prevent them from changing. It's quite complicated to make just the [RET] embedding row trainable, so I elected to this instead.

Hope that answers your question!

sijeh commented 1 year ago

That's the exact point I have overlooked. Thanks for your kind reply.

GuangtaoLyu commented 3 months ago

Thanks for bringing this up, as it is a subtle point. When the token embeddings are resized (to include the extra [RET] token):

https://github.com/kohjingyu/fromage/blob/964bd157ba3ac24824d525d7f90037eef42aa7c2/fromage/models.py#L71

requires_grad is automatically set to true for them (you can verify this during training when it prints out the params and whether they are trainable). This is why we have to zero out the gradients of the non-[RET] tokens in the training loop, so prevent them from changing. It's quite complicated to make just the [RET] embedding row trainable, so I elected to this instead.

Hope that answers your question!

i print the requires_grad. it's false.

for param in self.input_embeddings.parameters(): print(param.requires_grad)

assert param.grad.shape[0] == len(tokenizer)

AttributeError: 'NoneType' object has no attribute 'shape'