JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
255 stars 62 forks source link

happens-to-be-zero zero for tangent space of primal #476

Open willtebbutt opened 3 years ago

willtebbutt commented 3 years ago

Motivation

Somtimes it would be very convenient to have access to the element of the tangent space of a particular primal which happens to be zero for type-stability reasons. i.e. 0.0 and zeros(5, 4) rather than ZeroTangent.

The example I recently encountered where this would be helpful was reduce, specifically this line in Zygote's implementation of map. It's not possible to make this bit type-stable at the minute unless you know a priori whether or not the container you're mapping over is empty, so the pullback for map(function_with_fields, (5.0, 4.0)) infers, but the pullback map(function_with_fields, [5.0, 4.0]) does not. This is because the init kwarg is generally of a different type to the elements of Δf_and_args[1].

However, if we had access to a zero whose type doesn't change when we add cotangents to it, things ought to be type-stable.

Implementation

We know that the zero always exists, because the tangent space to a primal is a vector space, so there aren't any concerns regarding existence.

It's pretty clear what the right way to do this for composite types is via recursion (think rand_tangent, but zero rather than random), so we would just need to define it for primitives. It might get a little interesting here, because there are multiple possible tangent types for a Float64 primal (Float64, Float32, Float16, Int, etc) or a Vector primal (any AbstractVector of the same length with appropriate elements types), so possibly we would need additional information (such as the target tangent type) in order to do this.

Note that we do need to know the value of the primal for the same reasons that we need to have the primals hanging around in our projection functionality.

Anyway, I thought I'd bring this up because it's not something that we've thought much about before (ZeroTangent is often a really good option). It might be easier simply avoid situations like this most of the time (e.g. in the example I mentioned, using sized containers), but it's pretty annoying that the pullback for map isn't type-stable when mapping a closure over a Vector, because people to that a lot.

oxinabox commented 3 years ago

Would this be solved by https://github.com/JuliaLang/julia/issues/38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance. I can look into getting that resolved sooner rather than later if so.

willtebbutt commented 3 years ago

I don't think so. I've been encountering this with Zygote types (see line linked above), so I'm using nothing rather than ZeroTangent.

mcabbott commented 3 years ago

So you want a function that's a lot like dx = zero(x), but tries to guarantee that [dx1, dx2, dx3] will be of uniform type, when some are zero and some nonzero.

there are multiple possible tangent types for a Float64 primal (Float64, Float32, Float16, Int, etc)

Projection should convert all of these to Float64. But not all numbers:

julia> p = ProjectTo(1)
ProjectTo{Float64}()

julia> p(2f3)  # projected to Float64
2000.0

julia> p(Dual(1,2))  # passes through
Dual{Nothing}(1,2)

or a Vector primal (any AbstractVector of the same length with

What would be easy to do is to always use Fill(zero(T), ...) so that e.g. [dx1, dx2, dx3] isa Vector{<:AbstractVector} which will at least make some dispatch happy:

julia> @which reduce(hcat, [Fill(2,3), fill(4,3)])
reduce(::typeof(hcat), A::AbstractVector{<:AbstractVecOrMat}) in Base at abstractarray.jl:1624

but still won't be stable.

JuliaLang/julia#38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance.

Even with this, you'd still miss BLAS, right?

willtebbutt commented 3 years ago

So you want a function that's a lot like dx = zero(x), but tries to guarantee that [dx1, dx2, dx3] will be of uniform type, when some are zero and some nonzero.

Nearly. To be really specific, start with two things: a primal x, and a vectorts::Vector{T} of things in the tangent space of x. I want to sum ts. In general, to implement the summation you're going to need to write something like

reduce(+, ts; init=a_zero_tangent)

where a_zero_tangent is some representation of the zero element of the tangent space of x. What I want is a function zero_element(x, T) which picks the representation of a_zero_tangent such that the concrete type of the output of the sum is inferrable.

So if x is a Float64 and T == Float64 then I would want it to return 0.0.

Similarly, if x is a Vector{Float64} and T == Vector{Float64}, then I would want zeros(size(x)).

Structured array tangents are maybe more interesting. Suppose that x is a Vector{Float64} but T == Fill{Float64}, then I suspect you the optimal thing to do would be to make the zero a Fill{Float64}(0.0, size(x)) so that the sum can be performed efficiently. It would, of course, be totally valid to make the zero zeros(size(x)) again, it would just be sub-optimal, but would get you type -stability -- might have worse performance (locally) than the type-unstable version though.

If x is a Float64, and T == Real (perhaps because there's a mix of tangent precisions somehow), then I think you'd want to to make the initialisation a Float64 and somehow assert that the result must be a Float64 or something?

Projection should convert all of these to Float64. But not all numbers:

Oh, interesting, I didn't realise we were being that strict with projection. I guess I had been thinking that an Int is a perfectly fine tangent for a Float64 (no need to project) in the same sense that a Diagonal{Float64} is a perfectly good tangent representation for a Matrix{Float64}. A discussion for a different place perhaps...

mcabbott commented 3 years ago

being that strict

There's no maths in this, but accidental Float32 -> Float64 is such an easy performance bug to introduce, and we can kill it globally at last. But for Hessians etc. you may need dx::Dual so it can't be too strict. There's an argument that integers should be non-differentiable but making gradient(sin, 1) fail seems unfriendly, but nobody could think of a real use for integer gradients (beyond saving money on chalk).

ts::Vector{T} such that the concrete type of [sum(ts)] is inferrable.

Maybe this is easier than what I had, in that you're given T. For numbers init=zero(T) is done automatically.

I'm a bit confused by that line in the map adjoint though. How do you get map over an empty array to have a nontrivial gradient? Are any such cases not errors for other reasons?

willtebbutt commented 3 years ago

Maybe this is easier than what I had, in that you're given T. For numbers init=zero(T) is done automatically.

Exactly. Having both pieces of information is really helpful.

I'm a bit confused by that line in the map adjoint though. How do you get map over an empty array to have a nontrivial gradient? Are any such cases not errors for other reasons?

Here's an example of a programme that works:

julia> function foo(x)
           return x[1] + 3 * sum(map(sin, x[2:end]))
       end
foo (generic function with 1 method)

julia> Zygote.gradient(foo, randn(1))
([1.0],)

julia> Zygote.gradient(foo, randn(2))
([1.0, 2.9999914288727743],)
oxinabox commented 3 years ago

JuliaLang/julia#38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance.

Even with this, you'd still miss BLAS, right?

Right. Could we use Tullio for that?