Closed SlapDrone closed 2 years ago
I think the error is in the documentation. If i understood the code correctly, the skip_out connections connect the output of each residual block directly to the output of the TCN. In the documentation it says the skip_connections parameter specifies whether there should be a skip connection from the input to each residual block.
The skip connections inside the residual block are not affected by the parameter at all, if I read the code correctly.
I think it is an error, the _res_actx in residual block -->x in TCN and x in residual block --> skip_out in TCN. self.skip_connections save the x in residual block.
Hey guys,
I will update the README.
Here is the structure of one residual block (using tensorboard):
tensorboard --logdir logs
And I added:
from tensorflow.keras.callbacks import Callback
tensorboard = TensorBoard(
log_dir='/tmp/logs',
histogram_freq=1,
write_images=True
)
As a callback in the .fit()
function
The skip connections connect the output of each dilated conv stack (and not the residual) of all the residual blocks together.
It can be visualized here:
By dilated conv stack I mean this stack:
I'll this issue now. But feel free to re-open it. I pushed a new version 3.4.1 reflecting the updates (mostly renaming variables and README).
Hey folks,
I've been playing around a bit with the TCN model class and just wanted to check with you whether there is an issue in the way that the skip connections are included.
The residual block returns two outputs:
https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L151-L156
https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L157-L164
These are respectively the channel-matched input + the output of the convolutional layers (the typical output of block with a skip connection), and the output of the convolutional layers alone. In my mind then this first output (input + "residual") is what the network should use when
use_skip_connections
is flagged on, and the second is what should be used when it's flagged off.Later on when the TCN outputs are built from the residual blocks, these outputs are assigned to the variables
x
andskip_out
respectively, and I think this may be the wrong way round?https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L316-L320
https://github.com/philipperemy/keras-tcn/blob/2483cd99da5c97cda9763cb9d61baba56436e172/tcn/tcn.py#L323-L324
Specifically, when
use_skip_connections
is on, the TCN outputs the sum of those second outputs (somewhat confusingly namedskip_out
) - the conv block outputs themselves, without adding the inputs back.