haofanwang / Score-CAM

Official implementation of Score-CAM in PyTorch
MIT License
399 stars 66 forks source link

Efficiency problem regarding the implementation of Score-CAM #8

Closed YuhengHuang42 closed 3 years ago

YuhengHuang42 commented 3 years ago

Dear author:

First of all, thanks for your great work!

When I check the code of scorecam.py, I notice that it computes the score_saliency_map for every single instance. This is sometimes inefficient when you want to compute the score-cam for a number of instances.

I also read your paper and find that for your algorithm it is in fact possible to compute score-cam for a mini-batch (correct me if I am wrong). This can be more efficient than computing for a single instance.

However, to do the mini-batch computations, some of the codes need to be modified:

  1. I saw in ScoreCAM, you did score.backward(retain_graph=retain_graph). But according to my understanding, Score-CAM is gradient-free so the backward computation is in fact useless. We need to remove the backward process before we can do computation for mini-batch input.

  2. In ScoreCAM you skip the computation whenever saliency_map.max() == saliency_map.min() . This logic needs to be implemented for mini-batch computation as well.

I will leave my codes here. I have tested for a few instances and did not find any problems. This implementation will spend around 41 second for 32 instances on my server. And computing for one instance will spend around 16 second. So there is improvement for the efficiency problem.

As I am not sure whether the codes are correct, I will leave them below. You can check them when you are free.

def forward(self, input, class_idx=None, retain_graph=False):
    b, c, h, w = input.size()
    # predication on raw input
    logit = self.model_arch(input).cuda()

    if class_idx is None:
        predicted_class = logit.max(1)[-1]
        #score = logit[:, logit.max(1)[-1]].squeeze()
    else:
        predicted_class = class_idx.long() # assume the class_idx in tensor form
        #predicted_class = torch.LongTensor([class_idx])
        #score = logit[:, class_idx].squeeze()

    logit = F.softmax(logit, dim=1)

    if torch.cuda.is_available():
      predicted_class= predicted_class.cuda()
      #score = score.cuda()
      logit = logit.cuda()

    #self.model_arch.zero_grad()
    #score.backward(retain_graph=retain_graph)

    predicted_class = predicted_class.reshape(-1, 1)

    activations = self.activations['value']
    b, k, u, v = activations.size()

    score_saliency_map = torch.zeros((b, 1, h, w))

    if torch.cuda.is_available():
      activations = activations.cuda()
      score_saliency_map = score_saliency_map.cuda()

    with torch.no_grad():
      for i in range(k):

          # upsampling
          saliency_map = torch.unsqueeze(activations[:, i, :, :], 1)
          saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)

          #if saliency_map.max() == saliency_map.min():
          #  continue

          # normalize to 0-1
          saliency_max = saliency_map.view(b, -1).max(dim=1)[0]
          saliency_max = saliency_max.reshape(b, 1, 1, 1).repeat(1, 1, h, w)
          saliency_min = saliency_map.view(b, -1).min(dim=1)[0]
          saliency_min = saliency_min.reshape(b, 1, 1, 1).repeat(1, 1, h, w)
          norm_saliency_map = (saliency_map - saliency_min) / (saliency_max - saliency_min + 1e-7)

          # how much increase if keeping the highlighted region
          # predication on masked input
          output = self.model_arch(input * norm_saliency_map)
          output = F.softmax(output, dim=-1)
          #score = output[0][predicted_class]
          score = output[torch.arange(predicted_class.size(0)).unsqueeze(1), predicted_class]
          # Apply the torch.where function, so the score of saliency_map.max() == saliency_map.min() instance is 0.
          score = torch.where(saliency_map.view(b, -1).max(dim=1)[0].reshape(b, 1) > saliency_map.view(b, -1).min(dim=1)[0].reshape(b, 1), 
                                score, torch.zeros_like(score))

          score = score.reshape(b, 1, 1, 1).repeat(1, 1, h, w)
          score_saliency_map +=  score * saliency_map

    score_saliency_map = F.relu(score_saliency_map)
    score_saliency_map_min = score_saliency_map.view(b, -1).min(dim=1)[0]
    score_saliency_map_min = score_saliency_map_min.reshape(b, 1, 1, 1).repeat(1, 1, h, w)
    score_saliency_map_max = score_saliency_map.view(b, -1).max(dim=1)[0]
    score_saliency_map_max = score_saliency_map_max.reshape(b, 1, 1, 1).repeat(1, 1, h, w)
    #score_saliency_map_min, score_saliency_map_max = score_saliency_map.min(), score_saliency_map.max()

    # count_nonzero is only available after pytorch 1.7.0 
    if len(((score_saliency_map_max - score_saliency_map_min) == 0).nonzero(as_tuple=False)) != 0:
        raise Exception

    #if score_saliency_map_min == score_saliency_map_max:
    #    return None

    score_saliency_map = (score_saliency_map - score_saliency_map_min).div(score_saliency_map_max - score_saliency_map_min).data
    return score_saliency_map
haofanwang commented 3 years ago

Hi, @hyhzxhy, thanks for your implementation. It should be possible to do a mini-batch.

Feel free to pull a request as a new file (for example, socrecam_batch.py), I will check it later.

YuhengHuang42 commented 3 years ago

I will close this issue as I can see there is another implementation which is potentially faster than this one.