Open ferdifdi opened 1 year ago
I've tried to edit deblurring svd_operator by adding mask. But, the result is failed. Could you review my code or give clue/additional material? Thank you
elif deg == 'deblur_gauss':
from functions.svd_operators import Deblurring
loaded = np.load("exp/inp_masks/mask.npy")
mask = torch.from_numpy(loaded).to(self.device).reshape(-1)
missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
missing_g = missing_r + 1
missing_b = missing_g + 1
missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
sigma = 10
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(self.device)
A_funcs = Deblurring(kernel / kernel.sum(), config.data.channels, self.config.data.image_size, self.device, missing)
class Deblurring(A_functions):
def mat_by_img(self, M, v):
return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
def img_by_mat(self, v, M):
return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
def __init__(self, kernel, channels, img_dim, device, missing_indices=None, ZERO = 3e-2):
self.img_dim = img_dim
self.channels = channels
self.missing_indices = missing_indices
self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long()
#build 1D conv matrix
A_small = torch.zeros(img_dim, img_dim, device=device)
for i in range(img_dim):
for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
if j < 0 or j >= img_dim: continue
A_small[i, j] = kernel[j - i + kernel.shape[0]//2]
#get the svd of the 1D conv
self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
#ZERO = 3e-2
self.singulars_small_orig = self.singulars_small.clone()
self.singulars_small[self.singulars_small < ZERO] = 0
#calculate the singular values of the big matrix
self._singulars_orig = torch.matmul(self.singulars_small_orig.reshape(img_dim, 1), self.singulars_small_orig.reshape(1, img_dim)).reshape(img_dim**2)
self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
# sort the big matrix singulars and save the permutation
self._singulars, self._perm = self._singulars.sort(descending=True) # , stable=True)
self._singulars_orig = self._singulars_orig[self._perm]
print("init_akhir")
def V(self, vec):
#invert the permutation
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
temp = temp.reshape(vec.shape[0], -1)
out = torch.zeros_like(temp)
out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
out = out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
#multiply the image by V from the left and by V^T from the right
out = self.mat_by_img(self.V_small, out)
out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
return out
def Vt(self, vec):
#multiply the image by V^T from the left and by V from the right
temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)
#permute the entries according to the singular values
temp = temp[:, :, self._perm].permute(0, 2, 1)
temp = temp.reshape(vec.shape[0], -1)
out = torch.zeros_like(temp)
out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]
return out
def U(self, vec):
#invert the permutation
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
temp = temp.permute(0, 2, 1)
#multiply the image by U from the left and by U^T from the right
out = self.mat_by_img(self.U_small, temp)
out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
return out
def Ut(self, vec):
#multiply the image by U^T from the left and by U from the right
temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)
#permute the entries according to the singular values
temp = temp[:, :, self._perm].permute(0, 2, 1)
return temp.reshape(vec.shape[0], -1)
def singulars(self):
return self._singulars.repeat(1, 3).reshape(-1)
def add_zeros(self, vec):
return vec.clone().reshape(vec.shape[0], -1)
def A_pinv(self, vec):
temp = self.Ut(vec)
singulars = self._singulars.repeat(1, 3).reshape(-1)
factors = 1. / singulars
factors[singulars == 0] = 0.
temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
return self.V(self.add_zeros(temp))
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
temp_vec = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
temp_vec = self.img_by_mat(temp_vec, self.V_small).reshape(vec.shape[0], self.channels, -1)
temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
singulars = self._singulars_orig
lambda_t = torch.ones(self.img_dim ** 2, device=vec.device)
temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
temp_singulars[:singulars.size(0)] = singulars
singulars = temp_singulars
inverse_singulars = 1. / singulars
inverse_singulars[singulars == 0] = 0.
if a != 0 and sigma_y != 0:
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
lambda_t = lambda_t.reshape(1, -1, 1)
temp_vec = temp_vec * lambda_t
temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
temp[:, self._perm, :] = temp_vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)
temp = temp.reshape(vec.shape[0], -1)
out = torch.zeros_like(temp)
out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
out = out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
out = self.mat_by_img(self.V_small, out)
out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
return out
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1)
temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1)
temp_eps = temp_eps[:, :, self._perm].permute(0, 2, 1)
singulars = self._singulars_orig
d1_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * eta
d2_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
temp_singulars[:singulars.size(0)] = singulars
singulars = temp_singulars
inverse_singulars = 1. / singulars
inverse_singulars[singulars == 0] = 0.
if a != 0 and sigma_y != 0:
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
d2_t = d2_t * (-change_index + 1.0)
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
d2_t = d2_t * (-change_index + 1.0)
change_index = (singulars == 0) * 1.0
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
d1_t = d1_t.reshape(1, -1, 1)
d2_t = d2_t.reshape(1, -1, 1)
temp_vec = temp_vec * d1_t
temp_eps = temp_eps * d2_t
temp_vec_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
temp_vec_new[:, self._perm, :] = temp_vec
temp_vec_new = temp_vec_new.reshape(vec.shape[0], -1)
out = torch.zeros_like(temp_vec_new)
out[:, self.kept_indices] = temp_vec_new[:, :self.kept_indices.shape[0]]
out[:, self.missing_indices] = temp_vec_new[:, self.kept_indices.shape[0]:]
temp_vec_new = out.reshape(vec.shape[0], -1, self.channels)
out_vec = self.mat_by_img(self.V_small, temp_vec_new.permute(0, 2, 1))
out_vec = self.img_by_mat(out_vec, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
temp_eps_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
temp_eps_new[:, self._perm, :] = temp_eps
temp_eps_new = temp_eps_new.reshape(vec.shape[0], -1)
out = torch.zeros_like(temp_eps_new)
out[:, self.kept_indices] = temp_eps_new[:, :self.kept_indices.shape[0]]
out[:, self.missing_indices] = temp_eps_new[:, self.kept_indices.shape[0]:]
temp_eps_new = out.reshape(vec.shape[0], -1, self.channels)
out_eps = self.mat_by_img(self.V_small, temp_eps_new.permute(0, 2, 1))
out_eps = self.img_by_mat(out_eps, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
return out_vec + out_eps
If you do not need denoising, it will be easier to separate these two operators first.
Then, define the whole blur operation as a function A_blur( ), and its pseudo-inverse operator as Ap_blur( ).
Write the simplified mask operation as A_mask( ), and its pseudo-inverse as Ap_mask( ):
A_mask = lambda z: z*mask
Ap_mask = A
Your desired operator will be
A = lambda z: A_mask(A_blur(z))
and the pseudo-inverse will be
Ap = lambda z: Ap_blur(Ap_mask(z))
Thank you. Do you have any clue or any material to build A-blur and Ap_blur for deblurr_gauss because the simplified one has no deblurr degradation. Deblurring degradation is only in the SVD one, but it's hard for me to add the mask in it.
Thank you. Do you have any clue or any material to build A-blur and Ap_blur for deblurr_gauss because the simplified one has no deblurr degradation. Deblurring degradation is only in the SVD one, but it's hard for me to add the mask in it.
You can write the SVD blur degradation as a function A_blur()
For example,
A_funcs = Deblurring(torch.Tensor([1 / 9] * 9).to(self.device), config.data.channels, self.config.data.image_size, self.device)
A_blur = lambda z: A_funcs.A(z)
Ap_blur = lambda z: A_funcs.A_pinv(z)
Hi, sorry for the late reply. It is possible for any combination of linear operators, as described in Section 3.2.