lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.09k stars 768 forks source link

trainer.train_ Incorrect use of step in text to video #350

Open Ytimed2020 opened 1 year ago

Ytimed2020 commented 1 year ago

Hello, I am using my own text video dataset for training. During the training process, I found that using trainer.update can run, but using trainer.train_ Step will report an error

Traceback (most recent call last): File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/dataloader.py", line 162, in loss1 = trainer.train_step(unet_number = 1) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/trainer.py", line 613, in train_step loss = self.step_with_dl_iter(self.train_dl_iter, kwargs) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/trainer.py", line 631, in step_with_dl_iter loss = self.forward({kwargs, model_input}) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/trainer.py", line 136, in inner out = fn(model, *args, kwargs) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/trainer.py", line 984, in forward loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, *chunked_kwargs) File "/mnt/shareEx/yangyudong/.conda/envs/df/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(input, kwargs) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/elucidated_imagen.py", line 820, in forward text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/imagen_pytorch.py", line 69, in default return d() if callable(d) else d File "/mnt/shareEx/yangyudong/stable-diffusion-video/imagen-pytorch/imagen_pytorch/elucidated_imagen.py", line 820, in text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) TypeError: any() received an invalid combination of arguments - got (bool, dim=int), but expected one of: