Open jackymail opened 4 years ago
It's from the ImageDataset Object customized in the datasets.py. "A" means image A, and so as to B. U can customized ur own dataset in that file.
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))
def __getitem__(self, index):
image_A = Image.open(self.files_A[index % len(self.files_A)])
if self.unaligned:
image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
else:
image_B = Image.open(self.files_B[index % len(self.files_B)])
# Convert grayscale images to rgb
if image_A.mode != "RGB":
image_A = to_rgb(image_A)
if image_B.mode != "RGB":
image_B = to_rgb(image_B)
item_A = self.transform(image_A)
item_B = self.transform(image_B)
return {"A": item_A, "B": item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
def sample_images(batches_done): """Saves a generated sample from the validation set""" imgs = next(iter(val_dataloader)) real_A = Variable(imgs["B"].type(Tensor)) real_B = Variable(imgs["A"].type(Tensor)) fake_B = generator(real_A) img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2) save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)