atulshanbhag / Layerwise-Relevance-Propagation

Implementation of Layerwise Relevance Propagation for heatmapping "deep" layers
97 stars 25 forks source link

Changing neural network architecture doesn't improve heatmaps #3

Open sudhasubramaniam opened 5 years ago

sudhasubramaniam commented 5 years ago

In mnist program, I modified it to feed images and changed neural network architecture by including more convolutional layers and tried for cat/Dog images instead of mnist data. I got heat maps which include features other than cat and dog also. please let me know what has to be done for getting proper heat maps.

sudhasubramaniam commented 5 years ago

model.py

import tensorflow as tf

class MNIST_CNN:

def init(self, name='MNIST_CNN'): self.name = name

def convlayer(self, input, shape, name): w_conv = tf.Variable(tf.truncatednormal(shape=shape, dtype=tf.float32, stddev=0.1), name='w{0}'.format(name)) bconv = tf.Variable(tf.constant(0.0, shape=shape[-1:], dtype=tf.float32), name='b{0}'.format(name)) conv = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input, w_conv, [1, 1, 1, 1], padding='SAME'), b_conv), name=name) return w_conv, b_conv, conv

def fclayer(self, input, shape, name, prop=True): w_fc = tf.Variable(tf.truncatednormal(shape=shape, dtype=tf.float32, stddev=0.1), name='w{0}'.format(name)) bfc = tf.Variable(tf.constant(0.0, shape=shape[-1:], dtype=tf.float32), name='b{0}'.format(name)) if prop: fc = tf.nn.relu(tf.nn.bias_add(tf.matmul(input, w_fc), b_fc), name=name) return w_fc, b_fc, fc else: return w_fc, b_fc

