deepseek-ai / DeepSeek-Coder

DeepSeek Coder: Let the Code Write Itself
https://coder.deepseek.com/
MIT License
6.83k stars 472 forks source link

Running finetune_deepseekcoder.py results in return code = -9 and running script directly results in RuntimeError: 'weight' must be 2-D #54

Closed hobpond closed 11 months ago

hobpond commented 11 months ago

Thank you for the handy fine tuning guide but I am not able to get started.

I tried using the default settings as a POC but it ends up erroring out.

This is the output I get when using the sample deepspeed command in the README.md

[2023-11-27 20:47:43,736] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-11-27 20:47:44,929] [WARNING] [runner.py:203:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2023-11-27 20:47:44,929] [INFO] [runner.py:570:main] cmd = /home/user/DeepSeek-Coder/finetune/.venv/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None finetune_deepseekcoder.py --model_name_or_path deepseek-ai/deepseek-coder-6.7b-instruct --data_path ./data/training-data.json --output_dir ./output/ --num_train_epochs 3 --model_max_length 1024 --per_device_train_batch_size 16 --per_device_eval_batch_size 1 --gradient_accumulation_steps 4 --evaluation_strategy no --save_strategy steps --save_steps 100 --save_total_limit 100 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 1 --lr_scheduler_type cosine --gradient_checkpointing True --report_to tensorboard --deepspeed configs/ds_config_zero3.json --bf16 True
[2023-11-27 20:47:47,291] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-11-27 20:47:48,425] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
[2023-11-27 20:47:48,425] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
[2023-11-27 20:47:48,425] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2023-11-27 20:47:48,425] [INFO] [launch.py:163:main] dist_world_size=1
[2023-11-27 20:47:48,425] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
[2023-11-27 20:47:52,153] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-11-27 20:47:52,476] [INFO] [comm.py:637:init_distributed] cdb=None
[2023-11-27 20:47:52,476] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
====================================================================================================
TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=True,
bf16_full_eval=False,
cache_dir=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=configs/ds_config_zero3.json,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_tokens_per_second=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=2e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=./output/runs/Nov27_20-47-51_dev-llm-finetuning,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1.0,
logging_strategy=steps,
lr_scheduler_type=cosine,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
model_max_length=1024,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_torch,
optim_args=None,
output_dir=./output/,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=1,
per_device_train_batch_size=16,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['tensorboard'],
resume_from_checkpoint=None,
run_name=./output/,
save_on_each_node=False,
save_safetensors=True,
save_steps=100,
save_strategy=steps,
save_total_limit=100,
seed=42,
skip_memory_metrics=True,
split_batches=False,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=10,
weight_decay=0.0,
)
PAD Token: <|end▁of▁sentence|> 32014
BOS Token <|begin▁of▁sentence|> 32013
EOS Token <|EOT|> 32021
Load tokenizer from deepseek-ai/deepseek-coder-6.7b-instruct over.
[2023-11-27 20:48:03,930] [INFO] [partition_parameters.py:348:__exit__] finished initializing model - num_params = 291, num_elems = 6.74B
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:11<00:00,  5.72s/it]
Load model from deepseek-ai/deepseek-coder-6.7b-instruct over.
Training dataset samples: 99
...
Using /home/user/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/user/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.3529317378997803 seconds
Parameter Offload: Total persistent parameters: 266240 in 65 params
[2023-11-27 20:49:16,555] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 36733
[2023-11-27 20:49:16,557] [ERROR] [launch.py:321:sigkill_handler] ['/home/user/DeepSeek-Coder/finetune/.venv/bin/python', '-u', 'finetune_deepseekcoder.py', '--local_rank=0', '--model_name_or_path', 'deepseek-ai/deepseek-coder-6.7b-instruct', '--data_path', './data/training-data.json', '--output_dir', './output/', '--num_train_epochs', '3', '--model_max_length', '1024', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '1', '--gradient_accumulation_steps', '4', '--evaluation_strategy', 'no', '--save_strategy', 'steps', '--save_steps', '100', '--save_total_limit', '100', '--learning_rate', '2e-5', '--warmup_steps', '10', '--logging_steps', '1', '--lr_scheduler_type', 'cosine', '--gradient_checkpointing', 'True', '--report_to', 'tensorboard', '--deepspeed', 'configs/ds_config_zero3.json', '--bf16', 'True'] exits with return code = -9

I tried to run the finetune_deepseekcoder.py script directly to see what the actual error is and it outputted

Traceback (most recent call last):
  File "/home/user/DeepSeek-Coder/finetune/finetune_deepseekcoder.py", line 193, in <module>
    train()
  File "/home/user/DeepSeek-Coder/finetune/finetune_deepseekcoder.py", line 187, in train
    trainer.train()
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 879, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/user/DeepSeek-Coder/finetune/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
nstl-zyb commented 11 months ago

Using the sample deepspeed command。The same situation,kill process and return -9. How to fix.

DejianYang commented 11 months ago

请检查下你的package版本 以及是否有足够的内存。

hobpond commented 11 months ago

Thanks for the quick response!

You are right. Looks like an OOM but not VRAM. All 170GB of system ram was used just before the python process died. I changed pin_memory to false in configs/ds_config_zero3.json for both offload_optimizer and offload_param and that got it to start fine tuning. (I also chnaged the deepspeed param of --per_device_train_batch_size 1 instead of 16.)

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": false
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": false
        },
        ...

Not sure if I should tune sub_group_size as well but it is fine tuning now so will report back if a better config is found for a single A100 80GB.

Thanks again for the help!

hobpond commented 11 months ago

Using nvidia-smi to monitor the VRAM, my per_device_train_batch_size was way too small.

A-Janj commented 8 months ago

Thanks for the quick response!

You are right. Looks like an OOM but not VRAM. All 170GB of system ram was used just before the python process died. I changed _pinmemory to false in _configs/ds_configzero3.json for both _offloadoptimizer and _offloadparam and that got it to start fine tuning. (I also chnaged the deepspeed param of _--per_device_train_batchsize 1 instead of 16.)

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": false
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": false
        },
        ...

Not sure if I should tune sub_group_size as well but it is fine tuning now so will report back if a better config is found for a single A100 80GB.

Thanks again for the help!

Hey! Can you tell me the minimum requirement (like GPU VRAM, System RAM, memory) for finetuning using the edits you made in config.json? Actually I am having the same -9 error. Have you found any better config.json?