dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.61k stars 485 forks source link

Issue while adding cat emb #421

Closed gauravbrills closed 2 years ago

gauravbrills commented 2 years ago

Hi Team I am getting issue when adding cat embeddings in Tabnet . The same works well when I like one hot encode and use tabnet . The error also doest give much insigt on the cause o the issue . I have verified and I am passing correct indx and value counts for th embeddings . Only thing to note are these are of type int64 which I doubt shall create an issue.

config is like

                            cat_idxs = cat_index,       
                            cat_dims = cat_dims,  
                            cat_emb_dim = 1,

Find below the error trace

[/content/tabnet/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in fit(self, X_train, y_train, eval_set, eval_name, eval_metric, loss_fn, weights, max_epochs, patience, batch_size, virtual_batch_size, num_workers, drop_last, callbacks, pin_memory, from_unsupervised, warm_start, augmentations)
    239             self._callback_container.on_epoch_begin(epoch_idx)
    240 
--> 241             self._train_epoch(train_dataloader)
    242 
    243             # Apply predict epoch to all eval sets

[/content/tabnet/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in _train_epoch(self, train_loader)
    455             self._callback_container.on_batch_begin(batch_idx)
    456 
--> 457             batch_logs = self._train_batch(X, y)
    458 
    459             self._callback_container.on_batch_end(batch_idx, batch_logs)

[/content/tabnet/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in _train_batch(self, X, y)
    493             param.grad = None
    494 
--> 495         output, M_loss = self.network(X)
    496 
    497         loss = self.compute_loss(output, y)

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/content/tabnet/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, x)
    583 
    584     def forward(self, x):
--> 585         x = self.embedder(x)
    586         return self.tabnet(x)
    587 

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/content/tabnet/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, x)
    859             else:
    860                 cols.append(
--> 861                     self.embeddings[cat_feat_counter](x[:, feat_init_idx].long())
    862                 )
    863                 cat_feat_counter += 1

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    158         return F.embedding(
    159             input, self.weight, self.padding_idx, self.max_norm,
--> 160             self.norm_type, self.scale_grad_by_freq, self.sparse)
    161 
    162     def extra_repr(self) -> str:

[/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2197         # remove once script supports set_grad_enabled
   2198         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2199     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2200 
   2201 

RuntimeError: CUDA error: device-side assert triggered

It could be also something I am passing wrong so if you have any quick insight it shall help me a lot in diagnosing the same .

gauravbrills commented 2 years ago

Closing this seems was missing to add the right cat embeddings and didnt add Label Encoding the right way .