Closed theogf closed 4 years ago
You should pretty much be able to copy+paste code from Stheno.jl. I've adapted it a bit below to reflect what we've got going on:
@adjoint function ColVecs(X::AbstractMatrix)
pullback_ColVecs(Δ::NamedTuple{(:X,)}) = (Δ.X,)
return ColVecs(X), pullback_ColVecs
end
Stheno's version is a little more liberal with the types it allows for Δ
-- this is a mistake that I've not gotten around to fixing yet.
Additionally, if it's going to be the case that we'll be using getindex
on a ColVecs
in any code that needs to be differentiated through, then we need to define a custom adjoint for getindex
also, to make sure that it does the right thing (currently it will fall back to Zygote's default implementation for AbstractVector
s, which is necessarily a bit overzealous in terms of the things that it thinks it's correct for)
I would strongly recommend trying to avoid backprop-ing through getindex
at all costs -- it's (currently-unavoidably) incredibly slow. We might need to open a separate issue about this.
I think you implemented these adjoints already however that's not the problem I think.
The problem comes from the fact that the value returned by pullback
is currently a vector of vector which seems to indicate that the pullback for map on a vector is defined in a custom manner. I think we need to define our own adjoint for map(t::Transform, x::ColVecs)
Oh, weird, that might be the case. We probably need more careful testing of the transform code then. I'll address at the same time as #88 .
Ok so just using a new name (not Base.map
) let Zygote runs smoothly.
Should we look for a way to solve this? Or should we stop using Base.map
and just define a new function, e.g. apply(t::Transform, X::ColVecs)
or even keep the kernel approach and use (t::Transform)(X::ColVecs)
?
Any opinion @willtebbutt @devmotion ?
We can just overload map in the cases that we care about, and open an issue on Zygote. I refuse to not use the appropriate function from Base because Zygote is being greedy.
I might be wrong about the way to do it, but I think we need to create an adjoint literally for each case where we overloaded Base.map
with ColVecs
or RowVecs
, that's a lot...
Isn't it sufficient to define
ZygoteRules.@adjoint function Base.map(f, x::ColVecs)
...
end
ZygoteRules.@adjoint function Base.map(f, x::RowVecs)
...
end
?
But how do you get a general pullback function for any f ?
We should be able to do it with a bit of indirection
Base.map(f::Transform, x::ColVecs) = vectorised_evaluation(f, x)
function ZygoteRules.@adjoint Base.map(f::Transform, x::ColVecs)
return Zygote.pullback(vectorised_evaluation, f, x)
end
And replace all implementations of map(f::Transform, x::ColVecs)
with vectorised_evaluation
for f
-- vectorised_evaluation
has taken the place of apply
, but is clearly mean to apply the transform to collections of inputs rather than being ambiguous.
It's reasonable to ask at this point why we don't just replace map
with vectorised_evaluation
everywhere. To which I maintain that it's a point of principle. This shouldn't be an issue -- the fact that we can't implement a custom method for map
and have Zygote do the correct thing is a bug in Zygote on which I don't want to start depending.
I can open an issue on Zygote, but I am not even sure what I should write... That the adjoint for map
on vectors does not respect the typing?
That's cool. I'll open one.
Ok I for now used the indirection by using the name _map
(close enough) and tests are passing :)
Ok I for now used the indirection by using the name _map (close enough) and tests are passing :)
Nice, much better name :)
Will be solved in #114
Here is a MWE :
This make AD with Zygote not feasible at the moment.