dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

Autograd Function error with torch.script #485

Open RobColeman opened 1 year ago

RobColeman commented 1 year ago

When trying to scriptify Tabnet for portability in training, there are compilation errors in the SparseMax and EntMax15 functions.

Could not export Python function call 'Entmax15Function'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

This is a known issue in Pytorch with script compilation of torch.autograd.Function which looks to be triaged with no intention to fix. https://github.com/pytorch/pytorch/issues/22329

What is the current behavior? Compilation into torch.script will fail

If the current behavior is a bug, please provide the steps to reproduce.

scripted_tabnet_network = torch.jit.script(tabnet_model.network)

or

traced_script_module = torch.jit.trace(
    tabnet_model.network, input_features_example
)

Expected behavior Should compile into jit.script or jit.trace

Screenshots

Other relevant information: tabnet version: 3.1.1 & 4.0
python version: 3.8+

Suggested fix

Include an alternative implementation of Softmax and EntMax15 which do not use torch.autograd.Function

Optimox commented 1 year ago

Hello @RobColeman,

Do you absolutely need a torch script for your production environment ? If you have python install in production, then you can install pytorch-tabnet and just use the save/load framework.

Otherwise please don't hesitate to open a PR with an updated version of Entmax and Sparsemax that are scriptable.