FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Use ChainRules types #603

Open oxinabox opened 4 years ago

oxinabox commented 4 years ago

This issue is about swapping Zygote over to use ChainRule's types by default. When #366 is merged rules coming out of ChainRules will use its types like Composite and AbstractZero, but thing created via Source Code Transform (SCT) will still use NamedTuple and nothing.

This is fine as they are mutually compatible because accum falls back to + and ChainRules overloads +.

The use of Base types in Zygote tends to cause issues as its hard to add methods for NamedTuple and nothing, due to type piracy. Especially things like defining linear operations on them (e.g. overloading things from LinearAlgebra), as well as defining addition +. Particular discussion on Composite ans structured differential types is in #462

Related to:

454 #419 #329

oxinabox commented 4 years ago

Discussion @MikeInnes and I had today

Lyndon White:ox: I am now going through and fixing broken tests or rather fixing issues that the tests reveal One of the tests fails unless I define Base.:*(::Int, ::Nothing) = nothing which presumable means that somewhere Zygote (on perhaps inside a ZygoteRule?) there was something removing that nothing before it could hit if I just went full over to ChainRules types this would not be an issue since AbstractZero does define multiplcation. Does this mean there is some code somewhere that would be able to be deleted?

Mike Innes By default @adjoint rules ignore nothing entirely So I guess it’s just that we’re passing nothing to an rrule which tries to use it like a numeric gradient

Lyndon White:ox: yep

Mike Innes We can just add that dispatch to the CR rule wrapper function though, right?

Lyndon White:ox: Yeah we can What do you do when 1 input is nothing and other is non-nothing? and so on and so forth for nested tuples of input?

Mike Innes Since it’s an rrule there’s only one input

Lyndon White:ox: but that input could be a structure (inc tuple)

Mike Innes The tuple/struct functions can just carry on using the Zygote definitions for now

Lyndon White:ox: indeed, I am just curious

Mike Innes If you have an adjoint that takes a struct, all you usually do with the struct elements is forward them on to other gradient functions

Lyndon White:ox: That seems like it could mean missing out on some CSE.

Mike Innes I can imagine that in some cases you’d need to explicitly cast to a reasonable zero type, but that doesn’t seem to have come up too much in practice

Lyndon White:ox: Yeah, I only have seen a few structure pullbacks, so my instrincts are not great (edited) I think Will probably has a ton of them in his GP via Kalman filter package though

Mike Innes It would be useful to know if he runs into issues. I imagine we could implement a generic “cast zero” that supports AbstractZero and nothing

Lyndon White:ox: I think we do just want to get rid of nothing there are just too many random issues that show up from nothing not supporting +

Mike Innes Perhaps; my impression has been that most of those kinds of issues have highlighted things you want to fix anyway

Lyndon White:ox: That is true at very least to an exent

Mike Innes Julia gives you a lot of ease-of-use/performance tradeoffs in cases like this, what with generic code often being slow

Lyndon White:ox: I don’t think this is one of those cases Also there are a bunch of things that are simplified if you can define things on them (beyond + and *) Apparently adding Base.:/(z::AbstractZero, ::Any) = z massively simplified one of the most complicated rules relating to I think a derivative of a structure (Probably SVD or Cholekey) (edited)

Mike Innes That’s interesting

Lyndon White:ox: I would like to create a package that adds traits to all linear operators, so we could easily go and make them propage zeros like that. (idea not fully formed)

Mike Innes I’m thinking of cases like getindex, where AbstractZero would propagate ‘for free’, but actually having them in your matrices is terrible for performance compared to casting to Float64 (edited) Since then you lose blas for the rest of the computation

Lyndon White:ox: Oh yeah that is true

Mike Innes Of course we can just remove them; but then it’s equivalent to nothing while being easier to get wrong That’s just one case of course

Lyndon White:ox: In a matrix, if you are in that case, you should probably be using a sparse type You can’t use nothing there either for same reason I conceptuallize AbstractZero as equiv to the type of a structural zero in a sparse matrix It acts the same, including in presence of NaN I don’t see how it is easier to get wrong in that case?

Mike Innes e.g. getindex has given us a lot of pain because nothing keeps making it throw errors. But it’s actually good that we cast them away where possible. If we had AbstractZero, those cases would have worked fine but been slow So I think in that case the errors were a good motivator to fix the problem (not that you couldn’t write the same code over AbstractZero)

Lyndon White:ox: That matrix case might actually be a reason to use AbstractZero Since can define convert(T, ::AbstractZero) = zero(T) and then setindex! will do the convert for us.

Mike Innes That is true

Mike Innes My other concern is about the type of the expression gradient(f, x) With x a struct, it seems wrong that this is Union{Number,Struct} because you can write code that treats this as a number, and it’ll break when x actually has a gradient

Lyndon White:ox: AbstractZero does not subtype Number

Mike Innes Really I think the answer here is that the type is Differential where Zero is a differential, but in that case the name is a bit misleading Not that I have a better idea This may or may not really be worth worrying about, but I think you appreciate these kinds of issues

Thinking about AbstractZero as ‘the identity differential’ rather than zero(x) for all x does make me feel a bit warmer to it though

DhairyaLGandhi commented 4 years ago

Fwiw I think the concept of a AbstractGradient type, with ZeroGrad as an identity, sounds reasonable. This was the implicit definition of having nothing as a valid gradient, I feel.

oxinabox commented 4 years ago

One important thing to do as part of this PR is to make sure to have a clean deprecation path, so we don't break all existing custom rules that assume NamedTuple and nothing work. Part of this might be adding stuff to ZygoteRules to do conversion.

oxinabox commented 4 years ago

Advantages of this:

nickrobinson251 commented 4 years ago

Another issue we should make sure is resolved by/when changing to ChainRules types: https://github.com/FluxML/Zygote.jl/issues/802