dgcnz / dl2

Code for "Effect of equivariance on training dynamics"
1 stars 0 forks source link

Implement super-resolution model architecture #36

Closed dgcnz closed 1 month ago

dgcnz commented 1 month ago

Description

Currently it isn't 100% how the model architecture is setup for the equivariant or relaxed equivariant models. For example, if we look at the model architecture diagram, do we have to replace the conv3d with our group convolution? what about the group lifting? How do we apply a transposed convolution in group space or should we first pool and then apply the transposed convolution?

Resources:

image image

Tasks

Expected outcomes

MeneerTS commented 1 month ago

Regarding the Transposed Convolution

In the worst case we can make use of a helpful property of the transposed convolution, namely that it is equivalent to some standard convolution (which we can easily figure out). It is however more efficient to use the ConvTranspose3d from pytorch if possible. See https://arxiv.org/pdf/1603.07285 for a detailed explanation of convolutions. As such it is likely that if we have solved how to make the Conv3D layers into (relaxed) group convolutions, the transposed convolution will not be a problem.

MeneerTS commented 1 month ago

Regarding the Lifting Convolution

In the Wang 2024 paper under '2.3 Regular Group Convolutions' under 'Lifting Convolution' they state: The first layer of a G-convolutional network typically lifts the input to a function on G.

So perhaps assuming the first layer is replaced by a lifting convolution makes the most sense.

MeneerTS commented 1 month ago

Octahedral group, an extra obstacle

In Wang 2024, they say they use a separable group convolution since the octahedral group as many elements and the non-separable one scales exponentionally. They refer to this paper as a source: https://proceedings.mlr.press/v162/knigge22a/knigge22a.pdf (I haven't read that).

dgcnz commented 1 month ago

From #51 I observed some things that might be useful to resolve Q2.

The smokeplume task is (at least for the case of the convnet) takes 1 input and produces 6 (num_outputs) outputs

https://github.com/dgcnz/dl2/blob/2e4c7406e4032971da92e5cec6b57e4f6da28cba/configs/data/smokeplume.yaml#L4-L6

And the way it does is auto-regressively: It takes the input, produces output 1, feeds that output as input to produce output 2 and so on. I don't know why they do the torch.cat thing, it's useless, it only replaces xx = im basically. Maybe it's to cut the gradient flow, but I spent a bit too much time yesterday on this and right not is not a priority.

https://github.com/dgcnz/dl2/blob/2e4c7406e4032971da92e5cec6b57e4f6da28cba/src/models/wang2022_module.py#L79-L83

What I haven't found (although I haven't looked hard enough) is why is it 6 outputs? I think this is key to understand how we will feed 3 inputs to our super-resolution architecture. @MeneerTS @Nesta-gitU

dgcnz commented 1 month ago

Relevant info:

MeneerTS commented 1 month ago

Basic setup

I have attempted to follow the structure from the code from Wang 2022. It is less convenient for steerable, but seems intuitive for regular. This code has the following structure: Rotate the weights Create a stack of weights for each rotation Apply the stack using conv2d.

I have already coded up the part that can rotate the weights, it generates all the 48 rotation matrices of the octahedral group. (see notebooks/Regular_Octahedral_CNN.ipynb). To apply this structure to the 3D flow case we need to consider some things:

  1. Do we need to also rotate the vectors? See the need for irreducible representations in the steerable case under 'Other Field types' in https://github.com/QUVA-Lab/escnn/blob/master/examples/introduction.ipynb
  2. What will the shape of our weights be? (we are using 3d kernels instead of 2d)
  3. What are actually the sizes of the hidden dimensions? (not mentioned in the paper iirc)
  4. How can we use the seperable convolutions mentioned. They referenced this paper: https://proceedings.mlr.press/v162/knigge22a/knigge22a.pdf
dgcnz commented 1 month ago

I'm now 90% sure we can/should use this library to implement the (relaxed) octahedral group conv, it has a pretty nice interface that has already support for SE(3), SO(3) and O(3)

image

We just have to specify the group transformations and a couple of methods and reuse their blocks

dgcnz commented 1 month ago

I've added a unit test in https://github.com/dgcnz/dl2/commit/9f9ec06eddf5d0ad7697bab95c16d4d1893bb62e that proves that their implementation matches rui's implementation under appropriate parametrizations. This means we can use this library and save us the work of re-implementing everything from scratch.

I'll proceed with trying to implement the se(2) relaxed group lifting with their interfaces

dgcnz commented 1 month ago

In this commit I added support for se(2) lifting group convolutions for the gconv library. https://github.com/dgcnz/gconv/commit/1162c9fd6db297070ad4da35b6b985be24d78972

In this commit I added tests to ensure equivalence between Wang's implementation and gconvs 34dd8d9ce2d718b69bbb6c1473eaf30b51509556

dgcnz commented 1 month ago

We need to make something like this but for the octahedral group:

https://github.com/dgcnz/gconv/blob/1162c9fd6db297070ad4da35b6b985be24d78972/gconv/geometry/groups/o3.py#L8-L43


def uniform_grid(
    size: tuple[int, int], matrix_only: bool = False, device: Optional[str] = None
) -> Tensor:
    """
    Creates a grid of uniform rotations and reflections. Each O3 element
    is represented as a 10 dimensional vector, where the first element
    denotes the reflection coefficient, i.e., 1 or -1, and the remaining 9
    elements denote the flattened rotation matrix.

    Alternatively, if matrix_only is true, the grids will only consists of
    the rotation matrices, i.e., elements have a shape of 3 by 3.

    Arguments:
        - size: Tuple denoting `(n_rotations, n_reflections)`.
        - matrix_only: If true, will only rotation matrix part of O3 grid.

    Returns:
        Tensor of shape `(n_rotations + n_reflections, 10)` or
        `(n_rotations + n_reflections, 3, 3)` if matrix_only is true.
    """
    n_rotations, n_reflections = size

    R1 = so3.uniform_grid(n_rotations, "matrix", device=device)
    R2 = so3.uniform_grid(n_reflections, "matrix", device=device)
    R = torch.cat((R1, R2), dim=-0)

    if matrix_only:
        return R

    coeff1 = torch.ones(n_rotations, 1, device=device)
    coeff2 = -1 * torch.ones(n_reflections, 1, device=device)
    coeffs = torch.cat((coeff1, coeff2), dim=0)

    grid = torch.cat((coeffs, R.flatten(-2, -1)), dim=-1)

    return grid```
dgcnz commented 1 month ago
dgcnz commented 1 month ago

Just to update, I think we implement regular octahedral separable convolutions like this (still needs testing)

https://github.com/dgcnz/dl2/blob/0ac55bf1ab91ddced382274083f7c26c9631215a/src/models/components/gcnn/oh_t3.py#L24-L49

dgcnz commented 1 month ago

done