axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method #1750

Open RishabhMaheshwary opened 1 month ago

RishabhMaheshwary commented 1 month ago

Please check that this issue hasn't been reported before.

Expected Behavior

It should run without any errors.

Current behaviour

Throws the error:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Steps to reproduce

The latest pull with commit id 219cd0d with the following command and config below results in the error below

accelerate launch --use_deepspeed -m axolotl.cli.train ../examples/mistral/config.yml

Error:

sing the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_length, max_prompt_length, generate_during_eval, dataset_num_proc. Will not be supported from version '1.0.0'.

Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:311: UserWarning: You passed `generate_during_eval` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:389: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:402: UserWarning: You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:519: UserWarning: You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
Loading checkpoint shards:  67%|█████████████████████████████████████████████████████████▎                            | 2/3 [01:12<00:36, 36.33s/it]Process ForkPoolWorker-1:
Traceback (most recent call last):
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/multiprocess/process.py", line 314, in _bootstrap
    self.run()
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/multiprocess/pool.py", line 114, in worker
    task = get()
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/multiprocess/queues.py", line 370, in get
    return _ForkingPickler.loads(res)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/dill/_dill.py", line 303, in loads
    return load(file, ignore, **kwds)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/dill/_dill.py", line 289, in load
    return Unpickler(file, ignore=ignore, **kwds).load()
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/dill/_dill.py", line 444, in load
    obj = StockUnpickler.load(self)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/storage.py", line 381, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/serialization.py", line 1272, in _legacy_load
    result = unpickler.load()
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/serialization.py", line 1205, in persistent_load
    obj = restore_location(obj, location)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/serialization.py", line 390, in default_restore_location
    result = fn(storage, location)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/serialization.py", line 267, in _cuda_deserialize
    with torch.cuda.device(device):
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/cuda/__init__.py", line 365, in __enter__
    self.prev_idx = torch.cuda._exchange_device(self.idx)
  File "/mnt/rishabh/anaconda3/envs/axolotl1/lib/python3.10/site-packages/torch/cuda/__init__.py", line 279, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Config yaml

base_model: /mnt/rishabh/models/Mistral-7B-v0.2
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

rl: dpo
rl_beta: 0.1
remove_unused_columns: false

datasets:
  - ds_type: json
    data_files:
      - /mnt/rishabh/data/dpo/m2lingual.json
    split: train
    type: chatml.mistral
  - ds_type: json
    data_files:
      - /mnt/rishabh/data/dpo/m2lingual1.json
    split: train
    type: chatml.mistral
  - ds_type: json
    data_files:
      - /mnt/rishabh/data/dpo/m2lingual2.json
    split: train
    type: chatml.mistral
  - ds_type: json
    data_files:
      - /mnt/rishabh/data/dpo/m2lingual3.json
    split: train
    type: chatml.mistral
  - ds_type: json
    data_files:
      - /mnt/rishabh/data/dpo/m2lingual4.json
    split: train
    type: chatml.mistral

dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: /mnt/rishabh/checkpoints/axolotl/mistral_instrcut_dpo_m2lingual

sequence_len: 8192
sample_packing: false
pad_to_sequence_len: true
eval_sample_packing: false

use_tensorboard: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-7

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 100
evals_per_epoch: 1
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
deepspeed: ../deepspeed_configs/zero1.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:


### Possible solution

_No response_

### Which Operating Systems are you using?

- [X] Linux
- [ ] macOS
- [ ] Windows

### Python Version

3.10

### axolotl branch-commit

main

### Acknowledgements

- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
winglian commented 1 month ago

@RishabhMaheshwary how many gpus and what gpu type? Thanks

RishabhMaheshwary commented 1 month ago

@winglian 8GPUs, A100 80GB.

RishabhMaheshwary commented 1 month ago

I am able to run without any errors when I use the examples/mistral/config.yml. But when I just replace the dataset and training method to dpo shown below it gives the above error.

rl: dpo
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel

It might be related to trl?

winglian commented 1 month ago

@RishabhMaheshwary I've narrowed this down to an issue with DPO full finetuning. DPO LoRA doesn't exhibit the same error.

winglian commented 1 month ago

@RishabhMaheshwary a workaround for now is to append this in your launch command --dataset_processes=1 e.g.: accelerate launch --use_deepspeed -m axolotl.cli.train ../examples/mistral/config.yml --dataset_processes=1

RishabhMaheshwary commented 1 month ago

Thanks a lot! Will give it a try and let you know.

winglian commented 1 month ago

There should be a fix upstream in trl too to fix this soon.