jiaosiyu1999 / MAFT

46 stars 2 forks source link

assert mask.shape[-1] - xatten.shape[0] == 99 #7

Closed PeterVennerstrom closed 2 months ago

PeterVennerstrom commented 2 months ago

Getting an assertion error in third_party/CLIP/clip/model.py line 285. The code expects a 99 offset between mask.shape[-1] and xatten.shape[0].

I'm getting: AssertionError: (903, 901)

Here is my slightly adjusted code based on demo code shared in another issue:

import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F

from freeseg.modeling.clip_adapter.text_prompt import PredefinedPromptExtractor
from third_party.CLIP import clip

model, preprocess = clip.load("ViT-B/16", device="cpu")
model.visual.transformer.start_layers = 8  # start mask-attention layer

prompt_learner = PredefinedPromptExtractor("")
IPCLIP = model.visual

# load IP-CLIP fine-tuned weights
saved_state_dict = torch.load("MAFT_ViTb.pt")
new_params = IPCLIP.state_dict()
upgrade_list = []
for i in saved_state_dict:
    if "IPCLIP." in i:
        i_ = i.replace("IPCLIP.", "")
        new_params[i_] = saved_state_dict[i]
        upgrade_list.append(i_)
IPCLIP.load_state_dict(new_params)

IMAGE = "./coco_demo/37777/000000037777.jpg"
MASK_1 = "./coco_demo/37777/oven_1.png"
MASK_2 = "./coco_demo/37777/orange_1.png"
MASK_3 = "./coco_demo/37777/refrigerator_1.png"

# image: 1*3*480*480
image = preprocess(Image.open(IMAGE)).unsqueeze(0)
image = F.interpolate(
    image,
    size=(480, 480),
    mode="bilinear",
    align_corners=False,
)
text = ["oven", "orange", "car", "refrigerator"]
# mask: 1*n*480*480
mask1 = torch.as_tensor(cv2.imread(MASK_1, 0)).unsqueeze(0)
mask2 = torch.as_tensor(cv2.imread(MASK_2, 0)).unsqueeze(0)
mask3 = torch.as_tensor(cv2.imread(MASK_3, 0)).unsqueeze(0)
mask = torch.cat([mask1, mask2, mask3], dim=0).unsqueeze(0) / 255.0

with torch.no_grad():
    image_features = IPCLIP(image, mask)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = prompt_learner(text, model)

    logits_per_image = 100.0 * image_features.matmul(text_features.transpose(-1, -2))
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs, np.argmax(probs, axis=2))

Here is a link to the image and masks used.

With the assertion commented out, the code runs and reasonable results for the image and masks are returned.

Label probs: [[[8.8231611e-01 5.0742579e-05 9.3949282e-05 1.1753930e-01]
  [6.2247657e-04 9.9655068e-01 1.4205566e-03 1.4062579e-03]
  [6.3008839e-01 1.8207520e-02 8.3584106e-03 3.4334564e-01]]] 

[[0 1 0]]

Appreciate your thoughts on the assertion error. Thanks!

jiaosiyu1999 commented 2 months ago

Hi, IP-CLIP performs a repeat operation on the cls token. In the default setting, the Proposal Generator generates 100 masks, so the cls token needs to repeat 99 times. That is, mask.shape[-1] - xatten.shape[0] == 99 in line 285.

There are only 3 masks in your code, so mask.shape[-1] - xatten.shape[0] == 2 and raises the AssertionError. You can comment out this assertion to avoid the error, and this will not affect the correctness of the code.