cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 556 forks source link

[Bug] Incorrect gradient sizes when back-propagating through TLT - tensor product #1247

Open Balandat opened 4 years ago

Balandat commented 4 years ago

🐛 Bug

An error in gradient shape incompatibility gets triggered when back-propagating through a TLT - tensor product. This was discovered in the context of sampling from a MVN posterior using base samples in https://github.com/pytorch/botorch/issues/513

This issue is caused by the new TriangularLazyTensor from #1102: The following code works fine pre-merge on 3e87f849c56c5ad018ac535052080b6649b244da, but fails on the merge commit 4e6f2d0b0988409f312cdfe97eff0754bc356d48,

I haven't been able to dig much deeper, but the plot thickens.

To reproduce

from gpytorch.lazy import BlockDiagLazyTensor, NonLazyTensor 

E = torch.stack([
    torch.eye(2, requires_grad=True),
    torch.eye(2),
])
bdlt = BlockDiagLazyTensor(NonLazyTensor(E))
clt = bdlt.root_decomposition()
root = clt.root

base_samples = torch.randn(4, 1)
out = root.matmul(base_samples)
out.sum().backward()

Stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-74-0ac2634cc9e4> in <module>
     11 base_samples = torch.randn(4, 1)
     12 out = root.matmul(base_samples)
---> 13 out.sum().backward()

~/miniconda3/envs/botorch/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186 
    187     def register_hook(self, hook):

~/miniconda3/envs/botorch/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
--> 127         allow_unreachable=True)  # allow_unreachable flag
    128 
    129 

RuntimeError: Function MatmulBackward returned an invalid gradient at index 1 - got [4, 4] but expected shape compatible with [2, 2, 2]
Exception raised from validate_outputs at /Users/distiller/project/conda/conda-bld/pytorch_1595629430416/work/torch/csrc/autograd/engine.cpp:602 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 169 (0x107f72199 in libc10.dylib)
frame #1: torch::autograd::validate_outputs(std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&, std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> >&, std::__1::function<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > (std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&)> const&) + 5332 (0x119a68b24 in libtorch_cpu.dylib)
frame #2: torch::autograd::Engine::evaluate_function(std::__1::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::__1::shared_ptr<torch::autograd::ReadyQueue> const&) + 2483 (0x119a63973 in libtorch_cpu.dylib)
frame #3: torch::autograd::Engine::thread_main(std::__1::shared_ptr<torch::autograd::GraphTask> const&) + 764 (0x119a6266c in libtorch_cpu.dylib)
frame #4: torch::autograd::Engine::execute_with_graph_task(std::__1::shared_ptr<torch::autograd::GraphTask> const&, std::__1::shared_ptr<torch::autograd::Node>) + 1023 (0x119a6c6cf in libtorch_cpu.dylib)
frame #5: torch::autograd::python::PythonEngine::execute_with_graph_task(std::__1::shared_ptr<torch::autograd::GraphTask> const&, std::__1::shared_ptr<torch::autograd::Node>) + 53 (0x112210145 in libtorch_python.dylib)
frame #6: torch::autograd::Engine::execute(std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&, std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> > const&, bool, bool, std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&) + 662 (0x119a6ad36 in libtorch_cpu.dylib)
frame #7: torch::autograd::python::PythonEngine::execute(std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&, std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> > const&, bool, bool, std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&) + 82 (0x11220ff42 in libtorch_python.dylib)
frame #8: THPEngine_run_backward(THPEngine*, _object*, _object*) + 2174 (0x112210aee in libtorch_python.dylib)
<omitting python frames>

Expected Behavior

Gradients work.

Additional context

Note that this is not simply a blatant bug in TLT, since the following works fine:

D = torch.eye(2, requires_grad=True)
tlt = TriangularLazyTensor(D)
x = torch.rand(2, 1)
out = tlt.matmul(x)
out.sum().backward()
wjmaddox commented 4 years ago

It looks like it's because the _quad_form_derivative is returning the shapes incorrectly. A workaround is just to delete that method from the TriangularLazyTensor. In your example,

root._quad_form_derivative(torch.randn(4,1), torch.randn(4,1))[0].shape
# torch.Size(4,4)

super(TriangularLazyTensor, root)._quad_form_derivative(torch.randn(4,1), torch.randn(4,1))[0].shape
# torch.Size(2,2,2)

Looking at the source before this commit changed, the difference is that CholLazyTensor never actually had a _quad_form_derivative method and used the standard lazy tensor _quad_form_derivative.