lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.85k stars 246 forks source link

Feature Request: New SSL algorithm called SparK: Sparse and Hierarchical masKed modeling #1462

Open Djoels opened 7 months ago

Djoels commented 7 months ago

It would be great if this new MAE-style method called SparK was introduced to lightly.

Paper: https://arxiv.org/abs/2301.03580 featured in ICLR'23 Spotlight Code: https://github.com/keyu-tian/SparK

It was successfully applied to medical image applications, as documented in this Nature paper: https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main

philippmwirth commented 6 months ago

Hey @Djoels , thanks for bringing this up! We will take a look shortly and add it to the model tracker if relevant 🙂

guarin commented 6 months ago

This looks interesting indeed, thanks a lot for the issue! Added it to the methods tracker and will consider it for the paper session next week.

johnsutor commented 6 months ago

I can take this issue.

guarin commented 6 months ago

Thanks for looking into this @johnsutor! The original codebase implements the sparse net in a quite hacky way (see code here) and I was wondering whether it would be possible to pass the masks explicitly to the forward function instead of assigning them to a global variable. Maybe this would be interesting to explore, wdyt?

johnsutor commented 6 months ago

I'll investigate and get back to you!

johnsutor commented 5 months ago

Seems fairly straightforward to achieve based on https://github.com/keyu-tian/SparK/tree/main/pretrain#regarding-sparse-convolution. I don't mind giving it a stab, my thoughts are to implement the encoder and decoder from their code base (https://github.com/keyu-tian/SparK/tree/main/pretrain) within https://github.com/lightly-ai/lightly/tree/master/lightly/models, just naming the file something like spark.py, if this sounds good I'll give it a go.

guarin commented 5 months ago

Sounds good! Thanks a lot for looking into it.

Maybe create a lightly/models/sparse subdirectory and put it there. You could even name the file sparse_resnet.py. And it would be create if you could keep the same structure as the original resnet in torchvision. Then it would be easy to convert from sparse resnet to dense resnet and vice-versa.

johnsutor commented 5 months ago

I went ahead and implemented a resnet compatible with the standard torchvision library, so that we don't have to add timm as a dependency.

Furthermore, I achieved passing the mask at runtime without setting a global variable using a pre-forward hook. This is how it looks so far:

class SparseEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_size: int, sync_bn: bool = False):
        """Sparse Encoder as used by SparK [0]

        Default params are the ones explained in the original code base
        [0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580

        Attributes:
            backbone:
                Backbone model to extract features from images. Should have both
                the methods get_downsample_ratio() and get_feature_map_channels()
                implemented.
            input_size:
                Size of the input image.
            sync_bn:
                Whether or not to use Sync Batch Norm in this model.

        """
        super(SparseEncoder, self).__init__()
        self.mask: torch.Tensor
        self.sp_backbone = self.dense_model_to_sparse(m=backbone, sbn=sbn)
        self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
            input_size,
            backbone.get_downsample_ratio(),
            backbone.get_feature_map_channels(),
        )

    def mask_hook(
        self, module: nn.Module, input: Tuple[torch.Tensor], output: Tuple[torch.Tensor]
    ):
        input = (input[0], self.mask)
        return input

    def dense_model_to_sparse(self, m: nn.Module, sbn: bool = False):
        oup = m
        if isinstance(m, nn.Conv2d):
            m: nn.Conv2d
            bias = m.bias is not None
            oup = SparseConv2d(
                m.in_channels,
                m.out_channels,
                kernel_size=m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                groups=m.groups,
                bias=bias,
                padding_mode=m.padding_mode,
            )
            oup.weight.data.copy_(m.weight.data)
            if bias:
                oup.bias.data.copy_(m.bias.data)
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, nn.MaxPool2d):
            m: nn.MaxPool2d
            oup = SparseMaxPooling(
                m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                return_indices=m.return_indices,
                ceil_mode=m.ceil_mode,
            )
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, nn.AvgPool2d):
            m: nn.AvgPool2d
            oup = SparseAvgPooling(
                m.kernel_size,
                m.stride,
                m.padding,
                ceil_mode=m.ceil_mode,
                count_include_pad=m.count_include_pad,
                divisor_override=m.divisor_override,
            )
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            m: nn.BatchNorm2d
            oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(
                m.weight.shape[0],
                eps=m.eps,
                momentum=m.momentum,
                affine=m.affine,
                track_running_stats=m.track_running_stats,
            )
            oup.weight.data.copy_(m.weight.data)
            oup.bias.data.copy_(m.bias.data)
            oup.running_mean.data.copy_(m.running_mean.data)
            oup.running_var.data.copy_(m.running_var.data)
            oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
            if hasattr(m, "qconfig"):
                oup.qconfig = m.qconfig
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, (nn.Conv1d,)):
            raise NotImplementedError

        for name, child in m.named_children():
            oup.add_module(name, self.dense_model_to_sparse(child, sbn=sbn))
        del m
        oup.register_forward_pre_hook(self.mask_hook)
        return oup

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        assert (
            mask is not None or self.mask is not None
        ), "Mask must be supplied for training"
        self.mask = mask
        return self.sp_backbone(x, hierarchical=True)

