PacktPublishing / Deep-Learning-with-PyTorch-1.x

Deep Learning with PyTorch 1.x, published by Packt
MIT License
41 stars 37 forks source link

Problem mit dem Code aus Chapter 6: Coded example – standard autoencoder #2

Open AlexanderHuels opened 1 year ago

AlexanderHuels commented 1 year ago

Code führt zu Fehler: Unter 1. fehlen wichtige Importe. Mein Code:

import torch
import os
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
from torchvision.datasets import MNIST  
from torchvision.utils import save_image
import matplotlib.pyplot as plt

Unter 3. werden Farbbilder normalisiert. MNIST Datensatz sind aber Graustufenbilder: Daher muss die Compose-Funktion korrigiert werden und batch_size=batch_size geht auch nicht (ggf. muss bei dataset download=True bei 1. Ausführen gesetzt werden):

transform_image = transforms.Compose([transforms.ToTensor()
                                      ,transforms.Normalize((0.5, ), (0.5,))])
dataset = MNIST('./MNIST_data', transform=transform_image, download=False)
data_loader = DataLoader(dataset, batch_size=64,shuffle=True)

Plotfunktion auf Graustufenbilder anpassen:

def plot_img(image): 
    plt.imshow(image[0],cmap='gray')

Dann kann man es auch plotten:

sample_data = next(iter(data_loader)) 
plot_img(sample_data[0][2])

Bei 6. müsste man noch ein paar Anpassungen machen (bei dem print) und dem fehlenden Verzeichnis:

for epoch in range(number_epochs):
    for data in data_loader:
        image, i = data

        image = image.view(image.size(0), -1)
        image = Variable(image)
        # Forward pass
        output = model(image)
        loss = criterion(output, image)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch [{}/{}], Loss:{:.4f}'.format(epoch + 1,number_epochs, loss.item()))
    if epoch % 10 == 0:
        os.makedirs('./mlp_img/', exist_ok=True)
        pic = to_image(output.cpu().data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
    torch.save(model.state_dict(), './sim_autoencoder.pth')

Dann funktioniert es schon mal. Können Sie schauen, ob es passt?