This code implements the Permutohedral Lattice for high dimensional filtering. Read the original paper. If you use this work please consider citing our paper in addition to the original one.
The code contains:
This code can be used to perform (approximate) bilateral filtering, gaussian filtering, non-local means etc... It also supports an arbitrary number of spatial dimensions, input channels and reference channels.
The TensorFlow op has gradients implemented and hence can be used with backprop, it can be used with batch_size>=1
.
This code was made with to be used as part of larger algorithms such as Conditional Random Fields (CRFs).
|
Install CMake (version >= 3.9).
Open the file build.sh
and change the variables CXX_COMPILER
and CUDA_COMPILER
to the path of the C++ and nvcc
(CUDA) compilers on your machine.
To compile the code run:
sh build.sh
This will create a directory called build_dir
which will contain the compiled code.
This script will try to compile code for both CPU and GPU at the same time, so if you don't want the GPU part
(and want the script to run) you must change CMakeLists.txt
.
Because of the way the GPU (CUDA) code is implemented, the number of spatial dimensions and number of channels of
the input and reference images must be known at compile time. This can be changed in the build.sh
script as well by
changing the variables SPATIAL_DIMS
, INPUT_CHANNELS
and REFERENCE_CHANNELS
.
If you only need the CPU version this variables do nothing to it and these values can be run-time values.
./build_dir/test_bilateral_cpu Images/input.bmp Images/output.bmp 8 0.125
./build_dir/test_bilateral_gpu Images/input.bmp Images/output.bmp 8 0.125
Look into TFOpTests for actual working examples.
Example of bilateral filtering a 2D filtering gray scale image based on a RGB image.
On GPU compile with SPATIAL_DIMS=2
, INPUT_CHANNELS=1
and REFERENCE_CHANNELS=3
import tensorflow as tf
import lattice_filter_op_loader
input = tf.placeholder(shape=(batch_size, width, height, 1))
reference = tf.placeholder(shape=(batch_size, width, height, 3))
output = module.lattice_filter(input, reference_image, bilateral=True, theta_alpha=8, theta_beta=0.125)
# Then run the graph, load, save images
|
SPATIAL_DIMS
, INPUT_CHANNELS
and REFERENCE_CHANNELS
at run time.