dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

OOM problem when I search hyperparameters with Tabnet #527

Closed luyuhengCN closed 6 months ago

luyuhengCN commented 7 months ago

Dear, I came across an OOM issue while searching for the optimal hyperparameters for TabNet. I input my dataset (~20G) as an external variable to my TabNet network and concurrently conduct 4 parallel searches for multiple sets of hyperparameters. My understanding is that the dataset itself won't be replicated to take up more memory when fed to the network multiple times during the parallel process. However, it exceeded the maximal memory in my machine (200G). I want to know if there will be any additional replication of the dataset within the internal fitting process of the network? (i.e., TabNetRegressor().fit() )

What is the current behavior? Out of memory when I If the current behavior is a bug, please provide the steps to reproduce.

Expected behavior

Screenshots

Other relevant information: poetry version: tabnet 4.1.0, pytorch=1.13.1 python version: 3.9.18 Operating System: CentOS7 Additional tools:

Additional context

Optimox commented 7 months ago

Hi, without a code example it is difficult to know what is going on. The code is not optimized to be run multiple times on parallel. I think the best way to speed up training is to play with batch size, num workers etc so that you have a good gpu utilization and then simply do your hyperparameter search sequentially.

luyuhengCN commented 7 months ago

Hi, thanks for your reply. I monitor the memory usage of the .fit() processing, and I found the create_dataloaders() in utils.py change the X_train into np.float32. I guess this will take up more memory if the X_train is big. I'm not sure if it would cause the OOM problem.

Optimox commented 6 months ago

Yes usually models are trained using float32, we could try to use mixed precision (float16) but that would still not be meant for multiple trainings in parallel.