jonkhler / s2cnn

Spherical CNNs
MIT License
939 stars 176 forks source link

Usage documentation? #20

Closed meder411 closed 5 years ago

meder411 commented 5 years ago

Sorry to spam you all with multiple issues, but is there any usage documentation associated with your s2cnn package? In particular, I'm curious when to use the different grid types and convolution types. Things like so3_equatorial_grid() vs. so3_soft_grid(), etc.

I also notice you explicitly call so3_integrate() in the MNIST example. I am wondering why there is a need for explicit integration, and what the operation is doing (I couldn't find that in the papers).

mariogeiger commented 5 years ago

so3_soft_grid defines the SOFT grid that we use to represent our signals. You can use it to compute the Fourier transform of a signal, but we optimized the special case of the SOFT grid using FFT. The following two codes does the same thing (but the second one is much faster)

Slow

grid = so3_soft_grid(b_in)
so3_rft(x, b_out, grid)

Fast

so3_rfft(x, b_out=b_out)
tscohen commented 5 years ago

Regarding the different grids used for the kernel: you are free to use either one, or define your own grid. Spherical CNNs are very new, so we don't yet know what kind of grid / kernel support is appropriate for each kind of task. There is probably a lot that can be improved in terms of architecture details (for 2D CNNs, it took years to find good architectures, and they're still improving).

You can think of so3_integrate as being analogous to "global average pooling" in a standard CNN. This is sometimes done at the end to get approximate translation invariance (in the spherical CNN so3_integrate leads to rotation invariance). The reason you can't just sum up/average the values at all the points is that the grid points are not spread uniformly over the sphere, so we have to weigh them by the inverse of the density, which is given by the Haar measure. For instance, in the SOFT grid we have a lot of points near the north pole. Now imagine a signal that has high values near the north pole. After we rotate the signal, the large values could be near the equator, in which case we'd have fewer points with a high value. Simply summing the original and rotated pixels would not be rotation invariant, but so3_integrate() would be.

meder411 commented 5 years ago

Thanks for the detailed explanations! This is a fascinating area of study.

As I’m sure you know, one unusual effect of processing equirectangular images is that the pixel variance along the top and bottom image borders is 0. At the poles the same pixel is just repeated along the length of the image. In addition to the greater warping required of the filter to capture information, this means that a typical perspective CNN will end up overweighting certain pixels. I presume the non-uniform gridding is a mechanism to circumvent this.

One last question (for now at least ;-)): you have a so3_shortcut function that’s commented as “useful for ResNets.” How do spherical convolutions as you’ve defined them translate to a residual network architecture? Is just straightforward substitution of a spherical convolution for a typical planar one? Could you please give a short toy code example of a skip-connection?

tscohen commented 5 years ago

Indeed this is what so3_integrate counteracts. Note that the non-uniform grid is not the solution (high res exactly uniform grids on the sphere do not exist). The problem of having a non-uniform grid is solved by using quadrature weights / Haar measure to weigh each sample differently when averaging.

Yes, a Spherical ResNet is just a ResNet where you replace the 2D conv by a S2 or SO3 conv. ResNets compute y = f(x) + x, where f is usually a sequence of convolution / batch norm / relu. I currently don't have a code sample at hand, but it should be pretty straightforward.

meder411 commented 5 years ago

This makes sense. I will have to play around with the code a bit I think. If it helps anyone else, I visualized the S2 and SO3 sampling grids with PyPlot in Python. You can find the file attached here. grid_viz.tar.gz

mariogeiger commented 5 years ago

For the SO2 Visualization you exchanged sin and cos for the beta angle. We use beta=0 <=> north pole.

And also you should not set the bandwidth of the soft grid to pi.

meder411 commented 5 years ago

Ah, I thought I may have done that. Yeah, what are good ranges for the soft grid? I used pi because it just showed a lot of points.

Jiankai-Sun commented 5 years ago

For SO3Shortcut, Why do you simply assign grid=((0, 0, 0), ) ? Is this parameter grid need to be changed when we use it?

Thank you!

mariogeiger commented 5 years ago

The grid parameter can be compared to the parameter kernel_size in standard CNNs. But you have more degrees of freedom with grid because you can choose each point of the support of the kernel (compared to standard CNN where you can only have square grids).

Here grid=((0, 0, 0), ) plays the same role as kernel_size=1 (like here)

