lopuhin / transformer-lm

Transformer language model (GPT-2) with sentencepiece tokenizer
164 stars 47 forks source link

implement gradient checkpointing #12

Closed gooofy closed 5 years ago

gooofy commented 5 years ago

thanks for your feedback :) - I wasn't aware that you have automated tests in place, very cool! I have moved my tf related changes to a separate branch now and will focus on pytorch. I have also added the missing argument so tests should run cleanly.

however, please be aware that it looks like my changes have introduced a new bug which I am trying to hunt down right now:

save_every=4000, validate_every=4000
{
    "argv": "/home/bofh/projects/ai/torch/bin/gpt-2 de345-root data/encoded-de sp-model.model --n_embed=1024 --n_head=16 --n_layer=24 --batch_size=3 --gradient_checkpointing --save_every=4000",
    "batch_size": 3,
    "epochs": 10,
    "g_accum_gradients": 1,
    "hparams": {
        "gradient_checkpointing": true,
        "n_ctx": 1024,
        "n_embed": 1024,
        "n_head": 16,
        "n_hidden": 1024,
        "n_layer": 24,
        "n_vocab": 50000
    },
    "lr": 0.00025
}
Loading dataset from data/encoded-de
Train dataset has 1,196,206,177 tokens
Validation dataset has 12,355,824 tokens
Resuming from seen_tokens 89,094,144
epochs:   0%|                                                                                                                         | 0/10 [00:00<?, ?it/s/home/bofh/projects/ai/torch/lib/python3.6/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will
be Noneon:   0%|                                                                                                                  | 0/12066 [00:00<?, ?it/s]
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
                                                                                                                                                            Traceback (most recent call last):
  File "/home/bofh/projects/ai/torch/bin/gpt-2", line 11, in <module>
    load_entry_point('lm', 'console_scripts', 'gpt-2')()
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/main.py", line 324, in fire_main
    fire.Fire(only_allow_defined_args(main))
  File "/home/bofh/projects/ai/torch/lib/python3.6/site-packages/fire/core.py", line 127, in Fire
    component_trace = _Fire(component, args, context, name)
  File "/home/bofh/projects/ai/torch/lib/python3.6/site-packages/fire/core.py", line 366, in _Fire
    component, remaining_args)
  File "/home/bofh/projects/ai/torch/lib/python3.6/site-packages/fire/core.py", line 542, in _CallCallable
    result = fn(*varargs, **kwargs)
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/fire_utils.py", line 30, in _return_wrapped
    return function_to_decorate(*args, **kwargs)
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/main.py", line 261, in main
    train()
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/main.py", line 210, in train
    validate()
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/main.py", line 223, in validate
    valid_loss=get_valid_loss())
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/main.py", line 235, in get_valid_loss
    logits = model(ctx)['logits']
  File "/home/bofh/projects/ai/torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 539, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bofh/projects/ai/torch/transformer-lm/lm/model.py", line 43, in forward
    batch_size, n_ctx = x.shape
ValueError: not enough values to unpack (expected 2, got 1)

so it is probably a good idea to delay the merge until I have figured out what is going wrong there.

gooofy commented 5 years ago

ok, it looks like this issue is unrelated to my gradient checkpointing changes - seems like _valid_batch_iter can return empty batches, I have added a check against that - will see if training is stable now

gooofy commented 5 years ago

no worries! I am planning to focus on pytorch myself with my ml work anyways :)

I have not done any serious benchmarking - I did play around with the small model a little bit and found that with checkpointing I can increase the batch size significantly but performance will stay slower than without checkpointing to begin with.

however, checkpointing does enable training a 345M model on my 1080ti which I never managed to do without it enabled.

lopuhin commented 5 years ago

Thanks @gooofy 👍

lopuhin commented 4 years ago

I finally got to use this and it works great, very small performance overhead and much larger models possible, thank you @gooofy