gpleiss / efficient_densenet_pytorch

A memory-efficient implementation of DenseNets
MIT License
1.52k stars 327 forks source link

The BN running mean&var with torch.utils.checkpoint.checkpoint #73

Closed ljn114514 closed 3 years ago

ljn114514 commented 3 years ago

How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.

gpleiss commented 3 years ago

This is a good point. Ideally, PyTorch's batch norm layers should be smart enough to update the running mean/var appropriately with the checkpointing operation.

If this is not the case, then you should raise an issue with PyTorch, since the checkpointing/batch norm layers are part of their library, not this library.

ljn114514 commented 3 years ago

Ok, thanks for your reply