Closed konumaru closed 3 years ago
参考noteobok https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-infer
class CommonLitModel(nn.Module): def __init__( self, model_name, config, multisample_dropout=False, output_hidden_states=False ): super(CommonLitModel, self).__init__() self.config = config self.roberta = RobertaModel.from_pretrained( model_name, output_hidden_states=output_hidden_states ) self.layer_norm = nn.LayerNorm(config.hidden_size) if multisample_dropout: self.dropouts = nn.ModuleList([ nn.Dropout(0.5) for _ in range(5) ]) else: self.dropouts = nn.ModuleList([nn.Dropout(0.3)]) self.regressor = nn.Linear(config.hidden_size, 1) self._init_weights(self.layer_norm) self._init_weights(self.regressor) def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0)
参考noteobok https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-infer