lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
565 stars 34 forks source link

Fix causal info leakage in time downsample #27

Closed matwilso closed 10 months ago

matwilso commented 10 months ago

the current time downsample leaks causal information due to insufficient padding on the left side.

minimum repro/test, showing this and that the new version does not leak

import torch
import torch.nn as nn
from einops import rearrange, pack, unpack
import torch.nn.functional as F

def pack_one(t, pattern):
    return pack([t], pattern)
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

class TimeDownsample2x(nn.Module):
    def __init__(
        self,
        dim,
        kernel_size = 3,
    ):
        super().__init__()
        self.time_pad = kernel_size - 1 
        self.conv = nn.Conv1d(dim, dim, kernel_size, stride=2)

    def forward(self, x):
        x = rearrange(x, 'b c t h w -> b h w c t')
        x, ps = pack_one(x, '* c t')

        x = F.pad(x, (self.time_pad, 0))
        out = self.conv(x)

        out = unpack_one(out, ps, '* c t')
        out = rearrange(out, 'b h w c t -> b c t h w')
        return out

if __name__ == '__main__':
    img = torch.zeros([1, 3, 17, 64, 64])
    downsample = TimeDownsample2x(3)

    img[:, :, 0] = 1.0
    img[:, :, 1] = 2.0

    out1 = downsample(img)

    # if we modify the 1st element, the 0th element should not change it's output
    img[:, :, 1] = 3.0
    out2 = downsample(img)

    # slice out single time slice
    out1_sub = out1[0, 0, :, 0, 0]
    out2_sub = out2[0, 0, :, 0, 0]

    print(out1_sub)
    print(out2_sub)
    assert out1_sub[0] == out2_sub[0]
lucidrains commented 10 months ago

🤦‍♂️ yes indeed, thank you