kha-white / manga-ocr

Optical character recognition for Japanese text, with the main focus being Japanese manga
Apache License 2.0
1.68k stars 89 forks source link

M1 GPU Support (MPS) #29

Open Roxiun opened 1 year ago

Roxiun commented 1 year ago

Add MPS device type so that M1/M2 gpus can be utilised

See: https://pytorch.org/docs/stable/notes/mps.html

if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

    # Create a Tensor directly on the mps device
    x = torch.ones(5, device=mps_device)
    # Or
    x = torch.ones(5, device="mps")

    # Any operation happens on the GPU
    y = x * 2

    # Move your model to mps just like any other device
    model = YourFavoriteNet()
    model.to(mps_device)

    # Now every call runs on the GPU
    pred = model(x)
ccodykid commented 1 year ago

See PR https://github.com/kha-white/manga-ocr/pull/30

Actually MacOS 13.3 is required.

Roxiun commented 1 year ago

In addition to this if anyone is using mokuro I also have the following patches on comic text detector:

general.py Line 19

CUDA = True if torch.cuda.is_available() else False
DEVICE = 'cuda' if CUDA else 'cpu'

became

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

basemodel.py Line 13

CUDA = True if torch.cuda.is_available() else False
DEVICE = 'cuda' if CUDA else 'cpu'

became

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"