Closed prbzrg closed 2 months ago
All of them works with Zygote
. And the error is different when I don't use StatefulLuxLayer
.
This is very likely a DI issue due to closures (there is an option I think to mark f
as constant)
cc @gdalle @wsmoses
Yeah I presume this succeeds if you just do Enzyme.gradient/jacobian(Reverse, Const(f), ...), no ?
If the function is truly constant, can you try AutoEnzyme(function_annotation=Enzyme.Const)
?
It didn't work. But I got a different error. I updated the first post.
~Ok this is more of a Lux issue, but I am very surprised Enzyme is hitting Octavian :sweat:. I explicitly try my best to circumvent all loopvec/octavian/polyester for enzyme and just call BLAS or use loops.~
The dispatches are defined exclusively for ReverseMode, they need to be extended to ForwardMode
The solution would probably be to bite the bullet and write the enzyme rules https://github.com/LuxDL/LuxLib.jl/blob/c185f04183d760b84d0dcfa2b49511255cd1e7dc/src/impl/matmul.jl#L233-L238, instead of switching the implementations
A smaller reproducer
using Lux, Enzyme, Random
n = 2
r = rand(Float32, n, n)
nn = Chain(Dense(n => n, tanh))
ps, st = Lux.setup(Random.default_rng(), nn)
Enzyme.autodiff(Forward, Const(LuxCore.stateless_apply), Duplicated,
Const(nn), Duplicated(r, one.(r)), Const(ps))
ReverseMode enzyme works fine
DifferentiationInterface.jacobian(snn, AutoEnzyme(; function_annotation=Enzyme.Const, mode=Enzyme.Reverse), r)