dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

spike in memory when training ends #524

Closed Borda closed 6 months ago

Borda commented 7 months ago

Describe the bug

What is the current behavior?

Running this training on kaggle with 3M samples and when the training ends reaching max epoch, the GPU becomes idle but RAM starts to grow with a single CPU process so going from 20GB to 45GB and eventually crashing the computer

Expected behavior

Screenshots

Other relevant information: poetry version: none python version: 3.10 Operating System: linux Additional tools: latest TabNet

Additional context

Optimox commented 7 months ago

@Borda I think this might be due to feature importance computation which can be heavy for large datasets. Have you tried setting compute_importance=False when calling fit ?

Borda commented 7 months ago

Have you tried setting compute_importance=False when calling fit ?

Not yet, let me check it... but would it be possible to compute the importance of GPU (maybe with cudf) or use some booster over all CPU cores?

Optimox commented 7 months ago

The code can certainly be improved but that's not a quick fix, you may want to compute the feature importance on a smaller subset after training if you like : model._compute_feature_importances(X_subset)

Borda commented 7 months ago

you may want to compute the feature importance on a smaller subset after training if you like : model._compute_feature_importances(X_subset)

That would be a good alternative; just not sure about using protected API, which can be changed at any time, right?

Borda commented 7 months ago

also hit an interesting crash:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[18], line 30
     28 # Train a TabNet model for the current fold
     29 model = TabNetRegressor(**tabnet_params)
---> 30 model._compute_feature_importances(fold_train_features[-100_000:])
     31 model.fit(
     32     X_train=fold_train_features, y_train=fold_train_target,
     33     eval_set=[(fold_valid_features, fold_valid_target)],
     34     eval_metric=['mae'], **FIT_PARAMETERS
     35 )
     36 # Free up memory by deleting fold specific variables

File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:759, in TabModel._compute_feature_importances(self, X)
    750 def _compute_feature_importances(self, X):
    751     """Compute global feature importance.
    752 
    753     Parameters
   (...)
    757 
    758     """
--> 759     M_explain, _ = self.explain(X, normalize=False)
    760     sum_explain = M_explain.sum(axis=0)
    761     feature_importances_ = sum_explain / np.sum(sum_explain)

File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:336, in TabModel.explain(self, X, normalize)
    318 def explain(self, X, normalize=False):
    319     """
    320     Return local explanation
    321 
   (...)
    334         Sparse matrix showing attention masks used by network.
    335     """
--> 336     self.network.eval()
    338     if scipy.sparse.issparse(X):
    339         dataloader = DataLoader(
    340             SparsePredictDataset(X),
    341             batch_size=self.batch_size,
    342             shuffle=False,
    343         )

AttributeError: 'TabNetRegressor' object has no attribute 'network'
Optimox commented 7 months ago

you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.

Borda commented 7 months ago

you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.

ok, but if just want to fit and predict, then I do not need this _compute_feature_importances at all, right?

also, a bit looking at the code, do you have inside what is the root of the spike? Is it the Sparse matrix computation or the dot prod after it...

Optimox commented 7 months ago

I'm not sure about what is causing the spike, I would need to dig deeper to give you an answer.

If you don't want to look at the features importance then you can just skip the computation yes.