Xzzit / pytorch-tutorial

Introduction to PyTorch: A comprehensive Chinese course available at the provided link.
https://space.bilibili.com/12580263/channel/series
GNU General Public License v3.0
11 stars 4 forks source link

将04_CNN的数据集从MNIST变为ImageNet后,代码报错 #3

Closed Xzzit closed 1 year ago

Xzzit commented 1 year ago

提问来自B站用户:大力哥爱金坷垃

以下为该同学的代码:

import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(24),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

train_data = ImageFolder(root='/home/user/mq/datasets/ImageNet100/train', transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# train the model
num_epochs = 20

for epoch in range(num_epochs):
    for idx, (img, label) in enumerate(train_data_loader):
        img, label = img.to(device), label.to(device)

        # compute prediction error
        pred = cnn(img)
        loss = loss_fn(pred, label)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
Xzzit commented 1 year ago

报错的信息是RuntimeError: stack expects each tensor to be equal size, but got [1, 27, 24] at entry 0 and [1, 24, 32] at entry 1

原因:transforms.Resize(24)函数只能将图像中边长最短的一边resize为24,另一边按比例缩放,而ImageNet里的图像很多是长宽不同的图像,该函数会导致resize后的图像仅有最短的一边等于24,而另一条边不确定。关于该函数的详情,参考官方文档

修改方法:在transforms.Resize(24)下新增一行,transforms.CenterCrop(24)即可。