IBM / graph2tac

Graph-based neural tactic prediction models for Coq.
Apache License 2.0
10 stars 5 forks source link

Speedup inference model #146

Closed jasonrute closed 1 year ago

jasonrute commented 1 year ago

This PR significantly speeds up the neural network part of inference. (It still doesn't address beam search which is coming in another PR.) The idea is a continuation of the previous PR as follows. After computing the graph embeddings, we choose k tactics, and compute arguments for those tactics. However, the available local and global context is independent from the chosen tactics. So this allows for three optimizations:

These three optimizations, when combined, significantly speed up the model. Previously, most of the time was spent multiplying queries and keys to get the logits (or transforming tensors to get into and out of the right shape to multiply). This eliminates all that. And moreover, surprisingly to me, by reducing the shape of the key tensor, we also significantly speed up even the time spent on the einsum operator (even though we still multiply the same number of query-key pairs together).

jasonrute commented 1 year ago

Actually, I decided to just do things right and fix up _hidden_state_sequences to use the original graph instead of the stacked graph. Then we can avoid stacking graphs just to get out one tensor. We can always have the batch dimension first, avoiding the need for the ragged transpose function. Can you look this over again @mirefek now that the code is a bit different?