choosehappy / QuickAnnotator

An open-source digital pathology based rapid image annotation tool
BSD 3-Clause Clear License
74 stars 27 forks source link

Exporting trained UNet to ONNX #33

Closed naguileraleal closed 1 year ago

naguileraleal commented 1 year ago

Hi! I've trained a model using QA, and now I'm looking to import it into FastPathology. For that reason, I've downloaded the model from QA's web interface, I'm trying to convert the trained model to ONNX format.

To do that I'm using the following script

#Function to Convert to ONNX 
def Convert_ONNX(): 
    global model

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn((1, 3, 256, 256), requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "./fibrosis.onnx",       # where to save the model  
         verbose=True, 
         opset_version = 11)  # whether to execute constant folding for optimization 
    print(" ") 
    print('Model has been converted to ONNX') 

if _name_ == "_main_": 

    # Let's load the model we just created and test the accuracy per label 
    path = "/home/sofia/Downloads/best_model.pth" 
    checkpoint = torch.load(path, map_location=lambda storage,
                                                            loc: storage)   # load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666

    model = UNet(n_classes=checkpoint["n_classes"], in_channels=checkpoint["in_channels"],
                 padding=checkpoint["padding"], depth=checkpoint["depth"], wf=checkpoint["wf"],
                 up_mode=checkpoint["up_mode"], batch_norm=checkpoint["batch_norm"])
    model.load_state_dict(checkpoint["model_dict"])

    # Conversion to ONNX 
    Convert_ONNX()

When executing this script I get the following error

Traceback (most recent call last):
  File "pth_2_onnx.py", line 169, in <module>
    Convert_ONNX() 
  File "pth_2_onnx.py", line 134, in Convert_ONNX
    torch.onnx.export(model,         # model being run 
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/_init_.py", line 271, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 694, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 457, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args,
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 125, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 116, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "pth_2_onnx.py", line 67, in forward
    x = up(x, blocks[-i-1])
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "pth_2_onnx.py", line 116, in forward
    up = self.up(x)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/modules/upsampling.py", line 141, in forward
    return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
  File "/home/sofia/Apps/QuickAnnotator/venv/lib/python3.8/site-packages/torch/nn/functional.py", line 3462, in interpolate
    dim = input.dim() - 2  # Number of spatial dimensions.
AttributeError: 'NoneType' object has no attribute 'dim'

I'm in doubt of the dimensions in this line dummy_input = torch.randn((1, 3, 256, 256), requires_grad=True). Are they right?

Any help is much appreciated!

choosehappy commented 1 year ago

i've not used onnx before; does it need a single patch or a batch?

i ask because if you look at the code here:

https://github.com/choosehappy/QuickAnnotator/blob/57a13580a40ea10fe47637b7718bbe7c1051a424/train_model.py#L187

you'll see that torchsummary is able to provide a very nice summary, but doesn't have the batch dimension listed

naguileraleal commented 1 year ago

i've not used onnx before; does it need a single patch or a batch?

Me neither. I have the same doubt. From this tutorial I gather it needs a single patch. I suppose the dummy_input variable should have the same dimensions as the input that goes into the model when running in inference mode. I have not yet found out what these dimensions are, but I suppose I could know with a little logging inside make_output_unet_cmd.py.

choosehappy commented 1 year ago

These are the relevant lines of code:

https://github.com/choosehappy/QuickAnnotator/blob/57a13580a40ea10fe47637b7718bbe7c1051a424/config/config.ini#L56

https://github.com/choosehappy/QuickAnnotator/blob/57a13580a40ea10fe47637b7718bbe7c1051a424/make_output_unet_cmd.py#L33

https://github.com/choosehappy/QuickAnnotator/blob/57a13580a40ea10fe47637b7718bbe7c1051a424/make_output_unet_cmd.py#L136

patch size should be: Batch x 3 {RGB} x patch_size x patch_size

you can use whatever batch size will fit into GPU memory

it looks like the problem is somewhere else though?

    dim = input.dim() - 2  # Number of spatial dimensions.
AttributeError: 'NoneType' object has no attribute 'dim'

this would seem to suggest that input is "None" which is not related to particular size, but instead something like the type?

naguileraleal commented 1 year ago

I spent a little more time debugging the script and the problem arises when the torch.onnx.export() function executes the forward method of the model.

I'll now attach part of the UNet definition so I can refer to it.

class UNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,
                 batch_norm=False, up_mode='upconv' ,concat=True):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597

        Using the default arguments will yield the exact version used
        in the original paper

        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        self.concat = concat
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
                                                padding, batch_norm))
            prev_channels = 2**(wf+i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
                                            padding, batch_norm , concat))
            prev_channels = 2**(wf+i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path)-1:
                blocks.append(x)
                x = F.avg_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i-1])

        return self.last(x)

