konumaru / commonLit_readability_prize

https://www.kaggle.com/c/commonlitreadabilityprize
0 stars 0 forks source link

`init_weights` の設定 #11

Closed konumaru closed 3 years ago

konumaru commented 3 years ago

参考noteobok https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-infer

class CommonLitModel(nn.Module):
    def __init__(
        self, 
        model_name, 
        config,  
        multisample_dropout=False,
        output_hidden_states=False
    ):
        super(CommonLitModel, self).__init__()
        self.config = config
        self.roberta = RobertaModel.from_pretrained(
            model_name, 
            output_hidden_states=output_hidden_states
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        if multisample_dropout:
            self.dropouts = nn.ModuleList([
                nn.Dropout(0.5) for _ in range(5)
            ])
        else:
            self.dropouts = nn.ModuleList([nn.Dropout(0.3)])
        self.regressor = nn.Linear(config.hidden_size, 1)
        self._init_weights(self.layer_norm)
        self._init_weights(self.regressor)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)