Open khoshsirat opened 5 years ago
Dear Ali, thanks for your interest in our work, please find below a short excerpt for the 2D implementation
class OBELISKnet(nn.Module): #original trainable brief layer in obelisk
def __init__(self):
super(OBELISKnet, self).__init__()
self.offset = nn.Parameter(torch.randn(1,512,1,2)*0.15)
self.linear1 = nn.Conv1d(512,256,1,groups=4,bias=False)
self.batch1 = nn.BatchNorm1d(256)
self.linear2 = nn.Conv1d(256, 64,1,bias=False)
self.batch2 = nn.BatchNorm1d(64)
def forward(self, x, sample_grid):
# sample_grid: 1 x 1 x #samples x 2
# offsets: 1 x #offsets x 1 x 2
B,C,H,W = x.size()
x = F.grid_sample(x, (sample_grid + self.offset)).view(B,512,-1)
x = self.batch1(self.linear1(x))
x = self.batch2(self.linear2(x))
return x
class ClassifyNet(nn.Module):
def __init__(self):
super(ClassifyNet,self).__init__()
self.linear1 = nn.Conv2d(64, 64,1,bias=False)
self.batch1 = nn.BatchNorm2d(64)
self.linear2 = nn.Conv2d(64, 32,1,bias=False)
self.batch2 = nn.BatchNorm2d(32)
self.linear3 = nn.Conv2d(32, 32,1,bias=False)
self.batch3 = nn.BatchNorm2d(32)
self.predict = nn.Conv2d(32, 8, 1)
def forward(self, x):
x = self.batch1(self.linear1(x))
x = self.batch2(self.linear2(x))
x = self.batch3(self.linear3(x))
return self.predict(x)
Note that this is the simple so called unary version. For the binary version (subtracting two kernel elements), you'd have to change the tensor size of self.offset to e.g. 1,512,1,4 and call the grid_sample twice: x = (F.grid_sample(x, (sample_grid + self.offset[:,:,:,:2]))-F.grid_sample(x, (sample_grid + self.offset[:,:,:,2:]))).view(B,512,-1) Also note that in this case the input images should be pre-smoothed (simply some average poolings with stride=1) for best performance. You can call this function either using random spatial sampling coordinates or a (coarse) regular grid, e.g.: identity = torch.eye(3)[:2,:].unsqueeze(0) meshgrid = F.affine_grid(identity,torch.ones(1,1,50,50).size()).view(1,1,-1,2) Please, let me know if you need more details
Dear Mattias, I modified the codes of obeliskhybrid_visceral and obelisk_visceral for 2D. Please let me know what do you think about them. Works without errors, but may be some logical problem? Thanks.
class HybridObelisk2D(nn.Module): def init(self,out_channels,full_res): super(HybridObelisk2D, self).init() self.out_channels = out_channels H_in1 = full_res[0]; W_in1 = full_res[1]; H_in2 = (H_in1+1)//2; W_in2 = (W_in1+1)//2; #half resolution self.half_res = torch.Tensor([H_in2,W_in2]).long(); half_res = self.half_res H_in4 = (H_in2+1)//2; W_in4 = (W_in2+1)//2; #quarter resolution self.quarter_res = torch.Tensor([H_in4,W_in4]).long(); quarter_res = self.quarter_res H_in8 = (H_in4+1)//2; W_in8 = (W_in4+1)//2; #eighth resolution self.eighth_res = torch.Tensor([H_in8,W_in8]).long(); eighth_res = self.eighth_res
#U-Net Encoder
self.conv0 = nn.Conv2d(1, 4, 3, padding=1)
self.batch0 = nn.BatchNorm2d(4)
self.conv1 = nn.Conv2d(4, 16, 3, stride=2, padding=1)
self.batch1 = nn.BatchNorm2d(16)
self.conv11 = nn.Conv2d(16, 16, 3, padding=1)
self.batch11 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
self.batch2 = nn.BatchNorm2d(32)
# Obelisk Encoder (for simplicity using regular sampling grid)
# the first obelisk layer has 128 the second 512 trainable offsets
# sample_grid: 1 x 1 x #samples x 1 x 3
# offsets: 1 x #offsets x 1 x 1 x 3
self.sample_grid1 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,quarter_res[0],quarter_res[1]))).view(1,1,-1,2).detach()
self.sample_grid1.requires_grad = False
self.sample_grid2 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,eighth_res[0],eighth_res[1]))).view(1,1,-1,2).detach()
self.sample_grid2.requires_grad = False
self.offset1 = nn.Parameter(torch.randn(1,128,1,2)*0.05)
self.linear1a = nn.Conv2d(4*128,128,1,groups=4,bias=False)
self.batch1a = nn.BatchNorm2d(128)
self.linear1b = nn.Conv2d(128,32,1,bias=False)
self.batch1b = nn.BatchNorm2d(128+32)
self.linear1c = nn.Conv2d(128+32,32,1,bias=False)
self.batch1c = nn.BatchNorm2d(128+64)
self.linear1d = nn.Conv2d(128+64,32,1,bias=False)
self.batch1d = nn.BatchNorm2d(128+96)
self.linear1e = nn.Conv2d(128+96,32,1,bias=False)
self.offset2 = nn.Parameter(torch.randn(1,512,1,2)*0.05)
self.linear2a = nn.Conv2d(512,128,1,groups=4,bias=False)
self.batch2a = nn.BatchNorm2d(128)
self.linear2b = nn.Conv2d(128,32,1,bias=False)
self.batch2b = nn.BatchNorm2d(128+32)
self.linear2c = nn.Conv2d(128+32,32,1,bias=False)
self.batch2c = nn.BatchNorm2d(128+64)
self.linear2d = nn.Conv2d(128+64,32,1,bias=False)
self.batch2d = nn.BatchNorm2d(128+96)
self.linear2e = nn.Conv2d(128+96,32,1,bias=False)
#U-Net Decoder
self.conv6bU = nn.Conv2d(64, 32, 3, padding=1)
self.batch6bU = nn.BatchNorm2d(32)
self.conv6U = nn.Conv2d(64+16, 32, 3, padding=1)
self.batch6U = nn.BatchNorm2d(32)
self.conv8 = nn.Conv2d(32, out_channels, 1)
def forward(self, inputImg):
B,C,H,W = inputImg.size()
device = inputImg.device
leakage = 0.05 #leaky ReLU used for conventional CNNs
#unet-encoder
x00 = F.avg_pool2d(inputImg,3,padding=1,stride=1)
x1 = F.leaky_relu(self.batch0(self.conv0(inputImg)), leakage)
x = F.leaky_relu(self.batch1(self.conv1(x1)),leakage)
x2 = F.leaky_relu(self.batch11(self.conv11(x)),leakage)
x = F.leaky_relu(self.batch2(self.conv2(x2)),leakage)
#in this model two obelisk layers with fewer spatial offsets are used
#obelisk layer 1
x_o1 = F.grid_sample(x1, (self.sample_grid1.to(device).repeat(B,1,1,1) + self.offset1)).view(B,-1,self.quarter_res[0],self.quarter_res[1])
#1x1 kernel dense-net
x_o1 = F.relu(self.linear1a(x_o1))
x_o1a = torch.cat((x_o1,F.relu(self.linear1b(self.batch1a(x_o1)))),dim=1)
x_o1b = torch.cat((x_o1a,F.relu(self.linear1c(self.batch1b(x_o1a)))),dim=1)
x_o1c = torch.cat((x_o1b,F.relu(self.linear1d(self.batch1c(x_o1b)))),dim=1)
x_o1d = F.relu(self.linear1e(self.batch1d(x_o1c)))
x_o1 = F.interpolate(x_o1d, size=[self.half_res[0],self.half_res[1]], mode='bilinear', align_corners=False)
#obelisk layer 2
x_o2 = F.grid_sample(x00, (self.sample_grid2.to(device).repeat(B,1,1,1) + self.offset2)).view(B,-1,self.eighth_res[0],self.eighth_res[1])
x_o2 = F.relu(self.linear2a(x_o2))
#1x1 kernel dense-net
x_o2a = torch.cat((x_o2,F.relu(self.linear2b(self.batch2a(x_o2)))),dim=1)
x_o2b = torch.cat((x_o2a,F.relu(self.linear2c(self.batch2b(x_o2a)))),dim=1)
x_o2c = torch.cat((x_o2b,F.relu(self.linear2d(self.batch2c(x_o2b)))),dim=1)
x_o2d = F.relu(self.linear2e(self.batch2d(x_o2c)))
x_o2 = F.interpolate(x_o2d, size=[self.quarter_res[0],self.quarter_res[1]], mode='bilinear', align_corners=False)
#unet-decoder
x = F.leaky_relu(self.batch6bU(self.conv6bU(torch.cat((x,x_o2),1))),leakage)
x = F.interpolate(x, size=[self.half_res[0],self.half_res[1]], mode='bilinear', align_corners=False)
x = F.leaky_relu(self.batch6U(self.conv6U(torch.cat((x,x_o1,x2),1))),leakage)
x = F.interpolate(self.conv8(x), size=[H,W], mode='bilinear', align_corners=False)
return x
class Obelisk2D(nn.Module): def init(self,out_channels,full_res): super(Obelisk2D, self).init() self.out_channels = out_channels self.full_res = full_res H_in1 = full_res[0]; W_in1 = full_res[1]; H_in2 = (H_in1+1)//2; W_in2 = (W_in1+1)//2; #half resolution self.half_res = torch.Tensor([H_in2,W_in2]).long(); half_res = self.half_res H_in4 = (H_in2+1)//2; W_in4 = (W_in2+1)//2; #quarter resolution self.quarter_res = torch.Tensor([H_in4,W_in4]).long(); quarter_res = self.quarter_res
#Obelisk Layer
# sample_grid: 1 x 1 x #samples x 1 x 3
# offsets: 1 x #offsets x 1 x 1 x 3
self.sample_grid1 = F.affine_grid(torch.eye(2,3).unsqueeze(0),torch.Size((1,1,quarter_res[0],quarter_res[1])))
self.sample_grid1.requires_grad = False
#in this model (binary-variant) two spatial offsets are paired
self.offset1 = nn.Parameter(torch.randn(1,1024,1,2)*0.05)
#Dense-Net with 1x1x1 kernels
self.LIN1 = nn.Conv2d(1024, 256, 1, bias=False, groups=4) #grouped convolutions
self.BN1 = nn.BatchNorm2d(256)
self.LIN2 = nn.Conv2d(256, 128, 1, bias=False)
self.BN2 = nn.BatchNorm2d(128)
self.LIN3a = nn.Conv2d(128, 32, 1,bias=False)
self.BN3a = nn.BatchNorm2d(128+32)
self.LIN3b = nn.Conv2d(128+32, 32, 1,bias=False)
self.BN3b = nn.BatchNorm2d(128+64)
self.LIN3c = nn.Conv2d(128+64, 32, 1,bias=False)
self.BN3c = nn.BatchNorm2d(128+96)
self.LIN2d = nn.Conv2d(128+96, 32, 1,bias=False)
self.BN2d = nn.BatchNorm2d(256)
self.LIN4 = nn.Conv2d(256, out_channels,1)
def forward(self, inputImg, sample_grid=None):
B,C,H,W = inputImg.size()
if(sample_grid is None):
sample_grid = self.sample_grid1
sample_grid = sample_grid.to(inputImg.device)
#pre-smooth image (has to be done in advance for original models )
#x00 = F.avg_pool2d(inputImg,3,padding=1,stride=1)
_,H_grid,W_grid,_ = sample_grid.size()
input = F.grid_sample(inputImg, (sample_grid.view(1,1,-1,2).repeat(B,1,1,1) + self.offset1[:,:,:,0:1])).view(B,-1,H_grid,W_grid)-\
F.grid_sample(inputImg, (sample_grid.view(1,1,-1,2).repeat(B,1,1,1) + self.offset1[:,:,:,1:2])).view(B,-1,H_grid,W_grid)
x1 = F.relu(self.BN1(self.LIN1(input)))
x2 = self.BN2(self.LIN2(x1))
x3a = torch.cat((x2,F.relu(self.LIN3a(x2))),dim=1)
x3b = torch.cat((x3a,F.relu(self.LIN3b(self.BN3a(x3a)))),dim=1)
x3c = torch.cat((x3b,F.relu(self.LIN3c(self.BN3b(x3b)))),dim=1)
x2d = torch.cat((x3c,F.relu(self.LIN2d(self.BN3c(x3c)))),dim=1)
x4 = self.LIN4(self.BN2d(x2d))
#return half-resolution segmentation/prediction
return F.interpolate(x4, size=[self.half_res[0],self.half_res[1]], mode='bilinear',align_corners=False)
I am not able to implement the 2D version. I am using the sample code in the first page but the number of the output channels after calling the grid_sample function does not match the number of the spatial filter offsets. It would be very helpful if you could just add a very simple 2D implementation to the models.py file.