deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.
https://www.manning.com/books/deep-learning-with-pytorch
4.69k stars 1.98k forks source link

p1ch8: "ResBlock" objects are most likely identical #89

Open aallahyar opened 2 years ago

aallahyar commented 2 years ago

In the scripts of p1ch8 (section 8.5.3, page 227 to be specific), we are making sub-blocks of convolution using the following code:

self.resblocks = nn.Sequential(
            *(n_blocks * [ResBlock(n_chans=n_chans1)]))

However, considering that objects are copied by reference, then I can imagine that the weight matrices across created ResBlocks are identical.

I think the code should be changed to:

self.resblocks = nn.Sequential(
            *[ResBlock(n_chans=n_chans1) for _ in range(n_blocks)])
t-vi commented 2 years ago

Absolutely, thank you for spotting this and reporting.

ftianRF commented 2 years ago

@aallahyar Yes.

Here I add an example for more readers:

import torch
import torch.nn as nn

class NetResDeep(nn.Module):
    def __init__(self, n_chans1=32, n_blocks=10):
        super().__init__()
        self.n_chans1 = n_chans1
        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
        self.resblocks = nn.Sequential(*([ResBlock(n_chans=n_chans1)] * n_blocks))    # shown in the book
        #self.resblocks = nn.Sequential(*[ResBlock(n_chans=n_chans1) for _ in range(n_blocks)])    # the right version
        self.fc1 = nn.Linear(n_chans1 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
        out = F.max_pool2d(self.resblocks(out))
        out = out.view(-1, self.n_chans1 * 8 * 8)
        out = torch.relu(self.fc1(out))
        out = self.fc2(out)
        return out

netresdeep = NetResDeep()
id(netresdeep.resblocks[0].conv.weight) == id(netresdeep.resblocks[1].conv.weight)

Output:

True

This result shows that these two sets of weights share the same memory.