LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.31k stars 344 forks source link

`Tensor::dropout2d` bindings not generated? #168

Open grtlr opened 4 years ago

grtlr commented 4 years ago

It seems like no bindings are created for Tensor::dropout2d (or higher dimensions). I'm probably missing something, but is there a way to currently achieve this in tch-rs? Or, would it be possible to generate these bindings? Also, do you think it would make sense to generate dropout modules in tch::nn, similar to the Python API?

I've found #136, is the process described there still up-to-date? If the process does not involve too much knowledge of OCaml, I could also try adding these bindings, maybe with a little bit of help from your side.

Regardless, I wanted to thank you again for creating this crate. It continues to be of tremendous value to me and my research.

grtlr commented 4 years ago

Looking at the Pytorch implementation of torch.nn.functional.dropout2d (which is used by torch.nn.Dropout2d), we should be able to just call tch::Tensor::feature_dropout without generating the bindings.

LaurentMazare commented 4 years ago

Indeed there doesn't seem to be anything specific for this in the C++ api that tch-rs binds to. The feature_dropout is properly generated and hopefully should do the trick. If you think it would be convenient to have some dropout2d/dropout3d functions in the nn module so that it's easier to discover these, this should be straightforward (maybe dropout should also be added there in this case).

grtlr commented 4 years ago

I'll work with feature_dropout a bit.

If you think it would be convenient to have some dropout2d/dropout3d functions in the nn module so that it's easier to discover these, this should be straightforward (maybe dropout should also be added there in this case).

Once, I'm confident that it works as expected I will open a PR.

jerry73204 commented 3 years ago

@LaurentMazare Could we get binding generation process documented? Even though we have a Makefile at the root, it requires OCaml toolchain. Potential contributors may not have knowledge about it.