Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.1k stars 3.37k forks source link

Incosistant memory usage comparing to huggingface trainer when using deepspeed #20299

Open mickeysun0104 opened 1 week ago

mickeysun0104 commented 1 week ago

Bug description

I was able to fine-tune a 8B LLM using Huggingface training framework with PEFT+DeepSpeed stage 2 under fp16 precision(mixed precision training). Recently I would like to change my codebase to lightning due to our team's decision. However, I was not able to get the code work due to OOM issue even the settings from both side is nearly the same. Here's the code lightning-deepspeed.zip

lightning module ```python import lightning as L import torch import os from pathlib import Path from transformers import AutoModelForCausalLM from peft import get_peft_model, LoraConfig,PeftModel from lightning.pytorch.callbacks import Callback from typing import Optional LORA_CONFIG = LoraConfig( r = 64, lora_alpha=128, target_modules=['q_proj', 'k_proj', 'v_proj'], lora_dropout=0.1, bias="none", task_type="CASUAL_LM", use_dora=False ) class BoringModule(L.LightningModule): def __init__(self, model_name: str, precision=torch.float16, peft_cfg: LoraConfig=None, token: str=None, is_deepspeed_enabled: bool=True, ): super().__init__() self.model_name = model_name self.precision = precision self.token = token self.peft_cfg = peft_cfg self.model = None self.deepspeed = is_deepspeed_enabled def configure_model(self): if self.model is not None: return self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16, device_map={"": torch.cuda.current_device()}, trust_remote_code=True, token=self.token ) self.model.gradient_checkpointing_enable() self.model = get_peft_model(self.model, self.peft_cfg) def configure_optimizers(self): if self.deepspeed: from deepspeed.ops.adam import FusedAdam optimizer = FusedAdam(self.model.parameters(), lr=2e-4) else: optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1) return [optimizer], [scheduler] def forward(self, input_ids, attention_mask, label): return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=label, use_cache=False) def training_step(self, batch, batch_idx): output = self.forward(batch["input_ids"], batch["attention_mask"], batch["labels"]) loss = output.loss self.log_dict({"train_loss": loss}, on_step=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): output = self.forward(batch["input_ids"], batch["attention_mask"], batch["labels"]) loss = output.loss self.log_dict({"val_loss": loss}, on_step=True, sync_dist=True) return loss class PeftCheckpoint(Callback): def __init__(self, dirpath: Optional[str]=None, ): super().__init__() self.dirpath = dirpath self.ckpt_dir = None self.current_ckpt = {} def on_validation_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: current_step = trainer.global_step if current_step != 0: if not trainer.default_root_dir and not self.dirpath: output_dir = os.getcwd() elif not self.dirpath or not trainer.default_root_dir: output_dir = self.dirpath if self.dirpath else trainer.default_root_dir else: raise ValueError("Get output path from both trainer and callback, please provide the path from either one of them") self.ckpt_dir = os.path.join(output_dir, f"checkpoint-{current_step}") if not os.path.exists(self.ckpt_dir): Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) self.current_ckpt["dir"] = self.ckpt_dir def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: if isinstance(pl_module.model, PeftModel) and self.ckpt_dir: pl_module.model.save_pretrained(self.ckpt_dir) ```
lightning training pipeline ```python import lightning as L from transformers import DataCollatorForSeq2Seq, AutoTokenizer from pl_modules import BoringModule, LORA_CONFIG, PeftCheckpoint from datasets import load_dataset from torch.utils.data import DataLoader from lightning.pytorch.strategies import DeepSpeedStrategy def main(): model_name = "meta-llama/Meta-Llama-3-8B" token = None # load data and keep necessary columns data = load_dataset("json", data_files={"train":"./train_data.json", "val":"./val_data.json",}, split=["train[:100]", "val[:100]"]) train_data, val_data = data[0], data[1] # init pl module peft_llm = BoringModule(model_name=model_name, is_deepspeed_enabled=True, peft_cfg=LORA_CONFIG, token=token, ) tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side="left", max_length=8192) # put them in the dataloaders data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt",padding=True) train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=data_collator, num_workers=8) val_dataloader = DataLoader(val_data, batch_size=2, collate_fn=data_collator, num_workers=8) # init trainer and set the args peft_ckpt = PeftCheckpoint() trainer = L.Trainer(default_root_dir="./codetest", accelerator="cuda", callbacks=[peft_ckpt], log_every_n_steps=5, val_check_interval=5, devices=2, max_epochs=1, precision="16-mixed", num_sanity_val_steps=0, enable_checkpointing=True, strategy=DeepSpeedStrategy(config="./ds_config.json") ) trainer.fit(model=peft_llm, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) if __name__ == "__main__": main() ```
huggingface training pipeline ```python import torch from transformers import AutoModelForCausalLM, DataCollatorForSeq2Seq, Trainer, TrainingArguments,AutoTokenizer, HfArgumentParser from peft import get_peft_model from pl_modules import LORA_CONFIG from datasets import load_dataset MODEL = "meta-llama/Meta-Llama-3-8B" TOKEN = None def main(): parser = HfArgumentParser(TrainingArguments) training_args = parser.parse_args_into_dataclasses()[0] # load model and tokenizer model = AutoModelForCausalLM.from_pretrained(MODEL, token=TOKEN, torch_dtype=torch.float16, trust_remote_code=True, device_map={"": torch.cuda.current_device()}) if training_args.gradient_checkpointing: training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} model.config.use_cache = False peft_model = get_peft_model(model, LORA_CONFIG) tokenizer = AutoTokenizer.from_pretrained(MODEL, token=TOKEN, max_length=8192, padding_side="left") # load data data = load_dataset("json", data_files={"train":"./train_data.json", "val":"./val_data.json",}, split=["train[:100]", "val[:100]"]) train_data, val_data = data[0], data[1] data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt",padding=True) # init trainer trainer = Trainer(model = peft_model, args = training_args, train_dataset = train_data, eval_dataset = val_data, data_collator = data_collator, compute_metrics=None ) trainer.train() if __name__ == "__main__": main() ```
command - lightning ```bash python pipeline.py > codetest.log 2>&1 ``` - huggingface ``` bash deepspeed --num_gpus=2 hf-pipeline.py --output_dir ./hf_codetest --num_train_epochs 1 --per_device_train_batch_size 2 --per_device_eval_batch_size 2 --label_names labels --learning_rate 2e-4 --optim adamw_torch --lr_scheduler_type constant_with_warmup --fp16 True --evaluation_strategy steps --logging_steps 10 --save_steps 10 --eval_steps 10 --gradient_checkpointing True --gradient_accumulation_steps 1 --report_to none --deepspeed ./ds_config_hf.json > hf_codetest.log 2>&1 ``` * If the code has trouble saving checkpoint, modity the trainer.py L2401 to `logs["grad_norm"] = grad_norm.item()` refer to this [issue](https://github.com/huggingface/transformers/issues/29207)

I've seen some issues talking about the problem of using huggingface model in lightning framework, and I also tried some of the suggestions. however, none of them work : (

17878 -> confict about device setting

17043 ->properly load the model in configure_model hook should be alright

and some issues about using Zero 3 with hf pretrained model. I'm not putting all of them here since I'm trying to use zero 2 which should be less complicated.

The weird part I observe during lightning training is like below, the code start training with 4 processes which I have only two gpus. image when I use huggingface trainer, it only start training with 2 processes which makes sense. Also, the gpu utilization is balanced image

What version are you seeing the problem on?

v2.4

How to reproduce the bug

1. download the script and install the requirement
2. use the command above to start training

Error messages and logs

lightning log ```text /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. [2024-09-24 18:24:45,506] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/2 /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. [2024-09-24 18:24:52,443] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) initializing deepspeed distributed: GLOBAL_RANK: 1, MEMBER: 2/2 Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`. /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. current process device: 0 current process: 136236 current process: 0 current process: 0 Loading checkpoint shards: 0%| | 0/4 [00:00 main() File "/home/ubuntu/lightning-llm/pipline.py", line 83, in main trainer.fit(model=peft_llm, File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit call._call_and_handle_interrupt( File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run results = self._run_stage() ^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage self.fit_loop.run() File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run self.advance() File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance self.epoch_loop.run(self._data_fetcher) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run self.advance(data_fetcher) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run self._optimizer_step(batch_idx, closure) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step call._call_lightning_module_hook( File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook output = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step closure_result = closure() ^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__ self._result = self.closure(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 138, in closure self._backward_fn(step_output.closure_loss) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 239, in backward_fn call._call_strategy_hook(self.trainer, "backward", loss, optimizer) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook output = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 212, in backward self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 117, in backward deepspeed_engine.backward(tensor, *args, **kwargs) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn ret_val = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2020, in backward self.optimizer.backward(loss, retain_graph=retain_graph) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2064, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward scaled_loss.backward(retain_graph=retain_graph) File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "/home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/torch/autograd/__init__.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 948.00 MiB. GPU 0 has a total capacity of 31.74 GiB of which 474.12 MiB is free. Including non-PyTorch memory, this process has 30.97 GiB memory in use. Process 136372 has 306.00 MiB memory in use. Of the allocated memory 29.07 GiB is allocated by PyTorch, and 1.27 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) ```
huggingface trainer log ```text [2024-09-24 18:45:17,208] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2024-09-24 18:45:20,256] [WARNING] [runner.py:212:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only. [2024-09-24 18:45:20,256] [INFO] [runner.py:585:main] cmd = /home/ubuntu/lightning-llm/.venv/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None hf-pipeline.py --output_dir ./hf_codetest --num_train_epochs 1 --per_device_train_batch_size 2 --per_device_eval_batch_size 2 --label_names labels --learning_rate 2e-4 --optim adamw_torch --lr_scheduler_type constant_with_warmup --fp16 True --evaluation_strategy steps --logging_steps 10 --save_steps 10 --eval_steps 10 --gradient_checkpointing True --gradient_accumulation_steps 1 --report_to none --deepspeed /home/ubuntu/lightning-llm/ds_config_hf.json [2024-09-24 18:45:21,510] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2024-09-24 18:45:24,502] [INFO] [launch.py:146:main] WORLD INFO DICT: {'localhost': [0, 1]} [2024-09-24 18:45:24,502] [INFO] [launch.py:152:main] nnodes=1, num_local_procs=2, node_rank=0 [2024-09-24 18:45:24,502] [INFO] [launch.py:163:main] global_rank_mapping=defaultdict(, {'localhost': [0, 1]}) [2024-09-24 18:45:24,502] [INFO] [launch.py:164:main] dist_world_size=2 [2024-09-24 18:45:24,502] [INFO] [launch.py:168:main] Setting CUDA_VISIBLE_DEVICES=0,1 [2024-09-24 18:45:24,503] [INFO] [launch.py:256:main] process 144709 spawned with command: ['/home/ubuntu/lightning-llm/.venv/bin/python', '-u', 'hf-pipeline.py', '--local_rank=0', '--output_dir', './hf_codetest', '--num_train_epochs', '1', '--per_device_train_batch_size', '2', '--per_device_eval_batch_size', '2', '--label_names', 'labels', '--learning_rate', '2e-4', '--optim', 'adamw_torch', '--lr_scheduler_type', 'constant_with_warmup', '--fp16', 'True', '--evaluation_strategy', 'steps', '--logging_steps', '10', '--save_steps', '10', '--eval_steps', '10', '--gradient_checkpointing', 'True', '--gradient_accumulation_steps', '1', '--report_to', 'none', '--deepspeed', '/home/ubuntu/lightning-llm/ds_config_hf.json'] [2024-09-24 18:45:24,503] [INFO] [launch.py:256:main] process 144710 spawned with command: ['/home/ubuntu/lightning-llm/.venv/bin/python', '-u', 'hf-pipeline.py', '--local_rank=1', '--output_dir', './hf_codetest', '--num_train_epochs', '1', '--per_device_train_batch_size', '2', '--per_device_eval_batch_size', '2', '--label_names', 'labels', '--learning_rate', '2e-4', '--optim', 'adamw_torch', '--lr_scheduler_type', 'constant_with_warmup', '--fp16', 'True', '--evaluation_strategy', 'steps', '--logging_steps', '10', '--save_steps', '10', '--eval_steps', '10', '--gradient_checkpointing', 'True', '--gradient_accumulation_steps', '1', '--report_to', 'none', '--deepspeed', '/home/ubuntu/lightning-llm/ds_config_hf.json'] [2024-09-24 18:45:28,915] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2024-09-24 18:45:29,317] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2024-09-24 18:45:29,724] [INFO] [comm.py:652:init_distributed] cdb=None [2024-09-24 18:45:30,142] [INFO] [comm.py:652:init_distributed] cdb=None [2024-09-24 18:45:30,142] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl Number of GPUs: 2 /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. Number of GPUs: 2 /home/ubuntu/lightning-llm/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. Loading checkpoint shards: 0%| | 0/4 [00:00

Environment

Current environment ``` #- PyTorch Lightning Version : 2.4.0 #- PyTorch Version : 2.2.1 #- Python version : 3.12.3 #- OS (e.g., Linux) : Ubuntu 24.04 #- CUDA/cuDNN version: 12,0 #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): pip ```

More info

Harware information: NVIDIA Corporation GV100GL [Tesla V100 SXM2 32GB] *2

yznlp commented 6 days ago

Same issue here on a single Titan V GPU (12GB). With huggingface trainer I can comfortably fit a batch of 4 but with lightning I get OOM even with a single sample. Really not sure what the difference is...

mickeysun0104 commented 5 days ago

@yznlp Did you also use deepspeed with peft with huggingface functions and trying to training in lightning training framework?

yznlp commented 5 days ago

@mickeysun0104 sorry should have specified. I'm using the 7B LLaVA model with PEFT with huggingface functions in lightning training framework with precision set to 16-mixed. I haven't tried deepspeed but will have a look thanks :)

ChiShiang commented 4 days ago

I had the same issue with PyTorch Lightning 2.4.0. After several trials, the DeepSpeed strategy worked when I downgraded PyTorch Lightning to 2.1.0. It also worked in 2.0.5 but failed in 2.2.0.

mickeysun0104 commented 4 days ago

@yznlp I'm not sure if it's the problem of deepspeed integration since I also observed the same behavior (4 processes with 2 gpus) with ddp strategy. Thanks to @ChiShiang for the testing and found the temperary solution to this issue. Won't close the issue now cuz I believe there's a root cause still need to be figured out. I'm kinda new to lightning so I don't think I can find the key difference between lightning 2.1.0 and lightning >= 2.2.0. Kindly tag @awaelchli for further dive into the issue.(I'm not sure who's main author now)