huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.75k stars 26.45k forks source link

CUDA out of memory #25499

Closed andysingal closed 9 months ago

andysingal commented 1 year ago

System Info

Kaggle notebook

Who can help?

@pacman100 @sgu

Information

Tasks

Reproduction

training_args = transformers.TrainingArguments(
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=1,
    fp16=True,
    save_total_limit=4,
    logging_steps=25,
    output_dir="./outputs",
    save_strategy='epoch',
    optim="paged_adamw_8bit",
    lr_scheduler_type = 'cosine',
    warmup_ratio = 0.05,
    report_to="wandb" if wandb else []
)

trainer = transformers.Trainer(
    model=model,
    train_dataset=data,
    args=training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False

old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))

# if torch.__version__ >= "2" and sys.platform != "win32":
#     model = torch.compile(model)

print("\n If there's a warning about missing keys above, please disregard :)")

trainer.train()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

model.save_pretrained(OUTPUT_DIR)

got error:

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[20], line 36
     31 # if torch.__version__ >= "2" and sys.platform != "win32":
     32 #     model = torch.compile(model)
     34 print("\n If there's a warning about missing keys above, please disregard :)")
---> 36 trainer.train()
     37 gc.collect()
     38 torch.cuda.empty_cache()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1661, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1656     self.model_wrapped = self.model
   1658 inner_training_loop = find_executable_batch_size(
   1659     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1660 )
-> 1661 return inner_training_loop(
   1662     args=args,
   1663     resume_from_checkpoint=resume_from_checkpoint,
   1664     trial=trial,
   1665     ignore_keys_for_eval=ignore_keys_for_eval,
   1666 )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1946, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1943     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1945 with self.accelerator.accumulate(model):
-> 1946     tr_loss_step = self.training_step(model, inputs)
   1948 if (
   1949     args.logging_nan_inf_filter
   1950     and not is_torch_tpu_available()
   1951     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1952 ):
   1953     # if loss is nan or inf simply add the average of previous logged losses
   1954     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2753, in Trainer.training_step(self, model, inputs)
   2750     return loss_mb.reduce_mean().detach().to(self.args.device)
   2752 with self.compute_loss_context_manager():
-> 2753     loss = self.compute_loss(model, inputs)
   2755 if self.args.n_gpu > 1:
   2756     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2778, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2776 else:
   2777     labels = None
-> 2778 outputs = model(**inputs)
   2779 # Save past state if it exists
   2780 # TODO: this needs to be fixed and made cleaner later.
   2781 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:581, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    580 def forward(*args, **kwargs):
--> 581     return model_forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:569, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    568 def __call__(self, *args, **kwargs):
--> 569     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py:14, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     11 @functools.wraps(func)
     12 def decorate_autocast(*args, **kwargs):
     13     with autocast_instance:
---> 14         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:968, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)
    966 prompts = prompts.to(inputs_embeds.dtype)
    967 inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
--> 968 return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:688, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    685 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    687 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 688 outputs = self.model(
    689     input_ids=input_ids,
    690     attention_mask=attention_mask,
    691     position_ids=position_ids,
    692     past_key_values=past_key_values,
    693     inputs_embeds=inputs_embeds,
    694     use_cache=use_cache,
    695     output_attentions=output_attentions,
    696     output_hidden_states=output_hidden_states,
    697     return_dict=return_dict,
    698 )
    700 hidden_states = outputs[0]
    701 logits = self.lm_head(hidden_states)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:578, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    570     layer_outputs = torch.utils.checkpoint.checkpoint(
    571         create_custom_forward(decoder_layer),
    572         hidden_states,
   (...)
    575         None,
    576     )
    577 else:
--> 578     layer_outputs = decoder_layer(
    579         hidden_states,
    580         attention_mask=attention_mask,
    581         position_ids=position_ids,
    582         past_key_value=past_key_value,
    583         output_attentions=output_attentions,
    584         use_cache=use_cache,
    585     )
    587 hidden_states = layer_outputs[0]
    589 if use_cache:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:292, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    289 hidden_states = self.input_layernorm(hidden_states)
    291 # Self Attention
