Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.74k stars 1.05k forks source link

Using 3rd Party Pytorch Networks with MONAI loader+transformations? #750

Closed petteriTeikari closed 4 years ago

petteriTeikari commented 4 years ago

Is your feature request related to a problem? Please describe. I wanted to try a 3D segmentation net (namely this https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/networks/unet_CT_multi_att_dsv_3D.py) with my existing codebase

by simply replacing this (other parts similar to the to Spleen Pytorch Lightning tutorial https://github.com/Project-MONAI/MONAI/blob/master/examples/notebooks/spleen_segmentation_3d_lightning.ipynb)

model = monai.networks.nets.UNet()

with

model = unet_CT_multi_att_dsv_3D()

And it seems to be working (works okay with the Standard Monai U-Net), but the net seems to be loaded again to GPU memory upon validation dataset

image

-> validation set

image

Describe the solution you'd like Is there a recommended way (tutorial coming) to use existing networks with Monai when you have the data pipe working for your dataset? Without the "double load"? Where I would release the GPU memory between splits (e.g. https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530, https://github.com/PyTorchLightning/pytorch-lightning/issues/458)

Additional context

with the dataset

train_ds = Dataset(data=datalist_train, transform=train_trans)
val_ds = Dataset(data=datalist_val, transform=val_trans)

transformations

 train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg', 'mask']),
        Rotated(keys=['img', 'seg', 'mask'], angle=90), # so that mouth is down
        AddChanneld(keys=['img', 'seg', 'mask']),
        Spacingd(keys=['img', 'seg', 'mask'], pixdim=(1., 1., 1.)),
        Orientationd(keys=['img', 'seg', 'mask'], axcodes='RAS'),
        ScaleIntensityRanged(keys=['img'], a_min=train_params['scale_min'], a_max=train_params['scale_max'],
                             b_min=train_params['intensity_out_min'], b_max=train_params['intensity_out_max'], clip=True),
        RandCropByPosNegLabeld(keys=['img', 'seg', 'mask'], label_key='seg',
                               spatial_size=augmentation_params['patch_size'],
                               pos=1, neg=1,
                               num_samples=4, image_key='img', image_threshold=0),
        Rand3DElasticd(
            keys=['img', 'seg', 'mask'], mode=('bilinear'), prob=1.0,
            sigma_range=(5, 8),
            magnitude_range=(100, 200),
            spatial_size=augmentation_params['patch_size'],
            translate_range=(50, 50, 2),
            rotate_range=(np.pi / 36, np.pi / 36, np.pi/36), # 5 degrees to each direction
            scale_range=(0.15, 0.15, 0.15),
            padding_mode='border'),
        RandGaussianNoised(keys=['img'], prob=0.5, mean=0.0, std=0.5),
        ToTensord(keys=['img', 'seg', 'mask'])
    ])
    val_transforms = Compose([
        LoadNiftid(keys=['img', 'seg', 'mask']),
        Rotated(keys=['img', 'seg', 'mask'], angle=90), # so that mouth is down
        AddChanneld(keys=['img', 'seg', 'mask']),
        Spacingd(keys=['img', 'seg', 'mask'], pixdim=(1., 1., 1.)),
        Orientationd(keys=['img', 'seg', 'mask'], axcodes='RAS'),
        ScaleIntensityRanged(keys=['img'], a_min=train_params['scale_min'], a_max=train_params['scale_max'],
                             b_min=train_params['intensity_out_min'], b_max=train_params['intensity_out_max'], clip=True),
        ToTensord(keys=['img', 'seg', 'mask'])
    ])

Full network definition:

class unet_CT_multi_att_dsv_3D(nn.Module):

    def __init__(self, feature_scale=4, n_classes=2, is_deconv=True, in_channels=1,
                 nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True,
                 filters = [64, 128, 256, 512, 1024]):
        super(unet_CT_multi_att_dsv_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)

        # upsampling
        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
        self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
        self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
        self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)

        # final conv (without any concat)
        self.final = nn.Conv3d(n_classes*4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')

    def forward(self, x):

        # Feature Extraction
        conv1 = self.conv1(x)
        maxpool1 = self.maxpool1(conv1)

        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)

        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)

        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)

        # Gating Signal Generation
        center = self.center(maxpool4)
        gating = self.gating(center)

        # Attention Mechanism
        # Upscaling Part (Decoder)
        g_conv4, att4 = self.attentionblock4(conv4, gating)
        up4 = self.up_concat4(g_conv4, center)
        g_conv3, att3 = self.attentionblock3(conv3, up4)
        up3 = self.up_concat3(g_conv3, up4)
        g_conv2, att2 = self.attentionblock2(conv2, up3)
        up2 = self.up_concat2(g_conv2, up3)
        up1 = self.up_concat1(conv1, up2)

        # Deep Supervision
        dsv4 = self.dsv4(up4)
        dsv3 = self.dsv3(up3)
        dsv2 = self.dsv2(up2)
        dsv1 = self.dsv1(up1)
        final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))

        return final

    @staticmethod
    def apply_argmax_softmax(pred):
        log_p = F.softmax(pred, dim=1)
        return log_p

class MultiAttentionBlock(nn.Module):

    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
                                                 inter_channels=inter_size, mode=nonlocal_mode,
                                                 sub_sample_factor= sub_sample_factor)
        self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
                                                 inter_channels=inter_size, mode=nonlocal_mode,
                                                 sub_sample_factor=sub_sample_factor)
        self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),
                                           nn.BatchNorm3d(in_size),
                                           nn.ReLU(inplace=True)
                                           )

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue
            init_weights(m, init_type='kaiming')

    def forward(self, input, gating_signal):
        gate_1, attention_1 = self.gate_block_1(input, gating_signal)
        gate_2, attention_2 = self.gate_block_2(input, gating_signal)

        return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)
Nic-Ma commented 4 years ago

Hi @petteriTeikari ,

Thanks for your experiments with MONAI. I didn't quite understand the "double load" issue you said, how did you detect that it loaded twice? Is it some PyTorch Lightning specific issue?

Thanks.

petteriTeikari commented 4 years ago

@Nic-Ma "Double-load" as in when I start training on first epoch the GPU memory usage is ~2.7 GB, and on first validation set, the GPU memory usage goes to ~6.7 GB, and at the initial validation loading the memory use went above 7.2 GB briefly.

Which is not the behavior that I have seen with the standard "MONAI U-Net", which allocates the GPU ram it needs for the model at start?

And actually after overnight training attempt, the ancdata error, happening between 30-40 epochs of training, and before that the loss seemed to be falling so sorta working

Nic-Ma commented 4 years ago

Hi @petteriTeikari ,

I checked your network implementation, unfortunately, I didn't find any explicit memory-related difference between your network and MONAI UNet. Could you please help paste all your program here, maybe our PyTorch Lightning experts @marksgraham and @ericspod can also help take a look at your issue.

Thanks.

petteriTeikari commented 4 years ago

@Nic-Ma

I had a bit of look what is going on and actually yes the increase in GPU memory occurs always upon first validation (on first epoch), but similar funky nondeterministic error occurred during training with another error message though "AttributeError: 'Net' object has no attribute 'self' ": https://github.com/petteriTeikari/MONAI_lightning_segmentation/issues/1

I will test these without the PyTorch Lightning part and see if this makes any difference