muellerzr / Practical-Deep-Learning-for-Coders-2.0

Notebooks for the "A walk with fastai2" Study Group and Lecture Series
Other
744 stars 165 forks source link

TabNet explainability on custom data #38

Open alexanderwatanabe opened 3 years ago

alexanderwatanabe commented 3 years ago

Hello, thank you for this repo. I am trying to run the TabNet notebook on a custom data set, have got everything working up to the explainability decorator which fails with this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-37-bfedb2755251> in <module>
----> 1 learn.explain(dl)

<ipython-input-36-c2a2fc0e0447> in explain(x, dl)
      6   for batch_nb, data in enumerate(dl):
      7     with torch.no_grad():
----> 8       out, M_loss, M_explain, masks = x.model(data[0], data[1], True)
      9     for key, value in masks.items():
     10       masks[key] = csc_matrix.dot(value.numpy(), matrix)

~/dev/Practical-Deep-Learning-for-Coders-2.0/.venv-nix/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 3 positional arguments but 4 were given

I am reading the docs to better understand how to fix it, if you have any insights/pointers they would be appreciated.

muellerzr commented 3 years ago

What version of tabnet do you have installed?

On Wed, Dec 2, 2020 at 8:28 PM alexanderwatanabe notifications@github.com wrote:

Hello, thank you for this repo. I am trying to run the TabNet notebook on a custom data set, have got everything working up to the explainability decorator which fails with this error:

`--------------------------------------------------------------------------- TypeError Traceback (most recent call last) in ----> 1 learn.explain(dl)

in explain(x, dl) 6 for batch_nb, data in enumerate(dl): 7 with torch.no_grad(): ----> 8 out, M_loss, M_explain, masks = x.model(data[0], data[1], True) 9 for key, value in masks.items(): 10 masks[key] = csc_matrix.dot(value.numpy(), matrix)

~/dev/Practical-Deep-Learning-for-Coders-2.0/.venv-nix/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, kwargs) 725 result = self._slow_forward(*input, *kwargs) 726 else: --> 727 result = self.forward(input, kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(),

TypeError: forward() takes 3 positional arguments but 4 were given`

I am reading the docs to better understand how to fix it, if you have any insights/pointers they would be appreciated.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/issues/38, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3YCVZRE3CYHEDIVOMIPMTSS3SVLANCNFSM4ULJC6AQ .

alexanderwatanabe commented 3 years ago

here are what I think the relevant libraries i have in this environment are: fastai 2.1.5 fastcore 1.3.6 fast_tabnet 0.2.0 pytorch-tabnet 2.0.1 fastinference 0.0.30 pytorch 1.7.0

also my model is setup to solve for single-variable regression

muellerzr commented 3 years ago

If you look at my TabNet notebook, I changed the install dependencies as things were breaking and I didn’t have the time to adjust things.

On Wed, Dec 2, 2020 at 8:48 PM alexanderwatanabe notifications@github.com wrote:

here are what I think the relevant libraries i have in this environment are: fastai 2.1.5 fastcore 1.3.6 fast_tabnet 0.2.0 pytorch-tabnet 2.0.1 fastinference 0.0.30 pytorch 1.7.0

— You are receiving this because you commented.

Reply to this email directly, view it on GitHub https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/issues/38#issuecomment-737606094, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3YCV6RQFICM5X5J4M7M4DSS3VAJANCNFSM4ULJC6AQ .

alexanderwatanabe commented 3 years ago

Got it thanks!

muellerzr commented 3 years ago

There may be some issues with the new fastai or torch version though (wouldn’t be surprised) so let me know if you run into trouble!

On Wed, Dec 2, 2020 at 8:53 PM alexanderwatanabe notifications@github.com wrote:

Got it thanks!

— You are receiving this because you commented.

Reply to this email directly, view it on GitHub https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/issues/38#issuecomment-737607846, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3YCV5FBVZDJ665W3USFIDSS3VQZANCNFSM4ULJC6AQ .

alexanderwatanabe commented 3 years ago

Got it working with your pinned versions and the new verions of fastai/torch. If you have any notes or outline for how you might approach fixing it for the new versions I'd love to get involved with contributing. Understand you are probably busy so it's a standing offer for later if necessary!