allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
Apache License 2.0
2.17k stars 190 forks source link

OOM on summarization example #12

Open gabrielhuang opened 1 year ago

gabrielhuang commented 1 year ago

Hi there, I'm having OOM errors when running the summarization example on a 80GB A100 (CUDA 11.8).

I'm also getting some Tensorflow/TensorRT warnings, I'm wondering if it's related to that

2022-11-08 22:44:46.878785: I tensorflow/core/platform/] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-08 22:44:47.016183: E tensorflow/stream_executor/cuda/] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-08 22:44:47.979748: W tensorflow/stream_executor/platform/default/] Could not load dynamic library ''; dlerror: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979824: W tensorflow/stream_executor/platform/default/] Could not load dynamic library ''; dlerror: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979834: W tensorflow/compiler/tf2tensorrt/utils/] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

OOM error:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│                                                                              │
│ /mnt/home/code/RL4LMs/scripts/training/ in        │
│ <module>                                                                     │
│                                                                              │
│   63 │   │   │   │   │   │   help="Whether to use wandb logging")            │
│   64 │   args = parser.parse_args()                                          │
│   65 │                                                                       │
│ ❱ 66 │   main(args.config_path,                                              │
│   67 │   │    args.project_name,                                             │
│   68 │   │    args.experiment_name,                                          │
│   69 │   │    args.base_path_to_store_results,                               │
│ /mnt/home/code/RL4LMs/scripts/training/ in main   │
│                                                                              │
│   39 │   │   │   │   │   │   │   │     on_policy_alg_config=config["alg"],   │
│   40 │   │   │   │   │   │   │   │     train_eval_config=config["train_evalu │
│   41 │   │   │   │   │   │   │   │     tracker=tracker)                      │
│ ❱ 42 │   trainer.train_and_eval()                                            │
│   43                                                                         │
│   44                                                                         │
│   45 if __name__ == "__main__":                                              │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/ in   │
│ train_and_eval                                                               │
│                                                                              │
│   202 │   │   │   self._trainer_state["current_iter"] = epoch                │
│   203 │   │   │                                                              │
│   204 │   │   │   # inner rollout and learn loop for on-policy algorithm     │
│ ❱ 205 │   │   │   self._alg.learn(self._n_steps_per_iter)                    │
│   206 │   │   │                                                              │
│   207 │   │   │   # save the policy checkpoint                               │
│   208 │   │   │   if (epoch + 1) % self._train_eval_config.get("save_every", │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ in learn              │
│                                                                              │
│   344 │   │   reset_num_timesteps: bool = True,                              │
│   345 │   ) -> "PPO":                                                        │
│   346 │   │                                                                  │
│ ❱ 347 │   │   return super().learn(                                          │
│   348 │   │   │   total_timesteps=total_timesteps,                           │
│   349 │   │   │   callback=callback,                                         │
│   350 │   │   │   log_interval=log_interval,                                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/stable_baselines3/common/on │
│ in learn                                            │
│                                                                              │
│   264 │   │   │   │   self.logger.record("time/total_timesteps", self.num_ti │
│   265 │   │   │   │   self.logger.dump(step=self.num_timesteps)              │
│   266 │   │   │                                                              │
│ ❱ 267 │   │   │   self.train()                                               │
│   268 │   │                                                                  │
│   269 │   │   callback.on_training_end()                                     │
│   270                                                                        │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ in train              │
│                                                                              │
│   221 │   │   │   │   if self.use_sde:                                       │
│   222 │   │   │   │   │   self.policy.reset_noise(self.batch_size)           │
│   223 │   │   │   │                                                          │
│ ❱ 224 │   │   │   │   values, log_prob, entropy = self.policy.evaluate_actio │
│   225 │   │   │   │   │   rollout_data.observations, actions)                │
│   226 │   │   │   │   values = values.flatten()                              │
│   227 │   │   │   │   # Normalize advantage                                  │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/ in           │
│ evaluate_actions                                                             │
│                                                                              │
│    208 │   │                                                                 │
│    209 │   │   _, log_prob, entropy, _, _ = self.forward_policy(obs=obs,     │
│    210 │   │   │   │   │   │   │   │   │   │   │   │   │   │    actions=acti │
│ ❱  211 │   │   values, _ = self.forward_value(obs)                           │
│    212 │   │                                                                 │
│    213 │   │   return values, log_prob, entropy                              │
│    214                                                                       │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/ in           │
│ forward_value                                                                │
│                                                                              │
│    444 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     │
│    445 │   │                                                                 │
│    446 │   │   # and forrward pass to get hidden states                      │
│ ❱  447 │   │   outputs = self._value_model(                                  │
│    448 │   │   │   **model_inputs,                                           │
│    449 │   │   │   output_hidden_states=True,                                │
│    450 │   │   │   decoder_attention_mask=decoder_attn_mask,                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/ │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ in forward                                                   │
│                                                                              │
│   1645 │   │   │   │   decoder_attention_mask = │
│   1646 │   │                                                                 │
│   1647 │   │   # Decode                                                      │
│ ❱ 1648 │   │   decoder_outputs = self.decoder(                               │
│   1649 │   │   │   input_ids=decoder_input_ids,                              │
│   1650 │   │   │   attention_mask=decoder_attention_mask,                    │
│   1651 │   │   │   inputs_embeds=decoder_inputs_embeds,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/ │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ in forward                                                   │
│                                                                              │
│   1037 │   │   │   │   │   None,  # past_key_value is always None with gradi │
│   1038 │   │   │   │   )                                                     │
│   1039 │   │   │   else:                                                     │
│ ❱ 1040 │   │   │   │   layer_outputs = layer_module(                         │
│   1041 │   │   │   │   │   hidden_states,                                    │
│   1042 │   │   │   │   │   attention_mask=extended_attention_mask,           │
│   1043 │   │   │   │   │   position_bias=position_bias,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/ │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ in forward                                                    │
│                                                                              │
│    696 │   │   │   else:                                                     │
│    697 │   │   │   │   query_length = None                                   │
│    698 │   │   │                                                             │
│ ❱  699 │   │   │   cross_attention_outputs = self.layer[1](                  │
│    700 │   │   │   │   hidden_states,                                        │
│    701 │   │   │   │   key_value_states=encoder_hidden_states,               │
│    702 │   │   │   │   attention_mask=encoder_attention_mask,                │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/ │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ in forward                                                    │
│                                                                              │
│    610 │   │   output_attentions=False,                                      │
│    611 │   ):                                                                │
│    612 │   │   normed_hidden_states = self.layer_norm(hidden_states)         │
│ ❱  613 │   │   attention_output = self.EncDecAttention(                      │
│    614 │   │   │   normed_hidden_states,                                     │
│    615 │   │   │   mask=attention_mask,                                      │
│    616 │   │   │   key_value_states=key_value_states,                        │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/ │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ in forward                                                    │
│                                                                              │
│    506 │   │   )                                                             │
│    507 │   │                                                                 │
│    508 │   │   # compute scores                                              │
│ ❱  509 │   │   scores = torch.matmul(                                        │
│    510 │   │   │   query_states, key_states.transpose(3, 2)                  │
│    511 │   │   )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_stat │
│    512                                                                       │
RuntimeError: CUDA out of memory. Tried to allocate 150.00 MiB (GPU 0; 79.35 GiB
total capacity; 76.08 GiB already allocated; 108.19 MiB free; 77.20 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting
max_split_size_mb to avoid fragmentation.  See documentation for Memory

Any clues what's the issue? 80GB seems like a lot for just a T5-base model

rajcscw commented 1 year ago

Hey, this is expected I think. Keep in mind we have three t5 models loaded into memory. policy, value and reference policy. So far, we have used 4 GPUs to run summarization tasks. If you have just one GPU, just try to reduce the n_envs to a lower value. Also reduce the batch size of PPO. Otherwise, I would suggest running with more GPUs if possible.

gabrielhuang commented 1 year ago

I see thanks. I didn't expect there to be three full models. Any advantages vs. plugging three heads onto one language model trunk?

Also, I'm just curious, does the n_envs parameter scale up the batch size? Why does it have influence on GPU memory?

Many thanks.

JulesGM commented 1 year ago

yes people seem to usually just have different heads

rajcscw commented 1 year ago

@gabrielhuang Sure, shared layers will be paramter efficient. I am not sure how much performance change it will bring in.

Regarding n_envs, it controls batch_size for generating rollouts. You can think of it as a batched generation.

JulesGM commented 1 year ago

I'm trying to get google/flan-t5-xxl to run with a single A100 80GB gpu, for seq2seq policy.

Is there already a way to set the precision to bfloat16? (I don't see one, but just to be sure) If not I'll write a policy for that.

Also, will try to add model sharing between the policy and the value models, and allowing to freeze parts of the model.

JulesGM commented 1 year ago

Enabling offloading a model from GPU memory to CPU memory when it's not in use would likely be helpful too.

JulesGM commented 1 year ago

@gabrielhuang have you started doing work like this? (I'm also at Mila)

rajcscw commented 1 year ago

@JulesGM We don't have support for precision setting yet. You can implement this in a new policy and possibly we can try to merge this into existing classes (by configuring some args)

JulesGM commented 1 year ago

This is my current approach, indeed, just allowing the user to pass kwargs for from_pretrained and Linear. Passing torch_dtype to from_pretrained and dtype to Linear works.

I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.

import copy
import sys

import torch
import transformers

import rl4lms.envs.text_generation.registry as rl4lms_registry
import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy
from rl4lms.envs.text_generation import hf_generation_utils 

class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy):
    def __init__(

        self._from_pretrained_kwargs = from_pretrained_kwargs
        self._head_kwargs = head_kwargs

        super().__init__(*args, **kwargs)

    def _build_model_heads(self, model_name: str):
        self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self._policy_model.__class__ = hf_generation_utils.override_generation_routines(

        self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
            model_name, **self._from_pretrained_kwargs)
        self._ref_model = copy.deepcopy(self._policy_model).eval()

        self._value_head = torch.nn.Linear(
            self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs,

        # apply model parallel
        if torch.cuda.is_available():
            if self._apply_model_parallel and self._policy_model.is_parallelizable:
                self._value_head =
            else:  # else defaults to data parallel
                self._policy_model = torch.nn.DataParallel(
                self._ref_model    = torch.nn.DataParallel(self._ref_model   .to(self.device))
                self._value_model  = torch.nn.DataParallel(self._value_model .to(self.device))
                self._value_head   = torch.nn.DataParallel(self._value_head  .to(self.device))

JulesGM commented 1 year ago

looks like stable baselines 3 doesn't support bfloat16, because of all the a_tensor_name.cpu().numpy() calls. Indeed, doing that with a bfloat16 tensor leads to an exception, because torch tries to build a numpy array with the bfloat16 dtype, which is not supported by Numpy

JulesGM commented 1 year ago

in order for baselines 3 (and then rl4lms) to support bfloat16, it would suffice to modify a_tensor_name.cpu().numpy() to a_tensor_name.cpu().float().numpy().

lovodkin93 commented 1 year ago

This is my current approach, indeed, just allowing the user to pass kwargs for from_pretrained and Linear. Passing torch_dtype to from_pretrained and dtype to Linear works.

I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.

import copy
import sys

import torch
import transformers

import rl4lms.envs.text_generation.registry as rl4lms_registry
import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy
from rl4lms.envs.text_generation import hf_generation_utils 

class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy):
    def __init__(

        self._from_pretrained_kwargs = from_pretrained_kwargs
        self._head_kwargs = head_kwargs

        super().__init__(*args, **kwargs)

    def _build_model_heads(self, model_name: str):
        self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self._policy_model.__class__ = hf_generation_utils.override_generation_routines(

        self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
            model_name, **self._from_pretrained_kwargs)
        self._ref_model = copy.deepcopy(self._policy_model).eval()

        self._value_head = torch.nn.Linear(
            self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs,

        # apply model parallel
        if torch.cuda.is_available():
            if self._apply_model_parallel and self._policy_model.is_parallelizable:
                self._value_head =
            else:  # else defaults to data parallel
                self._policy_model = torch.nn.DataParallel(
                self._ref_model    = torch.nn.DataParallel(self._ref_model   .to(self.device))
                self._value_model  = torch.nn.DataParallel(self._value_model .to(self.device))
                self._value_head   = torch.nn.DataParallel(self._value_head  .to(self.device))


hey, thanks for your solution! I do have one question - is there a reason you didn't pass the from_pretrained_kwargs to the intialization of _policy_model:

self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)


lovodkin93 commented 1 year ago

@JulesGM Hey, so I tried what you suggested, passing to the from_pretrained **{"torch_dtype":torch.float16} and to the Linear **{"dtype": torch.float16} , and I got the following error:

Traceback (most recent call last): File "/home/nlp/sloboda1/anaconda3/lib/python3.9/", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/nlp/sloboda1/anaconda3/lib/python3.9/", line 87, in _run_code exec(code, run_globals) File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/", line 39, in cli.main() File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/", line 430, in main run() File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/", line 284, in run_file runpy.run_path(target, run_name="main") File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/", line 124, in _run_code exec(code, run_globals) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/scripts/training/", line 93, in main( File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/scripts/training/", line 64, in main trainer.train_and_eval() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/envs/text_generation/", line 214, in train_and_eval self._alg.learn(self._n_steps_per_iter) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/algorithms/ppo/", line 341, in learn return super().learn( File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/stable_baselines3/common/", line 267, in learn self.train() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/algorithms/ppo/", line 288, in train loss.backward() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/torch/", line 363, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/torch/autograd/", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Float but expected Half

I even tried passing **{"torch_dtype":torch.float16} to the from_pretrained of the self._poliy_model and still got that error.

Is there anythying else you converted to FP16 by any chance?

JulesGM commented 1 year ago

yes I did a bunch of other changes in the end

CathyKitten commented 8 months ago

May I ask if there is a complete code with changes that I can learn from?