QUVA-Lab / e2cnn_experiments

Experiment for General E(2)-Equivariant Steerable CNNs
Other
24 stars 4 forks source link

Questions about making ResNets equivariant with e2cnn #3

Open blazejdolicki opened 2 years ago

blazejdolicki commented 2 years ago

I'm trying to create an equivariant ResNet model based on e2_wide_resnet.py and I would really appreciate if you could clarify some of my doubts.

Main questions

  1. I see that in your repo e2_wide_resnet.py is the equivariant version of wide_resnet.py and WideBasic is equivalent to BasicBlock. In standard, non-equivariant CNNs, we pass the arguments in_planes and out_planes to BasicBlock which correspond to the number of input and output channels of the block. Afaik the number of channels in standard CNNs are equivalent to number of representations in e2cnn.nn.FieldType. However, in WideBasic instead of passing input and output FieldTypes, in_fiber and out_fiber, you also pass the inner FieldType inner_fiber. So my question is why is passing those two FieldTypes (like in non-equivariant network) not enough and we need the third one?

  2. Apart from using BasicBlock, I also want to use a different block needed for most models in the resnet family - Bottleneck. Bottleneck is a bit more complicated than BasicBlock, this is how it initialization looks like:

    class Bottleneck(nn.Module):
    def __init__(self, inplanes: int, planes: int, ...) -> None:
          super().__init__()
          if norm_layer is None:
              norm_layer = nn.BatchNorm2d
          width = int(planes * (base_width / 64.0)) * groups
          # Both self.conv2 and self.downsample layers downsample the input when stride != 1
          self.conv1 = conv1x1(inplanes, width)
          self.bn1 = norm_layer(width)
          self.conv2 = conv3x3(width, width, stride, groups, dilation)
          self.bn2 = norm_layer(width)
          self.conv3 = conv1x1(width, planes * self.expansion)
          self.bn3 = norm_layer(planes * self.expansion)
          self.relu = nn.ReLU(inplace=True)
          self.downsample = downsample
          self.stride = stride

    which looks like this on a diagram from the ResNext paper (for 32 groups): image At the beginning, we calculate the width which is the number of output channels in the self.conv1 and the number of input channels of self.conv2 (128 in the diagram above). Given that we need a different number of channels (or number of representations in FieldType), for the equivariant bottleneck, I create a new FieldType width_fiber:

    
    # I think len(out_fiber) is the number of channels in E2
    planes = len(out_fiber)
    width = int(planes * (base_width / 64.)) * groups

now we need to get the same field type but with width representations (number of channels)

dirty check if all representations in in_fiber are the same (otherwise next line is incorrect)

the assumption here is that all representations in this FieldType are the same (no mixed types)

first_rep_type = type(in_fiber.representations[0]) for rep in in_fiber.representations: assert first_rep_type == type(rep)

FIXME hardcoded representation, should be the same representation as in_fiber

in_rep = 'regular'

create new fiber with width channels

width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]]) self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False) ...

Does that seem correct to you? Is there a cleaner way to retrieve the representation type from FieldType instead of hardcoding it?
After doing that, I again need to use a different number of channels (from planes to planes * expansion) and I do a similar thing as before for `exp_out_fiber`. Here is the whole code for E2Bottleneck initialization:

class E2Bottleneck(nn.EquivariantModule): def init( self, in_fiber: nn.FieldType, inner_fiber: nn.FieldType, out_fiber: nn.FieldType=None, ...):

I think len(out_fiber) is the number of channels in E2

              planes = len(out_fiber)
              width = int(planes * (base_width / 64.)) * groups

              # now we need to get the same field type but with `width` representations (number of channels)

              # dirty check if all representations in `in_fiber` are the same
              first_rep_type = type(in_fiber.representations[0])
              for rep in in_fiber.representations:
                  assert first_rep_type == type(rep)

              # FIXME hardcoded representation, should be the same representation as in_fiber
              in_rep = 'regular'

              # create new fiber with `width` channels
              width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])

              self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
              self.bn1 = nn.InnerBatchNorm(width_fiber)
              self.conv2 = conv(width_fiber, width_fiber, stride, groups, dilation, sigma=sigma, F=F, initialize=False)
              self.bn2 = nn.InnerBatchNorm(width_fiber)

             # create new fiber with `planes * self.expansion` channels
              exp_out_fiber = nn.FieldType(in_fiber.gspace,
                                           planes * self.expansion * [in_fiber.gspace.representations[in_rep]])
              self.conv3 = conv1x1(width_fiber, exp_out_fiber, sigma=sigma, F=F, initialize=False)
              self.bn3 = nn.InnerBatchNorm(exp_out_fiber)
              self.relu = nn.ReLU(inplace=True)
              self.downsample = downsample
              self.stride = stride
**Smaller questions about e2_wide_resnet.py**
1. Why do we use conv layers with kernel size 3 for rotations 0, 2 and 4 while conv layers with kernel size 5 for others?

if rotations in [0, 2, 4]: conv = conv3x3 else: conv = conv5x5

2. Is this initialization correct?

elif isinstance(module, torch.nn.BatchNorm2d): module.weight.data.fill(1) module.bias.data.zero() elif isinstance(module, torch.nn.Linear): module.bias.data.zero_()


BatchNorm2d isn't even used, should we replace it with InnerBatchNorm or remove that part entirely (InnerBatchNorm doesn't have instance variables weight or bias from what I've checked). Also why is the standard linear initialized to 0 instead of using standard initializations?

Looking forward to your reply and please let me know if there is anything unclear in my question :)
Gabri95 commented 2 years ago

Hi @blazejdolicki

Main questions:

  1. If I understood correctly your question, in my code, inner_type represents what you called out_planes. You can then verify that, when the blocks are instantiated here, I always just set the output type to the same type as inner_type. The only exception is in the last layer of the model, where I need to output invariant features. This is the only reason why WideBasic has the option to specify a different output type

  2. The relation between the number of channels in the conventional model and the equivariant one depends on what you want to do. If you actually want to count the number of effective channels in the equivariant model (relevant for preserving the computational/memory cost), you should look at out_fiber.size. Instead, len(out_fiber) tells you the number of independent features. If you use regular_representation, this is equivalent to the number of channels in a GCNN; in this sense, out_fiber.size is |G| times larger, as each GCNN channel effectively stores an activation for each element of the group G. In my code, this behaviour is regulated by the fixparams argument. If set to False, I make sure that out_fiber.size = out_planes; instead, if fixparams=True, I scale up the number of channels by roughly sqrt{|G|} to ensure the total number of parameters is the same. If you use len(out_fiber) = out_channels you are generating a much larger model; I'd not recommend this as it's likely going to be a very expensive architecture.

2.1 I am not sure what you mean with type(in_fiber.representations[0]). Regarding your "dirty check", our new library includes a shortcut for that, check here.

2.2 One you verified all representations in in_fiber are the same, you can just access one as in_fiber.representations[0]and use it to build width_fiber as width_fiber = FieldType(in_fiber.gspace, width * [in_fiber.representations[0]]). This avoids hard-coding in_rep

2.3 The same trick can be used for expansions. If you want to preserve the total number of channels or the number of parameters (as recommended earlier), I'd recommend to estimate the number of channels using something like this and pass it planes * self.expansion, such that the number of channels is rounded after taking the product.


Smaller questions:

  1. the grid where images are sampled are perfectly symmetric only wrt 90 degrees rotations (contained in C_N for N=1, 2 or 4). Equivariance to other rotations is necessarily approximate; using slightly wider filters allows for better stability to smaller rotations

  2. Yeah the initialization of BatchNorm2d is not necessary but you don't need to add anything: InnerBatchNorm is already initialized in that way during instantiation. Also, note that I initialized to 0 only the bias, but not the weights of torch.nn.Linear.

Hope this helped! let me know if you have other questions Gabriele

blazejdolicki commented 2 years ago

Hey Gabriele, thanks for your elaborate answer, this makes it much more clear now! It seems that e2wrn.py is an improved version of e2_wide_resnet.py. Apart from some variable renaming (for example conv2triv was changed to totrivial), I see that before it was a parameter in the Wide_ResNet model class while now it's only included in _wide_layer. What was the reasoning behind that change? Moreover, now there is no GroupPooling layer, isn't it necessary for invariance?

blazejdolicki commented 2 years ago

I have another question related to the comment above. Based on my understanding it seems to me that the group space used in a FieldType should correspond to it's representation. For example, a FieldType with a trivial space should have trivial representations. Based on the examples, I see that this is not the case, for example in the first layer the gspace is Rot2dOnR2() (instead of TrivialOnR2()) while the representations are trivial (for example, in the 3rd code cell here). Can you explain why is my reasoning incorrect?

Gabri95 commented 2 years ago

Hi @blazejdolicki

Sorry for the late reply

