Closed dvrogozh closed 2 months ago
There was an assumption that https://github.com/pytorch/pytorch/pull/129119 PR will address this issue. Unfortunately it does not - I still see the same issue. I did add printouts to the toDevice()
function and can tell that this function is called and 129119 PR is playing the game - I see that patched code branch is getting hit. However, I also see that ultimately we get cuda:0
here:
https://github.com/pytorch/pytorch/blob/95046c86e3547e46ef5733925e02278e46c5c6d4/torch/csrc/utils/python_arg_parser.h#L823. Thus, I think there is one more place somewhere hardcoding cuda.
Root cause is that safetensors library hardcodes to return cuda device if only device index is provided, i.e. here: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L297
The fix probably should be to return device calling torch.device(N)
.
Filed https://github.com/huggingface/safetensors/issues/499.
Hl @dvrogozh, thanks for raising this issue, for the deep dive and for taking the time to write up such a detailed explanation of the problem and what you've tried - it's incredibly appreciated!
If I've understood correctly, there isn't anything to do on the transformers side and it's pending a resolution in the safetensors library?
cc @muellerzr @sun for reference as this touches accelerate and offloading logic
If I've understood correctly, there isn't anything to do on the transformers side and it's pending a resolution in the safetensors library?
That is right. I initially filed this issue here since transformers use case was affected and root cause was not clear at the moment of filing.
@dvrogozh OK, no worries. Just so we know if there's an action point for the team. We can leave as open until the safetensors issue is resolved and we can confirm that pipeline
will run as expected.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Found on this code versions: https://github.com/huggingface/transformers/commit/7f79a97399bb52aad8460e1da2f36577d5dccfed, https://github.com/huggingface/accelerate/commit/e1247de01e0733c5d21075cb6f39b2605f4be123, https://github.com/pytorch/pytorch/commit/3477ee38e4dd1429ecfd7e6f20a30cce0f4f78e7. This is an issue with XPU support in stock pytorch (i.e. without using IPEX).
Assume to have XPU gpu on the system, no CUDA. Also assume that model does not fit the device memory. In this case execution will fail since some tensors are wrongly sent to CUDA device. I noted this issue trying to run:
The following example script reproduces the issue on LLAMA 3 8B model by creating memory constrain on XPU device:
Log output:
CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 @sywangyi @yao-matrix