Closed sijeh closed 1 year ago
Thanks for bringing this up, as it is a subtle point. When the token embeddings are resized (to include the extra [RET]
token):
https://github.com/kohjingyu/fromage/blob/964bd157ba3ac24824d525d7f90037eef42aa7c2/fromage/models.py#L71
requires_grad
is automatically set to true for them (you can verify this during training when it prints out the params and whether they are trainable). This is why we have to zero out the gradients of the non-[RET]
tokens in the training loop, so prevent them from changing. It's quite complicated to make just the [RET]
embedding row trainable, so I elected to this instead.
Hope that answers your question!
That's the exact point I have overlooked. Thanks for your kind reply.
Thanks for bringing this up, as it is a subtle point. When the token embeddings are resized (to include the extra
[RET]
token):
requires_grad
is automatically set to true for them (you can verify this during training when it prints out the params and whether they are trainable). This is why we have to zero out the gradients of the non-[RET]
tokens in the training loop, so prevent them from changing. It's quite complicated to make just the[RET]
embedding row trainable, so I elected to this instead.Hope that answers your question!
i print the requires_grad. it's false.
for param in self.input_embeddings.parameters(): print(param.requires_grad)
assert param.grad.shape[0] == len(tokenizer)
AttributeError: 'NoneType' object has no attribute 'shape'
https://github.com/kohjingyu/fromage/blob/964bd157ba3ac24824d525d7f90037eef42aa7c2/fromage/models.py#L61
Hello kohjingyu, thanks for your great work! I'm a little confused about the frozen LLM model. It seems all the parameters in LLM are frozen. Should the
input_embedding.weight
in the[RET]
position be learnable? I could not find such code asself.input_embedding.requires_grad_(True)
orself.input_embedding.weight.requires_grad = True
. On the other hand, I see the gradient ofinput_embedding
is adjusted in param.grad[mask, :] = 0 Please point out if I neglect some important information.Best regards.