SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
131 stars 97 forks source link

fix `observed` for DEintegrator #595

Closed oscardssmith closed 8 months ago

oscardssmith commented 8 months ago

This still needs a test, but it turns out both of these methods are unnecessary (since they just fall back to working on the ScimlFunction) and the observed implementation here is just incorrect.

codecov[bot] commented 8 months ago

Codecov Report

Attention: 10 lines in your changes are missing coverage. Please review.

Comparison is base (0c62711) 26.25% compared to head (1cb0ec2) 30.53%. Report is 4 commits behind head on master.

Files Patch % Lines
src/scimlfunctions.jl 0.00% 9 Missing :warning:
src/problems/problem_interface.jl 0.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #595 +/- ## ========================================== + Coverage 26.25% 30.53% +4.27% ========================================== Files 54 54 Lines 4113 4120 +7 ========================================== + Hits 1080 1258 +178 + Misses 3033 2862 -171 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

oscardssmith commented 8 months ago

tests added. I think this is ready to merge.

AayushSabharwal commented 8 months ago

I get the is_observed fallback, but what is wrong with the observed implementation?

Without it observed expressions won't work for integrators. i.e. getu(integrator, sys.x + sys.y)(integrator) would error?

AayushSabharwal commented 8 months ago

I guess the is_observed method can also be removed for AbstractSciMLProblem in problems/problem_interface.jl as well.

oscardssmith commented 8 months ago

the correct observed function would just call getobserved. This version both gets the observed function and then calls it with work incorrect arguments (it doesn't pass in u, p, or t)

oscardssmith commented 8 months ago

I guess the is_observed method can also be removed for AbstractSciMLProblem in problems/problem_interface.jl as well.

good call on this. When removing it, I discovered that the observed function was broken in the same way as the integrator one :laughing:

AayushSabharwal commented 8 months ago

the correct observed function would just call getobserved. This version both gets the observed function and then calls it with work incorrect arguments (it doesn't pass in u, p, or t)

Where is this function defined? There's no method that SII.observed will fall back to now.

Also you'll need to disambiguate between ModelingToolkit.observed and SII.observed in the test since they're both loaded

oscardssmith commented 8 months ago

SII.observed falls back to the definition in SymolicIndexingInterfase (specifically observed(sys, sym) = observed(symbolic_container(sys), sym).

oscardssmith commented 8 months ago

lets see how CI likes this.

AayushSabharwal commented 8 months ago

SII.observed falls back to the definition in SymolicIndexingInterfase (specifically observed(sys, sym) = observed(symbolic_container(sys), sym).

Yeah but eventually something needs to implement it, because otherwise it'll fallback to the system which doesn't

oscardssmith commented 8 months ago

it gets implimented by SciMLFunction.

oscardssmith commented 8 months ago

I believe the symbolic indexing tests are passing (they run locally at least), but I'm having some trouble figuring out which of the CI runs is the related one.

AayushSabharwal commented 8 months ago

It's CI / test (Downstream, 1) which seems to fail before it even gets to that point. The error seems to be an Optimization.jl thing?

oscardssmith commented 8 months ago

Oh, I think the Optimization problem is https://github.com/SciML/SciMLBase.jl/pull/600

oscardssmith commented 8 months ago

rebased. Let's try this again.

ChrisRackauckas commented 8 months ago

@Vaibhavdixit02 I thought the Optimization stats merging was handled?

Vaibhavdixit02 commented 8 months ago

It is, this looks like something else iiuc. The downstream CI right?

ChrisRackauckas commented 8 months ago

oh, I must've not refreshed. It's an @AayushSabharwal issue then 😅

AayushSabharwal commented 8 months ago

Yeah this is a bad test I fixed in https://github.com/SciML/SciMLBase.jl/pull/584

get_obs = getu(sys_simplified, lorenz1.x + lorenz2.x)
get_obs_arr = getu(sys_simplified, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])

should be

get_obs = getu(prob, lorenz1.x + lorenz2.x)
get_obs_arr = getu(prob, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])
ChrisRackauckas commented 8 months ago

which PR?

AayushSabharwal commented 8 months ago

584

AayushSabharwal commented 8 months ago

Bump. This is kinda needed for MTK

ChrisRackauckas commented 8 months ago

This looked related: https://github.com/SciML/SciMLBase.jl/actions/runs/7649534210/job/20844091535?pr=595#step:6:2460 ?

oscardssmith commented 8 months ago

agreed. I'm unsure why this pr would have regressed inference though.

AayushSabharwal commented 8 months ago

Could you try rebasing this?

oscardssmith commented 8 months ago

rebased.

ChrisRackauckas commented 8 months ago

Downstream still failing.

oscardssmith commented 8 months ago

yes. I still have no idea why this PR causes us to lose inference precision.

AayushSabharwal commented 8 months ago

Changing the is_observed function in scimlfunctions.jl to

SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys(fn) ? is_observed(fn.sys, sym) : has_observed(fn)

helps, but the same error occurs later on.

AayushSabharwal commented 8 months ago

Changing SII.observed in the same file to:

function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym)
  if has_observed(fn)
    if is_time_dependent(fn)
      return if hasmethod(fn.observed, Tuple{typeof(sym)})
        fn.observed(sym)
      else
        let obs = fn.observed, sym = sym
          (u, p, t) -> obs(sym, u, p, t)
        end
      end
    else
      return if hasmethod(fn.observed, Tuple{typeof(sym)})
        fn.observed(sym)
      else
        let obs = fn.observed, sym = sym
          (u, p) -> obs(sym, u, p)
        end
      end
    end
  end
  error("SciMLFunction does not have observed")
end

Fixes this. I'm not completely sure why. We need the if condition because the generated observed for SDEs doesn't allow calling it with only a symbol (sdeprob.f.observed(sym) throws a MethodError). hasmethod is static so the branches should be removed at compile time.

oscardssmith commented 8 months ago

Something feels deeply wrong to me here. Should sdeprob.f.observed(sym) just work? where is the ODE version done ( odeprob.f.observed(sym))?

oscardssmith commented 8 months ago

that seems to have made things much worse.

AayushSabharwal commented 8 months ago

Should sdeprob.f.observed(sym) just work? where is the ODE version done ( odeprob.f.observed(sym))?

Yeah it should work. All of this is implemented in MTK, and the SDE version just isn't up to spec.

that seems to have made things much worse.

Yeah, that's weird.

oscardssmith commented 8 months ago

since our timezones match quite poorly, would you mind taking this over (e.g. possibly making your own branch so you can push without waiting for me?) it would be very nice to get this merged soon.

AayushSabharwal commented 8 months ago

Succeeded by https://github.com/SciML/SciMLBase.jl/pull/615

I can't close this PR