Open muziyongshixin opened 1 year ago
What version of transformers are you using? If you used the usual pip install you have the wrong version (since the argument is not yet implemented). In the README it states that you should install from github (so you need the main branch), that worked for me.
pip install git+https://github.com/huggingface/transformers
When I use the command below I got an error:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/tiger/qlora/qlora/qlora.py:758 in │
│ │
│ 755 │ │ │ fout.write(json.dumps(all_metrics)) │
│ 756 │
│ 757 if name == "main": │
│ ❱ 758 │ train() │
│ 759 │
│ │
│ /home/tiger/qlora/qlora/qlora.py:590 in train │
│ │
│ 587 │ if completed_training: │
│ 588 │ │ print('Detected that training was already completed!') │
│ 589 │ │
│ ❱ 590 │ model = get_accelerate_model(args, checkpoint_dir) │
│ 591 │ training_args.skip_loading_checkpoint_weights=True │
│ 592 │ │
│ 593 │ model.config.use_cache = False │
│ │
│ /home/tiger/qlora/qlora/qlora.py:263 in get_accelerate_model │
│ │
│ 260 │ │
│ 261 │ print(f'loading base model {args.model_name_or_path}...') │
│ 262 │ compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else t │
│ ❱ 263 │ model = AutoModelForCausalLM.from_pretrained( │
│ 264 │ │ args.model_name_or_path, │
│ 265 │ │ load_in_4bit=args.bits == 4, │
│ 266 │ │ load_in_8bit=args.bits == 8, │
│ │
│ /home/tiger/.local/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py:467 in │
│ from_pretrained │
│ │
│ 464 │ │ │ ) │
│ 465 │ │ elif type(config) in cls._model_mapping.keys(): │
│ 466 │ │ │ model_class = _get_model_class(config, cls._model_mapping) │
│ ❱ 467 │ │ │ return model_class.from_pretrained( │
│ 468 │ │ │ │ pretrained_model_name_or_path, *model_args, config=config, *hub_kwargs, │
│ 469 │ │ │ ) │
│ 470 │ │ raise ValueError( │
│ │
│ /home/tiger/.local/lib/python3.9/site-packages/transformers/modeling_utils.py:2611 in │
│ from_pretrained │
│ │
│ 2608 │ │ │ init_contexts.append(init_empty_weights()) │
│ 2609 │ │ │
│ 2610 │ │ with ContextManagers(init_contexts): │
│ ❱ 2611 │ │ │ model = cls(config, model_args, **model_kwargs) │
│ 2612 │ │ │
│ 2613 │ │ # Check first if we are
from_pt
│ │ 2614 │ │ if use_keep_in_fp32_modules: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ TypeError: init() got an unexpected keyword argument 'load_in_4bit'