skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.89k stars 391 forks source link

how to integrate pytorch-tabnet into skorch framework? #1029

Closed pangjac closed 1 year ago

pangjac commented 1 year ago

Hello, I'm new to the Skorch community. I'm currently experimenting with pytorch-tabnet, which is reputed to be a potent solution for tabular data in deep learning, built on PyTorch.

My query pertains to integrating pytorch-tabnet within the scikit-learn framework. I attempted to search for tabnet within current skorch repository but couldn't find any relevant keywords, leading me to believe that tabnet hasn't been implemented here yet.

With this in mind, if I wish to explore how to incorporate pytorch-tabnet into our framework, could you recommend where I should start? Is there any preliminary guidance or direction you could provide?

Thank you so much for this wonderful framework!

BenjaminBossan commented 1 year ago

AFAIK, tabnet provides a class that is similar to skorch's NeuralNetClassifier etc. So you can call net.fit(X, y), net.predict(X), etc. As such, there is not really a need to use both of them together. When it comes to how deep the tabnet scikit-learn integration is, I don't know, probably not as deep as skorch.

There was a discussion once about integrating tabnet with skorch because of the big overlap: https://github.com/dreamquark-ai/tabnet/issues/41, but this was never accomplished. Not sure if @ottonemo has any updates.

ottonemo commented 1 year ago

@pangjac there's this PR https://github.com/dreamquark-ai/tabnet/pull/279 where I added an example notebook showing how to integrate tabnet with skorch. The PR never got merged but it should still be able to serve you as an example I hope!