Open Djoels opened 7 months ago
Hey @Djoels , thanks for bringing this up! We will take a look shortly and add it to the model tracker if relevant 🙂
This looks interesting indeed, thanks a lot for the issue! Added it to the methods tracker and will consider it for the paper session next week.
I can take this issue.
Thanks for looking into this @johnsutor! The original codebase implements the sparse net in a quite hacky way (see code here) and I was wondering whether it would be possible to pass the masks explicitly to the forward function instead of assigning them to a global variable. Maybe this would be interesting to explore, wdyt?
I'll investigate and get back to you!
Seems fairly straightforward to achieve based on https://github.com/keyu-tian/SparK/tree/main/pretrain#regarding-sparse-convolution. I don't mind giving it a stab, my thoughts are to implement the encoder and decoder from their code base (https://github.com/keyu-tian/SparK/tree/main/pretrain) within https://github.com/lightly-ai/lightly/tree/master/lightly/models, just naming the file something like spark.py
, if this sounds good I'll give it a go.
Sounds good! Thanks a lot for looking into it.
Maybe create a lightly/models/sparse
subdirectory and put it there. You could even name the file sparse_resnet.py
. And it would be create if you could keep the same structure as the original resnet in torchvision. Then it would be easy to convert from sparse resnet to dense resnet and vice-versa.
I went ahead and implemented a resnet compatible with the standard torchvision
library, so that we don't have to add timm
as a dependency.
Furthermore, I achieved passing the mask at runtime without setting a global variable using a pre-forward hook. This is how it looks so far:
class SparseEncoder(nn.Module):
def __init__(self, backbone: nn.Module, input_size: int, sync_bn: bool = False):
"""Sparse Encoder as used by SparK [0]
Default params are the ones explained in the original code base
[0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580
Attributes:
backbone:
Backbone model to extract features from images. Should have both
the methods get_downsample_ratio() and get_feature_map_channels()
implemented.
input_size:
Size of the input image.
sync_bn:
Whether or not to use Sync Batch Norm in this model.
"""
super(SparseEncoder, self).__init__()
self.mask: torch.Tensor
self.sp_backbone = self.dense_model_to_sparse(m=backbone, sbn=sbn)
self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
input_size,
backbone.get_downsample_ratio(),
backbone.get_feature_map_channels(),
)
def mask_hook(
self, module: nn.Module, input: Tuple[torch.Tensor], output: Tuple[torch.Tensor]
):
input = (input[0], self.mask)
return input
def dense_model_to_sparse(self, m: nn.Module, sbn: bool = False):
oup = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
oup = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
oup.weight.data.copy_(m.weight.data)
if bias:
oup.bias.data.copy_(m.bias.data)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
oup = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode,
)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
oup = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override,
)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats,
)
oup.weight.data.copy_(m.weight.data)
oup.bias.data.copy_(m.bias.data)
oup.running_mean.data.copy_(m.running_mean.data)
oup.running_var.data.copy_(m.running_var.data)
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
if hasattr(m, "qconfig"):
oup.qconfig = m.qconfig
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, (nn.Conv1d,)):
raise NotImplementedError
for name, child in m.named_children():
oup.add_module(name, self.dense_model_to_sparse(child, sbn=sbn))
del m
oup.register_forward_pre_hook(self.mask_hook)
return oup
def forward(self, x: torch.Tensor, mask: torch.Tensor):
assert (
mask is not None or self.mask is not None
), "Mask must be supplied for training"
self.mask = mask
return self.sp_backbone(x, hierarchical=True)
if that works, I'll go ahead and implement the Spark Module as well. The one thing I'm thinking about altering there is configuring the forward pass to return the reconstructions only, and perhaps create a separate method for calculating the reconstruction loss. This is to keep the code similar to the masked auto encoder.
Oh wow, thanks a lot for looking into this! It looks really good!
I have some comments/questions:
input_size
parameter needed? The user has to pass the mask
anyways, can we not infer the size of it?get_downsample_ratio
and get_feature_map_channels
method that we can call. I think we should be able to calculate those within the sparse modules and adapt the mask accordingly.sync_bn
parameter as we can check for the module type when we make the conversion (see code below)parameter.data.copy_
is deprecated and parameter.copy_
should be used instead (changed this already in the code below). I was actually wondering whether we even need to copy the parameters, can we not just assign them with oupt.weight = m.weight
etc?Here is the draft for a version that doesn't use hooks. Instead, it saves a SparseMask
object on all modules that need access to the mask. The modules can then modify this mask in their forward pass. As the object is shared across all modules they'll all have access to it. I also moved the dense_model_to_sparse
function outside of the SparseEncoder
class as it doesn't really need access to the class. This would also make it easier to reuse the method in other modules.
class SparseMask:
def __init__(self):
self.mask: Union[Tensor, None] = None
class SparseEncoder(nn.Module):
def __init__(self, backbone: nn.Module, input_size: int):
"""Sparse Encoder as used by SparK [0]
Default params are the ones explained in the original code base
[0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580
Attributes:
backbone:
Backbone model to extract features from images. Should have both
the methods get_downsample_ratio() and get_feature_map_channels()
implemented.
input_size:
Size of the input image.
"""
super().__init__()
self.sparse_mask = SparseMask()
self.sparse_backbone = self.dense_model_to_sparse(
m=backbone,
mask=self.sparse_mask
)
self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
input_size,
backbone.get_downsample_ratio(),
backbone.get_feature_map_channels(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
# All submodules will now have access to the sparse mask
self.sparse_mask.mask = mask
return self.sp_backbone(x, hierarchical=True)
def dense_model_to_sparse(m: Module, sparse_mask: SparseMask) -> Module:
oup = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
oup = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
sparse_mask=sparse_mask,
)
oup.weight.copy_(m.weight)
if bias:
oup.bias.copy_(m.bias)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
oup = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode,
sparse_mask=sparse_mask,
)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
oup = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override,
sparse_mask=sparse_mask,
)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
oup = (SparseSyncBatchNorm2d if isinstance(m, nn.SyncBatchNorm) else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats,
sparse_mask=sparse_mask,
)
oup.weight.copy_(m.weight)
oup.bias.copy_(m.bias)
oup.running_mean.copy_(m.running_mean)
oup.running_var.copy_(m.running_var)
oup.num_batches_tracked.copy_(m.num_batches_tracked)
if hasattr(m, "qconfig"):
oup.qconfig = m.qconfig
elif isinstance(m, (nn.Conv1d,)):
raise NotImplementedError
for name, child in m.named_children():
oup.add_module(name, dense_model_to_sparse(child, sparse_mask=sparse_mask))
del m
return oup
Hey, thanks for checking it out! In regards to your bullets:
with torch.no_grad():
self._feature_map_channels = []
x = self.layer1(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer2(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer3(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer4(x)
self._feature_map_channels.append(x.shape[1])
Perhaps for a more general purpose feature extractor that should work with all modules, we can determine the resolution of the feature map by calling create_feature_extractor during initialization and comparing the feature map size to the input size. Or, we can call get_graph_node_names, and returning the intermediate output up until the final linear pooling and linear layer. This should work with most modules
Sparse
modules. Something along the lines of this:
class SparseConv2d(Conv2d):
def forward(x: Tensor) -> Tensor:
x = super().forward(x)
mask = get_mask_with_size(self.sparse_mask, x)
x = apply_mask(x, mask)
return x
The feature map channels are used in step three of the forward process, where the hierarchical dense features are calculated for decoding. When the SparK module is created, it creates a mask token and a densify norm layer for when it fills in the masked locations with the mask token. We can circumvent the norm issue using a lazy batch normalization, and perhaps for the mask token itself, we can create it on the fly from the first pass right before this line?
Update: been busy with other life requirements, I'll get back to it when I can. If you want, I can commit the code that I've been working on
@johnsutor did you end up uploading the code anywhere?
@mileseverett never did, but I have more time now so I'll have to get back to working on it. Thanks for reminding me!
It would be great if this new MAE-style method called SparK was introduced to lightly.
Paper: https://arxiv.org/abs/2301.03580 featured in ICLR'23 Spotlight Code: https://github.com/keyu-tian/SparK
It was successfully applied to medical image applications, as documented in this Nature paper: https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main