thadunge2 / AIDungeon

Other
51 stars 32 forks source link

Math nerd wanted for PyTorch #53

Open thadunge2 opened 4 years ago

thadunge2 commented 4 years ago

I opened a branch for running AIDungeon on PyTorch: https://github.com/thadunge2/AIDungeon/tree/pytorch-model/generator/gpt2

It's plug-and-play, just run play.py and it should install everything it needs to (unless you're on Windows, in which case it will tell you what to do). However, it's unusably slow until we rework the generate method to use hidden past states. This is beyond my ken, so if one of you wants to step up and do it, be my guest.

Here's the generate function we use: https://github.com/huggingface/transformers/blob/ce50305e5b8c8748b81b0c8f5539a337b6a995b9/src/transformers/modeling_utils.py#L699

outputs = self(**model_inputs) needs to take a "past" parameter and change like so: outputs, pasts = self(**model_inputs) I don't have the time or knowledge to make it do this, since it turns the 3D matrix into a 2D one and fucks everything up. So drop a PR on the pytorch-model branch fixing that and we can roll this feature out.

thadunge2 commented 4 years ago

Also, sorry to say, but I'm still getting incredibly slow generation times on the pytorch branch despite the new code.

AccidentallyOnPurpose commented 4 years ago

Damn, it looks like using past hidden states don't help that much. I'm not sure why either. I'm not sure how else to speed it up except distillation.

thadunge2 commented 4 years ago

Anon lied to me! What on Earth is slowing it down so much? It should be just as fast as TF.

ShnitzelKiller commented 4 years ago

A test one could do is stop using past hidden states in tensorflow and see if it ends up as slow as the current pytorch implementation. If so, it would strongly suggest that past state usage is significant and possibly not implemented correctly on our side. If not, it would at least tell us to look elsewhere for improvements.