google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.32k stars 152 forks source link

Causal Mask leads to NaN gradients (Paligemma) #129

Closed sanjayss34 closed 2 months ago

sanjayss34 commented 2 months ago

On a TPU v3-8, adding these two lines at the end of embed_image_and_text() (before the return statement) in models/proj/paligemma/paligemma.py gives me NaN gradients on the first training step:

mask_ar = jnp.full(text.shape, 1)
mask_ar = jnp.concatenate([jnp.full((zimg.shape[0], zimg.shape[1]), 1), mask_ar], axis=1)

I'm able to avoid this issue by setting big_neg (https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/models/ppp/gemma.py#L246) to a less negative number like -10, but obviously this is undesirable. Also, without the two lines (i.e. with default masking), I do not get the error. I've gotten this error when starting from both pt_224.npz and mix_224.npz checkpoints.

sanjayss34 commented 2 months ago

This issue was caused by initializing the projector (from SigLIP to Gemma) with all zero weights (https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/models/vit.py#L197). Setting that variable to False seems to resolve the issue. (I did not want to initialize the projector with pre-trained weights in my experiments.)