PyTorch binding? #2

Open turian opened 3 years ago

turian commented 3 years ago

Would it be possible to offer a pytorch binding? So this could be called from Python code?

divideconcept commented 3 years ago

If you need equivalent models in PyTorch code, there you go (it'll be part of my new software https://torchstudio.ai ) :

2D customizable UNet:

import torch
import torch.nn as nn
import torch.nn.functional as F

def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
    sequence = []
    for i in range(conv_per_block):
        sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
        if batch_norm:
            #BatchNorm best after ReLU:
    return nn.Sequential(*sequence)

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):

        self.pooling = pooling

        self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)

        if self.pooling:
            if not conv_downscaling:
                self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
                self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)

    def forward(self, x):
        x = self.block(x)
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
                 add_merging, conv_upscaling):

        self.add_merging = add_merging

        if not conv_upscaling:
            self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
            self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
            nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))

        self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)

    def forward(self, from_down, from_up):
        from_up = self.upconv(from_up)
        if not self.add_merging:
            x = torch.cat((from_up, from_down), 1)
            x = from_up + from_down
        x = self.block(x)
        return x

class UNet2D(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    UNet is a convolutional encoder-decoder neural network.

    Default parameters correspond to the original UNet, except
    convolutions use padding to preserve the original size.

    Optional Node improvements:
    Strided Convolution instead of Strided Max Pooling for Downsampling ( https://arxiv.org/pdf/1412.6806.pdf, https://arxiv.org/pdf/1701.03056.pdf , https://arxiv.org/pdf/1606.04797.pdf )
    Resize Convolution instead of Strided Deconvolution for Upsampling ( https://distill.pub/2016/deconv-checkerboard/ , https://www.kaggle.com/mpalermo/remove-grideffect-on-generated-images/notebook , https://arxiv.org/pdf/1806.02658.pdf )
    Partial Convolution to fix Zero-Padding bias ( https://arxiv.org/pdf/1811.11718.pdf , https://github.com/NVIDIA/partialconv )
    BatchNorm ( https://arxiv.org/abs/1502.03167 )

    def __init__(self, in_channels=1, out_channels=2, feature_channels=64,
                       depth=5, conv_per_block=2, kernel_size=3, batch_norm=False,
                       conv_upscaling=False, conv_downscaling=False, add_merging=False):
            in_channels: int, number of channels in the input tensor.
            out_channels: int, number of channels in the output tensor.
            feature_channels: int, number of channels in the first and last hidden feature layer.
            depth: int, number of levels
            conv_per_block: int, number of convolutions per level block
            kernel_size: int, kernel size for all block convolutions
            batch_norm: bool, add a batch norm after ReLU
            conv_upscaling: use a nearest upscale+conv instead of transposed convolution
            conv_downscaling: use a strided convolution instead of maxpooling
            add_merging: merge layers from different levels using a add instead of a concat

        self.out_channels = out_channels
        self.in_channels = in_channels
        self.feature_channels = feature_channels
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.feature_channels*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                                conv_downscaling, pooling=pooling)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                            conv_upscaling=conv_upscaling, add_merging=add_merging)

        self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

    def forward(self, x):
        encoder_outs = []

        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)

        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

1D customizable UNet:

import torch
import torch.nn as nn
import torch.nn.functional as F

def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
    sequence = []
    for i in range(conv_per_block):
        sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
        if batch_norm:
            #BatchNorm best after ReLU:
    return nn.Sequential(*sequence)

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):

        self.pooling = pooling

        self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)

        if self.pooling:
            if not conv_downscaling:
                self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
                self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)

    def forward(self, x):
        x = self.block(x)
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
                 add_merging, conv_upscaling):

        self.add_merging = add_merging

        if not conv_upscaling:
            self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
            self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
            nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))

        self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)

    def forward(self, from_down, from_up):
        from_up = self.upconv(from_up)
        if not self.add_merging:
            x = torch.cat((from_up, from_down), 1)
            x = from_up + from_down
        x = self.block(x)
        return x

