tscohen / GrouPy

Group Equivariant Convolutional Neural Networks
http://ta.co.nl
Other
349 stars 85 forks source link

What's the correct way to implement a coset max-pool in Tensorflow on the output of gconv2d? #17

Open sometimescasey opened 5 years ago

sometimescasey commented 5 years ago

Hi Dr. Cohen - thanks so much for providing the GrouPy and gconv_experiments repos.

I was wondering about the correct way to implement a coset max-pool on the output of y in your Tensorflow example:

# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, 3])

gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='Z2', h_output='D4', in_channels=3, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='D4', h_output='D4', in_channels=64, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

...

print y.shape  # (10, 9, 9, 512) 

My understanding is that the last dimension, 512, comes from the output of 64 channels in the last layer multiplied by 8 (the number of output transformations: 4 rotations, each one is flipped, so 8 total for D4). I assume I'd want to implement a coset maxpool on this output, so that the dimensions are (10, 9, 9, 64) before feeding it to the next layer. (Is this assumption correct)?

I'm not very familiar with Chainer and I'm having a bit of trouble analogising the Chainer code in gconv_experiments over to Tensorflow. I'd appreciate any guidance on recreating the following lines from your paper:

"Next, we replaced each convolution by a p4-convolution (eq. 10 and 11...and added max-pooling over rotations after the last convolution layer."

and

"We took the Z2CNN, replaced each convolution layer by a p4- convolution (eq. 10) followed by a coset max-pooling over rotations. "

Is there a distinction between max-pooling over rotations, vs coset max-pooling over rotations?

My best guess would be to do something like the following - would this be correct?

y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) # (10, 9, 9, 512)
y_reshaped = tf.reshape(y, [-1, 9, 9, 64, 8]) # break the flat 512 into 64 x 8
y_maxpooled = tf.reduce_max(y_reshaped, reduction_indices=[4]) # take max along the last dimension 

Thank you so much!

tscohen commented 5 years ago

Hi Casey,

Your understanding is correct. In the group convolution, we have to sum over: space, rotations/reflections, and channels. For implementation purposes, we fold the rotation/reflection axis into the channel axis, and let conv2d do one big sum over channels+rotations/reflections.

It wasn't entirely clear indeed, but there is no difference between "`Max-pooling over rotations" and "coset max-pooling over rotations".

In my experience it's best to not pool over rotations until the last layer. This is what we report in the G-CNN paper for rotated MNIST, and has also been true for other datasets we've worked with in other papers. What we do instead is reduce the number of channels, to either 1) keep the number of parameters the same as the baseline, growing the number of 2d channels, or 2) reducing the number of parameters, and keeping the number of 2d channels the same. The first one is typically most effective, but also more computationally expensive. Usually we still see accuracy improvements for 2, at no computational cost.

If you do decide to perform max pooling over rotations/reflections internally, note that the result of coset pooling wrt H=rotations+reflections should be treated as a planar (Z2) feature map. So use h_input='Z2' for the next gconv layer. In principle, you can also take a p4m feature map and pool only over cosets wrt the group of rotations only or reflections only, but we've never implemented that.

I'm almost certain that your maxpooling implementation is correct, I only have a slight uncertainty about whether the 512 axis should be unfolded as (64, 8) or (8, 64). I think you did it right though. To test whether it's correct, you can feed an image and a 90-degree rotated version of the same image, and feed both through the network. The pooled feature map should behave like a standard 2D image, so you should get 90-degree rotated feature maps. For p4m feature maps, each group of 8 channels should undergo a rotation/flip, and the 8 channels get permuted.

In any case, it's always good to do an end-to-end equivariance/invariance check of the network.

Good luck. Let me know if you have more questions.

sometimescasey commented 5 years ago

Wow! Thanks so much for the super detailed and quick reply here. I just have one followup:

What we do instead is reduce the number of channels, to either 1) keep the number of parameters the same as the baseline, growing the number of 2d channels, or 2) reducing the number of parameters, and keeping the number of 2d channels the same. The first one is typically most effective, but also more computationally expensive. Usually we still see accuracy improvements for 2, at no computational cost.

Sorry for the dumb question but I don't think I completely understand. What do you mean when you say "2d channel"?

In your paper, and in Z2CNN.py vs P4CNN.py, you dropped from using 20 channels in / 20 channels out (in all layers but l1 and top) to using 10 channels in / 10 channels out in P4CNN. This makes sense since 20 x 20 "connections" (times however many individual weights are in each connection, depending on k_size) = 400, and 10 x 4 rotations x 10 = 400, hence why you divided the number of filters by sqrt(4).

I believe this corresponds to option 1), correct? In that case, what does "growing the number of 2d channels" mean?

Thanks again for the detailed help!

tscohen commented 5 years ago

Hi Casey,

In GCNNs, I distinguish two concepts: 1) A G-channel is mathematically modelled as a function on G. 2) A 2d channel / planar channel is just like a channel in a normal CNN.

If you have G=p4, then each G-channel consists of 4 planar channels.

Makes sense?

Best, Taco On 17 Dec 2018, 21:05 +0100, Casey Li notifications@github.com, wrote:

