JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
258 stars 62 forks source link

Helper for creating `InplaceableThunks` when it is just broadcast + #274

Open nickrobinson251 opened 3 years ago

nickrobinson251 commented 3 years ago

This is a follow-up to this discussion in JuliaDiff/ChainRules.jl#336.

JuliaDiff/ChainRules.jl#336 improves the array rules for sum by changing the code e.g. (in the case of sum(abs2, x)) from 2 .* real.(ȳ) .* x to

InplaceableThunk(
    @thunk(2 .* real.(ȳ) .* x),        # val
    dx -> dx .+= 2 .* real.(ȳ) .* x    # add!(dx)
)

This makes two improvements:

It took me a while to work out why (2) was in improvement. The docs on InplaceableThunks say

add! should be defined such that: ithunk.add!(Δ) = Δ .+= ithunk.val but it should do this more efficently than simply doing this directly.

Looking at the code above, where val = 2 .* real.(ȳ) .* x, why is add!(dx) = dx .+= 2 .* real.(ȳ) .* x "more efficient" that add!(dx) = dx .+= val? By copying the code for val into the add! function we get a single expression, allowing the broadcast to be "fused", and thereby avoid allocating an intermediate val = 2 .* real.(ȳ) .* x array.

So that's cool! (Aside: there are some good blog posts about Julia's loop fusion and broadcast magic)

But it did mean we had to copy code. This issue is to ask "can we do this without having to copy code?" i.e. it's about API / user-friendliness / reducing code / syntactic stuff (which might in turn make this performance improvement more widely used in our array rules).

I see two options, but perhaps there are others:

(A) create a macro like @inplaceable_thunk

If we did this, code such as

x_thunk = InplaceableThunk(
    @thunk(2 .* real.(ȳ) .* x),
    dx -> dx .+= 2 .* real.(ȳ) .* x
)

could instead be written more succinctly as

x_thunk = @inplaceable_thunk(2 .* real.(ȳ) .* x)

(B) have @thunk always return an InplaceableThunk with the add! function defined like above (i.e. copying in the code for val)

I'm not sure if (B) is a valid option. But perhaps it is, if users are expected to go via the add!! function (which checks is_inplaceable_destination).

oxinabox commented 3 years ago

my worry is that here we are mixing out design to fuse with broadcasting. Ideally we would get a error if you tried to do add!!([1,2], @inplaceablethunk(1)) but wouldn't this make it give the result [2, 3]?

nickrobinson251 commented 3 years ago

yes, i suppose that'd be true. And I think that rules out (B), but not (A).

add!!([1, 2], @thunk(1)) would be an error (as now), but add!!([1, 2], @inplaceablethunk(1)) would return [2, 3].

add!!([1, 2], @inplaceablethunk(1)) is not something you'd ever want to do. add!!([1, 2], @inplaceablethunk([1, 10])) is the kind of thing you might want.

is this a new kind of risk? We could already write stuff like add!!([1, 2], InplaceableThunk(@thunk(1), dx -> dx .+= 1)). But perhaps this increases the risk of accidentally mis-using InplaceableThunk too much?

Or perhaps i'm not quite understanding the worry?

oxinabox commented 3 years ago

yeah, maybe it isn't a realistic worry. Dynamic language and all, can't use types to ensure correctness etc, need to use tests. Tests will catch this

mcabbott commented 3 years ago

A variant of (B) is to make @inplaceablethunk(1) an error -- if the top-level expression isn't broadcasting then it shouldn't guess.

Another nice thing to have would be a 2-argument @inplaceablethunk(f(x,dy), dx -> g!(dx,x,dy)). Writing InplaceableThunk(@thunk(...), ...) the whole time is noisy.

oxinabox commented 3 years ago

The also relates to the question of projecting with inplaceable-thunks. When the inplace part is nothing more than the broadcasting of the out-of-place form. It is kinda the easy case. Since one can just project before adding, though that will cause an allocation, and so render it pointless.

In https://github.com/JuliaDiff/ChainRulesCore.jl/pull/393#issuecomment-881003612 @mcabbott said

For many rules which end up broadcasting dx -> dx .+= ..., when projection is only about eltype it would be nice to insert it. I guess you can just write dx -> dx .+= projector.element.(...) by hand.

Which solves that allocation if it is only about the eltype.


I wonder if we can even escape from using InplacableThunk for this case.

Remembering that the lowering is

julia> f(x) = x .+= 1
f (generic function with 1 method)

julia> @code_lowered f([1,2])
CodeInfo(
1 ─ %1 = Base.broadcasted(Main.:+, x, 1)
│   %2 = Base.materialize!(x, %1)
└──      return %2
)

I wonder if we could have a macro like @bthunk(x*y') that leverages that to give us just the broadcasted object (or probably something that wraps it). and that has a overloads for ProjectTo. Such that for example ProjectTo{Diagonal} would cause it not to materialize objects off the diagonal.

Possibly the whole trick here might be that rather than costomizing the data being added, instead customize the +? We could have a projected_+ that overloads its broadcast style to do the projection in that step?

mcabbott commented 3 years ago

If there was method add!!(dx::AbstractArray, plus::Broadcasted) then it could accept un-materialised broadcasts of just the RHS, just 2 .* real.(ȳ) .* x, and turn them into dx .+= .... Then a method add!!(dx::Diagonal, plus::Broadcasted) could as you say customise its materialisation only to iterate the diagonal, although making it fast might be tricky. add!!(dx::UpperTriangular, plus::Broadcasted) is much easier, just writes to the parent. All could insert element-type projectors into the fused broadcast.