dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.6k stars 482 forks source link

Count the number of parameters #513

Closed CesarLeblanc closed 1 year ago

CesarLeblanc commented 1 year ago

Feature request

What is the expected behavior?

from pytorch_tabnet.tab_model import TabNetClassifier

model = TabNetClassifier()
model.fit(X_train, y_train, eval_set=[(X_valid, y_valid)])

print(sum(p.numel() for p in model.parameters() if p.requires_grad))
>>> XXX  # The number of trainable parameters in the model

What is motivation or use case for adding/changing the behavior?

Knowing the number of trainable parameters that the model has.

How should this be implemented in your opinion?

Implementing an attribute parameters for TabNetClassifier and TabNetRegressor objects.

Are you willing to work on this yourself?

Yes.

CesarLeblanc commented 1 year ago

Simply doing the following:

print(sum(p.numel() for p in model.network.parameters() if p.requires_grad))
Optimox commented 1 year ago

@CesarLeblanc can I close this issue?