lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.99k stars 757 forks source link

Correct Way to train both unets #354

Open gauravbyte opened 1 year ago

gauravbyte commented 1 year ago

Hello I am trying to train model on medical dataset I want to generate the images based on text I am using foll0wing script for training but apparently it does not printing training loss and valid loss and not executing sampling loop please look at script and let me know what is the issue

for i in tqdm(range(10000)):
    try:
        loss = trainer.train_step(unet_number = 1, max_batch_size = 32)
        loss = trainer.train_step(unet_number = 2, max_batch_size = 32)
        trainer.update(unet_number=1)
        trainer.update(unet_number=2)
        # trainer.update(unet_number=1)
        print(f'train loss: {loss}')
        if not(i%50):
            valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 32)
            print(f"valid loss {valid_loss}")

        if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
            images = trainer.sample(texts = ['xray of chest','mri of brain'], batch_size = 1, return_pil_images = True, stop_at_unet_number = 2)
            images[0].save(f'sample/xray-chest-{i // 100}.png')
            images[1].save(f'sample/brain-mri-{i//100}.png')

        if not(i%1000):
            trainer.save(f'checkpoints/checkpoint-{i//1000}.pt')

    except:
        continue
xbowlove commented 1 year ago

I encounter the same error with you. the printed info in terminal is "'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' AssertionError: you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet ". I can't understand the masssage.If someone can explain it to me, I would be very appreciated.