lonePatient / BERT-chinese-text-classification-pytorch

This repo contains a PyTorch implementation of a pretrained BERT model for text classification.
99 stars 19 forks source link

将模型运行在CNN上时,运行到loss.backward(),模型停止运行 没有任何提示 #5

Closed FOXaaFOX closed 4 years ago

FOXaaFOX commented 4 years ago

除了主文件的类名外,其他都没有修改,原模型可以运行通过,还麻烦请教,为什么运行到损失向前传播的时候,训练就停止了。 class BertCNN(BertPreTrainedModel): def init(self, config): super(BertCNN, self).init(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.convs = Conv1d(config.hidden_size, n_filters, filter_sizes) self.classifier = nn.Linear(len(filter_sizes) * n_filters, self.num_labels) self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None,labels=None): encodedlayers, = self.bert(input_ids, token_type_ids, attention_mask) encoded_layers = self.dropout(encoded_layers) encoded_layers = encoded_layers.permute(0, 2, 1) conved = self.convs(encoded_layers) pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved] cat = self.dropout(torch.cat(pooled, dim=1)) logits = self.classifier(cat) return logits

FOXaaFOX commented 4 years ago

问题已经解决 应该是pytorch在不同os的问题,我从win10切换到linux 运行正常