artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
9.96k stars 820 forks source link

Unable to generate predictions #275

Open SamarthMM opened 11 months ago

SamarthMM commented 11 months ago

Hello!

I am trying to generate predictions using the qlora.py script in this repo on my custom dataset. However I face two issues:

Short version:

  1. Firstly, the script throw the following error in the tokenizer.batch_decode() call:
    TypeError: argument 'ids': 'list' object cannot be interpreted as an integer I tried this with the base llama2 model with and without adapters.

  2. Secondly, when I tried to 'fix' the issue at my end (more details below), I get mangled output

Detailed version:

Full errror:

  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/mathur9/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/mathur9/samarth/qlora/qlora.py", line 948, in <module>
    train()
  File "/home/mathur9/samarth/qlora/qlora.py", line 930, in train
    predictions = tokenizer.batch_decode(
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3485, in batch_decode
    return [
           ^
  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3486, in <listcomp>
    self.decode(
  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3525, in decode
    return self._decode(
           ^^^^^^^^^^^^^
  File "/mnt/nobackup/samarth/anaconda3/envs/nov/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 546, in _decode
    text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

Config to reproduce error

"args": [ "--model_name_or_path", "/path/to/lama2-13b", "--output_dir", "./outputs/debug", "--logging_steps", "10", "--save_strategy", "steps", "--data_seed", "42", "--save_steps", "500", "--save_total_limit", "40", "--evaluation_strategy", "steps", "--eval_dataset_size", "1", "--max_eval_samples", "700", "--per_device_eval_batch_size", "1", "--max_new_tokens", "32", "--dataloader_num_workers", "3", "--group_by_length", "--logging_strategy", "steps", "--remove_unused_columns", "False", "--do_train","False", "--do_eval","False", "--do_predict", "--lora_r", "64", "--lora_alpha", "16", "--lora_modules", "all", "--double_quant", "--quant_type", "nf4", "--fp16", "--bits", "4", "--warmup_ratio", "0.03", "--lr_scheduler_type", "constant", "--gradient_checkpointing", "--dataset", "/path/to/dataset", "--dataset_format", "samarth", "--source_max_len", "512", "--target_max_len", "512", "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", "16", "--max_steps", "1875", "--eval_steps", "187", "--learning_rate", "0.0002", "--adam_beta2", "0.999", "--max_grad_norm", "0.3", "--lora_dropout", "0.05", "--weight_decay", "0.0", "--seed", "0", "--cache_dir", "./cache", ]

Offending code block

line 823 of qlora.py on tokenizer.batch_decode

    if args.do_predict:
        logger.info("*** Predict ***")
        prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict")
        prediction_metrics = prediction_output.metrics
        predictions = prediction_output.predictions
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        predictions = tokenizer.batch_decode(
            predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

Dataset: I have a small dataset of 2 examples where the first example is

data_module['predict_dataset']['input'][0]
'Input:\nHi I am a causal language model and I like to predict things in a sentence such as \n\n### Category:\n'

Now, predictions = prediction_output.prediction is a numpy array of shape (1, 55, 32001). which makes me believe it’s the logits of each word in the sequence and not a list of tokens. So I tried replacing the detokenizing command as follows:

tokenizer.batch_decode(
            np.argmax(predictions,axis=2), skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

But I get weird mangled output as shown below which indicated that even the input has been affected

's A\n, am a newal reader lear. I am to learn the. the fun. as "\n\n```# Input\n\n\namil f,e,r,,e,x,t,,s,e,r,d,'

Please help!