A customizable 1D/2D U-Net model for libtorch (PyTorch c++ UNet)
Robin Lobel, March 2020 - Requires libtorch 1.4.0 or higher. CPU & CUDA compatible. Qt compatible.
The default parameters produce the original 2D UNet ( https://arxiv.org/pdf/1505.04597.pdf ) with all core improvements activated, resulting in a fully convolutional 2D network.
The default parameters for the 1D Unet are inspired by the Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf ) with all core improvements activated, resulting in a fully convolutional 1D network.
You can customize the number of in/out channels, the number of hidden feature channels, the number of levels, the size of the kernel, and activate improvements such as:
You can additionally display the size of all internal layers the first time you use the network
#include "cunet.h"
int main(int argc, char *argv[])
{
int batchSize=16;
int inChannels=3, outChannels=3;
int height=512, width=512;
CUNet2d model(inChannels,outChannels);
torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));
torch::Tensor source=torch::randn({batchSize,inChannels,height,width});
torch::Tensor target=torch::randn({batchSize,outChannels,height,width});
torch::Tensor result, loss;
model->train();
for (int epoch = 0; epoch < 100; epoch++)
{
optim.zero_grad();
result = model(source);
loss = torch::mse_loss(result, target);
loss.backward();
optim.step();
}
model->eval();
torch::Tensor validation=torch::randn({batchSize,inChannels,height,width});
torch::Tensor inference = model(validation);
return 0;
}
#include "cunet.h"
int main(int argc, char *argv[])
{
int batchSize=16;
int inChannels=2, outChannels=2;
int size=2048;
CUNet1d model(inChannels,outChannels);
torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));
torch::Tensor source=torch::randn({batchSize,inChannels,size});
torch::Tensor target=torch::randn({batchSize,outChannels,size});
torch::Tensor result, loss;
model->train();
for (int epoch = 0; epoch < 100; epoch++)
{
optim.zero_grad();
result = model(source);
loss = torch::mse_loss(result, target);
loss.backward();
optim.step();
}
model->eval();
torch::Tensor validation=torch::randn({batchSize,inChannels,size});
torch::Tensor inference = model(validation);
return 0;
}