togethercomputer / OpenChatKit

Apache License 2.0
9.01k stars 1.02k forks source link

An error occurred while fine-tuning the model. #167

Closed ChengYen-Tang closed 1 year ago

ChengYen-Tang commented 1 year ago

Describe the bug An error occurred while fine-tuning the model. https://github.com/togethercomputer/OpenChatKit#fine-tuning-the-model

Screenshots

data_utils: get train_data_loader
Found cached dataset json (/home/kenneth/.cache/huggingface/datasets/json/default-e27cbaeaf7c8b572/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)
Found cached dataset json (/home/kenneth/.cache/huggingface/datasets/json/default-e27cbaeaf7c8b572/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)
Running  gpipe  without data parallel.
=======Initialize Gpipe.
=======Gpipe use FP16
=======Gradient accumulate step:  1
Running  gpipe  without data parallel.
=======Initialize Gpipe.
=======Gpipe use FP16
=======Gradient accumulate step:  1
=======Current micro-batch send/recv size: 8 MB (fp16)
=======Number of micro-batches: 24.
=======Current micro-batch send/recv size: 8 MB (fp16)
=======Number of micro-batches: 24.
Traceback (most recent call last):
  File "/mnt/sda/kenneth/OpenChatKit/training/dist_clm_train.py", line 478, in <module>
    main()
  File "/mnt/sda/kenneth/OpenChatKit/training/dist_clm_train.py", line 443, in main
    pipe = get_pp_module(args, config, device, use_dp)
  File "/mnt/sda/kenneth/OpenChatKit/training/pipeline_parallel/dist_pp_utils.py", line 7, in get_pp_module
    return GpipeAsync(args, config, device, use_dp)
  File "/mnt/sda/kenneth/OpenChatKit/training/pipeline_parallel/dist_gpipe_pipeline_async.py", line 193, in __init__
Traceback (most recent call last):
  File "/mnt/sda/kenneth/OpenChatKit/training/dist_clm_train.py", line 478, in <module>
    main()
  File "/mnt/sda/kenneth/OpenChatKit/training/dist_clm_train.py", line 443, in main
    pipe = get_pp_module(args, config, device, use_dp)
  File "/mnt/sda/kenneth/OpenChatKit/training/pipeline_parallel/dist_pp_utils.py", line 7, in get_pp_module
    return GpipeAsync(args, config, device, use_dp)
  File "/mnt/sda/kenneth/OpenChatKit/training/pipeline_parallel/dist_gpipe_pipeline_async.py", line 195, in __init__
    self.model = _StageLast(args, config, device)
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/dist_gpt_pp_module.py", line 147, in __init__
    self.model = _StageFirst(args, config, device)
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/dist_gpt_pp_module.py", line 113, in __init__
    super(GPTStageFirst, self).__init__(args, config)
    super(GPTStageLast, self).__init__(args, config)
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/dist_gpt_pp_module.py", line 33, in __init__
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/dist_gpt_pp_module.py", line 33, in __init__
    from .hf_gptneox_modules import GPTEmbeddings, GPTBlock, GPTLMHead
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/hf_gptneox_modules.py", line 16, in <module>
    from .hf_gptneox_modules import GPTEmbeddings, GPTBlock, GPTLMHead
  File "/mnt/sda/kenneth/OpenChatKit/training/modules/hf_gptneox_modules.py", line 16, in <module>
    from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding
    from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding
ImportError: cannot import name 'RotaryEmbedding' from 'transformers.models.gpt_neox.modeling_gpt_neox' (/home/kenneth/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py)
ImportError: cannot import name 'RotaryEmbedding' from 'transformers.models.gpt_neox.modeling_gpt_neox' (/home/kenneth/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py)

Desktop (please complete the following information):

darrinh commented 1 year ago

getting the same issue.

csris commented 1 year ago

@azahed98 Can you please take a look?

azahed98 commented 1 year ago

This was caused by a renaming of RotaryEmbedding in transformers 4.31.0. Issue fixed with MR #168