r-three / t-few

Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"
MIT License
429 stars 59 forks source link

Can't run 11 billion model on A100 with 80GB #14

Closed danielkorat closed 2 years ago

danielkorat commented 2 years ago

Hi @craffel @muqeeth @HaokunLiu,

We're trying to reproduce T-Few results for a paper, but we're getting 'CUDA out of memory' using an A100 with 80GB (your recommended setup).

This is what we're running:

python -m src.pl_train -c t011b.json+ia3.json+rte.json -k load_weight="pretrained_checkpoints/t011b_ia3_finish.pt" exp_name=t011b_rte_seed42_ia3_pretrained few_shot_random_seed=42 seed=42

We installed according to the README instructions and are using the default settings in the config files. We are able to run the 3 billion model using the command above, just not the 11 billion. Is there anything we are doing wrong?

This is the exception:

CUDA out of memory

Thank you

HaokunLiu commented 2 years ago

Thanks for your interest in our work!

It's hard to tell from the surface. Could you share with me the full log? And if you are familiar with pytorch lightning, mind if add something like print("Memory usage at line [add something here]", torch.cuda.memory_allocated(device=None)) in the start and end of training_step of EncoderDecoder.py?

danielkorat commented 2 years ago

Hi @HaokunLiu We added the prints and attached the logs here. Looks like it runs out of memory before starting the training.


(tfew3.7) unso@hf-paris-dgx-station-1:~/t-few$ python -m src.pl_train -c t011b.json+ia3.json+rte.json -k load_weight="pretrained_checkpoints/t011b_ia3_finish.pt" exp_name=t011b_rte_seed42_ia3_pretrained few_shot_random_seed=42 seed=42 > logfile
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Reusing dataset super_glue (/home/unso/.cache/huggingface/datasets/super_glue/rte/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /home/unso/t-few/exp_out/t011b_rte_seed42_ia3_pretrained/log
 | Name | Type            | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 11.1 B
-----------------------------------------------------
1.1 M   Trainable params
11.1 B  Non-trainable params
11.1 B  Total params
44,548.801Total estimated model params size (MB)
Traceback (most recent call last):
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/runpy.py", line 193, in _run_module_as_main
  "__main__", mod_spec)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/runpy.py", line 85, in _run_code
  exec(code, run_globals)
 File "/home/unso/t-few/src/pl_train.py", line 86, in <module>
  main(config)
 File "/home/unso/t-few/src/pl_train.py", line 57, in main
  trainer.fit(model, datamodule)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
  self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
  return trainer_fn(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
  self._run(model, ckpt_path=ckpt_path)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
  self._dispatch()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
  self.training_type_plugin.start_training(self)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
  self._results = trainer.run_stage()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
  return self._run_train()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
  self.fit_loop.run()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
  self.advance(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
  self.epoch_loop.run(data_fetcher)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
  self.advance(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
  batch_output = self.batch_loop.run(batch, batch_idx)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
  self.advance(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
  outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
  self.advance(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 219, in advance
  self.optimizer_idx,
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
  self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 386, in _optimizer_step
  using_lbfgs=is_lbfgs,
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1652, in optimizer_step
  optimizer.step(closure=optimizer_closure)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
  trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
  self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 80, in optimizer_step
  return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step
  optimizer.step(closure=closure, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
  return wrapped(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/optim/optimizer.py", line 109, in wrapper
  return func(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/optimization.py", line 528, in step
  loss = closure()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure
  closure_result = closure()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in __call__
  self._result = self.closure(*args, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
  step_output = self._step_fn()
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
  training_step_output = self.trainer.accelerator.training_step(step_kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step
  return self.training_type_plugin.training_step(*step_kwargs.values())
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 213, in training_step
  return self.model.training_step(*args, **kwargs)
 File "/home/unso/t-few/src/models/EncoderDecoder.py", line 62, in training_step
  decoder_attention_mask=decoder_attention_mask,
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 1623, in forward
  return_dict=return_dict,
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 1020, in forward
  output_attentions=output_attentions,
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 696, in forward
  hidden_states = self.layer[-1](hidden_states)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 306, in forward
  forwarded_states = self.DenseReluDense(forwarded_states)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 285, in forward
  hidden_states = self.wo(hidden_states)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/unso/miniconda3/envs/tfew3.7/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward
  return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA out of memory. Tried to allocate 80.00 MiB (GPU 0; 79.35 GiB total capacity; 78.13 GiB already allocated; 3.62 MiB free; 78.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
HaokunLiu commented 2 years ago

I remember the code will print out all the args in the beginning. Could you share that with me?

dptam commented 2 years ago

Sorry I think the config might be slightly off as it was meant for the 3B and not 11B versions. For the 11B variants, to fit into memory, we used a smaller batch size but still had an effect batch size of 8. Our hyperparameters werebatch_size=1 grad_accum_factor=8 eval_batch_size=2. Let us know if it still runs out of memory.

eunseojo commented 2 years ago

thanks!