JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

pullback on map(transform, colvec) breaks types #113

Closed theogf closed 4 years ago

theogf commented 4 years ago

Here is a MWE :

using KernelFunctions, Zygote
A = KernelFunctions.ColVecs(rand(10, 10))
t = ScaleTransform(rand())
typeof(map(t, A)) # ColVecs
v, g = pullback(map, t, A)
typeof(v) # Vector{Vector}

This make AD with Zygote not feasible at the moment.

willtebbutt commented 4 years ago

You should pretty much be able to copy+paste code from Stheno.jl. I've adapted it a bit below to reflect what we've got going on:

@adjoint function ColVecs(X::AbstractMatrix)
    pullback_ColVecs(Δ::NamedTuple{(:X,)}) = (Δ.X,)
    return ColVecs(X), pullback_ColVecs
end

Stheno's version is a little more liberal with the types it allows for Δ -- this is a mistake that I've not gotten around to fixing yet.

Additionally, if it's going to be the case that we'll be using getindex on a ColVecs in any code that needs to be differentiated through, then we need to define a custom adjoint for getindex also, to make sure that it does the right thing (currently it will fall back to Zygote's default implementation for AbstractVectors, which is necessarily a bit overzealous in terms of the things that it thinks it's correct for)

I would strongly recommend trying to avoid backprop-ing through getindex at all costs -- it's (currently-unavoidably) incredibly slow. We might need to open a separate issue about this.

theogf commented 4 years ago

I think you implemented these adjoints already however that's not the problem I think. The problem comes from the fact that the value returned by pullback is currently a vector of vector which seems to indicate that the pullback for map on a vector is defined in a custom manner. I think we need to define our own adjoint for map(t::Transform, x::ColVecs)

willtebbutt commented 4 years ago

Oh, weird, that might be the case. We probably need more careful testing of the transform code then. I'll address at the same time as #88 .

theogf commented 4 years ago

Here is the bad guy : https://github.com/FluxML/Zygote.jl/blob/3e0a904cae5565bc7b458ad6d575c531ce70dd96/src/lib/array.jl#L180

theogf commented 4 years ago

Ok so just using a new name (not Base.map) let Zygote runs smoothly. Should we look for a way to solve this? Or should we stop using Base.map and just define a new function, e.g. apply(t::Transform, X::ColVecs) or even keep the kernel approach and use (t::Transform)(X::ColVecs) ? Any opinion @willtebbutt @devmotion ?

willtebbutt commented 4 years ago

We can just overload map in the cases that we care about, and open an issue on Zygote. I refuse to not use the appropriate function from Base because Zygote is being greedy.

theogf commented 4 years ago

I might be wrong about the way to do it, but I think we need to create an adjoint literally for each case where we overloaded Base.map with ColVecs or RowVecs, that's a lot...

devmotion commented 4 years ago

Isn't it sufficient to define

ZygoteRules.@adjoint function Base.map(f, x::ColVecs)
  ...
end

ZygoteRules.@adjoint function Base.map(f, x::RowVecs)
  ...
end

?

theogf commented 4 years ago

But how do you get a general pullback function for any f ?

willtebbutt commented 4 years ago

We should be able to do it with a bit of indirection

Base.map(f::Transform, x::ColVecs) = vectorised_evaluation(f, x)
function ZygoteRules.@adjoint Base.map(f::Transform, x::ColVecs)
    return Zygote.pullback(vectorised_evaluation, f, x)
end

And replace all implementations of map(f::Transform, x::ColVecs) with vectorised_evaluation for f -- vectorised_evaluation has taken the place of apply, but is clearly mean to apply the transform to collections of inputs rather than being ambiguous.

It's reasonable to ask at this point why we don't just replace map with vectorised_evaluation everywhere. To which I maintain that it's a point of principle. This shouldn't be an issue -- the fact that we can't implement a custom method for map and have Zygote do the correct thing is a bug in Zygote on which I don't want to start depending.

theogf commented 4 years ago

I can open an issue on Zygote, but I am not even sure what I should write... That the adjoint for map on vectors does not respect the typing?

willtebbutt commented 4 years ago

That's cool. I'll open one.

theogf commented 4 years ago

Ok I for now used the indirection by using the name _map (close enough) and tests are passing :)

willtebbutt commented 4 years ago

Ok I for now used the indirection by using the name _map (close enough) and tests are passing :)

Nice, much better name :)

theogf commented 4 years ago

Will be solved in #114