sniklaus / pytorch-pwc

a reimplementation of PWC-Net in PyTorch that matches the official Caffe version
GNU General Public License v3.0
616 stars 123 forks source link

How can I add backward support for correlation package? #10

Closed Lotayou closed 5 years ago

Lotayou commented 5 years ago

Hey, I'm running some experiment on video generation, and I would like to use the L1-loss between optical flow tensors of real video and fake video as a temporal constraint. This means that I need to back-propagate gradients of the flow w.r.t fake video frames.

However I found in your correlation package the backward function is not implemented:

def backward(self, gradOutput):
        first, second = self.saved_tensors

        assert(gradOutput.is_contiguous() == True)

        gradFirst = first.new_zeros(first.size(0), first.size(1), first.size(2), first.size(3)) if self.needs_input_grad[0] == True else None
        gradSecond = first.new_zeros(first.size(0), first.size(1), first.size(2), first.size(3)) if self.needs_input_grad[1] == True else None

        if first.is_cuda == True:
            raise NotImplementedError()

        elif first.is_cuda == False:
            raise NotImplementedError()

        # end

        return gradFirst, gradSecond
    # end

I find that the officical PWC-Net implementation provided a complete correlation package, but I'm not quite sure how to incorporated it to your project. Can you give me a hint on what to do?

Thanks!

sniklaus commented 5 years ago

I have added it in 36f4b4f, please let me know in case there are any issues with it.