ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.07k stars 5.37k forks source link

5.5节train_ch5函数中batch_count变量的定义位置有疑问 #149

Open DesmondLei opened 4 years ago

DesmondLei commented 4 years ago

bug描述 为什么要把batch_count定义在“for epoch in range”循环外?这样即使每个epoch得到的train_l_sum相同,越往后面的epoch算出的loss也会越小,似乎不合理。

def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    loss = torch.nn.CrossEntropyLoss()
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

版本信息 pytorch: torchvision: torchtext: ...