deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.
4.69k stars 1.98k forks source link

p1ch8/1_convolution.ipynb L2 regularization problem #91

Open JuanCab opened 2 years ago

JuanCab commented 2 years ago

I've been working my way through the Jupyter Notebook for Chapter 8.

When I run the cell that trains using L2 regularization

model = Net().to(device=device)
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

    n_epochs = 100,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
all_acc_dict["l2 reg"] = validate(model, train_loader, val_loader)

The network will not train since the loss is 'nan'. I am curious if there is an error in the definition of training_loop_l2reg in the previous cell:

def training_loop_l2reg(n_epochs, optimizer, model, loss_fn,
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs =
            labels =
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            l2_lambda = 0.001
            # Replace pow(2.0) with abs() for L1 regularization
            l2_norm = sum(p.pow(2.0).sum()
                          for p in model.parameters())  
            loss = loss + l2_lambda * l2_norm


            loss_train += loss.item()

        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Training loss {}'.format(
      , epoch,
                loss_train / len(train_loader)))

Since if I instead train using the weight_decay parameter in SGD instead:

model = NetWidth(n_chans1=32).to(device=device)
optimizer = optim.SGD(model.parameters(), weight_decay=0.001, lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

    n_epochs = 100,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,

all_acc_dict["width"] = validate(model, train_loader, val_loader)

I have no problem with the loss converging.