QUVA-Lab / escnn

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
Other
350 stars 46 forks source link

Partial parameter expansion to subgroup #2

Open k8lion opened 2 years ago

k8lion commented 2 years ago

Hello,

I am interested in being able to change the group of an instantiated (and trained) group convolutional layer while preserving weight information. The export function accomplishes this to go from equivariance of the original group down to only translations (in the case of convolutions), but I would like to know how to do this in the general case of expanding the learnable parameters from equivariance of the original group to any subgroup, such as going from C8_on_R2 to C4_on_R2.

Thanks, Kaitlin

Gabri95 commented 2 years ago

Hi Kaitlin,

That's an interesting idea :)

Unfortunately, this is not directly supported for the moment by the library but it would be a very nice feature to include!

I think the best place where to implement this is inside the RdConv module or in the single-block basis expansion module.

I will try to explain the logic to follow by referring to some equations in our paper.

A G-equivariant convolution relies on a filters basis as parameterized in our Eq. 4 (Theorem 2.1). This parameterization involves a G-steerable basis (the Ys functions).

Let's say we have a pre-trained G' layer and we want to re-use its weights in a G<G' equivariant layer.

We can leverage Eq. 5 to decompose each G' steerable basis element (the {Y_j'i'}_j'i') in the G' steerable basis into a number of only G-steerable basis elements. Let's consider one particular such element Y_j'i' which used to transform according to the irrep j' of G'.

After restriction to G, the irrep j' may contain different G-irreps. In particular, the G-irrep j appears in this decomposition [jj'] times (see again Eq. 5).

Before restriction, we used to have a parameter for each endomorphism {c_r^j'}_r. The task is now to express these few endomorphisms as a linear combination of the endomorphisms of the G-irreps in the decomposition of j'. This should boil down to a couple of matrix multiplications to project some vectors (representing the original endomorphisms) on a different basis, but might be a bit nasty to code up.

I will try to write some code next week but I might delay this a bit. If you want to try implementing something like that, please feel free to open a pull request in the meantime!

Best, Gabriele