layumi / Person_reID_baseline_pytorch

:bouncing_ball_person: Pytorch ReID: A tiny, friendly, strong pytorch implement of person re-id / vehicle re-id baseline. Tutorial 👉https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
https://www.zdzheng.xyz
MIT License
4.09k stars 1k forks source link

Class activation heat map. #149

Open Rajat-Mehta opened 5 years ago

Rajat-Mehta commented 5 years ago

Is there a way to compute and visualize the class activation heatmaps on the query image (or the resultant reidentified image). Which can tell us at which parts of the image did the model focus more on to generate the final results?

layumi commented 5 years ago

This is my code. I used a different model loader. You need to modify the model loader part to use it.

##################################
# Visualize HearMap by sum
# Zheng, Zhedong, Liang Zheng, and Yi Yang. "A discriminatively learned cnn embedding for person reidentification." ACM Transactions on Multimedia Computing, Communications, and Applications (TOMM) 14, no. 1 (2018): 13.
###################################

import os
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
from model import ft_net, ft_net_dense, ft_net_NAS, PCB, PCB_test
from utils import load_network
import yaml
import argparse
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image

parser = argparse.ArgumentParser(description='Training')

parser.add_argument('--data_dir',default='../Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=1, type=int, help='batchsize')

opt = parser.parse_args()

config_path = os.path.join('./model',opt.name,'opts.yaml')
with open(config_path, 'r') as stream:
        config = yaml.load(stream)
opt.fp16 = config['fp16']
opt.PCB = config['PCB']
opt.use_dense = config['use_dense']
opt.use_NAS = config['use_NAS']
opt.stride = config['stride']

if 'h' in config:
    opt.h = config['h']
    opt.w = config['w']

if 'nclasses' in config: # tp compatible with old config files
    opt.nclasses = config['nclasses']
else:
    opt.nclasses = 751

def heatmap2d(img, arr):
    fig = plt.figure()
    ax0 = fig.add_subplot(121, title="Image")
    ax1 = fig.add_subplot(122, title="Heatmap")

    ax0.imshow(Image.open(img))
    heatmap = ax1.imshow(arr, cmap='viridis')
    fig.colorbar(heatmap)
    #plt.show()
    fig.savefig('heatmap')

data_transforms = transforms.Compose([
        transforms.Resize((opt.h, opt.w), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['train']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=1) for x in ['train']}

imgpath = image_datasets['train'].imgs
model, _, epoch = load_network(opt.name, opt)
model.classifier.classifier = nn.Sequential()
model = model.eval().cuda()

data = next(iter(dataloaders['train']))
img, label = data
with torch.no_grad():
    x = model.model.conv1(img.cuda())
    x = model.model.bn1(x)
    x = model.model.relu(x)
    x = model.model.maxpool(x)
    x = model.model.layer1(x)
    x = model.model.layer2(x)
    output = model.model.layer3(x)
    #output = model.model.layer4(x)

print(output.shape)
heatmap = output.squeeze().sum(dim=0).cpu().numpy()
print(heatmap.shape)
#test_array = np.arange(100 * 100).reshape(100, 100)
# Result is saved tas `heatmap.png`
heatmap2d(imgpath[0][0],heatmap)
layumi commented 5 years ago

https://github.com/layumi/Person-reID-verification/issues/4

Rajat-Mehta commented 5 years ago

Thanks @layumi, it worked for me. I am able to generate the heatmaps.

Please find below a snapshot of the heatmap that I generated for my dataset:

heatmap

But what if I need the heatmaps plotted on the original image as shown in this image taken from one or your project. I think these heatmaps look more realistic:

Screenshot from 2019-07-04 18-05-00

layumi commented 5 years ago

@Rajat-Mehta One simple way is 0.5original image + 0.5heatmap, and then imshow the combined result.

Rajat-Mehta commented 5 years ago

@layumi What if the dimensions of the original image and heatmap are not same? In my case, the original image is: 256 128 and heatmap is 16 8. The addition as you suggested won't work in that case.

layumi commented 5 years ago

Please resize the heatmap to adapt the size of original image.

lynnw123 commented 5 years ago

@layumi how to combine result such as 0.5original image + 0.5heatmap after resize the heatmap to the same of original image

layumi commented 5 years ago

@lynnw123 Just add them together and clip the value. (if value > 255, then reset to 255) For example,

combined_result = np.uint8(0.5 * x + 0.5 * y)
lynnw123 commented 5 years ago

The combined image did not look correct:

1

I did the following changes: heatmap = np.resize(heatmap, (128,64)) Inside heatmap2d func:
img = np.resize(Image.open(img), (128,64)) combined = np.uint8(0.5img + 0.5 arr) heatmap = ax1.imshow(combined)

bmiftah commented 4 years ago

@lynnw123 Did you solve this issue ? I mean , are you able to display the heatmap on to of the image and see which places got activated in relation with the image ?

@layumi I was wondering if i can display the 2048x16x8 activation from layer-4 on top of the input image 256x128x3 .. .when i display the activation as 16x8 - the heatmap is not clear(as shown below) it looks the heatmap is stretached out and activated regions are not clear ..

cam_sample_market_1

milliema commented 4 years ago

After we have got the overall heatmap, how can we know which part each local branch is focusing at?

layumi commented 4 years ago

Hi @milliema You may modify the code by replacing the sum with the index to visualize the heatmap of any specific layer.

milliema commented 4 years ago

Hi @milliema You may modify the code by replacing the sum with the index to visualize the heatmap of any specific layer.

Thanks a lot for the quick reply. In my understanding, the fmp after backbone (layer3/4) is of size 24x8x1024(HxWxC), by taking the sum along C dimension, we get the overall heatmap of 24x8. As for the heatmap of local branch, since the patching is conducted on backbone fmp along H dimension, does this mean the top 4x8 part in the overall heatmap correspond to the heatmap of 1st local branch? Is it correct?

layumi commented 4 years ago

Hi @milliema Yes. If you use PCB, you evenly split the 24x8 to 6parts of 4x8. Otherwise, if the partpooling is 5 parts (could not be divided by 24), Pytorch may split parts with overlappings.