lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

torch_tensorrt compilation fails #80

Open FredHaa opened 2 years ago

FredHaa commented 2 years ago

Hi

Thank you for contributing this amazing repo!

I've tried to compile the model with the new torch_tensorrt torchscript compiler. However, it does not seem to support a variable number of arguments which is heavily used for the various forward() functions in the model definition:

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/home/frederik/.local/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 576
    def forward(self, x, **kwargs):
                          ~~~~~~~ <--- HERE
        if self.auto_check_redraw:
            self.proj_updater.redraw_projections()

If I were to make a pull request which removes the **kwargs in order to provide TensorRT and torchscript support, would it be merged?