tscohen / GrouPy

Group Equivariant Convolutional Neural Networks
http://ta.co.nl
Other
349 stars 85 forks source link

Transformed filters count towards learnable parameters #21

Open frederikfaye opened 4 years ago

frederikfaye commented 4 years ago

I'm puzzled by the number of trainable parameters in networks using gconv2d.

The script below creates a network using gconv2ds from Z2 to C4 to C4 and counts the number of learnable parameters in the network.

For the groups C4 and D4, the result is S times larger different than what I expected, where S is the number of non-translation transformations, i.e. roto-flips (so S=4 for C4 and S=8 for D4).

Specifically, I'd expect there to be the same number of learnable parameters in a gconv2d layer as in a a normal 2D conv layer (namely n_feat_maps_in*n_feat_maps_out*kernel_size**2, when both have no biases, as is the case for this repository).

So, for C4, the number of parameters I would expect to be learnable in the example below would be 135 + 315, when it turns out to instead be 135 + 315*4. Similarly for D4, we get 135 + 315*8.

I understand how the total number of parameters should be 135 + 315*4 for C4 and 135 + 315*8 for D4, since the filters are practically speaking different (in that they have been roto-flipped). However, I don't think that they should all be individually learnable (since the roto-flip transformations are not learnable), and I'm worried that there may be a problem in the implementation. It could also very well be that I have misunderstood something fundamental, but isn't the whole point of gconvs related to a group G that they are equivariant to the transformations in G without an increase in the number of trainable parameters?

Finally, the test for equivariance at the end of the below script also fails. Is this related, or am I testing the wrong thing?

For the record, I'm finding the same when using the keras_gcnn Keras implementation, i.e., I get the same (and higher than expected) number of trainable parameters when using the model.summary() method of Keras.

Thank you for your time, and for this awesome work!

import numpy as np
import tensorflow as tf
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util

# Model parameters
kernel_size = 3

n_feat_maps_0 = 3
n_feat_maps_1 = 5
n_feat_maps_2 = 7

group_0 = 'Z2'
group_1 = 'C4'
group_2 = 'C4' # Not currently implemented for C4 --> D4

# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, n_feat_maps_0])

# Z2 --> C4 convolution
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input=group_0, h_output=group_1, in_channels=n_feat_maps_0, out_channels=n_feat_maps_1, ksize=kernel_size)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

# C4 --> C4 convolution
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input=group_1, h_output=group_2, in_channels=n_feat_maps_1, out_channels=n_feat_maps_2, ksize=kernel_size)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

# Compute
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
output = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, 3)})

print(output.shape)  # (10, 9, 9, 28)

# Count the number of trainable parameters
print(np.sum([np.prod(v.shape) for v in tf.trainable_variables()])) # 1395 (135 + 315*4)

# Test equivariance by comparing outputs for rotated versions of same datapoint
datapoint = np.random.randn(9, 9, 3)
input = np.stack([datapoint, np.rot90(datapoint)])
output = sess.run(y, feed_dict={x: input})
print(np.allclose(output[0], np.rot90(output[1]))) # False

sess.close()
tscohen commented 4 years ago

Hi Frederik,

In G-CNNs we distinguish two kinds of channels: what we call G-Channels (mathematically a function on G, in practice a collection of S channels), and orientation channels (what one calls a channel in an ordinary CNN - a 2D feature map). The number of channels in your code corresponds to G-channels, hence if you set the number of G channels equal to the number of channels in your baseline CNN, then the G-CNN will have far more orientation channels and also channels.

There are two common modes of converting a CNN to a G-CNN: equal parameters or equal compute / orientation channels. In the first case (equal parameters), you would take the number of channels in your CNN, say N, and divide it by sqrt(S) to get the number of G-Channels for your G-CNN. (Here S=4 for a p4-CNN or 8 for p4m-CNN). The number of orientation channels will be Nsqrt(S) because each of the N/sqrt(S) filters gets transformed S times. The number of parameters is the product N/sqrt(S) [the number of output G-channels] times Nsqrt(S) [the number of input orientation channels] times WH. So the number of parameters is N^2 W * H, the same as the CNN.

For equal compute, use N/S as the number of G-channels. This will lead to a reduction of the number of parameters, which may work better or worse depending on how big the network you started with was. For equal parameters we almost always see a substantial improvement. For more details see our papers on G-CNNs.

frederikfaye commented 4 years ago

Hi Taco, thanks for your answer!

I feel like I didn't explain my point well enough the first time, so I will try to be explicit about the math: For simplicity, let's say we have n_feat_maps_0 = n_feat_maps_1 = n_feat_maps_2 = 1 in my example network (still using Z2 to C4 to C4)*. Hence, we have the situation shown in the following illustration:

gconv

Let's look at the last G-channel (or feature map): The four orientation (or planar) channels together form the single G-channel. Each pixel in the G-channel corresponds to a group element (here an element of C4, or actually p4 in the paper's notation). Let's see how the pixel corresponding to the specific element g' = t's' = e r^3 (highlighted in the above illustration, where e is the identity translation) is calculated:

Using Eq. (18) in your paper, with X = p4, and having K = 1, we get

ql_d232f7c10819971c844adf15e5b14ade_l3

The question now becomes, how many trainable parameters are involved in this computation? The filter psi is non-zero only at 3*3 = 9 places, and therefore has 9 parameters, which are all trainable. In my understanding, this is the same phi used for each of the four orientation channels:

ql_8cdc9c664c37ff09a33fdf465bd0b6d0_l3

Since the same phi is used for all of them, these four cases have a total of only 9 trainable parameters: For each, psi is simply rotated by 90 degrees (and evaluated on different input orientation channels), but since the parameters of the rotation transformations are not trainable, the only trainable parameters remain the 9 in phi.