if that works, I'll go ahead and implement the Spark Module as well. The one thing I'm thinking about altering there is configuring the forward pass to return the reconstructions only, and perhaps create a separate method for calculating the reconstruction loss. This is to keep the code similar to the masked auto encoder.

guarin commented 5 months ago

Oh wow, thanks a lot for looking into this! It looks really good!

I have some comments/questions:

Here is the draft for a version that doesn't use hooks. Instead, it saves a SparseMask object on all modules that need access to the mask. The modules can then modify this mask in their forward pass. As the object is shared across all modules they'll all have access to it. I also moved the dense_model_to_sparse function outside of the SparseEncoder class as it doesn't really need access to the class. This would also make it easier to reuse the method in other modules.

class SparseMask:
    def __init__(self):
        self.mask: Union[Tensor, None] = None

class SparseEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_size: int):
        """Sparse Encoder as used by SparK [0]

        Default params are the ones explained in the original code base
        [0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580

        Attributes:
            backbone:
                Backbone model to extract features from images. Should have both
                the methods get_downsample_ratio() and get_feature_map_channels()
                implemented.
            input_size:
                Size of the input image.

        """
        super().__init__()
        self.sparse_mask = SparseMask()
        self.sparse_backbone = self.dense_model_to_sparse(
            m=backbone,
            mask=self.sparse_mask
        )
        self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
            input_size,
            backbone.get_downsample_ratio(),
            backbone.get_feature_map_channels(),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        # All submodules will now have access to the sparse mask
        self.sparse_mask.mask = mask
        return self.sp_backbone(x, hierarchical=True)

