aai-institute / ProtoTree

Fork of ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition, published at CVPR2021
Other
3 stars 0 forks source link

Exploit sparsity and different shaped tensors in derivative free leaf optimization #13

Open CliffordWilmot opened 1 year ago

CliffordWilmot commented 1 year ago

Problem

The leaf optimization implements the following pseudocode for a given leaf (with a few extra things like exponential smoothing and softmaxes):

leaf_probabilities[j] = sum[over i] ( p_arrival[i] * leaf_probabilities[j] *  y_true[i, j] / y_predicted[i, j] ),

where i indexes the B datapoints in the minibatch and j indexes the C classes. In log probabilities (which is what is currently in the code) this is

leaf_log_probabilities[j] = logsumexp[over i] ( log_p_arrival[i] + leaf_logits[j] +  y_true_logits[i, j] - y_predicted_logits[i, j] ).

Note that y_true is a one hot encoding, and so for most j there will only be a small number of is for which y_true[i, j] == 1 != 0 and hence anything is contributed to the sum[over i]. Similarly, for the log probabilities, most values of y_true_logits[i, j] will be non-contributing -Infs, and only a few 0s will contribute to the logsumexp[over i].

The current code does not take advantage of this structure, y_true_logits is a dense float64 (B, C) tensor where most of the entries are -Inf, and the overall result is computed by broadcasting everything up to (B, C). This is not efficient; y_true_logits has only B 0 entries, and y_predicted_logits is already calculated and shared between all the leaves (we only need to access [i, j] where y_true_logits == 0), so we should only need O(B+C) calculations. These redundant calculations and tensors full of -Infs also make the code a bit confusing.

Possible Solutions

There are a few ways we could write the efficient/cleaner algorithm.

  1. With Python for loops: This code would be slow (both due to pure Python and no GPU acceleration), which seems to defeat the point of trying to optimize this.
  2. With PyTorch (or other library) sparse masks: Unfortunately, PyTorch (and other libraries) don't seem to have sufficient support for sparse tensors or dense masks, let alone sparse masks; most of the calculations we want to do are very convoluted/infeasible right now. Even a PyTorch implementation with dense masks runs into problems due to no logsumexp implementation (and so we have to use exp and sum and risk floating point inaccuracies), and due to the lower efficiency of the masked dense tensors it's roughly 5x slower than the current non-masked dense version. An example of what we're trying to achieve is in https://github.com/appliedAI-Initiative/ProtoTree/pull/14. We could of course try adding the desired functionality to PyTorch, but this seems like it would require a huge amount of work. Sparse and masked tensors are still in beta and prototype status in PyTorch, respectively, so it's possible that this option will become much easier in the future as the PyTorch developers add more functionality.
  3. With custom C++/Rust/CUDA code for our specific algorithm: This seems feasible, but could still require a fair bit of work. We'd need to come up with a fairly efficient implementation for both CPU and GPU devices.

Overall, that the current code is still moderately readable, and that the leaf optimization is only 1-2% of training time, it doesn't seem worthwhile to spend time trying to optimize this code now given all of the issues with these approaches.

Out of scope

Note: Any further optimization or vectorization between leaves is not considered here, i.e. each leaf has its own separate tensor.

MischaPanch commented 1 year ago

Nice, thank you for the detailed description! I don't think we should aim for option 3, it would make lots of things more complicated (like maintainability and installation). I agree with you that it's fine for now and we can just wait for sparse masks to improve in torch

CliffordWilmot commented 1 year ago

More possible solutions:

  1. Use JIT to accelerate Python for loops: An initial attempt with TorchScript hasn't produced enough of a speedup to counteract Python's slowness. Also, the code with this approach is a bit clunkier than envisaged. It might be helpful to try other Python JIT options such as Numba.
  2. Use PyTorch nested tensors: i.e. With this approach we don't even build a one-hot representation of y_true_logits, we just directly use the indices of contributing i, j by building a length B nested tensor where the ith entry is a tensor of the contributing js. The nested tensor API is also in the prototype stage, and right now there isn't enough support for the sort of indexing we need to do.
CliffordWilmot commented 1 year ago

Note: I've been running experiments with gradient based optimization on the new model code. The preliminary results from this are moderately encouraging, so this issue may become redundant. However, more training runs are needed to see how well the gradient based optimization fares overall.