PyTorch implementation of Sparse Shift Layer(SSL) for 3D, 4D and 5D tensors from "All You Need is a Few Shifts: Designing Efficient Convolutional Neural Networks for Image Classification" (https://arxiv.org/pdf/1903.05285.pdf)
(I am not the author any of mentioned articles, I just implement this for my own purposes)
shifts tensor data(in memory) by indexes. Value and direction of shift are learnable and different between channels. It might be considered as Zero-FLOP replacement of DepthWise Convolution, with 4.5x less memory consumption(in compare with 3x3 DepthWise Convolution).
GroupedShift: First known application of shifts operator as replace of depthwise convolution. It utilize shifts as their exact form on forward and backward, hence the shifts values (weights) are not learnable (and for simplicity applied to group of channels, see article for detail) and act like hyperparams.
(Officially we have not support this kind of shifts here, but for exact
Active Shift: Replacing shift operation on linear(bi-,tri- for 2D,3D cases) interpolation on both forward and backward pass. "Shifts" values became learnable (because they are floats) and moreover shifts defined for each channel.
Sparse Shift Layer(SSL): The combination of two above articles. "Shifts" values are still learnable vi interpolation(on backward pass), and use exact shift operator on forward pass ("shift" values just rounded during forward pass). So we have simple Zero-FLOP shift operation (which is also native quantized, because shift operator require integer values), instead of DepthWise convolution!
Sparse - stands to L1 regularization on weights, this obviously sparsifying the shifts values among channel axis!
By default all Shift modules are Sparse Shift Layers! The module is always returns output
and loss
, where is last is L1 regularization loss(see theory), which should be added to general loss for take an effect!
Active Shift can be enabled by setting active_flag=True
, and sparsity_term=0
, because we do not need to compute regularization term(at least in original article).
Grouped Shifts are not officially supported here, however technically it possible: set active_flag=False
and sparsity_term=0
, freeze .weights
params from gradient computation like shift_layer.weights.requires_grad = False
(inside C function the gradient for weights will be always computed, so you will not gain in performance) and don't forget properly re-initialize .weights
values(including channels groups, etc.)
We implement several padding variants for filling empty values after shifts: Zeros (by default), Border, Periodic(stands for circular shifts!), Reflect and Symmetric. See here for details.(This paddings is also used during interpolation calculation)
C++17 must be supported by your compiler! (due to constexpr in code)
PyTorch >= 1.7.0;
cd ActiveSparseShifts-PyTorch
python torch_patch.py
(anyway it will automatically run during step 3) to fix it.
This script change a few lines of code in single C++ header file, however doing this directly in python dist-package folder.
Please, be sure that you have rights for changing files inside this folder!
Anyway, you should do it only once for each python environment(PyTorch package).
(If something will going wrong, please inspect torch_patch.py
first (it very simple) and try to reproduce patch manually.)python setup.py install
or python setup.py bdist_wheel
- to install/build packageExample:
from torchshifts import Shift1d, Shift2d, Shift3d
shift_layer = Shift1d(in_channels=3)
Additional options for shift layer:
padding(str) - Padding for filling empty values.
Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'.
init_shift(float) - Border for uniform initialization of weights(shifts): [-init_shift; init_shift]. Default: 1.
sparsity_term(float) - Strength of sparsity. Default: 5e-4.
active_flag(bool) - Enable Active Shift instead of SSL. Default: False
emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer.
It applies a heuristic and try to emulate their properties(including output shape)
init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift),
Type 2: uniform(0,init_shift) * random_sign
Depthwise Convolution Emulation: Provides a heuristic rules for emulation of DepthWise Convolution via Shift layer in terms of output shape and shift kernel behaviour.
Pytorch Quantization: SSL shifts can be used in quantized pipeline! Shifts do not needed the activation tracking and so model with shift module can be easily converted by following:
from torchshifts import quant_mapping
torch.quantization.convert(<model_with_Shift_module>, ..., mapping=quant_mapping)
Pytorch JIT: We support it out-of-box:
torch.jit.trace_module(<model_with_Shift_module>)