LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.28k stars 340 forks source link

Default arguments are not exposed. #512

Open Narsil opened 2 years ago

Narsil commented 2 years ago

Hi.

Thanks for this amazing lib.

Just stumbled upon the old issue where default arguments are not exposed.

https://github.com/LaurentMazare/tch-rs/issues/132 https://github.com/LaurentMazare/tch-rs/issues/133 https://github.com/LaurentMazare/tch-rs/issues/350

Here the use case is for new the bloom model which uses baddbmm operator. I had to come up with a fork https://github.com/LaurentMazare/tch-rs/compare/main...Narsil:tch-rs:support_optional_baddbmm_args Which doesn't feel right (and will be a mess to maintain).

Not sure I have a clean and elegant solution to suggest for this one (Maybe adding Option<P> but that feels like more work than necessary, the other one would have been to expose all params behind a feature flag).

I created a new issue since all the other ones are rather old and closed. Just wanted to raise the awareness.

LaurentMazare commented 2 years ago

Right, that's indeed a limitation of the current approach (and of optional arguments not having a very idiomatic way to be represented in Rust, there is the Default trait or the builder pattern but these are a bit verbose). I've just made a PR #513 that makes it possible to generate two versions for some selected functions, one without the optional arguments as is currently the case, and one requiring all these scalar optional arguments to be passed. I'm not sure how good of an idea this is so certainly keen to get some feedback on this, would also be great if you check that it works for your use case.

Narsil commented 2 years ago

I confirm you PR works !

Not sure we should do that for all functions, it's already pretty verbose between f_XX , XX, XX_ functions.

But at least there's an easy way to add such functions more easily/naturally going forward !

Thanks (Never did any serious ML code so I took a look but figure my hack was faster than adding the code you just wrote ^_^)

LaurentMazare commented 2 years ago

Certainly agree with your point that the f_XX, XX, XX_, and now XX_s functions is far from ideal. If it was only functions and not methods on the tensor type, we could namespace these via different mod and then the functions would be tensor::fallible::XX, tensor::XX, tensor::with_scalar_args::XX, but with methods this is trickier. Maybe using some traits we could achieve something like this and it would look at the current trait in scope only. Anyway I'll merge my small changes for now so that hopefully it helps for your use case and we can revisit down the line.

Narsil commented 2 years ago

Anyway I'll merge my small changes for now so that hopefully it helps for your use case and we can revisit down the line.

Thanks a lot. Actually I really like simple approaches, and traits are sometimes more cumbersome that helpful as it makes harder to figure out where the source/doc actually is in the docs. The search of the docs is convenient enough to find what I am looking for. When I saw multiple functions for the same feature I only had to look at the signature to see what was the difference (I also looked at the source to confirm my understanding)

What about a feature gate feature=optional-args that would enable all _s functions ? I understand gen.ml is not currently part of the build process but just throwing that idea out there because it makes sense that it's a power-user feature to want non default arguments, and maybe fallable functions (I actually LOVE the idea to be able to really catch torch exceptions as results, much easier to write robust code, even if I expect cuda to be a pain once it has hit certain particular bugs).

Another slight note, I found myself using more the f_ functions even when I was unwrapping just so that the exception would be in my code not in tch-rs without having to RUST_TRACEBACK the whole thing.

And finally, I am currently using tch-rs because tokenizers is already written in rust and I am doing experiments for API where the threading model is important and I want to avoid Python altogether to avoid all the GIL mess and multiprocessing message passing. I am glad I found a real world work use case to use these bindings, again thanks for making them !

LaurentMazare commented 1 year ago

Just a quick update on this, to avoid the _s additional messiness, I've removed the suffix so the parameters will appear on the default version of the function (so baddbmm_s does not exist anymore but baddbmm has the additional parameter). I also agree that the f_ functions are not convenient to use at the moment but should be the go to functions so as to have proper handling, I'm considering sketching a 1.0 version for this crate at some point that will have as a huge breaking change removing the f_X functions and instead all the X functions will return a Result<...> (which is hard to avoid as soon as it runs some C++ code that may raise an exception).

Narsil commented 1 year ago

Btw I'm doing this for fun: https://github.com/Narsil/tcheck/ which hopefully can make every tensor error a compile time check.