def call(self, images, reuse=False): with tf.variable_scope(self.name):

  if reuse:
    scope.reuse_variables()

  activations = []

  with tf.variable_scope('input'):
    images = tf.reshape(images, [-1, 128, 128, 1], name='input')
    activations += [images, ]

  with tf.variable_scope('conv1'):
    w_conv1, b_conv1, conv1 = self.convlayer(images, [3, 3, 1, 64], 'conv1')
    activations += [conv1, ]

  with tf.variable_scope('conv2'):
    w_conv2, b_conv2, conv2 = self.convlayer(conv1, [3, 3, 64, 64], 'conv2')
    activations += [conv2, ]

  with tf.variable_scope('max_pool1'):
    max_pool1 = tf.nn.max_pool(conv2, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool1')
    activations += [max_pool1, ]

  with tf.variable_scope('conv3'):
    w_conv3, b_conv3, conv3 = self.convlayer(max_pool1, [3, 3, 64, 128], 'conv3')
    activations += [conv3, ]

  with tf.variable_scope('conv4'):
    w_conv4, b_conv4, conv4 = self.convlayer(conv3, [3, 3, 128, 128], 'conv4')
    activations += [conv4, ]

  with tf.variable_scope('max_pool2'):
    max_pool2 = tf.nn.max_pool(conv4, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool2')
    activations += [max_pool2, ]

  with tf.variable_scope('conv5'):
    w_conv5, b_conv5, conv5 = self.convlayer(max_pool2, [3, 3, 128, 256], 'conv5')
    activations += [conv5, ]

  with tf.variable_scope('conv6'):
    w_conv6, b_conv6, conv6 = self.convlayer(conv5, [3, 3, 256, 256], 'conv6')
    activations += [conv6, ]

  with tf.variable_scope('max_pool3'):
    max_pool3 = tf.nn.max_pool(conv6, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool3')
    activations += [max_pool3, ] 

  with tf.variable_scope('conv7'):
    w_conv7, b_conv7, conv7 = self.convlayer(max_pool3, [3, 3, 256, 512], 'conv7')
    activations += [conv7, ]

  with tf.variable_scope('conv8'):
    w_conv8, b_conv8, conv8 = self.convlayer(conv7, [3, 3, 512, 512], 'conv8')
    activations += [conv8, ]

  with tf.variable_scope('max_pool4'):
    max_pool4 = tf.nn.max_pool(conv8, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool4')
    activations += [max_pool4, ]      

  with tf.variable_scope('flatten'):
    flatten = tf.contrib.layers.flatten(max_pool4)
    activations += [flatten, ]

  with tf.variable_scope('fc1'):
    n_in = int(flatten.get_shape()[1])
    w_fc1, b_fc1, fc1 = self.fclayer(flatten, [n_in, 4096], 'fc1')
    activations += [fc1, ]

  with tf.variable_scope('fc2'):
    n_in = int(fc1.get_shape()[1])
    w_fc2, b_fc2, fc2 = self.fclayer(fc1, [n_in, 4096], 'fc2')
    activations += [fc2, ]

  with tf.variable_scope('dropout2'):
    dropout2 = tf.nn.dropout(fc2, keep_prob=0.5, name='dropout2')

  with tf.variable_scope('output'):
    w_fc3, b_fc3 = self.fclayer(dropout2, [4096, 2], 'fc3', prop=False)
    logits = tf.nn.bias_add(tf.matmul(dropout2, w_fc3), b_fc3, name='logits')
    preds = tf.nn.softmax(logits, name='output')
    activations += [preds, ]

  return activations, logits

@property def params(self): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

sudhasubramaniam commented 5 years ago

train.py program

from utils import DataGenerator, MNISTLoader from model import MNIST_CNN

import tensorflow as tf

logdir = './logs/' chkpt = './logs/model.ckpt' n_epochs = 20 batch_size = 20

class Trainer:

def __init__(self):
    self.dataloader = MNISTLoader()

    self.x_train, self.y_train = self.dataloader.train
    #print("Train shape")
    #print(self.x_train.shape)
    self.x_validation, self.y_validation = self.dataloader.validation

    with tf.variable_scope('MNIST_CNN'):
        self.model = MNIST_CNN()

        self.X = tf.placeholder(tf.float32, [None,128,128], name='X')
        self.y = tf.placeholder(tf.float32, [None, 2], name='y')

        self.activations, self.logits = self.model(self.X)

        tf.add_to_collection('LayerwiseRelevancePropagation', self.X)
        for act in self.activations:
            tf.add_to_collection('LayerwiseRelevancePropagation', act)

        self.l2_loss = tf.add_n([tf.nn.l2_loss(p) for p in self.model.params if 'b' not in p.name]) * 0.001
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) + self.l2_loss
        self.optimizer = tf.train.AdamOptimizer().minimize(self.cost, var_list=self.model.params)

        self.preds = tf.equal(tf.argmax(self.logits, axis=1), tf.argmax(self.y, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.preds, tf.float32))

    self.cost_summary = tf.summary.scalar(name='Cost', tensor=self.cost)
    self.accuracy_summary = tf.summary.scalar(name='Accuracy', tensor=self.accuracy)

    self.summary = tf.summary.merge_all()

def run(self):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        self.file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())

        self.train_batch = DataGenerator(self.x_train, self.y_train, batch_size)
        self.validation_batch = DataGenerator(self.x_validation, self.y_validation, batch_size)

        for epoch in range(n_epochs):
            self.train(sess, epoch)
            self.validate(sess)
            self.saver.save(sess, chkpt)

def train(self, sess, epoch):
    n_batches = self.x_train.shape[0] // batch_size
    if self.x_train.shape[0] % batch_size != 0:
        n_batches += 1

    avg_cost = 0
    avg_accuracy = 0    
    for batch in range(n_batches):
        x_batch, y_batch = next(self.train_batch)
        _, batch_cost, batch_accuracy, summ = sess.run([self.optimizer, self.cost, self.accuracy, self.summary], 
                                                                                                        feed_dict={self.X: x_batch, self.y: y_batch})
        avg_cost += batch_cost
        avg_accuracy += batch_accuracy
        self.file_writer.add_summary(summ, epoch * n_batches + batch)

        completion = batch / n_batches
        print_str = '|'+int(completion * 20)*'#'+ (19 - int(completion * 20))  * ' ' + '|'
        print('\rEpoch {0:>3} {1} {2:3.0f}% Cost {3:6.4f} Accuracy {4:6.4f}'.format('#' + str(epoch + 1), print_str, completion * 100, avg_cost / (batch + 1), avg_accuracy / (batch + 1)), end='')
        #print("end="' ')

def validate(self, sess):
    n_batches = self.x_validation.shape[0] // batch_size
    if self.x_validation.shape[0] % batch_size != 0:
        n_batches += 1

    avg_accuracy = 0
    for batch in range(n_batches):
        x_batch, y_batch = next(self.validation_batch)
        avg_accuracy += sess.run([self.accuracy, ], feed_dict={self.X: x_batch, self.y: y_batch})[0]

    avg_accuracy /= n_batches
    print('Validation Accuracy {0:6.4f}'.format(avg_accuracy))

if name == 'main': Trainer().run()

sudhasubramaniam commented 5 years ago

utils.py program

import gzip import pickle import os import glob import cv2 import numpy as np import matplotlib.cm as cm import matplotlib.pyplot as plt

DATA_PATH = './mnist_png/training' TEST_PATH = './mnist_png/testing'

class DataGenerator:

def __init__(self, X, y, batch_size):
    assert(X.shape[0] == y.shape[0])
    self.X = X
    self.y = y
    self.batch_size = batch_size
    self.num_samples = X.shape[0]
    self.num_batches = X.shape[0] // self.batch_size
    if X.shape[0] % self.batch_size != 0:
        self.num_batches += 1
    self.batch_index = 0

def __iter__(self):
    return self

def __next__(self, shuffle=True):
    if self.batch_index == self.num_batches:
        self.batch_index = 0
        if shuffle:
            indices = np.random.permutation(self.num_samples)
            self.X = self.X[indices]
            self.y = self.y[indices]
    start = self.batch_index * self.batch_size
    end = min(self.num_samples, start + self.batch_size)
    self.batch_index += 1
    return self.X[start: end], self.y[start: end]   

class MNISTLoader:

def __init__(self, loc=DATA_PATH):
    self.loc = loc
    self.run()

def run(self):

 classes = ['cats','dogs']
 images = []
 labels = []
 ids = []
 cls = []

 for fld in classes:   # assuming data directory has a separate folder for each class, and that each folder is named after the class
        index = classes.index(fld)
        #print('Loading {} files (Index: {})'.format(fld, index))
        path = os.path.join(DATA_PATH, fld, '*g')
        files = glob.glob(path)
        for fl in files:
                image = cv2.imread(fl)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                #print(image.shape)
                image = cv2.resize(image, (128, 128), interpolation = cv2.INTER_LINEAR)
                image = image.astype(np.float32)
                image = np.multiply(image, 1.0 / 255.0)
                images.append(image)
                label = np.zeros(len(classes))
                label[index] = 1.0
                labels.append(label)
                flbase = os.path.basename(fl)
                ids.append(flbase)
                cls.append(fld)

 images = np.array(images)
 print('Total number of images {}'.format(len(images)))
 labels = np.array(labels)
 validation_size = int(0.2 * images.shape[0])

 '''try:
        with gzip.open(DATA_PATH, 'rb') as f:
            data = pickle.load(f, encoding='bytes')
    except FileNotFoundError:
        print('Dataset not found!')
        exit()'''

 self.x_validation = images[:validation_size]
 self.y_validation = labels[:validation_size]
 self.x_train = images[validation_size:]
 self.y_train = labels[validation_size:]
 print('Total number of training images {}'.format(len(self.x_train)))
 print('Total number of Validation images {}'.format(len(self.x_validation)))

 '''train_set, validation_set, test_set = data

 self.x_train, self.y_train = train_set
 self.x_validation, self.y_validation = validation_set
 self.x_test, self.y_test = test_set'''

 '''print(self.x_train[0].shape)
 plt.imshow(self.x_train[0])
 plt.show()'''

 test_images=[]
 test_labels=[]
 for fld in classes:   # assuming data directory has a separate folder for each class, and that each folder is named after the class
        index = classes.index(fld)
        #print('Loading {} files (Index: {})'.format(fld, index))
        path = os.path.join(TEST_PATH, fld, '*g')
        files = glob.glob(path)
        for fl in files:

                test_image = cv2.imread(fl)
                test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)
                test_image = cv2.resize(test_image, (128, 128), interpolation = cv2.INTER_LINEAR)
                test_images.append(test_image)
                test_label = np.zeros(len(classes))
                test_label[index] = 1.0
                test_labels.append(test_label)
                #flbase = os.path.basename(fl)
                #ids.append(flbase)
                #cls.append(fld)

 test_images = np.array(test_images)

 test_labels = np.array(test_labels)
 self.x_test = test_images
 self.y_test = test_labels

 '''I = np.eye(10)
 self.y_train = I[self.y_train]
 self.y_validation = I[self.y_validation]
 self.y_test = I[self.y_test]'''

'''def get_samples(self):
    #data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
    #samples_indices = np.random.choice(np.argwhere(np.argmax(data[1], axis=1) == digit).flatten(), size=n_samples)
    #return data[0][samples_indices]
    #print("%%%%")
    #print(self.x_test.shape)
    data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
    print("!!!!")
    #print(data.shape)
    samples_indices = (data[2].flatten())
    print(data[0][samples_indices].shape)
    return data[0][samples_indices]

    #return self.x_test'''

def get_samples(self):
    #data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
    #samples_indices = np.random.choice(np.argwhere(np.argmax(data[1], axis=1) == digit).flatten(), size=n_samples)
    #return data[0][samples_indices]
    print("%%%%")
    print(self.x_test.shape)
    return self.x_test

@property
def train(self):
    return self.x_train, self.y_train

@property
def validation(self):
    return self.x_validation, self.y_validation

@property
def test(self):
    return self.x_test, self.y_test

if name == 'main': dl = MNISTLoader()

train = dl.train
validation = dl.validation
test = dl.test

dg = DataGenerator(train[0], train[1], 20)
for i in range(5):
    x, y = next(dg)
    #print(i, x.shape, y.shape)

print('x_train shape', train[0].shape)

print('y_train shape', train[1].shape)

print('x_validation shape', validation[0].shape)
print('y_validation shape', validation[1].shape)

print('x_test shape', test[0].shape)
print('y_test shape', test[1].shape)    

print(dl.get_samples())
sudhasubramaniam commented 5 years ago

lrp.py program

from utils import MNISTLoader from tensorflow.python.ops import gen_nn_ops from matplotlib.cm import get_cmap

import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import cv2

logdir = './logs/' chkpt = './logs/model.ckpt' resultsdir = './results/'

class LayerwiseRelevancePropagation: print("Welcome") def init(self): self.dataloader = MNISTLoader() self.epsilon = 1e-10

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
  saver.restore(sess, tf.train.latest_checkpoint(logdir))

  weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='MNIST_CNN')
  self.activations = tf.get_collection('LayerwiseRelevancePropagation')

