lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
10.94k stars 1.06k forks source link

got an error " forward() got an unexpected keyword argument 'mask' " when I run rainbow_dalle example in Jupyter notebook #167

Closed onoduka closed 2 years ago

onoduka commented 2 years ago

TypeError Traceback (most recent call last) Input In [28], in <cell line: 2>() 1 dalle_model_file = "data/rainbow_dalle.model" 2 if not os.path.exists(dalle_model_file): ----> 3 dalle, loss_history = fit(dalle, opt, None, scheduler, 4 (captions_array[train_idx, ...], all_image_codes[train_idx, ...], captions_mask[train_idx, ...]), None, 200, 256, 5 dalle_model_file, train_dalle_batch, 6 n_train_samples=len(train_idx)) 8 plt.plot(loss_history) 9 else:

Input In [14], in fit(model, opt, criterion, scheduler, train_x, train_y, epochs, batch_size, model_file, trainer, n_train_samples) 14 model.train() 15 opt.zero_grad() ---> 16 loss = trainer(model, train_x, train_y, rnd_idx[batch_idx:(batch_idx + batch_size)], criterion) 17 loss.backward() 18 losses.append(loss.item())

Input In [26], in train_dalle_batch(vae, traindata, , idx, ) 1 def train_dalle_batch(vae, traindata, , idx, ): 2 text, image_codes, mask = train_data ----> 3 loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True) 4 return loss

File c:\users\xx\xx\dall\venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, *kwargs) 1106 # If we don't have any hooks, we want to skip the rest of the logic in 1107 # this function, and just call forward. 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'mask'

torch 1.11.0 torch-fidelity 0.3.0 torchmetrics 0.9.1 torchvision 0.12.0

win10 torch runs on cpu

lucidrains commented 2 years ago

could you reopen this issue at https://github.com/lucidrains/dalle-pytorch ?