tscohen / GrouPy

Group Equivariant Convolutional Neural Networks
http://ta.co.nl
Other
353 stars 85 forks source link

The axes of the output of function gconv2d are inconsistent with the axes of its input #2

Open yangyu12 opened 7 years ago

yangyu12 commented 7 years ago

Hello @tscohen

I'm trying to use the tensorflow API in your GrouPy lib. And I faced some problem. Then I find in GrouPy/groupy/gconv/tensorflow_gconv/splitgconv2d.py that the axes of returned tensor are (batch, out channels, height, width). And I notice that the input axes are (batch, height, width, in channels).

However, in your tensorflow sample code, you simply feed the the output of the previous conv layer into the next conv layer without any reshape. Does it make sense ?

Thanks a lot!

tscohen commented 7 years ago

Thanks for catching this! As you can see, the tensorflow version has not been battle tested like the Chainer version has been (I just ran the unit tests in check_gconv2d.py).

Looking at the gconv2d constructor, there is a data_format='NHWC' default parameter and a check:

if data_format != 'NHWC':
        raise NotImplemented('Currently only NHWC data_format is supported. Got:' + str(data_format))

But I don't remember why we can't have NCHW. The filter transformation operation should not be affected by the data_format because for both NHWC and NCHW, the shape of the filter is the same. Tf.nn.gconv2d should also support NCHW, though perhaps this wasn't supported previously.

Could you try removing the check and running with data_format='NCHW'? To make sure nothing silently broke, it is probably a good idea to test the equivariance of the layer using something like the code in check_gconv2d.py. You can also do this for your whole network.

yangyu12 commented 7 years ago

Thanks for your reply! I still have some confusion. I've tried to just use this function to construct a network. I guess if I directly use this function to construct each layer, then the true data_formats of each layer output are : NHWC(input) -> NCHW -> NHWC -> NCHW -> ... I'm not sure if it is right.

Do you mean the above usage actually works? Maybe I need to revise your paper to get more familiar with the principle :)