awesome-davian / Text2Colors

Code for "Coloring with Words: Guiding Image Colorization through Text-based Palette Generation" - ECCV 2018
MIT License
161 stars 27 forks source link

Dimensionality issue in testing TPN #5

Open FloatingFowl opened 5 years ago

FloatingFowl commented 5 years ago
Torch - 0.4.1
scikit-image - 0.13.1
Python - 3.5.2

When running python3 main.py --mode test_TPN, I get the following error using the above mentioned configurations:

Traceback (most recent call last):
  File "main.py", line 79, in <module>
    main(args)
  File "main.py", line 31, in main
    solver.test_TPN()
  File "/home/mohsin/Text2Colors/solver.py", line 404, in test_TPN
    i)
  File "/home/mohsin/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/mohsin/Text2Colors/model/TPN.py", line 86, in forward
    gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 11 and 32 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:87
gitszu commented 5 years ago

Hi,I also got the same problem. Have you sloved this problem. I debuged this code that the diminion is not the same, so function of torch.cat() can not connect this two parameters. Do you have any solutions for this problem.

FloatingFowl commented 5 years ago

In model/TPN.py, change AttnDecoderRNN::forward() to this:

    def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i):

        # Compute context vector.
        if i == 0:
            context = torch.mean(encoder_outputs, dim=0, keepdim=True)
            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(0)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

        else:
            attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size)
            context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))

            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

and remove the fourth receiving parameter whenever this is invoked.

As for the colourization network, I still haven't got that to work.

minhmanho commented 5 years ago

In model/TPN.py, change AttnDecoderRNN::forward() to this:

    def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i):

        # Compute context vector.
        if i == 0:
            context = torch.mean(encoder_outputs, dim=0, keepdim=True)
            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(0)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

        else:
            attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size)
            context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))

            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

and remove the fourth receiving parameter whenever this is invoked.

As for the colourization network, I still haven't got that to work.

Hi, thank you @FloatingFowl for solving the problem. However, the color is not good as expected. There is always like a red filter that covers the prediction. Is anything wrong? 14_color2

FloatingFowl commented 5 years ago

What I said is for the text to palette network. The PCN network had some issues I haven't been able to fix. If you manage to fix it however, please notify :)

Also how many epochs did you run it for?

minhmanho commented 5 years ago

I make an mistake that I thought the install_pre.sh is to download the dataset and related models in it. Could authors publish their models? Should we need their exact models for comparison?

zhangbanxian123 commented 4 years ago

In model/TPN.py, change AttnDecoderRNN::forward() to this:

    def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i):

        # Compute context vector.
        if i == 0:
            context = torch.mean(encoder_outputs, dim=0, keepdim=True)
            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(0)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

        else:
            attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size)
            context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))

            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden#, attn_weights

and remove the fourth receiving parameter whenever this is invoked. As for the colourization network, I still haven't got that to work.

Hi, thank you @FloatingFowl for solving the problem. However, the color is not good as expected. There is always like a red filter that covers the prediction. Is anything wrong? 14_color2

I have also reached your stage. There will also be a red mask. How did you solve it in the end? thank you very much!