Regarding the difference between e2wrn.py and e2_wide_resnet.py: the first is a cleaner example I prepared for the tutorial of the library while the second was a more flexible model I built to be able to run different experiments. I do not use GroupPooling in e2wrn.py since the last convolutional layer already maps to invariant features (trivial representations) so there is no need to perform further pooling.

Regarding the other question The gspace describes the symmetry group (e.g. Rot2dOnR2(N=8) = 8 discrete rotations, TrivialOnR2() = no rotations). The representations passed as second argument, instead, define how this symmetry group transforms the channels.

Take a look at this tutorial; does it make things more clear?

Best, Gabriele

blazejdolicki commented 2 years ago

Thanks, that makes it more clear. I implemented and trained an equivariant model for N=4 rotations, but when I check the probabilities between rotations by 90 degrees (so perfect rotation without interpolation artifacts) of the same images, sometimes there is a significant difference such as the second image here (each subplot title contains class probabilities for the corresponding image) image . Do you think this comes from some justifiable numerical inaccuracy or is there something inherently wrong with the model? Below I'm attaching my model architecture, the invariance is obtained by converting to trivial representations.

E2ResNet(
  (conv1): R2Conv([4-Rotations: 3 representations], [4-Rotations: 16 representations], kernel_size=5, stride=1, padding=2, bias=False)
  (bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
  (maxpool): PointwiseMaxPool()
  (layer1): SequentialModule(
    (0): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
      (conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
    )
    (1): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
      (conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
    )
  )
  (layer2): SequentialModule(
    (0): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=3, stride=2, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
      (conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
      (downsample): SequentialModule(
        (0): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=1, stride=2, bias=False)
        (1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
      (conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
    )
  )
  (layer3): SequentialModule(
    (0): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=3, stride=2, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
      (conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
      (downsample): SequentialModule(
        (0): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=1, stride=2, bias=False)
        (1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
      (conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
    )
  )
  (layer4): SequentialModule(
    (0): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=3, stride=2, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
      (conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 128 representations])
      (downsample): SequentialModule(
        (0): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=1, stride=2, bias=False)
        (1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): E2BasicBlock(
      (conv1): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
      (conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=3, stride=1, padding=1, bias=False)
      (bn2): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True, type=[4-Rotations: 512 representations])
      (downsample): SequentialModule(
        (0): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=1, stride=1, bias=False)
        (1): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

this model is created using this script:

class E2BasicBlock(nn.EquivariantModule):
    expansion: int = 1

    def __init__(
        self,
        in_fiber: nn.FieldType,
        inner_fiber: nn.FieldType,
        out_fiber: nn.FieldType = None,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        F: float = 1.,
        sigma: float = 0.45,
    ) -> None:
        super(E2BasicBlock, self).__init__()

        if out_fiber is None:
            out_fiber = in_fiber

        self.in_type = in_fiber
        inner_class = inner_fiber
        self.out_type = out_fiber

        if isinstance(in_fiber.gspace, gspaces.FlipRot2dOnR2):
            rotations = in_fiber.gspace.fibergroup.rotation_order
        elif isinstance(in_fiber.gspace, gspaces.Rot2dOnR2):
            rotations = in_fiber.gspace.fibergroup.order()
        else:
            rotations = 0

        if rotations in [0, 2, 4]:
            conv = conv3x3
        else:
            conv = conv5x5

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv(self.in_type, inner_class, stride=stride, sigma=sigma, F=F, initialize=False)
        self.bn1 = nn.InnerBatchNorm(inner_class)
        self.relu = nn.ReLU(inner_class, inplace=True)
        self.conv2 = conv(inner_class, self.out_type, sigma=sigma, F=F, initialize=False)
        self.bn2 = nn.InnerBatchNorm(self.out_type)
        # add another relu because the shape changes
        self.relu2 = nn.ReLU(self.out_type, inplace=True)
        self.stride = stride

        # `downsample` in resnet.py is the equivalent of `shortcut` in e2_wide_resnet.py

        self.downsample = None
        if stride != 1 or self.in_type != self.out_type:
            self.downsample = nn.SequentialModule(
                conv1x1(self.in_type, self.out_type, stride=stride, bias=False, sigma=sigma, F=F, initialize=False),
                nn.InnerBatchNorm(self.out_type),
                )

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu2(out)

        return out

    # abstract method
    def evaluate_output_shape(self, input_shape):
        raise NotImplementedError

class E2ResNet(torch.nn.Module):

    def __init__(
        self,
        block: Type[Union[E2BasicBlock, E2Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        N: int = 8,
        restrict: int = 1,
        flip: bool = True,
        main_fiber: str = "regular",
        inner_fiber: str = "regular",
        F: float = 1.,
        sigma: float = 0.45,
        deltaorth: bool = False,
        fixparams: bool = True,
        initial_stride: int = 1,
        conv2triv: bool = True,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None
    ) -> None:
        """

        :param block: Type of block used in the model (E2BasicBlock or E2Bottleneck)
        :param layers: Number of blocks in each layer
        :param num_classes:
        :param N:
        :param restrict:
        :param f: If the model is flip equivariant.
        :param main_fiber:
        :param inner_fiber:
        :param F:
        :param sigma:
        :param deltaorth:
        :param fixparams:
        :param conv2triv:
        :param zero_init_residual:
        :param groups:
        :param width_per_group:
        :param replace_stride_with_dilation:
        """
        super(E2ResNet, self).__init__()

        # Standard initialization of ResNet

        # Number of output channels of the first convolution
        self.inplanes = 64
        self.dilation = 1

        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        # Equivariant part of initialization of ResNet
        self._fixparams = fixparams
        self.conv2triv = conv2triv

        self._layer = 0
        self._N = N

        # if the model is [F]lip equivariant
        self._f = flip

        # level of [R]estriction:
        #   r < 0 : never do restriction, i.e. initial group (either D8 or C8) preserved for the whole network
        #   r = 0 : do restriction before first layer, i.e. initial group doesn't have rotation equivariance (C1 or D1)
        #   r > 0 : restrict after every block, i.e. start with 8 rotations, then restrict to 4 and finally 1
        self._r = restrict

        self._F = F
        self._sigma = sigma

        if self._f:
            self.gspace = gspaces.FlipRot2dOnR2(N)
        else:
            self.gspace = gspaces.Rot2dOnR2(N)

        if self._r == 0:
            id = (0, 1) if self._f else 1
            self.gspace, _, _ = self.gspace.restrict(id)

        # Start building layers

        # field type of layer lifting the Z^2 input to N rotations
        self.in_lifting_type = nn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)

        # field type for the first lifted layer
        self.next_in_type = FIBERS[main_fiber](self.gspace, self.inplanes, fixparams=self._fixparams)

        # number of output channels in each outer layer
        num_channels = [64, 128, 256, 512]
        # For this initial cnn, torchvision ResNet uses kernel_size=7, stride=2, padding=3
        # wide_resnet.py uses kernel_size=3, stride=1, padding=1
        # and e2_wideresnet.py uses kernel_size=5. We follow the latter.
        self.conv1 = conv5x5(self.in_lifting_type, self.next_in_type, sigma=sigma, F=F, initialize=False)
        self.bn1 = nn.InnerBatchNorm(self.next_in_type)
        self.relu = nn.ReLU(self.next_in_type, inplace=True)
        self.maxpool = nn.PointwiseMaxPool(self.next_in_type, kernel_size=3, stride=2, padding=1)
        # self.layer_i is equivalent to self.block_i in wide_resnet.py (for ith layer)
        # self._make_layer is equivalent to NetworkBlock (which contains the same method) in wide_resnet.py
        # and to _wide_layer in e2_wide_resnet.py

        self.layer1 = self._make_layer(block, num_channels[0], layers[0], stride=initial_stride,
                                       dilate=replace_stride_with_dilation[0],
                                       main_fiber=main_fiber, inner_fiber=inner_fiber)

        # first restriction layer
        if self._r > 0:
            id = (0, 4) if self._f else 4
            self.restrict1 = self._restrict_layer(id)
        else:
            self.restrict1 = lambda x: x

        self.layer2 = self._make_layer(block, num_channels[1], layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0],
                                       main_fiber=main_fiber, inner_fiber=inner_fiber)

        # second restriction layer
        if self._r > 1:
            id = (0, 1) if self._f else 1
            self.restrict2 = self._restrict_layer(id)
        else:
            self.restrict2 = lambda x: x

        self.layer3 = self._make_layer(block, num_channels[2], layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1],
                                       main_fiber=main_fiber, inner_fiber=inner_fiber)

        if self.conv2triv:
            out_fiber = "trivial"
        else:
            out_fiber = None

        self.layer4 = self._make_layer(block, num_channels[3], layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2],
                                       main_fiber=main_fiber, inner_fiber=inner_fiber, out_fiber=out_fiber)

        if not self.conv2triv:
            self.mp = nn.GroupPooling(self.layer4.out_type)

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1,1))
        linear_input_features = self.mp.out_type.size if not self.conv2triv else self.layer4.out_type.size
        self.fc = torch.nn.Linear(linear_input_features, num_classes)

        for module in self.modules():
            if isinstance(module, nn.R2Conv):
                if deltaorth:
                    init.deltaorthonormal_init(module.weights.data, module.basisexpansion)
                else:
                    init.generalized_he_init(module.weights.data, module.basisexpansion)
            elif isinstance(module, torch.nn.Linear):
                module.bias.data.zero_()

        num_params = sum([p.numel() for p in self.parameters() if p.requires_grad])
        print("Total number of learnable parameters:", num_params)

    def _make_layer(self, block: Type[Union[E2BasicBlock, E2Bottleneck]], planes: int, num_blocks: int,
                    stride: int = 1, dilate: bool = False,
                    main_fiber: str = "regular",
                    inner_fiber: str = "regular",
                    out_fiber: str = None,
                    ) -> nn.SequentialModule:
        self._layer += 1
        logging.info(f"Start building layer {self._layer}")

        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1

        layers = []

        main_type = FIBERS[main_fiber](self.gspace, planes, fixparams=self._fixparams)
        inner_class = FIBERS[inner_fiber](self.gspace, planes, fixparams=self._fixparams)

        out_f = main_type

        # add first block that starts with `self.inplanes` channels and ends with `planes` channels
        # use stride=`stride` for the first block and stride=1 for all the rest (default value)
        first_block = block(in_fiber=self.next_in_type,
                            inner_fiber=inner_class,
                            out_fiber=out_f,
                            stride=stride,
                            # downsample=downsample,
                            groups=self.groups,
                            base_width=self.base_width,
                            dilation=previous_dilation,
                            sigma=self._sigma,
                            F=self._F)
        layers.append(first_block)

        # create new field type with `planes * block.expansion` channels
        self.next_in_type = first_block.out_type
        out_f = self.next_in_type
        for _ in range(1, num_blocks-1):
            next_block = block(in_fiber=self.next_in_type,
                               inner_fiber=inner_class,
                               out_fiber=out_f,
                               groups=self.groups,
                               base_width=self.base_width,
                               dilation=self.dilation,
                               sigma=self._sigma,
                               F=self._F)
            layers.append(next_block)
            self.next_in_type = out_f

        # add last block
        if out_fiber is None:
            out_fiber = main_fiber
        out_type = FIBERS[out_fiber](self.gspace, planes, fixparams=self._fixparams)

        last_block = block(in_fiber=self.next_in_type,
                           inner_fiber=inner_class,
                           out_fiber=out_type,
                           groups=self.groups,
                           base_width=self.base_width,
                           dilation=self.dilation,
                           sigma=self._sigma,
                           F=self._F)
        layers.append(last_block)
        self.next_in_type = out_f

        logging.info(f"Built layer {self._layer}")

        return nn.SequentialModule(*layers)

    def _restrict_layer(self, subgroup_id):
        layers = list()
        layers.append(nn.RestrictionModule(self.next_in_type, subgroup_id))
        layers.append(nn.DisentangleModule(layers[-1].out_type))
        self.next_in_type = layers[-1].out_type
        self.gspace = self.next_in_type.gspace

        restrict_layer = nn.SequentialModule(*layers)
        return restrict_layer

    def features(self, x):

        x = nn.GeometricTensor(x, self.in_lifting_type)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        out = self.maxpool(x)

        x1 = self.layer1(out)

        x2 = self.layer2(self.restrict1(x1))

        x3 = self.layer3(self.restrict2(x2))

        x4 = self.layer4(x3)
        # out = self.relu(self.mp(self.bn1(out)))

        return x1, x2, x3, x4

    def _forward_impl(self, x: Tensor) -> Tensor:
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass

        x = nn.GeometricTensor(x, self.in_lifting_type)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(self.restrict1(x))
        x = self.layer3(self.restrict2(x))
        x = self.layer4(x)

        if not self.conv2triv:
            x = self.mp(x)
        x = self.avgpool(x.tensor)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
Gabri95 commented 2 years ago

Hi @blazejdolicki

sorry for the late reply.

No, the equivariance to 90 deg should be almost perfect, with an absolute error probably lower than 1e-5.

At a first look, your code seems ok. Usually this problem arises when you used strided convolution with odd-size filters but your input images have even size.

Check page 3 of our paper for more details on this issue and how to solve it.

Let me know if this was your problem

Best, Gabriele