lucidrains / tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
MIT License
808 stars 102 forks source link

Minor Bug: actuation function being applied to output layer in class MLP #9

Closed rminhas closed 3 years ago

rminhas commented 3 years ago

The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims) - 1)

The last line should be changed to is_last = ind >= (len(dims) - 2):

class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims) - 2)

If you like, I can do a pull request.

lucidrains commented 3 years ago

@rminhas oh yes, thank you for finding this bug! fixed in 0.1.4