Open dsanmart opened 1 year ago
We are still integrating the new changes for Chipper, but I tried to see what could be done for MPS. After decoupling the encoder and the decoder it seems that there might be additional changes to be done on the decoder. I get a warning that indicates that some tensors need to be mapped from int64 to int32 and makes the greedy decoding as slow as just using CPU, it is even slower than CPU when beam search size = 3 is used. This seems to be an issue in the integration of MPS capabilities in PyTorch, even in the latest version of PyTorch. One option could be to modify the generator and test int32 where LongTensor is currently used or check support from PyTorch for LongTensor under MPS.
Looking at Apple forums, int64 operations are supported by the GPU accelerator.
Have you tried using the latest PyTorch nightly build? This issue was previously raised in PyTorch and also in other repos. Perhaps your PyTorch version doesn't have the LongTensor
ops enabled on MPS? Could you please share the warning message you are getting?
If this was not the issue, how would you approach the first option proposed? Would it be possible to convert the input sequence to int32
before passing it to the decoder and then converting it back to int64
to avoid encountering bugs later? It would look something like this:
input_seq = torch.tensor(input_seq, dtype=torch.int32)
output_seq = decoder(input_seq)
output_seq = output_seq.type(torch.int64)
I did try the latest version. In order to try making it work, the HF generation code ( https://huggingface.co/docs/transformers/main_classes/text_generation) will need to be revised to convert the mentions of LongTensor to int32 or a version that works on mps efficiently. I tried converting the input ids to int32 but the warning is from one of the methods in the HF generation code that has no relation to the type of the input ids. It should be tested with an MBARTDecoder, which probably works ok. If there is a setting in which the HF generation code with an MBARTDecoder works faster on mps than on CPU, it should be possible to speed up Chipper with mps.
On Wed, Oct 4, 2023 at 8:25 PM Diego Sanmartin @.***> wrote:
Looking at Apple forums https://developer.apple.com/forums/thread/712317, int64 operations are supported by the GPU accelerator.
Have you tried using the latest PyTorch nightly build? This issue was previously raised in PyTorch https://github.com/pytorch/pytorch/issues/96610#issuecomment-1467395916 and also in other repos https://github.com/Stability-AI/StableLM/issues/61#issuecomment-1528835238. Perhaps your PyTorch version doesn't have the LongTensor ops enabled on MPS? Could you please share the warning message you are getting?
If this was not the issue, how would you approach the first option proposed? Would it be possible to convert the input sequence to int32 before passing it to the decoder and then converting it back to int64 to avoid encountering bugs later? It would look something like this:
input_seq = torch.tensor(input_seq, dtype=torch.int32) output_seq = decoder(input_seq) output_seq = output_seq.type(torch.int64)
— Reply to this email directly, view it on GitHub https://github.com/Unstructured-IO/unstructured-inference/issues/239#issuecomment-1746484393, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA6BZDJKLVVLVYNWUQJEKD3X5UTQDAVCNFSM6AAAAAA5PE2IVKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONBWGQ4DIMZZGM . You are receiving this because you were mentioned.Message ID: @.***>
I can imagine that you are testing it on the new code for #232 ? What encoder and decoder architectures are you using in chipper-fast-fine-tuning
?
I saw that you are using LongTensors
in the prediction for logits_processor
. Is what you tried converting these to int32
?
As mentioned by @ajjimeno, the
encoder
is not available to MPS but thedecoder
is the bottleneck and can be run through a CUDA or MPS backend for GPU acceleration. This MPS backend is supported by the PyTorch framework. Pytorch backend support docsIt would just be to check if MPS is available, detach the
encoder
anddecoder
when detecting MPS instead of runningmodel.generate
, and map the computational graph of thedecoder
on themps
device. HugginFace example on MPS backend.