omasaht / headpose-fsanet-pytorch

Pytorch implementation of FSA-Net: Learning Fine-Grained Structure Aggregation for Head Pose Estimation from a Single Image
MIT License
112 stars 31 forks source link

Maybe a bug #4

Closed tiandunx closed 3 years ago

tiandunx commented 3 years ago

Thanks for this excellent work! It really helps me a lot. But I guess there exists a bug at line 237 in model.py Your implementation is U1 = U1.view(-1, uw uh , ch), U2 = U2.view(-1, uwuh, ch), but actually, U1's shape is (batch_size, channel, w, h), if you simply view it as (batch, w*h, c), actually the last dimension is not 3 channel pixels but pixels in width dimension.

omasaht commented 3 years ago

Hey tiandunx, it's not a bug. See my explanation below:

U1, U2 and U3 shape is: _(batchsize, channel, h, w). I reshape it to: _(batchsize, h*w, channel). I concatenate U1, U2 and U3 along dim 1 to get U whose shape is: _(batchsize,h*w*3,channel). Then I do matrix multiplication with a S1, S2 and S3 whose shape is: _(batch_size, nhat, n) where n=hw3 and n_hat=7. This gives me Ubar_1, Ubar_2, and Ubar3 whose shape is: (batch_size, nhat, channel)

Please note that channel does not refer to image channels (i.e. RGB/3) here. These are actually output channels that we get after final convolution in MultiStreamMultiStage module. It is 64 in this case.

tiandunx commented 3 years ago

Thanks for your explanation. For sure we should do matrix mulitplication along channel dimension as what you've done here. But what I am worried about is that since U1(U2,U3) 's shape is (batch_size, channel, h, w) and its memory in physical storage is width last and its memory format is( width_pixel_0, width_pixel 1, width_pixel_2). In memory format, 2 consecutive numbers are width instead of channels. Even if you reshape it as (batch_size, h*w, C), but you didn't change the physical memory layout. Let me show you an example.

image x shape is 2x2x3 where the first 2 represents channels, height = 2, width = 3, but if we simply view it as (height * width, channels), then the last 2 consecutive number are not channel number, it's witdh. But 1,7 are expected behavior.

omasaht commented 3 years ago

Hey tiandunx, thank you for explaining the bug to me. I understand what you mean now! I will make some time and update the repository fixing this bug. I will also look out for similar bug elsewhere in the code. I believe the right fix for this would be: U1 = *torch.transpose(U1.view(-1,ch,hw),1,2)**. Do you agree?

tiandunx commented 3 years ago

Agreed! In fact, you may simplify the code. Here is what I reimplement it. U1 = U1.view(batch_size, ch, -1), U2 = U2.view(batch_size, ch, -1) , U3 = U3.view(batch_size, ch, -1) Then U = torch.cat((U1, U2, U3), dim=2), U_bar1 = torch.bmm(S1, torch.transpose(U, 1, 2))

tiandunx commented 3 years ago

Here is one another question that really puzzles me. When I use your original code, everything works fine. But when I fix this potential bug, then Pytorch throws the following warning message.

Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator()) It seems that there exists channel last operator. But when I change the code back just as your current implementation. The warning message is gone. I cannot find a workaround. Do you have any idea?

omasaht commented 3 years ago

Where did you encounter this warning? Right now I have made the fix and tried retraining the model again using 'Train Model' notebook and I did not get any warning. My code is:

U1 = U1.view(-1,ch,uh*uw)
U2 = U2.view(-1,ch,uh*uw)
U3 = U3.view(-1,ch,uh*uw)

U = torch.cat((U1,U2,U3),dim=2)

U = torch.transpose(U,1,2)
Ubar_1 = torch.matmul(S1,U)
...

Pytorch version: 1.5.1

omasaht commented 3 years ago

I just looked up on this warning. It seems it is because when you are taking transpose, you are actually causing U tensor to become non-contiguous i.e. its physical memory layout is the same but order of strides has changed. S1 tensor is contiguous so when you call batch-matrix multiply operator, you are giving pytorch contiguous and non-contiguous tensors. I believe pytorch operators by default expects tensors to be contiguous. To remove this warning, you can do U.transpose(U,1,2).contiguous(). This will copy the tensor and reorder tensor in physical memory to preserve contiguity. This should remove the warning that you are having, let me know if it works.

tiandunx commented 3 years ago

Well, it works fine for me. When I change my pytorch from 1.6.0 to 1.7.0, the warning message is gone. I guess it's related to some specific version. Thank you. I'll spend some time working to find if any improvement can be made after fixing this bug.