Open FloatingFowl opened 5 years ago
Hi,I also got the same problem. Have you sloved this problem. I debuged this code that the diminion is not the same, so function of torch.cat() can not connect this two parameters. Do you have any solutions for this problem.
In model/TPN.py
, change AttnDecoderRNN::forward()
to this:
def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i):
# Compute context vector.
if i == 0:
context = torch.mean(encoder_outputs, dim=0, keepdim=True)
# Compute gru output.
gru_input = torch.cat((last_palette, context.squeeze(0)), 1)
gru_hidden = self.gru(gru_input, last_decoder_hidden)
# Generate palette color.
#palette = self.out(gru_hidden.squeeze(0))
palette = self.out(gru_hidden.squeeze(1))
return palette, context.unsqueeze(0), gru_hidden#, attn_weights
else:
attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size)
context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))
# Compute gru output.
gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
gru_hidden = self.gru(gru_input, last_decoder_hidden)
# Generate palette color.
#palette = self.out(gru_hidden.squeeze(0))
palette = self.out(gru_hidden.squeeze(1))
return palette, context.unsqueeze(0), gru_hidden#, attn_weights
and remove the fourth receiving parameter whenever this is invoked.
As for the colourization network, I still haven't got that to work.
In
model/TPN.py
, changeAttnDecoderRNN::forward()
to this:def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i): # Compute context vector. if i == 0: context = torch.mean(encoder_outputs, dim=0, keepdim=True) # Compute gru output. gru_input = torch.cat((last_palette, context.squeeze(0)), 1) gru_hidden = self.gru(gru_input, last_decoder_hidden) # Generate palette color. #palette = self.out(gru_hidden.squeeze(0)) palette = self.out(gru_hidden.squeeze(1)) return palette, context.unsqueeze(0), gru_hidden#, attn_weights else: attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size) context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1)) # Compute gru output. gru_input = torch.cat((last_palette, context.squeeze(1)), 1) gru_hidden = self.gru(gru_input, last_decoder_hidden) # Generate palette color. #palette = self.out(gru_hidden.squeeze(0)) palette = self.out(gru_hidden.squeeze(1)) return palette, context.unsqueeze(0), gru_hidden#, attn_weights
and remove the fourth receiving parameter whenever this is invoked.
As for the colourization network, I still haven't got that to work.
Hi, thank you @FloatingFowl for solving the problem. However, the color is not good as expected. There is always like a red filter that covers the prediction. Is anything wrong?
What I said is for the text to palette network. The PCN network had some issues I haven't been able to fix. If you manage to fix it however, please notify :)
Also how many epochs did you run it for?
I make an mistake that I thought the install_pre.sh is to download the dataset and related models in it. Could authors publish their models? Should we need their exact models for comparison?
In
model/TPN.py
, changeAttnDecoderRNN::forward()
to this:def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i): # Compute context vector. if i == 0: context = torch.mean(encoder_outputs, dim=0, keepdim=True) # Compute gru output. gru_input = torch.cat((last_palette, context.squeeze(0)), 1) gru_hidden = self.gru(gru_input, last_decoder_hidden) # Generate palette color. #palette = self.out(gru_hidden.squeeze(0)) palette = self.out(gru_hidden.squeeze(1)) return palette, context.unsqueeze(0), gru_hidden#, attn_weights else: attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size) context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1)) # Compute gru output. gru_input = torch.cat((last_palette, context.squeeze(1)), 1) gru_hidden = self.gru(gru_input, last_decoder_hidden) # Generate palette color. #palette = self.out(gru_hidden.squeeze(0)) palette = self.out(gru_hidden.squeeze(1)) return palette, context.unsqueeze(0), gru_hidden#, attn_weights
and remove the fourth receiving parameter whenever this is invoked. As for the colourization network, I still haven't got that to work.
Hi, thank you @FloatingFowl for solving the problem. However, the color is not good as expected. There is always like a red filter that covers the prediction. Is anything wrong?
I have also reached your stage. There will also be a red mask. How did you solve it in the end? thank you very much!
When running
python3 main.py --mode test_TPN
, I get the following error using the above mentioned configurations: