NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.98k stars 328 forks source link

[bug] Failed to load pretrained model with huggingface transformers #1317

Open kehuanfeng opened 1 week ago

kehuanfeng commented 1 week ago

torch 2.4.0a0+3bcc3cddb5.nv24.7 transformer-engine 1.11.0+c27ee60 transformers 4.45.0

[rank5]: Traceback (most recent call last):
[rank5]:   File "/data/kehuan/LLaMA-Factory/src/train.py", line 28, in <module>
[rank5]:     main()
[rank5]:   File "/data/kehuan/LLaMA-Factory/src/train.py", line 19, in main
[rank5]:     run_exp()
[rank5]:   File "/data/kehuan/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
[rank5]:     run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
[rank5]:   File "/data/kehuan/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 48, in run_sft
[rank5]:     model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
[rank5]:   File "/data/kehuan/LLaMA-Factory/src/llamafactory/model/loader.py", line 162, in load_model
[rank5]:     model = load_class.from_pretrained(**init_kwargs)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 559, in from_pretrained
[rank5]:     return model_class.from_pretrained(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 4008, in from_pretrained
[rank5]:     ) = cls._load_pretrained_model(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 4272, in _load_pretrained_model
[rank5]:     if param.device == torch.device("meta"):
[rank5]: AttributeError: '_io.BytesIO' object has no attribute 'device'

I know it's due to the missing of _extra_state related to fp8, but have no idea how to fix this kind of issue?

timmoon10 commented 5 days ago

Can you try running with https://github.com/NVIDIA/TransformerEngine/pull/1335?

kehuanfeng commented 2 hours ago

Can you try running with #1335?

@timmoon10 It's working now with your pull request. Thank you.