Closed WFLiu0327 closed 1 year ago
Hi there. You can first download this repository and then add your custom dataset class to ddpm_torch/datasets.py
. For example,
@register_dataset
class CustomDataset(tvds.VisionDataset):
"""
My custom dataset
"""
base_folder = "mydata" # subdirectory under data root, e.g. ~/datasets
resolution = (32, 32) # re-scaled image resolution
channels = 3 # RGB by default
transform = transforms.Compose([
transforms.Resize(32, 32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) # your custom transformations
all_size = 30000 # your dataset size
def __init__(
self,
root,
transform=None
):
super().__init__(root, transform=transform)
self.filename = sorted([
fname
for fname in os.listdir(os.path.join(root, self.base_folder))
if fname.endswith((".png", ".jpg", ".jpeg", ".bmp"))
], key=lambda name: name.rsplit(".", maxsplit=1)[0])
np.random.RandomState(1234).shuffle(self.filename)
def __getitem__(self, index):
im = PIL.Image.open(os.path.join(self.root, self.base_folder, self.filename[index]))
if self.transform is not None:
im = self.transform(im)
return im
def __len__(self):
return len(self.filename)
Hello, I am recently learning DDPM, can you tell me how to use your code to train my own image dataset, all images in the same folder?