QUVA-Lab / escnn

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
Other
362 stars 47 forks source link

ESCNN SO(3) 3D CNN example #23

Open Ale9806 opened 1 year ago

Ale9806 commented 1 year ago

Hello! I was wondering if there is an example of a 3D Steerable CNN (using R3Conv). Furthermore, Is it possible to use this library to train a model using a tensor of shape (batch, channels, width, length, depth)?

Thanks!

Gabri95 commented 1 year ago

Hi @Ale9806

Thanks for the question! You are right, I will try to prepare a simple example in the next days :) Anyways, R3Conv works exactly like R2Conv, so you can use the SO(2)-steerable CNN example here as a reference. Let me prepare a more detailed example though

Regarding the second question, R3Conv is made precisely to work on tensors of shape (batch, channels, width, length, depth). R3PointConv instead works on point clouds and geometric graphs (where nodes have coordinates in the 3D space).

I will keep this issue open until I add an example for R3Conv.

Hope this helps, Gabriele

Ale9806 commented 1 year ago

Thanks Gabriele!

Ale9806 commented 1 year ago

Similarly is there an example of a 3D gcnn net with the icosahedral group? If so could you share the example?

psteinb commented 1 year ago

Hi @Gabri95 et al

I tried to implement a 3D ESCNN these days (not done yet). two things popped up on my radar:

Bottom line from my side: for a beginner it is super helpful to have a tested and working example (classification, regression). I am happy to help.

Gabri95 commented 1 year ago

Sorry for the delay on this, but I am a bit busy in this month so I have not finished preparing these examples yet. I will try to complete this next week!

@psteinb , regarding the notebook, feel free to open another issue here about it or share a pull-request with your proposed solution. We can upload a corrected version of the notebook in this repository. I will notify the people maintaining the https://uvadlc-notebooks.readthedocs.io/ website about the changes (and, of course, acknowledge your contribution there as well).

Thanks, Gabriele

psteinb commented 1 year ago

Digging a bit deeper, I feel a bit lost now. So I can't resist asking here. Note, I still consider myself a newbie in this field. So feel free to correct me where possible.

Where I am coming from:

So I was putting my bets on using these induced representations of the group obtained from a Fourier Transform (as eluded to here, bottom of the slide) as the output of the first conv layer. I was hoping that this would reduce the memory overhead of my 3D CNN. But bouncing back and forth between the lecture and the escnn docs confused me.

My question: Does escnn support all NN operations for the induced representations (as the lecture calls it)? At this point of the lecture I get the impression that the answer is yes. Sieving through escnn examples, docs and tests, I get the impression the answer is no. I would appreciate some guidance.

psteinb commented 1 year ago

In principle my confusion stems from the fact, what the lecture calls a steerable equivariant convolutional NN and how the lecture tutorial quoted above implements it using escnn.

image

Gabri95 commented 1 year ago

hi @psteinb ,

Sorry for the confusion but, unfortunately, the term induced representation is used in two contexts in practice.

Given two groups H < G (H is a subgroup of G), an induced representation from H to G generates a representation of G from a representation psi of H. Elements of the vector space transforming under an induced representation can be thought as vector fields over the quotient space G/H, where each vector component transforms according to the representation psi of H.

1) Now, theoretical papers describing Steerable CNNs and Erik's lectures, use induced representations to describe the steerable features of a steerable CNN. In this case, you can take for example G=SE(3) and H=SO(3), such that you obtain vector fields over G/H = R^3, which are indeed the steerable features of the neural network. In this case, psi is the representation of SO(3) you chose to define a FieldType (you can think of a GSpace as defining a pair G/H and H and a FieldType as associating with it a representation psi such that it now implicitly defines an induced representation).

2) However, induced representations can also be used to generate the representation psi of SO(3) (that we used above to build a FieldType in a steerable CNN) by inducing from another subgroup of SO(3). If you forgive me for the confusing notation, we can now pick G=SO(3), H=SO(2) and a small representation psi of SO(2) (e.g. the trivial one) to generate fields over a sphere G/H= S^2. Call this induced representation rho.

Now, we can use rho as a representation of SO(3) in a steerable CNN. With again the notation used before, we pick G=SE(3), H=SO(3) and psi=rho. These rho-features are quite expressive and more compact than a regular representation of SO(3).

In my library, when you see layers mentioning Induced[Representation], they refer to the second case: these are modules designed to deal with features whose channels transform under some induced representation. But when you read some theoretical paper or Erik's lectures, you can think of all features of a steerable CNN as a form of induced representation!

Hope this clarifies your doubts a bit!

Let me know if you have more doubts, Gabriele

Gabri95 commented 1 year ago

I keep this issue open in case you want to discuss induced representations further.

I have finally included an example of 3D equivariant CNN, I'm sorry for the delay but I wanted to train the model and check it still works well on ModelNeto10 (the architecture is essentially the same used in our paper).

Best, Gabriele

Ale9806 commented 1 year ago

Thanks, Gabriele !

Ale9806 commented 1 year ago

Hi, when trying to run the code I get the following error:

 64 _channels = int(round(_channels))
 66 # Build the non-linear layer
 67 # Internally, this module performs an Inverse FT sampling the `_channels` continuous input features on the `S`
 68 # samples, apply ELU pointwise and, finally, recover `_channels` output features with discrete FT.

---> 69 ftelu = FourierELU(self.gspace, _channels, irreps=so3.bl_irreps(L), inplace=True, *grid) 70 res_type = ftelu.in_type 72 print(f'ResBlock: {in_type.size} -> {res_type.size} -> {self.out_type.size} | {S_channels}')

AttributeError: 'SO3' object has no attribute 'bl_irreps'

Ale9806 commented 1 year ago

I also tried clonning the repo and running it (instead of running the code with the pypl package) but if fails the asserations of several files!

Gabri95 commented 1 year ago

hi @Ale9806

Are you sure you are using the latest version of the library? SO3 does have that method as you can see here

Could you maybe check escnn.__version__?

Best, Gabriele