Jungjaewon / Reference_based_Skectch_Image_Colorization

This repository implements the paper "Reference based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence" which is published in CVPR2020.
51 stars 7 forks source link

some questions #2

Closed kunkun0w0 closed 1 year ago

kunkun0w0 commented 3 years ago

triplet loss: I think positive v_k and negative v_k might be the v_k of the reference image and the transformed reference image

kunkun0w0 commented 3 years ago

model.py Encoder part

the output should be reshaped first, then permute it to the proper size

Jungjaewon commented 3 years ago

For now I do not have computation devices. I will write code for triplet loss. Encoder code should be changed it causes warning about channel stride. Thanks for issue and your interest.!!

kunkun0w0 commented 3 years ago

I am sorry to bother you again. But I am still confused about the encoder part. (my English is not good, sorry again)

in the model.py line137-140:

#print('output.size : ', output.size()) # output.size :  torch.Size([2, 992, 16, 16])
b, ch, h, w = output.size()

output = output.reshape((b, h * w, ch)) # output.size :  torch.Size([2, 256, 992])

I notice that you reshape the tensor directly after done the channel-wise concatenation. but in this paper, it seems to structure a 2d tensor whose one dim means spatial information, one dim means channel information. Due to the reshape is just a re-numbering operation, I thought even though you reshape to the proper size (batch, 256, 992), the numbers in the tensor won't be arranged into the way you want (batch, spatial, channel).

in my opinion, the output should be reshaped first, then permute it to the proper size. So that the numbers in the tensor can be arranged into paper-likes. (batch, spatial, channel)

#print('output.size : ', output.size()) # output.size :  torch.Size([2, 992, 16, 16])
b, ch, h, w = output.size()
output = output.reshape((b, ch, h * w))  # output.size :  torch.Size([2, 992, 256])
output = output.permute(0, 2, 1)  # output.size :  torch.Size([2, 256, 992])

but when I try this way to train the model in the same config, it is much harder to converge than the original model, especially the recon_loss. It makes me so confused. Would you like to discuss this phenomenon?

at last, because I am a green hand DL learner, I have no idea of visualizing the attention mechanism like this paper. Could you give some ideas for the visualization?

thank you so much!!!!!!!!! thank you so much!!!!!!!!!

Jungjaewon commented 3 years ago

First, Thanks for interest on my repository, it also make me confused. Did you change only that codes?. I found a link deals about permute and reshape. https://discuss.pytorch.org/t/difference-between-2-reshaping-operations-reshape-vs-permute/30749/4 I do not understand the link perfectly. Unstable training is possibly related to permute. Tell me your thought.

For visualization, I am not sure about that for now, In the paper section 4.5, It is related to Attention map pixel similarity matrix on feature level. I will look into deeply and leave ideas about that.

kunkun0w0 commented 3 years ago

permute and reshape

in my opinion: reshape: create a new empty tensor B with target size, then move the number in origin tensor A to the new tensor one by one.The first number in A is the first number in B ,The second number in A is the second number in B, and so on.In this way, the relative position of the numbers may be changed. permute: we rotate the tensor and change the order of tensor's dimensions into what we want.In this way, the relative position of the numbers can not be changed.

converge problem

In this week, I have done some tests and find interesting things which I want share with you. (It might be a little long) let's first define three corresponding ways:

1. Spatial (SCFT)

if we code in the follow way:

#print('output.size : ', output.size()) # output.size :  torch.Size([2, 992, 16, 16])
b, ch, h, w = output.size()
output = output.reshape((b, ch, h * w))  # output.size :  torch.Size([2, 992, 256])
output = output.permute(0, 2, 1)  # output.size :  torch.Size([2, 256, 992])

that means we are using the spatial information to do the corresponding feature transfer, which is called SCFT in the paper.

2. Channel with a bit spatial

if we code in the follow way:

#print('output.size : ', output.size()) # output.size :  torch.Size([2, 992, 16, 16])
b, ch, h, w = output.size()
output = output.reshape((b, h * w, ch)) # output.size :  torch.Size([2, 256, 992])

that means we are using the channel information to do the corresponding feature transfer, but we group the 992 channels into 256 parts and each part (also can understand as a new channel) does not strictly have the global information, because H×W=256 and 992 is not evenly divisible by 256. I call this corresponding way 'Channel with a bit spatial'.

3. Channel

if we code in the follow way:

'''encoder part'''
#print('output.size : ', output.size()) # output.size :  torch.Size([2, 992, 16, 16])
b, ch, h, w = output.size()
output = output.reshape((b, ch, h * w)) # output.size :  torch.Size([2, 992, 256])

'''SCFT part'''
(change the attention map size into 256*256)

that means we are using the channel information to do the corresponding feature transfer, and each channel has the global information. I call this corresponding way 'Channel'.

The restult of training in three ways

train in the same config (without triplet loss), except 'Channel''s batch_size is 8, others are 16 (because I change the computation devices midway)

5 epochs and logs recloss every 100 iteration. It records 100 times in total. (in 'Channel'_, I log rec_loss every 200 iteration)

result: Figure_1

Conclusion

You easily figure out if we use Channel corresponding feature transfer,the rec_loss will converge easily. Of course, other effects need to be studied by more experiments. But I think it may be a new idea in reference-based sketch image colorization.

At last, if you are interested in it or have some other interesting things willing to share with me, you can email me in 975588189@qq.com

Jungjaewon commented 3 years ago

First, Thanks for your interest,

according to your experiments,

original SCFT is not for sketch colorization in terms of reconstruction loss.

channel corresponding is a good way to converge loss.

and if the number of concatenated channel from encoder is 1024(or some digits divisible by 256). my code(a bit spatial) will work as SCTF right?.

I did not implements triplet loss. you should do it and do experiments again.

Also this clue would be an idea for a new research.

When you find more interesting please leave new issues.