Closed joanvelja closed 1 year ago
Yes, please submit a PR. I can't test it locally
@joanvelja I have created a PR to add MPS support. Feel free to give it a try and please provide feedback as you do in the PR thread.
I tried to add mps support but it seems to be that there are calculations and datatypes used which are currently not supported by apple silicon gpus. In particular, the operator aten::erfinv.out and datatype bfloat16 are not yet supported.
@DrMWeigand Where do you see the usage of erfinv?
According to this PR, erfinv will have the MPS variant soon.
Aside from bf16 support, the only other issue I ran into was addressed by (a modified version of) c02b772.
cool! What if we don't use bf16 for mps?
Should be fine if slower
I had to upgrade to nightly because of unavailability of aten::roll
in 2.0.1, but with 2.2.0-dev*`...
Good news: only casting to bfloat16
when device.type != "mps"
seems to be running. Bad news: it seems to be running extremely slowly; a 2 page PDF has been decoding for >10 min w/ bs=16. 😄 It just completed with both pages skipped so I'm rerunning with error heuristics disabled.
@DrMWeigand Where do you see the usage of erfinv?
According to this PR, erfinv will have the MPS variant soon.
It was triggered on line 103 in predict.py:
model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
leading to the error:
NotImplementedError: The operator 'aten::erfinv.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable
PYTORCH_ENABLE_MPS_FALLBACK=1to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
As written in the error message, it can be avoided by setting the stated environment variable to 1. This worked as a workaround but then I was running into the bfloat16 error.
I take my previous comment back: it seems to have been an unlucky PDF. Running predict
on my CV yields mostly reasonable speed (and quality). I'll push the most recent changes, but with the caveat that it's only been manually tested on inference using a nightly pytorch. This might only require 2.1.0, but I'm not positive as I haven't tracked down the right commits for MPS operations.
I am running nougat on an M2 MacBook Air, 8GB RAM, with PyTorch 2.2.0 nightly. For me, since upgrading nougat-api to the latest version (which uses mps), it is unusable because it tries to allocate >13GB RAM, even when setting --batchsize 1. It was working fine on CPU, with RAM <4GB (although quite slow). I checked, it is using bf16. Does anyone have the same issue? I am also willing to contribute some code here, if you point me into the right direction.
Hello; Is this working?
Hi all, Why is the code not supporting MPS acceleration? Should it work or it has not been implemented quite yet?