Open swilliamson7 opened 4 months ago
what version of things are you on -- as this worked for me [though I needed to add Enzyme.API.runtimeActivity!(true)]
I'm on Enzyme v0.12.22, but I don't have the runtime activity flag so can try adding that
Whoops that was wrong, I'm actually on Enzyme v0.12.25#main
Adding the flag did indeed make the error go away, but weirdly the error still happens with my version of the code that uses a struct:
using Enzyme
Enzyme.API.runtimeActivity!(true)
mutable struct model
J::Float64
nt::Int
k::Float64
r::Float64
x0::Vector{Float64}
data_steps::StepRange{Int64, Int64}
data::Array{Float64}
end
function integrate(model)
nt = model.nt
dt = 0.001
x0 = model.x0
k = model.k
r = model.r
data = model.data
data_steps = model.data_steps
J = model.J
A = [1 0 0 dt 0 0;
0 1 0 0 dt 0;
0 0 1 0 0 dt;
-2*k*dt k*dt 0 1-r*dt 0 0;
k*dt -3*k*dt k*dt 0 1-r*dt 0;
0 k*dt -2*k*dt 0 0 1-r*dt
]
E = [1 0 0 0 0 0;
0 1 0 0 0 0;
0 0 1 0 0 0;
0 0 0 1 0 0;
0 0 0 0 1 0;
0 0 0 0 0 1
]
sigma_forcing = 0.01
B = [1 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0]
function u(t)
return sigma_forcing * randn(6)
end
Gamma = [1 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0;
0 0 0 0 0 0]
all_states = zeros(6, nt+1)
all_states[:,1] = x0
x = x0
current_timestep = dt
for j in 1:nt+1
x = A * x
all_states[:, j] = x
current_timestep += dt
if j in data_steps
J += sum((E * data[:, j] - E * x).^2)
end
data[:, j] = x
end
return nothing
end
nt = 10000
parameters = model(
0.0,
10000,
30.,
0.5,
[1., 2., 3., 0., 0., 0.],
3000:300:7000,
zeros(6, nt+1)
)
integrate(parameters)
data = parameters.data + 0.5 .* randn(6, nt+1)
parameters_for_enzyme = model(
0.0,
10000,
30.,
0.5,
[1., 2., 3., 0., 0., 0.],
3000:300:7000,
data
)
derivatives = Enzyme.make_zero(parameters_for_enzyme)
autodiff(ReverseWithPrimal,
integrate,
Duplicated(parameters_for_enzyme, derivatives)
)
I'm in the process of trying to minimize a bigger error, but with my smaller code have hit a "invoke is not a generic function" error I've not seen before. This very well might be from me doing something silly, but the error output seems to point to Enzyme code so not sure.
The code I'm running:
and the error I see: