princeton-vl / pytorch_stacked_hourglass

Pytorch implementation of the ECCV 2016 paper "Stacked Hourglass Networks for Human Pose Estimation"
BSD 3-Clause "New" or "Revised" License
469 stars 94 forks source link

a question about calculate loss fouction #11

Closed xdTin closed 4 years ago

xdTin commented 4 years ago

Hello~Thanks for you work and it helps me understand the paper. But I have a question when I read the code about calculate loss. in the posenet.py, the calc_loss fouction in class PoseNet make me confustion.

def calc_loss(self, combined_hm_preds, heatmaps): combined_loss = [] for i in range(self.nstack): combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps)) combined_loss = torch.stack(combined_loss, dim=1) return combined_loss

Especially, I cannot understand the following for i in range(self.nstack): combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))

I don't know why use the combined_hm_preds[0][:,i] compute the loss. Shouldn't combined_hm_preds a list which contains eight predictions?Why not combined_hm_preds[i]?

Thank you very much if you answer this question.

crockwell commented 4 years ago

The line is correct as written, although it indeed may be a bit confusing. Line 60 of posenet.py first stacks predictions along axis 1, then we turn this into a list in task/pose.py. The reason we format combined_hm_preds as a list is for use later in task/pose.py (if I recall consistent formatting with loss), although there certainly could be other correct implementations as well.

BriFuture commented 4 years ago

Hi, @crockwell, line 60 of posenet.py indicates that the shape of PoseNet model output tensor should be Batch StackSize Channels Height Width. Im I right?

if the first dim stands for Batch Size, how can combined_hm_preds[0][:,i] use all inputs from the batch?

Assume that batch size is 2, nstack is 3, out_channels is 16,

>>> import torch
>>> pred = torch.randn((2, 3, 16, 64, 64))
>>> pred[0].shape
torch.Size([3, 16, 64, 64])
>>> pred[0][:, 0].shape 
torch.Size([3, 64, 64])
>>> for i in range(3):
...   print(pred[0][:, i].shape)
... 
torch.Size([3, 64, 64])
torch.Size([3, 64, 64])
torch.Size([3, 64, 64])

Thank you.

BriFuture commented 4 years ago

Oh, I just find that following codes would unsqueeze the 1st dim:

if type(combined_hm_preds)!=list and type(combined_hm_preds)!=tuple:
    combined_hm_preds = [combined_hm_preds]