axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

Error Finetuning CodeQwen 1.5 7B: ```Column 1 named token_type_ids expected length 113 but got length 112``` #1632

Open artemdinaburg opened 3 months ago

artemdinaburg commented 3 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

I am attempting to finetune CodeQwen 1.5 7B on a custom dataset; I expect the finetuning to start; the same exact dataset works for CodeLlama and DeepSeek Coder.

Current behaviour

(setup pid=13913) 888c8f8c4e81: Pull complete
(setup pid=13913) Digest: sha256:b7c4ce815c7d2e77aae54390747c06ec99ac0409a8d989e9b3abf2aa9023c9d1
(setup pid=13913) Status: Downloaded newer image for winglian/axolotl:main-latest
(setup pid=13913) docker.io/winglian/axolotl:main-latest
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) ==========
(train-codeqwen, pid=13913) == CUDA ==
(train-codeqwen, pid=13913) ==========
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) CUDA Version 11.8.0
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) This container image and its contents are governed by the NVIDIA Deep Learning Container License.
(train-codeqwen, pid=13913) By pulling and using the container, you accept the terms and conditions of this license:
(train-codeqwen, pid=13913) https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
(train-codeqwen, pid=13913) Token is valid (permission: read).
(train-codeqwen, pid=13913) Your token has been saved to /root/.cache/huggingface/token
(train-codeqwen, pid=13913) Login successful
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) ==========
(train-codeqwen, pid=13913) == CUDA ==
(train-codeqwen, pid=13913) ==========
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) CUDA Version 11.8.0
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) This container image and its contents are governed by the NVIDIA Deep Learning Container License.
(train-codeqwen, pid=13913) By pulling and using the container, you accept the terms and conditions of this license:
(train-codeqwen, pid=13913) https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) The following values were not passed to `accelerate launch` and had defaults used instead:
(train-codeqwen, pid=13913)     `--num_processes` was set to a value of `1`
(train-codeqwen, pid=13913)     `--num_machines` was set to a value of `1`
(train-codeqwen, pid=13913)     `--mixed_precision` was set to a value of `'no'`
(train-codeqwen, pid=13913)     `--dynamo_backend` was set to a value of `'no'`
(train-codeqwen, pid=13913) To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
(train-codeqwen, pid=13913) WARNING: BNB_CUDA_VERSION=118 environment variable detected; loading libbitsandbytes_cuda118.so.
(train-codeqwen, pid=13913) This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
(train-codeqwen, pid=13913) If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
(train-codeqwen, pid=13913) If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
(train-codeqwen, pid=13913) For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) [2024-05-16 20:55:49,103] [INFO] [datasets.<module>:58] [PID:38] PyTorch version 2.1.2+cu118 available.
(train-codeqwen, pid=13913) [2024-05-16 20:55:50,255] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(train-codeqwen, pid=13913) df: /root/.triton/autotune: No such file or directory
(train-codeqwen, pid=13913) [2024-05-16 20:55:50,338] [INFO] [root.spawn:38] [PID:38] gcc -pthread -B /root/miniconda3/envs/py3.10/compiler_compat -
Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /root/miniconda3/envs/py3.10/include -fPIC -O2 -isystem /root/minicon
da3/envs/py3.10/include -fPIC -c /tmp/tmpirov19_4/test.c -o /tmp/tmpirov19_4/test.o
(train-codeqwen, pid=13913) [2024-05-16 20:55:50,358] [INFO] [root.spawn:38] [PID:38] gcc -pthread -B /root/miniconda3/envs/py3.10/compiler_compat /
tmp/tmpirov19_4/test.o -laio -o /tmp/tmpirov19_4/a.out
(train-codeqwen, pid=13913)  [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
(train-codeqwen, pid=13913)  [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
(train-codeqwen, pid=13913)  [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
(train-codeqwen, pid=13913) [2024-05-16 20:55:52,063] [WARNING] [axolotl.utils.config.models.input.hint_trust_remote_code:276] [PID:38] [RANK:0] `tr
ust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.
(train-codeqwen, pid=13913) [2024-05-16 20:55:52,063] [DEBUG] [axolotl.normalize_config:79] [PID:38] [RANK:0] bf16 support detected, enabling for this configuration.
(train-codeqwen, pid=13913) /root/miniconda3/envs/py3.10/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
(train-codeqwen, pid=13913)   warnings.warn(
(train-codeqwen, pid=13913) [2024-05-16 20:55:52,240] [INFO] [axolotl.normalize_config:182] [PID:38] [RANK:0] GPU memory usage baseline: 0.000GB (+0.614GB misc)
(train-codeqwen, pid=13913)                                  dP            dP   dP
(train-codeqwen, pid=13913)                                  88            88   88
(train-codeqwen, pid=13913)       .d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88
(train-codeqwen, pid=13913)       88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88
(train-codeqwen, pid=13913)       88.  .88  .d88b.  88.  .88 88 88.  .88   88   88
(train-codeqwen, pid=13913)       `88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913) ****************************************
(train-codeqwen, pid=13913) **** Axolotl Dependency Versions *****
(train-codeqwen, pid=13913)   accelerate: 0.30.1
(train-codeqwen, pid=13913)         peft: 0.10.0
(train-codeqwen, pid=13913) transformers: 4.40.2
(train-codeqwen, pid=13913)          trl: 0.8.5
(train-codeqwen, pid=13913)        torch: 2.1.2+cu118
(train-codeqwen, pid=13913) bitsandbytes: 0.43.1
(train-codeqwen, pid=13913) ****************************************
(train-codeqwen, pid=13913) /root/miniconda3/envs/py3.10/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
(train-codeqwen, pid=13913)   warnings.warn(
(train-codeqwen, pid=13913) [2024-05-16 20:55:53,957] [DEBUG] [axolotl.load_tokenizer:280] [PID:38] [RANK:0] EOS: 2 / <|endoftext|>
(train-codeqwen, pid=13913) [2024-05-16 20:55:53,958] [DEBUG] [axolotl.load_tokenizer:281] [PID:38] [RANK:0] BOS: 2 / <|endoftext|>
(train-codeqwen, pid=13913) [2024-05-16 20:55:53,958] [DEBUG] [axolotl.load_tokenizer:282] [PID:38] [RANK:0] PAD: 2 / <|endoftext|>
(train-codeqwen, pid=13913) [2024-05-16 20:55:53,958] [DEBUG] [axolotl.load_tokenizer:283] [PID:38] [RANK:0] UNK: 0 / <unk>
(train-codeqwen, pid=13913) [2024-05-16 20:55:53,958] [INFO] [axolotl.load_tokenizer:294] [PID:38] [RANK:0] No Chat template selected. Consider adding a chat template for easier inference.
(train-codeqwen, pid=13913) [2024-05-16 20:55:54,035] [INFO] [axolotl.load_tokenized_prepared_datasets:183] [PID:38] [RANK:0] Unable to find prepared dataset in /sky-notebook/codeqwen-7b-lora/last_run_prepared/aecff90c4b2de7ce4995cf990640ff10
(train-codeqwen, pid=13913) [2024-05-16 20:55:54,035] [INFO] [axolotl.load_tokenized_prepared_datasets:184] [PID:38] [RANK:0] Loading raw datasets...
(train-codeqwen, pid=13913) [2024-05-16 20:55:54,036] [WARNING] [axolotl.load_tokenized_prepared_datasets:186] [PID:38] [RANK:0] Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.
(train-codeqwen, pid=13913) [2024-05-16 20:55:54,036] [INFO] [axolotl.load_tokenized_prepared_datasets:193] [PID:38] [RANK:0] No seed provided, using default seed of 42
Generating train split: 274036 examples [00:05, 53118.96 examples/s][00:04, 53696.76 examples/s]
(train-codeqwen, pid=13913) Tokenizing Prompts (num_proc=12):   2%|▏         | 5600/274036 [00:02<01:15, 3565.15 examples/s]
(train-codeqwen, pid=13913) Tokenizing Prompts (nuShared connection to 34.42.205.109 closed.[00:06<01:12, 3512.56 examples/s]
Tokenizing Prompts (num_proc=12):  21%|██        | 58200/274036 [00:36<02:15, 1589.99 examples/s]6<05:37, 639.33 examples/s]
(train-codeqwen, pid=13913) multiprocess.pool.RemoteTraceback:
(train-codeqwen, pid=13913) """
(train-codeqwen, pid=13913) Traceback (most recent call last):
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/multiprocess/pool.py", line 125, in worker
(train-codeqwen, pid=13913)     result = (True, func(*args, **kwds))
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 678, in _write_generator_to_queue
(train-codeqwen, pid=13913)     for i, result in enumerate(func(**kwargs)):
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3570, in _map_single
(train-codeqwen, pid=13913)     writer.write_batch(batch)
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/datasets/arrow_writer.py", line 571, in write_batch
(train-codeqwen, pid=13913)     pa_table = pa.Table.from_arrays(arrays, schema=schema)
(train-codeqwen, pid=13913)   File "pyarrow/table.pxi", line 4642, in pyarrow.lib.Table.from_arrays
(train-codeqwen, pid=13913)   File "pyarrow/table.pxi", line 3922, in pyarrow.lib.Table.validate
(train-codeqwen, pid=13913)   File "pyarrow/error.pxi", line 91, in pyarrow.lib.check_status
(train-codeqwen, pid=13913) pyarrow.lib.ArrowInvalid: Column 1 named token_type_ids expected length 113 but got length 112
(train-codeqwen, pid=13913) """
(train-codeqwen, pid=13913)
(train-codeqwen, pid=13913)     raise self._value
(train-codeqwen, pid=13913) pyarrow.lib.ArrowInvalid: Column 1 named token_type_ids expected length 113 but got length 112
(train-codeqwen, pid=13913) Traceback (most recent call last):
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/bin/accelerate", line 8, in <module>
(train-codeqwen, pid=13913)     sys.exit(main())
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
(train-codeqwen, pid=13913)     args.func(args)
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1082, in launch_command
(train-codeqwen, pid=13913)     simple_launcher(args)
(train-codeqwen, pid=13913)   File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 688, in simple_launcher
(train-codeqwen, pid=13913)     raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
(train-codeqwen, pid=13913) subprocess.CalledProcessError: Command '['/root/miniconda3/envs/py3.10/bin/python', '-m', 'axolotl.cli.train', '/sky_workdir/codeqwen-7b-lora.yml']' returned non-zero exit status 1.

### Steps to reproduce

I run axoltol via SkyPilot using the latest docker container. I can reliably reproduce the current behavior with a dataset that works to train CodeLlama and DeepSeek Coder.

### Config yaml

```yaml
base_model: Qwen/CodeQwen1.5-7B

trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: /sky_workdir/solidity_fim_80_txt.jsonl
    type: completion
dataset_prepared_path: /sky-notebook/codeqwen-7b-lora/last_run_prepared
output_dir: /sky-notebook/codeqwen-7b-lora/
val_set_size: 0.05

sequence_len: 1024
sample_packing: false

adapter: lora
lora_model_dir:
lora_r: 256
lora_alpha: 512
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project: solidity
wandb_name: codeqwen-7b-solidity-lora
wandb_watch:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0003

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
eval_step: 0.10
save_strategy: steps
save_steps: 100
save_total_limit: 5
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
  pad_token: "<|endoftext|>"

Possible solution

My gut tells me this has something to do with Qwen's weird tokenizer.

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

docker winglian/axolotl:main-latest b7c4ce815c7d2e77aae54390747c06ec99ac0409a8d989e9b3abf2aa9023c9d1

Acknowledgements

winglian commented 3 months ago

@artemdinaburg is there a dataset you can share to reproduce the issue? I tried codellama with the alpaca dataset and it tokenized without any issues, so it's likely something specific to the dataset you are using.

artemdinaburg commented 3 months ago

Hi Wing,

Thanks for taking a look! I did some binary searching and I have a three-line JSONL that reproduces a very similar error (the lengths are different).

Invocation:

python -m axolotl.cli.preprocess debug.yml --debug
[2024-05-20 11:37:56,299] [INFO] [datasets.<module>:58] [PID:1195709] PyTorch version 2.1.2 available.
                                 dP            dP   dP
                                 88            88   88
      .d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88
      88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88
      88.  .88  .d88b.  88.  .88 88 88.  .88   88   88
      `88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP

****************************************
**** Axolotl Dependency Versions *****
  accelerate: 0.30.1
        peft: 0.10.0
transformers: 4.40.2
         trl: 0.8.5
       torch: 2.1.2
bitsandbytes: 0.43.1
****************************************
[2024-05-20 11:37:57,317] [WARNING] [axolotl.utils.config.models.input.hint_trust_remote_code:276] [PID:1195709] [RANK:0] `trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.
[2024-05-20 11:37:57,317] [DEBUG] [axolotl.normalize_config:79] [PID:1195709] [RANK:0] bf16 support detected, enabling for this configuration.
/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
[2024-05-20 11:37:57,523] [INFO] [axolotl.normalize_config:182] [PID:1195709] [RANK:0] GPU memory usage baseline: 0.000GB (+0.216GB misc)
[2024-05-20 11:37:58,318] [DEBUG] [axolotl.load_tokenizer:280] [PID:1195709] [RANK:0] EOS: 2 / <|endoftext|>
[2024-05-20 11:37:58,318] [DEBUG] [axolotl.load_tokenizer:281] [PID:1195709] [RANK:0] BOS: 2 / <|endoftext|>
[2024-05-20 11:37:58,318] [DEBUG] [axolotl.load_tokenizer:282] [PID:1195709] [RANK:0] PAD: 2 / <|endoftext|>
[2024-05-20 11:37:58,318] [DEBUG] [axolotl.load_tokenizer:283] [PID:1195709] [RANK:0] UNK: 0 / <unk>
[2024-05-20 11:37:58,318] [INFO] [axolotl.load_tokenizer:294] [PID:1195709] [RANK:0] No Chat template selected. Consider adding a chat template for easier inference.
[2024-05-20 11:37:58,319] [INFO] [axolotl.load_tokenized_prepared_datasets:183] [PID:1195709] [RANK:0] Unable to find prepared dataset in last_run_prepared/884cc22e60a6d1b0682e05f4340d9a4d
[2024-05-20 11:37:58,319] [INFO] [axolotl.load_tokenized_prepared_datasets:184] [PID:1195709] [RANK:0] Loading raw datasets...
[2024-05-20 11:37:58,319] [INFO] [axolotl.load_tokenized_prepared_datasets:193] [PID:1195709] [RANK:0] No seed provided, using default seed of 42
Generating train split: 3 examples [00:00, 1508.56 examples/s]
num_proc must be <= 3. Reducing num_proc to 3 for dataset of size 3.
[2024-05-20 11:37:58,672] [WARNING] [datasets.arrow_dataset.map:3087] [PID:1195709] num_proc must be <= 3. Reducing num_proc to 3 for dataset of size 3.
Tokenizing Prompts (num_proc=3):  67%|████████████████████████████████████████                    | 2/3 [00:00<00:00,  7.60 examples/s]
multiprocess.pool.RemoteTraceback:
"""
Traceback (most recent call last):
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 678, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3570, in _map_single
    writer.write_batch(batch)
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/arrow_writer.py", line 571, in write_batch
    pa_table = pa.Table.from_arrays(arrays, schema=schema)
  File "pyarrow/table.pxi", line 4642, in pyarrow.lib.Table.from_arrays
  File "pyarrow/table.pxi", line 3922, in pyarrow.lib.Table.validate
  File "pyarrow/error.pxi", line 91, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Column 1 named token_type_ids expected length 2 but got length 1
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/artem/git/axolotl/src/axolotl/cli/preprocess.py", line 82, in <module>
    fire.Fire(do_cli)
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/artem/git/axolotl/src/axolotl/cli/preprocess.py", line 72, in do_cli
    load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
  File "/home/artem/git/axolotl/src/axolotl/cli/__init__.py", line 403, in load_datasets
    train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
  File "/home/artem/git/axolotl/src/axolotl/utils/data/sft.py", line 66, in prepare_dataset
    train_dataset, eval_dataset, prompters = load_prepare_datasets(
  File "/home/artem/git/axolotl/src/axolotl/utils/data/sft.py", line 460, in load_prepare_datasets
    dataset, prompters = load_tokenized_prepared_datasets(
  File "/home/artem/git/axolotl/src/axolotl/utils/data/sft.py", line 399, in load_tokenized_prepared_datasets
    dataset_wrapper, dataset_prompter = get_dataset_wrapper(
  File "/home/artem/git/axolotl/src/axolotl/utils/data/sft.py", line 553, in get_dataset_wrapper
    dataset_wrapper = TokenizedPromptDataset(
  File "/home/artem/git/axolotl/src/axolotl/datasets.py", line 43, in __init__
    self.process(dataset).data,
  File "/home/artem/git/axolotl/src/axolotl/datasets.py", line 55, in process
    return dataset.map(
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3248, in map
    for rank, done, content in iflatmap_unordered(
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 718, in iflatmap_unordered
    [async_result.get(timeout=0.05) for async_result in async_results]
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 718, in <listcomp>
    [async_result.get(timeout=0.05) for async_result in async_results]
  File "/home/artem/anaconda3/envs/sol-eval/lib/python3.10/site-packages/multiprocess/pool.py", line 774, in get
    raise self._value
pyarrow.lib.ArrowInvalid: Column 1 named token_type_ids expected length 2 but got length 1

My config (debug.yml):

base_model: Qwen/CodeQwen1.5-7B

trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: ./minimum_error_repro.jsonl
    type: completion
dataset_prepared_path: ./last_run_prepared
output_dir: ./codeqwen-7b-lora/
val_set_size: 0.05

sequence_len: 1024
sample_packing: false

adapter: lora
lora_model_dir:
lora_r: 256
lora_alpha: 512
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project: solidity
wandb_name: codeqwen-7b-solidity-lora
wandb_watch:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0003

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
eval_step: 0.10
save_strategy: steps
save_steps: 100
save_total_limit: 5
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
  pad_token: "<|endoftext|>"

minimum_error_repro.jsonl.gz

artemdinaburg commented 3 months ago

As an update, the first two lines seem irrelevant and each can be replaced with {"text": "A"} but the third line is extremely sensitive to any kind of modification; adding or removing tokens makes the error disappear.

artemdinaburg commented 3 months ago

@winglian I tracked it down to here: https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/prompt_tokenizers.py#L80-L86

where a token is added to input_ids and attention_mask but not token_type_ids

I believe the Qwen tokenizer is one of the few that returns token_type_ids.

The resolution would be to either set return_token_type_ids=False in the call to tokenizer, or to add an entry for token_type_ids. I am not sure which is best.

Validated that adding:


            result["attention_mask"].append(1)
            if "token_type_ids" in result:
                result["token_type_ids"].append(result["token_type_ids"][-1])```

Seems to fix it for me.