Closed isalirezag closed 6 years ago
DCGAN paper (Figure 5) used guided backpropagation to visualize the filters in discriminators. Not sure about visualization for generators.
Hello! I had the same question and I am not sure if this is right... I was trying to visualize the weights for the first layer of the G network. For the facades dataset, the 4x4 weights do not seem to capture features, but for my dataset they seem to (e.g. show a gradient across the square or a smaller more intense middle square) . I was having trouble following the UNet architecture in the code. I'm not sure if I even have the correct layer. Does anybody know?
### PREAMBLE ##################################################################
import torch
import torch.nn as nn
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
# the path for the network to load
pathNets = '/home/Documents/placenta/pytorch-CycleGAN-and-pix2pix/checkpoints/facades_pix2pix'
fileToLoad = 'latest_net_G.pth'
### VISUALIZE FILTERS #########################################################
# load the network into the variable 'net'
net = torch.load(os.path.join(pathNets, fileToLoad))
## figuring out dimensions of filters, layers, etc.
for k, v in net.items():
print(k)
# the filters are of size 3x4x4, and there are 64 of them
# I think this is for the first conv. layer?
print(net["model.model.0.weight"].size())
# plot just one channel of every filter
for jj in range(64):
# get jj-th filter, which is 3x4x4
temp = np.floor((net["model.model.0.weight"][jj,:,:,:])*255);
# save the first (0th) channel in the variable 'img'
img = np.zeros((4,4))
img = temp[0,:,:].numpy()
# plot grayscale image of that channel in the jj-th filter
plt.figure()
plt.imshow(img, vmin=-25, vmax=25, cmap='gray')
fig = plt.gcf()
fig.savefig("./filter"+str(jj)+".png")
@canghel I have tried your script , but the images are all white and have nothing, hope there are some process in your script~ thanks~
Hi @visonpon, Thank you for trying it out!! It's great to be able to check my code with someone!
I think what might be wrong is the vmin=-25, vmax=25
in the plt.imshow()
on the third line from the bottom. I shouldn't have included those limits, since they are specific to the facades network I trained. I wanted to standardize the colour map range to be the same for all of the filters.
For getting the correct max and min of the 0th channel, I think this works (though there may be a shorter command):
np.min(np.array(np.floor(net["model.model.0.weight"][:,0,:,:]*255)))
np.max(np.array(np.floor(net["model.model.0.weight"][:,0,:,:]*255)))
If you adjust the vmax
and vmin
, can you see non-white images?
Does anyone know how to display/visualize the trained filters for the generators and discriminators after training is finished?