hulianyuyy / SEN_CSLR

Self-Emphasizing Network for Continuous Sign Language Recognition (AAAI2023 Oral)
Apache License 2.0
41 stars 2 forks source link

A confusion about frame-by-frame difference calculation #13

Open 95AliceHong opened 1 month ago

95AliceHong commented 1 month ago

Hello, thank you for your work. In your paper, the functional description of TSEM is as follows: "By using global average pooling to eliminate spatial dimensions W and H, the difference between adjacent frames is calculated as approximate motion information." However, when I looked at the code for the TSEM module, I found that it is very similar to the code for the SSEM module, and there does not appear to be any frame-by-frame difference calculation. Could you explain my confusion? We look forward to your reply!

class TSEM(nn.Module):
    def __init__(self, input_size ):
        super(TSEM, self).__init__()
        hidden_size = input_size//16
        self.conv_transform = nn.Conv1d(input_size, hidden_size, kernel_size=1, stride=1, padding=0)
        self.conv_back = nn.Conv1d(hidden_size, input_size, kernel_size=1, stride=1, padding=0)
        #self.conv_enhance = nn.Conv1d(hidden_size, hidden_size, kernel_size=9, stride=1, padding=4)
        self.num = 5
        self.conv_enhance = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=int(i+1), groups=hidden_size, dilation=int(i+1)) for i in range(self.num)
        ])
        self.weights = nn.Parameter(torch.ones(self.num) / self.num, requires_grad=True)
        self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv_transform(x.mean(-1).mean(-1))
        aggregated_out = 0
        for i in range(self.num):
            aggregated_out += self.conv_enhance[i](out) * self.weights[i]
        out = self.conv_back(aggregated_out)
        return x*(F.sigmoid(out.unsqueeze(-1).unsqueeze(-1))-0.5) * self.alpha

class SSEM(nn.Module):
    def __init__(self, input_size ):
        super(SSEM, self).__init__()
        div_channel = input_size//16
        self.conv_transform = nn.Conv3d(input_size, div_channel, kernel_size=(1,1,1))
        self.num = 3
        self.conv_enhance = nn.ModuleList([
            nn.Conv3d(div_channel, div_channel, kernel_size=(9,3,3), padding=(4,i+1,i+1), dilation=(1,i+1,i+1), groups=div_channel) for i in range(self.num)
        ])

        self.weights = nn.Parameter(torch.ones(self.num) / self.num, requires_grad=True)
        self.conv_back = nn.Conv3d(div_channel, input_size, kernel_size=(1,1,1))
        self.alpha = nn.Parameter(torch.ones(1), requires_grad=True)

    def forward(self, x):
        out = self.conv_transform(x)
        aggregated_out = 0
        for i in range(self.num):
            aggregated_out += self.conv_enhance[i](out) * self.weights[i]
        out = self.conv_back(aggregated_out)
        return x*(F.sigmoid(out)-0.5) * self.alpha
hulianyuyy commented 1 month ago

As we later state in the implementation in the readme.md, we later found that a multi-scale architecture could perform on par with what we report in the paper for TSEM, and thus implement it as such. For the original processing code, it's designed as follows:

class TSEM(nn.Module):
    def __init__(self, input_size ):
        super(TSEM, self).__init__()
        hidden_size = input_size//16
        self.conv_transform = nn.Conv1d(input_size, hidden_size, kernel_size=1, stride=1, padding=0)
        self.conv_back = nn.Conv1d(hidden_size*2, input_size, kernel_size=1, stride=1, padding=0)
        self.conv_enhance = nn.Conv1d(hidden_size*2, hidden_size*2, kernel_size=5, stride=1, padding=2)

    def forward(self, x):
        out = self.conv_transform(x.mean(-1).mean(-1))
        B, C, T= out.size()
        out_diff = out[:, :, 1:] - out[:, :, :-1]
        out_diff = torch.cat([out_diff.new(B, C, 1).zero_(), out_diff], dim=2)
        out = self.conv_enhance(torch.cat((out_diff, out),1))
        return x + x*(F.sigmoid(out.unsqueeze(-1).unsqueeze(-1))-0.5)
95AliceHong commented 1 month ago

Okay, I understand now! Thank you for your reply!