paul-krug / pytorch-tcn

(Realtime) Temporal Convolutions in PyTorch
MIT License
55 stars 8 forks source link
causal causal-inference causal-models realtime realtime-neural-network streaming-inference temporal-convolutional-networks

PyTorch-TCN

Streamable (Real-Time) Temporal Convolutional Networks in PyTorch

This python package provides a flexible and comprehensive implementation of temporal convolutional neural networks (TCN) in PyTorch analogous to the popular tensorflow/keras package keras-tcn. Like keras-tcn, the implementation of pytorch-tcn is based on the TCN architecture presented by Bai et al., while also including some features of the original WaveNet architecture (e.g. skip connections) and the option for automatic reset of dilation sizes to allow training of very deep TCN structures.

Additionally, this package offers a streaming inference option for causal networks (with and without lookahead). This allows to process data in small blocks instead of the whole sequence, which is essential for real-time applications. See section Streaming Inference for more details.

Dilated causal (left) and non-causal convolutions (right).

Installation

pip install pytorch-tcn

How to use the TCN class

from pytorch_tcn import TCN

model = TCN(
    num_inputs: int,
    num_channels: ArrayLike,
    kernel_size: int = 4,
    dilations: Optional[ ArrayLike ] = None,
    dilation_reset: Optional[ int ] = None,
    dropout: float = 0.1,
    causal: bool = True,
    use_norm: str = 'weight_norm',
    activation: str = 'relu',
    kernel_initializer: str = 'xavier_uniform',
    use_skip_connections: bool = False,
    input_shape: str = 'NCL',
    embedding_shapes: Optional[ ArrayLike ] = None,
    embedding_mode: str = 'add',
    use_gate: bool = False,
    lookahead: int = 0,
    output_projection: Optional[ int ] = None,
    output_activation: Optional[ str ] = None,
)
# Continue to train/use model for your task

Input and Output shapes

The TCN expects input tensors of shape (N, Cin, L), where N, Cin, L denote the batch size, number of input channels and the sequence length, respectively. This corresponds to the input shape that is expected by 1D convolution in PyTorch. If you prefer the more common convention for time series data (N, L, Cin) you can change the expected input shape via the 'input_shape' parameter, see below for details. The order of output dimensions will be the same as for the input tensors.

Parameters and how to choose meaningful values

Streaming Inference

For kernel sizes > 1, a TCN will always use zero padding to ensure that the output has the same number of time steps as the input. This leads to problems during blockwise processing: E.g. let [ X1, X2, X3, X4 ] be an input sequence. With a kernel size of 3 and a dilation rate of 1, the padding length of the first convolutional layer would be 2. Hence, its input would look like this [ 0, 0, X1, X2, X3, X4 ] (for a causal network). If the same sequence is divided into two chunks [ X1, X2 ] and [ X3, X4 ], the effective input would look like this [ 0, 0, X1, X2] + [ 0, 0, X3, X4 ]. These discontinuities in the receptive field of the TCN will lead to different (and very likely degraded) outputs for the same input sequence divided into smaller chunks.

To avoid this issue, a buffer is implemented that stores the networks input history. The history is then used as padding for the next processing step. This way you will obtain the same results as if the whole sequence was processed at once.

For streaming inference the batch size must be 1.

How to use the streaming option

from pytorch_tcn import TCN

tcn = TCN(
    num_inputs,
    num_channels,
    causal=True,
)

# Important: reset the buffer before processing a new sequence
tcn.reset_buffers()

# blockwise processing
# block should be of shape:
# (1, block_size, num_inputs)
for block in blocks:
    out = tcn.inference(block)

# or alternatively

for block in blocks:
    out = tcn(block, inference=True)

Lookahead

Streaming inference does only make sense for causal networks. However, one may want to use a lookahead on future time frames to increase the modelling accuracy. This can be achieved by setting the lookahead parameter to a value greater than 0. The lookahead parameter specifies the number of future time steps that will be processed in addition to the input block.

Note that lookahead will introduce additional latency in real-time applications.


tcn = TCN(
    num_inputs,
    num_channels,
    causal=True,
    lookahead=1,
)

tcn.reset_buffers()

block # shape: (1, block_size + lookahead, num_inputs)
output = tcn.inference(block)

# output will be of shape:
# (1, block_size, num_outputs)