lopuhin / transformer-lm

Transformer language model (GPT-2) with sentencepiece tokenizer
164 stars 47 forks source link

Plans to add gradient checkpointing? #11

Closed gooofy closed 5 years ago

gooofy commented 5 years ago

Do you have plans to add gradient checkpoint to enable training of larger models on consumer GPUs?

nshepperd seems to have a working implementation for tensorflow:

https://github.com/nshepperd/gpt-2/commit/47df6da611716b4826e3397cd68d711c6951c8e5

I could look into applying this to the gpt2 tensorflow code here but it might be even nicer to have this feature on the pytorch side of things but I fear that is a bit beyond what I could do :o)

gooofy commented 5 years ago

I have managed to implement gradient checkpointing now, will send you a pull request

lopuhin commented 5 years ago

Done in #12 , thanks to @gooofy 👍