class UNet1D(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    UNet is a convolutional encoder-decoder neural network.

    This 1D variant is inspired by 1D Unet are inspired by the
    Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf )
    Default parameters correspond to the Wave UNet.
    Convolutions use padding to preserve the original size.

    Optional Node improvements:
    Strided Convolution instead of Strided Max Pooling for Downsampling ( https://arxiv.org/pdf/1412.6806.pdf, https://arxiv.org/pdf/1701.03056.pdf , https://arxiv.org/pdf/1606.04797.pdf )
    Resize Convolution instead of Strided Deconvolution for Upsampling ( https://distill.pub/2016/deconv-checkerboard/ , https://www.kaggle.com/mpalermo/remove-grideffect-on-generated-images/notebook , https://arxiv.org/pdf/1806.02658.pdf )
    Partial Convolution to fix Zero-Padding bias ( https://arxiv.org/pdf/1811.11718.pdf , https://github.com/NVIDIA/partialconv )
    BatchNorm ( https://arxiv.org/abs/1502.03167 )

    def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
                       depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
                       conv_upscaling=False, conv_downscaling=False, add_merging=False):
            in_channels: int, number of channels in the input tensor.
            out_channels: int, number of channels in the output tensor.
            feature_channels: int, number of channels in the first and last hidden feature layer.
            depth: int, number of levels
            conv_per_block: int, number of convolutions per level block
            kernel_size: int, kernel size for all block convolutions
            batch_norm: bool, add a batch norm after ReLU
            conv_upscaling: use a nearest upsize+conv instead of transposed convolution
            conv_downscaling: use a strided convolution instead of maxpooling
            add_merging: merge layers from different levels using a add instead of a concat

        self.out_channels = out_channels
        self.in_channels = in_channels
        self.feature_channels = feature_channels
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.feature_channels*(i+1)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                                conv_downscaling, pooling=pooling)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins - self.feature_channels
            up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                            conv_upscaling=conv_upscaling, add_merging=add_merging)

        self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

    def forward(self, x):
        encoder_outs = []

        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)

        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x
swilson314 commented 3 years ago

I've used your c++ model for creating/training a unet model -- many thanks!! I'm now in the process of translating the model to onnx so I can embed it. So the python listing is extremely helpful too, since the onnx export is currently only supported from python. I noticed that the c++ and python models are not precisely the same. In particular, the python model seems to have two new variables, conv_per_block and add_merging. Per comparison with the c++ code, I think conv_per_block=2 and add_merging=false. Did I miss any other differences?

divideconcept commented 3 years ago

That's right, if you leave those default values for conv_per_block and add_merging you'll get the same results as the C++ UNet.

swilson314 commented 3 years ago


swilson314 commented 3 years ago

I found a few places where the python and c++ code wasn't identical (required for onnx). See the SBW comments below:

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
                 add_merging, conv_upscaling):

        self.add_merging = add_merging

        if not conv_upscaling:
            self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
            self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
            # SBW 2021.10.01 Bug fix: sync with c++ version.
            # nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, groups=1, stride=1))

        self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)

    def forward(self, from_down, from_up):
        from_up = self.upconv(from_up)
        # print(from_up)
        # print(from_down)
        if not self.add_merging:
            # SBW 2021.10.04 Bug fix: sync with c++ version.
            # x = torch.cat((from_up, from_down), 1)
            x = torch.cat((from_down, from_up), 1)
            x = from_up + from_down
        x = self.block(x)
        return x
def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
                   depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
                   conv_upscaling=False, conv_downscaling=False, add_merging=False):
        in_channels: int, number of channels in the input tensor.
        out_channels: int, number of channels in the output tensor.
        feature_channels: int, number of channels in the first and last hidden feature layer.
        depth: int, number of levels
        conv_per_block: int, number of convolutions per level block
        kernel_size: int, kernel size for all block convolutions
        batch_norm: bool, add a batch norm after ReLU
        conv_upscaling: use a nearest upsize+conv instead of transposed convolution
        conv_downscaling: use a strided convolution instead of maxpooling
        add_merging: merge layers from different levels using a add instead of a concat

        self.out_channels = out_channels
        self.in_channels = in_channels
        self.feature_channels = feature_channels
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.feature_channels*(i+1)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                                conv_downscaling, pooling=pooling)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins - self.feature_channels
            up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
                            conv_upscaling=conv_upscaling, add_merging=add_merging)

        # SBW 2021.10.01 Bug fix: sync with c++ version. Moved to below.
        # self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        # SBW 2021.10.01 Bug fix: sync with c++ version. Moved to below.
        self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)