meta-llama / llama3

The official Meta Llama 3 GitHub site
Other
26.15k stars 2.94k forks source link

Llama3.1 405B using FP8 TypeError: couldn't find storage object Float8_e4m3fnStorage #285

Open Corsky opened 1 month ago

Corsky commented 1 month ago

File "/workdir/user_repository/inference/local_deploy_demo.py", line 41, in load_model self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True, File "/usr/local/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained return model_class.from_pretrained( File "/usr/local/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3737, in from_pretrained dtype_orig = cls._set_default_torch_dtype(torch_dtype) File "/usr/local/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1571, in _set_default_torch_dtype torch.set_default_dtype(dtype) File "/usr/local/lib/python3.9/site-packages/torch/init.py", line 796, in set_default_dtype _C._set_default_dtype(d) TypeError: couldn't find storage object Float8_e4m3fnStorage

I set torch_dtype=torch.float8_e4m3fn when loading the llama3.1 405b model using transformers and this error occurs. cuda version is 11.8 and torch version is 2.4.0. transformers version is 4.43.1 Any idea how to fix this error?

pcuenca commented 1 month ago

Hello @Corsky! Might be resolved if you upgrade to transformers 4.43.2, see https://github.com/huggingface/transformers/issues/32185 for reference.

Corsky commented 1 month ago

Thanks @pcuenca I tried transformers 4.43.2 but same error occurs again... I read over the fix on transformers, it only add a check on if torch has dtype fp8_e4m3fn, the fix would work for torch <2.1 but not for 2.4 I guess, 2.4.0 should have the fp8 dtype already.

is_param_float8_e4m3fn = hasattr(torch, 'float8_e4m3fn') and param.dtype == torch.float8_e4m3fn

seems like it somehow passed the hasattr(torch, 'float8_e4m3fn') check but torch still couldnt find Float8_e4m3fnStorage...

let me check if there might be some error on my pytorch version?

pcuenca commented 1 month ago

I assume you are running on FP8 hardware (H100), right? (I believe you'd get a different error if you are not). Other than that, I'm not sure if you'd need to upgrade to cuda 12.

Corsky commented 1 month ago

@pcuenca I'm using H800 now, it do support FP8 I think, and yes the error is different when I tried on A100 at the first time. Let me try if it works on cuda 12