facebookresearch / nougat

Implementation of Nougat Neural Optical Understanding for Academic Documents
https://facebookresearch.github.io/nougat/
MIT License
8.98k stars 567 forks source link

MPS accelaration #61

Closed joanvelja closed 1 year ago

joanvelja commented 1 year ago

Hi all, Why is the code not supporting MPS acceleration? Should it work or it has not been implemented quite yet?

erip commented 1 year ago

It assumes cuda because of these. You can submit a PR to fix them.

lukas-blecher commented 1 year ago

Yes, please submit a PR. I can't test it locally

erip commented 1 year ago

@joanvelja I have created a PR to add MPS support. Feel free to give it a try and please provide feedback as you do in the PR thread.

DrMWeigand commented 1 year ago

I tried to add mps support but it seems to be that there are calculations and datatypes used which are currently not supported by apple silicon gpus. In particular, the operator aten::erfinv.out and datatype bfloat16 are not yet supported.

nullhook commented 1 year ago

@DrMWeigand Where do you see the usage of erfinv?

According to this PR, erfinv will have the MPS variant soon.

erip commented 1 year ago

Aside from bf16 support, the only other issue I ran into was addressed by (a modified version of) c02b772.

lukas-blecher commented 1 year ago

cool! What if we don't use bf16 for mps?

erip commented 1 year ago

Should be fine if slower

erip commented 1 year ago

I had to upgrade to nightly because of unavailability of aten::roll in 2.0.1, but with 2.2.0-dev*`...

Good news: only casting to bfloat16 when device.type != "mps" seems to be running. Bad news: it seems to be running extremely slowly; a 2 page PDF has been decoding for >10 min w/ bs=16. 😄 It just completed with both pages skipped so I'm rerunning with error heuristics disabled.

DrMWeigand commented 1 year ago

@DrMWeigand Where do you see the usage of erfinv?

According to this PR, erfinv will have the MPS variant soon.

It was triggered on line 103 in predict.py: model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)

leading to the error:

NotImplementedError: The operator 'aten::erfinv.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variablePYTORCH_ENABLE_MPS_FALLBACK=1to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

As written in the error message, it can be avoided by setting the stated environment variable to 1. This worked as a workaround but then I was running into the bfloat16 error.

erip commented 1 year ago

I take my previous comment back: it seems to have been an unlucky PDF. Running predict on my CV yields mostly reasonable speed (and quality). I'll push the most recent changes, but with the caveat that it's only been manually tested on inference using a nightly pytorch. This might only require 2.1.0, but I'm not positive as I haven't tracked down the right commits for MPS operations.

kaieberl commented 1 year ago

I am running nougat on an M2 MacBook Air, 8GB RAM, with PyTorch 2.2.0 nightly. For me, since upgrading nougat-api to the latest version (which uses mps), it is unusable because it tries to allocate >13GB RAM, even when setting --batchsize 1. It was working fine on CPU, with RAM <4GB (although quite slow). I checked, it is using bf16. Does anyone have the same issue? I am also willing to contribute some code here, if you point me into the right direction.

ehartford commented 9 months ago

Hello; Is this working?