huzi96 / Coarse2Fine-PyTorch

70 stars 6 forks source link

How to use the model trained and some other problems? #5

Open achel-x opened 3 years ago

achel-x commented 3 years ago

After ran the train.py, I got some .ckpt file, and I want to know how load it and check its effect like the AppEncDec.py to compress and decomress image. Should I write a similar one like that?

Moreover, I want to know how to control the type of model PSNR or MS-SSIM, and how to control the QP values . I am still a very freshman in this area, if my questions are too naive, please forgive it.

huzi96 commented 3 years ago

https://github.com/huzi96/Coarse2Fine-PyTorch/blob/93f4e8cb7575b1db3bfd57553fbf10f0ac04bf39/networks.py#L636 This is where the script loads the ckpt. In my implementation, the name is specified by the given model type (PSNR:0, SSIM:1) and the compression level (QP). You can specify the filename of the checkpoint here.

huzi96 commented 3 years ago

In this case, you don't need the load_tf_weights function anymore. You can just load it in the pytorch way.

achel-x commented 3 years ago

In this case, you don't need the load_tf_weights function anymore. You can just load it in the pytorch way. Yes, I have read this line in 'network.py' . So I wonder how can I can control the QP and the PSNR/SSIM in 'train.py' . In every QP with PSNR/SSIM, should I run the whole completely training stage that means one model one bpp?

achel-x commented 3 years ago

There is another question about training process. I have only one GPU Geforce3080 10G, and ran the 'python train.py train --batchsize 16 --train_glob "IMAGE_PATH/*.png" --checkpoint_dir checkpoint --lambda 0.004' , it occurs CUDA out of memory RUNTIMEERROR, so I set the batchsize==8, should I increase the epoch of training? Double epoch or ? Besides, in stage 4 it spents one epoch per 133s. It is very time-consuming. Are the 5500 epochs neccesary? Can I reduce the number of epoch in the last stage? It is originally from 1300 to 5500. Can set it maybe to 2000 or some else?

huzi96 commented 3 years ago
  1. Yes. One model per QP per type (PSNR/SSIM).
  2. Setting batchsize = 8 would be OK and you don't need to double the number of epochs.
  3. The model should converge at an acceptable performance at 3000 epochs, but not guaranteed.
achel-x commented 3 years ago

Thank you very very much for your kind answer. Excuse me for my lots of problems. Maybe there is not the last question... In my perspective, QP means compression ratio is not the concept in traditional codec( $\lambda = f(qp)$ ), right? So I wonder how the network to control the compression ratio? Is it set before training Or after training I can know the exact compression ratio through compare the original image and reconstructed one, if so, how can I control a certain compression rate as I want?

huzi96 commented 3 years ago

QP is not Quantization Parameter in the context of this code, but it is related to lambda. It controls the compression rate. We train different models with different lambdas in the loss function R + \lambda D to generate different models. Each model serves for a specific compression rate (or we say rate-distortion tradeoff). The current implementation is like a non-rate-control enable codec for images. Namely, we can compress an image with a specific QP but we don't know the bit-rate of the output binary. Compressing an image to a specific bit-rate is a problem of Rate Control. However, no rate-control mechanism is implemented here. You may want to take a look at different rate control algorithms for this.

achel-x commented 3 years ago

I probably know what you mean.
So L oss function = R + \labmda D , the lambda is set manually in the python train.py train --batchsize 16 --train_glob "IMAGE_PATH/*.png" --checkpoint_dir checkpoint --lambda 0.004 I can change the value of lambda to get different tradeoff between Rate and Distortion, is it ? Could you please provide some probable value of lambda? The example lambda is 0.004 , any other values? Whether my understaniding in this way is right?

huzi96 commented 3 years ago

I would recommend 0.002, 0.004, 0.008, 0.016, 0.032.

achel-x commented 3 years ago

Thank you very, very much for sparing your valuable time.

achel-x commented 2 years ago

Excuse me for bothering again. According to your reply, load the trained model in pytorch way to use the trained model. I tryed it recently, but I found that the reconstructed picture is incorrect. I read your code networks.py and gdn_v3.py. The networks in train.py is different from the former one, like GDN/IGDN and some other code in class Network. I thoght the parameters maybe change and the trained model by Pytorch is different from .pk model. So load it cannot I still want to reconstruct the image, what can I do? should I adjust the following code compresslow() and decompresslow() to align with these changes between train.py and networks .py ? I am not familiar with the specific code in learned image compression. I just read some papers around this topics, this is my first trial to know how it works, so please forgive me for my naive problems.

