ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

.backward() broken when called against LibTorch 10.2 #7

Closed derrelldurrett closed 4 years ago

derrelldurrett commented 4 years ago

I've been trying to learn to use libtorch in Ruby, and I was following your examples (OK, cutting and pasting) and I ran into an error that can be reproduced via the sequence:

    x = Torch.ones(2, 2, requires_grad: true)
    y = x + 2
    z = y * y * 3
    out = z.mean
    out.backward
    out.backward(Torch.randn(1, 10))

I'm using libtorch 10.2, so I'm guessing the interface to backward() has added the retain_graph boolean since you wrote ext.cpp.

So, following your instructions on how to develop I cloned torch.rb, and have made what I think are the relevant changes: (tensor.rb)

    def backward(gradient = nil, retain_graph: false)
      gradient = Torch.empty(0) if gradient.nil? and retain_graph
      _backward(gradient, retain_graph)
    end

and (ext.cpp):

      "_backward",
      *[](Tensor& self, Object gradient, bool retain_graph) {
        return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient), retain_graph);
      })

The changes to tensor.rb are a complete hack, but appear to work (though I'm not sure how best to test them -- if you have any suggestions, I'm open to them).

Should I open a pull request?

ankane commented 4 years ago

Hey @derrelldurrett, thanks for the report. A PR would be great. It'd be good to add the create_graph option as well and follow the same behavior as Python and C++ (relevant docs).

From this tutorial, it looks like default gradient is Torch.tensor(1.0).

Let’s backprop now. Because out contains a single scalar, out.backward() is equivalent to out.backward(torch.tensor(1.)).

(also, I think 10.2 might be the CUDA version? as 1.5.1 is the latest version of LibTorch)

Edit: Rereading it, that default may be specific to the example.

derrelldurrett commented 4 years ago

You are, of course, correct that 10.2 is the CUDA version.

I'll try to add the create_graph flag as well (unless I'm missing something, it's quite trivial?).

derrelldurrett commented 4 years ago
3:14 PM Can't Create Pull Request
                Push failed:
                remote: Permission to ankane/torch.rb.git denied to derrelldurrett.
                unable to access 'https://github.com/ankane/torch.rb.git/': The requested URL returned error: 403
ankane commented 4 years ago

Hey @derrelldurrett, you'll need to fork and push there.

ankane commented 4 years ago

Cleaning up issues