eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.22k stars 4.05k forks source link

what does imgs["A"] , image["B"] and batch[“A”] batch["B"] mean? where does "A" and "B" come from? #119

Open jackymail opened 4 years ago

jackymail commented 4 years ago

def sample_images(batches_done): """Saves a generated sample from the validation set""" imgs = next(iter(val_dataloader)) real_A = Variable(imgs["B"].type(Tensor)) real_B = Variable(imgs["A"].type(Tensor)) fake_B = generator(real_A) img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2) save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)

boyden commented 4 years ago

It's from the ImageDataset Object customized in the datasets.py. "A" means image A, and so as to B. U can customized ur own dataset in that file.

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))