Closed cortner closed 11 months ago
Here is a minimal working example for the problem I'm seeing:
using Polynomials4ML, Lux, ReverseDiff, Random, Zygote, Optimisers
rng = MersenneTwister()
model = Chain(LinearLayer(10, 1), WrappedFunction(x -> x[1]))
ps, st = Lux.setup(rng, model)
x = randn(10)
model(x, ps, st)[1]
Zygote.gradient(_x -> model(_x, ps, st)[1], x)[1]
ReverseDiff.gradient(_x -> model(_x, ps, st)[1], x)
p, rest = destructure(ps)
Zygote.gradient(_p -> model(x, rest(_p), st)[1], p)[1]
ReverseDiff.gradient(_p -> model(x, rest(_p), st)[1], p)
# ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Float64 is not defined. Please use `ReverseDiff.value` instead.
I suspect that ReverseDiff
simply doesn't use the rrule
because we didn't tell it to.
help?> ReverseDiff.@grad_from_chainrules
@grad_from_chainrules f(args...; kwargs...)
The @grad_from_chainrules macro provides a way to import adjoints(rrule) defined
in ChainRules to ReverseDiff. One must provide a method signature to import the
corresponding rrule. In the provided method signature, one should replace the
types of arguments to which one wants to take derivatives with respect with
ReverseDiff.TrackedReal and ReverseDiff.TrackedArray respectively. For example,
we can import rrule of f(x::Real, y::Array) like below:
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Array)
ReverseDiff.@grad_from_chainrules f(x::Real, y::TrackedArray)
Hm ... looking at the implementatio of the linear layer, the problem seems to be elsewhere. I think I can fix it and will make a PR.
.... and it also fixed my complex ACE model differentiation.
This is very good news. No ReverseDiff extension needed for now. We can return to it later if we need it for performance optimisation.
In order to be able to do reverse over reverse, we need (for the time being at least) to use
ReverseDiff
for the second derivative. I think this is best implemented via an extension. This extension would simply "register" all the rrules we have with ReverseDiff.It is not 100% clear to me why this is needed, but right now I get some error messages from ReverseDiff about trying to convert a tracked real to a real. But in my early experiments this wasn't needed. Cf.