Tangshitao / MVDiffusion

MVDiffusion: Enabling Holistic Multi-view Image Generation with Correspondence-Aware Diffusion, NeurIPS 2023 (spotlight)
498 stars 27 forks source link

Any plan for releasing the evaluation code? #17

Closed chengzhag closed 1 year ago

chengzhag commented 1 year ago

Hi Shitao Tang, Thanks for sharing your amazing work!

I've dived into the evaluation protocol of image generation metrics (FID, IS, CS), but didn't find if you are evaluating panorama generation in panoramic or perspective format. Since you mentioned that for text2light and SD (pano) baselines, perspective images are obtained with projection. It would be great if you could provide this information. It would be even better if you could release the code for evaluation.

Thank you for your attention to this matter.

Tangshitao commented 1 year ago

The code is simple. I pose code snippets here for reference.

`

from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048).cuda()
data_loader = torch.utils.data.DataLoader(
    gt_dataset, num_workers=num_workers, batch_size=num_workers, prefetch_factor=prefetch_factor)
gen_dataloader = torch.utils.data.DataLoader(
    gen_dataset, num_workers=num_workers, batch_size=num_workers, prefetch_factor=prefetch_factor)
# Accumulate statistics.

mu = 0
sigma = 0
for images_gen in tqdm.tqdm(gen_dataloader):
    images_gen = rearrange(images_gen, 'b l h w c -> (b l) c h w')
    fid.update(images_gen.cuda(), real=False)

for images_gt in tqdm.tqdm(data_loader):  # , fake_dataloader)):
    images_gt = rearrange(images_gt.cuda(), 'b l h w c -> (b l) c h w')
    fid.update(images_gt, real=True)

`

chengzhag commented 1 year ago

Got it. Thanks for the prompt reply. Can I presume it's evaluated in perspective views? Since the data from the dataloader has 'l' images, which seem to be different views of a single panorama.

Tangshitao commented 1 year ago

yes.

chengzhag commented 1 year ago

Thanks a lot. I'm closing the issue.

chengzhag commented 1 year ago

Hi Shitao, Sorry for reopening the issue. You proposed "overlapping PSNR" to evaluate the consistency between overlapping views. Would you mind releasing the evaluation code, especially this part?

Tangshitao commented 1 year ago

I still pose code snippets here.

def compute_psnr_masked(img1, img2, mask):
    img1_masked = img1[mask]/255
    img2_masked = img2[mask]/255

    mse = np.mean((img1_masked - img2_masked)**2)

    if mse == 0:
        return float('inf')

    max_pixel_value = 255.0
    psnr = -10*np.log10(mse)

    return psnr

def compute_psnr(img1, img2, K, R):
    im_h, im_w, _ = img1.shape
    homo_matrix = K@R@np.linalg.inv(K)
    mask = np.ones((im_h, im_w))
    img_warp2 = cv2.warpPerspective(img2, homo_matrix, (im_w, im_h))
    mask = cv2.warpPerspective(mask, homo_matrix, (im_w, im_h)) == 1
    psnr = compute_psnr_masked(img1, img_warp2, mask)

    return psnr

def calculate_inception_stats(
    dataset,
    num_workers=2, prefetch_factor=2, device=torch.device('cuda'),
):

    data_loader = torch.utils.data.DataLoader(
        dataset, num_workers=num_workers, batch_size=1, prefetch_factor=prefetch_factor)

    psnrs=[]
    for _iter, (images, R, K) in tqdm.tqdm(enumerate(data_loader)):
        images = images.numpy()
        R=R.numpy()
        K=K.numpy()
        for b in range(images.shape[0]):
            for i in range(images.shape[1]):
                idx1=i
                if i==images.shape[1]-1:
                    idx2=0
                else:
                    idx2=i+1
                R_rel=np.linalg.inv(R[b, idx1])@R[b, idx2]

                psnr=compute_psnr(images[b, idx1], images[b, idx2], K[b, i], R_rel)
                psnrs.append(psnr)
chengzhag commented 1 year ago

Thank Shitao for your code snippets.

I went through the test process with the pretrained checkpoint and got the evaluation results: FID: 19.47 IS: 7.40 CS: 24.70 overlapping PSNR: 24.25

Can you confirm if these results are correct?

The full code I used is attached below:

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.image import PeakSignalNoiseRatio
import torch
import argparse
import tqdm
import os
import cv2
import numpy as np
from einops import rearrange
from src.dataset.utils import get_K_R
from src.dataset.Matterport3D import warp_img
from src.models.pano.utils import get_correspondences
from collections import defaultdict
import torch.nn.functional as F

