quancore / social-lstm

Social LSTM implementation in PyTorch
430 stars 256 forks source link

Bug in test.py #20

Open cuihenggang opened 5 years ago

cuihenggang commented 5 years ago

Hello,

Can someone please at a look at these two lines in test.py? https://github.com/quancore/social-lstm/blob/master/test.py#L202-L203

            total_error += get_mean_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], sample_args.use_cuda, lookup_seq)
            final_error += get_final_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], lookup_seq)

More specifically, why are the prediction errors computed on the 1:sample_args.obs_length range, which I think represents the observed data.

quancore commented 5 years ago

I think, you are right. Feel free to send a PR.

Miaowaaaa commented 4 years ago

@cuihenggang Have you fixed this bug?I modify the 1:sample_args.obs_length to sample_args.obs_length:seq_lenght, however, the error output is nan.

hzzzzjzyq commented 3 years ago

It may be because the test data set lacks 12 frames of data, you can use https://github.com/julioba/trajectoryprediction_studienarbeit to test

aminmanafi commented 2 years ago

Hello,

Can someone please at a look at these two lines in test.py? https://github.com/quancore/social-lstm/blob/master/test.py#L202-L203

            total_error += get_mean_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], sample_args.use_cuda, lookup_seq)
            final_error += get_final_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], lookup_seq)

More specifically, why are the prediction errors computed on the 1:sample_args.obs_length range, which I think represents the observed data.

I think, you are right. Feel free to send a PR.

Hello I am working on this paper and I have the same issue. Could you help me to solve this problem? I want a solution to compare my work and this paper after I have done it, So the solution should give me the same results which are in the paper. Thanks