LAION-AI / dalle2-laion

Pretrained Dalle2 from laion
499 stars 66 forks source link

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 768x1280) #38

Open MikeWangWZHL opened 1 year ago

MikeWangWZHL commented 1 year ago

Thanks for the great effort. I was trying to try out the inference script with this provided example code:

from dalle2_laion import ModelLoadConfig, DalleModelManager
from dalle2_laion.scripts import InferenceScript

class ExampleInference(InferenceScript):
    def run(self, text: str) -> PILImage.Image:
        """
        Takes a string and returns a single image.
        """
        text = [text]
        image_embedding_map = self._sample_prior(text)
        image_embedding = image_embedding_map[0][0]
        image_map = self._sample_decoder(text=text, image_embed=image_embedding)
        return image_map[0][0]

model_config = ModelLoadConfig.from_json_path("path/to/config.json")
model_manager = DalleModelManager(model_config)
inference = ExampleInference(model_manager)
image = inference.run("Hello World")

But encountered RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 768x1280) Any idea on why? Thanks!

Below is the full error track back:

│ /shared/nas/data/m1/wangz3/phy-lm-vid/third_party/dalle2-laion/dalle2_laion/scripts/playaround.p │
│ y:64 in <module>                                                                                 │
│                                                                                                  │
│   61 model_config = ModelLoadConfig.from_json_path("/shared/nas/data/m1/wangz3/phy-lm-vid/thi    │
│   62 model_manager = DalleModelManager(model_config)                                             │
│   63 inference = ExampleInference(model_manager)                                                 │
│ ❱ 64 output_im = inference.run("Hello World")                                                    │
│   65 output_im.save(f"test_output_image.jpg")                                                    │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/phy-lm-vid/third_party/dalle2-laion/dalle2_laion/scripts/playaround.p │
│ y:58 in run                                                                                      │
│                                                                                                  │
│   55 │   │   text = [text]                                                                       │
│   56 │   │   image_embedding_map = self._sample_prior(text)                                      │
│   57 │   │   image_embedding = image_embedding_map[0][0]                                         │
│ ❱ 58 │   │   image_map = self._sample_decoder(text=text, image_embed=image_embedding)            │
│   59 │   │   return image_map[0][0]                                                              │
│   60                                                                                             │
│   61 model_config = ModelLoadConfig.from_json_path("/shared/nas/data/m1/wangz3/phy-lm-vid/thi    │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/phy-lm-vid/third_party/dalle2-laion/dalle2_laion/scripts/InferenceScr │
│ ipt.py:227 in _sample_decoder                                                                    │
│                                                                                                  │
│   224 │   │   │   │   │   args["inpaint_image"] = inpaint_image_tensors.to(self.device)          │
│   225 │   │   │   │   │   args["inpaint_mask"] = torch.stack(inpaint_image_masks).to(self.devi   │
│   226 │   │   │   │   │   self.print(f"image tensor shape: {args['inpaint_image'].shape}. mask   │
│ ❱ 227 │   │   │   │   output_images = decoder.sample(**args, cond_scale=cond_scale)              │
│   228 │   │   │   │   for output_image, input_embedding_number in zip(output_images, embedding   │
│   229 │   │   │   │   │   if input_embedding_number not in output_image_map:                     │
│   230 │   │   │   │   │   │   output_image_map[input_embedding_number] = []                      │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/autograd/ │
│ grad_mode.py:28 in decorate_context                                                              │
│                                                                                                  │
│    25 │   │   @functools.wraps(func)                                                             │
│    26 │   │   def decorate_context(*args, **kwargs):                                             │
│    27 │   │   │   with self.__class__():                                                         │
│ ❱  28 │   │   │   │   return func(*args, **kwargs)                                               │
│    29 │   │   return cast(F, decorate_context)                                                   │
│    30 │                                                                                          │
│    31 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:95 in inner                                                                    │
│                                                                                                  │
│     92 │   def inner(model, *args, **kwargs):                                                    │
│     93 │   │   was_training = model.training                                                     │
│     94 │   │   model.eval()                                                                      │
│ ❱   95 │   │   out = fn(model, *args, **kwargs)                                                  │
│     96 │   │   model.train(was_training)                                                         │
│     97 │   │   return out                                                                        │
│     98 │   return inner                                                                          │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2809 in sample                                                                 │
│                                                                                                  │
│   2806 │   │   │   │                                                                             │
│   2807 │   │   │   │   # denoising loop for image                                                │
│   2808 │   │   │   │                                                                             │
│ ❱ 2809 │   │   │   │   img = self.p_sample_loop(                                                 │
│   2810 │   │   │   │   │   unet,                                                                 │
│   2811 │   │   │   │   │   shape,                                                                │
│   2812 │   │   │   │   │   image_embed = image_embed,                                            │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/autograd/ │
│ grad_mode.py:28 in decorate_context                                                              │
│                                                                                                  │
│    25 │   │   @functools.wraps(func)                                                             │
│    26 │   │   def decorate_context(*args, **kwargs):                                             │
│    27 │   │   │   with self.__class__():                                                         │
│ ❱  28 │   │   │   │   return func(*args, **kwargs)                                               │
│    29 │   │   return cast(F, decorate_context)                                                   │
│    30 │                                                                                          │
│    31 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2661 in p_sample_loop                                                          │
│                                                                                                  │
│   2658 │   │   is_ddim = timesteps < num_timesteps                                               │
│   2659 │   │                                                                                     │
│   2660 │   │   if not is_ddim:                                                                   │
│ ❱ 2661 │   │   │   return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **k  │
│   2662 │   │                                                                                     │
│   2663 │   │   return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timeste  │
│   2664                                                                                           │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/autograd/ │
│ grad_mode.py:28 in decorate_context                                                              │
│                                                                                                  │
│    25 │   │   @functools.wraps(func)                                                             │
│    26 │   │   def decorate_context(*args, **kwargs):                                             │
│    27 │   │   │   with self.__class__():                                                         │
│ ❱  28 │   │   │   │   return func(*args, **kwargs)                                               │
│    29 │   │   return cast(F, decorate_context)                                                   │
│    30 │                                                                                          │
│    31 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2533 in p_sample_loop_ddpm                                                     │
│                                                                                                  │
│   2530 │   │   │   │   │   noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = t  │
│   2531 │   │   │   │   │   img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)   │
│   2532 │   │   │   │                                                                             │
│ ❱ 2533 │   │   │   │   img = self.p_sample(                                                      │
│   2534 │   │   │   │   │   unet,                                                                 │
│   2535 │   │   │   │   │   img,                                                                  │
│   2536 │   │   │   │   │   times,                                                                │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/autograd/ │
│ grad_mode.py:28 in decorate_context                                                              │
│                                                                                                  │
│    25 │   │   @functools.wraps(func)                                                             │
│    26 │   │   def decorate_context(*args, **kwargs):                                             │
│    27 │   │   │   with self.__class__():                                                         │
│ ❱  28 │   │   │   │   return func(*args, **kwargs)                                               │
│    29 │   │   return cast(F, decorate_context)                                                   │
│    30 │                                                                                          │
│    31 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2476 in p_sample                                                               │
│                                                                                                  │
│   2473 │   @torch.no_grad()                                                                      │
│   2474 │   def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None,   │
│   2475 │   │   b, *_, device = *x.shape, x.device                                                │
│ ❱ 2476 │   │   model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, ima  │
│   2477 │   │   noise = torch.randn_like(x)                                                       │
│   2478 │   │   # no noise when t == 0                                                            │
│   2479 │   │   nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))    │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2442 in p_mean_variance                                                        │
│                                                                                                  │
│   2439 │   def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings =  │
│   2440 │   │   assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder  │
│   2441 │   │                                                                                     │
│ ❱ 2442 │   │   pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_em  │
│   2443 │   │                                                                                     │
│   2444 │   │   if learned_variance:                                                              │
│   2445 │   │   │   pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)                   │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:64 in default                                                                  │
│                                                                                                  │
│     61 def default(val, d):                                                                      │
│     62 │   if exists(val):                                                                       │
│     63 │   │   return val                                                                        │
│ ❱   64 │   return d() if callable(d) else d                                                      │
│     65                                                                                           │
│     66 def cast_tuple(val, length = None, validate = True):                                      │
│     67 │   if isinstance(val, list):                                                             │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:2442 in <lambda>                                                               │
│                                                                                                  │
│   2439 │   def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings =  │
│   2440 │   │   assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder  │
│   2441 │   │                                                                                     │
│ ❱ 2442 │   │   pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_em  │
│   2443 │   │                                                                                     │
│   2444 │   │   if learned_variance:                                                              │
│   2445 │   │   │   pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)                   │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:1854 in forward_with_cond_scale                                                │
│                                                                                                  │
│   1851 │   │   cond_scale = 1.,                                                                  │
│   1852 │   │   **kwargs                                                                          │
│   1853 │   ):                                                                                    │
│ ❱ 1854 │   │   logits = self.forward(*args, **kwargs)                                            │
│   1855 │   │                                                                                     │
│   1856 │   │   if cond_scale == 1:                                                               │
│   1857 │   │   │   return logits                                                                 │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/dalle2_pytorch/ │
│ dalle2_pytorch.py:1916 in forward                                                                │
│                                                                                                  │
│   1913 │   │   # discovered by @mhh0318 in the paper                                             │
│   1914 │   │                                                                                     │
│   1915 │   │   if exists(image_embed) and exists(self.to_image_hiddens):                         │
│ ❱ 1916 │   │   │   image_hiddens = self.to_image_hiddens(image_embed)                            │
│   1917 │   │   │   image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')               │
│   1918 │   │   │   null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)          │
│   1919                                                                                           │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/nn/module │
│ s/module.py:1102 in _call_impl                                                                   │
│                                                                                                  │
│   1099 │   │   # this function, and just call forward.                                           │
│   1100 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1101 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1102 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1103 │   │   # Do not call functions when jit is used                                          │
│   1104 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1105 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/nn/module │
│ s/container.py:141 in forward                                                                    │
│                                                                                                  │
│   138 │   # with Any as TorchScript expects a more precise type                                  │
│   139 │   def forward(self, input):                                                              │
│   140 │   │   for module in self:                                                                │
│ ❱ 141 │   │   │   input = module(input)                                                          │
│   142 │   │   return input                                                                       │
│   143                                                                                            │
│   144                                                                                            │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/nn/module │
│ s/module.py:1102 in _call_impl                                                                   │
│                                                                                                  │
│   1099 │   │   # this function, and just call forward.                                           │
│   1100 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1101 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1102 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1103 │   │   # Do not call functions when jit is used                                          │
│   1104 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1105 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/nn/module │
│ s/linear.py:103 in forward                                                                       │
│                                                                                                  │
│   100 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │
│   101 │                                                                                          │
│   102 │   def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 103 │   │   return F.linear(input, self.weight, self.bias)                                     │
│   104 │                                                                                          │
│   105 │   def extra_repr(self) -> str:                                                           │
│   106 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │
│                                                                                                  │
│ /shared/nas/data/m1/wangz3/miniconda/envs/phy-lm-vid/lib/python3.9/site-packages/torch/nn/functi │
│ onal.py:1848 in linear                                                                           │
│                                                                                                  │
│   1845 │   """                                                                                   │
│   1846 │   if has_torch_function_variadic(input, weight, bias):                                  │
│   1847 │   │   return handle_torch_function(linear, (input, weight, bias), input, weight, bias=  │
│ ❱ 1848 │   return torch._C._nn.linear(input, weight, bias)                                       │
│   1849                                                                                           │
│   1850                                                                                           │
│   1851 def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = No 
clu5 commented 1 year ago

I'm getting the same issue on dalle2_laion==1.1.0 using upsampler.example.json config.

piaopu0120 commented 1 year ago

In the InferenceScript func run , the shape of image_embedding needs batchsize dimension.

class ExampleInference(InferenceScript):
    def run(self, text: str) :
        """
        Takes a string and returns a single image.
        """
        text = [text]
        image_embedding_map = self._sample_prior(text)
        image_embedding = image_embedding_map[0][0] 
        image_embedding = image_embedding.unsqueeze(0) # <- add this
        image_map = self._sample_decoder(text=text, image_embed=image_embedding)
        return image_map[0][0]