Closed benediktwagner closed 6 years ago
If I understand your question correctly, by labels you are referring to numerical labels in the visualization of the decision path in a trained decision tree, right?
Then, the answer can be split to 3 parts:
set
of unique numbers (MNIST classes) is created from the list
of labels of all leaves from the sub-tree of the given node (this is done through check of the common prefix in node title if k.startswith(key)
) retrieved as argmax
of the distributions learned by corresponding leaves (bigots). All of this logic is encompassed here:
https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L76-L79argmax
es of their learned (bigot) distributions, denoting the class predicted by given leaf:
https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L90input_img
with current kernel
is calculated and aggregated into logit
(note that bias
is added as well to conform with the inference of the model), which is used to determine the decision path on current node according to the decision rule as described in the paper (sigmoid activation is ommited here, but we know how it looks - the boundary is at 0
). This logit
now quantifies correlation of the input with the kernel of current node, so we plot it below the kernel as an indication of how strong the decision was (the further from 0
, the stronger, obviously)
https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L66-L73
3.2 Nodes on the Maximum probability path are collected here as well
https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L69
3.3 Maximum probability path is used to determine color of arrows (green along this path, red otherwise)
https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L99-L101@benbetze I hope that this answers your intended question. Please close this issue if it does, otherwise let me know but take care to formulate it properly.
First of all: thank you providing such a detailed answer into your implementation – I really appreciate it.
I was indeed talking about the numerical labels in the visualisation and more specifically the bigot distribution (including handling for sub-trees) which is clearer to me now. However, there is one last point of confusion left for me: I have trouble interpreting the bigot distribution itself. I understand this a learned probability distribution over all classes. Do we tune these distributions as well as the kernels on the inner nodes? Or: Are these distributions our "observations" from the training set by storing where the input training images ended up? (I have struggle connecting Equation 2 and Figure 1 in the paper)
Thanks so much again for your help and patience.
I understand this a learned probability distribution over all classes.
Correct
Do we tune these distributions as well as the kernels on the inner nodes?
Yes we do, but not directly. From Equation 2: we tune parameters phi
of those distributions and apply softmax activation to these parameters in order to get the actual distributions that are used in 2 ways:
softmax(phi)
and corresponding label, weighted by its path probability (this is inherited from the re-used concept of mixture of experts, see this video for a quick intro). argmax(softmax(phi))
we retrieve the static predictions of individual leaves and prediction of the tree is chosen to be the prediction of the leaf with maximum path probability. During the inference, maximum path probability is function of input, while actual distributions of leaves are independent of the input data, thus they act as Bigots.On the inference illustration below you can see how input image is used to determine the maximum probability path (green arrows) that leads to a single leaf / bigot / distribution that corresponds to model prediction. The distribution parameters phi
themselves are stored as a parameters of the model and act as constants in the inference (calculation of model prediction).
Thanks so much! You have been really helpful..
Thanks for sharing! How exactly do you label the paths and especially leafs? Are you storing the probabilities to look for likeliest labels at each node and look in which leafs data points ended up from your training set?
Thanks