Closed Lucien66 closed 2 years ago
Can you share a code example?
class Dynamic_conv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=1, dilation=1, groups=1, if_bias=True, K=5, init_weight=False):
super(Dynamic_conv2d, self).__init__()
assert in_planes % groups == 0
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.if_bias = if_bias
self.K = K
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
if self.if_bias:
self.bias = nn.Parameter(torch.Tensor(K, out_planes), requires_grad=True)
else:
self.bias = None
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])
if self.if_bias:
nn.init.constant_(self.bias[i], 0)
def forward(self, inputs):
x = inputs['x']
softmax_attention = inputs['weights']
batch_size, in_planes, height, width = x.size()
x = x.contiguous().view(1, -1, height, width)
weight = self.weight.view(self.K, -1)
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
if self.bias is not None:
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups*batch_size)
else:
output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * batch_size)
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
return output
inputs['weights'] come from the output of other networks
THOP
currently does not register counting rule for functions, it only supports defining rules for nn.Modules
You may want to check the an example of matmul
https://github.com/Lyken17/pytorch-OpCounter/blob/d6d8ec033b973c931f1ded6dbd637a82c6247906/thop/vision/basic_hooks.py#L140
and its registration here https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/profile.py#L53
In my network structure, I adopt a dynamic structure through F.conv2d. Should I define the rules of F.conv2d myself