lmartak / distill-nn-tree

Distillation of Neural Network Into a Soft Decision Tree
https://vgg.fiit.stuba.sk/people/martak/distill-nn-tree
MIT License
64 stars 29 forks source link

General Question #1

Closed benediktwagner closed 6 years ago

benediktwagner commented 6 years ago

Thanks for sharing! How exactly do you label the paths and especially leafs? Are you storing the probabilities to look for likeliest labels at each node and look in which leafs data points ended up from your training set?

Thanks

lmartak commented 6 years ago

If I understand your question correctly, by labels you are referring to numerical labels in the visualization of the decision path in a trained decision tree, right?

Then, the answer can be split to 3 parts:

  1. The numbers above inner nodes (shown as kernels) of decision tree are calculated in a bottom-up manner. Thus set of unique numbers (MNIST classes) is created from the list of labels of all leaves from the sub-tree of the given node (this is done through check of the common prefix in node title if k.startswith(key)) retrieved as argmax of the distributions learned by corresponding leaves (bigots). All of this logic is encompassed here: https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L76-L79
  2. The numbers below leaves are simply argmaxes of their learned (bigot) distributions, denoting the class predicted by given leaf: https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L90
  3. The numbers below inner nodes (shown only below nodes on the maximum probability path) are only shown if input image is provided. In such case: 3.1 Correlation with input_img with current kernel is calculated and aggregated into logit (note that bias is added as well to conform with the inference of the model), which is used to determine the decision path on current node according to the decision rule as described in the paper (sigmoid activation is ommited here, but we know how it looks - the boundary is at 0). This logit now quantifies correlation of the input with the kernel of current node, so we plot it below the kernel as an indication of how strong the decision was (the further from 0, the stronger, obviously) https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L66-L73 3.2 Nodes on the Maximum probability path are collected here as well https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L69 3.3 Maximum probability path is used to determine color of arrows (green along this path, red otherwise) https://github.com/lmartak/distill-nn-tree/blob/7ea4632eb497808b1a70623c48e4aceffb15fcac/models/utils.py#L99-L101

@benbetze I hope that this answers your intended question. Please close this issue if it does, otherwise let me know but take care to formulate it properly.

benediktwagner commented 6 years ago

First of all: thank you providing such a detailed answer into your implementation – I really appreciate it.

I was indeed talking about the numerical labels in the visualisation and more specifically the bigot distribution (including handling for sub-trees) which is clearer to me now. However, there is one last point of confusion left for me: I have trouble interpreting the bigot distribution itself. I understand this a learned probability distribution over all classes. Do we tune these distributions as well as the kernels on the inner nodes? Or: Are these distributions our "observations" from the training set by storing where the input training images ended up? (I have struggle connecting Equation 2 and Figure 1 in the paper)

Thanks so much again for your help and patience.

lmartak commented 6 years ago

I understand this a learned probability distribution over all classes.

Correct

Do we tune these distributions as well as the kernels on the inner nodes?

Yes we do, but not directly. From Equation 2: we tune parameters phi of those distributions and apply softmax activation to these parameters in order to get the actual distributions that are used in 2 ways:

On the inference illustration below you can see how input image is used to determine the maximum probability path (green arrows) that leads to a single leaf / bigot / distribution that corresponds to model prediction. The distribution parameters phi themselves are stored as a parameters of the model and act as constants in the inference (calculation of model prediction).

benediktwagner commented 6 years ago

Thanks so much! You have been really helpful..