huzi96 commented 2 years ago

Sorry for that. Indeed, if you would like to use the trained network for inference, you need some modifications. The easiest way to do this is to replace the definition of class Net and its related modules in the network.py code with the definitions in train.py. Besides, you need two extra functions to make it work. I would like to share the code for this but I currently do not have time to test it. If you encounter any problems please reach back to me.

def encode(self, inputs):
    b, c, h, w = inputs.shape
    tb, tc, th, tw = inputs.shape

    z3 = self.a_model(inputs)
    z3_rounded = bypass_round(z3)

    z2 = self.ha_model_2(z3_rounded)
    z2_rounded = bypass_round(z2)

    z1 = self.ha_model_1(z2_rounded)
    z1_rounded = bypass_round(z1)

    h1 = self.hs_model_1(washed(z1_rounded))
    h2 = self.hs_model_2(washed(z2_rounded))

    z1_sigma = torch.abs(self.get_h1_sigma)
    z1_mu = torch.zeros_like(z1_sigma)

    z1_likelihoods = self.entropy_bottleneck_z1(
        z1_rounded, z1_sigma, z1_mu)

    z2_mu, z2_sigma = self.prediction_model_2(
        (b, 64*4, h//2//16, w//2//16), h1, self.sampler_2)

    z2_likelihoods = self.entropy_bottleneck_z2(
        z2_rounded, z2_sigma, z2_mu)

    z3_mu, z3_sigma = self.prediction_model_3(
        (b, 192, h//16, w//16), h2, self.sampler_3)

    z3_likelihoods = self.entropy_bottleneck_z3(
        z3_rounded, z3_sigma, z3_mu)

    test_num_pixels = inputs.size()[0] * inputs.size()[2] * inputs.size()[3]

    eval_bpp = torch.sum(torch.log(z3_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + torch.sum(torch.log(z2_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + torch.sum(torch.log(z1_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels)

    ret = {}
    ret['z1_mu'] = z1_mu.detach().cpu().numpy()
    ret['z1_sigma'] = z1_sigma.detach().cpu().numpy()
    ret['z2_mu'] = z2_mu.detach().cpu().numpy()
    ret['z2_sigma'] = z2_sigma.detach().cpu().numpy()
    ret['z3_mu'] = z3_mu.detach().cpu().numpy()
    ret['z3_sigma'] = z3_sigma.detach().cpu().numpy()
    ret['z1_rounded'] = z1_rounded.detach().cpu().numpy()
    ret['z2_rounded'] = z2_rounded.detach().cpu().numpy()
    ret['z3_rounded'] = z3_rounded.detach().cpu().numpy()
    ret['eval_bpp'] = eval_bpp.detach().cpu().numpy()
    return ret

def decode(self, inputs, stage):
    if stage == 0:
        z1_sigma = torch.abs(self.get_h1_sigma)
        z1_mu = torch.zeros_like(z1_sigma)

        ret = {}
        ret['z1_sigma'] = z1_sigma.detach().cpu().numpy()
        ret['z1_mu'] = z1_mu.detach().cpu().numpy()
        return ret

    elif stage == 1:
        z1_rounded = inputs['z1_rounded']
        h1 = self.hs_model_1(z1_rounded)
        self.h1 = h1
        z2_mu, z2_sigma = self.prediction_model_2((h1.shape[0],64*4,h1.shape[2],h1.shape[3]), h1, self.sampler_2)
        ret = {}
        ret['z2_sigma'] = z2_sigma.detach().cpu().numpy()
        ret['z2_mu'] = z2_mu.detach().cpu().numpy()
        return ret

    elif stage == 2:
        z2_rounded = inputs['z2_rounded']
        h2 = self.hs_model_2(z2_rounded)
        self.h2 = h2
        z3_mu, z3_sigma = self.prediction_model_3((h2.shape[0],192,h2.shape[2],h2.shape[3]), h2, self.sampler_3)
        ret = {}
        ret['z3_sigma'] = z3_sigma.detach().cpu().numpy()
        ret['z3_mu'] = z3_mu.detach().cpu().numpy()
        return ret

    elif stage == 3:
        z3_rounded = inputs['z3_rounded']
        pf = self.s_model(z3_rounded)
        x_tilde = self.side_recon_model(pf, self.h2, self.h1)
        x_tilde = torch.round(torch.clamp((x_tilde + 1) * 127.5, 0, 255))
        return x_tilde.detach().cpu().numpy()
achel-x commented 2 years ago

Thank you for your reply. I am trying to do this thing. But only change the definition of class Net() from train.py to networks.py isn't enough. Maybe I should modificate the following codes in function compress_low() and decompress_low() to adapt to these differences. Is my understanding right? Plus, I have another question. I found train.py class Network() when mode=="test" # Bring both images back to 0..255 range. gt = torch.round((inputs + 1) * 127.5) x_hat = torch.clamp((x_tilde + 1) * 127.5, 0, 255) x_hat = torch.round(x_hat).float() v_mse = torch.mean((x_hat - gt) ** 2, [1, 2, 3]) v_psnr = torch.mean(20 * torch.log10(255 / torch.sqrt(v_mse)), 0) return eval_bpp, v_psnr, x_hat, bpp1, bpp2, bpp3

Originally, I thought the x_hat is the reconstructed image(tensor). And it show like this x_hat the shape of x_hat is [1, 3, 256, 256]. The train code RandomCrop the input image to 256x256 patch and preprocess make the x*2 / 255.0 -1 . What are the usages of these operations?

Through this image, I thought the x_hat is the latent representation, but it can calculate psnr related to original picture and reconstructed picture. That means x_hat is the decoded image, but it is actually not. I wanna know why.

huzi96 commented 2 years ago

Does the PSNR value look good? If so, please check the following.

  1. You should transpose the tensor to (256,256,3) in order to save it to an image with most of the libraries.
  2. Check the pixel values before you save it as an image. Do all the values lie in [0.255]?
achel-x commented 2 years ago

test0914_result

This is my first test result. The PSNR looks ok. The x_hat in return eval_bpp, v_psnr, x_hat, bpp1, bpp2, bpp3 is the last line of this image. It is a 0-dim tensor. After I print it, it looks like a pixel value as it in this result.

I saved the x_tilde as ‘x_tilde.pth` and load it separatly, then I ran the code

image

The result looks great. I used torchvision.transforms.ToPILImage to convert the x_hat tensor to image. But the result is as the former issue looks like. So I thought it is a latent representation removed statistical redundancy, like this paper variation ... hyper prior.

https://img-blog.csdnimg.cn/20201209144447566.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQyMjgxNDI1,size_16,color_FFFFFF,t_70

But it can calculate PSNR. I am condused. Reading your latest reply, I try to reshape the x_hat. Firstly, I still used the ToPILImage(), but it occured image

Secondly, I try to convert it to numpy then to PIL image, in this way image

The result is alike.

This is x_hat tensor to PIL image x_hat_recon


This is the numpy to PIL image x_hat_recon0927

huzi96 commented 2 years ago

Comments for your second trial: You should use transpose, not reshape.

achel-x commented 2 years ago

I transpose it with the shape of (256, 256, 3) But the image is still weird.

image x_hat_recon0928

Besides, I read the code in trian.py and networks.py

image

image

As the code in train.py when mode==test, it onliy finish the encode part in networks.py not about decode. Why it can be used to calculate v_psnr and the value is seems good? But the x_hat can't be transformed to a normal picture.

achel-x commented 2 years ago

I aslo look at the gt, it looks like ok. Just a RandomCrop patch of the original picture. 0807 gt

achel-x commented 2 years ago

Sorry to bother you, I found that I don't round the x_hat. Now I can get a correct output image of x_hat.

This is gt. gt_0928_2


And this is x_hat. x_hat_recon0928_2

xisi789 commented 2 years ago

Sorry for that. Indeed, if you would like to use the trained network for inference, you need some modifications. The easiest way to do this is to replace the definition of class Net and its related modules in the network.py code with the definitions in train.py. Besides, you need two extra functions to make it work. I would like to share the code for this but I currently do not have time to test it. If you encounter any problems please reach back to me.

def encode(self, inputs):
    b, c, h, w = inputs.shape
    tb, tc, th, tw = inputs.shape

    z3 = self.a_model(inputs)
    z3_rounded = bypass_round(z3)

    z2 = self.ha_model_2(z3_rounded)
    z2_rounded = bypass_round(z2)

    z1 = self.ha_model_1(z2_rounded)
    z1_rounded = bypass_round(z1)

    h1 = self.hs_model_1(washed(z1_rounded))
    h2 = self.hs_model_2(washed(z2_rounded))

    z1_sigma = torch.abs(self.get_h1_sigma)
    z1_mu = torch.zeros_like(z1_sigma)

    z1_likelihoods = self.entropy_bottleneck_z1(
        z1_rounded, z1_sigma, z1_mu)

    z2_mu, z2_sigma = self.prediction_model_2(
        (b, 64*4, h//2//16, w//2//16), h1, self.sampler_2)

    z2_likelihoods = self.entropy_bottleneck_z2(
        z2_rounded, z2_sigma, z2_mu)

    z3_mu, z3_sigma = self.prediction_model_3(
        (b, 192, h//16, w//16), h2, self.sampler_3)

    z3_likelihoods = self.entropy_bottleneck_z3(
        z3_rounded, z3_sigma, z3_mu)

    test_num_pixels = inputs.size()[0] * inputs.size()[2] * inputs.size()[3]

    eval_bpp = torch.sum(torch.log(z3_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + torch.sum(torch.log(z2_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + torch.sum(torch.log(z1_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels)

    ret = {}
    ret['z1_mu'] = z1_mu.detach().cpu().numpy()
    ret['z1_sigma'] = z1_sigma.detach().cpu().numpy()
    ret['z2_mu'] = z2_mu.detach().cpu().numpy()
    ret['z2_sigma'] = z2_sigma.detach().cpu().numpy()
    ret['z3_mu'] = z3_mu.detach().cpu().numpy()
    ret['z3_sigma'] = z3_sigma.detach().cpu().numpy()
    ret['z1_rounded'] = z1_rounded.detach().cpu().numpy()
    ret['z2_rounded'] = z2_rounded.detach().cpu().numpy()
    ret['z3_rounded'] = z3_rounded.detach().cpu().numpy()
    ret['eval_bpp'] = eval_bpp.detach().cpu().numpy()
    return ret

def decode(self, inputs, stage):
    if stage == 0:
        z1_sigma = torch.abs(self.get_h1_sigma)
        z1_mu = torch.zeros_like(z1_sigma)

        ret = {}
        ret['z1_sigma'] = z1_sigma.detach().cpu().numpy()
        ret['z1_mu'] = z1_mu.detach().cpu().numpy()
        return ret

    elif stage == 1:
        z1_rounded = inputs['z1_rounded']
        h1 = self.hs_model_1(z1_rounded)
        self.h1 = h1
        z2_mu, z2_sigma = self.prediction_model_2((h1.shape[0],64*4,h1.shape[2],h1.shape[3]), h1, self.sampler_2)
        ret = {}
        ret['z2_sigma'] = z2_sigma.detach().cpu().numpy()
        ret['z2_mu'] = z2_mu.detach().cpu().numpy()
        return ret

    elif stage == 2:
        z2_rounded = inputs['z2_rounded']
        h2 = self.hs_model_2(z2_rounded)
        self.h2 = h2
        z3_mu, z3_sigma = self.prediction_model_3((h2.shape[0],192,h2.shape[2],h2.shape[3]), h2, self.sampler_3)
        ret = {}
        ret['z3_sigma'] = z3_sigma.detach().cpu().numpy()
        ret['z3_mu'] = z3_mu.detach().cpu().numpy()
        return ret

    elif stage == 3:
        z3_rounded = inputs['z3_rounded']
        pf = self.s_model(z3_rounded)
        x_tilde = self.side_recon_model(pf, self.h2, self.h1)
        x_tilde = torch.round(torch.clamp((x_tilde + 1) * 127.5, 0, 255))
        return x_tilde.detach().cpu().numpy()

@huzi96 @achel-x I follow these suggestions to modify network.py. But the image is weird.

This is gt. 10000001290

And this decode image. example_dec

This is code. image image image

achel-x commented 2 years ago

The modification is not located in class Net but in the load weight part. In maybe def train() or def encode() when you load the weights, you need to make a little change. Here is the part of the successful tested code provided by author. You can refer to it for your work. '

sd = net.state_dict()

td = load_tf_weights(sd, f'models/model{model_type}_qp{qp}.pk')

Changed to directly load the weights

Here we only manually load one fixed model for testing

You can write some code to read from command line arguments

td = torch.load('trained_fat_l004.ckpt', map_location=torch.device(device)) ntd = {}

naive way to remove 'module.' before the name of each tensor

for k in td: if k[:6] == 'module': ntd[k[7:]] = td[k] else: ntd[k] = td[k] ret = net.load_state_dict(ntd, strict=False) print(ret)

arr = np.array([fshape[0], fshape[1]], dtype=np.uint16) arr.tofile(fileobj) fileobj.close() '

The comment block can't upload the .py file. So I paste the modification here.

xisi789 commented 2 years ago

@achel-x Thank you for your reply. I loaded the self-trained model. The loading method is the same as yours. But the decoding is the same as above. load weight part: image

In addition, where should these codes be placed? image

achel-x commented 2 years ago

The red bounding box is not the modification. I copy it to indicate its location in code. Just as same as the networks.py in def compress_low() and def decompress_low(). e.g.

def compress_low(args): """Compresses an image.""" global device mode = 'low' if args.qp > 3: mode = 'high' device = torch.device(args.device) from PIL import Image

Load input image and add batch dimension.

f = Image.open(args.input) fshape = [f.size[1], f.size[0], 3] x = np.array(f).reshape([1,] + fshape)

compressed_file_path = args.output fileobj = open(compressed_file_path, mode='wb')

qp = args.qp model_type = args.model_type print(f'model_type: {model_type}, qp: {qp}')

buf = qp << 1 buf = buf + model_type arr = np.array([0], dtype=np.uint8) arr[0] = buf arr.tofile(fileobj)

h, w = (fshape[0]//64)64, (fshape[1]//64)64 if h < fshape[0]: h += 64 if w < fshape[1]: w += 64

pad_up = (h - fshape[0]) // 2 pad_down = (h - fshape[0]) - pad_up pad_left = (w - fshape[1]) // 2 pad_right = (w - fshape[1]) - pad_left

x = np.pad(x, [[0, 0], [pad_up, pad_down], [pad_left, pad_right], [0, 0]]) torch_x = (torch.tensor(x).permute(0,3,1,2).float() / 127.5) - 1

if mode == 'low': net = NetLow().eval() else: net = NetHigh().eval() net = net.to(device)

sd = net.state_dict()

td = load_tf_weights(sd, f'models/model{model_type}_qp{qp}.pk')

Changed to directly load the weights

Here we only manually load one fixed model for testing

You can write some code to read from command line arguments

td = torch.load('trained_fat_l004.ckpt', map_location=torch.device(device)) ntd = {}

naive way to remove 'module.' before the name of each tensor

for k in td: if k[:6] == 'module': ntd[k[7:]] = td[k] else: ntd[k] = td[k] ret = net.load_state_dict(ntd, strict=False) print(ret)

arr = np.array([fshape[0], fshape[1]], dtype=np.uint16) arr.tofile(fileobj) fileobj.close()

xisi789 commented 2 years ago

@achel-x Sorry to bother you again, but I did not find any difference between my code and yours. This is my code. image Could you provide the complete networks.py? thank you very much!

huzi96 commented 2 years ago

@xisi789 I uploaded the code for testing on a trained checkpoint. Take a look~

xisi789 commented 2 years ago

@xisi789 I uploaded the code for testing on a trained checkpoint. Take a look~

Now I can get the correct output image! Thank you very much!

balabengba commented 2 years ago

example_dec  2 (1) I test the picture using the code for testing on a trained checkpoint updated recently. But the result is not right.

huzi96 commented 2 years ago

Check if you can correctly encode and decode images with this weight to find out what went wrong (you may want to set qp>3) https://drive.google.com/file/d/1x2ZePCM8VaqGNWteKS55jd9LWpaPfK_O/view?usp=sharing

balabengba commented 2 years ago

检查您是否可以使用此权重正确编码和解码图像以找出问题所在(您可能需要设置 qp>3) https://drive.google.com/file/d/1x2ZePCM8VaqGNWteKS55jd9LWpaPfK_O/view?usp=sharing

Thank you very much! I get the correct output image!