def dense_model_to_sparse(m: Module, sparse_mask: SparseMask) -> Module:
    oup = m
    if isinstance(m, nn.Conv2d):
        m: nn.Conv2d
        bias = m.bias is not None
        oup = SparseConv2d(
            m.in_channels,
            m.out_channels,
            kernel_size=m.kernel_size,
            stride=m.stride,
            padding=m.padding,
            dilation=m.dilation,
            groups=m.groups,
            bias=bias,
            padding_mode=m.padding_mode,
            sparse_mask=sparse_mask,
        )
        oup.weight.copy_(m.weight)
        if bias:
            oup.bias.copy_(m.bias)
    elif isinstance(m, nn.MaxPool2d):
        m: nn.MaxPool2d
        oup = SparseMaxPooling(
            m.kernel_size,
            stride=m.stride,
            padding=m.padding,
            dilation=m.dilation,
            return_indices=m.return_indices,
            ceil_mode=m.ceil_mode,
            sparse_mask=sparse_mask,
        )
    elif isinstance(m, nn.AvgPool2d):
        m: nn.AvgPool2d
        oup = SparseAvgPooling(
            m.kernel_size,
            m.stride,
            m.padding,
            ceil_mode=m.ceil_mode,
            count_include_pad=m.count_include_pad,
            divisor_override=m.divisor_override,
            sparse_mask=sparse_mask,
        )
    elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
        m: nn.BatchNorm2d
        oup = (SparseSyncBatchNorm2d if isinstance(m, nn.SyncBatchNorm) else SparseBatchNorm2d)(
            m.weight.shape[0],
            eps=m.eps,
            momentum=m.momentum,
            affine=m.affine,
            track_running_stats=m.track_running_stats,
            sparse_mask=sparse_mask,
        )
        oup.weight.copy_(m.weight)
        oup.bias.copy_(m.bias)
        oup.running_mean.copy_(m.running_mean)
        oup.running_var.copy_(m.running_var)
        oup.num_batches_tracked.copy_(m.num_batches_tracked)
        if hasattr(m, "qconfig"):
            oup.qconfig = m.qconfig
    elif isinstance(m, (nn.Conv1d,)):
        raise NotImplementedError

    for name, child in m.named_children():
        oup.add_module(name, dense_model_to_sparse(child, sparse_mask=sparse_mask))
    del m
    return oup
johnsutor commented 5 months ago

Hey, thanks for checking it out! In regards to your bullets:

  1. I think that approach of storing the sparse mask on the modules works (this was going to be my second approach if we didn't use hooks, so that works well for me haha).
  2. The input size is used with the SparK module to determine which channels should get masked when reshaping the image tensor, though not for the encoder. I'll remove it from the encoder. As for the SparK module itself, we could determine input size on the fly, though I'm not sure if this would have adverse effects on the training procedure if different batches have different spatial dimensions. It's only ultimately used to calculate the number of channels to keep from flattened representations, so your call on whether or not to define the input size up front to enforce consistent spatial dimensions or not.
  3. With the prototype spark-compatible resnet that I have in the works, I calculate the feature map channels like so:
        with torch.no_grad():
            self._feature_map_channels = []
            x = self.layer1(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer2(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer3(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer4(x)
            self._feature_map_channels.append(x.shape[1])

    Perhaps for a more general purpose feature extractor that should work with all modules, we can determine the resolution of the feature map by calling create_feature_extractor during initialization and comparing the feature map size to the input size. Or, we can call get_graph_node_names, and returning the intermediate output up until the final linear pooling and linear layer. This should work with most modules

  4. Sounds good!
  5. That's fine by me! I tried to leave the code as similar as possible to avoid breaking anything, but I doubt that change will alter anything.
guarin commented 5 months ago
  1. Haha perfect!
  2. Sounds good :)
  3. I am wondering whether the feature map channels have to be known in advance. Are they used for anything else than for the mask resizing? I imagine we could calculate the size on the fly in the forward pass of the Sparse modules. Something along the lines of this:
    class SparseConv2d(Conv2d):
    def forward(x: Tensor) -> Tensor:
         x = super().forward(x)
         mask = get_mask_with_size(self.sparse_mask, x)
         x = apply_mask(x, mask)
        return x
johnsutor commented 5 months ago

The feature map channels are used in step three of the forward process, where the hierarchical dense features are calculated for decoding. When the SparK module is created, it creates a mask token and a densify norm layer for when it fills in the masked locations with the mask token. We can circumvent the norm issue using a lazy batch normalization, and perhaps for the mask token itself, we can create it on the fly from the first pass right before this line?

johnsutor commented 5 months ago

Update: been busy with other life requirements, I'll get back to it when I can. If you want, I can commit the code that I've been working on

mileseverett commented 2 weeks ago

@johnsutor did you end up uploading the code anywhere?

johnsutor commented 2 weeks ago

@mileseverett never did, but I have more time now so I'll have to get back to working on it. Thanks for reminding me!