pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.67k stars 22.8k forks source link

Failure on non-contiguous gradients for F.pad/permute combination #32705

Closed bnehoran closed 4 years ago

bnehoran commented 4 years ago

🐛 Bug

Some PyTorch primitives expect the gradient passed in during the backward pass to be contiguous, but not all functions produce a contiguous gradient in their backward pass. When two incompatible functions -- one which returns a non-contiguous gradient in the backward pass, and another which expects a contiguous gradient as input to its backward pass -- are strung together, the autodifferentiation fails. In particular, permute and pad don’t play well together, as the following example shows:

# short example to reproduce the error:

import torch
import torch.nn.functional as F

inv = torch.zeros((3, 8), dtype=torch.float).requires_grad_()
indices = torch.zeros((2, 3), dtype=torch.long)

comb = torch.sparse.FloatTensor(indices, inv, (4, 4, 8)).to_dense()
big = F.pad(comb, (0, 0, 1, 1, 1, 1))
shaped = big.view(-1, 8).permute(1, 0).unsqueeze(0)
res = F.fold(shaped, output_size=(5, 5),
             kernel_size=(2, 2), padding=(1, 1))
res.sum().backward()

which results in

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. (view at ../aten/src/ATen/native/TensorShape.cpp:1185)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x65 (0x7f0ab4b2eb15 in /usr/people/bnehoran/pytorch/torch/lib/libc10.so)
frame #1: at::native::view(at::Tensor const&, c10::ArrayRef<long>) + 0x2fb (0x7f0ac8ce33cb in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x118f433 (0x7f0ac8edf433 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x11dfb17 (0x7f0ac8f2fb17 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x1054db3 (0x7f0ac8da4db3 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #5: at::native::sparse_mask_cpu(at::Tensor const&, at::Tensor const&) + 0x84 (0x7f0ac8da6614 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x126c8e3 (0x7f0ac8fbc8e3 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0x11decb4 (0x7f0ac8f2ecb4 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #8: at::Tensor c10::Dispatcher::callUnboxed<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::OperatorHandle const&, at::Tensor const&, at::Tensor const&) const + 0xb2 (0x7f0acdfcac12 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x2d94771 (0x7f0acaae4771 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x11decb4 (0x7f0ac8f2ecb4 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #11: at::native::to_dense_backward(at::Tensor const&, at::Tensor const&) + 0x21a (0x7f0ac8ca7b8a in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x12a4f53 (0x7f0ac8ff4f53 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x2b1a4be (0x7f0aca86a4be in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x11decb4 (0x7f0ac8f2ecb4 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #15: at::Tensor c10::Dispatcher::callUnboxed<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::OperatorHandle const&, at::Tensor const&, at::Tensor const&) const + 0xb2 (0x7f0acdfcac12 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_python.so)
frame #16: torch::autograd::generated::ToDenseBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0xf9 (0x7f0aca6a06f9 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x2feebdb (0x7f0acad3ebdb in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x16ac (0x7f0acad3a34c in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x591 (0x7f0acad3b511 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::Engine::thread_init(int) + 0x49 (0x7f0acad32b89 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_cpu.so)
frame #21: torch::autograd::python::PythonEngine::thread_init(int) + 0x48 (0x7f0ace244cc8 in /usr/people/bnehoran/pytorch/torch/lib/libtorch_python.so)
frame #22: <unknown function> + 0xc819d (0x7f0af8e5819d in /usr/people/bnehoran/anaconda3/envs/pytorch/lib/libstdc++.so.6)
frame #23: <unknown function> + 0x76ba (0x7f0b15f766ba in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #24: clone + 0x6d (0x7f0b15cac41d in /lib/x86_64-linux-gnu/libc.so.6)

Tested on the master branch (1.5.0a0+7fdc6cb) but was already present in versions as early as 1.2.

Edit: The error in https://github.com/pytorch/pytorch/issues/28650 might be related to this issue.

cc @ezyang @SsnL @albanD @zou3519 @gqchen

albanD commented 4 years ago

The gradient formula should be updated to contain the proper calls to .contiguous().

gchanan commented 4 years ago

wouldn't reshape also potentially work (and avoid making things contiguous in the case where the view is valid)?

albanD commented 4 years ago

Right, I forgot that now .view() can handle some non-contiguous Tensors. Indeed, reshape() is better.

gchanan commented 4 years ago

The issue looks to be the backwards definition of sparse_mask.

Make sure that https://github.com/pytorch/pytorch/issues/28650 also passes.

zou3519 commented 4 years ago

I can't reproduce this nor https://github.com/pytorch/pytorch/issues/28650 so both look to be fixed on master. @bnehoran could you try installing a nightly build of pytorch to confirm if the problem still persists for you?

zou3519 commented 4 years ago

I'm optimistically closing this because it seems fixed on master. Please feel free to reopen if this is not the case.

bnehoran commented 4 years ago

Yeah, awesome. It seems to have been fixed sometime over the past couple of weeks.