the current time downsample leaks causal information due to insufficient padding on the left side.
minimum repro/test, showing this and that the new version does not leak
import torch
import torch.nn as nn
from einops import rearrange, pack, unpack
import torch.nn.functional as F
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
class TimeDownsample2x(nn.Module):
def __init__(
self,
dim,
kernel_size = 3,
):
super().__init__()
self.time_pad = kernel_size - 1
self.conv = nn.Conv1d(dim, dim, kernel_size, stride=2)
def forward(self, x):
x = rearrange(x, 'b c t h w -> b h w c t')
x, ps = pack_one(x, '* c t')
x = F.pad(x, (self.time_pad, 0))
out = self.conv(x)
out = unpack_one(out, ps, '* c t')
out = rearrange(out, 'b h w c t -> b c t h w')
return out
if __name__ == '__main__':
img = torch.zeros([1, 3, 17, 64, 64])
downsample = TimeDownsample2x(3)
img[:, :, 0] = 1.0
img[:, :, 1] = 2.0
out1 = downsample(img)
# if we modify the 1st element, the 0th element should not change it's output
img[:, :, 1] = 3.0
out2 = downsample(img)
# slice out single time slice
out1_sub = out1[0, 0, :, 0, 0]
out2_sub = out2[0, 0, :, 0, 0]
print(out1_sub)
print(out2_sub)
assert out1_sub[0] == out2_sub[0]
the current time downsample leaks causal information due to insufficient padding on the left side.
minimum repro/test, showing this and that the new version does not leak