Open pswena opened 2 years ago
Hi,
Thank you for your interest in our work. You are correct, the output of second refinement stage to regress key-points has 20 channels rather than 21. The reason is that we are only interested in the 20 key-points and not the background. After we publish our work, we noticed that if ignore refining the background pixels (21st channel) we have a slightly better performance with reduced compute.
In your paper, stage 2 will return "kp" that the shape should be [B, 20, 56, 56], but in your code the shape is [B, 21, 56, 56] . in your code : kp = self.L5(self.res2(x) + self.L4(joint)), but self.L5 output 21 channels not 20.
class FineRegressor(nn.Module): """ Key-Point Refinement Network """ def init(self, n2=20): super(FineRegressor, self).init() self.N2 = n2 self.Normalize = nn.Softmax(dim=2) self.MaxPool = nn.MaxPool2d(2, 2) self.L1 = nn.Sequential(nn.Conv2d(24, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) self.HR1 = nn.Sequential(nn.Conv2d(64, 64, 7), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.ConvTranspose2d(256, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) self.res1 = nn.Sequential(nn.Conv2d(64, 64, 1), nn.BatchNorm2d(64), nn.ReLU()) self.L2 = nn.Sequential(nn.ConvTranspose2d(64, self.N2 + 1 , 7), nn.BatchNorm2d(self.N2 + 1), nn.ReLU(), nn.Conv2d(self.N2 + 1, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) self.L3 = nn.Sequential(nn.Conv2d(64, 64, 7), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 1), nn.BatchNorm2d(256), nn.ReLU()) self.res2 = nn.Sequential(nn.Conv2d(64, 64, 1), nn.BatchNorm2d(64), nn.ReLU()) self.L4 = nn.Sequential(nn.ConvTranspose2d(256, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) self.L5 = nn.Sequential(nn.ConvTranspose2d(64, self.N2 + 1, 7), nn.BatchNorm2d(self.N2 + 1), nn.ReLU()) self.pose_branch1 = nn.Sequential(nn.Conv2d(256, 128, 7), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) self.pose_branch2 = nn.Sequential(nn.Conv2d(64, 32, 7), nn.BatchNorm2d(32), nn.ReLU()) self.FC = nn.Sequential(nn.Linear(2048, 256, bias=True), nn.Dropout(0.5), nn.Linear(256, 8, bias=True))