sniklaus / pytorch-pwc

a reimplementation of PWC-Net in PyTorch that matches the official Caffe version
GNU General Public License v3.0
608 stars 122 forks source link

Just want to confirm if my analysis of your commit 5f4d7de is correct #53

Closed Etienne66 closed 2 years ago

Etienne66 commented 2 years ago

On commit 5f4d7def149c71ee1610c527e5502264e847c940 in correlation.py you moved the save_for_backward to the end just before the return in the forward definition for the class _FunctionCorrelation. I am guessing that is because rbot0 and rbot1 have only zeros at the place in code the save_for_backward was before. I haven't finished training my model which uses your model on as well but it is definitely taking longer and using the CPU a lot more. I'm guessing that is because the backward finally has some data for rbot0 and rbot1.

Not sure why you don't like putting comments in your code nor in your commits as to why something was changed but a lot of PHD students seem to be that way. I'm a programmer and I get chewed out by my boss if I do that. Oh well... I really do appreciate you improving your code so I can't complain too much :smile:

sniklaus commented 2 years ago

The commit doesn't actually change the behavior, it just makes the code easier to read. rbot0 and rbot1 will be modified in place, so even though in the previous version we put the save_for_backward before populating rbot0 and rbot1, the saved tensors weren't just zeros. You can verify this using gradcheck for example. That is, change all float to double in the correlation.py and run the following code. Then change the location of save_for_backward and try again. The gradcheck will pass both times. As for why I don't use comments, I have the philosophic that if your code requires comments then the code itself isn't written well enough (just my opinion though, I am well aware that the readability could still be improved with comments).

import torch
import sys; sys.path.insert(0, './correlation'); import correlation

class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
    # end

    def forward(self, tenOne, tenTwo):
        return correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=tenTwo)
    # end
# end

netTest = Network().cuda()

tenOne = torch.randn(2, 4, 10, 10).double().cuda().requires_grad_()
tenTwo = torch.randn(2, 4, 10, 10).double().cuda().requires_grad_()

torch.autograd.gradcheck(func=netTest, inputs=tuple([tenOne, tenTwo]))
Etienne66 commented 2 years ago

I'm not quite buying that theory that save_for_backward lets you modify the values afterwards. The reason I say this is that my epochs were taking 29 hours but after this change they are now taking 30 hours. Something is taking longer to calculate and this was the only change I made.

sniklaus commented 2 years ago

I'm not quite buying that theory that save_for_backward lets you modify the values afterwards.

If you don't believe me then I encourage you to give it a try yourself. Print the average rbot0 in the backwards function and play around with the implementation. rbot0 (and rbot1) will be nonzero regardless of where save_for_backward is called but if you instead remove the two calls to kernel_Correlation_rearrange then rbot0 (and rbot1) will be zero.

Etienne66 commented 2 years ago

It isn't so much that I don't believe you. I know you have more experience with this than I do. I just don't understand why that change would be more computational if the value was allowed to be modified after the save_for_backward plus the default for new_zeros is requires_grad = False which is what is declared right before the original save_for_backward so a grad check on that tensor wouldn't show anything. Besides I always thought in place variable changes were the ones being passed in not on ones being passed out. It seems like save_for_backward would retain the value at the time it was called but I can't find a lot of documentation on it.

Etienne66 commented 2 years ago

the one thing I do see is that mark_dirty must be used to mark any input that is modified inplace by the forward function. It makes more sense if that is supposed to be used on rbot0 and rbot1. I didn't really understand the use of mark_dirty and I could not find a single example.

sniklaus commented 2 years ago

I am afraid that I don't know what happens in PyTorch internally. Both versions of save_for_backward produce the same result though, so if one is faster then the other for you there is nothing that should hold you back form using the faster one. :slightly_smiling_face:

Etienne66 commented 2 years ago

Well as long as you are sure. I still wonder why the computations are taking longer though. It certainly seemed like it made a big difference in my loss total as well. I'm a PL/SQL developer in my professional life and I'm still very new at Python and Pytorch. I appreciate all of the information @sniklaus