Closed jasonrute closed 1 year ago
Actually, I decided to just do things right and fix up _hidden_state_sequences
to use the original graph instead of the stacked graph. Then we can avoid stacking graphs just to get out one tensor. We can always have the batch dimension first, avoiding the need for the ragged transpose function. Can you look this over again @mirefek now that the code is a bit different?
This PR significantly speeds up the neural network part of inference. (It still doesn't address beam search which is coming in another PR.) The idea is a continuation of the previous PR as follows. After computing the graph embeddings, we choose
k
tactics, and compute arguments for those tactics. However, the available local and global context is independent from the chosen tactics. So this allows for three optimizations:[batch, None(cxt), hdim]
wherebatch
is the original batch size (1 in our case).k
tactics under the same batch element. The queries then have ragged shape[batch, None(tactic-args), hdim]
, wherebatch
is just the original batch_size (in our case 1).These three optimizations, when combined, significantly speed up the model. Previously, most of the time was spent multiplying queries and keys to get the logits (or transforming tensors to get into and out of the right shape to multiply). This eliminates all that. And moreover, surprisingly to me, by reducing the shape of the key tensor, we also significantly speed up even the time spent on the einsum operator (even though we still multiply the same number of query-key pairs together).