piEsposito / blitz-bayesian-deep-learning

A simple and extensible library to create Bayesian Neural Network layers on PyTorch.
GNU General Public License v3.0
918 stars 107 forks source link

Image inference of a BayesianCNN on the MNIST dataset #74

Closed pieterblok closed 3 years ago

pieterblok commented 3 years ago

Hello @piEsposito , thank you very much for this nice pytonic implementation of Bayesian neural nets!

Excuse me for this massive posting, it contains my inference script (which is a bit lengthy).

I have used your training script example (blitz/examples/bayesian_LeNet_mnist.py) to train a Bayesian CNN on the MNIST dataset. Then, I made an inference script using the training weights:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from blitz.modules import BayesianLinear, BayesianConv2d
from blitz.losses import kl_divergence_from_nn
from blitz.utils import variational_estimator

import matplotlib.pyplot as plt
import numpy as np
import time
np.set_printoptions(formatter={'float_kind':'{:f}'.format})

train_dataset = dsets.MNIST(root="./data",
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True
                            )
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=64,
                                           shuffle=True)

test_dataset = dsets.MNIST(root="./data",
                             train=False,
                             transform=transforms.ToTensor(),
                             download=True
                            )
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=64,
                                           shuffle=True)

def plot_uncertain_images(uncertain_images, uncertain_vars):
    sorted_vars = uncertain_vars.copy()
    sorted_vars.sort()
    highest_vars = sorted_vars[len(sorted_vars)-20:]
    w=10
    h=10
    fig=plt.figure(figsize=(8, 8))
    columns = 4
    rows = 5
    for i in range(1, columns*rows +1):
        fig.add_subplot(rows, columns, i)
        idx = uncertain_vars.index(highest_vars[i-1])
        img = uncertain_images[idx]
        plt.imshow(img)
    plt.show()

@variational_estimator
class BayesianCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = BayesianConv2d(1, 6, (5,5))
        self.conv2 = BayesianConv2d(6, 16, (5,5))
        self.fc1   = BayesianLinear(256, 120)
        self.fc2   = BayesianLinear(120, 84)
        self.fc3   = BayesianLinear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier = BayesianCNN()
classifier.load_state_dict(torch.load("./weights/epoch-66.pt"))
classifier.to(device)
classifier.eval()

samples = 100
correct = 0
predicted = 0
uncertain_vars = []
uncertain_images = []

## do the image inference on the test-dataset
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        batch_size = images.shape[0]

        predictions = np.zeros((batch_size, samples)).astype(np.uint8)

        for i in range(samples):
            outputs = classifier(images.to(device))
            probs = F.softmax(outputs.data,1)
            preds = torch.argmax(probs,1)
            predictions[:,i] = preds.detach().cpu().numpy()

        var = np.var(predictions, axis=1)

        for j in range(batch_size):
            pred_var = var[j]
            if pred_var == 0:
                # I'm sure about the prediction, so I'm going to predict
                predicted += 1
                prediction = int(np.percentile(predictions[j,:], 50))
                correct += (prediction == labels[j].numpy()).sum().item()
            else:
                # I'm not sure about the prediction, so I'll skip the prediction
                uncertain_vars.append(pred_var)
                img = np.multiply(images[j].permute(1, 2, 0).detach().cpu().numpy(), 255).astype(np.uint8)
                uncertain_images.append(img)

print('Accuracy of the network on the {0:d} predicted test images: {1:.2f} %'.format(predicted, (100 * correct / predicted)))
plot_uncertain_images(uncertain_images, uncertain_vars)

## do the same trick, but then on 64 randomly generated images ('noise')
batch_size = 64
predicted = 0
images_random = torch.rand(batch_size,1,28,28)
labels_random = torch.randint(0,10,(batch_size,))
predictions = np.zeros((batch_size, samples)).astype(np.uint8)
certain_labels = []
certain_images = []

for i in range(samples):
    outputs = classifier(images_random.to(device))
    probs = F.softmax(outputs.data,1)
    preds = torch.argmax(probs,1)
    predictions[:,i] = preds.detach().cpu().numpy()

var = np.var(predictions, axis=1)
for j in range(batch_size):
    pred_var = var[j]
    if pred_var == 0:
        # I'm sure about the prediction, so I'm going to predict
        predicted += 1
        prediction = int(np.percentile(predictions[j,:], 50))
        img = np.multiply(images_random[j].permute(1, 2, 0).detach().cpu().numpy(), 255).astype(np.uint8)
        certain_images.append(img)
        certain_labels.append(prediction)
    else:
        # I'm not sure about the prediction, so I'll skip the prediction
        pass

print('{0:d} predictions were made on {1:d} images with random noise'.format(predicted, batch_size))
for h in range(len(certain_images)):
    plt.imshow(certain_images[h])
    plt.title('Prediction: {:d}'.format(certain_labels[h]))
    plt.show()

My main question: what is a good "decision rule" to select the MNIST-digits the BayesianCNN is less confident about?

As you can see in my script (pred_var == 0), I sampled the BayesianCNN 100 times and then rejected a digit when the variance of the 100 estimates exceeded 0 (meaning that the prediction is rejected when at least one of estimates deviates from the rest). I have also done a sanity check by simulating 64 random noise images, and then checking whether the BayesianCNN is giving uniform predictions there...

This is one of the outputs that was generated:

"Accuracy of the network on the 9054 predicted test images: 99.96 %"
"5 predictions were made on 64 images with random noise"

Logically, these 5 predictions (all predictions were of digit '8' by the way) are much better than the 64 predictions that were made with a standard LeNet trained on MNIST (without the Bayesian Layers).

Another question: the prediction of the BayesianCNN was done with the torch.nn.functional.softmax function (followed by the selection of the highest probability of the softmax). However I'm wondering if this softmax approach is correct... what would you advise? Is there a better probabilistic/Bayesian way?

Thanks in advance!

piEsposito commented 3 years ago

Hello @pieterbl86 , and sorry for the late reply, work and stuff was being very pushy. Anyway, thank you so much for suing BLiTZ.

In answer to your main question, there is actually no answer. The literature states that there is not yet a certain decision rule, but there are some ideas on how to follow.

When I use BLiTZ, I check the calibration of the model by seeing, for different Xs, what is the accuracy (or precision, or f1-score ou whatever metric I try) that the model give if I only keep the prediction of the X% less disperse predictions, getting the distribution of each prediction probability via monte-carlo sampling.

For the last layer, using monte-carlo sampling + softmax on top is the way I see on literature, on TensorFlow Probability and on research with my friends and peers.

Hope that helps. Feel free to reach me on my email or here if you have any doubts (I promise I will answer faster).

Best regards, -Pi

piEsposito commented 3 years ago

Colsing due to staleness.