Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
7.95k stars 797 forks source link

CUDA Out of Memory for the Falcon 7B model on A100 80GB GPU #159

Closed k21993 closed 12 months ago

k21993 commented 1 year ago

I am trying to reproduce the Falcon-7B Lora fine-tuning on the Alpaca dataset. I followed the steps to convert the checkpoints to lightning format, downloaded and tokenized the Alpaca dataset as instructed. When I run:

python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

I get the following traceback:

{'eval_interval': 100, 'save_interval': 100, 'eval_iters': 100, 'log_interval': 1, 'devices': 1, 'learning_rate': 0.0003, 'batch_size': 4, 'micro_batch_size': 4, 'gradient_accumulation_iters': 1, 'max_iters': 50000, 'weight_decay': 0.01, 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'warmup_iters': 100}
Using bfloat16 Automatic Mixed Precision (AMP)
Global seed set to 1337
Loading model 'checkpoints/tiiuae/falcon-7b/lit_model.pth' with {'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 512, 'padded_vocab_size': 65024, 'n_layer': 32, 'n_head': 71, 'n_embd': 4544, 'rotary_percentage': 1.0, 'parallel_residual': True, 'bias': False, 'n_query_groups': 1, 'shared_attention_norm': True}
Number of trainable parameters: 3506176
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.

### Response:
[The Martian](https://www.imdb.com/title/tt1878107) is a really good movie to watch during the weekend. It is set on Mars and is based on the book by Andy Weir. Weir is a retired engineer who won an international writing contest for promising science fiction writers. The movie is funny and at the same time it is thoughtful and inspiring. I will recommend this movie to you because of the following reasons.

1. The movie
Estimated TFLOPs: 384.19
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ finetune/lora.py: │
│ 288 in <module>                                                                                  │
│                                                                                                  │
│   285 │   │   message="Remove `.no_backward_sync()` from your code",                             │
│   286 │   )                                                                                      │
│   287 │                                                                                          │
│ ❱ 288 │   CLI(setup)                                                                             │
│   289                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:85 in CLI                             │
│                                                                                                  │
│    82 │   │   │   return parser                                                                  │
│    83 │   │   cfg = parser.parse_args(args)                                                      │
│    84 │   │   cfg_init = parser.instantiate_classes(cfg)                                         │
│ ❱  85 │   │   return _run_component(component, cfg_init)                                         │
│    86 │                                                                                          │
│    87 │   subcommands = parser.add_subcommands(required=True)                                    │
│    88 │   comp_dict = {c.__name__: c for c in components}                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:147 in _run_component                 │
│                                                                                                  │
│   144 def _run_component(component, cfg):                                                        │
│   145 │   cfg.pop("config", None)                                                                │
│   146 │   if not inspect.isclass(component):                                                     │
│ ❱ 147 │   │   return component(**cfg)                                                            │
│   148 │   subcommand = cfg.pop("subcommand")                                                     │
│   149 │   if not subcommand:                                                                     │
│   150 │   │   return component(**cfg)                                                            │
│                                                                                                  │
│ finetune/lora.py: │
│ 75 in setup                                                                                      │
│                                                                                                  │
│    72 │   print(hparams)                                                                         │
│    73 │                                                                                          │
│    74 │   fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision)      │
│ ❱  75 │   fabric.launch(main, data_dir, checkpoint_dir, out_dir, precision)                      │
│    76                                                                                            │
│    77                                                                                            │
│    78 def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, precisio   │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:759 in launch                  │
│                                                                                                  │
│   756 │   │   │   │   f"To use the `{type(self.strategy).__name__}` strategy, `.launch()` need   │
│   757 │   │   │   │   " that contains the code to launch in processes."                          │
│   758 │   │   │   )                                                                              │
│ ❱ 759 │   │   return self._wrap_and_launch(function, self, *args, **kwargs)                      │
│   760 │                                                                                          │
│   761 │   def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:                     │
│   762 │   │   """Trigger the callback methods with the given name and arguments.                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:841 in _wrap_and_launch        │
│                                                                                                  │
│   838 │   │   to_run = partial(self._wrap_with_setup, to_run)                                    │
│   839 │   │   if (launcher := self._strategy.launcher) is not None:                              │
│   840 │   │   │   return launcher.launch(to_run, *args, **kwargs)                                │
│ ❱ 841 │   │   return to_run(*args, **kwargs)                                                     │
│   842 │                                                                                          │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:846 in _wrap_with_setup        │
│                                                                                                  │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│   845 │   │   with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(Bat   │
│ ❱ 846 │   │   │   return to_run(*args, **kwargs)                                                 │
│   847 │                                                                                          │
│   848 │   def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn   │
│   849 │   │   initial_device = next(model.parameters(), torch.tensor(0)).device                  │
│                                                                                                  │
│ finetune/lora.py: │
│ 112 in main                                                                                      │
│                                                                                                  │
│   109 │   │   max_seq_length = json.load(data_config_path).get("max_seq_length", model.config.   │
│   110 │                                                                                          │
│   111 │   train_time = time.time()                                                               │
│ ❱ 112 │   train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir, max_s   │
│   113 │   fabric.print(f"Training time: {(time.time()-train_time):.2f}s")                        │
│   114 │                                                                                          │
│   115 │   # Save the final LoRA checkpoint at the end of training                                │
│                                                                                                  │
│ finetune/lora.py: │
│ 138 in train                                                                                     │
│                                                                                                  │
│   135 │   estimated_flops = estimate_flops(model) * micro_batch_size                             │
│   136 │   fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")    │
│   137 │   if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported                  │
│ ❱ 138 │   │   measured_flops = measure_flops(                                                    │
│   139 │   │   │   model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), devi   │
│   140 │   │   )                                                                                  │
│   141 │   │   fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}"   │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/speed_ │
│ monitor.py:269 in measure_flops                                                                  │
│                                                                                                  │
│   266 │   flop_counter = FlopCounterMode(model, display=False)                                   │
│   267 │   ctx = nullcontext() if model.training else torch.no_grad()                             │
│   268 │   with ctx, flop_counter:                                                                │
│ ❱ 269 │   │   y = model(x)                                                                       │
│   270 │   │   if model.training:                                                                 │
│   271 │   │   │   y.sum().backward()                                                             │
│   272 │   return flop_counter.get_total_flops()                                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/wrappers.py:116 in forward               │
│                                                                                                  │
│   113 │   │   args, kwargs = self._precision.convert_input((args, kwargs))                       │
│   114 │   │                                                                                      │
│   115 │   │   with self._precision.forward_context():                                            │
│ ❱ 116 │   │   │   output = self._forward_module(*args, **kwargs)                                 │
│   117 │   │                                                                                      │
│   118 │   │   output = self._precision.convert_output(output)                                    │
│   119 │   │   return output                                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:92 in forward                                                                                 │
│                                                                                                  │
│    89 │   │                                                                                      │
│    90 │   │   if input_pos is None:  # proxy for use_cache=False                                 │
│    91 │   │   │   for block in self.transformer.h:                                               │
│ ❱  92 │   │   │   │   x, *_ = block(x, (cos, sin), mask, max_seq_length)                         │
│    93 │   │   else:                                                                              │
│    94 │   │   │   self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, c   │
│    95 │   │   │   for i, block in enumerate(self.transformer.h):                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:158 in forward                                                                                │
│                                                                                                  │
│   155 │   │   kv_cache: Optional[KVCache] = None,                                                │
│   156 │   ) -> Tuple[torch.Tensor, Optional[KVCache]]:                                           │
│   157 │   │   n_1 = self.norm_1(x)                                                               │
│ ❱ 158 │   │   h, new_kv_cache = self.attn(n_1, rope, mask, max_seq_length, input_pos, kv_cache   │
│   159 │   │   if self.config.parallel_residual:                                                  │
│   160 │   │   │   n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)             │
│   161 │   │   │   x = x + h + self.mlp(n_2)                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:233 in forward                                                                                │
│                                                                                                  │
│   230 │   │   │   kv_cache = k, v                                                                │
│   231 │   │                                                                                      │
│   232 │   │   # efficient attention using Flash Attention CUDA kernels                           │
│ ❱ 233 │   │   y = F.scaled_dot_product_attention(                                                │
│   234 │   │   │   q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0 / math.sqrt(self.config.he   │
│   235 │   │   )                                                                                  │
│   236                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/utils/flop_counter.py:395 in __torch_dispatch__     │
│                                                                                                  │
│   392 │                                                                                          │
│   393 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                       │
│   394 │   │   kwargs = kwargs if kwargs else {}                                                  │
│ ❱ 395 │   │   out = func(*args, **kwargs)                                                        │
│   396 │   │   func_packet = func._overloadpacket                                                 │
│   397 │   │   if func_packet in self.flop_mapping:                                               │
│   398 │   │   │   flop_count_func = self.flop_mapping[func_packet]                               │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:401 in __call__                             │
│                                                                                                  │
│   398 │   │   )                                                                                  │
│   399 │                                                                                          │
│   400 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 401 │   │   return self._op(*args, **kwargs or {})                                             │
│   402 │                                                                                          │
│   403 │   def __hash__(self):                                                                    │
│   404 │   │   return hash(self._op)                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacty of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process
has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory 
is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It is also using just 1 GPU and not 8 that I have. Please help me resolve these issues ASAP. Thanks!

griff4692 commented 1 year ago

I have same issue on 48GB GPU - am following to see what the solution is.

griff4692 commented 1 year ago

tldr you can can force the strategy to be deepspeed and it should run. the default ds config is stage 2 which is effective even on a single gpu.

k21993 commented 1 year ago

@griff4692 Thanks for the pointer, I hardcoded strategy as strategy = DeepSpeedStrategy(config=ds_config) here and it runs! Although there are two issues that I see:

  1. Peak GPU memory is ~30 GB.
  2. It only runs on 1 GPU even if multiple GPUs are available.
  3. Why does it run out of memory on a 80GB GPU when deepspeed is not enabled? The model is ~30GB right?

Do you know why this is the case?

griff4692 commented 1 year ago

The devices constant in Lora.py is set to 1. You could try changing it and see what happens

k21993 commented 1 year ago

Aah I didn't realize it was hardcoded there, thanks!

aniketmaurya commented 1 year ago

The devices constant in Lora.py is set to 1. You could try changing it and see what happens

@awaelchli @lantiga maybe show a warning if more devices are available?

rasbt commented 1 year ago

@k21993 LoRA with Falcon 7B should work on a single GPU with ~16 Gb. If not, you can change the micro_batch_size = 4 to micro_batch_size = 1 (it only affects the runtime) or try to reduce the LoRA rank.

xy990 commented 1 year ago

what else did you change? even I change micro_batch_size = 4 to micro_batch_size = 1, LoRA with Falcon 7B does not work on a single GPU with 24 GB.

rasbt commented 1 year ago

That's weird, here are the complete settings I used https://github.com/rasbt/LLM-finetuning-scripts/blob/main/lit-benchmarks/falcon-7b/finetune/lora.py

via

python finetune/lora.py  --checkpoint_dir checkpoints/tiiuae/falcon-7b/

the peak memory use was 16.97 according to

print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
k21993 commented 1 year ago

@rasbt @aniketmaurya

  1. I tried running with 8 A100 (80GB) GPUs with the settings:

    batch_size = 64
    micro_batch_size = 4
    lora_r = 8
    devices=8

    It runs for ~15k iterations and eventually fails with:

    OutOfMemoryError: CUDA out of memory. Tried to allocate 632.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 452.44 MiB is free. Process 147633 has 32.01 GiB memory in 
    use. Including non-PyTorch memory, this process has 46.70 GiB memory in use. Of the allocated memory 42.63 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but 
    unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
    PYTORCH_CUDA_ALLOC_CONF
  2. If I set devices=1 and run

    python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

    It fails on startup itself.

  3. If I set devices=1 and hardcode strategy=deepspeed, it still uses a lot of memory:

    Screenshot 2023-06-16 at 1 42 26 PM
rasbt commented 1 year ago

Regarding the 1 GPU setting you have above, you mention micro_batch_size = 4. So if you set this to micro_batch_size = 1, then theoretically it should work: 67,775 Mib / 4 = 16,943 Mib

rasbt commented 1 year ago

Regarding multi-GPU training, it is currently set to deep speed stage 2, which is not very memory efficient (it optimizes for speed). If you set this to deepspeed stage 3, it is more memory-efficient, but there is currently a bug with stage 3 & multi-GPU (#161). But the 1 GPU case should definitely work.

carmocca commented 12 months ago

I have a fix in #171 that will reduce the memory requirements for fine-tuning and training

k21993 commented 12 months ago

@carmocca Seems like this is a fix for the adapter method but not lora based on the PR. Can you outline the basic steps to make these changes for lora?

carmocca commented 12 months ago

@k21993 the fix above also applies to lora

k21993 commented 12 months ago

Hey @carmocca I tried your fix and the memory requirement seems to be the same while the iteration time decreases from ~10s to ~7s.

Here's my config:

max_seq_len = 2048
micro_batch_size = 2
batch_size = 64
lora_r = 64
lora_alpha = 128
devices = 1
ds_config = {
    "train_micro_batch_size_per_gpu": micro_batch_size,
    "gradient_accumulation_steps": gradient_accumulation_iters,
    "zero_optimization": {"stage": 2},
}  

The memory occupied is the same (~73 GB)

Screenshot 2023-06-19 at 2 55 05 PM
peerdavid commented 12 months ago

I did not do a deep analysis but here is what helped in my case (now mem consumption is constant at ~ 16GB with micro_batch_size of 1): First I removed the SpeedMonitor because for some reason this needed lots of memory. Second I have seen that over the training time more and more memory was consumed -- I now call torch.cuda.empty_cache() every n iterations and now the mem consumption is constant over time too.

fozziethebeat commented 12 months ago

I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:

    if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported
        measured_flops = measure_flops(
            model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
        )
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
    else:
        measured_flops = None

and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.

fozziethebeat commented 12 months ago

I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:

    if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported
        measured_flops = measure_flops(
            model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
        )
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
    else:
        measured_flops = None

and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.

I lied, I still ran into an OOM issue about 80 steps in after fixing a NaN problem (solved by using --precision bf16-mixed).

I've tried using adapter_v2.py, adapter.py and lora.py. All quickly OOM on my 48GB GPU (within 80 steps). Not sure what's causing this yet.

EDIT: With some tweaking, changing these settings got me a few more steps (up to about 600) before OOM:

batch_size = 64 / devices
micro_batch_size = 1

broadly, it'd be nice if the scripts referenced in the guide worked as reported. Even with all these tweaks the minimum vram usage i'm seeing when training starts is ~30GB, not 16GB.

k21993 commented 12 months ago

@fozziethebeat What's your micro_batch_size and max_seq_len? Since the sequence length is local to the batch, may be it finds a batch later in your training that is big enough to cause OOM.

fozziethebeat commented 12 months ago

i'm using the default max_seq_length as generated by running

python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

Looking at the config directly, looks like it's 1079. That doesn't seem to extreme to me and is lower than the block size (2048) reported by falcon-7b.

peerdavid commented 12 months ago

So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...

Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.

carmocca commented 12 months ago

I merged #173, that should fix the FLOPs counter issue.

I'll try replicating the sequence length issues you are seeing now

cipher982 commented 12 months ago

So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...

Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.

Noticing the same thing on my end. Specifically iter 251 gets a token length around 600 and crashes on my 3090. I modified the script to skip any inputs above 600 and it trains a little longer but crashes later on around a 500 token input. It appears the memory usage slowly creeps up over a few minutes while training, maybe something is not being released correctly.

carmocca commented 12 months ago

Hey all. Using current main, here's what I'm calling:

python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true

with micro_batch_size=1 I get a constant ~16GB use. It might seem to slowly creep up, but that is just the CUDA allocator keeping more than it needs. As https://github.com/Lightning-AI/lit-parrot/issues/159#issuecomment-1598193614 mentioned, empty_cache() will keep it down, but beware because that will slow it down a lot, so don't call it often if you need it.

In terms of model requirements, here's what you expect

Number of trainable parameters: 1365330
Number of non trainable parameters: 7217189760
Sum: 7218555090

Model weights fp32: 7218555090 * 4 / 1e9 = 28.87 GB
AdamW fp32: 2 * 4 * 1365330 / 1e9 = 0.01 GB

Which matches the observed 29.02 GB returned by torch.cuda.memory_reserved() and --precision bf16-mixed. Using 16-true or bf16-true, the memory is halved.

All is working as expected so far. Now, if I force all inputs to be of the maximum sequence length for the alpaca dataset (1079), the max memory reserved does jump to 24.5 GB.

I'll open a PR trying to alleviate that jump, as it's caused by an autograd issue with backward. However, you might still need to tweak the max_seq_length depending on your available GPU memory

fozziethebeat commented 12 months ago

Thank you! This so far seems to be the needed fix.

Trying now at main and this so far is working really smoothly. Using the exact command you tried, I'm seeing ~29GB VRAM usage and no NaNs in my loss function. So far at step 600 and no issues.

I do see small memory increases but it's much less dramatic than before.

EDIT: posted too soon. Hit an OOM after iter 1599 step 100

carmocca commented 12 months ago

I merged #178 which should be a small decrease in memory usage.

I'll also be adding #182 which includes a change so that the longest alpaca sequence is loaded first, so that OOM happens at the beginning.

For the deepspeed issues, I'll be replacing it with FSDP in #118

Closing this issue. Feel free to open new ones for any new issues. Thank you all

fozziethebeat commented 12 months ago

Should this be staying under 48GB VRAM usage when we run the command below at head?

python finetune/adapter.py \
    --data_dir data/alpaca  \
    --checkpoint_dir checkpoints/tiiuae/falcon-7b \
    --out_dir out/adapter/alpaca --precision bf16-true

I've just tried this out and I still see a OOM at iter 1599 step 100.

cipher982 commented 12 months ago

Should this be staying under 48GB VRAM usage when we run the command below at head?

python finetune/adapter.py \
    --data_dir data/alpaca  \
    --checkpoint_dir checkpoints/tiiuae/falcon-7b \
    --out_dir out/adapter/alpaca --precision bf16-true

I've just tried this out and I still see a OOM at iter 1599 step 100.

Trying now on A6000 and it looks like I am basically maxed out on ~48GB right from the start. So possible it moves a bit up/down from there and gets OOM.

fozziethebeat commented 12 months ago

That's exactly what I noticed. It started at 100% VRAM usage and then something at iter 1599 step 100 kills it with the tiniest increase of memory.

carmocca commented 12 months ago

@cipher982 @fozziethebeat, With the latest changes, you should get a maximum usage of 24.5 GB with true half precision and micro_batch_size=1 at the beginning of training

ritvikshrivastava commented 12 months ago

I was running into the same OOM errors even after yesterday's merge, when using LoRA with Falcon-7B-Instruct for finetuning. Going through the comments in this thread, tried removing the speed monitor code from lora.py and that helped constrain the memory issues on a single GPU for now. Also true half-precision.

EDIT: Although this worked with this config, it's still borderline in terms of GPU memory still, a small spike is enough to error out with OOM.

carmocca commented 12 months ago

tried removing the speed monitor code from lora.py and that helped constrain the memory issues

The speed monitor shouldn't impact memory usage at all. Do you have a way to show that this is the case? It would be considered a bug if so