huzi96 / Coarse2Fine-PyTorch

70 stars 6 forks source link

Question about " x_tilde = self.side_recon_model(pf, h2, h1)" during inference #8

Closed achel-x closed 2 years ago

achel-x commented 2 years ago

Last time I got a crop of input image to show the grount truth and x_hat. But I want to see the complete image. So I tried to adapt the networks.py to inference. In class NetHigh() encode part, I met some obstacles. The x_hat is wrong.

In def compresslow()

for block in tqdm.tqdm(blocks): with torch.no_grad(): block = block.to(device) ret = net.encode(block)

` def encode(self, inputs): b, c, h, w = inputs.shape print("inputs shape is ", inputs.shape)

print("inputs is ", inputs)

    z3 = self.a_model(inputs)
    z3_rounded = bypass_round(z3)
    # print("z3_rounded : ", z3_rounded)

    z2 = self.ha_model_2(z3_rounded)
    z2_rounded = bypass_round(z2)

    z1 = self.ha_model_1(z2_rounded)
    z1_rounded = bypass_round(z1)

    z1_sigma = torch.abs(self.get_h1_sigma)
    z1_mu = torch.zeros_like(z1_sigma)

    h1 = self.hs_model_1(washed(z1_rounded))
    h2 = self.hs_model_2(washed(z2_rounded))
    # The shape of h1 and h2 is (1, 256, 24, 24)

    z1_likelihoods = self.entropy_bottleneck_z1(z1_rounded, z1_sigma, z1_mu)

    z2_mu, z2_sigma = self.prediction_model_2(
        (b, 64 * 4, h // 2 // 16, w // 2 // 16), h1, self.sampler_2)

    z2_likelihoods = self.entropy_bottleneck_z2(
        z2_rounded, z2_sigma, z2_mu)

    z3_mu, z3_sigma = self.prediction_model_3(
        (b, 384, h // 16, w // 16), h2, self.sampler_3)

    z3_likelihoods = self.entropy_bottleneck_z3(
        z3_rounded, z3_sigma, z3_mu)

    pf = self.s_model(washed(z3_rounded))
    # the shape of pf is (1, 384, 384, 384)
    x_tilde = self.side_recon_model(pf, h2, h1)

    test_num_pixels = inputs.size()[0] * h * w

    eval_bpp = torch.sum(torch.log(z3_likelihoods), [0, 1, 2, 3]) / (-np.log(2) * test_num_pixels) + torch.sum(
        torch.log(z2_likelihoods), [0, 1, 2, 3]) / (-np.log(2) * test_num_pixels) + torch.sum(
        torch.log(z1_likelihoods), [0, 1, 2, 3]) / (-np.log(2) * test_num_pixels)

    gt = torch.round((inputs + 1) * 127.5)
    x_hat = torch.clamp((x_tilde + 1) * 127.5, 0, 255)
    x_hat = torch.round(x_hat).float()
    v_mse = torch.mean((x_hat - gt) ** 2, [1, 2, 3])
    v_psnr = torch.mean(20 * torch.log10(255 / torch.sqrt(v_mse)), 0)

    ret = {}
    ret['z1_mu'] = z1_mu.detach().cpu().numpy()
    ret['z1_sigma'] = z1_sigma.detach().cpu().numpy()
    ret['z2_mu'] = z2_mu.detach().cpu().numpy()
    ret['z2_sigma'] = z2_sigma.detach().cpu().numpy()
    ret['z3_mu'] = z3_mu.detach().cpu().numpy()
    ret['z3_sigma'] = z3_sigma.detach().cpu().numpy()
    ret['z1_rounded'] = z1_rounded.detach().cpu().numpy()
    ret['z2_rounded'] = z2_rounded.detach().cpu().numpy()
    ret['z3_rounded'] = z3_rounded.detach().cpu().numpy()
    ret['v_psnr'] = v_psnr.detach().cpu().numpy()
    ret['eval_bpp'] = eval_bpp.detach().cpu().numpy()

    return ret

` I return the gt, x_hat, x_tilde, pf, h2, h1 to display them.

The gt is normal like this gt_1011_1

But the x_hat and x_tilde is wrong. This is x_tilde, and the latter is x_hat

x_tilde_0

x_hat_1011_1

x_tilde = self.side_recon_model(pf, h2, h1) I print some information concerned with pf, h2 and h1.

image The shape of pf is (1, 384, 384, 384)

image

I am confused that why x_tilde and x_hat can't reconstruct the input. Kindly give me your advice.

huzi96 commented 2 years ago

Please check the value of x_hat and x_tilde for clues.