adityassrana / Content-Weighted-Image-Compression

PyTorch implementation of Learning Convolutional Networks for Content-Weighted Image Compression
23 stars 6 forks source link

Gradient for Importance Map Mask #5

Open amarzullo24 opened 2 years ago

amarzullo24 commented 2 years ago

Hi, Thanks for this implementation, good starting point!

I have some concerns about the implementation of the backward pass in the Mask class. From the code it seems returning a vector of ones. However, it looks different from Equation (7) in the paper (as I tried to implement below). Am I missing something?

class Mask(torch.autograd.Function):
    """
    Equation (6) and (7) in paper. This mask will be element-wise
    multiplied with the binary feature map generated by the encoder.
    Input:  A quantized importance map of shape (N,1,h,w)
            with L different integer values from 0 to (L-1)
    Output: A 3-D mask of dimensions (N,64,h,w) filled with
            sequential 1s and 0s
    """
    @staticmethod
    def forward(ctx, i):
        device = i.device
        N, _, H, W = i.shape
        n = 64
        L = 16
        qimp = i
        mask = torch.zeros(n, N*H*W).to(device)
        qimp_flat = qimp.view(1, N*H*W)
        ctx.save_for_backward(qimp)
        for indx in range(n):
            mask[indx, :] = torch.where(indx < np.ceil(n/L)*qimp_flat,
                                        torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
        mask = mask.view(n, N, H, W).permute((1, 0, 2, 3))
        return mask

    @staticmethod
    def backward(ctx, grad_output):
        qimp = ctx.saved_tensors[0]
        grad_input = grad_output.clone()
        device = grad_input.device

        N, _, H, W = qimp.shape
        n = 64
        L = 16

        grad_input = grad_input.view(n, N*H*W)
        qimp_flat = qimp.view(N*H*W)
        for indx in range(n):
            condition = (L * qimp_flat - 1 <= np.ceil(indx*L/n)) & (np.ceil(indx*L/n) < L * qimp_flat + 2)
            grad_input[indx, :] = torch.where(condition, torch.Tensor([L]).to(device), torch.Tensor([0]).to(device))

        return grad_input.view(n, N, H, W).permute((1, 0, 2, 3)).to(device)

        """
       # Original code
        N, _, H, W = grad_output.shape
        if grad_output.is_cuda:
            return torch.ones(N, 1, H, W).cuda()
        else:
            return torch.ones(N, 1, H, W)
        """
adityassrana commented 2 years ago

Hi @emmeduz, this was a long time ago but from what I can recall, I already considered this but I was either facing difficulty in learning or there was some other gradient issue. It would be super cool if you could test this out and let me know if it works, and I would be happy to receive a PR

aalex1945 commented 1 year ago

Have you finished the code of training importance map yet? I plan to work on the relavant project but I have trouble complement this part. Thanks!