Algolzw / daclip-uir

[ICLR 2024] Controlling Vision-Language Models for Universal Image Restoration. 5th place in the NTIRE 2024 Restore Any Image Model in the Wild Challenge.
https://algolzw.github.io/daclip-uir
MIT License
582 stars 30 forks source link

task_name #40

Open wjkbigface opened 3 months ago

wjkbigface commented 3 months ago

Hello! What does it mean to achieve task_name in print(f"Task: {task_name}: {index]} - {text_probs[0][index]}")?

Algolzw commented 3 months ago

Hello! Where is this code? (I forget why but I guess I wanted to print whether the predicted degradation type is correct.)

wjkbigface commented 3 months ago

import torch from PIL import Image import open_clip

checkpoint = 'pretrained/daclip_ViT-B-32.pt' model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=checkpoint) tokenizer = open_clip.get_tokenizer('ViT-B-32')

image = preprocess(Image.open("haze_01.png")).unsqueeze(0) degradations = ['motion-blurry','hazy','jpeg-compressed','low-light','noisy','raindrop','rainy','shadowed','snowy','uncompleted'] text = tokenizer(degradations)

with torch.no_grad(), torch.cuda.amp.autocast(): text_features = model.encode_text(text) image_features, degra_features = model.encode_image(image, control=True) degra_features /= degra_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * degra_features @ text_features.T).softmax(dim=-1)
index = torch.argmax(text_probs[0])

print(f"Task: {task_name}: {degradations[index]} - {text_probs[0][index]}")there

Algolzw commented 3 months ago

Ok, according to the code, I think I just want to know whether the predicted degradation type is correct. The task name should be pre-defined like image dehazing.