XinJCheng / CSPN

Convolutional Spatial Propagation Network
496 stars 92 forks source link

up_pooling is not fast #8

Closed dontLoveBugs closed 5 years ago

dontLoveBugs commented 5 years ago

I think your un_pooling operation is not fast. I recommend using the code below.

Unpool: 2*2 unpooling with zero padding

class Unpool(nn.Module):
  def __init__(self, num_channels, stride=2):
    super(Unpool, self).__init__()
    self.num_channels = num_channels
    self.stride = stride

  def forward(self, x):
    weights = torch.zeros(self.num_channels, 1, self.stride, self.stride)
    if torch.cuda.is_available():
        weights = weights.cuda()
    weights[:, :, 0, 0] = 1
    return F.conv_transpose2d(x, weights, stride=self.stride, groups=self.num_channels)

`

XinJCheng commented 5 years ago

thx for your suggestion