crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

Consider Making Normalization Optional in MappingFeedForwardBlock #94

Open mnslarcher opened 6 months ago

mnslarcher commented 6 months ago

Hi there!

I've noticed that in the forward method of MappingNetwork, you apply RMSNorm to the input:

class MappingNetwork(nn.Module):
    def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.in_norm = RMSNorm(d_model)
        self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.out_norm = RMSNorm(d_model)

    def forward(self, x):
        x = self.in_norm(x)
        for block in self.blocks:
            x = block(x)
        x = self.out_norm(x)
        return x

However, MappingFeedForwardBlock also performs normalization, which means the first block normalizes input that has already been normalized. Here's the current implementation for reference:

class MappingFeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
        self.dropout = nn.Dropout(dropout)
        self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))

    def forward(self, x):
        skip = x
        x = self.norm(x)
        x = self.up_proj(x)
        x = self.dropout(x)
        x = self.down_proj(x)
        return x + skip

Wouldn't it make sense to introduce an option to toggle normalization in MappingFeedForwardBlock and turn it off for the first block?

To be honest, even with this setup, the RMSNorm in the first block still plays a role, as there could be different scales for skip and x.

Just a thought while reviewing the code – feel free to ignore if it's not relevant!

stefan-baumann commented 5 months ago

Good catch! Yeah, we should consider changing this, although at this point (given that it also seems to work well with the double norm), we also have to consider whether it's worth introducing potentially breaking changes for the public implementation.

mnslarcher commented 5 months ago

Sure, it's not a big deal, maybe I'd consider adding something like use_norm=True or norm=True in the MappingFeedForwardBlock. It'll keep things as they are for now, but later on, if you wanna turn off normalization when reusing the block, you'll have the option. Anyway, it's pretty much a ~0 impact thing, probably not worth making the code more complex.