lucidrains / perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
MIT License
1.1k stars 134 forks source link

Source for the gated GELU MLP #48

Open breuderink opened 3 years ago

breuderink commented 3 years ago

Reading the code I found the following implementation for the feed-forward MLP of the Perceiver IO:

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

I could not find references to a gated GELU in the PerceiverIO paper nor in in the code.

Is there a particular to use GEGLU instead of GELU?

lucidrains commented 3 years ago

@breuderink ohh this is actually a trick from a Shazeer paper https://arxiv.org/pdf/2002.05202.pdf that should give an extra performance boost, but i should probably make it optional to stay faithful to the original paper