abduallahmohamed / Social-STGCNN

Code for "Social-STGCNN: A Social Spatio-Temporal Graph Convolutional Neural Network for Human Trajectory Prediction" CVPR 2020
MIT License
483 stars 141 forks source link

Questions about some details in paper and codes #13

Closed simmonssong closed 4 years ago

simmonssong commented 4 years ago

Thank you for your interesting work. I have some question about some details within paper and codes.

  1. TCN It same that regular TCN is not introduced. In other words, it likes a mapping from input sequence to output sequence in temporal dimension, which achieved by torch.nn.Conv2d. Then, what's the meaning of the codes after that? Can you give some explanation?
    n, kc, t, v = x.size()
    x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
    x = torch.einsum('nkctv,kvw->nctw', (x, A))
  2. Sample Is Social-STGCNN a determinant model? How can it generate several different trajectories? (As shown in Figure 3, the distributions of multi-modal trajectories.)

Thanks lot for your nice work.

abduallahmohamed commented 4 years ago

Hi,

  1. We are not using regular TCN in our work, we just refer to the concept and correlate it with TXPCNN layer that treats the temporal dimension as a feature channel; unlike TCN that treats temporal data as pixel values. The einsum is Einstein sum which is a concept you can google for it; What we are trying to do here is to collapse the graph sequences into a single representation of <time,ped,features> using the graphs and their corresponding adjacency matrix (A); in other terms we weight the features from neighbor pedestrians to a specific pedestrian using A.

2, Social-STGCNN is not a deterministic model, if you refer to the loss function in the paper we model the trajectory as a bi-variate gaussian distribution and predict the 5 parameters of each trajectory in time which are mean_x, mean_y,variance_x,variance_y and correlation_xy. By predicting the distribution, you can sample multiple trajectories, in our testing we sample 20 trajectories as this was a community standard for these kind of problems.

Thanks

simmonssong commented 4 years ago

Thank you. Question 2. I got definition in Social-LSTM[1], a joint of several independent 2-dimension Gaussian distributions. Question 1. In ST-GCN[2] model, kernel_size is for spatial convolution on graph, where adjacency matrix is time-invariant. If my understanding is not wrong, it is a learnable kernel just like in regular CNNs. But in your paper, adjacency matrix is time-variant and non-learnable. So I think torch.einsum('nctv,tvw->nctw', (x, A)) is better. And parameter kernel_size can be removed.

The complete code is as follows. I'm testing whether this change will influence the result.

class ConvTemporalGraphical(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 t_kernel_size=1,
                 t_stride=1,
                 t_padding=0,
                 t_dilation=1,
                 bias=True):
        super(ConvTemporalGraphical,self).__init__()
        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(t_kernel_size, 1),
            padding=(t_padding, 0),
            stride=(t_stride, 1),
            dilation=(t_dilation, 1),
            bias=bias)

    def forward(self, x, A):
        assert A.size(0) == self.kernel_size
        x = self.conv(x)
        x = torch.einsum('nctv,tvw->nctw', (x, A))
        return x.contiguous(), A

[1] Alahi, A., Goel, K., Ramanathan, V., Robicquet, A., Fei-Fei, L., & Savarese, S. (2016). Social LSTM: Human trajectory prediction in crowded spaces. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2016-December, 961–971. [2] Yan, S., Xiong, Y., & Lin, D. (2018). Spatial temporal graph convolutional networks for skeleton-based action recognition. 32nd AAAI Conference on Artificial Intelligence, AAAI 2018, 7444–7452.

abduallahmohamed commented 4 years ago

Hi, Thanks for your notice on this; I re-ran the experiments again and obtained similar results as per your suggestions and it makes sense. I also updated the repo accordingly.

d-zh commented 4 years ago

Hi, I wonder which commit is the vesion of your paper published in CVPR 2020. Thanks!

abduallahmohamed commented 4 years ago

@d-zh https://github.com/abduallahmohamed/Social-STGCNN/tree/ebd57aaf34d84763825d05cf9d4eff738d8c96bb