ykwon0407 / UQ_BNN

Uncertainty quantification using Bayesian neural networks in classification (MIDL 2018, CSDA)
133 stars 21 forks source link

about p_hat #4

Closed ShellingFord221 closed 5 years ago

ShellingFord221 commented 5 years ago

Hi, According to your code, p_hat is the matrix of all test data's probability vector after softmax, e.g. [[0.3, 0.2, 0.5], [0.1, 0.8, 0.1], [0.4, 0.2, 0.4], [0.7, 0.1, 0.2]] (assumed that there are 3 classes, two test data and the number of stochastic dropout is 2), and aleatoric and epistemic uncertainty are calculated according to p_hat, am I correct? Thanks!

ykwon0407 commented 5 years ago

@ShellingFord221 Hello! My code implicitly assumes that the problem to be solved is a binary classification. As in this link, the outputs of the model will calculate the probability of being Class 1.

If you try to solve a multi-class classification, you need to change a few lines from my code. After the changes, a test datum's probability vector after softmax should be like

[[0.3, 0.2, 0.5], #-> p_hat_stochastic_1 [0.1, 0.8, 0.1]] #-> p_hat_stochastic_2

Then, aleatoric and epistemic, which should be a 3 by 3 matrix, will be followed.

ShellingFord221 commented 5 years ago

Hi, According to the settings above, i.e. there are 3 classes, should the code be changed as: def predict(model, image, gt, T=10): ...

predict stochastic dropout model T times

       p_hat_l = []
       for t in range(T):
             p_hat_l.append( model.predict(image) )
       p_hat = np.array(p_hat_l)

       # mean prediction in single class
       prediction = np.mean(p_hat, axis=0)

       # storage sample probabilities for each class
       class_1 = np.array([i[0] for i in p_hat_l])   # [0.3 0.1]
       class_2 = np.array([i[1] for i in p_hat_l])   # [0.2, 0.8]
       class_3 = np.array([i[2] for i in p_hat_l])   # [0.5, 0.1]

       aleatoric = np.mean(class_1*class_2*class_3, axis=0)
       epistemic = np.mean(p_hat**2, axis=0) - np.mean(p_hat, axis=0)**2

       ...
return ...

Am I correct? Thanks!
ykwon0407 commented 5 years ago

@ShellingFord221 Hello~ I guess this answer may help you!