Wow! Thanks so much for the super detailed and quick reply here. I just have one followup:

What we do instead is reduce the number of channels, to either 1) keep the number of parameters the same as the baseline, growing the number of 2d channels, or 2) reducing the number of parameters, and keeping the number of 2d channels the same. The first one is typically most effective, but also more computationally expensive. Usually we still see accuracy improvements for 2, at no computational cost. Sorry for the dumb question but I don't think I completely understand. What do you mean when you say "2d channel"? In your paper, and in Z2CNN.py vs P4CNN.py, you dropped from using 20 channels in / 20 channels out (in all layers but l1 and top) to using 10 channels in / 10 channels out in P4CNN. This makes sense since 20 x 20 "connections" (times however many individual weights are in each connection, depending on k_size) = 400, and 10 x 4 rotations x 10 = 400, hence why you divided the number of filters by sqrt(4). I believe this corresponds to option 1), correct? In that case, what does "growing the number of 2d channels" mean? Thanks again for the detailed help! — You are receiving this because you commented. Reply to this email directly, view it on GitHub, or mute the thread.

sometimescasey commented 5 years ago

Ah ok, I think I get it. So in (for example) l2 in P4CNN we have 10 G-channels consisting of 4 planar channels each, which is more than the original 20 planar channels in the Z2CNN. But we still have the same number of total parameters between layers (20x20 = 400, or 10x4x10 = 400).

Thanks again!

sometimescasey commented 5 years ago

Hi again Dr. Cohen - I've been trying to reproduce your P4CNN.py in Tensorflow on MNIST_rot by defining the network as follows, but I'm only getting max 85% test accuracy even after 100 epochs of training, whereas I can reproduce your ~2.28% test error in your Chainer experiments. I was wondering if anything obviously wrong jumps out at you in the network structure below:

def make_network():
    images = tf.placeholder(tf.float32, [None, 28, 28, 1])
    tf.identity(images, 'images')
    labels = tf.placeholder(tf.int64, [None])
    tf.identity(images, 'labels')
    training = tf.Variable(True, name='training')

    # l1
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='Z2', h_output='C4', in_channels=1, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y1 = gconv2d(input=images, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y1, 'l1')

    # l2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y2 = gconv2d(input=y1, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y2, 'l2')

    pool2 = tf.layers.max_pooling2d(
    inputs=y2,
    pool_size=[2, 2],
    strides=2)
    tf.identity(pool2, 'pool2')

    # l3
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y3 = gconv2d(input=pool2, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y3, 'l3')

    # l4
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y4 = gconv2d(input=y3, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y4, 'l4')

    # l5
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y5 = gconv2d(input=y4, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y5, 'l5')

    # l6
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y6 = gconv2d(input=y5, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y6, 'l6')

    # top
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=4)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y_top = gconv2d(input=y6, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    tf.identity(y_top, 'features')

    # coset max-pool
    y_reshape = tf.reshape(y_top, [-1, 14, 14, 10, 4])
    y_pool = tf.reduce_max(y_reshape, reduction_indices=[4]) 
    # neither unfolding with (10,4) nor (4,10) seems to give an identical output here when I feed a single image, and a 90 degree rotated version of it

    pool2_flat = tf.reshape(y_pool, [-1, 14 * 14 * 10])
    logits = tf.layers.dense(inputs=pool2_flat, units=10)
    tf.identity(logits, 'logits')

    return images, labels, logits, training

I suspect the pooling after l2 may not be correct here, as you've got something called plane_group_spatial_max_pooling() in the Chainer version. But I assumed I wasn't supposed to do the coset maxpool in the middle of the net as you stated above, so does "plane_group_spatial_max_pooling" refer to something else?

In addition, I tried your suggestion of feeding a single image and then a 90-degree rotated image to test whether the unfold should be tf.reshape(y_top, [-1, 14, 14, 10, 4]) or y_reshape = tf.reshape(y_top, [-1, 14, 14, 4, 10]) in the example above. (10, 4) seems to give higher accuracy but neither results in getting an identical output from y_pool between the two images. So I assume that means I'm doing something wrong...?

Thanks again for all your help...I'll keep inspecting your Chainer code to see if I can figure out my error but I figure it wouldn't hurt to ask.

tscohen commented 5 years ago

Unless I'm missing something, your network lacks nonlinearities aside from the pooling. Add some relus after each layer.

After that, check that the weight scale is correct. Have a look at "He initialization" or "Xavier initialization", or just play around with a global weight scale factor until it works. Also make sure to tune the learning rate.

Another good way to debug is to replace the gconv2d with conv2d, using exactly the same w, and see if it works. If the weight scale is off, then the conv2d net won't work either.

After that, I recommend adding residual blocks.

The pooling we used is indeed just 2x2 spatial max pooling. In our old codebase this required a special function, because we kept the rotation axis and channel axis separate. So we had to fold them together, pool, and then unfold again: https://github.com/tscohen/GrouPy/blob/master/groupy/gconv/chainer_gconv/pooling/plane_group_spatial_max_pooling.py It's easier to just keep the axes folded.