MC-E / Deep-Generalized-Unfolding-Networks-for-Image-Restoration

Accepted by CVPR 2022
123 stars 26 forks source link

Issue about trainable measurement matrix A in compressed sensing mode #8

Open XiChen-97 opened 1 year ago

XiChen-97 commented 1 year ago

Hi there @MC-E ,

Thank you for your great input and sharing the code! I have a question about the compressed sensing case:

1) As you mentioned in the paper: "Note that in the task of compressive sensing, the degradation matrix A is exactly known, i.e., the sampling matrix Φ. Thus, we directly use Φ to calculate the gradient.". However, in the code, you set A to be trainable parameters instead:

    PhiTPhi = torch.mm(torch.transpose(self.Phi, 0, 1), self.Phi)  # torch.mm(Phix, Phi)
    Phix = torch.mm(img.view(b,-1), torch.transpose(self.Phi, 0, 1))  # compression result
    PhiTb = torch.mm(Phix,self.Phi)
    # compute r_0
    x_0=PhiTb.view(b,-1)
    x = x_0 - self.r0 * torch.mm(x_0, PhiTPhi)
    r_0 = x + self.r0 * PhiTb
    r_0=r_0.view(b,c,w,h)

https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/bae845c2612d0df56a479020d59896441168d07a/Compressive-Sensing/DGUNet.py#L384C1-L391C30

where the self.Phi (referred as A) is learnable in the training. This makes me confused, because in compressed sensing, we assume that the only information we have are y and A, and we have no access to the raw image X_0. But here since A is learnable params, y is essentially a linear transformation of the real X_0, which means all the information of X_0 is known as the input of the model. Eventually, the model is actually learning \hat{X} (output) given the real image X_0 (input), which is somehow equivalent to a problem of recovering X_0 given X_0.

Instead, since you assume A is unknown, then the process of making measurement y should not involve the learnable parameter A, which is the process of getting Phix in the code.

2) What is more, the input of the model when testing is the real image (say X_0):

        batch_x = torch.from_numpy(Img_output)
        batch_x = batch_x.type(torch.FloatTensor)
        batch_x = batch_x.to(device)
      # Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))  # compression result
      # PhixPhiT = torch.mm(Phix, Phi)
        batch_x = batch_x.view(batch_x.shape[0], 1, args.patch_size, args.patch_size)
        x_output = model(batch_x)[0]  # torch.mm(batch_x,

https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/bae845c2612d0df56a479020d59896441168d07a/Compressive-Sensing/train.py#L231C1-L238C62

and the measurement y is obtained by y=AX_0, where the A is the trainable parameters. I couldn't understand this setting, since in the testing case, we assume that the only information we have is y and A (if we know the degradation model), but here the input of model is the real raw testing image.

Please correct me if I misunderstood anything here, and I apologize in advance if I missed anything or misunderstood anything that is already explained clearly in the paper and code. Thank you so much, and l look forward to your replying!

MC-E commented 9 months ago

Sorry for the late reply. A is learnable, capturing the degradation matrix through training data, as in [1] and [2]. [1] Deep Memory-Augmented Proximal Unrolling Network for Compressive Sensing [2] Optimization-inspired compact deep compressive sensing