lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.08k stars 1.08k forks source link

text mask issue while training decoder #286

Open Sevixdd opened 1 year ago

Sevixdd commented 1 year ago

Hello, I get this error when I pass the text embedings to train the decoder and I'm quite stuck

Traceback (most recent call last): File "/decoder_test.py", line 103, in loss = trainer.forward(img, forward_params, unet_number=unet_number, _device=device) File "lib/python3.8/site-packages/dalle2_pytorch/trainer.py", line 107, in inner out = fn(model, *args, *kwargs) File "lib/python3.8/site-packages/dalle2_pytorch/trainer.py", line 723, in forward loss_obj = self.decoder(chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, chunked_kwargs) File "lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 3268, in forward losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) File "lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 3049, in p_losses unet_output = unet( File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, **kwargs) File "lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 2255, in forward text_mask = text_mask[:, :self.max_text_len] IndexError: too many indices for tensor of dimension 1

Process finished with exit code 1