Open CliffordWilmot opened 1 year ago
Nice, thank you for the detailed description! I don't think we should aim for option 3, it would make lots of things more complicated (like maintainability and installation). I agree with you that it's fine for now and we can just wait for sparse masks to improve in torch
More possible solutions:
y_true_logits
, we just directly use the indices of contributing i, j
by building a length B
nested tensor where the i
th entry is a tensor of the contributing j
s. The nested tensor API is also in the prototype stage, and right now there isn't enough support for the sort of indexing we need to do.Note: I've been running experiments with gradient based optimization on the new model code. The preliminary results from this are moderately encouraging, so this issue may become redundant. However, more training runs are needed to see how well the gradient based optimization fares overall.
Problem
The leaf optimization implements the following pseudocode for a given leaf (with a few extra things like exponential smoothing and softmaxes):
where
i
indexes theB
datapoints in the minibatch andj
indexes theC
classes. In log probabilities (which is what is currently in the code) this isNote that
y_true
is a one hot encoding, and so for mostj
there will only be a small number ofi
s for whichy_true[i, j] == 1 != 0
and hence anything is contributed to thesum[over i]
. Similarly, for the log probabilities, most values ofy_true_logits[i, j]
will be non-contributing-Inf
s, and only a few0
s will contribute to thelogsumexp[over i]
.The current code does not take advantage of this structure,
y_true_logits
is a dense float64(B, C)
tensor where most of the entries are-Inf
, and the overall result is computed by broadcasting everything up to(B, C)
. This is not efficient;y_true_logits
has onlyB
0
entries, andy_predicted_logits
is already calculated and shared between all the leaves (we only need to access[i, j]
wherey_true_logits == 0
), so we should only needO(B+C)
calculations. These redundant calculations and tensors full of-Inf
s also make the code a bit confusing.Possible Solutions
There are a few ways we could write the efficient/cleaner algorithm.
logsumexp
implementation (and so we have to useexp
andsum
and risk floating point inaccuracies), and due to the lower efficiency of the masked dense tensors it's roughly 5x slower than the current non-masked dense version. An example of what we're trying to achieve is in https://github.com/appliedAI-Initiative/ProtoTree/pull/14. We could of course try adding the desired functionality to PyTorch, but this seems like it would require a huge amount of work. Sparse and masked tensors are still in beta and prototype status in PyTorch, respectively, so it's possible that this option will become much easier in the future as the PyTorch developers add more functionality.Overall, that the current code is still moderately readable, and that the leaf optimization is only 1-2% of training time, it doesn't seem worthwhile to spend time trying to optimize this code now given all of the issues with these approaches.
Out of scope
Note: Any further optimization or vectorization between leaves is not considered here, i.e. each leaf has its own separate tensor.