tqch / ddpm-torch

Unofficial PyTorch Implementation of Denoising Diffusion Probabilistic Models (DDPM)
MIT License
200 stars 35 forks source link

train my own image datasets #6

Closed WFLiu0327 closed 1 year ago

WFLiu0327 commented 1 year ago

Hello, I am recently learning DDPM, can you tell me how to use your code to train my own image dataset, all images in the same folder?

tqch commented 1 year ago

Hi there. You can first download this repository and then add your custom dataset class to ddpm_torch/datasets.py. For example,

@register_dataset
class CustomDataset(tvds.VisionDataset):
    """
    My custom dataset
    """
    base_folder = "mydata"  # subdirectory under data root, e.g. ~/datasets
    resolution = (32, 32)  # re-scaled image resolution
    channels = 3  # RGB by default
    transform = transforms.Compose([
        transforms.Resize(32, 32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])   # your custom transformations
    all_size = 30000  # your dataset size

    def __init__(
            self,
            root,
            transform=None
    ):
        super().__init__(root, transform=transform)
        self.filename = sorted([
            fname
            for fname in os.listdir(os.path.join(root, self.base_folder))
            if fname.endswith((".png", ".jpg", ".jpeg", ".bmp"))
        ], key=lambda name: name.rsplit(".", maxsplit=1)[0])
        np.random.RandomState(1234).shuffle(self.filename)

    def __getitem__(self, index):
        im = PIL.Image.open(os.path.join(self.root, self.base_folder, self.filename[index]))

        if self.transform is not None:
            im = self.transform(im)

        return im

    def __len__(self):
        return len(self.filename)