Open ivanstepanovftw opened 7 months ago
Email have been sent to paper authors regarding this concern. Still waiting for answers.
CC: @liuguilin1225 @fitsumreda @bryancatanzaro
I have added to comparison basic Conv2d as many asked to reproduce your results, and I see that it is goes completely opposite with the paper.
Code:
```python from contextlib import contextmanager from functools import partial from typing import Tuple, Any, Callable import torch import torch.nn.functional as F from matplotlib import pyplot as plt from torch import nn, Tensor class PartialConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): # whether the mask is multi-channel or not if 'multi_channel' in kwargs: self.multi_channel = kwargs['multi_channel'] kwargs.pop('multi_channel') else: self.multi_channel = False if 'return_mask' in kwargs: self.return_mask = kwargs['return_mask'] kwargs.pop('return_mask') else: self.return_mask = False super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])) else: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] self.last_size = (None, None, None, None) self.update_mask = None self.mask_ratio = None def forward(self, input, mask_in=None): assert len(input.shape) == 4 if mask_in is not None or self.last_size != tuple(input.shape): self.last_size = tuple(input.shape) with torch.no_grad(): if mask_in is None: # if mask is not provided, create a mask if self.multi_channel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype) else: mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) # for mixed precision training, change 1e-8 to 1e-6 self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) self.update_mask = torch.clamp(self.update_mask, 0, 1) self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view output = torch.mul(output, self.update_mask) else: output = torch.mul(raw_out, self.mask_ratio) if self.return_mask: return output, self.update_mask else: return output class MaskedConv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', eps=1e-8, multichannel: bool = False, partial_conv: bool = False, device=None, dtype=None ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) if multichannel: self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False) else: self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False) self.eps = eps self.multichannel = multichannel self.partial_conv = partial_conv def get_mask( self, input: torch.Tensor, mask: torch.Tensor | None ) -> (torch.Tensor, torch.Tensor): if mask is None: if self.multichannel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype) else: if self.multichannel: mask = mask.expand_as(input) else: mask = mask.expand(1, 1, *input.shape[2:]) return mask def forward( self, input: torch.Tensor, mask: torch.Tensor | None = None ) -> (torch.Tensor, torch.Tensor | None): if mask is not None: input *= mask mask = self.get_mask(input, mask) if self.partial_conv: output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) mask_kernel_numel = self.mask_weight.data.shape[1:].numel() mask_ratio = mask_kernel_numel / (mask + self.eps) mask.clamp_(0, 1) # Apply re-weighting and bias output *= mask_ratio if self.bias is not None: output += self.bias.view(-1, 1, 1) output *= mask else: output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] mask = mask / max_vals return output, mask def extra_repr(self): return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}" class MaskedPixelUnshuffle(nn.PixelUnshuffle): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): return super().forward(input), super().forward(mask) if mask is not None else None class MaskedSequential(nn.Sequential): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): for module in self: input, mask = module(input, mask) return input, mask @contextmanager def register_hooks( model: torch.nn.Module, hook: Callable, predicate: Callable[[str, torch.nn.Module], bool], **hook_kwargs ): handles = [] try: for name, module in model.named_modules(): if predicate(name, module): hook: Callable = partial(hook, name=name, **hook_kwargs) handle = module.register_forward_hook(hook) handles.append(handle) yield handles finally: for handle in handles: handle.remove() def activations_recorder_hook( module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, name: str, *, storage: dict[str, Any] ): if name in storage: if isinstance(storage[name], list): storage[name].append(output) else: storage[name] = [storage[name], output] else: storage[name] = output def forward_with_activations( model: torch.nn.Module, predicate: Callable[[str, torch.nn.Module], bool], *model_args, **model_kwargs, ) -> Tuple[torch.Tensor, dict[str, Any]]: storage = {} with register_hooks(model, activations_recorder_hook, predicate, storage=storage): output = model(*model_args, **model_kwargs) return output, storage def test_it(): torch.manual_seed(37) in_channels = 3 downscale_factor = 2 scale = 1 base = 2 depth = 8 visualize_depth = 6 eps = 1e-8 conv = [] for i in range(depth): conv.append(nn.PixelUnshuffle(downscale_factor)) conv.append(nn.Conv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False) ) conv = nn.Sequential(*conv) pconv = [] for i in range(depth): pconv.append(MaskedPixelUnshuffle(downscale_factor)) pconv.append(PartialConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True) ) pconv = MaskedSequential(*pconv) mpconv = [] for i in range(depth): mpconv.append(MaskedPixelUnshuffle(downscale_factor)) mpconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True) ) mpconv = MaskedSequential(*mpconv) mconv = [] for i in range(depth): mconv.append(MaskedPixelUnshuffle(downscale_factor)) mconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False) ) mconv = MaskedSequential(*mconv) with torch.no_grad(): print(f"{conv=}") print(f"{pconv=}") print(f"{mpconv=}") print(f"{mconv=}") print(f"{list(conv.state_dict().keys())=}") print(f"{list(pconv.state_dict().keys())=}") print(f"{list(mpconv.state_dict().keys())=}") print(f"{list(mconv.state_dict().keys())=}") pconv.load_state_dict(conv.state_dict()) mpconv.load_state_dict(conv.state_dict()) mconv.load_state_dict(conv.state_dict()) # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth) x = torch.randn(1, in_channels, 512, 512) mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x) def is_conv_predicate(name: str, module: torch.nn.Module): return isinstance(module, torch.nn.Conv2d) y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x) (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv) (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv) (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv) assert not torch.allclose(y_conv, y_mpconv) assert torch.allclose(y_mpconv, y_pconv) assert not torch.allclose(y_mconv, y_mpconv) print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15'] # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180) fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180) axs = axs.flatten() for impl_i, (name, y, mask, activations) in enumerate([ ("conv", y_conv, None, activations_conv), ("pconv", y_pconv, mask_pconv, activations_pconv), ("mpconv", y_mpconv, mask_mpconv, activations_mpconv), ("mconv", y_mconv, mask_mconv, activations_mconv) ]): batch_i = 0 for depth_i in range(visualize_depth): # ax = axs[depth_i * 4 + impl_i] ax = axs[impl_i * visualize_depth + depth_i] layer_output = activations[f"{depth_i * 2 + 1}"] if isinstance(layer_output, torch.Tensor): output = layer_output[batch_i] mask_output = None else: output = layer_output[0][batch_i] mask_output = layer_output[1][batch_i] assert output.dim() == 3 mean = output.mean() std = output.std(unbiased=False) skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps) kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps) print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}") # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std) ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std) ax.set_title(f"{name} {depth_i=}") ax.axis('off') # plt.suptitle(f"Depth {depth_i}") plt.show() if __name__ == '__main__': test_it() ```
Output:
``` conv=Sequential( (0): PixelUnshuffle(downscale_factor=2) (1): Conv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (2): PixelUnshuffle(downscale_factor=2) (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): PixelUnshuffle(downscale_factor=2) (5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (6): PixelUnshuffle(downscale_factor=2) (7): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (8): PixelUnshuffle(downscale_factor=2) (9): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (10): PixelUnshuffle(downscale_factor=2) (11): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (12): PixelUnshuffle(downscale_factor=2) (13): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (14): PixelUnshuffle(downscale_factor=2) (15): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) pconv=MaskedSequential( (0): MaskedPixelUnshuffle(downscale_factor=2) (1): PartialConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (2): MaskedPixelUnshuffle(downscale_factor=2) (3): PartialConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): MaskedPixelUnshuffle(downscale_factor=2) (5): PartialConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (6): MaskedPixelUnshuffle(downscale_factor=2) (7): PartialConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (8): MaskedPixelUnshuffle(downscale_factor=2) (9): PartialConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (10): MaskedPixelUnshuffle(downscale_factor=2) (11): PartialConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (12): MaskedPixelUnshuffle(downscale_factor=2) (13): PartialConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (14): MaskedPixelUnshuffle(downscale_factor=2) (15): PartialConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) mpconv=MaskedSequential( (0): MaskedPixelUnshuffle(downscale_factor=2) (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (2): MaskedPixelUnshuffle(downscale_factor=2) (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (4): MaskedPixelUnshuffle(downscale_factor=2) (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (6): MaskedPixelUnshuffle(downscale_factor=2) (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (8): MaskedPixelUnshuffle(downscale_factor=2) (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (10): MaskedPixelUnshuffle(downscale_factor=2) (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (12): MaskedPixelUnshuffle(downscale_factor=2) (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) (14): MaskedPixelUnshuffle(downscale_factor=2) (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True) ) mconv=MaskedSequential( (0): MaskedPixelUnshuffle(downscale_factor=2) (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (2): MaskedPixelUnshuffle(downscale_factor=2) (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (4): MaskedPixelUnshuffle(downscale_factor=2) (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (6): MaskedPixelUnshuffle(downscale_factor=2) (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (8): MaskedPixelUnshuffle(downscale_factor=2) (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (10): MaskedPixelUnshuffle(downscale_factor=2) (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (12): MaskedPixelUnshuffle(downscale_factor=2) (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) (14): MaskedPixelUnshuffle(downscale_factor=2) (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False) ) list(conv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight'] list(pconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight'] list(mpconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight'] list(mconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight'] activations_pconv.keys()=dict_keys(['1', '3', '5', '7', '9', '11', '13', '15']) name='conv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264) name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3238), skewness=tensor(0.0080), kurtosis=tensor(3.0212) name='conv', depth_i=2, mean=tensor(6.5161e-06), std=tensor(0.1855), skewness=tensor(0.0049), kurtosis=tensor(3.0922) name='conv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1054), skewness=tensor(0.0081), kurtosis=tensor(3.0650) name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0589), skewness=tensor(-0.0125), kurtosis=tensor(3.1699) name='conv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0316), skewness=tensor(-0.0147), kurtosis=tensor(3.2110) name='pconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276) name='pconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518) name='pconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635) name='pconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829) name='pconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324) name='pconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953) name='mpconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276) name='mpconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518) name='mpconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635) name='mpconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829) name='mpconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324) name='mpconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953) name='mconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264) name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3232), skewness=tensor(0.0085), kurtosis=tensor(3.0263) name='mconv', depth_i=2, mean=tensor(-9.7408e-06), std=tensor(0.1844), skewness=tensor(0.0053), kurtosis=tensor(3.1119) name='mconv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1039), skewness=tensor(0.0093), kurtosis=tensor(3.1074) name='mconv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0571), skewness=tensor(-0.0164), kurtosis=tensor(3.2821) name='mconv', depth_i=5, mean=tensor(0.0006), std=tensor(0.0296), skewness=tensor(-0.0277), kurtosis=tensor(3.3867) ```
It even looks worse with the partially occluded mask.
Code:
```python from contextlib import contextmanager from functools import partial from typing import Tuple, Any, Callable import torch import torch.nn.functional as F from matplotlib import pyplot as plt from torch import nn, Tensor class PartialConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): # whether the mask is multi-channel or not if 'multi_channel' in kwargs: self.multi_channel = kwargs['multi_channel'] kwargs.pop('multi_channel') else: self.multi_channel = False if 'return_mask' in kwargs: self.return_mask = kwargs['return_mask'] kwargs.pop('return_mask') else: self.return_mask = False super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])) else: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] self.last_size = (None, None, None, None) self.update_mask = None self.mask_ratio = None def forward(self, input, mask_in=None): assert len(input.shape) == 4 if mask_in is not None or self.last_size != tuple(input.shape): self.last_size = tuple(input.shape) with torch.no_grad(): if mask_in is None: # if mask is not provided, create a mask if self.multi_channel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype) else: mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) # for mixed precision training, change 1e-8 to 1e-6 self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) self.update_mask = torch.clamp(self.update_mask, 0, 1) self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view output = torch.mul(output, self.update_mask) else: output = torch.mul(raw_out, self.mask_ratio) if self.return_mask: return output, self.update_mask else: return output class MaskedConv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', eps=1e-8, multichannel: bool = False, partial_conv: bool = False, device=None, dtype=None ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) if multichannel: self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False) else: self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False) self.eps = eps self.multichannel = multichannel self.partial_conv = partial_conv def get_mask( self, input: torch.Tensor, mask: torch.Tensor | None ) -> (torch.Tensor, torch.Tensor): if mask is None: if self.multichannel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype) else: if self.multichannel: mask = mask.expand_as(input) else: mask = mask.expand(1, 1, *input.shape[2:]) return mask def forward( self, input: torch.Tensor, mask: torch.Tensor | None = None ) -> (torch.Tensor, torch.Tensor | None): if mask is not None: input *= mask mask = self.get_mask(input, mask) if self.partial_conv: output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) mask_kernel_numel = self.mask_weight.data.shape[1:].numel() mask_ratio = mask_kernel_numel / (mask + self.eps) mask.clamp_(0, 1) # Apply re-weighting and bias output *= mask_ratio if self.bias is not None: output += self.bias.view(-1, 1, 1) output *= mask else: output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] mask = mask / max_vals return output, mask def extra_repr(self): return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}" class MaskedPixelUnshuffle(nn.PixelUnshuffle): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): return super().forward(input), super().forward(mask) if mask is not None else None class MaskedSequential(nn.Sequential): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): for module in self: input, mask = module(input, mask) return input, mask @contextmanager def register_hooks( model: torch.nn.Module, hook: Callable, predicate: Callable[[str, torch.nn.Module], bool], **hook_kwargs ): handles = [] try: for name, module in model.named_modules(): if predicate(name, module): hook: Callable = partial(hook, name=name, **hook_kwargs) handle = module.register_forward_hook(hook) handles.append(handle) yield handles finally: for handle in handles: handle.remove() def activations_recorder_hook( module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, name: str, *, storage: dict[str, Any] ): if name in storage: if isinstance(storage[name], list): storage[name].append(output) else: storage[name] = [storage[name], output] else: storage[name] = output def forward_with_activations( model: torch.nn.Module, predicate: Callable[[str, torch.nn.Module], bool], *model_args, **model_kwargs, ) -> Tuple[torch.Tensor, dict[str, Any]]: storage = {} with register_hooks(model, activations_recorder_hook, predicate, storage=storage): output = model(*model_args, **model_kwargs) return output, storage def test_it(): torch.manual_seed(37) in_channels = 3 downscale_factor = 2 scale = 1 base = 2 depth = 8 visualize_depth = 6 eps = 1e-8 conv = [] for i in range(depth): conv.append(nn.PixelUnshuffle(downscale_factor)) conv.append(nn.Conv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False) ) conv = nn.Sequential(*conv) pconv = [] for i in range(depth): pconv.append(MaskedPixelUnshuffle(downscale_factor)) pconv.append(PartialConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True) ) pconv = MaskedSequential(*pconv) mpconv = [] for i in range(depth): mpconv.append(MaskedPixelUnshuffle(downscale_factor)) mpconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True) ) mpconv = MaskedSequential(*mpconv) mconv = [] for i in range(depth): mconv.append(MaskedPixelUnshuffle(downscale_factor)) mconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False) ) mconv = MaskedSequential(*mconv) with torch.no_grad(): print(f"{conv=}") print(f"{pconv=}") print(f"{mpconv=}") print(f"{mconv=}") print(f"{list(conv.state_dict().keys())=}") print(f"{list(pconv.state_dict().keys())=}") print(f"{list(mpconv.state_dict().keys())=}") print(f"{list(mconv.state_dict().keys())=}") pconv.load_state_dict(conv.state_dict()) mpconv.load_state_dict(conv.state_dict()) mconv.load_state_dict(conv.state_dict()) # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth) x = torch.randn(1, in_channels, 512, 512) x_mask = torch.ones_like(x) x_mask[..., 128:256, 128:256] = 0 def is_conv_predicate(name: str, module: torch.nn.Module): return isinstance(module, torch.nn.Conv2d) y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x * x_mask) (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, x_mask) (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, x_mask) (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, x_mask) assert not torch.allclose(y_conv, y_mpconv) assert torch.allclose(y_mpconv, y_pconv) assert not torch.allclose(y_mconv, y_mpconv) print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15'] # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180) fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180) axs = axs.flatten() for impl_i, (name, y, mask, activations) in enumerate([ ("conv", y_conv, None, activations_conv), ("pconv", y_pconv, mask_pconv, activations_pconv), ("mpconv", y_mpconv, mask_mpconv, activations_mpconv), ("mconv", y_mconv, mask_mconv, activations_mconv) ]): batch_i = 0 for depth_i in range(visualize_depth): # ax = axs[depth_i * 4 + impl_i] ax = axs[impl_i * visualize_depth + depth_i] layer_output = activations[f"{depth_i * 2 + 1}"] if isinstance(layer_output, torch.Tensor): output = layer_output[batch_i] mask_output = None else: output = layer_output[0][batch_i] mask_output = layer_output[1][batch_i] assert output.dim() == 3 mean = output.mean() std = output.std(unbiased=False) skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps) kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps) print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}") # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std) ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std) ax.set_title(f"{name} {depth_i=}") ax.axis('off') # plt.suptitle(f"Depth {depth_i}") plt.show() if __name__ == '__main__': test_it() ```
Output (notice large kurtosis, which means that there is more peaking outliers in the distribution):
name='conv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3134), skewness=tensor(0.0081), kurtosis=tensor(3.2148)
name='conv', depth_i=2, mean=tensor(0.0002), std=tensor(0.1794), skewness=tensor(0.0086), kurtosis=tensor(3.2706)
name='conv', depth_i=3, mean=tensor(6.1037e-06), std=tensor(0.1016), skewness=tensor(0.0055), kurtosis=tensor(3.2192)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0566), skewness=tensor(-0.0155), kurtosis=tensor(3.2757)
name='conv', depth_i=5, mean=tensor(0.0004), std=tensor(0.0301), skewness=tensor(-0.0230), kurtosis=tensor(3.1709)
name='pconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='pconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='pconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='pconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='pconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='pconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mpconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='mpconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='mpconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='mpconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='mpconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='mpconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3124), skewness=tensor(0.0086), kurtosis=tensor(3.2303)
name='mconv', depth_i=2, mean=tensor(0.0001), std=tensor(0.1776), skewness=tensor(0.0087), kurtosis=tensor(3.3192)
name='mconv', depth_i=3, mean=tensor(-1.3955e-05), std=tensor(0.0991), skewness=tensor(0.0069), kurtosis=tensor(3.3181)
name='mconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0537), skewness=tensor(-0.0256), kurtosis=tensor(3.4699)
name='mconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0266), skewness=tensor(-0.0291), kurtosis=tensor(3.3908)
partialconv even worse than regular convolution in object detection task (DETR-like model with Hungarian loss to minimize). Training performed of different image sizes batched, with their respective mask.
image 1:
image 2:
I have received a response from the authors. I will provide further details via email.
I have released Masked Convolution for Diverse Sample Sizes, so you can now use the fix from this issue under permissive license: https://github.com/ivanstepanovftw/masked_torch
I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":
I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.
Here is code for your reference to double check your implementation, my implementation, and fix by yourself:
Code
```python from contextlib import contextmanager from functools import partial from typing import Tuple, Any, Callable import torch import torch.nn.functional as F from matplotlib import pyplot as plt from torch import nn, Tensor class PartialConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): # whether the mask is multi-channel or not if 'multi_channel' in kwargs: self.multi_channel = kwargs['multi_channel'] kwargs.pop('multi_channel') else: self.multi_channel = False if 'return_mask' in kwargs: self.return_mask = kwargs['return_mask'] kwargs.pop('return_mask') else: self.return_mask = False super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])) else: self.register_buffer(name='weight_maskUpdater', persistent=False, tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] self.last_size = (None, None, None, None) self.update_mask = None self.mask_ratio = None def forward(self, input, mask_in=None): assert len(input.shape) == 4 if mask_in is not None or self.last_size != tuple(input.shape): self.last_size = tuple(input.shape) with torch.no_grad(): if mask_in is None: # if mask is not provided, create a mask if self.multi_channel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype) else: mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) # for mixed precision training, change 1e-8 to 1e-6 self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) self.update_mask = torch.clamp(self.update_mask, 0, 1) self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view output = torch.mul(output, self.update_mask) else: output = torch.mul(raw_out, self.mask_ratio) if self.return_mask: return output, self.update_mask else: return output class MaskedConv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', eps=1e-8, multichannel: bool = False, partial_conv: bool = False, device=None, dtype=None ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) if multichannel: self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False) else: self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False) self.eps = eps self.multichannel = multichannel self.partial_conv = partial_conv def get_mask( self, input: torch.Tensor, mask: torch.Tensor | None ) -> (torch.Tensor, torch.Tensor): if mask is None: if self.multichannel: mask = torch.ones_like(input) else: mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype) else: if self.multichannel: mask = mask.expand_as(input) else: mask = mask.expand(1, 1, *input.shape[2:]) return mask def forward( self, input: torch.Tensor, mask: torch.Tensor | None = None ) -> (torch.Tensor, torch.Tensor | None): if mask is not None: input *= mask mask = self.get_mask(input, mask) if self.partial_conv: output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) mask_kernel_numel = self.mask_weight.data.shape[1:].numel() mask_ratio = mask_kernel_numel / (mask + self.eps) mask.clamp_(0, 1) # Apply re-weighting and bias output *= mask_ratio if self.bias is not None: output += self.bias.view(-1, 1, 1) output *= mask else: output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1) max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] mask = mask / max_vals return output, mask def extra_repr(self): return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}" class MaskedPixelUnshuffle(nn.PixelUnshuffle): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): return super().forward(input), super().forward(mask) if mask is not None else None class MaskedSequential(nn.Sequential): def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None): for module in self: input, mask = module(input, mask) return input, mask @contextmanager def register_hooks( model: torch.nn.Module, hook: Callable, predicate: Callable[[str, torch.nn.Module], bool], **hook_kwargs ): handles = [] try: for name, module in model.named_modules(): if predicate(name, module): hook: Callable = partial(hook, name=name, **hook_kwargs) handle = module.register_forward_hook(hook) handles.append(handle) yield handles finally: for handle in handles: handle.remove() def activations_recorder_hook( module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, name: str, *, storage: dict[str, Any] ): if name in storage: if isinstance(storage[name], list): storage[name].append(output) else: storage[name] = [storage[name], output] else: storage[name] = output def forward_with_activations( model: torch.nn.Module, predicate: Callable[[str, torch.nn.Module], bool], *model_args, **model_kwargs, ) -> Tuple[torch.Tensor, dict[str, Any]]: storage = {} with register_hooks(model, activations_recorder_hook, predicate, storage=storage): output = model(*model_args, **model_kwargs) return output, storage def test_it(): torch.manual_seed(37) in_channels = 3 downscale_factor = 2 scale = 1 base = 2 depth = 8 visualize_depth = 4 eps = 1e-8 pconv = [] for i in range(depth): pconv.append(MaskedPixelUnshuffle(downscale_factor)) pconv.append(PartialConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True) ) pconv = MaskedSequential(*pconv) mpconv = [] for i in range(depth): mpconv.append(MaskedPixelUnshuffle(downscale_factor)) mpconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True) ) mpconv = MaskedSequential(*mpconv) mconv = [] for i in range(depth): mconv.append(MaskedPixelUnshuffle(downscale_factor)) mconv.append(MaskedConv2d( in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2, out_channels=scale * base ** i * downscale_factor ** 2, kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False) ) mconv = MaskedSequential(*mconv) with torch.no_grad(): print(f"{pconv=}") print(f"{mpconv=}") print(f"{mconv=}") print(f"{list(pconv.state_dict().keys())=}") print(f"{list(mpconv.state_dict().keys())=}") print(f"{list(mconv.state_dict().keys())=}") mpconv.load_state_dict(pconv.state_dict()) mconv.load_state_dict(pconv.state_dict()) x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth) mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x) def is_conv_predicate(name: str, module: torch.nn.Module): return isinstance(module, torch.nn.Conv2d) (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv) (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv) (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv) assert torch.allclose(y_mpconv, y_pconv) assert not torch.allclose(y_mconv, y_mpconv) print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15'] # fig, axs = plt.subplots(nrows=visualize_depth, ncols=3, figsize=(12, 8), dpi=180) fig, axs = plt.subplots(nrows=3, ncols=visualize_depth, figsize=(12, 8), dpi=180) axs = axs.flatten() for impl_i, (name, y, mask, activations) in enumerate([ ("pconv", y_pconv, mask_pconv, activations_pconv), ("mpconv", y_mpconv, mask_mpconv, activations_mpconv), ("mconv", y_mconv, mask_mconv, activations_mconv) ]): batch_i = 0 for depth_i in range(visualize_depth): # ax = axs[depth_i * 3 + impl_i] ax = axs[impl_i * visualize_depth + depth_i] output = activations[f"{depth_i * 2 + 1}"][0][batch_i] mask_output = activations[f"{depth_i * 2 + 1}"][1][batch_i] mean = output.mean() std = output.std(unbiased=False) skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps) kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps) print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}") ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std) ax.set_title(f"{name} {depth_i=}") ax.axis('off') # plt.suptitle(f"Depth {depth_i}") plt.show() if __name__ == '__main__': test_it() ```
Output:
pconv
is an original implementation of partial conv (this repo)mpconv
is my implementation of partial convmconv
is my approach of masked convolutionHere is also activations on real images: