fepegar / unet

"pip install unet": PyTorch Implementation of 1D, 2D and 3D U-Net architecture.
MIT License
148 stars 22 forks source link

Forward pass going twice through ConvolutionalBlock layers #26

Closed GFabien closed 2 years ago

GFabien commented 4 years ago

Description

Working on a function to generate a Keras-like summary for Pytoch models from pytorch-summary I came across an issue with the ConvolutionalBlock class: layers are registered twice in the model leading to weird consequences.

Indeed, in a forward pass, the model goes twice through every convolutional layer in a ConvolutionalBlock. Here is a code to illustrate this behaviour:

import torch
from unet import UNet3D

def hook(module, i, o):
    class_name = str(module.__class__).split(".")[-1].split("'")[0]
    if class_name != 'Conv3d':
        return
    print(module.id)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input = torch.zeros((1, 3, 64, 64, 64)).to(device)
model = UNet3D(3, 3, num_encoding_blocks=2).to(device)
model.eval()

hooks = []
model.apply(lambda module: hooks.append(module.register_forward_hook(hook)))
model(input)
for h in hooks:
    h.remove()

To run it, you need to add this small change to the ConvolutionalBlock class in order to assign a unique id to every convolutional layer:

conv_layer.id = os.urandom(24).hex()

The above snippet produces the following output where you can see that every id appears twice: conv_ids

Fix

A simple fix seems to be to remove the following lines:

self.conv_layer = conv_layer
self.norm_layer = norm_layer
self.activation_layer = activation_layer
self.dropout_layer = dropout_layer

and add the out_channels variable as an attribute of the ConvolutionalBlock class to retrieve it in the EncodingBlock class through self.conv2.out_channels instead of self.conv2.conv_layer.out_channels.

If this fix seems reasonable I can take care of the PR.

fepegar commented 2 years ago

Hi, @GFabien. Thanks for reporting this. I know I'm very late, but I discovered only today that I wasn't watching this repo! Sorry about that.

I do get the ID of each module twice (I replaced module.id with id(module)). It must be because of the lines you mention in the Fix section. However, I've debugged the code and the forward pass happens only once. Phew!

If you think it's an issue that modules are registered twice, please do let me know.