Open marcoct opened 4 years ago
Hi! I am glad to hear the Gen team is willing to collaborate on this package! I also don't like the current state of DistributionsAD's dependencies. It was mostly driven by a desperation to make things work without thinking too much about loading time and dependencies. I am open to improvement suggestions. As for Combinatorics, we only use the combinations
function so re-implementing it shouldn't be too hard.
Actually Combinatorics can be removed entirely in the next release after https://github.com/JuliaDiff/ReverseDiff.jl/pull/125 is merged and released.
thank @marcoct, it would be great to work together on the common set of libraries underlying Turing and Gen, and avoid duplicate efforts where possible. Also, we're happy to move DistributionsAD
to a more suitable org account in the future, e.g. JuliaStats
or JuliaPPL
.
Have you benchmarked the dependencies? I suggest doing that before duplicating any implementation. E.g., ExprTools is a very lightweight package that contains basically only splitdef
and combinedef
from MacroTools, and is a lot faster to load than MacroTools since it does not have any dependencies. If there are packages that increase compilation time significantly, IMO one could either try to fix these problems upstream or move commonly used functions to separate packages (but not DistributionsAD) so that multiple packages can make use of them (such as DistributionsAD and even the upstream package). ExprTools, for instance, was motivated by the fact that MacroTools tends to increase compilation times significantly (apparently due to its dependency on DataStructures) even if one is only interested in the splitdef
and combinedef
functionality.
Another thing is that I've seen and debugged a case where Zygote increased test times from around 160 seconds to over 2000 seconds (see https://discourse.julialang.org/t/extremely-long-compilation-time-2000-sec-on-win32/26496?u=devmotion and the links therein). I assume some of these issues are fixed on newer Julia version (I've never checked actually since we disabled tests on 32bit and made Zygote an optional dependency), but to me it seems more likely that the AD packages and typical candidates such as MacroTools and Requires have a more significant influence on the compile time than lightweight packages such as ExprTools. However, of course, that's not completely clear without benchmarking :man_shrugging:
Going even further, for compilation times and number of dependencies, it might actually be more helpful to split DistributionsAD in multiple packages such as DistributionsTracker, DistributionsForwardDiff, DistributionsZygote, and DistributionsReverseDiff for dedicated AD backends. Then people that only want to use, e.g., ForwardDiff would not affect by compile times of the other AD packages. Unfortunately, it seems that's not (completely) possible since, e.g., adjoints for Tracker are sometimes computed with ForwardDiff: https://github.com/TuringLang/DistributionsAD.jl/blob/e3989c66f9d9a926c1a6142f7380b2fc8f0ace02/src/univariate.jl#L279-L286 Hence this is not a viable approach, and I guess in the future it might be better to define forward-, reverse-, and mixed-mode derivatives using ChainRulesCore without depending on any specific AD backend (ForwardDiff2 already works with ChainRulesCore, and Zygote support should be available soon: https://github.com/FluxML/Zygote.jl/pull/366). For the time being, it would at least be better to only depend on ZygoteRules than Zygote (from a quick glance at the code it seems we could actually drop Zygote as a dependency, I'll check it locally and make a PR if that's the case).
What is the opinion on developers of this package about splitting the gradient code from the AD integration code into separate packages? For example, one package that would contain the gradient implementations, and several other packages would contain the glue code for different AD backends (as suggested by @devmotion above).
I could imagine that the gradient code could be placed in Distributions.jl (see https://github.com/JuliaStats/Distributions.jl/issues/1085), or perhaps in a separate package (e.g. DistributionGradients).
This would make it possible to use the gradient logic without needing to depend on any AD system at all. For example, in Gen some specialized modeling languages make use of gradients of log-densities, but do not use any of the Julia AD packages
I also see https://github.com/matbesancon/DistributionsDiff.jl/blob/master/src/DistributionsDiff.jl, but this looks like a work in progress, and also does not appear to be what I'm proposing.
I'm just proposing to have simple plain-Julia implementations of functions like this: https://github.com/probcomp/Gen/blob/master/src/modeling_library/distributions/laplace.jl#L15-L26. Then, each AD-glue package would invoke these internally.
One potential issue (which is probably a separate discussion, because I see it shows up in the current DistributionsAD code as well) is that the user may want gradients with respect to some subset of the parameters, and always computing the gradients for all parameters may be wasteful, and only computing gradients with respect to individual parameters would also be wasteful. But I suppose this could be handled by allowing additional optimized implementations for different use cases to be implemented as a performance optimization as needed, and otherwise falling back to the version that computes gradients with respect to all parameters. Maybe this is the sort of thing that ChainRulesCore is designed to help with. But for various reasons (stability, performance optimization) I think it can make sense to separate low-level code that is accessible from a more conservative programming model from more abstract programming models.
What is the opinion on developers of this package about splitting the gradient code from the AD integration code into separate packages?
By gradient code, do you mean (for instance) implementation of forward-, reverse-, and mixed-mode derivatives with ChainRulesCore? Or do you have something else in mind?
I also see matbesancon/DistributionsDiff.jl:src/DistributionsDiff.jl@master , but this looks like a work in progress, and also does not appear to be what I'm proposing.
Nice, I didn't know about DistributionsDiff! It seems its goal is actually to implement the ChainRulesCore interface for Distributions.
I'm just proposing to have simple plain-Julia implementations of functions like this: https://github.com/probcomp/Gen/blob/master/src/modeling_library/distributions/laplace.jl#L15-L26
IMO the "nicer" approach would be to implement the ChainRulesCore interface, as started in DistributionsDiff. Then AD backends just have to support this interface and you don't need any glue code anymore, as far as I understand. And by implementing forward-, reverse-, and mixed-mode derivatives you can support all different kinds of AD backends. If this works out as intended, you would never have to deal with these backends in your code.
By gradient code, do you mean (for instance) implementation of forward-, reverse-, and mixed-mode derivatives with ChainRulesCore? Or do you have something else in mind?
I mean functions that take as input the parameters of a distribution and a value of the random variate, and return gradients of the log density with respect to parameters and/or the value of the random choice. This code would not know that AD exists. This code itself would use a simpler conceptual vocabulary that is specialized to the problem of computing certain derivatives associated with probability distributions.
Then, users of reverse-mode AD (or more precisely, implementers of packages like DistributionsAD) could invoke this code in their implementation of pullbacks. But users who just want derivatives of probability distributions, and are not using AD at all, can still use this code, without needing to depend on or adopt any interfaces or packages designed for AD.
IMO the "nicer" approach would be to implement the ChainRulesCore interface, as started in DistributionsDiff. Then AD backends just have to support this interface and you don't need any glue code anymore, as far as I understand. And by implementing forward-, reverse-, and mixed-mode derivatives you can support all different kinds of AD backends. If this works out as intended, you would never have to deal with these backends in your code.
Right. Something implicit in my suggestion is that the existence of glue code that separates this specialized gradient logic from particular usage patterns (e.g. in the context of AD) is actually useful. Users that just want to obtain a particular gradient associated with a distribution can use the specialized gradient code however they want (e.g. in a highly optimized implementation that doesn't use AD at all), and AD users can (with a single line of glue code) make this code accessible for their usage pattern.
Users that just want to obtain a particular gradient associated with a distribution can use the specialized gradient code however they want (e.g. in a highly optimized implementation that doesn't use AD at all),
It is possible to use ChainRules directly, without any AD package. But that's maybe not what you had in mind?
it is my opinion (and I am biased as I am the main maintainer of ChainRules now)
that the correct place for the definations of custom sensitivity rules (as in overloads for ChainRulesCore.rrule
and ChainRules.frule
)
is in the package that declares the primal function,
This is kinda the same thing as for Plots.jl/RecipiesBase recipies.
IMO ideally DistributionsAD (which basically only contains type piracy :smile:) shouldn't exist at all. Of course, it's very understandable that nobody wants Distributions to depend on a bunch of AD backends (which would be the case if one would just copy the implementation of DistributionsAD) - and that's why something like ChainRulesCore would be extremely useful IMO. Nevertheless, I guess even if at some point definitions of sensitivity rules end up in Distributions, it would still be useful (and probably also faster) to start working on, e.g., ChainsRulesCore based sensitivity rules in a separate package (such as DistributionsDiff apparently) for a while, before integrating them into Distributions.
I agree that we should try port most of DistributionsAD (DAD) to ChainRules. The main problem I can see with that is that sometimes we don't define a custom adjoint in DAD in a traditional sense. Sometimes a distribution's implementation would be too restricting for AD to go through it, so we implement a shadow version of that distribution that's more AD-compatible (whatever that means, it can mean different things for different AD backends) and then we use dispatch or Zygote's adjoint to "guide" the AD package through an AD-friendly chain of methods and types. There is a fair bit of that in DAD. But in general, I agree with the sentiment that DAD should use ChainRules wherever possible.
As for splitting the package into separate packages, I am down with that but the core package should probably keep ForwardDiff support. ForwardDiff is used by Tracker, ReverseDiff and Zygote for broadcasting and scalar functions so supporting FD by default seems like a good idea. Also the FD bits of DAD are quite lightweight.
I also agree with @marcoct that computing all the derivatives wrt all the parameters is wasteful when we only want some. Ideally, the Julia compiler can figure out that we don't use some of those derivatives so it will not compute them, but I am not sure we can rely on the compiler to do this for us. Selectively computing only some of the derivatives is somewhat difficult though because you have to account for all the combinations of arguments that we may want to differentiate wrt. And I don't think this is a trivial optimization to add to the reverse-mode AD packages that we have.
On having a separate package for derivative functions of logpdf wrt arguments of a distribution and the random variable if continuous, I think that will work but only when the output of logpdf is a scalar. When the output is a vector, e.g. logpdf(::MvNormal, ::Matrix)
you would have to define the Jacobian which is wasteful because most of the time we only need the operator x -> J x
or x -> J' x
not the full Jacobian J
. If you are going as far as defining these operators then we might as well do it in ChainRules since that's mostly useful in the AD context. Whether such a package will be useful or not compared to simply implementing the derivatives in ChainRules is a separate issue though. Perhaps adding some Distributions-specific utility functions to ChainRules can be all that's needed to get the API you desire from that other package.
As a first step, it might be helpful to make Tracker-support optional (via Requires.jl). Tracker has a lot of transitive dependencies (e.g. via NNlib), and will likely become less and less relevant in the future (in favor of Zygote).
This is part of https://github.com/TuringLang/DistributionsAD.jl/pull/95.
Oh, sorry, didn't see that one - thanks, @devmotion !
Are there plans already to support ChainRulesCore.jl in the future? From what I understand, that would potentially cover both Zygote and the upcoming ForwardDiff2 with a single depencency.
Yes, this is the plan (and #95 already makes use of ChainRules).
Awesome, thanks!
This package is great work and the Gen team would love to collaborate and contribute to this. One concern I have is that DistributionsAD brings in some a list of dependencies beyond those of Zygote, Tracker, ForwardDiff, and ReverseDiff, that includes some potentially unnecessary packages and binaries. For example, these are added dependencies introduced if I add DistributionsAD to the Gen project:
[861a8166] + Combinatorics v0.7.0 [ced4e74d] + DistributionsAD v0.1.4 [e2ba6199] + ExprTools v0.1.1 [8f5d6c58] + EzXML v1.1.0 [d8418881] + Intervals v1.1.0 [94ce4f54] + Libiconv_jll v1.16.0+2 [78c3b35d] + Mocking v0.7.1 [f27b6e38] + Polynomials v0.8.0 [3cdcf5f2] + RecipesBase v1.0.1 [f269a46b] + TimeZones v1.1.1 [02c8fc9c] + XML2_jll v2.9.9+4 [83775a58] + Zlib_jll v1.2.11+9
This chain of dependencies seems to be responsible for many of these:
Combinatorics => Polynomials => Intervals => TimeZones => EzXML
It would be possible to break the chain and remove dependence on Intervals in either DistributionsAD itself, or in Combinatorics, or Polynomials. In any of these cases, if there is a single function that is used from the package one-level down the chain, it might make sense to pay the cost of re-implementing it, instead of adding these dependencies.