coxlab / prednet

Code and models accompanying "Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"
https://arxiv.org/abs/1605.08104
MIT License
759 stars 259 forks source link

Mean Squared Error calculation appears to be wrong #69

Open dHannasch opened 4 years ago

dHannasch commented 4 years ago

I might be completely off base here, but it looks like the computation of Mean Squared Error (MSE) https://github.com/coxlab/prednet/blob/master/kitti_evaluate.py#L58 is comparing the true current frame with the predicted next frame.

That is, we compute X_hat = test_model.predict(X_test, batch_size), which as I understand it includes as its first element the predicted second frame (predicted based on the true first frame). Thus the first element of X_hat[:, 1:] is the predicted third frame (prediction based on the true second frame) while the first element of X_test[:, 1:] is that true second frame. So it appears that the first element of X_test[:, 1:] - X_hat[:, 1:] is the difference between the predicted third frame and the true second frame (its own input). It seems like probably what you really want is the difference between the predicted third frame and the true third frame, isn't it?

(Later, of course, you deliberately calculate the difference between the current frame and the previous frame https://github.com/coxlab/prednet/blob/master/kitti_evaluate.py#L59 , but I thought that was just the control group, checking what MSE you get from the "dumb" strategy of always predicting that the next frame will be identical to the current frame.)

bill-lotter commented 4 years ago

Hi it's a good question, but it actually is the case that the timestep dimension contains the actual frame at time t and the predicted frame at time t, so X_test[:, 1] is the true 2nd frame and X_hat[:, 1] is the predicted 2nd frame given X_hat[:, 0]. Essentially where this manifests is here where the R units and prediction are updated before A is used: https://github.com/coxlab/prednet/blob/843b452f52bb44ca79ba6f65365e9d60cebed489/prednet.py#L254