torch.manual_seed(0)

def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--result_dir', type=str, required=True,
                        help='Directory where results are saved')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='Number of workers for dataloader')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size')

    return parser.parse_args()

def compute_psnr_masked(img1, img2, mask):
    img1_masked = img1[mask]/255
    img2_masked = img2[mask]/255

    mse = np.mean((img1_masked - img2_masked)**2)

    if mse == 0:
        return float('inf')

    max_pixel_value = 255.0
    psnr = -10*np.log10(mse)

    return psnr

def compute_psnr(img1, img2, K, R):
    im_h, im_w, _ = img1.shape
    homo_matrix = K@R@np.linalg.inv(K)
    mask = np.ones((im_h, im_w))
    img_warp2 = cv2.warpPerspective(img2, homo_matrix, (im_w, im_h))
    mask = cv2.warpPerspective(mask, homo_matrix, (im_w, im_h)) == 1
    psnr = compute_psnr_masked(img1, img_warp2, mask)

    return psnr

class MVResultDataset(torch.utils.data.Dataset):
    def __init__(self, result_dir):
        self.result_dir = result_dir
        self.scenes = os.listdir(result_dir)

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        num_views = 8
        images_gt = []
        images_gen = []
        Rs = []
        cameras = defaultdict(list)
        for i in range(num_views):
            for images, ext in zip([images_gt, images_gen], ["_natural.png", ".png"]):
                img = cv2.imread(os.path.join(self.result_dir, self.scenes[idx], f"{i}{ext}"))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                images.append(img)

            theta = (360 / num_views * i) % 360
            K, R = get_K_R(90, theta, 0, *img.shape[:2])

            Rs.append(R)
            cameras['height'].append(img.shape[0])
            cameras['width'].append(img.shape[1])
            cameras['FoV'].append(90)
            cameras['theta'].append(theta)
            cameras['phi'].append(0)

        images_gt = np.stack(images_gt, axis=0)
        images_gen = np.stack(images_gen, axis=0)
        K = np.stack([K]*len(Rs)).astype(np.float32)
        R = np.stack(Rs).astype(np.float32)
        for k, v in cameras.items():
            cameras[k] = np.stack(v)

        prompt_dir = os.path.join(self.result_dir, self.scenes[idx], "prompt.txt")
        prompt = []
        with open(prompt_dir, 'r') as f:
            for line in f:
                prompt.append(line.strip())

        return {
            'images_gt': images_gt,
            'images_gen': images_gen,
            'K': K,
            'R': R,
            'cameras': cameras,
            'prompt': prompt
        }

if __name__ == '__main__':
    args = parse_args()

    fid = FrechetInceptionDistance(feature=2048).cuda()
    inception = InceptionScore().cuda()
    cs = CLIPScore().cuda()
    psnr = PeakSignalNoiseRatio(data_range=1.0)
    psnrs=[]

    dataset = MVResultDataset(args.result_dir)
    data_loader = torch.utils.data.DataLoader(
        dataset, num_workers=args.num_workers, batch_size=args.batch_size)

    for batch in tqdm.tqdm(data_loader):
        images_gt = rearrange(batch['images_gt'].cuda(), 'b l h w c -> (b l) c h w')
        images_gen = rearrange(batch['images_gen'].cuda(), 'b l h w c -> (b l) c h w')
        fid.update(images_gt, real=True)
        fid.update(images_gen, real=False)
        inception.update(images_gen)

        prompt_reshape = sum(map(list, zip(*batch['prompt'])), [])
        cs.update(images_gen, prompt_reshape)

        correspondences = get_correspondences(
            batch['R'], batch['K'],
            batch['cameras']['height'][0, 0].item(), batch['cameras']['width'][0, 0].item())
        correspondences[..., 0] = correspondences[..., 0] / (batch['cameras']['width'][0, 0] - 1) * 2 - 1
        correspondences[..., 1] = correspondences[..., 1] / (batch['cameras']['height'][0, 0] - 1) * 2 - 1
        masks = ((correspondences > -1) & (correspondences < 1)).all(-1)
        # overlap_mat = masks.any(-1).any(-1)
        m = correspondences.shape[1]
        for i in range(m):
            j = (i + 1) % m
            xy_l = correspondences[:, j, i]
            mask = masks[:, j, i]
            image_src = batch['images_gen'][:, i].permute(0, 3, 1, 2).float() / 255
            image_warp = F.grid_sample(image_src, xy_l, align_corners=True)
            image_tgt = batch['images_gen'][:, j].permute(0, 3, 1, 2).float() / 255
            mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
            psnr.update(image_warp[mask], image_tgt[mask])

        images = batch['images_gen'].numpy()
        R=batch['R'].numpy()
        K=batch['K'].numpy()
        for b in range(images.shape[0]):
            for i in range(images.shape[1]):
                idx1=i
                if i==images.shape[1]-1:
                    idx2=0
                else:
                    idx2=i+1
                R_rel=np.linalg.inv(R[b, idx1])@R[b, idx2]

                p=compute_psnr(images[b, idx1], images[b, idx2], K[b, i], R_rel)
                psnrs.append(p)

    print(f"FID: {fid.compute()}")
    print(f"IS: {inception.compute()}")
    print(f"CS: {cs.compute()}")
    print(f"PSNR: {psnr.compute()}")
    print(f"PSNR (author's code): {np.mean(psnrs)}")