It seems like the current implementation counts these four rotated versions of phi as containing 9*4 independent trainable parameters (presumably 9 per orientation), and I think this is wrong, hence this question.

Finally, could you comment on the failed equivariance test mentioned in the original post? Thanks!


*To get the above code to run with n_feat_maps_0 = n_feat_maps_1 = n_feat_maps_2 = 1, change the following two lines as follows:

# output = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, 3)})
# to
output = sess.run(y, feed_dict={x: np.random.randn(10, 9, 9, n_feat_maps_0)})
# and
# datapoint = np.random.randn(9, 9, 3)
# to
datapoint = np.random.randn(9, 9, n_feat_maps_0)
tscohen commented 4 years ago

Your understanding is almost correct. The only thing is that a filter psi in the second layer is a function on p4 (the rotation+translation group), not a function on Z2 (translation only). So psi itself has four input orientation channels. So psi has 994 parameters (times however many input G-channels). This is then expanded to yield a filter bank with 499*4 scalars.

To see this, consider the second equation you posted. Psi is indexed by t and s, i.e. translation and rotation. That's how you can see it is a function on p4 not z2. The filter L_s' psi(ts) is obtained from psi by: rotating each of the 4 orientation channels of psi by s' and additionally cyclically permuting them. It's the same kind of transformation law as for the feature maps (shown in one of the figures in our paper).

This funny kind of transformation law follows directly from the definition of L_s' (acting on p4 functions) and the group operation. We have L_s' psi(ts) = psi( (s')^{-1} ts). The element (s')^{-1} ts is a roto-translation that first rotates by (s'^{-1} s) and then translates by (s'^{-1} t). You can derive this by e.g. writing the group elements as 3x3 homogeneous matrices [s', 0; 0, 1] * [s, t; 0, 1] and multiplying those out to get [s' s, s' t; 0, 1].

Regarding the equivariance test, I don't immediately see a mistake / something suspicious. I would trace the execution and look at all the intermediate results - see if they make sense. Maybe it's something trivial like rot90 not doing what you expect or something like that.

frederikfaye commented 4 years ago

Ah, that makes much more sense, thank you for the detailed explanation! (EDIT: I assume you meant "9*4 parameters" instead of "9*9*4 parameters", since the example was using 3x3 filters.)

Regarding the equivariance, I have discovered that SAME padding destroys equivariance when the padding is not symmetric (e.g. in the case of a 2x2 input). There also seems to be some weird rounding errors, but for now, let's focus of the padding problem.

The following script test for equivariance in two very similar examples: One with VALID padding with a 3x3 input, which is equivariant, and one with SAME padding with a 2x2 input, which is clearly not equivariant (there is no way of transforming the second output to yield the first). Change the padding = line to change between the two cases.

The reason this happens, I believe, is because of the asymmetric SAME padding for the 2x2 case (see here for a visual explanation), hence destroying the equivariance before any actual convolution operation takes place.

Should I open another issue with this?

import numpy as np
import tensorflow as tf
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util

np.random.seed(0)
tf.set_random_seed(0)

padding = 'VALID' # 'VALID', 'SAME'
if padding=='SAME':
    input_h, input_w = 2, 2
elif padding=='VALID':
    input_h, input_w = 3, 3

# Model parameters
kernel_size = 2

n_feat_maps_0 = 1
n_feat_maps_1 = 1

group_0 = 'Z2'
group_1 = 'C4'

# Construct graph
x = tf.placeholder(tf.float32, [None, input_h, input_w, n_feat_maps_0])

# Z2 --> C4 convolution
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input=group_0, h_output=group_1, in_channels=n_feat_maps_0, out_channels=n_feat_maps_1, ksize=kernel_size)
w_init = np.array([[2,3],[5,7]]).astype(np.float32)
print('Single 2x2 filter is:')
print(w_init)
w_init = np.expand_dims(np.expand_dims(w_init, -1), -1) # Expand to be compatible
w = tf.Variable(w_init)
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding=padding,
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

# Initialize
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# Test equivariance by comparing outputs for rotated versions of same datapoint:
if padding=='SAME':
    datapoint = np.expand_dims(np.array([[1,0], [0,0]]), -1)
elif padding=='VALID':
    datapoint = np.expand_dims(np.array([[0,1,0], [0,0,0], [0,0,0]]), -1)
inputs = np.stack([datapoint, np.rot90(datapoint)])
outputs = sess.run(y, feed_dict={x: inputs})
sess.close()

# Print inputs
input = inputs[0]
input_rot = inputs[1]
print('Network has inputs:')
print(np.squeeze(input))
print('and')
print(np.squeeze(input_rot))
print('')

def arr_pprint(arr):
    for i in range(arr.shape[-1]):
        print(arr[...,i])

# Print outputs
output = outputs[0]
output_rot = outputs[1]
print(f'and outputs (with total shape {outputs.shape}):')
arr_pprint(output)
print('and')
arr_pprint(output_rot)

# Create transformation that should take the first output to the second
# First undo the (here cyclic) permutation
output_rot_unroll = np.roll(output_rot, -1, axis=-1)
# arr_pprint(output_rot_unroll)

# Now undo the rotation
output_rot_unroll_unrot = np.rot90(output_rot_unroll, axes=(1,0))
print('Try to transform (unroll then rotate) second output to first:')
arr_pprint(output_rot_unroll_unrot)

# Compare with output
is_equivariant = np.all(output==output_rot_unroll_unrot)
if is_equivariant:
    print('Network is equivariant!')
else:
    print('Network is NOT equivariant!')