coreweave / tensorizer

Module, Model, and Tensor Serialization/Deserialization
MIT License
180 stars 25 forks source link

utils.get_device() should include 'mps' #136

Open Zihann73 opened 3 months ago

Zihann73 commented 3 months ago

https://pytorch.org/docs/master/tensor_attributes.html#torch-device Now this function only return cpu or cuda. I failed to run some transformer based models on my MacOS due to this limitation. image I fixed it by calling model.to('mps'). image

Zihann73 commented 3 months ago

Please check my quick fix: https://github.com/coreweave/tensorizer/pull/137