lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.12k stars 1.09k forks source link

train Decoder error It may be a version problem #243

Closed FTKyaoyuan closed 2 years ago

FTKyaoyuan commented 2 years ago

``

epochs=20
batch_size=4
num_workers=4

transform1 = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    ]
)

def preproc(img):
    img= transform1(img)
    return img

dataloader = create_image_embedding_dataloader(
    tar_url="/home/zhengshiguang/mscoco/{00000..00059}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    img_embeddings_url="/home/zhengshiguang/output/findaly",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=num_workers,
    batch_size=batch_size,
    index_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
    img_preproc=preproc

)

clip = OpenClipAdapter("ViT-L/14")

unet1 = Unet(
    dim = 128,
    image_embed_dim = 768,
    text_embed_dim = 256,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings = True,
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 768,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8),
).cuda()
decoder = Decoder(
    unet = (unet1,unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

for epoch in range(epochs):
    print("正在训练第%d epoch"%(epoch))
    u1_losses=0
    u2_losses=0
    for i,data in enumerate(tqdm(dataloader)):
        print("正在训练第%d batch" %(i))
        img , emb = data
        for unet_number in (1, 2):
            loss = decoder_trainer(
                        img,
                    text=emb["img"],
                    unet_number=1,  # which unet to train on
                    max_batch_size=4
                    # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
                    )
            print("Loss=%f"%(loss))
            decoder_trainer.update(1)  # update the specific unet as well as its exponential moving average
        if i%100==0:
                print("正在保存模型")
                decoder_trainer.save("best_decoder.pth")
    print("正在保存模型")
    decoder_trainer.save("last_decoder.pth")
`
FTKyaoyuan commented 2 years ago

The above is my training code, but there is an error image So I can change the existing code image change to image

FTKyaoyuan commented 2 years ago

image But other errors occurred

These seem to be related to the version of pytorch

You can tell me dalle2_ pytorch Do you want to use the version of python and the version of python

lucidrains commented 2 years ago

@FTKyaoyuan ohh oops, thanks for reporting the first error! https://github.com/lucidrains/DALLE2-pytorch/commit/d0c11b30b081a26dc22fb7cdcb2c6750316acc27

the second error, i believe you are passing in text as floats instead of indices (text token ids)

FTKyaoyuan commented 2 years ago

Thanks I understand but first error , I change my clip

from dalle2_pytorch.dalle2_pytorch import OpenClipAdapter
clip = OpenAIClipAdapter("ViT-L/14")

change to

from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
clip = OpenAIClipAdapter("ViT-L/14")

there would be no first mistake Maybe their imange_size return types are different

FTKyaoyuan commented 2 years ago

I have solved the second problem

By the way, what is the minimum GPU required for training using the config file and how much GPU memory

FTKyaoyuan commented 2 years ago

and another The code in the example using Dataloaders is

dataloader = create_image_embedding_dataloader(
    tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    embeddings_url="path/or/url/to/embeddings/folder",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=4,
    batch_size=32,
    shard_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
    print(img.shape)  # torch.Size([32, 3, 256, 256])
    print(emb.shape)  # torch.Size([32, 512])

but in my code meb.type() is dict

so maybe right coda is

print(img.shape)  # torch.Size([32, 3, 256, 256])
print(emb["img"].shape)  # torch.Size([32, 512])
lucidrains commented 2 years ago

@FTKyaoyuan thanks! https://github.com/lucidrains/DALLE2-pytorch/commit/b39653cf962380e1c2e8e4d447aaba608067ddc2