TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
150 stars 30 forks source link

Definition of pullback for `logpdf` is is overly optimistic #121

Open willtebbutt opened 4 years ago

willtebbutt commented 4 years ago

This definition is very optimistic about the things that it thinks that it can handle.

In particular, it hijacks control away from this method in Stheno, and causes AD to do something entirely innappropriate in the sense that if this rule didn't exist, my code would work just fine. It causes similar problems to type piracy -- see this well-known ChainRules issue, which explains the core of the problem.

TLDR: defining rules for abstract types causes problems. Since we need to be able to work with abstract types at the minute, this means that you have to be really careful about the abstract types for which you implement rules.

@mohamed82008 any thoughts on how this implementation could be made less aggressive? It's currently blocking for Stheno-Turing integration, and is related to this issue.

devmotion commented 4 years ago

One could keep an explicit list of supported and tested distributions in DistributionsAD, and only define it for those.

However, I think a proper fix would be to change the implementation of logpdf in Distributions such that the defaults do not use the in-place method logpdf! but just map with the out-of-place logpdf for single samples (that seems also much more robust in general). Then the Zygote definitions here could be removed, it seems.

In general, I try to not look too carefully at the implementations in DistributionsAD - there's type piracy all over the place, and as the example shows it can lead to all kinds of problems...

willtebbutt commented 4 years ago

However, I think a proper fix would be to change the implementation of logpdf in Distributions such that the defaults do not use the in-place method logpdf! but just map with the out-of-place logpdf for single samples (that seems also much more robust in general). Then the Zygote definitions here could be removed, it seems.

This seems reasonable to me. Is this something that the Distributions folks might be receptive to do you think?

devmotion commented 4 years ago

Is this something that the Distributions folks might be receptive to do you think?

I don't know but I assume they might be fine with it. I mean, it seems reasonable to me even without Zygote :shrug:

mohamed82008 commented 4 years ago

@mohamed82008 any thoughts on how this implementation could be made less aggressive?

This is a hard one. Without making Distributions.jl compatible with Zygote, we need a catchall adjoint here to be an adjoint for Distributions.jl's catchall method. In your case, a workaround would be to define an adjoint for your method that calls pullback on another function name that has no adjoint.