Tangshitao commented 1 year ago

The clip score seems not correct here. Do you combine all the texts and then compute the combined text with each perspective images? For PSNR, it's also a bit lower, but I guess that's due to variance.

chengzhag commented 1 year ago

The CLIP score is currently calculated between each perspective image and its corresponding text prompt. Should I "combine all the texts and then compute the combined text with each perspective images"?

Tangshitao commented 1 year ago

I'm not sure what happened in your codes, but here are my codes

for images_gen, prompt, prompts in tqdm.tqdm(gen_dataloader):
        images_gen=images_gen.cuda()

        for b in range(images_fake.shape[0]):
            score=0
            for i in range(images_gen.shape[1]):
                score += metric(images_gen[b, i], prompts[i][b])
            score/=images_gen.shape[1]
            scores.append(score.cpu().detach().numpy())
print(np.mean(scores))
chengzhag commented 1 year ago

Thanks again for the help. I tried the provided code using torchmetrics.multimodal.clip_score.CLIPScore with default model openai/clip-vit-large-patch14, and got the same result:

image

Would it be possible for I am using a different model or metric implementation, that causes the difference? Would you mind sharing your library and settings for evaluating CS?

Tangshitao commented 1 year ago

I also use the default model. Can you give me some sample images for sanity check?

chengzhag commented 1 year ago

Sure. Here is a sample test result for 001499e78f734a4eaf14e727140baf21_skybox0_sami.jpg_001499e78f734a4eaf14e727140baf21: 001499e78f734a4eaf14e727140baf21_skybox0_sami.jpg_001499e78f734a4eaf14e727140baf21.zip

Tangshitao commented 1 year ago

I acctually use 'openai/clip-vit-base-patch16', and the Clip score for this is 31.68.

chengzhag commented 1 year ago

Thanks a lot for confirming. I tried the same model, and the CS is now 29.93, close to your results.

Thanks again for your time and effort. Closing the issue.

Chjx395 commented 3 months ago

Thank Shitao for your code snippets.

I went through the test process with the pretrained checkpoint and got the evaluation results: FID: 19.47 IS: 7.40 CS: 24.70 overlapping PSNR: 24.25

Can you confirm if these results are correct?

The full code I used is attached below:

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.image import PeakSignalNoiseRatio
import torch
import argparse
import tqdm
import os
import cv2
import numpy as np
from einops import rearrange
from src.dataset.utils import get_K_R
from src.dataset.Matterport3D import warp_img
from src.models.pano.utils import get_correspondences
from collections import defaultdict
import torch.nn.functional as F

torch.manual_seed(0)

def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--result_dir', type=str, required=True,
                        help='Directory where results are saved')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='Number of workers for dataloader')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size')

    return parser.parse_args()

def compute_psnr_masked(img1, img2, mask):
    img1_masked = img1[mask]/255
    img2_masked = img2[mask]/255

    mse = np.mean((img1_masked - img2_masked)**2)

    if mse == 0:
        return float('inf')

    max_pixel_value = 255.0
    psnr = -10*np.log10(mse)

    return psnr

def compute_psnr(img1, img2, K, R):
    im_h, im_w, _ = img1.shape
    homo_matrix = K@R@np.linalg.inv(K)
    mask = np.ones((im_h, im_w))
    img_warp2 = cv2.warpPerspective(img2, homo_matrix, (im_w, im_h))
    mask = cv2.warpPerspective(mask, homo_matrix, (im_w, im_h)) == 1
    psnr = compute_psnr_masked(img1, img_warp2, mask)

    return psnr

