Closed sanjayss34 closed 2 months ago
This issue was caused by initializing the projector (from SigLIP to Gemma) with all zero weights (https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/models/vit.py#L197). Setting that variable to False seems to resolve the issue. (I did not want to initialize the projector with pre-trained weights in my experiments.)
On a TPU v3-8, adding these two lines at the end of
embed_image_and_text()
(before the return statement) inmodels/proj/paligemma/paligemma.py
gives me NaN gradients on the first training step:I'm able to avoid this issue by setting
big_neg
(https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/models/ppp/gemma.py#L246) to a less negative number like -10, but obviously this is undesirable. Also, without the two lines (i.e. with default masking), I do not get the error. I've gotten this error when starting from bothpt_224.npz
andmix_224.npz
checkpoints.