uncbiag / ICON

A library for performing image registration using deep learning, regularized by inverse consistency
Other
42 stars 9 forks source link

Input shape requirements #78

Open koegl opened 2 months ago

koegl commented 2 months ago

What are the input shape requirements for the images passed to the network?

I'm referring to this function, where te shape is fixed. I want to use my own data for training, but I cannot figure out what the shape has to be

def make_network():

    phi = network_wrappers.FunctionFromVectorField(
        networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3)
    )
    psi = network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3))

    hires_net = icon_registration.GradientICON(
        network_wrappers.DoubleNet(
            network_wrappers.DownsampleNet(
                network_wrappers.TwoStepRegistration(phi, psi), dimension=3
            ),
            network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3)),
        ),
        icon_registration.LNCCOnlyInterpolated(sigma=5),
        3,
    )
    SCALE = 2  # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES
    input_shape = [1, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE]
    hires_net.assign_identity_map(input_shape)
    return hires_net

(from lncc_train_knees.py)

I tried forcing my images to [1, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE], but I get an error in the forward() of UNet2ChunkyMiddle at this line: x = torch.cat([x, skips[depth]], 1)

Sizes of tensors must match except in dimension 1. Expected size 6 but got size 1 for tensor number 1 in the list.

those are the shapes:

x.shape
torch.Size([6, 256, 4, 6, 6])
skips[depth].shape
torch.Size([1, 256, 5, 12, 12])
HastingsGreer commented 1 month ago

Hi! This is a great issue. The class UNet2ChunkyMiddle is from the paper ICON, before we were really aspiring to generalize to arbitrary images. Specifically, it only works if the input is a specific size, that specific size is not documented, and is missing any check that its input is that size. The short term fix is to just switch to UNet2, which is used for all stages in the GradICON and uniGradICON papers. The class UNet2 (which you can get an instance of with reasonable defaults by calling icon_registration.networks.tallUNet2(dimension=3) ) is fully parametric over input size and so won't have this error. Long term, the next update of icon needs to add the asserts for the input shape of UNet2ChunkyMiddle (and possibly deprecate it entirely)