class MVResultDataset(torch.utils.data.Dataset):
    def __init__(self, result_dir):
        self.result_dir = result_dir
        self.scenes = os.listdir(result_dir)

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        num_views = 8
        images_gt = []
        images_gen = []
        Rs = []
        cameras = defaultdict(list)
        for i in range(num_views):
            for images, ext in zip([images_gt, images_gen], ["_natural.png", ".png"]):
                img = cv2.imread(os.path.join(self.result_dir, self.scenes[idx], f"{i}{ext}"))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                images.append(img)

            theta = (360 / num_views * i) % 360
            K, R = get_K_R(90, theta, 0, *img.shape[:2])

            Rs.append(R)
            cameras['height'].append(img.shape[0])
            cameras['width'].append(img.shape[1])
            cameras['FoV'].append(90)
            cameras['theta'].append(theta)
            cameras['phi'].append(0)

        images_gt = np.stack(images_gt, axis=0)
        images_gen = np.stack(images_gen, axis=0)
        K = np.stack([K]*len(Rs)).astype(np.float32)
        R = np.stack(Rs).astype(np.float32)
        for k, v in cameras.items():
            cameras[k] = np.stack(v)

        prompt_dir = os.path.join(self.result_dir, self.scenes[idx], "prompt.txt")
        prompt = []
        with open(prompt_dir, 'r') as f:
            for line in f:
                prompt.append(line.strip())

        return {
            'images_gt': images_gt,
            'images_gen': images_gen,
            'K': K,
            'R': R,
            'cameras': cameras,
            'prompt': prompt
        }

if __name__ == '__main__':
    args = parse_args()

    fid = FrechetInceptionDistance(feature=2048).cuda()
    inception = InceptionScore().cuda()
    cs = CLIPScore().cuda()
    psnr = PeakSignalNoiseRatio(data_range=1.0)
    psnrs=[]

    dataset = MVResultDataset(args.result_dir)
    data_loader = torch.utils.data.DataLoader(
        dataset, num_workers=args.num_workers, batch_size=args.batch_size)

    for batch in tqdm.tqdm(data_loader):
        images_gt = rearrange(batch['images_gt'].cuda(), 'b l h w c -> (b l) c h w')
        images_gen = rearrange(batch['images_gen'].cuda(), 'b l h w c -> (b l) c h w')
        fid.update(images_gt, real=True)
        fid.update(images_gen, real=False)
        inception.update(images_gen)

        prompt_reshape = sum(map(list, zip(*batch['prompt'])), [])
        cs.update(images_gen, prompt_reshape)

        correspondences = get_correspondences(
            batch['R'], batch['K'],
            batch['cameras']['height'][0, 0].item(), batch['cameras']['width'][0, 0].item())
        correspondences[..., 0] = correspondences[..., 0] / (batch['cameras']['width'][0, 0] - 1) * 2 - 1
        correspondences[..., 1] = correspondences[..., 1] / (batch['cameras']['height'][0, 0] - 1) * 2 - 1
        masks = ((correspondences > -1) & (correspondences < 1)).all(-1)
        # overlap_mat = masks.any(-1).any(-1)
        m = correspondences.shape[1]
        for i in range(m):
            j = (i + 1) % m
            xy_l = correspondences[:, j, i]
            mask = masks[:, j, i]
            image_src = batch['images_gen'][:, i].permute(0, 3, 1, 2).float() / 255
            image_warp = F.grid_sample(image_src, xy_l, align_corners=True)
            image_tgt = batch['images_gen'][:, j].permute(0, 3, 1, 2).float() / 255
            mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
            psnr.update(image_warp[mask], image_tgt[mask])

        images = batch['images_gen'].numpy()
        R=batch['R'].numpy()
        K=batch['K'].numpy()
        for b in range(images.shape[0]):
            for i in range(images.shape[1]):
                idx1=i
                if i==images.shape[1]-1:
                    idx2=0
                else:
                    idx2=i+1
                R_rel=np.linalg.inv(R[b, idx1])@R[b, idx2]

                p=compute_psnr(images[b, idx1], images[b, idx2], K[b, i], R_rel)
                psnrs.append(p)

    print(f"FID: {fid.compute()}")
    print(f"IS: {inception.compute()}")
    print(f"CS: {cs.compute()}")
    print(f"PSNR: {psnr.compute()}")
    print(f"PSNR (author's code): {np.mean(psnrs)}")

Sorry to bother you. I tried your evaluation code, but the results for FID and CS are incorrect—they are 72 and 6.6, respectively.

I followed the test process using scripts/test_depth_two_stage.sh with the pretrained checkpoint depth_gen.ckpt on 100 sequences from the ScanNet test set. Then, I used your evaluation code.

Did I miss something? Could the issue be due to the smaller number of test image sequences? I noticed that the paper used 590 non-overlapping image sequences.