Don't do unit normalization of the keys (definition embeddings) during inference. The embeddings are already normalized, so that is just an expensive and redundant operation.
Gather the keys before duplicating the graph for each tactic. This speeds up the gather operation greatly. It does come at the cost of later tiling the keys to get them into the correct shape. (This can later be sped up by being a bit more careful with the multiplication and I'll do that in a future PR.)
(This isn't the bottle neck yet. That is beam search. I'm still working on that. It is going well, but I want to clean it up and test it more, while this is an easy fix.)
Speed up inference with two simple changes:
(This isn't the bottle neck yet. That is beam search. I'm still working on that. It is going well, but I want to clean it up and test it more, while this is an easy fix.)