Nota-NetsPresso / BK-SDM

A Compressed Stable Diffusion for Efficient Text-to-Image Generation [ECCV'24]
Other
238 stars 16 forks source link

Question of Dreambooth evaluation #36

Closed ofzlo closed 9 months ago

ofzlo commented 11 months ago

Hi, thank you for sharing your awesome work ☺️ How to reproduce your Dreambooth quantitative performance in Table. 5? Would you provide the evaluation code?

bokyeong1015 commented 11 months ago

Hi, thanks for your interest.

We’ve followed the protocol described in “Sect. 4.1 Dataset and Evaluation” in the DreamBooth paper.

Hope the following information can be helpful. In order to share the code, some cleanup and refactoring are necessary; however, due to other projects and workload, it might take time for us to provide the full evaluation script/code. We kindly ask for your understanding.

1. Generate 3000 images (30 subjects × 25 prompts × 4 generated images)

2. Compute {DINO, CLIP-I, CLIP-T} scores over 3000 images

score_txt_img_list = [] # CLIP-T scores over 30 subjects and 25 prompts
score_img_img_list = [] # CLIP-I or DINO scores over 30 subjects and 25 prompts

for subject in subject_list:
    # Compute embeddings of real images
    real_img_list = get_real_img_list(subject)        
    real_img_embs = []
    for real_img in real_img_list:
        real_img_emb = get_img_emb(real_img, 'clip') # CLIP or DINO embedding of a single real image
        real_img_embs.append(real_img_emb)

    # Compute embeddings of prompts & generated images
    prompt_list = get_prompt_list(subject)
    for prompt in prompt_list:
        txt_emb = get_txt_emb(prompt, 'clip') # CLIP embedding of a single text prompt        

        score_txt_img = 0.0
        score_img_img = 0.0
        gen_img_list = get_gen_img_list(subject, prompt)
        for gen_img in gen_img_list:
            gen_img_emb = get_img_emb(gen_img, 'clip') # CLIP or DINO embedding of a single generated image
            score_txt_img += (txt_emb @ gen_img_emb.T) # CLIP-T score for a single pair

            for real_img_emb in real_img_embs:
                score_img_img += (real_img_emb @ gen_img_emb.T) # CLIP-I or DINO score for a single pair

        score_txt_img /= len(gen_img_list) 
        score_img_img /= (len(gen_img_list) * len(real_img_list))

        # Collect per-subject-per-prompt CLIP-T and {CLIP-I or DINO} scores
        score_txt_img_list.append(score_txt_img)
        score_img_img_list.append(score_img_img)

# Compute final average CLIP-T and {CLIP-I or DINO} scores
final_score_txt_img = sum(score_txt_img_list) / len(score_txt_img_list)
final_score_img_img = sum(score_img_img_list) / len(score_img_img_list)
ofzlo commented 11 months ago

Hi, thank you for kind response. However, I have another question, do you also use clip-g-14 when reproducing clip-i and clip-t just like clip-score reproduction for bk-sdm? Or use vit-s/16 DINO for DreamBooth? Otherwise, could you provide more details about the get_img_emb function?

I found it :) However, It would be appreciated if you could give more information about get_img_emb details.

bokyeong1015 commented 10 months ago

Hi, apologies for the late response. I was away on a business trip followed by a personal vacation. I've tried to provide a clearer explanation by refactoring our evaluation code. Hope the example below will be helpful.

Reference: For DINO embeddings, the validation part from DINO's eval_linear.py was used.


# ------------------------------------------------------------------------------------
# Copyright (c) 2023 Nota Inc. All Rights Reserved.
# Code modified from 
# [dino] https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/eval_linear.py#L196-L214
# [clip] https://github.com/mlfoundations/open_clip/tree/37b729bc69068daa7e860fb7dbcf1ef1d03a4185#usage
# ------------------------------------------------------------------------------------

import argparse
import torch
from PIL import Image
from torchvision import transforms as pth_transforms
import open_clip

