minimaxir / aitextgen

A robust Python tool for text-based AI training and generation using GPT-2.
https://docs.aitextgen.io
MIT License
1.84k stars 220 forks source link

Add support for Apple MPS GPU #230

Open tony352 opened 1 year ago

tony352 commented 1 year ago

This adds support for Apple MPS GPUs (Apple Silicon). It adds a new function called to_gpu_mps, modelled on the to_gpu function. A default parameter (init) is set to False.

It is based on the guidance below from Apple (see the 'verify' section): https://developer.apple.com/metal/pytorch/

I have not been able to test this because it connects onto Google Collab and other files I do not have.