CQFIO / PhotographicImageSynthesis

Photographic Image Synthesis with Cascaded Refinement Networks
https://cqf.io/ImageSynthesis/
1.25k stars 226 forks source link

Rescaling of lambdas in composite loss? #1

Closed Quasimondo closed 7 years ago

Quasimondo commented 7 years ago

In the paper it says: "The hyperparameters {λl} are set automatically. They are initialized to the inverse of the number of elements in each layer. After 100 epochs, {λl} are rescaled to normalize the expected contribution of each term kΦl(I) − Φl(g(L; θ)) k1 to the loss." - I am trying to find that part in the demo_256p.py code, but I fail to spot it. From what I can see the weights for p0 to p5 are fixed values.

Can you point me to the part where this is applied?

Also I am wondering if the difference between the weights in demo_256p.py:

#demo_256p.py
p0=compute_error(vgg_real['input'],vgg_fake['input'],label)
p1=compute_error(vgg_real['conv1_2'],vgg_fake['conv1_2'],label)
p2=compute_error(vgg_real['conv2_2'],vgg_fake['conv2_2'],tf.image.resize_area(label,(sp//2,sp)))
p3=compute_error(vgg_real['conv3_2'],vgg_fake['conv3_2'],tf.image.resize_area(label,(sp//4,sp//2)))
p4=compute_error(vgg_real['conv4_2'],vgg_fake['conv4_2'],tf.image.resize_area(label,(sp//8,sp//4)))
p5=compute_error(vgg_real['conv5_2'],vgg_fake['conv5_2'],tf.image.resize_area(label,(sp//16,sp//8)))*10

and the ones in demo_512p.py and demo 1024p.py is intentional?

#demo_512p.py & demo 1024p.py
p0=compute_error(vgg_real['input'],vgg_fake['input'],label)
p1=compute_error(vgg_real['conv1_2'],vgg_fake['conv1_2'],label)/2.6
p2=compute_error(vgg_real['conv2_2'],vgg_fake['conv2_2'],tf.image.resize_area(label,(sp//2,sp)))/4.8
p3=compute_error(vgg_real['conv3_2'],vgg_fake['conv3_2'],tf.image.resize_area(label,(sp//4,sp//2)))/3.7
p4=compute_error(vgg_real['conv4_2'],vgg_fake['conv4_2'],tf.image.resize_area(label,(sp//8,sp//4)))/5.6
p5=compute_error(vgg_real['conv5_2'],vgg_fake['conv5_2'],tf.image.resize_area(label,(sp//16,sp//8)))*10/1.5
CQFIO commented 7 years ago

I was rebalancing the Lambda weights after 100 epoch in demo_256p in a separate code. I calculated the weights by computing the average of l0,l1,l2,l3,l4,l5 in the 100th epoch. I will integrate the codes for this. Btw, I do not find balancing the weights affects the visual quality much.

For demo_512p.py and demo_1024p.py, I just fine-tune for 20 or 5 epochs. I just collect the statistics in the 1st epoch and put the weights there.