Open Corsky opened 1 month ago
Hello @Corsky! Might be resolved if you upgrade to transformers 4.43.2, see https://github.com/huggingface/transformers/issues/32185 for reference.
Thanks @pcuenca I tried transformers 4.43.2 but same error occurs again... I read over the fix on transformers, it only add a check on if torch has dtype fp8_e4m3fn, the fix would work for torch <2.1 but not for 2.4 I guess, 2.4.0 should have the fp8 dtype already.
is_param_float8_e4m3fn = hasattr(torch, 'float8_e4m3fn') and param.dtype == torch.float8_e4m3fn
seems like it somehow passed the hasattr(torch, 'float8_e4m3fn') check but torch still couldnt find Float8_e4m3fnStorage...
let me check if there might be some error on my pytorch version?
I assume you are running on FP8 hardware (H100), right? (I believe you'd get a different error if you are not). Other than that, I'm not sure if you'd need to upgrade to cuda 12.
@pcuenca I'm using H800 now, it do support FP8 I think, and yes the error is different when I tried on A100 at the first time. Let me try if it works on cuda 12
File "/workdir/user_repository/inference/local_deploy_demo.py", line 41, in load_model self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True, File "/usr/local/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained return model_class.from_pretrained( File "/usr/local/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3737, in from_pretrained dtype_orig = cls._set_default_torch_dtype(torch_dtype) File "/usr/local/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1571, in _set_default_torch_dtype torch.set_default_dtype(dtype) File "/usr/local/lib/python3.9/site-packages/torch/init.py", line 796, in set_default_dtype _C._set_default_dtype(d) TypeError: couldn't find storage object Float8_e4m3fnStorage
I set torch_dtype=torch.float8_e4m3fn when loading the llama3.1 405b model using transformers and this error occurs. cuda version is 11.8 and torch version is 2.4.0. transformers version is 4.43.1 Any idea how to fix this error?