Closed ljn114514 closed 3 years ago
This is a good point. Ideally, PyTorch's batch norm layers should be smart enough to update the running mean/var appropriately with the checkpointing operation.
If this is not the case, then you should raise an issue with PyTorch, since the checkpointing/batch norm layers are part of their library, not this library.
Ok, thanks for your reply
How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.