deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.
https://www.manning.com/books/deep-learning-with-pytorch
4.69k stars 1.98k forks source link

Confusion regarding Normalized CIFAR10 dataset in Chapter 7 #70

Open tataganesh opened 3 years ago

tataganesh commented 3 years ago

In Chapter 7, the CIFAR10 dataset is initially loaded as -

cifar10 = datasets.CIFAR10(data_path, train=True, download=True)

Then, section 7.1.4 discusses the importance of normalizing the data. The transformed CIFAR10 dataset is loaded as -

transformed_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, 
                                      transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4915, 0.4823, 0.4468),
                                      (0.2470, 0.2435, 0.2616))
                                      ]))

However, in section 7.2.1, a dataset consisting of samples with labels 0 and 2 is created using the cifar10 variable.

cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]

I am assuming that the cifar10 variable here indicates the normalized cifar10 dataset. Hence, would it clearer to replace cifar10 with transformed_cifar10?

cifar2 = [(img, label_map[label]) for img, label in transformed_cifar10 if label in [0, 2]]

This will ensure that someone who is implementing these steps understands that the normalized data is now being used to train the NN.