self.X = self.activations[0]

self.act_weights = {}
for act in self.activations[2:]:
  for wt in weights:
    if len(act.name.split('/'))>2:
      name = act.name.split('/')[2]
    if name == wt.name.split('/')[2]:
      if name not in self.act_weights:
        self.act_weights[name] = wt

self.activations = self.activations[:0:-1]
self.relevances = self.get_relevances()

def get_relevances(self): relevances = [self.activations[0], ]

for i in range(1, len(self.activations)):
  if len(self.activations[i - 1].name.split('/'))>2:
     name = self.activations[i - 1].name.split('/')[2]
     #print(name)
  if 'output' in name or 'fc' in name:
    relevances.append(self.backprop_fc(name, self.activations[i], relevances[-1]))
  elif 'flatten' in name:
    relevances.append(self.backprop_flatten(self.activations[i], relevances[-1]))
  elif 'max_pool' in name:
    relevances.append(self.backprop_max_pool2d(self.activations[i], relevances[-1]))
  elif 'conv' in name:
    relevances.append(self.backprop_conv2d(name, self.activations[i], relevances[-1]))
  else:
    #raise 'Error parsing layer!' 
    print("Error parsing layer!")

return relevances

def backprop_fc(self, name, activation, relevance): w = self.act_weights[name] w_pos = tf.maximum(0.0, w) z = tf.matmul(activation, w_pos) + self.epsilon s = relevance / z

