Closed makarr closed 3 years ago
I think the backbone for DeepLabv3 have layers modified with smaller stride, i.e larger feature maps which consumes more memory with the same number of parameters. You can check the output shape of the network to verify this.
@WilhelmT That's a good point so I dug into it. The third and fourth layers of the Deeplab backbone use dilation so the overall outputs are larger. To test whether this was causing the memory issue, I ran the following:
resnet = models.resnet50()
deeplab = models.segmentation.deeplabv3_resnet50()
for name, module in deeplab.backbone._modules.items():
if name in resnet._modules:
resnet._modules[name] = copy.deepcopy(module)
learner = PixelCL(resnet, ...)
The RAM usage was ~5x greater (1.33 vs. 6.6 GB), but it didn't take up 16+ GB as before.
This doesn't exactly get to the bottom of the issue, but at least I can use PixelCL
now. Thanks for your comment. I guess I'll close.
Following the usage in the README, there are no problems if the model looks like this:
But when I try the following, the results are very different:
In this case RAM usage spirals out of control and my cloud instance crashes, even though both resnets have ~23M params. The class in the second case is
IntermediateLayerGetter
instead ofResNet
, which might be the root of the issue, but glancing through the source I can't see what would cause the problem.Also, small suggestion, it might be nice if this and BYOL supported grayscale. Segmentation is a popular task for medical imaging, and annotated data is hard to come by, so unsupervised pretraining for segmentation is super helpful in this domain. (Example) I copied the code and made changes as necessary (in augmentations and elsewhere), but I ended up being unable to use the package.
Anyway, thanks for your work. You also might want to take a look at Segmentation Transformer, which uses ViT for segmentation. Unfortunately in my testing it underperforms DeepLabV3 on small to medium size datasets, and contrastive pretraining doesn't help much. I haven't yet tried the denoising / masking / corrupted token approach.