Open blazejdolicki opened 2 years ago
Hi @blazejdolicki
Main questions:
If I understood correctly your question, in my code, inner_type
represents what you called out_planes
. You can then verify that, when the blocks are instantiated here, I always just set the output type to the same type as inner_type
. The only exception is in the last layer of the model, where I need to output invariant features. This is the only reason why WideBasic has the option to specify a different output type
The relation between the number of channels in the conventional model and the equivariant one depends on what you want to do. If you actually want to count the number of effective channels in the equivariant model (relevant for preserving the computational/memory cost), you should look at out_fiber.size
. Instead, len(out_fiber)
tells you the number of independent features. If you use regular_representation
, this is equivalent to the number of channels in a GCNN; in this sense, out_fiber.size
is |G|
times larger, as each GCNN channel effectively stores an activation for each element of the group G
. In my code, this behaviour is regulated by the fixparams
argument. If set to False, I make sure that out_fiber.size = out_planes
; instead, if fixparams=True
, I scale up the number of channels by roughly sqrt{|G|}
to ensure the total number of parameters is the same. If you use len(out_fiber) = out_channels
you are generating a much larger model; I'd not recommend this as it's likely going to be a very expensive architecture.
2.1 I am not sure what you mean with type(in_fiber.representations[0])
. Regarding your "dirty check", our new library includes a shortcut for that, check here.
2.2 One you verified all representations in in_fiber
are the same, you can just access one as in_fiber.representations[0]
and use it to build width_fiber
as width_fiber = FieldType(in_fiber.gspace, width * [in_fiber.representations[0]])
. This avoids hard-coding in_rep
2.3 The same trick can be used for expansions. If you want to preserve the total number of channels or the number of parameters (as recommended earlier), I'd recommend to estimate the number of channels using something like this and pass it planes * self.expansion
, such that the number of channels is rounded after taking the product.
Smaller questions:
the grid where images are sampled are perfectly symmetric only wrt 90 degrees rotations (contained in C_N
for N=1, 2 or 4). Equivariance to other rotations is necessarily approximate; using slightly wider filters allows for better stability to smaller rotations
Yeah the initialization of BatchNorm2d is not necessary but you don't need to add anything: InnerBatchNorm
is already initialized in that way during instantiation. Also, note that I initialized to 0 only the bias, but not the weights of torch.nn.Linear
.
Hope this helped! let me know if you have other questions Gabriele
Hey Gabriele, thanks for your elaborate answer, this makes it much more clear now! It seems that e2wrn.py is an improved version of e2_wide_resnet.py. Apart from some variable renaming (for example conv2triv
was changed to totrivial
), I see that before it was a parameter in the Wide_ResNet
model class while now it's only included in _wide_layer
. What was the reasoning behind that change? Moreover, now there is no GroupPooling
layer, isn't it necessary for invariance?
I have another question related to the comment above. Based on my understanding it seems to me that the group space used in a FieldType should correspond to it's representation. For example, a FieldType with a trivial space should have trivial representations. Based on the examples, I see that this is not the case, for example in the first layer the gspace is Rot2dOnR2() (instead of TrivialOnR2()) while the representations are trivial (for example, in the 3rd code cell here). Can you explain why is my reasoning incorrect?
Hi @blazejdolicki
Sorry for the late reply
Regarding the difference between e2wrn.py and e2_wide_resnet.py: the first is a cleaner example I prepared for the tutorial of the library while the second was a more flexible model I built to be able to run different experiments. I do not use GroupPooling in e2wrn.py since the last convolutional layer already maps to invariant features (trivial representations) so there is no need to perform further pooling.
Regarding the other question The gspace describes the symmetry group (e.g. Rot2dOnR2(N=8) = 8 discrete rotations, TrivialOnR2() = no rotations). The representations passed as second argument, instead, define how this symmetry group transforms the channels.
Take a look at this tutorial; does it make things more clear?
Best, Gabriele
Thanks, that makes it more clear. I implemented and trained an equivariant model for N=4 rotations, but when I check the probabilities between rotations by 90 degrees (so perfect rotation without interpolation artifacts) of the same images, sometimes there is a significant difference such as the second image here (each subplot title contains class probabilities for the corresponding image)
.
Do you think this comes from some justifiable numerical inaccuracy or is there something inherently wrong with the model? Below I'm attaching my model architecture, the invariance is obtained by converting to trivial representations.
E2ResNet(
(conv1): R2Conv([4-Rotations: 3 representations], [4-Rotations: 16 representations], kernel_size=5, stride=1, padding=2, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(maxpool): PointwiseMaxPool()
(layer1): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
)
)
(layer2): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
)
)
(layer3): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
)
)
(layer4): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 512 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=1, stride=1, bias=False)
(1): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=2, bias=True)
)
this model is created using this script:
class E2BasicBlock(nn.EquivariantModule):
expansion: int = 1
def __init__(
self,
in_fiber: nn.FieldType,
inner_fiber: nn.FieldType,
out_fiber: nn.FieldType = None,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
F: float = 1.,
sigma: float = 0.45,
) -> None:
super(E2BasicBlock, self).__init__()
if out_fiber is None:
out_fiber = in_fiber
self.in_type = in_fiber
inner_class = inner_fiber
self.out_type = out_fiber
if isinstance(in_fiber.gspace, gspaces.FlipRot2dOnR2):
rotations = in_fiber.gspace.fibergroup.rotation_order
elif isinstance(in_fiber.gspace, gspaces.Rot2dOnR2):
rotations = in_fiber.gspace.fibergroup.order()
else:
rotations = 0
if rotations in [0, 2, 4]:
conv = conv3x3
else:
conv = conv5x5
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv(self.in_type, inner_class, stride=stride, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(inner_class)
self.relu = nn.ReLU(inner_class, inplace=True)
self.conv2 = conv(inner_class, self.out_type, sigma=sigma, F=F, initialize=False)
self.bn2 = nn.InnerBatchNorm(self.out_type)
# add another relu because the shape changes
self.relu2 = nn.ReLU(self.out_type, inplace=True)
self.stride = stride
# `downsample` in resnet.py is the equivalent of `shortcut` in e2_wide_resnet.py
self.downsample = None
if stride != 1 or self.in_type != self.out_type:
self.downsample = nn.SequentialModule(
conv1x1(self.in_type, self.out_type, stride=stride, bias=False, sigma=sigma, F=F, initialize=False),
nn.InnerBatchNorm(self.out_type),
)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu2(out)
return out
# abstract method
def evaluate_output_shape(self, input_shape):
raise NotImplementedError
class E2ResNet(torch.nn.Module):
def __init__(
self,
block: Type[Union[E2BasicBlock, E2Bottleneck]],
layers: List[int],
num_classes: int = 1000,
N: int = 8,
restrict: int = 1,
flip: bool = True,
main_fiber: str = "regular",
inner_fiber: str = "regular",
F: float = 1.,
sigma: float = 0.45,
deltaorth: bool = False,
fixparams: bool = True,
initial_stride: int = 1,
conv2triv: bool = True,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None
) -> None:
"""
:param block: Type of block used in the model (E2BasicBlock or E2Bottleneck)
:param layers: Number of blocks in each layer
:param num_classes:
:param N:
:param restrict:
:param f: If the model is flip equivariant.
:param main_fiber:
:param inner_fiber:
:param F:
:param sigma:
:param deltaorth:
:param fixparams:
:param conv2triv:
:param zero_init_residual:
:param groups:
:param width_per_group:
:param replace_stride_with_dilation:
"""
super(E2ResNet, self).__init__()
# Standard initialization of ResNet
# Number of output channels of the first convolution
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
# Equivariant part of initialization of ResNet
self._fixparams = fixparams
self.conv2triv = conv2triv
self._layer = 0
self._N = N
# if the model is [F]lip equivariant
self._f = flip
# level of [R]estriction:
# r < 0 : never do restriction, i.e. initial group (either D8 or C8) preserved for the whole network
# r = 0 : do restriction before first layer, i.e. initial group doesn't have rotation equivariance (C1 or D1)
# r > 0 : restrict after every block, i.e. start with 8 rotations, then restrict to 4 and finally 1
self._r = restrict
self._F = F
self._sigma = sigma
if self._f:
self.gspace = gspaces.FlipRot2dOnR2(N)
else:
self.gspace = gspaces.Rot2dOnR2(N)
if self._r == 0:
id = (0, 1) if self._f else 1
self.gspace, _, _ = self.gspace.restrict(id)
# Start building layers
# field type of layer lifting the Z^2 input to N rotations
self.in_lifting_type = nn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
# field type for the first lifted layer
self.next_in_type = FIBERS[main_fiber](self.gspace, self.inplanes, fixparams=self._fixparams)
# number of output channels in each outer layer
num_channels = [64, 128, 256, 512]
# For this initial cnn, torchvision ResNet uses kernel_size=7, stride=2, padding=3
# wide_resnet.py uses kernel_size=3, stride=1, padding=1
# and e2_wideresnet.py uses kernel_size=5. We follow the latter.
self.conv1 = conv5x5(self.in_lifting_type, self.next_in_type, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(self.next_in_type)
self.relu = nn.ReLU(self.next_in_type, inplace=True)
self.maxpool = nn.PointwiseMaxPool(self.next_in_type, kernel_size=3, stride=2, padding=1)
# self.layer_i is equivalent to self.block_i in wide_resnet.py (for ith layer)
# self._make_layer is equivalent to NetworkBlock (which contains the same method) in wide_resnet.py
# and to _wide_layer in e2_wide_resnet.py
self.layer1 = self._make_layer(block, num_channels[0], layers[0], stride=initial_stride,
dilate=replace_stride_with_dilation[0],
main_fiber=main_fiber, inner_fiber=inner_fiber)
# first restriction layer
if self._r > 0:
id = (0, 4) if self._f else 4
self.restrict1 = self._restrict_layer(id)
else:
self.restrict1 = lambda x: x
self.layer2 = self._make_layer(block, num_channels[1], layers[1], stride=2,
dilate=replace_stride_with_dilation[0],
main_fiber=main_fiber, inner_fiber=inner_fiber)
# second restriction layer
if self._r > 1:
id = (0, 1) if self._f else 1
self.restrict2 = self._restrict_layer(id)
else:
self.restrict2 = lambda x: x
self.layer3 = self._make_layer(block, num_channels[2], layers[2], stride=2,
dilate=replace_stride_with_dilation[1],
main_fiber=main_fiber, inner_fiber=inner_fiber)
if self.conv2triv:
out_fiber = "trivial"
else:
out_fiber = None
self.layer4 = self._make_layer(block, num_channels[3], layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
main_fiber=main_fiber, inner_fiber=inner_fiber, out_fiber=out_fiber)
if not self.conv2triv:
self.mp = nn.GroupPooling(self.layer4.out_type)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1,1))
linear_input_features = self.mp.out_type.size if not self.conv2triv else self.layer4.out_type.size
self.fc = torch.nn.Linear(linear_input_features, num_classes)
for module in self.modules():
if isinstance(module, nn.R2Conv):
if deltaorth:
init.deltaorthonormal_init(module.weights.data, module.basisexpansion)
else:
init.generalized_he_init(module.weights.data, module.basisexpansion)
elif isinstance(module, torch.nn.Linear):
module.bias.data.zero_()
num_params = sum([p.numel() for p in self.parameters() if p.requires_grad])
print("Total number of learnable parameters:", num_params)
def _make_layer(self, block: Type[Union[E2BasicBlock, E2Bottleneck]], planes: int, num_blocks: int,
stride: int = 1, dilate: bool = False,
main_fiber: str = "regular",
inner_fiber: str = "regular",
out_fiber: str = None,
) -> nn.SequentialModule:
self._layer += 1
logging.info(f"Start building layer {self._layer}")
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
layers = []
main_type = FIBERS[main_fiber](self.gspace, planes, fixparams=self._fixparams)
inner_class = FIBERS[inner_fiber](self.gspace, planes, fixparams=self._fixparams)
out_f = main_type
# add first block that starts with `self.inplanes` channels and ends with `planes` channels
# use stride=`stride` for the first block and stride=1 for all the rest (default value)
first_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_f,
stride=stride,
# downsample=downsample,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
sigma=self._sigma,
F=self._F)
layers.append(first_block)
# create new field type with `planes * block.expansion` channels
self.next_in_type = first_block.out_type
out_f = self.next_in_type
for _ in range(1, num_blocks-1):
next_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_f,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
sigma=self._sigma,
F=self._F)
layers.append(next_block)
self.next_in_type = out_f
# add last block
if out_fiber is None:
out_fiber = main_fiber
out_type = FIBERS[out_fiber](self.gspace, planes, fixparams=self._fixparams)
last_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_type,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
sigma=self._sigma,
F=self._F)
layers.append(last_block)
self.next_in_type = out_f
logging.info(f"Built layer {self._layer}")
return nn.SequentialModule(*layers)
def _restrict_layer(self, subgroup_id):
layers = list()
layers.append(nn.RestrictionModule(self.next_in_type, subgroup_id))
layers.append(nn.DisentangleModule(layers[-1].out_type))
self.next_in_type = layers[-1].out_type
self.gspace = self.next_in_type.gspace
restrict_layer = nn.SequentialModule(*layers)
return restrict_layer
def features(self, x):
x = nn.GeometricTensor(x, self.in_lifting_type)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
out = self.maxpool(x)
x1 = self.layer1(out)
x2 = self.layer2(self.restrict1(x1))
x3 = self.layer3(self.restrict2(x2))
x4 = self.layer4(x3)
# out = self.relu(self.mp(self.bn1(out)))
return x1, x2, x3, x4
def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = nn.GeometricTensor(x, self.in_lifting_type)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(self.restrict1(x))
x = self.layer3(self.restrict2(x))
x = self.layer4(x)
if not self.conv2triv:
x = self.mp(x)
x = self.avgpool(x.tensor)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
Hi @blazejdolicki
sorry for the late reply.
No, the equivariance to 90 deg should be almost perfect, with an absolute error probably lower than 1e-5.
At a first look, your code seems ok. Usually this problem arises when you used strided convolution with odd-size filters but your input images have even size.
Check page 3 of our paper for more details on this issue and how to solve it.
Let me know if this was your problem
Best, Gabriele
I'm trying to create an equivariant ResNet model based on
e2_wide_resnet.py
and I would really appreciate if you could clarify some of my doubts.Main questions
I see that in your repo
e2_wide_resnet.py
is the equivariant version ofwide_resnet.py
andWideBasic
is equivalent toBasicBlock
. In standard, non-equivariant CNNs, we pass the argumentsin_planes
andout_planes
toBasicBlock
which correspond to the number of input and output channels of the block. Afaik the number of channels in standard CNNs are equivalent to number of representations ine2cnn.nn.FieldType
. However, in WideBasic instead of passing input and output FieldTypes,in_fiber
andout_fiber
, you also pass the inner FieldTypeinner_fiber
. So my question is why is passing those two FieldTypes (like in non-equivariant network) not enough and we need the third one?Apart from using
BasicBlock
, I also want to use a different block needed for most models in the resnet family - Bottleneck.Bottleneck
is a bit more complicated thanBasicBlock
, this is how it initialization looks like:which looks like this on a diagram from the ResNext paper (for 32 groups):
At the beginning, we calculate the
width
which is the number of output channels in the self.conv1 and the number of input channels of self.conv2 (128 in the diagram above). Given that we need a different number of channels (or number of representations in FieldType), for the equivariant bottleneck, I create a new FieldTypewidth_fiber
:now we need to get the same field type but with
width
representations (number of channels)dirty check if all representations in
in_fiber
are the same (otherwise next line is incorrect)the assumption here is that all representations in this FieldType are the same (no mixed types)
first_rep_type = type(in_fiber.representations[0]) for rep in in_fiber.representations: assert first_rep_type == type(rep)
FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'
create new fiber with
width
channelswidth_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]]) self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False) ...
class E2Bottleneck(nn.EquivariantModule): def init( self, in_fiber: nn.FieldType, inner_fiber: nn.FieldType, out_fiber: nn.FieldType=None, ...):
I think len(out_fiber) is the number of channels in E2
if rotations in [0, 2, 4]: conv = conv3x3 else: conv = conv5x5
elif isinstance(module, torch.nn.BatchNorm2d): module.weight.data.fill(1) module.bias.data.zero() elif isinstance(module, torch.nn.Linear): module.bias.data.zero_()