Open CWanli opened 3 years ago
期待作者早日开源
This is my version of TGM and TRM. But it could not work in segmentation task. I hope you can tell me where went wrong.
import torch
import torch.nn as nn
import torch.nn.functional as F
class LowRankModule(nn.Module):
def __init__(self,in_channels, rank=8):
super(LowRankModule, self).__init__()
self.in_channels = in_channels
if rank is None:
self.rank = in_channels
else:
self.rank = rank
self.lamd = nn.Parameter(torch.rand(self.rank)).cuda()
self.conv_h_list = []
self.conv_w_list = []
self.conv_c_list = []
for r in range(self.rank):
self.conv_h_list.append(nn.Conv2d(1,1,1))
self.conv_w_list.append(nn.Conv2d(1,1,1))
self.conv_c_list.append(nn.Conv2d(in_channels,in_channels,1))
self.conv_c = nn.ModuleList(self.conv_c_list)
self.conv_h = nn.ModuleList(self.conv_h_list)
self.conv_w = nn.ModuleList(self.conv_w_list)
# self.conv_cat = nn.Conv2d(in_channels *2, in_channels,1)
def forward(self,x):
B,C,H,W = x.size()
x_hp = F.adaptive_avg_pool3d(x, (1,H,1))
x_wp = F.adaptive_avg_pool3d(x, (1,1,W))
x_cp = F.adaptive_avg_pool3d(x,(C,1,1))
attn_w = torch.zeros_like(x).to(x.device)
for r in range(self.rank):
x_c = self.conv_c[r](x_cp) # [b,c,1,1]
x_c = torch.sigmoid(x_c)
x_h = self.conv_h[r](x_hp) # [b,1,h,1]
x_h = torch.sigmoid(x_h)
x_w = self.conv_w[r](x_wp) # [b,1,1,w]
x_w = torch.sigmoid(x_w)
x_c = x_c.view(B,C,-1)
x_h = x_h.view(B,1,-1)
x_tmp = torch.bmm(x_c,x_h) # [b,c,h]
x_tmp = x_tmp.view(B,C*H).unsqueeze(-1) # [b,c,h,1]
x_w = x_w.view(B,1,-1) # [b,1,w]
x_final = torch.bmm(x_tmp,x_w)
x_final = x_final.view(B,C,H, W)
attn_w += self.lamd[r] * x_final
out = x * attn_w
return out
This is my version of TGM and TRM. But it could not work in segmentation task. I hope you can tell me where went wrong.
import torch import torch.nn as nn import torch.nn.functional as F class LowRankModule(nn.Module): def __init__(self,in_channels, rank=8): super(LowRankModule, self).__init__() self.in_channels = in_channels if rank is None: self.rank = in_channels else: self.rank = rank self.lamd = nn.Parameter(torch.rand(self.rank)).cuda() self.conv_h_list = [] self.conv_w_list = [] self.conv_c_list = [] for r in range(self.rank): self.conv_h_list.append(nn.Conv2d(1,1,1)) self.conv_w_list.append(nn.Conv2d(1,1,1)) self.conv_c_list.append(nn.Conv2d(in_channels,in_channels,1)) self.conv_c = nn.ModuleList(self.conv_c_list) self.conv_h = nn.ModuleList(self.conv_h_list) self.conv_w = nn.ModuleList(self.conv_w_list) # self.conv_cat = nn.Conv2d(in_channels *2, in_channels,1) def forward(self,x): B,C,H,W = x.size() x_hp = F.adaptive_avg_pool3d(x, (1,H,1)) x_wp = F.adaptive_avg_pool3d(x, (1,1,W)) x_cp = F.adaptive_avg_pool3d(x,(C,1,1)) attn_w = torch.zeros_like(x).to(x.device) for r in range(self.rank): x_c = self.conv_c[r](x_cp) # [b,c,1,1] x_c = torch.sigmoid(x_c) x_h = self.conv_h[r](x_hp) # [b,1,h,1] x_h = torch.sigmoid(x_h) x_w = self.conv_w[r](x_wp) # [b,1,1,w] x_w = torch.sigmoid(x_w) x_c = x_c.view(B,C,-1) x_h = x_h.view(B,1,-1) x_tmp = torch.bmm(x_c,x_h) # [b,c,h] x_tmp = x_tmp.view(B,C*H).unsqueeze(-1) # [b,c,h,1] x_w = x_w.view(B,1,-1) # [b,1,w] x_final = torch.bmm(x_tmp,x_w) x_final = x_final.view(B,C,H, W) attn_w += self.lamd[r] * x_final out = x * attn_w return out
Very close. I just uploaded my version.
So the author codes require resolution of the feature map is 51288. While the convolution part of Colorblank's version is 1*1, which is different from the author's?
统一回复一下: 我最近沉迷股市,然后被市场教育了,现在打算回归初心,打工赚钱。代码很快就会开源,同时我也征求到了DMNet和APCNet作者Junjun He的同意,届时也会上传他的代码。各种细节也可以email我直接获取。
Dear All: I was addicted to stock market but I lose a lot of money. Now I come back to research & coding. The source code will release soon. Additionally, I was approved by Junjun He (the 1st author of DMNet and APCNet) to release his code. If you have any question, feel free to contact with me.