class ImageEmbedder:
    def __init__(self, dino_model, dino_n_last_blocks, dino_avgpool_patchtokens,
                 clip_model, clip_data, device):     
        self.device = device   
        # dino-related part
        self.dino_model = torch.hub.load('facebookresearch/dino:main', dino_model)
        self.dino_model.to(device)
        self.dino_model.eval()
        self.dino_preprocess = pth_transforms.Compose([
            pth_transforms.Resize(256, interpolation=3),
            pth_transforms.CenterCrop(224),
            pth_transforms.ToTensor(),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.dino_n_last_blocks = dino_n_last_blocks
        self.dino_avgpool_patchtokens = dino_avgpool_patchtokens        
        # clip-related part
        self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model,
                                                                                          pretrained=clip_data,
                                                                                          device=device)        

    def get_img_emb(self, pil_img, model_type):
        with torch.no_grad():
            if model_type == 'dino':
                image = self.dino_preprocess(pil_img).unsqueeze(0).to(self.device)
                intermediate_output = self.dino_model.get_intermediate_layers(image, self.dino_n_last_blocks)
                output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
                if self.dino_avgpool_patchtokens:
                    output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
                    output = output.reshape(output.shape[0], -1)                
                image_feature = output            
            elif model_type == 'clip':
                image = self.clip_preprocess(pil_img).unsqueeze(0).to(self.device)
                image_feature = self.clip_model.encode_image(image)
            else:
                raise NotImplementedError
        # print(image_feature.shape)
        image_feature /= image_feature.norm(dim=-1, keepdim=True)
        image_feature = image_feature.cpu().numpy()        
        return image_feature

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dino_model", type=str, default='dino_vits16', help='see https://github.com/facebookresearch/dino#pretrained-models-on-pytorch-hub')
    parser.add_argument('--dino_n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
        for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
    parser.add_argument('--dino_avgpool_patchtokens', default=False, type=bool,
        help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
        We typically set this to False for ViT-Small and to True with ViT-Base.""")
    parser.add_argument("--clip_model", type=str, default='ViT-g-14')
    parser.add_argument("--clip_data", type=str, default='laion2b_s34b_b88k')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use, cuda:gpu_number or cpu')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    embedder = ImageEmbedder(args.dino_model, args.dino_n_last_blocks, args.dino_avgpool_patchtokens,
                             args.clip_model, args.clip_data, args.device)

    ## test DINO and CLIP image embeddings
    for model_type in ['dino', 'clip']:
        emb_cat0 = embedder.get_img_emb(Image.open('dreambooth/dataset/cat/00.jpg'), model_type)
        emb_cat1 = embedder.get_img_emb(Image.open('dreambooth/dataset/cat/01.jpg'), model_type)
        emb_teapot = embedder.get_img_emb(Image.open('dreambooth/dataset/teapot/00.jpg'), model_type)

        probs_cat0_cat1 = (emb_cat0 @ emb_cat1.T)[0][0]    
        probs_cat0_teapot = (emb_cat0 @ emb_teapot.T)[0][0]
        print(f"[{model_type}] cat0 vs cat1: {probs_cat0_cat1}")
        print(f"[{model_type}] cat0 vs teapot: {probs_cat0_teapot}") 

    # [dino] cat0 vs cat1: 0.8219528794288635
    # [dino] cat0 vs teapot: 0.24008619785308838
    # [clip] cat0 vs cat1: 0.9108940362930298
    # [clip] cat0 vs teapot: 0.41291046142578125
ofzlo commented 10 months ago

HI, thanks for your response. I have confirmed that the code works well. However, when I executed scripts/generate_after_full_ft.sh as you instructed, I noticed that the images were generated differently from my intention. In other words, I observed some distortion in the images produced through the pipeline. I suspect that the error might be occurring due to the image size argument, so I removed the image size option and rewrote the pipeline code (please refer to the code below). However, when I conducted an evaluation using the images generated through this inference pipeline, I obtained values of DINO 0.696, CLIP-I 0.687, and CLIP-T 0.240. These scores are somewhat different from the DINO 0.723, CLIP-I 0.717, and CLIP-T 0.260 scores presented in your paper. Could you help me identify what the issue might be? I would appreciate your response.

pipeline = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16).to(args.device)
generator = torch.Generator(args.device).manual_seed(args.seed)
img = pipeline(val_prompt, num_inference_steps=25, generator=generator).images[0]
bokyeong1015 commented 9 months ago

Hi, The below images are obtained by just running scripts/finetune_full.sh -> scripts/generate_after_full_ft.sh. For the original model, UNET_TYPE="nota-ai/bk-sdm-base" changes to "CompVis/stable-diffusion-v1-4" at this line. image

when I executed scripts/generate_after_full_ft.sh as you instructed, I noticed that the images were generated differently from my intention.

  • It would be better if you could share your results and the exact procedure you've conducted.

I suspect that the error might be occurring due to the image size argument, so I removed the image size option

  • It would be helpful if you could explain why you think the image size option matters.

rewrote the pipeline code (please refer to the code below)

  • (Our paper) We use DPM-Solver [36, 37] for DreamBooth results.
  • For this, by referring to our code [here and here], properly adjust your code.

I obtained values of DINO 0.696, CLIP-I 0.687, and CLIP-T 0.240.

  • It would be better if you also measure the scores of the original SD-v1.4 with your code.
ofzlo commented 9 months ago

Thank you for your response. If I learn more precise details about the issue I mentioned earlier, I will register it as a new issue at that time :)