Jiankai-Sun commented 5 years ago

Thank you for your quick reply!

I noticed that in the provided s2cnn MNIST Example,

grid_s2 = s2_near_identity_grid()
grid_so3 = so3_near_identity_grid()

self.conv1 = S2Convolution(
            nfeature_in=1,
            nfeature_out=f1,
            b_in=b_in,
            b_out=b_l1,
            grid=grid_s2)

self.conv2 = SO3Convolution(
            nfeature_in=f1,
            nfeature_out=f2,
            b_in=b_l1,
            b_out=b_l2,
            grid=grid_so3)

Here, grid_s2 is a tuple with length 24, and grid_so3 is a tuple with length 72.

>>> pprint(grid_s2)
((0.1308996938995747, 0.0),
 (0.1308996938995747, 0.7853981633974483),
 (0.1308996938995747, 1.5707963267948966),
 (0.1308996938995747, 2.356194490192345),
 (0.1308996938995747, 3.141592653589793),
 (0.1308996938995747, 3.9269908169872414),
 (0.1308996938995747, 4.71238898038469),
 (0.1308996938995747, 5.497787143782138),
 (0.2617993877991494, 0.0),
 (0.2617993877991494, 0.7853981633974483),
 (0.2617993877991494, 1.5707963267948966),
 (0.2617993877991494, 2.356194490192345),
 (0.2617993877991494, 3.141592653589793),
 (0.2617993877991494, 3.9269908169872414),
 (0.2617993877991494, 4.71238898038469),
 (0.2617993877991494, 5.497787143782138),
 (0.39269908169872414, 0.0),
 (0.39269908169872414, 0.7853981633974483),
 (0.39269908169872414, 1.5707963267948966),
 (0.39269908169872414, 2.356194490192345),
 (0.39269908169872414, 3.141592653589793),
 (0.39269908169872414, 3.9269908169872414),
 (0.39269908169872414, 4.71238898038469),
 (0.39269908169872414, 5.497787143782138))
