junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
22.8k stars 6.29k forks source link

[multi-input for cycleGAN] #1514

Open louis-xin opened 1 year ago

louis-xin commented 1 year ago

Hello I want to perform an A1,A2 --> B prediction using the cycle GAN model, as you mentioned in #498, I have to make a custom dataset to read the A1 and A2 image separately and concatenate it into (x,x,6) array A and pass it to the generator to generate an image B (x,x,3).

Problem is that the transform doesnt accept 6 channel np.array and also in the cycle_gan_model.py line 86, it only works when the input and output images have the same number of channels.

Could you give me some insights to solve this problem?

Thanks.


This is my custom dataset

`import os from data.base_dataset import BaseDataset, get_transform, get_params from data.image_folder import make_dataset from PIL import Image import random import numpy as np

from data.image_folder import make_dataset

from PIL import Image

class BucklingDataset(BaseDataset): """A template dataset class for you to implement custom datasets."""

@staticmethod

def __init__(self, opt):
    """Initialize this dataset class.

    Parameters:
        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
    """
    BaseDataset.__init__(self, opt)
    self.dir_A1 = os.path.join(opt.dataroot, opt.phase + 'A1')  # create a path '/path/to/data/trainA1'
    self.dir_A2 = os.path.join(opt.dataroot, opt.phase + 'A2')  # create a path '/path/to/data/trainA2'
    self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')  # create a path '/path/to/data/trainB'

    self.A1_paths = sorted(make_dataset(self.dir_A1, opt.max_dataset_size))   # load images from '/path/to/data/trainA1'
    self.A2_paths = sorted(make_dataset(self.dir_A2, opt.max_dataset_size))   # load images from '/path/to/data/trainA2'
    self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))    # load images from '/path/to/data/trainB'

    # self.A_paths = np.concatenate([self.A1_paths,self.A2_paths], axis=0)

    self.A1_size = len(self.A1_paths)  # get the size of dataset A
    self.A2_size = len(self.A2_paths)  # get the size of dataset A
    self.B_size = len(self.B_paths)  # get the size of dataset B
    btoA = self.opt.direction == 'BtoA'
    self.input_nc = self.opt.output_nc if btoA else 3     # get the number of channels of input image
    self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc      # get the number of channels of output image

def __getitem__(self, index):
    """Return a data point and its metadata information.

    Parameters:
        index (int)      -- a random integer for data indexing

    Returns a dictionary that contains A, B, A_paths and B_paths
        A (tensor)       -- an image in the input domain
        B (tensor)       -- its corresponding image in the target domain
        A_paths (str)    -- image paths
        B_paths (str)    -- image paths
    """
    A1_path = self.A1_paths[index % self.A1_size]  # make sure index is within then range
    A2_path = self.A2_paths[index % self.A2_size]  # make sure index is within then range
    if self.opt.serial_batches:   # make sure index is within then range
        index_B = index % self.B_size
    else:   # randomize the index for domain B to avoid fixed pairs.
        index_B = random.randint(0, self.B_size - 1)
    B_path = self.B_paths[index_B]
    A1_img = Image.open(A1_path).convert('RGB')
    A1_img = A1_img.resize((512,512))
    A2_img = Image.open(A2_path).convert('RGB')
    A2_img = A2_img.resize((512,512))

    A_img = np.concatenate([A1_img,A2_img],axis=2)

    # A_img = Image.fromarray(A_img)
    B_img = Image.open(B_path).convert('RGB')

    # transform_params = get_params(self.opt, A_img.size)
    self.transform_A = get_transform(self.opt)
    self.transform_B = get_transform(self.opt)
    # apply image transformation
    A = self.transform_A(A_img)
    B = self.transform_B(B_img)

    return {'A': A, 'B': B, 'A1_paths': A1_path, 'B_paths': B_path}
# 'A2_paths': A2_path,

def __len__(self):
    """Return the total number of images in the dataset.

    As we have two datasets with potentially different number of images,
    we take a maximum of
    """
    return max(self.A1_size, self.B_size)`
youngprogrammerBee commented 1 year ago

Have you solve the problem? I am also working on using A1,A2,A3 to predict B

louis-xin commented 1 year ago

Have you solve the problem? I am also working on using A1,A2,A3 to predict B

No not yet, do u have any idea for it?

tsshubhamv commented 1 year ago

Any update @workingpotato !?

Have you solve the problem? I am also working on using A1,A2,A3 to predict B

No not yet, do u have any idea for it?

prabathbr commented 11 months ago

I am also trying this but has same issue. I have stacked my input to 6 channel NPZ and changed the dataset loader accordingly. My output channels would be just 1. Has anyone found a solution for this ?

kuhhg commented 6 months ago

Have you solve the problem? I am also working on using multi-input but have no idea to solve it.