QUVA-Lab / escnn

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
Other
335 stars 44 forks source link

【Passive Rotation】 #98

Open Luo-Z13 opened 6 months ago

Luo-Z13 commented 6 months ago

I'm interested in learning how to use the operators in e2cnn/escnn to implement a function f such that f(Fea_I) = Fea_I_rot. Here, Fea_I = B(I) represents the feature of an image I after passing through a backbone, and Fea_I_rot = B(I_rot) is the feature of the rotated version of the image I_rot.

From my understanding, it seems that the results after passing I and I_rot through the rotation-equivariant network are the same, as observed in your validation program. Could you please provide some guidance on this matter?

Luo-Z13 commented 5 months ago

Could you please take a moment to look into this question when you have the time? @kalekundert

kalekundert commented 5 months ago

FYI, I'm just a user of this software, not a maintainer. A lot of the math that goes on behind the scenes is beyond my understanding. But I'm familiar with the basics and happy to try the help.

Unfortunately, I don't really understand your question. What's a backbone? Which validation program are you referring to?

If you're asking whether or not you can implement a function $f$ such that:

$$ f(x) = f(g \cdot x) $$

where $x$ is some arbitrary input and $g$ is some sort of transformation that can act on $x$ (e.g. a rotation), then the answer is yes. $f$ in this case would be considered invariant with respect to $g$, not just equivariant. Invariant models are pretty common, because you often want predictions that don't depend on the orientation of the input. There are two ways that I know of to make invariant models using escnn:

Luo-Z13 commented 5 months ago

FYI, I'm just a user of this software, not a maintainer. A lot of the math that goes on behind the scenes is beyond my understanding. But I'm familiar with the basics and happy to try the help.

Unfortunately, I don't really understand your question. What's a backbone? Which validation program are you referring to?

If you're asking whether or not you can implement a function f such that:

f(x)=f(g⋅x)

where x is some arbitrary input and g is some sort of transformation that can act on x (e.g. a rotation), then the answer is yes. f in this case would be considered invariant with respect to g, not just equivariant. Invariant models are pretty common, because you often want predictions that don't depend on the orientation of the input. There are two ways that I know of to make invariant models using escnn:

  • Via convolution: The output of a convolution will be invariant if (i) each spatial dimension is size=1 and (ii) all of the representations are trivial. Here's an example of this:

    import torch
    
    from escnn.gspaces import rot3dOnR3
    from escnn.nn import FieldType, GeometricTensor, R3Conv
    from math import radians
    
    gs = rot3dOnR3()
    so3 = gs.fibergroup
    
    ft1 = FieldType(gs, [so3.bl_regular_representation(2)])
    ft2 = FieldType(gs, 4 * [so3.trivial_representation])
    
    f = R3Conv(ft1, ft2, 3)
    
    x = GeometricTensor(torch.randn(1, 35, 3, 3, 3), ft1)
    
    # 180° rotation around the z-axis.
    g = so3.element([0, 0, radians(180)], 'EV')
    
    gx = x.transform(g)
    
    print(f(x).tensor.reshape(4))
    print(f(gx).tensor.reshape(4))

    Example output:

    tensor([-1.5842, -1.1770, -0.0731,  0.2437], grad_fn=<ReshapeAliasBackward0>)
    tensor([-1.5842, -1.1770, -0.0731,  0.2437], grad_fn=<ReshapeAliasBackward0>)

    The downside to this approach, as I understand it (and this is getting outside what I really understand), is that this convolution will only be able to use the parts of the input that also have trivial representations. This means that a lot of the latent space ends up being wasted.

  • Via Fourier transforms: The frequency=0 components of a Fourier transform are invariant with respect to rotation, so you can get invariance by doing a Fourier transform and only recovering these components. As above, the spatial dimensions have to be size=1. Here's an example:

    import torch
    
    from escnn.gspaces import rot3dOnR3
    from escnn.nn import GeometricTensor, QuotientFourierPointwise
    from math import radians
    
    gs = rot3dOnR3()
    so3 = gs.fibergroup
    so2_z = False, -1
    
    f = QuotientFourierPointwise(
          gs, so2_z, 4, so3.bl_irreps(2),
          out_irreps=so3.bl_irreps(0),
          grid=so3.sphere_grid('thomson_cube', N=4)
    )
    ft = f.in_type
    
    x = GeometricTensor(torch.randn(1, ft.size, 1, 1, 1), ft)
    
    # 180° rotation around the z-axis.
    g = so3.element([0, 0, radians(180)], 'EV')
    gx = x.transform(g)
    
    print(f(x).tensor.reshape(4))
    print(f(gx).tensor.reshape(4))

    Example output:

    tensor([1.7091, 0.4297, 0.9685, 0.4096])
    tensor([1.7091, 0.4297, 0.9685, 0.4096])

    The downside to this approach is that you don't end up with very many channels. The specific number depends on the "band limit" of the input representation, but generally you'll end up with 10x fewer channels than you started with.

invariant

Thank you very much for your response! I apologize if my descriptions were not clear enough due to my limited understanding of this topic.

Specifically, I've referred to the program at https://github.com/QUVA-Lab/escnn/blob/master/examples/e2wrn.py, and from your explanation, I understand that the Wide_ResNet network in e2wrn is rotation invariant. I also referenced the rotation equivariant network ReResNet used in object detection: https://github.com/csuhan/ReDet/issues/133. The validation program here goes through GroupPooling to extract rotation equivariant features and through a Linear layer to extract the final equivariant features (https://github.com/csuhan/ReDet/blob/3eae28f784f771fee8e2305f17a69ac8e84567b0/mmcls/models/backbones/re_resnet.py#L643C13-L643C35), respectively.

So, I would like to know two things: 1) On the GitHub homepage of e2cnn, how are the heat maps of rotation equivariant feature fields drawn (the middle one)? 2) For the feature Fea_I produced by the rotation equivariant network from image I, and the feature Fea_I_rot produced by the rotation equivariant network from the rotated image I_rot, can they be made identical by rotating just in the H/W channel? Or is further rotation required within the orientation channels (the definition like type = e2cnn.nn.FieldType(e2cnn.gspaces.Rot2dOnR2(8), [gs.regular_repr]*3)) to make them identical?

Thank you!