lucidrains / pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
MIT License
252 stars 29 forks source link

Same size models -> wildly different RAM usage #13

Closed makarr closed 3 years ago

makarr commented 3 years ago

Following the usage in the README, there are no problems if the model looks like this:

from torchvision import models
resnet = models.resnet50()
learner = PixelCL(resnet, ...)

But when I try the following, the results are very different:

from torchvision.models.segmentation import deeplabv3_resnet50
resnet = deeplabv3_resnet50().backbone
learner = PixelCL(resnet, ...)

In this case RAM usage spirals out of control and my cloud instance crashes, even though both resnets have ~23M params. The class in the second case is IntermediateLayerGetter instead of ResNet, which might be the root of the issue, but glancing through the source I can't see what would cause the problem.

Also, small suggestion, it might be nice if this and BYOL supported grayscale. Segmentation is a popular task for medical imaging, and annotated data is hard to come by, so unsupervised pretraining for segmentation is super helpful in this domain. (Example) I copied the code and made changes as necessary (in augmentations and elsewhere), but I ended up being unable to use the package.

Anyway, thanks for your work. You also might want to take a look at Segmentation Transformer, which uses ViT for segmentation. Unfortunately in my testing it underperforms DeepLabV3 on small to medium size datasets, and contrastive pretraining doesn't help much. I haven't yet tried the denoising / masking / corrupted token approach.

WilhelmT commented 3 years ago

I think the backbone for DeepLabv3 have layers modified with smaller stride, i.e larger feature maps which consumes more memory with the same number of parameters. You can check the output shape of the network to verify this.

makarr commented 3 years ago

@WilhelmT That's a good point so I dug into it. The third and fourth layers of the Deeplab backbone use dilation so the overall outputs are larger. To test whether this was causing the memory issue, I ran the following:

resnet = models.resnet50()
deeplab = models.segmentation.deeplabv3_resnet50()
for name, module in deeplab.backbone._modules.items():
    if name in resnet._modules:
        resnet._modules[name] = copy.deepcopy(module)
learner = PixelCL(resnet, ...)

The RAM usage was ~5x greater (1.33 vs. 6.6 GB), but it didn't take up 16+ GB as before.

This doesn't exactly get to the bottom of the issue, but at least I can use PixelCL now. Thanks for your comment. I guess I'll close.