Open dhpollack opened 6 years ago
I think this is indeed related to the interaction between pybind11 and classes like at::IntList
. We have well maintained and specialized bindings for at::Tensor
since we expose those directly in Python, but even in PyTorch we don't bind IntList
and Scalar
into Python. pybind11 comes with a lot of built-in support for std::
types, so if you want ease-of-life w.r.t. Python binding, you should use std::vector<at::Tensor>
, int64_t
, std::tuple<at::Tensor, int64_t>
etc.
I ended up using std::vector<int64_t>
for what I needed, but it wasn't intuitive (nor am I even sure now it's the right way to do it). Specifically, the docs for the CTC Loss in the master make it seem like one could input a non-tensor for at::IntList
, but perhaps that's just a typo.
Even a negative example would be nice, like "you can't use at::Scalar
as an input, use a normal int64_t or scalar_t instead". I guess because the current example only uses tensor inputs, it's a bit difficult to know when to use non-tensor inputs and it's quite a common use case to cover.
To extend on this point, for the python wrapper part, non-tensor inputs don't seem to fit in ctx.save_for_backward
(I did some "hacky" stuff instead to pass them to the backward method and I am still wondering what should be the right way to do this ...), and we also have to give a None
gradient to each of them in the backward method, even if they are int64_t
or bool
. Would be great if all these could be clarified.
Agree on this, I had the same problems getting ints, float arrays and other stuff into my c extensions for the older pyTorch and now I find myself again at the same point looking at the tutorial and it has only Tensor inputs and there is simply no other documentation available... you could perhaps just add some dummy ints, floats and arrays to the LLTM example? :)
It would be nice if there were more examples and especially examples with different types of inputs. Currently, the only input type is
at::Tensor
, but what about the other types? Specifically, I had a lot of trouble usingat::Scalar
andat::IntList
as an input and instead usedint
andstd::vector<int>
. This might be an issue with pybind11 rather than pytorch extensions specifically, but for users with limited knowledge of pybind11 and aten, examples can be very helpful.