dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.6k stars 482 forks source link

The mask tensor M in script tab_network.py needs to be transformed to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step". #516

Open sciengineer opened 1 year ago

sciengineer commented 1 year ago

Describe the bug In the mask tensor M, elements consistently register values far from 1 and near zero, which results in "self.gamma - M" and the prior value being distinct from zero. However, the paper stipulates that "a feature is enforced to be used only at one decision step" when gamma equals 1. This seems practically unachievable, leading me to infer that a transformation of tensor M is necessary.

Tensor M, a batch_size * num_features tensor, is the output of sparsemax. It's interpreted as the weightage of features during a decision step, with each row of M summing up to 1. Consequently, when dealing with a large number of features, the values of the elements are zeros or positive numbers approximating zero. It's highly unlikely for there to be a solitary element equal to 1 amidst other elements that are zeros.

In my view, the author of the Tabnet paper should have employed a tensor, transformed from M, which frequently contains several elements equal to 1, rather than using M itself. This tensor should have been used as the subtrahend in this line: "prior = torch.mul(self.gamma - M, prior)", to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step".

def forward(self, x, prior=None):
        x = self.initial_bn(x)

        bs = x.shape[0]  # batch size
        if prior is None:
            prior = torch.ones((bs, self.attention_dim)).to(x.device)

        M_loss = 0
        att = self.initial_splitter(x)[:, self.n_d :]
        steps_output = []
        for step in range(self.n_steps):
            M = self.att_transformers[step](prior, att)
            M_loss += torch.mean(
                torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1)
            )
            # update prior
            prior = torch.mul(self.gamma - M, prior)

What is the current behavior? A feature is not enforced to be used only at one desicion step when gamma is 1 as stated in the paper. If the current behavior is a bug, please provide the steps to reproduce.

Set gamma = 1, and see the value of variabes M, prior,and masked_x in debug mode when in for loop of decision steps. Expected behavior

As the paper says, when gamma = 1, a feature is enforced to be used only at one decision step. Screenshots

Other relevant information: poetry version:
python version: Operating System: Additional tools:

Additional context