princeton-vl / pytorch_stacked_hourglass

Pytorch implementation of the ECCV 2016 paper "Stacked Hourglass Networks for Human Pose Estimation"
BSD 3-Clause "New" or "Revised" License
469 stars 94 forks source link

`model.train()` usage #44

Closed ashutoshsingh0223 closed 1 year ago

ashutoshsingh0223 commented 1 year ago

The code uses model.train() during both train and valid phases, wouldn't this effect the running stats for batchnorm layers. Is this intentional?

https://github.com/princeton-vl/pytorch_stacked_hourglass/blob/ed91059e874f35089dd3a8e692fa895929785b91/task/pose.py#L107-L132

ashutoshsingh0223 commented 1 year ago

And since test.py uses validation examples wouldn't it lead to data leakage? I am just curious and want to know your opinion.

crockwell commented 1 year ago

Re: test.py. Good point! I'm following the convention of prior work which has an evaluation script on the validation set. MPII does not publicly share the test set labels, so if you want to get final test results it requires submitting to their repo. So "test.py" is just used on the val set (weird, I know. It would probably be wise to rename this val.py...).

However, if your first point is correct and indeed batchnorm is in train mode on the val set that could result in leakage. To be honest, I haven't looked at this code much since 2019, so I can't say for sure on the top of my head if that's the case. Actually, if you set a breakpoint to verify it is, and wanted to make a pull request, that would be much appreciated! No worries if not.

ashutoshsingh0223 commented 1 year ago

Thankyou for replying. I can create a PR no problem. I think this line from documentation verifies that running stats are update during training -

Also by default, during training this layer keeps running estimates of its computed mean and variance,
which are then used for normalization during evaluation.

https://pytorch.org/docs/master/generated/torch.nn.BatchNorm2d.html?highlight=batchnorm2d#torch.nn.BatchNorm2d

ashutoshsingh0223 commented 1 year ago

I will create a PR for this.

crockwell commented 1 year ago

Great!

Oh, I meant if you could run an instance of this train code and just verify model.eval() isn't being called somewhere. It's a bit of a different code layout, so it might be the case that somehow model.train() isn't being used when we look at the val set during training.

And just to be sure, the model uses model.eval() in test.py, right? I'm not sure if there was some setup such that we decided to just keep it in train mode all the time. But just want to be sure before we make a change.

Thanks!

ashutoshsingh0223 commented 1 year ago

Yes test.py uses model.eval().

No if you call model.eval() or net.eval() anywhere then Trainer's forward method will not return loss values and validation phase has these loss values.

But just to be completely sure I will do a sanity check before raising the PR.

Thankyou

crockwell commented 1 year ago

Ahh that's right.

It would be inconvenient to not have any sense of validation during training. I guess the ideal solution would be to compute keypoint accuracy, but I suppose that is a bit involved and loss is a more robust measure of training than accuracy. Here's one idea of what someone did in a similar case: https://discuss.pytorch.org/t/how-to-calculate-validation-loss-for-faster-rcnn/96307

ashutoshsingh0223 commented 1 year ago

Yes putting specific layers into eval mode seems like the best solution - without modifying the lower level classes and method.