Open re4388 opened 5 years ago
need to change .data[0] => .item()
add model.train() at beginning of the loop
Only need to modify the training loop code, below is the fixed code worked for me :)
for epoch in range(num_epoches): model.train() print('epoch {}'.format(epoch + 1)) print('*' * 10) running_loss = 0.0 running_acc = 0.0 for i, data in enumerate(train_loader, 1): img, label = data b, c, h, w = img.size() assert c == 1, 'channel must be 1' img = img.squeeze(1) # img = img.view(b*h, w) # img = torch.transpose(img, 1, 0) # img = img.contiguous().view(w, b, -1) if use_gpu: img = Variable(img).cuda() label = Variable(label).cuda() else: img = Variable(img) label = Variable(label) # 向前传播 out = model(img) loss = criterion(out, label) running_loss += loss.item() * label.size(0) _, pred = torch.max(out, 1) num_correct = (pred == label).sum() running_acc += num_correct.item() # 向后传播 optimizer.zero_grad() loss.backward() optimizer.step() if i % 300 == 0: print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, num_epoches, running_loss / (batch_size * i), running_acc / (batch_size * i))) print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, running_loss / (len(train_dataset)), running_acc / (len( train_dataset)))) model.eval() eval_loss = 0. eval_acc = 0. for data in test_loader: img, label = data b, c, h, w = img.size() assert c == 1, 'channel must be 1' img = img.squeeze(1) # img = img.view(b*h, w) # img = torch.transpose(img, 1, 0) # img = img.contiguous().view(w, b, h) if use_gpu: img = Variable(img, volatile=True).cuda() label = Variable(label, volatile=True).cuda() else: img = Variable(img, volatile=True) label = Variable(label, volatile=True) out = model(img) loss = criterion(out, label) eval_loss += loss.item() * label.size(0) _, pred = torch.max(out, 1) num_correct = (pred == label).sum() eval_acc += num_correct.item() print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len( test_dataset)), eval_acc / (len(test_dataset)))) print()
need to change .data[0] => .item()
add model.train() at beginning of the loop
Only need to modify the training loop code, below is the fixed code worked for me :)