--> 292 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    293     hidden_states=hidden_states,
    294     attention_mask=attention_mask,
    295     position_ids=position_ids,
    296     past_key_value=past_key_value,
    297     output_attentions=output_attentions,
    298     use_cache=use_cache,
    299 )
    300 hidden_states = residual + hidden_states
    302 # Fully Connected

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:212, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    208     value_states = torch.cat([past_key_value[1], value_states], dim=2)
    210 past_key_value = (key_states, value_states) if use_cache else None
--> 212 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    214 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
    215     raise ValueError(
    216         f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
    217         f" {attn_weights.size()}"
    218     )

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.26 GiB (GPU 0; 15.90 GiB total capacity; 13.60 GiB already allocated; 877.75 MiB free; 14.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Expected behavior

would like model to train but vicuna does not support qlora... i am using PromptConfig + Vicuna

sgugger commented 1 year ago

Please provide a full reproducer and a reason why this all should fit in the 16GB GPU you have available.

andysingal commented 1 year ago

Please provide a full reproducer and a reason why this all should fit in the 16GB GPU you have available.

@sgugger Here is the full code(i hope you got the link shared) While going over lmsys repo i found that they are still doing research on stablevicuna + qlora,... i tried loraconfig, however Loraconfig target_variables do not work. I tried PromptConfig since i was working on Human/bot. Please let me know if you have any further question sor concerns

pacman100 commented 1 year ago

Hello, could you please reshare the minimal reproducer: code, command you are using to launch the training, the hardware as well as the versions of PyTorch, Transformers, Accelerate and PEFT?

andysingal commented 1 year ago

Hello, could you please reshare the minimal reproducer: code, command you are using to launch the training, the hardware as well as the versions of PyTorch, Transformers, Accelerate and PEFT?

Thanks for your response. Here is the colab notebook: https://colab.research.google.com/drive/1By1tOO6HE5Oopj2prr3tkDduewDFNpZu?usp=sharing @pacman100 @sgugger

andysingal commented 1 year ago

Hello, could you please reshare the minimal reproducer: code, command you are using to launch the training, the hardware as well as the versions of PyTorch, Transformers, Accelerate and PEFT?

Thanks for your response. Here is the colab notebook: https://colab.research.google.com/drive/1By1tOO6HE5Oopj2prr3tkDduewDFNpZu?usp=sharing @pacman100 @sgugger

Any updates @pacman100 @sgugger

pacman100 commented 10 months ago

I think the best one for this issue would be @SunMarc as the user is trying to use AutoGPTQ along with PEFT Prompt Tuning.

When trying it on Colab with T4 GPU, I am getting below error which is probably related to the Flash Attention:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-20-e3a673c6a851>](https://localhost:8080/#) in <cell line: 38>()
     36 # print("\n If there's a warning about missing keys above, please disregard :)")
     37 
---> 38 trainer.train()
     39 gc.collect()
     40 torch.cuda.empty_cache()

5 frames
[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1554                 hf_hub_utils.enable_progress_bars()
   1555         else:
-> 1556             return inner_training_loop(
   1557                 args=args,
   1558                 resume_from_checkpoint=resume_from_checkpoint,

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1870 
   1871                 with self.accelerator.accumulate(model):
-> 1872                     tr_loss_step = self.training_step(model, inputs)
   1873 
   1874                 if (

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in training_step(self, model, inputs)
   2746                 scaled_loss.backward()
   2747         else:
-> 2748             self.accelerator.backward(loss)
   2749 
   2750         return loss.detach() / self.args.gradient_accumulation_steps

[/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py](https://localhost:8080/#) in backward(self, loss, **kwargs)
   1984             self.scaler.scale(loss).backward(**kwargs)
   1985         else:
-> 1986             loss.backward(**kwargs)
   1987 
   1988     def set_trigger(self):

[/usr/local/lib/python3.10/dist-packages/torch/_tensor.py](https://localhost:8080/#) in backward(self, gradient, retain_graph, create_graph, inputs)
    490                 inputs=inputs,
    491             )
--> 492         torch.autograd.backward(
    493             self, gradient, retain_graph, create_graph, inputs=inputs
    494         )

[/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py](https://localhost:8080/#) in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: Expected is_sm80 || is_sm90 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

The notebook that I am trying out is having few changes on top of what the user shared above: https://colab.research.google.com/drive/1UDoYUoSK-YJoFMwEzClhNxbeyBkv5aza?usp=sharing

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.