Open thadunge2 opened 4 years ago
Also, sorry to say, but I'm still getting incredibly slow generation times on the pytorch branch despite the new code.
Damn, it looks like using past hidden states don't help that much. I'm not sure why either. I'm not sure how else to speed it up except distillation.
Anon lied to me! What on Earth is slowing it down so much? It should be just as fast as TF.
A test one could do is stop using past hidden states in tensorflow and see if it ends up as slow as the current pytorch implementation. If so, it would strongly suggest that past state usage is significant and possibly not implemented correctly on our side. If not, it would at least tell us to look elsewhere for improvements.
I opened a branch for running AIDungeon on PyTorch: https://github.com/thadunge2/AIDungeon/tree/pytorch-model/generator/gpt2
It's plug-and-play, just run play.py and it should install everything it needs to (unless you're on Windows, in which case it will tell you what to do). However, it's unusably slow until we rework the generate method to use hidden past states. This is beyond my ken, so if one of you wants to step up and do it, be my guest.
Here's the generate function we use: https://github.com/huggingface/transformers/blob/ce50305e5b8c8748b81b0c8f5539a337b6a995b9/src/transformers/modeling_utils.py#L699
outputs = self(**model_inputs)
needs to take a "past" parameter and change like so:outputs, pasts = self(**model_inputs)
I don't have the time or knowledge to make it do this, since it turns the 3D matrix into a 2D one and fucks everything up. So drop a PR on the pytorch-model branch fixing that and we can roll this feature out.