cvlab-stonybrook / PathLDM

Official Code for PathLDM: Text conditioned Latent Diffusion Model for Histopathology (WACV 2024)
33 stars 3 forks source link

How to get the reported FID score based on provided checkpoint? #21

Closed mQvQ closed 5 months ago

mQvQ commented 6 months ago

Thanks for your excellent work ! I attempted to generate 30,000 images on the test split for the purpose of calculating the FID score, and the final score I obtained was 190 +. Here is my code. Could you please help me identify the issue? I would greatly appreciate it.

for idx, batch in enumerate(tqdm(test_dataloader)):
        if samples > SAMPLE_NUM: break # calculate fid score
        x_samples_ddim = gen_baseline_sample(batch, shape, scale)
        # samples
        samples += batch['image'].shape[0]
        x_samples_ddim = torch.tensor(x_samples_ddim, dtype=torch.float32).to(device)
        with torch.no_grad():
            pred = fid_model(x_samples_ddim)[0]
            if pred.size(2) != 1 or pred.size(3) != 1:
                pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
            pred = pred.squeeze(3).squeeze(2).cpu().numpy()
    results.append(pred)
    path = fid_path
    with np.load(path) as f:
        m1, s1 = f["mu"][:], f["sigma"][:]
    act = np.concatenate(results, axis=0)
    m2 = np.mean(act, axis=0)
    s2 = np.cov(act, rowvar=False)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
    print('fid', fid_value)

def gen_baseline_sample(batch, shape, scale):
    batch_size = batch['image'].shape[0]
    with torch.no_grad():
        ut = get_unconditional_token(batch_size)
        uc = model.get_learned_conditioning(ut)
        ct = batch['caption']
        cc = model.get_learned_conditioning(ct)
        samples_ddim, _ = sampler.sample(50, batch_size, shape, cc, verbose=False, \
                                         unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=0)
        x_samples_ddim = model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

    return x_samples_ddim
mQvQ commented 6 months ago

I used this checkpoint to calculate FID https://drive.google.com/drive/folders/1v3SXkA1D94w7Q1XMPSEA1yrSfpwhXzCr

srikarym commented 6 months ago

What's the input range expected by the fid_model? x_samples_ddim lies in [-1, 1], but pytorch-fid transforms them to [0, 1]