>>> pprint(grid_so3)
((0.1308996938995747, 0.0, -0.39269908169872414),
 (0.1308996938995747, 0.0, 0.0),
 (0.1308996938995747, 0.0, 0.39269908169872414),
 (0.1308996938995747, 0.7853981633974483, -1.1780972450961724),
 (0.1308996938995747, 0.7853981633974483, -0.7853981633974483),
 (0.1308996938995747, 0.7853981633974483, -0.39269908169872414),
 (0.1308996938995747, 1.5707963267948966, -1.9634954084936207),
 (0.1308996938995747, 1.5707963267948966, -1.5707963267948966),
 (0.1308996938995747, 1.5707963267948966, -1.1780972450961724),
 (0.1308996938995747, 2.356194490192345, -2.748893571891069),
 (0.1308996938995747, 2.356194490192345, -2.356194490192345),
 (0.1308996938995747, 2.356194490192345, -1.9634954084936207),
 (0.1308996938995747, 3.141592653589793, -3.5342917352885173),
 (0.1308996938995747, 3.141592653589793, -3.141592653589793),
 (0.1308996938995747, 3.141592653589793, -2.748893571891069),
 (0.1308996938995747, 3.9269908169872414, -4.319689898685965),
 (0.1308996938995747, 3.9269908169872414, -3.9269908169872414),
 (0.1308996938995747, 3.9269908169872414, -3.5342917352885173),
 (0.1308996938995747, 4.71238898038469, -5.105088062083414),
 (0.1308996938995747, 4.71238898038469, -4.71238898038469),
 (0.1308996938995747, 4.71238898038469, -4.319689898685965),
 (0.1308996938995747, 5.497787143782138, -5.890486225480862),
 (0.1308996938995747, 5.497787143782138, -5.497787143782138),
 (0.1308996938995747, 5.497787143782138, -5.105088062083414),
 (0.2617993877991494, 0.0, -0.39269908169872414),
 (0.2617993877991494, 0.0, 0.0),
 (0.2617993877991494, 0.0, 0.39269908169872414),
 (0.2617993877991494, 0.7853981633974483, -1.1780972450961724),
 (0.2617993877991494, 0.7853981633974483, -0.7853981633974483),
 (0.2617993877991494, 0.7853981633974483, -0.39269908169872414),
 (0.2617993877991494, 1.5707963267948966, -1.9634954084936207),
 (0.2617993877991494, 1.5707963267948966, -1.5707963267948966),
 (0.2617993877991494, 1.5707963267948966, -1.1780972450961724),
 (0.2617993877991494, 2.356194490192345, -2.748893571891069),
 (0.2617993877991494, 2.356194490192345, -2.356194490192345),
 (0.2617993877991494, 2.356194490192345, -1.9634954084936207),
 (0.2617993877991494, 3.141592653589793, -3.5342917352885173),
 (0.2617993877991494, 3.141592653589793, -3.141592653589793),
 (0.2617993877991494, 3.141592653589793, -2.748893571891069),
 (0.2617993877991494, 3.9269908169872414, -4.319689898685965),
 (0.2617993877991494, 3.9269908169872414, -3.9269908169872414),
 (0.2617993877991494, 3.9269908169872414, -3.5342917352885173),
 (0.2617993877991494, 4.71238898038469, -5.105088062083414),
 (0.2617993877991494, 4.71238898038469, -4.71238898038469),
 (0.2617993877991494, 4.71238898038469, -4.319689898685965),
 (0.2617993877991494, 5.497787143782138, -5.890486225480862),
 (0.2617993877991494, 5.497787143782138, -5.497787143782138),
 (0.2617993877991494, 5.497787143782138, -5.105088062083414),
 (0.39269908169872414, 0.0, -0.39269908169872414),
 (0.39269908169872414, 0.0, 0.0),
 (0.39269908169872414, 0.0, 0.39269908169872414),
 (0.39269908169872414, 0.7853981633974483, -1.1780972450961724),
 (0.39269908169872414, 0.7853981633974483, -0.7853981633974483),
 (0.39269908169872414, 0.7853981633974483, -0.39269908169872414),
 (0.39269908169872414, 1.5707963267948966, -1.9634954084936207),
 (0.39269908169872414, 1.5707963267948966, -1.5707963267948966),
 (0.39269908169872414, 1.5707963267948966, -1.1780972450961724),
 (0.39269908169872414, 2.356194490192345, -2.748893571891069),
 (0.39269908169872414, 2.356194490192345, -2.356194490192345),
 (0.39269908169872414, 2.356194490192345, -1.9634954084936207),
 (0.39269908169872414, 3.141592653589793, -3.5342917352885173),
 (0.39269908169872414, 3.141592653589793, -3.141592653589793),
 (0.39269908169872414, 3.141592653589793, -2.748893571891069),
 (0.39269908169872414, 3.9269908169872414, -4.319689898685965),
 (0.39269908169872414, 3.9269908169872414, -3.9269908169872414),
 (0.39269908169872414, 3.9269908169872414, -3.5342917352885173),
 (0.39269908169872414, 4.71238898038469, -5.105088062083414),
 (0.39269908169872414, 4.71238898038469, -4.71238898038469),
 (0.39269908169872414, 4.71238898038469, -4.319689898685965),
 (0.39269908169872414, 5.497787143782138, -5.890486225480862),
 (0.39269908169872414, 5.497787143782138, -5.497787143782138),
 (0.39269908169872414, 5.497787143782138, -5.105088062083414))

Question 1: So does that mean the kernel size for MNIST Example are 27 and 72 respectively? Is there any possible explaination to adopt such a large kernel size for MNIST? Does it mean the learning effect will be better if we use such a large kernel size in the same way for VGG or ResNet?

Question 2: If it is recommended to just use the same kernel size as the original vgg or resnet implementation (for vgg, the kernel size is usually 2, 3, for ResNet, the kernel size is usually 3, 1, 7), is there any suggestions or rules about how to choose grid with the specified size? Should we manually specify 3 values (e.g. ((0, 0, 0), (1, 1, 1), (2, 2, 2))) or randomly select 3 elements from the existing grid grid_s2 and grid_so3? What if there are repeating elements in chosen points of kernel (e.g. ((0, 0, 0), (0, 0, 0), (1, 1, 1)))? Probably, small kernel size can reduce the needed memory :) If digit is project to the northern hemisphere, should we just choose the point of the support of the kernel whose point[2] > 1?

Question 3: I wonder how to set the parameter bandwidth_in for SO3Convolution() and S2Convolution() if the input rectangle whose size is (batch_size, channel, bandwidth, 2 * bandwidth) instead of (batch_size, channel, 2 * bandwidth, 2*bandwidth)? I am afraid that this is not an easy situation to handle rectangle input.

