wyhuai / DDNM

[ICLR 2023 Oral] Zero-Shot Image Restoration Using Denoising Diffusion Null-Space Model
MIT License
1.16k stars 83 forks source link

How to deblurring with mask? Is it possible? #33

Open ferdifdi opened 1 year ago

wyhuai commented 1 year ago

Hi, sorry for the late reply. It is possible for any combination of linear operators, as described in Section 3.2.

ferdifdi commented 1 year ago

I've tried to edit deblurring svd_operator by adding mask. But, the result is failed. Could you review my code or give clue/additional material? Thank you

elif deg == 'deblur_gauss':
            from functions.svd_operators import Deblurring
            loaded = np.load("exp/inp_masks/mask.npy")
            mask = torch.from_numpy(loaded).to(self.device).reshape(-1)
            missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
            missing_g = missing_r + 1
            missing_b = missing_g + 1
            missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
            sigma = 10
            pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
            kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(self.device)
            A_funcs = Deblurring(kernel / kernel.sum(), config.data.channels, self.config.data.image_size, self.device, missing)
class Deblurring(A_functions):
    def mat_by_img(self, M, v):
        return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)

    def img_by_mat(self, v, M):
        return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])

    def __init__(self, kernel, channels, img_dim, device, missing_indices=None, ZERO = 3e-2):
        self.img_dim = img_dim
        self.channels = channels
        self.missing_indices = missing_indices
        self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long()

        #build 1D conv matrix
        A_small = torch.zeros(img_dim, img_dim, device=device)
        for i in range(img_dim):
            for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
                if j < 0 or j >= img_dim: continue
                A_small[i, j] = kernel[j - i + kernel.shape[0]//2]
        #get the svd of the 1D conv
        self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
        #ZERO = 3e-2
        self.singulars_small_orig = self.singulars_small.clone()
        self.singulars_small[self.singulars_small < ZERO] = 0
        #calculate the singular values of the big matrix
        self._singulars_orig = torch.matmul(self.singulars_small_orig.reshape(img_dim, 1), self.singulars_small_orig.reshape(1, img_dim)).reshape(img_dim**2)
        self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
        # sort the big matrix singulars and save the permutation

        self._singulars, self._perm = self._singulars.sort(descending=True)  # , stable=True)
        self._singulars_orig = self._singulars_orig[self._perm]
        print("init_akhir")

    def V(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
        temp = temp.reshape(vec.shape[0], -1)
        out = torch.zeros_like(temp)
        out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
        out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
        out = out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)

        #multiply the image by V from the left and by V^T from the right
        out = self.mat_by_img(self.V_small, out)
        out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Vt(self, vec):        
        #multiply the image by V^T from the left and by V from the right
        temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
        temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)

        #permute the entries according to the singular values
        temp = temp[:, :, self._perm].permute(0, 2, 1)
        temp = temp.reshape(vec.shape[0], -1)

        out = torch.zeros_like(temp)
        out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
        out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]

        return out 

    def U(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
        temp = temp.permute(0, 2, 1)

        #multiply the image by U from the left and by U^T from the right
        out = self.mat_by_img(self.U_small, temp)
        out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Ut(self, vec):
        #multiply the image by U^T from the left and by U from the right
        temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
        temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)

        #permute the entries according to the singular values
        temp = temp[:, :, self._perm].permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def singulars(self):
        return self._singulars.repeat(1, 3).reshape(-1)

    def add_zeros(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def A_pinv(self, vec):
        temp = self.Ut(vec)
        singulars = self._singulars.repeat(1, 3).reshape(-1)

        factors = 1. / singulars
        factors[singulars == 0] = 0.

        temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
        return self.V(self.add_zeros(temp))

    def Lambda(self, vec, a, sigma_y, sigma_t, eta):
        temp_vec = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
        temp_vec = self.img_by_mat(temp_vec, self.V_small).reshape(vec.shape[0], self.channels, -1)
        temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)

        singulars = self._singulars_orig
        lambda_t = torch.ones(self.img_dim ** 2, device=vec.device)
        temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
        temp_singulars[:singulars.size(0)] = singulars
        singulars = temp_singulars
        inverse_singulars = 1. / singulars
        inverse_singulars[singulars == 0] = 0.

        if a != 0 and sigma_y != 0:
            change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
            lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
                    singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)

        lambda_t = lambda_t.reshape(1, -1, 1)
        temp_vec = temp_vec * lambda_t

        temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
        temp[:, self._perm, :] = temp_vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)
        temp = temp.reshape(vec.shape[0], -1)

        out = torch.zeros_like(temp)
        out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
        out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
        out = out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)

        out = self.mat_by_img(self.V_small, out)
        out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
        temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1)
        temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)

        temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1)
        temp_eps = temp_eps[:, :, self._perm].permute(0, 2, 1)

        singulars = self._singulars_orig
        d1_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * eta
        d2_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5

        temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
        temp_singulars[:singulars.size(0)] = singulars
        singulars = temp_singulars
        inverse_singulars = 1. / singulars
        inverse_singulars[singulars == 0] = 0.

        if a != 0 and sigma_y != 0:
            change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
            d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
            d2_t = d2_t * (-change_index + 1.0)

            change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
            d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
                change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
            d2_t = d2_t * (-change_index + 1.0)

            change_index = (singulars == 0) * 1.0
            d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
            d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5

        d1_t = d1_t.reshape(1, -1, 1)
        d2_t = d2_t.reshape(1, -1, 1)

        temp_vec = temp_vec * d1_t
        temp_eps = temp_eps * d2_t

        temp_vec_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
        temp_vec_new[:, self._perm, :] = temp_vec
        temp_vec_new = temp_vec_new.reshape(vec.shape[0], -1)

        out = torch.zeros_like(temp_vec_new)
        out[:, self.kept_indices] = temp_vec_new[:, :self.kept_indices.shape[0]]
        out[:, self.missing_indices] = temp_vec_new[:, self.kept_indices.shape[0]:]
        temp_vec_new = out.reshape(vec.shape[0], -1, self.channels)

        out_vec = self.mat_by_img(self.V_small, temp_vec_new.permute(0, 2, 1))
        out_vec = self.img_by_mat(out_vec, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)

        temp_eps_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
        temp_eps_new[:, self._perm, :] = temp_eps
        temp_eps_new = temp_eps_new.reshape(vec.shape[0], -1)

        out = torch.zeros_like(temp_eps_new)
        out[:, self.kept_indices] = temp_eps_new[:, :self.kept_indices.shape[0]]
        out[:, self.missing_indices] = temp_eps_new[:, self.kept_indices.shape[0]:]
        temp_eps_new = out.reshape(vec.shape[0], -1, self.channels)

        out_eps = self.mat_by_img(self.V_small, temp_eps_new.permute(0, 2, 1))
        out_eps = self.img_by_mat(out_eps, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out_vec + out_eps
wyhuai commented 1 year ago

If you do not need denoising, it will be easier to separate these two operators first.

Then, define the whole blur operation as a function A_blur( ), and its pseudo-inverse operator as Ap_blur( ).

Write the simplified mask operation as A_mask( ), and its pseudo-inverse as Ap_mask( ): A_mask = lambda z: z*mask Ap_mask = A

Your desired operator will be A = lambda z: A_mask(A_blur(z)) and the pseudo-inverse will be Ap = lambda z: Ap_blur(Ap_mask(z))

ferdifdi commented 1 year ago

Thank you. Do you have any clue or any material to build A-blur and Ap_blur for deblurr_gauss because the simplified one has no deblurr degradation. Deblurring degradation is only in the SVD one, but it's hard for me to add the mask in it.

wyhuai commented 1 year ago

Thank you. Do you have any clue or any material to build A-blur and Ap_blur for deblurr_gauss because the simplified one has no deblurr degradation. Deblurring degradation is only in the SVD one, but it's hard for me to add the mask in it.

You can write the SVD blur degradation as a function A_blur()

wyhuai commented 1 year ago

For example, A_funcs = Deblurring(torch.Tensor([1 / 9] * 9).to(self.device), config.data.channels, self.config.data.image_size, self.device) A_blur = lambda z: A_funcs.A(z) Ap_blur = lambda z: A_funcs.A_pinv(z)