divideconcept / PyTorch-libtorch-U-Net

A customizable 1D/2D U-Net model for libtorch (PyTorch C++ UNet)
MIT License
26 stars 4 forks source link

PyTorch/libtorch Customizable 1D/2D U-Net

A customizable 1D/2D U-Net model for libtorch (PyTorch c++ UNet)
Robin Lobel, March 2020 - Requires libtorch 1.4.0 or higher. CPU & CUDA compatible. Qt compatible.

The default parameters produce the original 2D UNet ( https://arxiv.org/pdf/1505.04597.pdf ) with all core improvements activated, resulting in a fully convolutional 2D network.
The default parameters for the 1D Unet are inspired by the Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf ) with all core improvements activated, resulting in a fully convolutional 1D network.

You can customize the number of in/out channels, the number of hidden feature channels, the number of levels, the size of the kernel, and activate improvements such as:

You can additionally display the size of all internal layers the first time you use the network

How to choose the parameters

Usage (2D UNet)

#include "cunet.h"

int main(int argc, char *argv[])
{
    int batchSize=16;
    int inChannels=3, outChannels=3;
    int height=512, width=512;

    CUNet2d model(inChannels,outChannels);
    torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));

    torch::Tensor source=torch::randn({batchSize,inChannels,height,width});
    torch::Tensor target=torch::randn({batchSize,outChannels,height,width});
    torch::Tensor result, loss;

    model->train();
    for (int epoch = 0; epoch < 100; epoch++)
    {
        optim.zero_grad();
        result = model(source);
        loss = torch::mse_loss(result, target);
        loss.backward();
        optim.step();
    }

    model->eval();
    torch::Tensor validation=torch::randn({batchSize,inChannels,height,width});
    torch::Tensor inference = model(validation);

    return 0;
}

Usage (1D UNet)

#include "cunet.h"

int main(int argc, char *argv[])
{
    int batchSize=16;
    int inChannels=2, outChannels=2;
    int size=2048;

    CUNet1d model(inChannels,outChannels);
    torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));

    torch::Tensor source=torch::randn({batchSize,inChannels,size});
    torch::Tensor target=torch::randn({batchSize,outChannels,size});
    torch::Tensor result, loss;

    model->train();
    for (int epoch = 0; epoch < 100; epoch++)
    {
        optim.zero_grad();
        result = model(source);
        loss = torch::mse_loss(result, target);
        loss.backward();
        optim.step();
    }

    model->eval();
    torch::Tensor validation=torch::randn({batchSize,inChannels,size});
    torch::Tensor inference = model(validation);

    return 0;
}