google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.2k stars 147 forks source link

tokenization error when using msiglip #126

Open simran-khanuja opened 1 month ago

simran-khanuja commented 1 month ago

Hi, I get this error when preprocessing text using the mSigLIP model. Any idea what may be wrong? I didn't change anything in the demo colab

Traceback (most recent call last):
  File "/home/${USER}/babelnet/labels/msiglip.py", line 131, in <module>
    _, ztxt, out = model.apply({'params': params}, None, txts)
  File "/home/${USER}/babelnet/big_vision/big_vision/models/proj/image_text/two_towers.py", line 55, in __call__
    ztxt, out_txt = text_model(text, **kw)
  File "/home/${USER}/babelnet/big_vision/big_vision/models/proj/image_text/text_transformer.py", line 64, in __call__
    x = out["embedded"] = embedding(text)
  File "/home/${USER}/miniconda3/envs/msiglip/lib/python3.10/site-packages/flax/linen/linear.py", line 1106, in setup
    self.embedding = self.param(
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (256000, 1152) but got shape (250000, 1152) instead for parameter "embedding" in "/txt/Embed_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
merveenoyan commented 3 weeks ago

@simran-khanuja not very sure but I came across above issue with the latest released SigLIP (so400m patch16) and for me the tokenizer was different + vocab dim should've been 256k (I fixed it during initialization).

Maybe try a different tokenizer (I confirmed with google folks there seems to be a mistake with config for my case, it might be different for you as well), MSigLIP tokenizer spiece model exists here so swapping tokenizer should work. Also if you're ok with using PyTorch mSigLIP is implemented at transformers, you can use that for time being here.

edit: the new notebook seems to be fixed