junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
23.09k stars 6.32k forks source link

visualizer error for 6 input and 3 output channels #1442

Open Davegdd opened 2 years ago

Davegdd commented 2 years ago

Hello and thanks a lot for this software and taking the time to read my issue.

I'm trying to train pix2pix to go from a combination of 2 RGB images (6 input channels) to 1 RGB image (3 output channels). My dataset looks like this (same portion of the sky in optical, ultraviolet and infrared (false-coloured), respectively):

imagen

Setting --input_nc 6 and modifying getitem in aligned_dataset to be able to input 2 images (6 channels) like this:

        w, h = AB.size
        w3 = int(w / 3)
        A = AB.crop((0, 0, w3, h))
        B = AB.crop((w3, 0, w3*2, h))
        C = AB.crop((w3*2, 0, w, h))

        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)
        C = B_transform(C)
        B = torch.cat((B, C))

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

I get the following error after epoch 1:

(epoch: 1, iters: 100, time: 1.336, data: 0.211) G_GAN: 1.438 G_L1: 2.366 D_real: 0.543 D_fake: 0.663 
(epoch: 1, iters: 200, time: 1.338, data: 0.005) G_GAN: 0.956 G_L1: 1.331 D_real: 0.871 D_fake: 0.448 
(epoch: 1, iters: 300, time: 1.338, data: 0.003) G_GAN: 0.834 G_L1: 2.449 D_real: 0.504 D_fake: 0.583 
/usr/local/lib/python3.7/dist-packages/visdom/__init__.py:366: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return np.array(a)
Traceback (most recent call last):
  File "train.py", line 57, in <module>
    visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
  File "/content/pytorch-CycleGAN-and-pix2pix/util/visualizer.py", line 154, in display_current_results
    padding=2, opts=dict(title=title + ' images'))
  File "/usr/local/lib/python3.7/dist-packages/visdom/__init__.py", line 389, in wrapped_f
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/visdom/__init__.py", line 1292, in images
    height = int(tensor.shape[2] + 2 * padding)
IndexError: tuple index out of range

Are there any other modifications needed beyond the ones in dataset? What would they be? Thanks a lot for any assistance.

junyanz commented 2 years ago

I am not sure if our current code supports visualization of 6-channel images. Two potential fixes: (1) you can try using the wandb visualization and see if they handle it or not. (2) you may want to modify the visualizer code. We often call the tensor2im function (here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/visualizer.py#L41) , which only works for 1 channel or 3 channel images. You may want to modify it. https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/util.py#L9

Davegdd commented 2 years ago

I am not sure if our current code supports visualization of 6-channel images. Two potential fixes: (1) you can try using the wandb visualization and see if they handle it or not. (2) you may want to modify the visualizer code. We often call the tensor2im function (here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/visualizer.py#L41) , which only works for 1 channel or 3 channel images. You may want to modify it. https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/util.py#L9

Thanks a lot for pointing me to that. I was able to modify the tensor2im and it's training now. For anyone with the same issue I just included the following at the same indentation level as # grayscale to RGB in tensor2im:

        if image_numpy.shape[0] == 6:  
            image_numpy, b = np.vsplit(image_numpy, 2)

This will just split the two input images that were concatenated to feed the network and return one of them for visualization as real A for reference.