anantzoid / Conditional-PixelCNN-decoder

Tensorflow implementation of Gated Conditional Pixel Convolutional Neural Network
485 stars 83 forks source link

Error in Mask Implementation #9

Closed loeweX closed 6 years ago

loeweX commented 6 years ago

There is an error in the implementation of the masks on the horizontal and vertical stacks:

If I use the code to recreate the results from here, I get the following output:

output_mask

As you can see, there is no information propagated along the first horizontal line. This can be solved by exchanging the code in https://github.com/anantzoid/Conditional-PixelCNN-decoder/blob/055dab695f5754e0787a4a22899bc4918e3d2e1b/models.py#L33-L40 by

with tf.variable_scope("v_stack_1"+i): 
    v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=None).output()

with tf.variable_scope("h_stack"+i):
    h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask).output()

with tf.variable_scope("h_stack_1"+i):
    h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=None).output()

i.e. not applying the mask for convolutions with kernels of size 1x1. With this implementation I get the desired output:

output_mask_correct

anantzoid commented 6 years ago

Thanks for pointing this out @loeweX. Surprisingly, I still got some results using the former implementation. Did you see any improvement using your method?

loeweX commented 6 years ago

I was using a different dataset, so I can't say anything about the numerical results. However, I first noticed the mistake because errors kept propagating along the horizontal lines of images. But this is really only visible when you use very simple images with prolonged edges and might therefore not play a big role when evaluating datasets like Cifar10.

Also, it increases the runtime a lot, as it basically doubles the size of the network to be evaluated.

anantzoid commented 6 years ago

Thanks for clarifying. Can you send a pull request?