mattiaspaul / OBELISK

MIDL 2018 / MEDIA 2019: one binary extremely large and inflecting sparse kernel (pytorch)
MIT License
44 stars 9 forks source link

[Feature Request] 2D Implementation #2

Open khoshsirat opened 5 years ago

khoshsirat commented 5 years ago

I am not able to implement the 2D version. I am using the sample code in the first page but the number of the output channels after calling the grid_sample function does not match the number of the spatial filter offsets. It would be very helpful if you could just add a very simple 2D implementation to the models.py file.

mattiaspaul commented 5 years ago

Dear Ali, thanks for your interest in our work, please find below a short excerpt for the 2D implementation

class OBELISKnet(nn.Module): #original trainable brief layer in obelisk
    def __init__(self):
        super(OBELISKnet, self).__init__()

        self.offset = nn.Parameter(torch.randn(1,512,1,2)*0.15)
        self.linear1 = nn.Conv1d(512,256,1,groups=4,bias=False)
        self.batch1 = nn.BatchNorm1d(256)
        self.linear2 = nn.Conv1d(256, 64,1,bias=False)
        self.batch2 = nn.BatchNorm1d(64)

    def forward(self, x, sample_grid):
        # sample_grid: 1 x    1     x #samples x 2
        # offsets:     1 x #offsets x     1    x 2
        B,C,H,W = x.size()

        x = F.grid_sample(x, (sample_grid + self.offset)).view(B,512,-1)
        x = self.batch1(self.linear1(x))
        x = self.batch2(self.linear2(x))

        return x

class ClassifyNet(nn.Module):
    def __init__(self):
        super(ClassifyNet,self).__init__()
        self.linear1 = nn.Conv2d(64, 64,1,bias=False)
        self.batch1 = nn.BatchNorm2d(64)
        self.linear2 = nn.Conv2d(64, 32,1,bias=False)
        self.batch2 = nn.BatchNorm2d(32)
        self.linear3 = nn.Conv2d(32, 32,1,bias=False)
        self.batch3 = nn.BatchNorm2d(32)
        self.predict = nn.Conv2d(32, 8, 1)

    def forward(self, x):

        x = self.batch1(self.linear1(x))
        x = self.batch2(self.linear2(x))
        x = self.batch3(self.linear3(x))

        return self.predict(x)

Note that this is the simple so called unary version. For the binary version (subtracting two kernel elements), you'd have to change the tensor size of self.offset to e.g. 1,512,1,4 and call the grid_sample twice: x = (F.grid_sample(x, (sample_grid + self.offset[:,:,:,:2]))-F.grid_sample(x, (sample_grid + self.offset[:,:,:,2:]))).view(B,512,-1) Also note that in this case the input images should be pre-smoothed (simply some average poolings with stride=1) for best performance. You can call this function either using random spatial sampling coordinates or a (coarse) regular grid, e.g.: identity = torch.eye(3)[:2,:].unsqueeze(0) meshgrid = F.affine_grid(identity,torch.ones(1,1,50,50).size()).view(1,1,-1,2) Please, let me know if you need more details

soumickmj commented 4 years ago

Dear Mattias, I modified the codes of obeliskhybrid_visceral and obelisk_visceral for 2D. Please let me know what do you think about them. Works without errors, but may be some logical problem? Thanks.

Hybrid OBELISK CNN model that contains two obelisk layers combined with traditional CNNs

the layers have 512 and 128 trainable offsets and 230k trainable weights in total

class HybridObelisk2D(nn.Module): def init(self,out_channels,full_res): super(HybridObelisk2D, self).init() self.out_channels = out_channels H_in1 = full_res[0]; W_in1 = full_res[1]; H_in2 = (H_in1+1)//2; W_in2 = (W_in1+1)//2; #half resolution self.half_res = torch.Tensor([H_in2,W_in2]).long(); half_res = self.half_res H_in4 = (H_in2+1)//2; W_in4 = (W_in2+1)//2; #quarter resolution self.quarter_res = torch.Tensor([H_in4,W_in4]).long(); quarter_res = self.quarter_res H_in8 = (H_in4+1)//2; W_in8 = (W_in4+1)//2; #eighth resolution self.eighth_res = torch.Tensor([H_in8,W_in8]).long(); eighth_res = self.eighth_res

    #U-Net Encoder
    self.conv0 = nn.Conv2d(1, 4, 3, padding=1)
    self.batch0 = nn.BatchNorm2d(4)
    self.conv1 = nn.Conv2d(4, 16, 3, stride=2, padding=1)
    self.batch1 = nn.BatchNorm2d(16)
    self.conv11 = nn.Conv2d(16, 16, 3, padding=1)
    self.batch11 = nn.BatchNorm2d(16)
    self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
    self.batch2 = nn.BatchNorm2d(32)

    # Obelisk Encoder (for simplicity using regular sampling grid)
    # the first obelisk layer has 128 the second 512 trainable offsets
    # sample_grid: 1 x    1     x #samples x 1 x 3
    # offsets:     1 x #offsets x     1    x 1 x 3
    self.sample_grid1 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,quarter_res[0],quarter_res[1]))).view(1,1,-1,2).detach()
    self.sample_grid1.requires_grad = False
    self.sample_grid2 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,eighth_res[0],eighth_res[1]))).view(1,1,-1,2).detach()
    self.sample_grid2.requires_grad = False

    self.offset1 = nn.Parameter(torch.randn(1,128,1,2)*0.05)
    self.linear1a = nn.Conv2d(4*128,128,1,groups=4,bias=False)
    self.batch1a = nn.BatchNorm2d(128)
    self.linear1b = nn.Conv2d(128,32,1,bias=False)
    self.batch1b = nn.BatchNorm2d(128+32)
    self.linear1c = nn.Conv2d(128+32,32,1,bias=False)
    self.batch1c = nn.BatchNorm2d(128+64)
    self.linear1d = nn.Conv2d(128+64,32,1,bias=False)
    self.batch1d = nn.BatchNorm2d(128+96)
    self.linear1e = nn.Conv2d(128+96,32,1,bias=False)

    self.offset2 = nn.Parameter(torch.randn(1,512,1,2)*0.05)
    self.linear2a = nn.Conv2d(512,128,1,groups=4,bias=False)
    self.batch2a = nn.BatchNorm2d(128)
    self.linear2b = nn.Conv2d(128,32,1,bias=False)
    self.batch2b = nn.BatchNorm2d(128+32)
    self.linear2c = nn.Conv2d(128+32,32,1,bias=False)
    self.batch2c = nn.BatchNorm2d(128+64)
    self.linear2d = nn.Conv2d(128+64,32,1,bias=False)
    self.batch2d = nn.BatchNorm2d(128+96)
    self.linear2e = nn.Conv2d(128+96,32,1,bias=False)

    #U-Net Decoder 
    self.conv6bU = nn.Conv2d(64, 32, 3, padding=1)
    self.batch6bU = nn.BatchNorm2d(32)
    self.conv6U = nn.Conv2d(64+16, 32, 3, padding=1)
    self.batch6U = nn.BatchNorm2d(32)
    self.conv8 = nn.Conv2d(32, out_channels, 1)

def forward(self, inputImg):

    B,C,H,W = inputImg.size()
    device = inputImg.device
    leakage = 0.05 #leaky ReLU used for conventional CNNs

    #unet-encoder
    x00 = F.avg_pool2d(inputImg,3,padding=1,stride=1)

    x1 = F.leaky_relu(self.batch0(self.conv0(inputImg)), leakage)
    x = F.leaky_relu(self.batch1(self.conv1(x1)),leakage)
    x2 = F.leaky_relu(self.batch11(self.conv11(x)),leakage)
    x = F.leaky_relu(self.batch2(self.conv2(x2)),leakage)

    #in this model two obelisk layers with fewer spatial offsets are used
    #obelisk layer 1
    x_o1 = F.grid_sample(x1, (self.sample_grid1.to(device).repeat(B,1,1,1) + self.offset1)).view(B,-1,self.quarter_res[0],self.quarter_res[1])
    #1x1 kernel dense-net
    x_o1 = F.relu(self.linear1a(x_o1))
    x_o1a = torch.cat((x_o1,F.relu(self.linear1b(self.batch1a(x_o1)))),dim=1)
    x_o1b = torch.cat((x_o1a,F.relu(self.linear1c(self.batch1b(x_o1a)))),dim=1)
    x_o1c = torch.cat((x_o1b,F.relu(self.linear1d(self.batch1c(x_o1b)))),dim=1)
    x_o1d = F.relu(self.linear1e(self.batch1d(x_o1c)))
    x_o1 = F.interpolate(x_o1d, size=[self.half_res[0],self.half_res[1]], mode='bilinear', align_corners=False)

    #obelisk layer 2
    x_o2 = F.grid_sample(x00, (self.sample_grid2.to(device).repeat(B,1,1,1) + self.offset2)).view(B,-1,self.eighth_res[0],self.eighth_res[1])
    x_o2 = F.relu(self.linear2a(x_o2))
    #1x1 kernel dense-net
    x_o2a = torch.cat((x_o2,F.relu(self.linear2b(self.batch2a(x_o2)))),dim=1)
    x_o2b = torch.cat((x_o2a,F.relu(self.linear2c(self.batch2b(x_o2a)))),dim=1)
    x_o2c = torch.cat((x_o2b,F.relu(self.linear2d(self.batch2c(x_o2b)))),dim=1)
    x_o2d = F.relu(self.linear2e(self.batch2d(x_o2c)))
    x_o2 = F.interpolate(x_o2d, size=[self.quarter_res[0],self.quarter_res[1]], mode='bilinear', align_corners=False)

    #unet-decoder
    x = F.leaky_relu(self.batch6bU(self.conv6bU(torch.cat((x,x_o2),1))),leakage)
    x = F.interpolate(x, size=[self.half_res[0],self.half_res[1]], mode='bilinear', align_corners=False)
    x = F.leaky_relu(self.batch6U(self.conv6U(torch.cat((x,x_o1,x2),1))),leakage)
    x = F.interpolate(self.conv8(x), size=[H,W], mode='bilinear', align_corners=False)

    return x

