AlgebraicJulia / DiagrammaticEquations.jl

MIT License
9 stars 1 forks source link

Oscillator Decapode code generation, state inference fails #28

Closed jClugstor closed 4 months ago

jClugstor commented 4 months ago
using Catlab
using Catlab.Graphics
using CombinatorialSpaces
using Decapodes

using Catlab
using Catlab.Graphics
using CombinatorialSpaces
using CombinatorialSpaces.ExteriorCalculus
using DiagrammaticEquations
using DiagrammaticEquations.Deca
using Decapodes

oscillator = @decapode begin
    X::Form0
    V::Form0

    k::Constant

    ∂ₜ(X) == V
    ∂ₜ(V) == -k*(X)
end

decapode_code = gensim(oscillator, dimension = 1)

infer_state_names(oscillator)

gensim(oscillator, dimension = 1) generates:

   #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:570 =#
    function simulate(mesh, operators, hodge = GeometricHodge())
        #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:570 =#
        #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:571 =#
        begin
            #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:174 =#
        end
        #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:572 =#
        begin
            #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:468 =#
        end
        #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:573 =#
        begin
            #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:227 =#
            var"__•1" = Decapodes.FixedSizeDiffCache(Vector{Float64}(undef, nparts(mesh, :V)))
            __V̇ = Decapodes.FixedSizeDiffCache(Vector{Float64}(undef, nparts(mesh, :V)))
        end
        #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:574 =#
        f(du, u, p, t) = begin
                #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:574 =#
                #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:575 =#
                begin
                    #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:252 =#
                    X = u.X
                    k = p.k
                end
                #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:576 =#
                var"•1" = Decapodes.get_tmp(var"__•1", u)
                V̇ = Decapodes.get_tmp(__V̇, u)
                var"•1" .= (.-)(k)
                V̇ .= var"•1" .* X
                #= /home/jadonclugston/.julia/packages/Decapodes/qxJAY/src/simulation.jl:577 =#
                getproperty(du, :X) .= V
                getproperty(du, :V) .= V̇
            end
    end
end

The value of V is never set or updated.

infer_state_names(oscillator) does not include V:

Out[1]: 2-element Vector{Symbol}:
 :X
 :k
jpfairbanks commented 4 months ago

I manually overrode the state names for gensim and the generated simulation code looks like it will work.

decapode_code = gensim(oscillator, [:X, :k, :V], dimension = 1)

yields


quote
    #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:570 =#
    function simulate(mesh, operators, hodge = GeometricHodge())
        #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:570 =#
        #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:571 =#
        begin
            #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:174 =#
        end
        #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:572 =#
        begin
            #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:468 =#
        end
        #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:573 =#
        begin
            #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:227 =#
            var"__•1" = Decapodes.FixedSizeDiffCache(Vector{Float64}(undef, nparts(mesh, :V)))
            __V̇ = Decapodes.FixedSizeDiffCache(Vector{Float64}(undef, nparts(mesh, :V)))
        end
        #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:574 =#
        f(du, u, p, t) = begin
                #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:574 =#
                #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:575 =#
                begin
                    #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:252 =#
                    X = u.X
                    k = p.k
                    V = u.V
                end
                #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:576 =#
                var"•1" = Decapodes.get_tmp(var"__•1", u)
                V̇ = Decapodes.get_tmp(__V̇, u)
                var"•1" .= (.-)(k)
                V̇ .= var"•1" .* X
                #= /Users/fairbanksj/AlgebraicJulia/Decapodes.jl/src/simulation.jl:577 =#
                getproperty(du, :X) .= V
                getproperty(du, :V) .= V̇
            end
    end
end
jpfairbanks commented 4 months ago

I think this would be a slightly more generic fix:

# Original definition of state variables, which is those that cannot be directly calculuted by the formulas.
# function DiagrammaticEquations.infer_states(d::SummationDecapode)
#   filter(parts(d, :Var)) do v
#       length(incident(d, v, :tgt)) == 0 &&
#       length(incident(d, v, :res)) == 0 &&
#       length(incident(d, v, :sum)) == 0 &&
#       d[v, :type] != :Literal
#   end
# end

infer_state_names(oscillator)

# Definition of state variables as "things that have time derivatives, plus stuff that we can't compute implicitly."
function variables_having_derivative(d::SummationDecapode)
    union(d[incident(d,:∂ₜ, :op1), :src],
        filter(parts(d, :Var)) do v
            length(incident(d, v, :tgt)) == 0 &&
            length(incident(d, v, :res)) == 0 &&
            length(incident(d, v, :sum)) == 0 &&
            d[v, :type] != :Literal
        end)
end

DiagrammaticEquations.infer_states(d::SummationDecapode) = variables_having_derivative(d)

variables_having_derivative(oscillator)
decapode_code = gensim(oscillator, dimension = 1)

@lukem12345, what do you think about upstreaming this to DiagrammaticEquations?

lukem12345 commented 4 months ago

Yeah the boolean check here just needs to be a little different since dt is also valid