Question 4: How can we use different weight initialization method for s2cnn? E.g. Glorot initialization, kaiming initialization and so on.

(You can ignore the next 2 questions if it is not is not clearly stated :) ) Question 5: Noticed that here, in order to ensure that only south hemisphere (the paper said each digit is projected on the northern hemisphere) gets projected, you choose grid[2] <= 1, so the grid[2]==1 is the equator? grid[2] <= 1 is the south hemisphere and grid[2] > 1 is the north hemisphere? If we want to project the digit to the whole sphere, we just need to remove this line? Probably there is also something to do with these 2 lines

Question 6: Does s2cnn support project a rectangle shape image instead of square shape image to the sphere? Noticed that lie_learn.spaces.S2.meshgrid(b=b, grid_type=grid_type) can only return the theta, phi with shape [2 bandwidth, 2 bandwidth]. It seems that only square meshgrid are supported. What we need to do to project a 512 1024 3 image to a sphere and then s2_grid?

Thank you!

tscohen commented 5 years ago

Q1: It is a kernel with 27 points. This is comparable to a 2D kernel of size 5, which has 5x5 = 25 points. In 3D you have another dimension, which again increases the number of points / parameters.

The grid is a flat list of coordinates. In a 3x3 kernel in a 2D CNN, the analog would be ((0,0), (0, 1), (0,2), (1,0),(1,1),(1,2),(2,0),(2,1),(2,2)).

Q2: It's not about choosing a set of points from the grid_s2 or grid_so3, it's about defining the right grid. The length of the grid defines the number of samples / points / parameters. The coordinates themselves (phi, theta), define where on the sphere this point lies. The near_identity_grid is one type of grid that we proposed, where the points are near the north pole (for s2) or near the identity transformation (for SO3). You can come up with any grid you like though. Whether it will work well depends on the characteristics of your data. Figuring this out is a research problem; we have only demonstrated that near_identity_grid etc. work reasonably well for the problems we looked at.

Q3: bandwidth_in of one layer should equal the output bandwidth of the previous layer. Not sure what you are asking wrt the square / rectangular grid. Can you rephrase?

Q4: S2Conv is just a pytorch Module, with parameters stored in .kernel and .bias. You can change them as you like.

Q5: whether that's the north or south pole is not an objective mathematical fact. It depends on how we describe the grid in prose / the name we give to a certain coordinate. I think we said that (theta, phi) = (0, 0) corresponds to the north pole.

Q6: we currently only support the SOFT grid, which is square. You could implement spherical convolution for other grids as wel, and this would have other benefits such as a more homogeneous sampling of the sphere. However, you can also resample a rectangular image onto a square grid. The question is how you want to project your image onto the sphere.

Is your image a planar image or a spherical image? If planar, you need to define a projection map from R2 to the S2, e.g. the stereographic projection. Then you can work out for each point in some square grid on the sphere, where it gets mapped to on the plane. Then you can sample that point in the image using bilinear interpolation.

If your image is a spherical image, it must be asssociated with some grid on the sphere. You need to figure out what it is, ie for each pixel what its (theta, phi) coordinates are. Then again you can figure out for each point in the SOFT grid, where in the rectangular image you need to sample.

Jiankai-Sun commented 5 years ago

For Question 3, what I mean is how to set the parameter bandwidth_in if the height and width of input image is different after projected to a grid? For example, if the input is a panorama (Probably panorama is a kind of spherical image that has been associated with a grid on the sphere) with height 512, width 1024 and channel 3 (the height and width of input image is different), Should we set the bandwidth_in for the first layer S2Convolution as 512/2=256 or 1024/2=512? Different from s2cnn MNIST example with size 60x60x1 (same height and width) so that we can easily set the bandwidth as 60/2=30.

Thank you!

tscohen commented 5 years ago

If you want to do things properly, you first need to figure out what is the coordinate on the sphere associated with each of the points in your rectangular / panorama grid. Once you have that, you can create a new square grid, and use it to sample the rectangular one. The bandwidth of this grid can be anything, but I would think that setting bandwidth = 512 would preserve more detail than bandwidth = 256.

Jiankai-Sun commented 5 years ago

Sure, thank you for your reply!