lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.96k stars 751 forks source link

how to train your inpainting model using my own dataset?? #352

Open dreamlychina opened 1 year ago

dreamlychina commented 1 year ago

Thanks for sharing this amazing work,I want to train your inpainting model using my own dataset, could you show me any training script and how to prepare the data at your convenience?

swayampragnya-malla commented 8 months ago

from imagen_pytorch import Unet, Imagen, ImagenTrainer from imagen_pytorch.data import Dataset

output_path="/content/drive/MyDrive/imgen_pytorch/output"

unets for unconditional imagen

unet = Unet( dim = 32, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 1, layer_attns = (False, False, False, True), layer_cross_attns = False )

imagen, which contains the unet above

imagen = Imagen( condition_on_text = False, # this must be set to False for unconditional Imagen unets = unet, image_sizes = 256, timesteps = 1000 )

trainer = ImagenTrainer( imagen = imagen, split_valid_from_train = True # whether to split the validation dataset from the training ).cuda()

instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/content/drive/MyDrive/unconditional_generation/dataset_256', image_size = 256)

trainer.add_train_dataset(dataset, batch_size = 16)

working training loop

for i in range(20000): loss = trainer.train_step(unet_number = 1, max_batch_size = 4) print(f'loss: {loss}')

if not (i % 50):
    valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
    print(f'valid loss: {valid_loss}')

if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
    images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
    images[0].save(f'{output_path}/{i // 100}.png')

This is the training code for your custom dataset .