KexianHust / Structure-Guided-Ranking-Loss

Structure-Guided Ranking Loss for Single Image Depth Prediction
Other
182 stars 18 forks source link

Instance-Guided Sampling #18

Open unlugi opened 1 year ago

unlugi commented 1 year ago

Dear author,

Do you plan on providing instance-guided sampling code? Thank you.

KexianHust commented 9 months ago

You can refer the following code if you are interested.

` ################################################################################

Structure Guided Ranking Loss

################################################################################ class StructureGuidedRankingLoss(nn.Module): def init(self, point_pairs=3000, sigma=0.03, alpha=1.0, mask_value=-1e-1): super(StructureGuidedRankingLoss, self).init() self.point_pairs = point_pairs # number of point pairs self.sigma = sigma # used for determining the ordinal relationship between a selected pair self.alpha = alpha # used for balancing the effect of = and (<,>) self.mask_value = mask_value self.regularization_loss = GradientLoss(scales=4)

def getEdge(self, images):
    n,c,h,w = images.size()
    a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
    b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
    if c == 3:
        gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a)
        gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b)
    else:
        gradient_x = F.conv2d(images, a)
        gradient_y = F.conv2d(images, b)
    edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2))
    edges = F.pad(edges, (1,1,1,1), "constant", 0)
    thetas = torch.atan2(gradient_y, gradient_x)
    thetas = F.pad(thetas, (1,1,1,1), "constant", 0)

    return edges, thetas

def ind2sub(self, idx, cols):
    r = idx / cols
    c = idx - r * cols
    return r, c

def sub2ind(self, r, c, cols):
    idx = r * cols + c
    return idx

def forward(self, inputs, targets, images, masks, instance_masks):
    regularization_loss = self.regularization_loss(inputs.squeeze(1), targets.squeeze(1), masks.squeeze(1))

    # RGB domain
    edges_img, thetas_img = self.getEdge(images)

    #=============================
    n,c,h,w = targets.size()
    if n != 1:
        inputs = inputs.view(n, -1).double()
        targets = targets.view(n, -1).double()
        masks = masks.view(n, -1).double()
        instance_masks = instance_masks.view(n, -1).double()
        edges_img = edges_img.view(n, -1).double()
        thetas_img = thetas_img.view(n, -1).double()
    else:
        inputs = inputs.contiguous().view(1, -1).double()
        targets = targets.contiguous().view(1, -1).double()
        masks = masks.contiguous().view(1, -1).double()
        instance_masks = instance_masks.contiguous().view(1, -1).double()
        edges_img = edges_img.contiguous().view(1, -1).double()
        thetas_img = thetas_img.contiguous().view(1, -1).double()

    loss = torch.DoubleTensor([0.]).cuda()
    inputs_edge = []
    minlen = []

    min_samples = self.point_pairs
    img_pixels = h*w
    for i in range(n):
        inputs_A = torch.Tensor([]).double().cuda()
        inputs_B = torch.Tensor([]).double().cuda()
        targets_A = torch.Tensor([]).double().cuda()
        targets_B = torch.Tensor([]).double().cuda()
        masks_A = torch.Tensor([]).double().cuda()
        masks_B = torch.Tensor([]).double().cuda()
        # if exists instances
        instance_num = torch.max(instance_masks[i])

        #print('For this image, instance number is:', instance_num)
        unique_instance_id = torch.unique(instance_masks[i])
        total_instance_pixels = len(instance_masks[i, :].gt(0))
        if instance_num == 0:
            # use EGS + RS
            #
            # Edge-Guided sampling
            inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w)
            # Random Sampling
            random_sample_num = sample_num
            random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)

            # Combine ER + RS
            inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
            inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
            targets_A = torch.cat((targets_A, random_targets_A), 0)
            targets_B = torch.cat((targets_B, random_targets_B), 0)
            masks_A = torch.cat((masks_A, random_masks_A), 0)
            masks_B = torch.cat((masks_B, random_masks_B), 0)

        else:
            # instance guided sampling
            #print('Instance-guided sampling: ', unique_instance_id) #[0,1,2,3]
            for step, instance_id in enumerate(unique_instance_id):
                if instance_id == 0:
                    continue

                find_instance = instance_masks[i, :].eq(instance_id)
                instance_id_index = torch.masked_select(inputs[i, :], find_instance)
                background_index = torch.masked_select(inputs[i, :], ~find_instance)

                targets_instance_index = torch.masked_select(targets[i, :], find_instance)
                targets_background_index = torch.masked_select(targets[i, :], ~find_instance)

                masks_instance_index = torch.masked_select(masks[i, :], find_instance)
                masks_background_index = torch.masked_select(masks[i, :], ~find_instance)

                num_effect_pixels = len(instance_id_index)
                num_background_pixels = img_pixels - num_effect_pixels
                shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda()
                shuffle_background_pixels = torch.randperm(num_background_pixels).cuda()

                if num_effect_pixels < 20:
                    continue

                elif (num_effect_pixels < min_samples) and (num_effect_pixels >= 20):
                    instance_instance = instance_id_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
                    instance_background = instance_id_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
                    background_samples = background_index[shuffle_background_pixels[:num_effect_pixels]]
                    background_instance = background_samples[:(num_effect_pixels+1)//2]
                    background_background = background_samples[(num_effect_pixels+1)//2:]

                    targets_instance_instance = targets_instance_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
                    targets_instance_background = targets_instance_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
                    targets_background_samples = targets_background_index[shuffle_background_pixels[:num_effect_pixels]]
                    targets_background_instance = targets_background_samples[:(num_effect_pixels+1)//2]
                    targets_background_background = targets_background_samples[(num_effect_pixels+1)//2:]

                    masks_instance_instance = masks_instance_index[shuffle_effect_pixels[:(num_effect_pixels+1)//2]]
                    masks_instance_background = masks_instance_index[shuffle_effect_pixels[(num_effect_pixels+1)//2:]]
                    masks_background_samples = masks_background_index[shuffle_background_pixels[:num_effect_pixels]]
                    masks_background_instance = masks_background_samples[:(num_effect_pixels+1)//2]
                    masks_background_background = masks_background_samples[(num_effect_pixels+1)//2:]

                    if num_effect_pixels % 2 == 0:
                        inputs_instance_A = torch.cat((instance_instance, instance_background, background_instance), 0)
                        inputs_instance_B = torch.cat((instance_background, background_instance, background_background), 0)
                        targets_instance_A = torch.cat((targets_instance_instance, targets_instance_background, targets_background_instance), 0)
                        targets_instance_B = torch.cat((targets_instance_background, targets_background_instance, targets_background_background), 0)
                        masks_instance_A = torch.cat((masks_instance_instance, masks_instance_background, masks_background_instance), 0)
                        masks_instance_B = torch.cat((masks_instance_background, masks_background_instance, masks_background_background), 0)
                    else:
                        inputs_instance_A = torch.cat((instance_instance[:-1], instance_background, background_instance[:-1]), 0)
                        inputs_instance_B = torch.cat((instance_background, background_instance[:-1], background_background), 0)
                        targets_instance_A = torch.cat((targets_instance_instance[:-1], targets_instance_background, targets_background_instance[:-1]), 0)
                        targets_instance_B = torch.cat((targets_instance_background, targets_background_instance[:-1], targets_background_background), 0)
                        masks_instance_A = torch.cat((masks_instance_instance[:-1], masks_instance_background, masks_background_instance[:-1]), 0)
                        masks_instance_B = torch.cat((masks_instance_background, masks_background_instance[:-1], masks_background_background), 0)

                    inputs_A = torch.cat((inputs_A, inputs_instance_A), 0)
                    inputs_B = torch.cat((inputs_B, inputs_instance_B), 0)
                    targets_A = torch.cat((targets_A, targets_instance_A), 0)
                    targets_B = torch.cat((targets_B, targets_instance_B), 0)
                    masks_A = torch.cat((masks_A, masks_instance_A), 0)
                    masks_B = torch.cat((masks_B, masks_instance_B), 0)

                elif num_effect_pixels == img_pixels:
                    # if the whole image belongs to one instance, use random sampling
                    random_sample_num = min_samples
                    random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)

                    inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
                    inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
                    targets_A = torch.cat((targets_A, random_targets_A), 0)
                    targets_B = torch.cat((targets_B, random_targets_B), 0)
                    masks_A = torch.cat((masks_A, random_masks_A), 0)
                    masks_B = torch.cat((masks_B, random_masks_B), 0)

                elif num_effect_pixels >= min_samples:
                    # for each instance, sample min_samples
                    # if pixels of background less than min_samples
                    if  num_background_pixels < min_samples:
                        effect_min_samples = (num_background_pixels//2)*2
                    else:
                        effect_min_samples = min_samples

                    instance_instance = instance_id_index[shuffle_effect_pixels[:effect_min_samples//2]]
                    instance_background = instance_id_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
                    background_samples = background_index[shuffle_background_pixels[:effect_min_samples]]
                    background_instance = background_samples[:effect_min_samples//2]
                    background_background = background_samples[effect_min_samples//2:]
                    inputs_instance_A = torch.cat((instance_instance, instance_background, background_instance), 0)
                    inputs_instance_B = torch.cat((instance_background, background_instance, background_background), 0)
                    inputs_A = torch.cat((inputs_A, inputs_instance_A), 0)
                    inputs_B = torch.cat((inputs_B, inputs_instance_B), 0)

                    targets_instance_instance = targets_instance_index[shuffle_effect_pixels[:effect_min_samples//2]]
                    targets_instance_background = targets_instance_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
                    targets_background_samples = targets_background_index[shuffle_background_pixels[:effect_min_samples]]
                    targets_background_instance = targets_background_samples[:effect_min_samples//2]
                    targets_background_background = targets_background_samples[effect_min_samples//2:]
                    targets_instance_A = torch.cat((targets_instance_instance, targets_instance_background, targets_background_instance), 0)
                    targets_instance_B = torch.cat((targets_instance_background, targets_background_instance, targets_background_background), 0)
                    targets_A = torch.cat((targets_A, targets_instance_A), 0)
                    targets_B = torch.cat((targets_B, targets_instance_B), 0)

                    masks_instance_instance = masks_instance_index[shuffle_effect_pixels[:effect_min_samples//2]]
                    masks_instance_background = masks_instance_index[shuffle_effect_pixels[effect_min_samples//2:effect_min_samples]]
                    masks_background_samples = masks_background_index[shuffle_background_pixels[:effect_min_samples]]
                    masks_background_instance = masks_background_samples[:effect_min_samples//2]
                    masks_background_background = masks_background_samples[effect_min_samples//2:]
                    masks_instance_A = torch.cat((masks_instance_instance, masks_instance_background, masks_background_instance), 0)
                    masks_instance_B = torch.cat((masks_instance_background, masks_background_instance, masks_background_background), 0)
                    masks_A = torch.cat((masks_A, masks_instance_A), 0)
                    masks_B = torch.cat((masks_B, masks_instance_B), 0)

            if len(targets_A) < 1000:
                # if the selected point paris less than 1000, then random sampling
                random_sample_num = 1000
                random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)

                inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
                inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
                targets_A = torch.cat((targets_A, random_targets_A), 0)
                targets_B = torch.cat((targets_B, random_targets_B), 0)
                masks_A = torch.cat((masks_A, random_masks_A), 0)
                masks_B = torch.cat((masks_B, random_masks_B), 0)

        #GT ordinal relationship
        target_ratio = torch.div(targets_A+1e-6, targets_B+1e-6)
        mask_eq = target_ratio.lt(1.0 + self.sigma) * target_ratio.gt(1.0/(1.0+self.sigma))
        labels = torch.zeros_like(target_ratio)
        labels[target_ratio.ge(1.0 + self.sigma)] = 1
        labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1

        # consider forward-backward consistency checking
        consistency_mask = masks_A * masks_B

        equal_loss = (inputs_A - inputs_B).pow(2) * mask_eq.double() * consistency_mask
        unequal_loss = torch.log(1 + torch.exp((-inputs_A + inputs_B) * labels)) * (~mask_eq).double() * consistency_mask

        # Please comment the regularization term if you don't want to use the multi-scale gradient matching loss !!!
        loss = loss + self.alpha * equal_loss.mean() + 1.0 * unequal_loss.mean() + 0.2 * regularization_loss.double()

    return loss.float()/n

`