Open WeiYangBin opened 4 years ago
def train_step(self, batch_sents, batch_tags, word2id, tag2id): self.model.train() self.step += 1 # 准备数据 tensorized_sents, lengths = tensorized(batch_sents, word2id) tensorized_sents = tensorized_sents.to(self.device) targets, lengths = tensorized(batch_tags, tag2id) targets = targets.to(self.device) # forward scores = self.model(tensorized_sents, lengths) # 计算损失 更新参数 self.optimizer.zero_grad() loss = self.cal_loss_func(scores, targets, tag2id).to(self.device) loss.backward() self.optimizer.step() return loss.item()
以上代码来自与model/bilstm_crf.py文件 我想问下train_step函数下的self.model.train()起到一个什么作用
self.model.train()可以让model变成训练模式,一些操作例如 dropout和batch normalization的只有在model.train()的情况下才会生效。
我刚刚试着将这行代码删除,print的结果是几乎一样的
以上代码来自与model/bilstm_crf.py文件 我想问下train_step函数下的self.model.train()起到一个什么作用