jiaosiyu1999 / MAFT

46 stars 2 forks source link

how to use ip-encoder #5

Closed Justhuayu closed 2 months ago

Justhuayu commented 5 months ago

I want to use IP Encoder, how can I load the pre trained model of this module separately?

jiaosiyu1999 commented 5 months ago

Please refer to the following code, where maski.png represents the i-th mask.

import torch, cv2
from torch.nn import functional as F
from PIL import Image
from  third_party.CLIP import clip

from freeseg.modeling.clip_adapter.clip import build_clip_model, crop_with_mask, CLIP
from freeseg.modeling.clip_adapter.text_prompt import PredefinedPromptExtractor

model, preprocess = clip.load("ViT-B/16", device="cpu")
model.visual.transformer.start_layers = 8   # start mask-attention layer

prompt_learner = PredefinedPromptExtractor('')
IPCLIP = model.visual

# load IP-CLIP fine-tuned weights
saved_state_dict = torch.load("out/MAFT_ViTb.pt")
new_params = IPCLIP.state_dict()
upgrade_list = []
for i in saved_state_dict:
    if "IPCLIP." in i:
        i_ = i.replace("IPCLIP.", "")
        new_params[i_] = saved_state_dict[i]
        upgrade_list.append(i_)
IPCLIP.load_state_dict(new_params)
print(upgrade_list)

# image: 1*3*480*480
image = preprocess(Image.open("img.png")).unsqueeze(0)
image = F.interpolate(image, size=(480, 480), mode="bilinear", align_corners=False,)
text = ["sheep", "grass", "car", "house"]
# mask: 1*n*480*480
mask1 = torch.as_tensor(cv2.imread("mask1.png", 0)).unsqueeze(0)
mask2 = torch.as_tensor(cv2.imread("mask2.png", 0)).unsqueeze(0)
mask3 = torch.as_tensor(cv2.imread("mask3.png", 0)).unsqueeze(0)
mask = torch.cat([mask1, mask2, mask3], dim = 0).unsqueeze(0) / 255.0

with torch.no_grad():
    image_features = IPCLIP(image, mask)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = prompt_learner(text, model)

    logits_per_image = 100.0 * image_features.matmul(text_features.transpose(-1,-2))
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  
Justhuayu commented 4 months ago

I would like to replace ViT-B/16 with ViT-L/ 14@336px Do you have a corresponding pt model?

jiaosiyu1999 commented 3 months ago

The pretrained weights can be found at MAFT_Vitl.pt

And you need to set start_layers = 16 by model.visual.transformer.start_layers = 16