gpleiss / efficient_densenet_pytorch

A memory-efficient implementation of DenseNets
MIT License
1.51k stars 329 forks source link

Question: why use bn_function on 1x1 conv, not on 3x3 conv #60

Closed lizhenstat closed 4 years ago

lizhenstat commented 4 years ago

Hi, thanks for your great work.

I have one question on the forward function https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py#L35 why do you use cp.checkpoint only on 1x1 convolution, is there a problem when applying it to 3x3 convolution?

Thanks in advance

gpleiss commented 4 years ago

Hi @lizhenstat - the first set of convolutions has a quadratic memory cost, whereas the second set only has a linear memory cost.

To see this - the first set of convolutions maps from num_previous_filters -> bn_size * growth_rate. Storing the normalized inputs to this operations requires storing a feature map of size num_previous_filters. Doing this for all the layers will incur a quadratic cost.

However, the second set of convolutions only has an input of bn_size * growth_rate. Since this is a constant for all layers, it incurs a linear cost.

See the tech report for more details :)