ACEsuit / Polynomials4ML.jl

Polynomials for ML: fast evaluation, batching, differentiation
MIT License
12 stars 5 forks source link

ReverseDiff Compatibility #69

Closed cortner closed 11 months ago

cortner commented 11 months ago

In order to be able to do reverse over reverse, we need (for the time being at least) to use ReverseDiff for the second derivative. I think this is best implemented via an extension. This extension would simply "register" all the rrules we have with ReverseDiff.

It is not 100% clear to me why this is needed, but right now I get some error messages from ReverseDiff about trying to convert a tracked real to a real. But in my early experiments this wasn't needed. Cf.

https://github.com/cortner/2023MitacsDV/blob/master/co_doublebackprop/ace_example_2.jl
cortner commented 11 months ago

Here is a minimal working example for the problem I'm seeing:


using Polynomials4ML, Lux, ReverseDiff, Random, Zygote, Optimisers
rng = MersenneTwister()

model = Chain(LinearLayer(10, 1), WrappedFunction(x -> x[1]))
ps, st = Lux.setup(rng, model)

x = randn(10)
model(x, ps, st)[1]
Zygote.gradient(_x -> model(_x, ps, st)[1], x)[1]
ReverseDiff.gradient(_x -> model(_x, ps, st)[1], x)

p, rest = destructure(ps)
Zygote.gradient(_p -> model(x, rest(_p), st)[1], p)[1]
ReverseDiff.gradient(_p -> model(x, rest(_p), st)[1], p)
# ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Float64 is not defined. Please use `ReverseDiff.value` instead.

I suspect that ReverseDiff simply doesn't use the rrule because we didn't tell it to.

help?> ReverseDiff.@grad_from_chainrules
  @grad_from_chainrules f(args...; kwargs...)

  The @grad_from_chainrules macro provides a way to import adjoints(rrule) defined
  in ChainRules to ReverseDiff. One must provide a method signature to import the
  corresponding rrule. In the provided method signature, one should replace the
  types of arguments to which one wants to take derivatives with respect with
  ReverseDiff.TrackedReal and ReverseDiff.TrackedArray respectively. For example,
  we can import rrule of f(x::Real, y::Array) like below:

  ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
  ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Array)
  ReverseDiff.@grad_from_chainrules f(x::Real, y::TrackedArray)
cortner commented 11 months ago

Hm ... looking at the implementatio of the linear layer, the problem seems to be elsewhere. I think I can fix it and will make a PR.

cortner commented 11 months ago

70 seems to at least fix the mini-example above.

cortner commented 11 months ago

.... and it also fixed my complex ACE model differentiation.

This is very good news. No ReverseDiff extension needed for now. We can return to it later if we need it for performance optimisation.