original OBELISK model as described in MIDL2018 paper

contains around 130k trainable parameters and 1024 binary offsets

most simple Obelisk-Net with one deformable convolution followed by 1x1 Dense-Net

class Obelisk2D(nn.Module): def init(self,out_channels,full_res): super(Obelisk2D, self).init() self.out_channels = out_channels self.full_res = full_res H_in1 = full_res[0]; W_in1 = full_res[1]; H_in2 = (H_in1+1)//2; W_in2 = (W_in1+1)//2; #half resolution self.half_res = torch.Tensor([H_in2,W_in2]).long(); half_res = self.half_res H_in4 = (H_in2+1)//2; W_in4 = (W_in2+1)//2; #quarter resolution self.quarter_res = torch.Tensor([H_in4,W_in4]).long(); quarter_res = self.quarter_res

    #Obelisk Layer
    # sample_grid: 1 x    1     x #samples x 1 x 3
    # offsets:     1 x #offsets x     1    x 1 x 3

    self.sample_grid1 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,quarter_res[0],quarter_res[1])))
    self.sample_grid1.requires_grad = False

    #in this model (binary-variant) two spatial offsets are paired 
    self.offset1 = nn.Parameter(torch.randn(1,1024,1,2)*0.05)

    #Dense-Net with 1x1x1 kernels
    self.LIN1 = nn.Conv2d(1024, 256, 1, bias=False, groups=4) #grouped convolutions
    self.BN1 = nn.BatchNorm2d(256)
    self.LIN2 = nn.Conv2d(256, 128, 1, bias=False)
    self.BN2 = nn.BatchNorm2d(128)

    self.LIN3a = nn.Conv2d(128, 32, 1,bias=False)
    self.BN3a = nn.BatchNorm2d(128+32)
    self.LIN3b = nn.Conv2d(128+32, 32, 1,bias=False)
    self.BN3b = nn.BatchNorm2d(128+64)
    self.LIN3c = nn.Conv2d(128+64, 32, 1,bias=False)
    self.BN3c = nn.BatchNorm2d(128+96)
    self.LIN2d = nn.Conv2d(128+96, 32, 1,bias=False)
    self.BN2d = nn.BatchNorm2d(256)

    self.LIN4 = nn.Conv2d(256, out_channels,1)

def forward(self, inputImg, sample_grid=None):

    B,C,H,W = inputImg.size()
    if(sample_grid is None):
        sample_grid = self.sample_grid1
    sample_grid = sample_grid.to(inputImg.device)    
    #pre-smooth image (has to be done in advance for original models )
    #x00 = F.avg_pool2d(inputImg,3,padding=1,stride=1)

    _,H_grid,W_grid,_ = sample_grid.size()
    input = F.grid_sample(inputImg, (sample_grid.view(1,1,-1,2).repeat(B,1,1,1) + self.offset1[:,:,:,0:1])).view(B,-1,H_grid,W_grid)-\
    F.grid_sample(inputImg, (sample_grid.view(1,1,-1,2).repeat(B,1,1,1) + self.offset1[:,:,:,1:2])).view(B,-1,H_grid,W_grid)

    x1 = F.relu(self.BN1(self.LIN1(input)))
    x2 = self.BN2(self.LIN2(x1))

    x3a = torch.cat((x2,F.relu(self.LIN3a(x2))),dim=1)
    x3b = torch.cat((x3a,F.relu(self.LIN3b(self.BN3a(x3a)))),dim=1)
    x3c = torch.cat((x3b,F.relu(self.LIN3c(self.BN3b(x3b)))),dim=1)
    x2d = torch.cat((x3c,F.relu(self.LIN2d(self.BN3c(x3c)))),dim=1)

    x4 = self.LIN4(self.BN2d(x2d))
    #return half-resolution segmentation/prediction 
    return F.interpolate(x4, size=[self.half_res[0],self.half_res[1]], mode='bilinear',align_corners=False)