Closed ChrisRackauckas closed 1 month ago
Note that by default it just throws a warning because it's a try/catch:
https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/concrete_solve.jl#L23-L27
You can force the error by changing solve(prob, Vern9(), p = θ, saveat = ts, sensealg = GaussAdjoint(autojacvec=EnzymeVJP()))
A more minimal version:
using LuxLib
using ComponentArrays, Random, Enzyme
ps = (; w=rand(5, 5), b=rand(5))
xtest = rand(5, 1)
dx = zeros(size(xtest))
function test_function(x, ps, st)
y = LuxLib.fused_dense_bias_activation(identity, ps.w, x, ps.b)
return sum(y)
end
st = NamedTuple()
test_function(xtest, ps, st)
@time autodiff(Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps), Const(st))
ps_ca = ComponentArray(ps)
@time autodiff(
Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps_ca), Const(st))
@avik-pal that minimization was perfect.
I know what this is and I will start working on a fix
@ChrisRackauckas the latest jll bump I just pushed should remedy this
However the broadcast now requires runtime activity here (which turning on it all runs). I'll see if we can improve alias analysis to fix that here, but at least you should now be unblocked.
For the broadcast mixed activity perf: https://fwd.gymni.ch/k03gKb
Should now be fixed on main with latest jll bump, please reopen if it persists.
Found when updating https://github.com/SciML/SciMLSensitivity.jl/pull/1052
Throws this error