paul-krug / pytorch-tcn

(Realtime) Temporal Convolutions in PyTorch
MIT License
55 stars 8 forks source link

Receptive Field? #13

Closed Bruno-TT closed 5 months ago

Bruno-TT commented 5 months ago

Hi,

I'm investigating the behaviour of my TCN with dilations [1,2,4,8,16,32,64] and kernel size 4

The standard receptive field formula would give 763, but I've been measuring it and my model will produce different outputs up to 2304 steps back, beyond which it will stop changing.

So I thought that was strange, but I found the relationship 2305=(kernel_size-1)*763+16 so I assumed I'd misunderstood the architecture of the model, but upon adding another layer this 2305 number stays the same! Any ideas why this could be the case? How should I correctly calculate the receptive field?

Many thanks,

Bruno

paul-krug commented 5 months ago

Could you please post a minimal working example of how to reproduce your observation? It may has to do with the way you are testing the receptive field, but it is difficult to tell without seeing your code.

Bruno-TT commented 5 months ago
import torch
from pytorch_tcn import TCN

length=1000

data=torch.randn(1, length, 32)

dilations=[1,2,4,8,16,32,64]
num_channels=[64,56,48,40,32,24,16]
kernel_size=4
receptive_field=1+2*(kernel_size-1)*sum(dilations)

print(f"Receptive Field: {receptive_field}")

tcn = TCN(
    num_inputs=32,
    num_channels=num_channels,
    kernel_size=kernel_size,
    dilations=dilations,
    dropout=0.1,
    causal=True,
    use_norm=None,
    activation='leaky_relu',
    kernel_initializer='xavier_normal',
    use_skip_connections=True,
    input_shape='NLC')

tcn.eval()
with torch.no_grad():
    output_on_all_data=tcn(data)[:,-1,:]
    for last_i_rows in range(1,length+1):
        output=tcn(data[:,-last_i_rows:,:])[:,-1,:]
        if (output==output_on_all_data).all().cpu().item():
            print(f"Output same with last {last_i_rows} rows")
        else:
            print(f"Output different with last {last_i_rows} rows")

Here you are - thanks so much for the speedy response. Receptive field is printing 763, but the program prints different for all i except 1000.

paul-krug commented 5 months ago

Ok so I cross checked your code and there is one major issue: With "output==output_on_all_data" you are testing equality on floating point numbers, which is not allowed. If one uses "output.isclose(output_on_all_data)" instead, the code works as intended. Then you wil see that you get the switch of messages at the correct index as calculated with the receptive field formula. I noticed however, that after that, sometimes you can still get the message "Output different with last {last_i_rows} rows", and I found that the frequency with how often this occurs changes with "use_skip_connection=False instead of true" and the type of kernel_initalizer. I am pretty sure though, that this is due to floating point precision and the precision-threshold in torch.isclose(). So my conclusion is that the receptive field works as intended.

Bruno-TT commented 4 months ago

Hi, you're right. Rookie mistake on my part, my apologies. Thanks very much for looking into this, and apologies for accusing you of having broken code haha.