philipperemy / keras-tcn

Keras Temporal Convolutional Network.
MIT License
1.89k stars 455 forks source link

Possible bug in use_skip_connections #218

Closed SlapDrone closed 2 years ago

SlapDrone commented 3 years ago

Hey folks,

I've been playing around a bit with the TCN model class and just wanted to check with you whether there is an issue in the way that the skip connections are included.

The residual block returns two outputs:

https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L151-L156

https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L157-L164

These are respectively the channel-matched input + the output of the convolutional layers (the typical output of block with a skip connection), and the output of the convolutional layers alone. In my mind then this first output (input + "residual") is what the network should use when use_skip_connections is flagged on, and the second is what should be used when it's flagged off.

Later on when the TCN outputs are built from the residual blocks, these outputs are assigned to the variables x and skip_out respectively, and I think this may be the wrong way round?

https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L316-L320

https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L323-L324

Specifically, when use_skip_connections is on, the TCN outputs the sum of those second outputs (somewhat confusingly named skip_out) - the conv block outputs themselves, without adding the inputs back.

RomanKrajewski commented 3 years ago

I think the error is in the documentation. If i understood the code correctly, the skip_out connections connect the output of each residual block directly to the output of the TCN. In the documentation it says the skip_connections parameter specifies whether there should be a skip connection from the input to each residual block.

The skip connections inside the residual block are not affected by the parameter at all, if I read the code correctly.

changyuanhong commented 2 years ago

I think it is an error, the _res_actx in residual block -->x in TCN and x in residual block --> skip_out in TCN. self.skip_connections save the x in residual block.

philipperemy commented 2 years ago

Hey guys,

I will update the README.

Here is the structure of one residual block (using tensorboard):

tensorboard --logdir logs

And I added:

from tensorflow.keras.callbacks import Callback
tensorboard = TensorBoard(
    log_dir='/tmp/logs',
    histogram_freq=1,
    write_images=True
)

As a callback in the .fit() function

image

The skip connections connect the output of each dilated conv stack (and not the residual) of all the residual blocks together.

It can be visualized here:

image

By dilated conv stack I mean this stack:

image

philipperemy commented 2 years ago

I'll this issue now. But feel free to re-open it. I pushed a new version 3.4.1 reflecting the updates (mostly renaming variables and README).