Thinking about the bigger problem linked in that issue (I didn't read the whole issue so not sure if this has been discussed), I think we can essentially formalise the workaround used here by adding an additional dispatch layer that allows you to modify the "method-rrule matching rule". Imo, every method should have its own adjoint. If a more specific method was important enough to have in the forward pass then it makes sense that we may need to special case the reverse pass. But some times we may also not need that where a sufficiently generic reverse pass can be the adjoint of many forward methods. Imo this problem can be mostly solved by giving more control to developers and perhaps changing defaults. So now when I define a new Julia function, I can tell ChainRules please don't match my function using Julia's multiple dispatch criteria but treat it as its own thing. This can be literally implemented under the hood using the workaround proposed here, i.e. defining a "bridge rule" that calls another function with no adjoint methods.

We can also have an option at the rule definition site telling ChainRules not to match the rule to any forward method whose signature is more specific. Instead, only apply this adjoint to the "specific signatures" provided. This double-headed approach can let us mix and match between "method-rrule matching rules", sometimes using normal multiple dispatch where it seems to not break things and other times matching the exact signature. But more importantly, the approach proposed here lets the user opt out of the rules defined in ChainRules at the forward method definition site.

mohamed82008 commented 4 years ago

So now when I define a new function, if I suspect ChainRules is hurting my performance/correctness, I can either define a correct and performant rule for my method or opt out of ChainRules for this method. When defining a new type, it's more complicated because we automatically "sign in" with a few methods in the forward-pass. So perhaps we can also provide an opt-out mechanism based on types not just methods.

willtebbutt commented 4 years ago

Thanks for your thoughts @mohamed82008 -- I think we're on the same page in regards to the problem.

I can either define a correct and performant rule for my method or opt out of ChainRules for this method.

Could you elaborate with some pseudo / example code or something? I'm struggling to understand what you're proposing, but would be keen to understand better.

I would generally be much more in favour of an opt-in mechanism. My reasoning for that we should view an inability of AD to automatically derive a rule as the norm, rather than the exception.

mohamed82008 commented 4 years ago

Could you elaborate with some pseudo / example code or something? I'm struggling to understand what you're proposing, but would be keen to understand better.

Defining a correct or performant rule is easy, just use ChainRules. Opting out can be done by overloading Zygote.has_chain_rrule (https://github.com/FluxML/Zygote.jl/blob/2fc416464ca4910d19618f589b0c93f595b16afb/src/compiler/chainrules.jl#L12) which I think should live in ChainRules anyways.

mohamed82008 commented 4 years ago

An opt-in mechanism (opt out by default) would be hard to implement though. This because when we check for a rule, we check the concrete types to see if there is a method in rrule that can take these types as arguments. An opt-out by default approach means that any rrule defined on abstract types will never get a match, unless there is a rule for the particular concrete types that "forwards" to the abstract rrule definition. This doesn't feel Julian at all in the sense that we won't be taking advantage of the type system or multiple dispatch to simplify codes. For example, we will need rules for Float64, Float32, Float16, DoubleFloat, etc. I don't think that's feasible. A default opt-in approach with occasional opting out here or there for special types makes more sense to me from an implementation point of view.

mohamed82008 commented 4 years ago

I think what you are really advocating for here is rule definition for "narrow" abstract types, e.g. AbstractFloat instead of Real or DenseVector instead of AbstractVector. A trait-based rule matching would be useful here as well. So rules can make certain "assumptions" about their inputs, e.g. they are dense, sparse, O(1)-sized, etc. Then the rule checking needs to find an appropriate rule then. I think trait-based matching makes more sense than type-based matching or defining rules on concrete types only. Proper language-level support for traits may help here. See the discussion in https://github.com/JuliaLang/julia/issues/37790.

willtebbutt commented 4 years ago

Defining a correct or performant rule is easy, just use ChainRules. Opting out can be done by overloading Zygote.has_chain_rrule (https://github.com/FluxML/Zygote.jl/blob/2fc416464ca4910d19618f589b0c93f595b16afb/src/compiler/chainrules.jl#L12) which I think should live in ChainRules anyways.

Hmm yes. This could also be done by writing an rrule that returns nothing.

I mean, the best way to implement an opt-in mechanism is to just not define rrules for abstract types, and whenever you find that AD doesn't work for a particular type, define an rrule that calls the default function. For example, there are finitely-many Distribution subtypes types defined in Distributions.jl. For each of them you could define (maybe via metaprogramming) a trivial rrule that calls some rrule-like function (maybe a function called loglikelihood_rrule_helper or something).

More generally, the symptom of what DistributionsAD is doing is very much like the symptoms of type-piracy. Consider that I wrote some code in my package, that works just fine with AD when I don't load Turing. Then I load Turing, and it breaks -- this is exactly the kind of thing we try to avoid by avoiding commiting type piracy.

One way of reasoning about this as type-piracy is by considering that when I wrote my code, I also implicitly "wrote" a method of pullback (or whatever function @adjoint spits out) by not defining a method of pullback. i.e. I explicitly intended for the method that Zygote implements automatically to be the one that is used. DistributionsAD then goes and defines a more-specific method of pullback that over-rides the default behaviour.

I will grant you, that you could either construe this as a problem with the way that Zygote works, but it feels like the kind of thing that we should really be solving at the ChainRules level.

mohamed82008 commented 4 years ago

The DistributionsAD approach is breaking all the Julia rules and it needs to go. But this package was born out of the need to "fix" differentiating most of the distributions using all the AD packages. This meant different workarounds for different packages. Some of those "workarounds" made it back to ReverseDiff or were changed to using ChainRules, while others remained. In a way, defining rrules on any type we don't own is type piracy. But doing so on abstract types is especially bad for the reason you outline. So in summary, I am in favour of removing the method in question here if removing it doesn't break anything or if you have a better implementation.