dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

n_independent_decoder and n_shared_decoder not being used in TabNetPretrainer #478

Closed M-R-T-U-D closed 1 year ago

M-R-T-U-D commented 1 year ago

Describe the bug

See title

What is the current behavior?

If the current behavior is a bug, please provide the steps to reproduce.

  1. Create a TabNetPretrainer instance, e.g.:
    TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(
        lr=1e-5, 
        weight_decay=1e-3,
        betas=(0.9, 0.9),
    ),
    clip_value=1,
    n_d=32,
    n_a=32,
    n_steps=4,
    gamma=1.3,
    n_independent=2,
    n_shared=2,
    lambda_sparse=1e-3,
    seed=0,
    epsilon=1e-15,
    momentum=0.2,
    mask_type='sparsemax',
    n_shared_decoder=2, # nb shared glu for decoding
    n_indep_decoder=2, # nb independent glu for decoding
    verbose=5,
    device_name="cuda"
    )
  2. use summary() from torchinfo library to show total parameter of the network:
    device = "cuda"
    input_data = [torch.randn(1, 150).to(device)]
    summary(
            unsupervised_model.network, 
            input_data=input_data,
            depth=3,
            device=device,
            verbose=1
        )
  3. Results in the following: Total params: 406,684
  4. Now change n_shared_decoder or n_indep_decoder to e.g. 44 and you would see the same total params as in step 3. So changing the said params does not effect the size?
  5. I looked into the code and I see that the TabNetPretraining instance is not being passed the params when initializing the network which results in using the default set value which is 1 for both params. Relevant code snippets:
    self.network = tab_network.TabNetPretraining(
            self.input_dim,
            pretraining_ratio=self.pretraining_ratio,
            n_d=self.n_d,
            n_a=self.n_a,
            n_steps=self.n_steps,
            gamma=self.gamma,
            cat_idxs=self.cat_idxs,
            cat_dims=self.cat_dims,
            cat_emb_dim=self.cat_emb_dim,
            n_independent=self.n_independent,
            n_shared=self.n_shared,
            epsilon=self.epsilon,
            virtual_batch_size=self.virtual_batch_size,
            momentum=self.momentum,
            mask_type=self.mask_type,
            group_attention_matrix=self.group_matrix.to(self.device),
        ).to(self.device)

    and

    class TabNetPretraining(torch.nn.Module):
    def __init__(
        self,
        ...
        n_shared_decoder=1,
        n_indep_decoder=1,
        ...
    ): ...

    Expected behavior

    The size of tabnet should change since independent and shared layers in decoder changes. Which does not happen since the both params are not being passed via the class TabNetPretrainer to TabNetPretraining instance.

Screenshots

Other relevant information: poetry version: - python version: 3.8 Operating System: Linux Additional tools: torchinfo

Additional context

Optimox commented 1 year ago

thanks @M-R-T-U-D,

It looks like I forgot those 2 parameters. Do you think that the above PR would solve the problem ?

M-R-T-U-D commented 1 year ago

Hi Optimox, yes that should fix the problem. Thanks for fixing the bug 👍since now the shared and indep layers in decoder will change if passed via TabNetPretrainer

Optimox commented 1 year ago

@M-R-T-U-D I did a test with the bugfix and the number of parameters does change if add more decoder layer.

I'll try to make a release soon, in the meantime you can use the develop branch.

Out of curiosity have you been able to play with the attention groups ?

M-R-T-U-D commented 1 year ago

Not for now. Do you suspect that there is a bug with that also?

Optimox commented 1 year ago

No it's just that you seem to be using the library quite heavily so I would be happy to get a feedback about this, since it's not in the original paper.

M-R-T-U-D commented 1 year ago

Not planning to use it anytime soon, but I will let you know if I do use it.