huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
645 stars 36 forks source link

Inference from a reload quantized open clip model (by .load_state_dict) resulted in IndexError #217

Open kechan opened 1 week ago

kechan commented 1 week ago

transformers 4.41.2 optimum-quanto 0.2.1 torch 2.3.1

Python 3.10.14

I performed this on a recent google GCP VM with Nvidia driver setup and basic torch sanity test passing.

I tried to quantize a HF ClipModel "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" and then save the state_dict() on to disk. I then used "meta" device to initiate the model without the weight (this helps with memory conversation on a very weak box). Then I followed these calls to load the saved quantized weight into this "lighter" model:

with torch.device("meta"):
  reload_model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")

quantize(reload_model, weights=qint8, activations=None)

reload_model.to_empty(device=torch.device("cpu"));

reload_model.load_state_dict(torch.load('quantized_clip_model.pth'), strict=True, assign=True)

reload_model.to(device)
reload_model.eval()

Unfortunately, when I tried to make inference on an image input to obtain its vector representation, I got error:

File ".../vss_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 191, in forward embeddings = embeddings + self.position_embedding(self.position_ids) File ".../vss_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File ".../vss_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) File ".../vss_env/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward return F.embedding( File ".../vss_env/lib/python3.10/site-packages/torch/nn/functional.py", line 2264, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) IndexError: index out of range in self

I have attached the a full standalone test.

kechan commented 1 week ago

I found out the root cause, the impl of CLIPVisionEmbeddings at init:

self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

and this doesnt appear to be saved to state dict, such that when I reinstantiate the model on "meta", and load the state dict in, it just got random values. To fix, I tried to manually assign the right value to position_ids (0, 1, 2 ..) after load dict:

num_positions = reload_model.vision_model.embeddings.num_positions
reload_model.vision_model.embeddings.register_buffer('position_ids', torch.arange(num_positions).expand((1, -1)), persistent=False)

and the inference went through without an error now.

This looks like a hack fix and I remembered seeing examples where this apparently worked without such issue (a transformers inside whatever that model was). So not sure why I am hitting this in the CLIPModel. So I am not sure if there's indeed a bug with the library, or this is fact of life.

kechan commented 4 days ago

I found this to be a general issue. if I use open_clip python library to do the same (instead of huggingface CLIPModel), then the model's attn_mask suffers the same problem due to register_buffer. if one "casually" just use the pattern of calls to try to sidestep loading of all pre-quantized float32 weights, and then use load_state_dict to only load the quantized weights, debugging may be needed to find out all the register_buffer variables and be sure to initialize them properly.

I thought this should be a valid use case to think about, since one big advantage is to be able to run on extremely constrained env (disk, RAM, etc).