explainingai-code / DiT-PyTorch

This repo implements Diffusion Transformers(DiT) in PyTorch and provides training and inference code on CelebHQ dataset
8 stars 2 forks source link

Maybe something wrong in transformer_layer.py #1

Open ahrismile opened 2 weeks ago

ahrismile commented 2 weeks ago
def forward(self, x, condition):
        scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
        (pre_attn_shift, pre_attn_scale, post_attn_scale,
         pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params
        out = x
        attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1))
                            + pre_attn_shift.unsqueeze(1))
        out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output)
        mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
                           pre_mlp_shift.unsqueeze(1))
       # wrong
        out = out + post_mlp_scale.unsqueeze(1) * self.attn_block(mlp_norm_output)
        return out

according to the paper and the video you make, the code in the last but two should be out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output)

finally, reeeeeally appreciate your video and your work !!!!

explainingai-code commented 2 weeks ago

Thank you @ahrismile , Yes it should obviously be self.mlp block, dont know how I missed this. Fixed it now - https://github.com/explainingai-code/DiT-PyTorch/commit/52d0a7551875d60b63b002a015b68bd60ef9435f