junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
22.48k stars 6.23k forks source link

How can we display/visualize the trained filters for the generators and discriminators? #112

Closed isalirezag closed 6 years ago

isalirezag commented 6 years ago

Does anyone know how to display/visualize the trained filters for the generators and discriminators after training is finished?

junyanz commented 6 years ago

DCGAN paper (Figure 5) used guided backpropagation to visualize the filters in discriminators. Not sure about visualization for generators.

canghel commented 6 years ago

Hello! I had the same question and I am not sure if this is right... I was trying to visualize the weights for the first layer of the G network. For the facades dataset, the 4x4 weights do not seem to capture features, but for my dataset they seem to (e.g. show a gradient across the square or a smaller more intense middle square) . I was having trouble following the UNet architecture in the code. I'm not sure if I even have the correct layer. Does anybody know?

### PREAMBLE ##################################################################
import torch
import torch.nn as nn
import os

import numpy as np
import matplotlib.pyplot as plt
import cv2

# the path for the network to load
pathNets = '/home/Documents/placenta/pytorch-CycleGAN-and-pix2pix/checkpoints/facades_pix2pix'
fileToLoad = 'latest_net_G.pth'

### VISUALIZE FILTERS #########################################################
# load the network into the variable 'net'
net = torch.load(os.path.join(pathNets, fileToLoad))

## figuring out dimensions of filters, layers, etc.
for k, v in net.items():
    print(k)

# the filters are of size 3x4x4, and there are 64 of them
# I think this is for the first conv. layer?
print(net["model.model.0.weight"].size())

# plot just one channel of every filter
for jj in range(64):
    # get jj-th filter, which is 3x4x4
    temp = np.floor((net["model.model.0.weight"][jj,:,:,:])*255);
    # save the first (0th) channel in the variable 'img'
    img = np.zeros((4,4))
    img = temp[0,:,:].numpy()
    # plot grayscale image of that channel in the jj-th filter
    plt.figure() 
    plt.imshow(img, vmin=-25, vmax=25, cmap='gray')
    fig = plt.gcf()
    fig.savefig("./filter"+str(jj)+".png")
visonpon commented 6 years ago

@canghel I have tried your script , but the images are all white and have nothing, hope there are some process in your script~ thanks~

canghel commented 6 years ago

Hi @visonpon, Thank you for trying it out!! It's great to be able to check my code with someone!

I think what might be wrong is the vmin=-25, vmax=25 in the plt.imshow() on the third line from the bottom. I shouldn't have included those limits, since they are specific to the facades network I trained. I wanted to standardize the colour map range to be the same for all of the filters.

For getting the correct max and min of the 0th channel, I think this works (though there may be a shorter command):

np.min(np.array(np.floor(net["model.model.0.weight"][:,0,:,:]*255)))
np.max(np.array(np.floor(net["model.model.0.weight"][:,0,:,:]*255)))

If you adjust the vmax and vmin, can you see non-white images?