nilesh2797 / DEXML

Dual-encoders for Extreme Multi-label Learning
Apache License 2.0
5 stars 3 forks source link

Support for MPS #7

Open thekop69 opened 4 months ago

thekop69 commented 4 months ago

Hi @nilesh2797 - As you might have noticed from my recent PR, I've been playing with the notebook in this repo to better understand DE in XMC. I've been running this on my Macbook M1 Pro and have had to make some changes to support MPS. If you are interested, I'd be happy to contribute these changes back into your repo via another PR. The changes itself are tiny but a key change is that MPS does not support float64, of which the label embeddings are. This doesn't affect CUDA however, though I guess, some floating point precision difference will be noticed when training on Metal and compared against an NVidia run.

thekop69 commented 4 months ago

The other thing that's worth noting is that autocast isn't supported by MPS (yet). There's been a PR that's been on-going for quite some time: https://github.com/pytorch/pytorch/pull/127063/commits/428a37c05a8b09ca4a2d551ea8488fb0d99710cb#diff-0ca6f7e7a58e1b11315cda35388d83cc646f76f14eb1e5ec2d5fd520f80d5a9b

But I am able to train and get the embeddings.