lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

MLP design in MaskEstimator #15

Closed YoungloLee closed 10 months ago

YoungloLee commented 10 months ago

Thanks for your great work. In the BS-RoFormer paper, the authors mention:

Each MLP layer consists of a RMS Norm layer, a fully connected layer followed by a Tanh activation, and a fully connected layer followed by a gated linear unit (GLU) layer[29].

However, your MaskEstimator implementation is somewhat different from the paper description.

Here's my implementation (I also checked the number of parameters):

class TanH(Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x.tanh()
class GLU(Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * gate.sigmoid()
class MaskEstimator(Module):
    @beartype
    def __init__(
            self,
            dim,
            dim_inputs: Tuple[int, ...],
            mlp_expansion_factor = 4,
    ):
        super().__init__()
        self.dim_inputs = dim_inputs
        self.to_freqs = ModuleList([])

        for dim_in in dim_inputs:
            net = []
            net.append(TanH(dim, dim * ff_mult))
            net.append(GLU(dim * ff_mult, dim_in))
            self.to_freqs.append(nn.Sequential(*net))

    def forward(self, x):
        x = x.unbind(dim=-2)

        outs = []

        for band_features, to_freq in zip(x, self.to_freqs):
            freq_out = to_freq(band_features)
            outs.append(freq_out)

        return torch.cat(outs, dim=-1)
lucidrains commented 10 months ago

@YoungloLee hey Younglo

yes i think you are right

want to check 0.2.6 see if it aligns better? also, do you know if the tanh output aligns with the reals, and the glu output aligns with the imaginary? i don't have enough domain expertise to know why one mlp output is bounded between -1 and 1 while the other is not

YoungloLee commented 10 months ago

@lucidrains

According to the paper, tanh activation output is directly fed into the next fully connected layer with glu activation:

MLP: Sequential(RMSNorm, Linear, tanh, Linear, GLU))

I'm not certain, but those two don't seem to be related to the real and imaginary components. It just appears to be the authors' preference for an activation function

lucidrains commented 10 months ago

@YoungloLee ahh i totally misunderstood, thank you! i'm used to seeing swish for activation functions

ok, 0.2.8 should be accurate to the paper then!

Psarpei commented 10 months ago

Looks pretty well now and the parameters getting really close to the original paper but I think the RMS Norm they mentioned is still missing

lucidrains commented 10 months ago

@Psarpei it is taken care of at the end of the attention layers https://github.com/lucidrains/BS-RoFormer/blob/main/bs_roformer/bs_roformer.py#L302

lucidrains commented 10 months ago

it is standard to do a final normalization at the end of pre-normalized transformer

Psarpei commented 10 months ago

Aaaaah I see thanks :)