cypw / PyTorch-MFNet

MIT License
252 stars 56 forks source link

High gpu memory usage #13

Closed ntomita closed 5 years ago

ntomita commented 5 years ago

I've noticed the higher memory usage of MFNet compared to that of ResNet in image processing.

Settings: Framework: pytorch resnet model: resnet18 from torchvision mfnet model: modification of /network/mfnet_3d.py for 2d processing input size: 128x3x224x224 with enabled gradients computation.

Results: GPU memory consumption observed: 8GB for mfnet vs 3.4GB for resnet

The number of params matches to the paper for each model, but the actual memory consumption of mfnet doesn't reflect the reduced FLOPS. Have you observed the same behaviour, or could this be caused by the 2d conversion of the model? I'm confident that the modifications I made follows the description of 2D architecture on Table2 and it shouldn't be tricky. Any idea? (By the way I still appreciate if you could release an official 2D version of MFNet, even though that's not the main point of your work.)

ntomita commented 5 years ago

I did follow-up experiments and observed that the model actually consume less memory in inference (vs resnet18), but more in training setting. I supposed this is the characteristics of the architecture that involves many 1x1 convolutions, which eventually generates a longer gradient graph as opposed to its compact parameter size. I'm going to close this as it's not related to the implementation.

[FYI: Table. memory consumption in mb of mfnet and resnet18 with different size of batch. Rows: architecture, Columns: batch size of 1,2,4,8,16,32,64,128]

Inference: [611.0, 615.0, 621.0, 627.0, 647.0, 665.0, 735.0, 815.0] [663.0, 663.0, 667.0, 695.0, 697.0, 755.0, 871.0, 1089.0] Train: [685.0, 757.0, 877.0, 1167.0, 1689.0, 2629.0, 4483.0, 8175.0] [679.0, 715.0, 761.0, 867.0, 1011.0, 1445.0, 2133.0, 3541.0]