lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
10.97k stars 1.07k forks source link

prior model train #250

Open cccusername opened 1 year ago

cccusername commented 1 year ago

While attempting to train the diffusion prior (with train_diffusion_prior.py), I run into the following exception: My dalle2_pytorch version is 1.10.7. And dataset is mscoco.

Traceback (most recent call last):
  File "train_diffusion_prior.py", line 770, in <module>
    main()
  File "/usr/local/lib/python3.8/dist-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.8/dist-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.8/dist-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "train_diffusion_prior.py", line 766, in main
    initialize_training(config_file, accelerator)
  File "train_diffusion_prior.py", line 749, in initialize_training
    train(
  File "train_diffusion_prior.py", line 475, in train
    loss = trainer(text=txt, image_embed=img)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenhw/DALLE2-pytorch-main/dalle2_pytorch/trainer.py", line 107, in inner
    out = fn(model, *args, **kwargs)
  File "/home/chenhw/DALLE2-pytorch-main/dalle2_pytorch/trainer.py", line 405, in forward
    loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenhw/DALLE2-pytorch-main/dalle2_pytorch/dalle2_pytorch.py", line 1464, in forward
    return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
  File "/home/chenhw/DALLE2-pytorch-main/dalle2_pytorch/dalle2_pytorch.py", line 1353, in p_losses
    pred = self.net(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenhw/DALLE2-pytorch-main/dalle2_pytorch/dalle2_pytorch.py", line 1091, in forward
    image_embed = torch.where(
RuntimeError: The size of tensor a (512) must match the size of tensor b (768) at non-singleton dimension 2
Lixin-Liu commented 1 year ago

I have the same problem, have you solved it?

cccusername commented 1 year ago

I have the same problem, have you solved it?

没有呢,好像是clip-retrieval得到的特征的问题

Lixin-Liu commented 1 year ago

我在clip-retrieval的时候把模型指定为ViT-L-14,就没有这个问题了

krrishdholakia commented 1 year ago

Curious - what is the dimension of the tensors 'a' and 'b'?