Open unlugi opened 1 year ago
You can refer the following code if you are interested.
` ################################################################################
################################################################################ class StructureGuidedRankingLoss(nn.Module): def init(self, point_pairs=3000, sigma=0.03, alpha=1.0, mask_value=-1e-1): super(StructureGuidedRankingLoss, self).init() self.point_pairs = point_pairs # number of point pairs self.sigma = sigma # used for determining the ordinal relationship between a selected pair self.alpha = alpha # used for balancing the effect of = and (<,>) self.mask_value = mask_value self.regularization_loss = GradientLoss(scales=4)
def getEdge(self, images):
n,c,h,w = images.size()
a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
if c == 3:
gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a)
gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b)
else:
gradient_x = F.conv2d(images, a)
gradient_y = F.conv2d(images, b)
edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2))
edges = F.pad(edges, (1,1,1,1), "constant", 0)
thetas = torch.atan2(gradient_y, gradient_x)
thetas = F.pad(thetas, (1,1,1,1), "constant", 0)
return edges, thetas
def ind2sub(self, idx, cols):
r = idx / cols
c = idx - r * cols
return r, c
def sub2ind(self, r, c, cols):
idx = r * cols + c
return idx
def forward(self, inputs, targets, images, masks, instance_masks):
regularization_loss = self.regularization_loss(inputs.squeeze(1), targets.squeeze(1), masks.squeeze(1))
# RGB domain
edges_img, thetas_img = self.getEdge(images)
#=============================
n,c,h,w = targets.size()
if n != 1:
inputs = inputs.view(n, -1).double()
targets = targets.view(n, -1).double()
masks = masks.view(n, -1).double()
instance_masks = instance_masks.view(n, -1).double()
edges_img = edges_img.view(n, -1).double()
thetas_img = thetas_img.view(n, -1).double()
else:
inputs = inputs.contiguous().view(1, -1).double()
targets = targets.contiguous().view(1, -1).double()
masks = masks.contiguous().view(1, -1).double()
instance_masks = instance_masks.contiguous().view(1, -1).double()
edges_img = edges_img.contiguous().view(1, -1).double()
thetas_img = thetas_img.contiguous().view(1, -1).double()
loss = torch.DoubleTensor([0.]).cuda()
inputs_edge = []
minlen = []
min_samples = self.point_pairs
img_pixels = h*w
for i in range(n):
inputs_A = torch.Tensor([]).double().cuda()
inputs_B = torch.Tensor([]).double().cuda()
targets_A = torch.Tensor([]).double().cuda()
targets_B = torch.Tensor([]).double().cuda()
masks_A = torch.Tensor([]).double().cuda()
masks_B = torch.Tensor([]).double().cuda()
# if exists instances
instance_num = torch.max(instance_masks[i])
#print('For this image, instance number is:', instance_num)
unique_instance_id = torch.unique(instance_masks[i])
total_instance_pixels = len(instance_masks[i, :].gt(0))
if instance_num == 0:
# use EGS + RS
#
# Edge-Guided sampling
inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w)
# Random Sampling
random_sample_num = sample_num
random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)
# Combine ER + RS
inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
targets_A = torch.cat((targets_A, random_targets_A), 0)
targets_B = torch.cat((targets_B, random_targets_B), 0)
masks_A = torch.cat((masks_A, random_masks_A), 0)
masks_B = torch.cat((masks_B, random_masks_B), 0)
else:
# instance guided sampling
#print('Instance-guided sampling: ', unique_instance_id) #[0,1,2,3]
for step, instance_id in enumerate(unique_instance_id):
if instance_id == 0:
continue
find_instance = instance_masks[i, :].eq(instance_id)
instance_id_index = torch.masked_select(inputs[i, :], find_instance)
background_index = torch.masked_select(inputs[i, :], ~find_instance)
targets_instance_index = torch.masked_select(targets[i, :], find_instance)
targets_background_index = torch.masked_select(targets[i, :], ~find_instance)
masks_instance_index = torch.masked_select(masks[i, :], find_instance)
masks_background_index = torch.masked_select(masks[i, :], ~find_instance)
num_effect_pixels = len(instance_id_index)
num_background_pixels = img_pixels - num_effect_pixels
shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda()
shuffle_background_pixels = torch.randperm(num_background_pixels).cuda()
if num_effect_pixels < 20:
continue
elif (num_effect_pixels < min_samples) and (num_effect_pixels >= 20):
instance_instance = instance_id_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
instance_background = instance_id_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
background_samples = background_index[shuffle_background_pixels[:num_effect_pixels]]
background_instance = background_samples[:(num_effect_pixels+1)//2]
background_background = background_samples[(num_effect_pixels+1)//2:]
targets_instance_instance = targets_instance_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
targets_instance_background = targets_instance_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
targets_background_samples = targets_background_index[shuffle_background_pixels[:num_effect_pixels]]
targets_background_instance = targets_background_samples[:(num_effect_pixels+1)//2]
targets_background_background = targets_background_samples[(num_effect_pixels+1)//2:]
masks_instance_instance = masks_instance_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
masks_instance_background = masks_instance_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
masks_background_samples = masks_background_index[shuffle_background_pixels[:num_effect_pixels]]
masks_background_instance = masks_background_samples[:(num_effect_pixels+1)//2]
masks_background_background = masks_background_samples[(num_effect_pixels+1)//2:]
if num_effect_pixels % 2 == 0:
inputs_instance_A = torch.cat((instance_instance, instance_background, background_instance), 0)
inputs_instance_B = torch.cat((instance_background, background_instance, background_background), 0)
targets_instance_A = torch.cat((targets_instance_instance, targets_instance_background, targets_background_instance), 0)
targets_instance_B = torch.cat((targets_instance_background, targets_background_instance, targets_background_background), 0)
masks_instance_A = torch.cat((masks_instance_instance, masks_instance_background, masks_background_instance), 0)
masks_instance_B = torch.cat((masks_instance_background, masks_background_instance, masks_background_background), 0)
else:
inputs_instance_A = torch.cat((instance_instance[:-1], instance_background, background_instance[:-1]), 0)
inputs_instance_B = torch.cat((instance_background, background_instance[:-1], background_background), 0)
targets_instance_A = torch.cat((targets_instance_instance[:-1], targets_instance_background, targets_background_instance[:-1]), 0)
targets_instance_B = torch.cat((targets_instance_background, targets_background_instance[:-1], targets_background_background), 0)
masks_instance_A = torch.cat((masks_instance_instance[:-1], masks_instance_background, masks_background_instance[:-1]), 0)
masks_instance_B = torch.cat((masks_instance_background, masks_background_instance[:-1], masks_background_background), 0)
inputs_A = torch.cat((inputs_A, inputs_instance_A), 0)
inputs_B = torch.cat((inputs_B, inputs_instance_B), 0)
targets_A = torch.cat((targets_A, targets_instance_A), 0)
targets_B = torch.cat((targets_B, targets_instance_B), 0)
masks_A = torch.cat((masks_A, masks_instance_A), 0)
masks_B = torch.cat((masks_B, masks_instance_B), 0)
elif num_effect_pixels == img_pixels:
# if the whole image belongs to one instance, use random sampling
random_sample_num = min_samples
random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)
inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
targets_A = torch.cat((targets_A, random_targets_A), 0)
targets_B = torch.cat((targets_B, random_targets_B), 0)
masks_A = torch.cat((masks_A, random_masks_A), 0)
masks_B = torch.cat((masks_B, random_masks_B), 0)
elif num_effect_pixels >= min_samples:
# for each instance, sample min_samples
# if pixels of background less than min_samples
if num_background_pixels < min_samples:
effect_min_samples = (num_background_pixels//2)*2
else:
effect_min_samples = min_samples
instance_instance = instance_id_index[shuffle_effect_pixels[:effect_min_samples//2]]
instance_background = instance_id_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
background_samples = background_index[shuffle_background_pixels[:effect_min_samples]]
background_instance = background_samples[:effect_min_samples//2]
background_background = background_samples[effect_min_samples//2:]
inputs_instance_A = torch.cat((instance_instance, instance_background, background_instance), 0)
inputs_instance_B = torch.cat((instance_background, background_instance, background_background), 0)
inputs_A = torch.cat((inputs_A, inputs_instance_A), 0)
inputs_B = torch.cat((inputs_B, inputs_instance_B), 0)
targets_instance_instance = targets_instance_index[shuffle_effect_pixels[:effect_min_samples//2]]
targets_instance_background = targets_instance_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
targets_background_samples = targets_background_index[shuffle_background_pixels[:effect_min_samples]]
targets_background_instance = targets_background_samples[:effect_min_samples//2]
targets_background_background = targets_background_samples[effect_min_samples//2:]
targets_instance_A = torch.cat((targets_instance_instance, targets_instance_background, targets_background_instance), 0)
targets_instance_B = torch.cat((targets_instance_background, targets_background_instance, targets_background_background), 0)
targets_A = torch.cat((targets_A, targets_instance_A), 0)
targets_B = torch.cat((targets_B, targets_instance_B), 0)
masks_instance_instance = masks_instance_index[shuffle_effect_pixels[:effect_min_samples//2]]
masks_instance_background = masks_instance_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
masks_background_samples = masks_background_index[shuffle_background_pixels[:effect_min_samples]]
masks_background_instance = masks_background_samples[:effect_min_samples//2]
masks_background_background = masks_background_samples[effect_min_samples//2:]
masks_instance_A = torch.cat((masks_instance_instance, masks_instance_background, masks_background_instance), 0)
masks_instance_B = torch.cat((masks_instance_background, masks_background_instance, masks_background_background), 0)
masks_A = torch.cat((masks_A, masks_instance_A), 0)
masks_B = torch.cat((masks_B, masks_instance_B), 0)
if len(targets_A) < 1000:
# if the selected point paris less than 1000, then random sampling
random_sample_num = 1000
random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)
inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
targets_A = torch.cat((targets_A, random_targets_A), 0)
targets_B = torch.cat((targets_B, random_targets_B), 0)
masks_A = torch.cat((masks_A, random_masks_A), 0)
masks_B = torch.cat((masks_B, random_masks_B), 0)
#GT ordinal relationship
target_ratio = torch.div(targets_A+1e-6, targets_B+1e-6)
mask_eq = target_ratio.lt(1.0 + self.sigma) * target_ratio.gt(1.0/(1.0+self.sigma))
labels = torch.zeros_like(target_ratio)
labels[target_ratio.ge(1.0 + self.sigma)] = 1
labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1
# consider forward-backward consistency checking
consistency_mask = masks_A * masks_B
equal_loss = (inputs_A - inputs_B).pow(2) * mask_eq.double() * consistency_mask
unequal_loss = torch.log(1 + torch.exp((-inputs_A + inputs_B) * labels)) * (~mask_eq).double() * consistency_mask
# Please comment the regularization term if you don't want to use the multi-scale gradient matching loss !!!
loss = loss + self.alpha * equal_loss.mean() + 1.0 * unequal_loss.mean() + 0.2 * regularization_loss.double()
return loss.float()/n
`
Dear author,
Do you plan on providing instance-guided sampling code? Thank you.