In the second for loop of the forward method, when i=0, the call to x = up(x, blocks[-i-1]) returns null as the value of x, and so on the next iteration the call is equivalent to x = up(null, blocks[-i-1]). This is what causes the exception. The thing is I have no idea on how to solve this.

I'll attach my best_model.pth in case you want to give it a try yourself.

best_model.zip

choosehappy commented 1 year ago

very strange, but you were able to actually train the model and make the associated output? the challenge is only in the onnx component? may be better to report this issue there?

On 23/03/2023 01:13, naguileraleal wrote:

I spent a little more time debugging the script and the problem arises when the torch.onnx.export() function executes the forward method of the model.

I'll now attach part of the UNet definition so I can refer to it.

class UNet(nn.Module): def init(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False, batch_norm=False, up_mode='upconv' ,concat=True): """ Implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) https://arxiv.org/abs/1505.04597

    Using the default arguments will yield the exact version used
    in the original paper

    Args:
        in_channels (int): number of input channels
        n_classes (int): number of output channels
        depth (int): depth of the network
        wf (int): number of filters in the first layer is 2**wf
        padding (bool): if True, apply padding such that the input shape
                        is the same as the output.
                        This may introduce artifacts
        batch_norm (bool): Use BatchNorm after layers with an
                           activation function
        up_mode (str): one of 'upconv' or 'upsample'.
                       'upconv' will use transposed convolutions for
                       learned upsampling.
                       'upsample' will use bilinear upsampling.
    """
    super(UNet, self).__init__()
    assert up_mode in ('upconv', 'upsample')
    self.padding = padding
    self.depth = depth
    self.concat = concat
    prev_channels = in_channels
    self.down_path = nn.ModuleList()
    for i in range(depth):
        self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
                                            padding, batch_norm))
        prev_channels = 2**(wf+i)

    self.up_path = nn.ModuleList()
    for i in reversed(range(depth - 1)):
        self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
                                        padding, batch_norm , concat))
        prev_channels = 2**(wf+i)

    self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

def forward(self, x):
    blocks = []
    for i, down in enumerate(self.down_path):
        x = down(x)
        if i != len(self.down_path)-1:
            blocks.append(x)
            x = F.avg_pool2d(x, 2)

    for i, up in enumerate(self.up_path):
        x = up(x, blocks[-i-1])

    return self.last(x)

In the second for loop of the forward method, when i=0, the call to x = up(x, blocks[-i-1]) returns null as the value of x, and so on the next iteration the call is equivalent to x = up(null, blocks[-i-1]). This is what causes the exception. The thing is I have no idea on how to solve this.

I'll attach my best_model.pth in case you want to give it a try yourself.

best_model.ziphttps://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fchoosehappy%2FQuickAnnotator%2Ffiles%2F11045689%2Fbest_model.zip&data=05%7C01%7Candrew.r.janowczyk%40emory.edu%7Ccec3022bd0e64857b80e08db2b33711f%7Ce004fb9cb0a4424fbcd0322606d5df38%7C0%7C0%7C638151275201766888%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=%2BTFGe4ZEyZRg4vDLm1PgC5UdkEzXeVwe7mwrQkl2%2FtA%3D&reserved=0

— Reply to this email directly, view it on GitHubhttps://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fchoosehappy%2FQuickAnnotator%2Fissues%2F33%23issuecomment-1480412882&data=05%7C01%7Candrew.r.janowczyk%40emory.edu%7Ccec3022bd0e64857b80e08db2b33711f%7Ce004fb9cb0a4424fbcd0322606d5df38%7C0%7C0%7C638151275201766888%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=jq64tG%2BNcpMS1rWKQACEE2hlRQlCIHr8sGBPYClHh0Y%3D&reserved=0, or unsubscribehttps://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FACJ3XTB66MN7O6SAJQA43ALW5OIS3ANCNFSM6AAAAAAWAMXK6U&data=05%7C01%7Candrew.r.janowczyk%40emory.edu%7Ccec3022bd0e64857b80e08db2b33711f%7Ce004fb9cb0a4424fbcd0322606d5df38%7C0%7C0%7C638151275201766888%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=ZRG4w76etJ1EdjDFrY4WxFWz0Ru2%2BBIHf35%2Fq0I68ag%3D&reserved=0. You are receiving this because you commented.Message ID: @.***>

naguileraleal commented 1 year ago

Hi! There was a bug in my script, specifically in the UNet class definition. I copy/pasted the structure of the class in the script and must made a mistake during that step. Importing the UNet class directly from QuickAnnotator's unet module fix my original issue and I was able to convert the model to ONNX.

I'm closing this issue now. Sorry for the inconvenience, and thanks again for your help.

choosehappy commented 1 year ago

great, glad you figured it out!