gpleiss / efficient_densenet_pytorch

A memory-efficient implementation of DenseNets
MIT License
1.51k stars 329 forks source link

Is this really memory efficient? #66

Open leonardishere opened 4 years ago

leonardishere commented 4 years ago

I see the memory consumption chart in the readme, but after looking at the code, I have doubts that this implementation is fully memory efficient. I see the call to cp.checkpoint in _DenseLayer.forward(), but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?

If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?

gpleiss commented 4 years ago

but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?

This implementation and the default DenseNet implementation use pre-activation normalization and contiguous concatenation. Error increases without pre-activation normalization, and training time suffers significantly without contiguous concatenation. The purpose of the technical report was to design a memory efficient implementation under the constraints that we wanted pre-activation normalization and contiguous concatenation (as in the original DenseNet implementation).

If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?

We are tossing the memory-hogging intermediate values (as described in the technical report), but this makes the memory consumption linear. Figure 3 in the technical report explains. Storing the intermediate activations cause the quadratic memory consumption, whereas the total number of features is linear in depth.