pronobis / libspn

Library for learning and inference with Sum-product Networks
Other
23 stars 3 forks source link

log probability for all classes given an observation #90

Closed anicolson closed 4 years ago

anicolson commented 4 years ago

Hi, for tutorial 7, I was wondering which node from the graph I should give to sess.run() to get the log probability of each class given an observation.

E.g. for 32 observations from the MNIST set: the size of batch_x is [32,784]. I want output of sess.run() to be of size [32,10], where these are the log probabilities of each of the 10 classes.

Thank you very much!

jostosh commented 4 years ago

Hi, there's this part:

class_roots = spn.ParallelSums(full_scope_prod, num_sums=num_classes)
root = spn.Sum(class_roots)

You can add something along these lines:

# Make another root that doesn't take in latent indicators, thereby always marginalizing over class
marginalizing_root = spn.Sum(class_roots, weights=root.weights)

log_p_x_given_y = class_roots.get_log_value()
log_p_y = marginalizing_root.weights.node.get_log_value()
log_p_x = marginalizing_root.get_log_value()

# Good ol' Bayes rule: p(y|x) == p(x|y)p(y)/p(x)
log_p_y_given_x = (log_p_x_given_y + log_p_y) - log_p_x
p_y_given_x = tf.exp(log_p_y_given_x)
anicolson commented 4 years ago

Thank you! I'll give it a try