Closed dgcnz closed 1 month ago
In the worst case we can make use of a helpful property of the transposed convolution, namely that it is equivalent to some standard convolution (which we can easily figure out). It is however more efficient to use the ConvTranspose3d from pytorch if possible. See https://arxiv.org/pdf/1603.07285 for a detailed explanation of convolutions. As such it is likely that if we have solved how to make the Conv3D layers into (relaxed) group convolutions, the transposed convolution will not be a problem.
In the Wang 2024 paper under '2.3 Regular Group Convolutions' under 'Lifting Convolution' they state: The first layer of a G-convolutional network typically lifts the input to a function on G.
So perhaps assuming the first layer is replaced by a lifting convolution makes the most sense.
In Wang 2024, they say they use a separable group convolution since the octahedral group as many elements and the non-separable one scales exponentionally. They refer to this paper as a source: https://proceedings.mlr.press/v162/knigge22a/knigge22a.pdf (I haven't read that).
From #51 I observed some things that might be useful to resolve Q2.
The smokeplume task is (at least for the case of the convnet) takes 1 input and produces 6 (num_outputs) outputs
And the way it does is auto-regressively: It takes the input, produces output 1, feeds that output as input to produce output 2 and so on. I don't know why they do the torch.cat thing, it's useless, it only replaces xx = im basically. Maybe it's to cut the gradient flow, but I spent a bit too much time yesterday on this and right not is not a priority.
What I haven't found (although I haven't looked hard enough) is why is it 6 outputs? I think this is key to understand how we will feed 3 inputs to our super-resolution architecture. @MeneerTS @Nesta-gitU
Relevant info:
I have attempted to follow the structure from the code from Wang 2022. It is less convenient for steerable, but seems intuitive for regular. This code has the following structure: Rotate the weights Create a stack of weights for each rotation Apply the stack using conv2d.
I have already coded up the part that can rotate the weights, it generates all the 48 rotation matrices of the octahedral group. (see notebooks/Regular_Octahedral_CNN.ipynb). To apply this structure to the 3D flow case we need to consider some things:
I'm now 90% sure we can/should use this library to implement the (relaxed) octahedral group conv, it has a pretty nice interface that has already support for SE(3)
, SO(3)
and O(3)
We just have to specify the group transformations and a couple of methods and reuse their blocks
I've added a unit test in https://github.com/dgcnz/dl2/commit/9f9ec06eddf5d0ad7697bab95c16d4d1893bb62e that proves that their implementation matches rui's implementation under appropriate parametrizations. This means we can use this library and save us the work of re-implementing everything from scratch.
I'll proceed with trying to implement the se(2) relaxed group lifting with their interfaces
In this commit I added support for se(2) lifting group convolutions for the gconv library. https://github.com/dgcnz/gconv/commit/1162c9fd6db297070ad4da35b6b985be24d78972
In this commit I added tests to ensure equivalence between Wang's implementation and gconvs 34dd8d9ce2d718b69bbb6c1473eaf30b51509556
We need to make something like this but for the octahedral group:
def uniform_grid(
size: tuple[int, int], matrix_only: bool = False, device: Optional[str] = None
) -> Tensor:
"""
Creates a grid of uniform rotations and reflections. Each O3 element
is represented as a 10 dimensional vector, where the first element
denotes the reflection coefficient, i.e., 1 or -1, and the remaining 9
elements denote the flattened rotation matrix.
Alternatively, if matrix_only is true, the grids will only consists of
the rotation matrices, i.e., elements have a shape of 3 by 3.
Arguments:
- size: Tuple denoting `(n_rotations, n_reflections)`.
- matrix_only: If true, will only rotation matrix part of O3 grid.
Returns:
Tensor of shape `(n_rotations + n_reflections, 10)` or
`(n_rotations + n_reflections, 3, 3)` if matrix_only is true.
"""
n_rotations, n_reflections = size
R1 = so3.uniform_grid(n_rotations, "matrix", device=device)
R2 = so3.uniform_grid(n_reflections, "matrix", device=device)
R = torch.cat((R1, R2), dim=-0)
if matrix_only:
return R
coeff1 = torch.ones(n_rotations, 1, device=device)
coeff2 = -1 * torch.ones(n_reflections, 1, device=device)
coeffs = torch.cat((coeff1, coeff2), dim=0)
grid = torch.cat((coeffs, R.flatten(-2, -1)), dim=-1)
return grid```
Just to update, I think we implement regular octahedral separable convolutions like this (still needs testing)
done
Description
Currently it isn't 100% how the model architecture is setup for the equivariant or relaxed equivariant models. For example, if we look at the model architecture diagram, do we have to replace the conv3d with our group convolution? what about the group lifting? How do we apply a transposed convolution in group space or should we first pool and then apply the transposed convolution?
Resources:
Tasks
Expected outcomes