lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.11k stars 1.09k forks source link

Running with MPS as device (for apple M1 chips) #136

Open JoseEliel opened 2 years ago

JoseEliel commented 2 years ago

Hi!

I've been trying to set this up on a MacBook with an M1 chip. However, simply changing ".cuda()" -> .to(device) with device being set to "mps" does not work.

As far as I can understand, the issue boils down to the use of torch.cuda.amp (the gradient scaler and autocaster) in trainer.py. I tried setting enabled=false by hand, but that does not work either. My current knowledge is not enough to figure out how to modify the code to get around that. It does work if I set the device to "CPU" though.

Any pointers on how to do it?

Thanks for the hard work in this implementation, looking forward to making it work on my M1 MacBook.

YannickAaron commented 2 years ago

Since this was somehow connected to 8105 maybe it will work now.