Heidelberg-NLP / MM-SHAP

This is the official implementation of the paper "MM-SHAP: A Performance-agnostic Metric for Measuring Multimodal Contributions in Vision and Language Models & Tasks"
https://aclanthology.org/2023.acl-long.223/
MIT License
17 stars 4 forks source link

How is Explainer getting image data in CLIP? #3

Closed skshvl closed 10 months ago

skshvl commented 10 months ago

As part of my thesis, I am trying to understand the code in mm-shap_clip_dataset.py, and I'm a bit stumped at the following section, in which we generate the tensor X which is passed to the Explainer instance to generate masks and then SHAP values. I am concerned that in the code as it is written here, X ends up containing no image data -- or at least, I do not understand how it does.

# shap values need one sentence for transformer
            for k, sentence in enumerate(test_sentences):

                try:  # image feature extraction can go wrong
                    inputs = processor(
                        text=sentence, images=image, return_tensors="pt", padding=True
                    )
                except:
                    continue
                model_prediction = model(**inputs).logits_per_image[0,0].item()

                text_length_tok = inputs.input_ids.shape[1]
                p = int(math.ceil(np.sqrt(text_length_tok)))
                patch_size = 224 // p
                image_token_ids = torch.tensor(
                    range(1, p**2+1)).unsqueeze(0) # (inputs.pixel_values.shape[-1] // patch_size)**2 +1
                # make a cobination between tokens and pixel_values (transform to patches first)
                X = torch.cat(
                    (inputs.input_ids, image_token_ids), 1).unsqueeze(1)

                # create an explainer with model and image masker
                explainer = shap.Explainer(
                    get_model_prediction, custom_masker, silent=True)
                shap_values = explainer(X)
                mm_score = compute_mm_score(text_length_tok, shap_values)

Specifically, X consists of a concatenation of two things: image_token_ids (image) and inputs.input_ids (text)

                # make a cobination between tokens and pixel_values (transform to patches first)
                X = torch.cat(
                    (inputs.input_ids, image_token_ids), 1).unsqueeze(1)

But while the inputs object contains both text and image data, image_token_ids seems to take no image data from the inputs object's pixel_values (other than in its shape).

image_token_ids = torch.tensor(
                    range(1, p**2+1)).unsqueeze(0) # (inputs.pixel_values.shape[-1] // patch_size)**2 +1

Then, by the time we generate the concatenation X, we are combining inputs.input_ids and image_token_ids without having added anything to image_token_ids.

Right after X is assigned, we create an Explainer and pass X to it.


                # create an explainer with model and image masker
                explainer = shap.Explainer(
                    get_model_prediction, custom_masker, silent=True)
                shap_values = explainer(X)

So what I am trying to understand is how does the explainer gets any access to the image data when X consists only of the text data + the blank image_token_ids? Would appreciate any input, thanks!

LetiP commented 10 months ago

Hi, thanks for the excellent description of your question. I understand it is complicated. The complicated solution is because of the complexity of the shap library. This is the way we tried to implement what we wanted and yes, I've spent a lot of time reading the shap library to get it done. And surely there is a better way to do it, so feel free to update the code to your needs.

Your question is very related to my answer to this question #2 , so would you please be so kind to read that first and I will try to give you a more personalized answer after you read that answer and you tell me the remaining confusion? Thanks!

skshvl commented 10 months ago

@LetiP Thank you, I actually think I understand now after thinking more about your explanation of get_model_prediction. It seems like inputs is actually a variable outside that function that get_model_prediction() is able to access by virtue of it being a global variable within the .py file being run. So the image data does not need to be directly passed to Explainer, since Explainer accesses get_model_prediction() which accesses inputs which has the image pixel data that can be masked. Thanks!

LetiP commented 10 months ago

Exactly. Thanks, you got it! 👏