SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
330 stars 70 forks source link

Adjoints of observed variables #760

Open iliailmer opened 1 year ago

iliailmer commented 1 year ago

I was trying to adapt this example from tutorials so that I can use a more customized data sample. For example, if I have states x1, x2, x3, I collect the following sample:

mq = [y1~x1, y2~x2+x3]

then I am trying to use loss as

function loss(p)
    sol = solve(prob, Tsit5(), p = p, saveat = tsteps)
    data_true = [data_sample[v.rhs] for v in mq]
    data = [sol[v.rhs] for v in mq] # this is where the code fails
    loss = sum(sum((data[i] .- data_true[i]) .^ 2 for i in eachindex(data)))
    return loss, sol
end

but I am getting ArgumentError: invalid index: x1(t) of type Term{Real, Base.ImmutableDict{DataType, Any}} at the line

 data = [sol[v.rhs] for v in mq]

What can be done to fix this?

ChrisRackauckas commented 1 year ago

This boils down to a simple issue that observed variables currently do not have adjoints defined. This had a PR that went stale: https://github.com/SciML/SciMLBase.jl/pull/85. We should revive it, and given we need to finish https://github.com/SciML/SciMLBase.jl/pull/342 when Yingbo comes back, @YingboMa let's take a day to solve these two before continuing to other things.

iliailmer commented 1 year ago

@ChrisRackauckas @YingboMa Thank you! For now, using numeric indices solves the issue.

BernhardAhrens commented 1 year ago

@iliailmer Could you give some more details for your workaround?

For now, using numeric indices solves the issue.

I think I have a similar use case. Thank you!

iliailmer commented 1 year ago

@BernhardAhrens sure, if states are [x1, x2] and measured quantities are [x1, x2^2] then I replace the line

data = [sol[v.rhs] for v in mq] # this is where the code fails

with

data = [sol[1] sol[2]^2]
ChrisRackauckas commented 1 year ago

https://github.com/SciML/SciMLBase.jl/pull/479 Handles a lot of cases, but sol[sym] is still missing an overload.