quancore / social-lstm

Social LSTM implementation in PyTorch
410 stars 248 forks source link

test total_error/final_error #6

Closed kingwmk closed 5 years ago

kingwmk commented 5 years ago

Excuse me.It seems that total_error/final_error in test stage, it compute the error between observed data and itself, beause ret_x_seq has been signed to ret_x_seq[:args.obs_length, :, :] = x_seq.clone() in The sample function. May I request your answer? Thanks very much!

Record the mean and final displacement error

        total_error += get_mean_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], sample_args.use_cuda, lookup_seq)
        final_error += get_final_error(ret_x_seq[1:sample_args.obs_length].data, orig_x_seq[1:sample_args.obs_length].data, PedsList_seq[1:sample_args.obs_length], PedsList_seq[1:sample_args.obs_length], lookup_seq)
quancore commented 5 years ago

ret_x_seq has only the same values with x_seq for observed part (ret_x_seq[:args.obs_length, :, :] = x_seq.clone() this line).After observing part, we are getting prediction part from model.Therefore, ret_x_seq has the same in observed part however, it has predicted values in prediction part.