-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New SSL algorithm called SparK: Sparse and Hierarchical masKed modeling #1462
Comments
Hey @Djoels , thanks for bringing this up! We will take a look shortly and add it to the model tracker if relevant 🙂 |
This looks interesting indeed, thanks a lot for the issue! Added it to the methods tracker and will consider it for the paper session next week. |
I can take this issue. |
Thanks for looking into this @johnsutor! The original codebase implements the sparse net in a quite hacky way (see code here) and I was wondering whether it would be possible to pass the masks explicitly to the forward function instead of assigning them to a global variable. Maybe this would be interesting to explore, wdyt? |
I'll investigate and get back to you! |
Seems fairly straightforward to achieve based on https://github.com/keyu-tian/SparK/tree/main/pretrain#regarding-sparse-convolution. I don't mind giving it a stab, my thoughts are to implement the encoder and decoder from their code base (https://github.com/keyu-tian/SparK/tree/main/pretrain) within https://github.com/lightly-ai/lightly/tree/master/lightly/models, just naming the file something like |
Sounds good! Thanks a lot for looking into it. Maybe create a |
I went ahead and implemented a resnet compatible with the standard Furthermore, I achieved passing the mask at runtime without setting a global variable using a pre-forward hook. This is how it looks so far: class SparseEncoder(nn.Module):
def __init__(self, backbone: nn.Module, input_size: int, sync_bn: bool = False):
"""Sparse Encoder as used by SparK [0]
Default params are the ones explained in the original code base
[0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580
Attributes:
backbone:
Backbone model to extract features from images. Should have both
the methods get_downsample_ratio() and get_feature_map_channels()
implemented.
input_size:
Size of the input image.
sync_bn:
Whether or not to use Sync Batch Norm in this model.
"""
super(SparseEncoder, self).__init__()
self.mask: torch.Tensor
self.sp_backbone = self.dense_model_to_sparse(m=backbone, sbn=sbn)
self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
input_size,
backbone.get_downsample_ratio(),
backbone.get_feature_map_channels(),
)
def mask_hook(
self, module: nn.Module, input: Tuple[torch.Tensor], output: Tuple[torch.Tensor]
):
input = (input[0], self.mask)
return input
def dense_model_to_sparse(self, m: nn.Module, sbn: bool = False):
oup = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
oup = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
oup.weight.data.copy_(m.weight.data)
if bias:
oup.bias.data.copy_(m.bias.data)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
oup = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode,
)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
oup = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override,
)
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats,
)
oup.weight.data.copy_(m.weight.data)
oup.bias.data.copy_(m.bias.data)
oup.running_mean.data.copy_(m.running_mean.data)
oup.running_var.data.copy_(m.running_var.data)
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
if hasattr(m, "qconfig"):
oup.qconfig = m.qconfig
oup.register_forward_pre_hook(self.mask_hook)
elif isinstance(m, (nn.Conv1d,)):
raise NotImplementedError
for name, child in m.named_children():
oup.add_module(name, self.dense_model_to_sparse(child, sbn=sbn))
del m
oup.register_forward_pre_hook(self.mask_hook)
return oup
def forward(self, x: torch.Tensor, mask: torch.Tensor):
assert (
mask is not None or self.mask is not None
), "Mask must be supplied for training"
self.mask = mask
return self.sp_backbone(x, hierarchical=True) if that works, I'll go ahead and implement the Spark Module as well. The one thing I'm thinking about altering there is configuring the forward pass to return the reconstructions only, and perhaps create a separate method for calculating the reconstruction loss. This is to keep the code similar to the masked auto encoder. |
Oh wow, thanks a lot for looking into this! It looks really good! I have some comments/questions:
Here is the draft for a version that doesn't use hooks. Instead, it saves a
|
Hey, thanks for checking it out! In regards to your bullets:
with torch.no_grad():
self._feature_map_channels = []
x = self.layer1(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer2(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer3(x)
self._feature_map_channels.append(x.shape[1])
x = self.layer4(x)
self._feature_map_channels.append(x.shape[1]) Perhaps for a more general purpose feature extractor that should work with all modules, we can determine the resolution of the feature map by calling create_feature_extractor during initialization and comparing the feature map size to the input size. Or, we can call get_graph_node_names, and returning the intermediate output up until the final linear pooling and linear layer. This should work with most modules |
|
The feature map channels are used in step three of the forward process, where the hierarchical dense features are calculated for decoding. When the SparK module is created, it creates a mask token and a densify norm layer for when it fills in the masked locations with the mask token. We can circumvent the norm issue using a lazy batch normalization, and perhaps for the mask token itself, we can create it on the fly from the first pass right before this line? |
Update: been busy with other life requirements, I'll get back to it when I can. If you want, I can commit the code that I've been working on |
@johnsutor did you end up uploading the code anywhere? |
@mileseverett never did, but I have more time now so I'll have to get back to working on it. Thanks for reminding me! |
It would be great if this new MAE-style method called SparK was introduced to lightly.
Paper: https://arxiv.org/abs/2301.03580 featured in ICLR'23 Spotlight
Code: https://github.com/keyu-tian/SparK
It was successfully applied to medical image applications, as documented in this Nature paper:
https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main
The text was updated successfully, but these errors were encountered: