LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.27k stars 340 forks source link

Request for torch.distributions primitives #29

Open jerry73204 opened 5 years ago

jerry73204 commented 5 years ago

It seems current tch-rs lacks distribution primitives like torch.distributions and tf.distributions. Though it already have probability generators, there's still rooms for probability arithmetic and deductions.

I see a promising crate rv that may fit my needs. However, its interface is not general enough. We cannot pass a mean/std tensor to rv's Gaussion new() function.

Porting torch.distributions may solve the problem. It requires some patience to make it work, or I could go ahead to improve rv. I'd like to know if author has any plan to port the module, or leaves it to other crates.

If you're asking which place needs this feature, I'm implementing ELBO loss for GQN (paper).

LaurentMazare commented 5 years ago

Thanks for the suggestion. This indeed seems like a nice thing to add - the underlying torch primitives for sampling should already exist in the rust wrappers, e.g. normal so I guess this would mostly consist in adding some trait for distributions and implementing it mimicking the python implementation, e.g. Normal. Would that be useful to you? Also which methods/distributions would be the most interesting to you?

jerry73204 commented 5 years ago

Yes, I see that normal function. I need probability arithmetic rather than prob generators. Take the GQN for example, I need a Normal object which parameters are defined by tensors, and compute the log_prob (log prob mass function) for any value on the Normal. I also have to compute KL div for two normals. There's no need for sampling.

Torch and TensorFlow have their own such function. Interestingly, you'll see NotImplementedError if you look into Torch's source code. So I bet improving rv would be a good direction. Currently I just write raw formulas with cautions on floating point precisions.

LaurentMazare commented 5 years ago

Do you need to backprop through your KL divergence ? If that's the case I'm not sure that rv could be used out of the box but maybe I'm missing something ? When looking at the 'Normal' distribution in torch, I don't see much NotImplementedError besides in the base Distribution class. Also the kl divergence for the normal distribution can be found here.

jerry73204 commented 5 years ago

The NotImplememtedError goes here. Sorry for inprecise comment.

Backprop is desired in my case. As far as I know, torch.distributioms looks like an add-on of pytorch rather than in libtorch. Implementation in Rust is like a completely new library.

Let me write my some code for this and to see any proper way to build this feature.

LaurentMazare commented 5 years ago

Re NotImplementedError, indeed that's the base class for distributions in pytorch so nothing is implemented here and the various methods are implemented in the derived classes like 'Normal'. The distribution bit is included within the main pytorch repo and package (contrary to say vision which is an external repo and pypi package). I don't have a proper opinion on where this belongs - in the main crate or in a separate one but starting by an external crate sounds good and if there is some upside to merge it in the main repo we can consider it later. Let me know if you notice some pytorch primitives missing from tch-rs that could be useful to you!

LaurentMazare commented 5 years ago

Just to mention that I added a variational auto-encoder example to tch-rs. This inctludes a KL divergence loss here. It's certainly very far away from what a nice distributions api would provide but it may be handy.

vegapit commented 5 years ago

It is just a case of adding numerical approximation functions for pdfs and cdfs of popular statistical distributions. I personally just implemented the Normal cdf derived in this article using tensors and could normally calculate the gradients with the library:

fn norm_cdf(x: &Tensor) -> Tensor {
    let denom = ((-358.0 * x / 23.0) + 111.0 * (37.0 * x / 294.0).atan()).exp() + 1.0;
    1.0 / denom
}
LaurentMazare commented 5 years ago

@vegapit yes that's mostly about adding such functions and probably some traits for the various distributions. You can see the implementation for the normal distribution in the python api here. The code for the cdf is a bit different from yours and relies on torch.erf. Not sure which one has the best precision.

def cdf(self, value):
        return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))
vegapit commented 5 years ago

@LaurentMazare I did not know you actually had already added the error function implementation in Torch, otherwise I would have used it. I do not know what the numerical approximation in torch.erf is but my guess is that its precision must be similar to the function I described above.

LaurentMazare commented 5 years ago

Btw the error function is also available in tch-rs https://docs.rs/tch/0.1.0/tch/struct.Tensor.html#method.erf (as the rust low level wrappers are automatically generated, we mostly get these for free)

jerry73204 commented 5 years ago

A little update here. In my previous gqnrs project has some useful traits for prob distributions. There is only Normal dist there. Suppose we can start a working branch to fill the blanks for {Bernoulli,Exponential,Categorical, etc} dists?

LaurentMazare commented 5 years ago

Yes that kind of trait would indeed probably be useful, @vegapit do you think this would cover your use case?

vegapit commented 5 years ago

I use Torch to solve for maximum likelihood in non-linear parametrised models. The density functions can be reconstructed or approximated using tensor methods but I guess it is clearly more user friendly to provide them in wrappers as Pytorch does

spebern commented 4 years ago

I started to work on a crate to port the distributions: https://github.com/spebern/tch-distr-rs Besides porting, the most tedious work is to test everything.

I think the best way to ease porting is to test against the python implementations directly by supplying the same input and comparing the outputs. Later, this can also be extended with fuzzy input.

The Distribution trait is open for discussion and pull requests are more than welcome to add more distributions/tests.

LaurentMazare commented 4 years ago

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

spebern commented 4 years ago

That would be really nice! It definitely needs some polishing and more implemented distributions, but I really think that the testing against python implementations takes away a lot of work.

dbsxdbsx commented 2 years ago

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

@LaurentMazare, I wonder if it a good idea to introduce tch-distr-rs of @spebern as a feature for torch-rs. So that user no need to introduce 2 crates when using torch in rust.

LaurentMazare commented 2 years ago

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

dbsxdbsx commented 2 years ago

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

@LaurentMazare , the reason for why I hope tch-distr-rs could be part of tch-rs is that in pytorch, the distribution part code is also part of the whole python torch module, though it is not a part of code in the C++ version. Meanwhile, I think it not proper to treat tch-distr-rs as a tool ONLY for reinforcement learning or some other fields.

Therefore, I suggest making it as an optional feature, which would also be flexible (as user could decide whether to include it or not through tag "feature" in Cargo.toml) and easy to transfer from pytorch for users familiar with pytorch.