zhjohnchan / M3AE

[MICCAI-2022] This is the official implementation of Multi-Modal Masked Autoencoders for Medical Vision-and-Language Pre-Training.
111 stars 10 forks source link

The "init_weights" function for model initialization. #21

Open Pyy-hah opened 6 months ago

Pyy-hah commented 6 months ago

Hello,

I've been exploring your project and am particularly interested in the model initialization process. I noticed that during model initialization, the init_weights function from m3ae_utils.py is utilized to initialize the parameters. This approach caught my attention, and I would like to understand it better.

def init_weights(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

Could you please share the benefits of using the init_weights function for parameter initialization? Additionally, I'm curious about how one should determine the most appropriate initialization scheme for a given model.

Understanding the rationale behind your choice of initialization and the factors to consider when selecting an initialization scheme would greatly enhance my knowledge and potentially benefit the community by shedding light on this critical aspect of model development.

Thank you for your time and contribution to the field.

Best regards, Zilin Lu