zuoxingdong / VIN_PyTorch_Visdom

PyTorch implementation of Value Iteration Networks (VIN): Clean, Simple and Modular. Visualization in Visdom.
226 stars 39 forks source link

How did you implement channel-wise Max Pooling? #4

Closed Youngkl0726 closed 6 years ago

zuoxingdong commented 6 years ago

@Youngkl0726 Thanks for asking, it has a fancy name channel-wise max pooling, but in practice it is quite silly simple, just in this line. Think about when we have a feature tensor of 'virtual Q-values', the channel-wise max operation can be simply done by taking a max operation over the channel dimension of the tensor. Hope it helps.

zuoxingdong commented 6 years ago

@Youngkl0726 By the way, I have a recent plan to upgrade the code to latest PyTorch and integrate into lagom framework ( an RL infrastructure on top of PyTorch that I have developed in the past year.)

If you are interested in, it might be very helpful to get some feedback for the new version of code.

martinodonnell commented 4 years ago

Not an issue, more like a question. I am trying to do the same channel pooling but with a specific kernel size. I have used a lot of for loops for this. Is there any way to speed this up?

Input is [32,512,7,7] # Batch size, channel dim, kernel weight, kernel height] I have been using the values in constructor to ensure that a filter size of 7 works with the input

class ChannelPool(nn.Module):

def __init__(self, kernel_size=7, stride=2, padding=3, dilation=1,
             return_indices=False, ceil_mode=False):
    super().__init__()
    self.kernel_size = kernel_size
    self.stride = stride or kernel_size
    self.padding = padding
    self.dilation = dilation
    self.return_indices = return_indices
    self.ceil_mode = ceil_mode
    self.compression = 2
    self.output = None

def forward(self, input):

    n, c, w, h = input.size()
    #Add padding to input so work with kernal size
    input = torch.nn.functional.pad(input, (0, 0, 0, 0, self.padding, self.padding), "constant", 0)

    #Get output
    # output = torch.empty(n, int(c/self.compression), w, h)
    # for x in range(n):
    output = torch.stack([ 
                    torch.stack(
                        [torch.max(input[x][index:index+self.kernel_size-1],axis=0)[0] #Get max from kernal size
                        for index in range(0,input.size()[1]-self.kernel_size,self.stride)]) #Move stride
                        for x in range(n)]) #Do work for each image in batch

    return output.cuda()