print("!!!!")

#print(name,s.shape,w_pos.shape)
c = tf.matmul(s, tf.transpose(w_pos))
return c * activation

def backprop_flatten(self, activation, relevance): shape = activation.get_shape().as_list() shape[0] = -1

print("flatten")

return tf.reshape(relevance, shape)

def backprop_max_pool2d(self, activation, relevance, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]): z = tf.nn.max_pool(activation, ksize, strides, padding='SAME') + self.epsilon s = relevance / z c = gen_nn_ops.max_pool_grad_v2(activation, z, s, ksize, strides, padding='SAME')

print("Max pool")

return c * activation

def backprop_conv2d(self, name, activation, relevance, strides=[1, 1, 1, 1]): w = self.act_weights[name] w_pos = tf.maximum(0.0, w) z = tf.nn.conv2d(activation, w_pos, strides, padding='SAME') + self.epsilon s = relevance / z c = tf.nn.conv2d_backprop_input(tf.shape(activation), w_pos, s, strides, padding='SAME')

print("Conv")

return c * activation

def get_heatmap(self,i): samples = self.dataloader.get_samples()

with tf.Session() as sess:    
  saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
  saver.restore(sess, tf.train.latest_checkpoint(logdir))
  #samples[i]=samples[i].reshape(1,128,128)
  print(samples[i].shape,self.relevances[-1])
  cmap_type='rainbow'
  shape = list(samples[i].shape)
  cmap = get_cmap(name='rainbow')
  heatmap = cmap(samples[i].flatten())[:, :1]
  heatmap = heatmap
  print(heatmap.shape)
  shape[-1] = 3

return heatmap.reshape(128,128)

def test(self):

samples = self.dataloader.get_samples(n_samples=1, digit=np.random.choice(10))

samples=self.dataloader.get_samples()

#print(len(samples))
leng=len(samples)

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
  saver.restore(sess, tf.train.latest_checkpoint(logdir))
  R = sess.run(self.relevances, feed_dict={self.X: samples})
  #for r in R:
    #print(r.sum())
return leng

if name == 'main':

lent=LayerwiseRelevancePropagation().test()
print(lent)

for i in range(lent):
  heatmap = LayerwiseRelevancePropagation().get_heatmap(i)

  fig = plt.figure()
  ax = fig.add_subplot(111)
  ax.axis('off')
  ax.imshow(heatmap, cmap='Reds', interpolation='bilinear')

  fig.savefig('{0}{1}.jpg'.format(resultsdir, i))