HighDimensionalEconLab / DifferentiableStateSpaceModels.jl

MIT License
46 stars 1 forks source link

Test of AD with named tuples #41

Closed jlperla closed 3 years ago

jlperla commented 3 years ago

The chainrules crew tell me that it is now reasonable to pass in named tuples into AD, which might let us make a major syntax improvement in the DSL/perturbation solver proof of concept. Better to do now than later.

The thing to check is basically whether we can write a custom rrule that hooks into the named tuples approach.

Take the following code for the primal,

# Will only generate rrule for the `p` argument, and not the p_f!
function f(p, p_f)
    all_p = merge(p, p_f)

    return [all_p.a^2 + all_p.b
            all_p.c^2 + all_p.d^2]
end

function g_1(val)
    p = (a = val[1], b = val[2], d = val[3])
    p_f = (c = 4.0,)
    y = f(p, p_f)
    return sum(y)
end

function g_2(val)
    p = (b = val[1],d = val[2])
    p_f = (a = 5.0, c = 4.0,)
    y = f(p, p_f)
    return sum(y)
end

And then what we want to do is create a custom rrule for the f where we plug in forward rules for the the analytical derivatives. Of course, this would be fancier in the real setup.

function ChainRulesCore.rrule(
    ::typeof(f),
    p;
    p_f 
)
    val = f(p, p_f)
    # Could go through the keys(p) and fill in details for the derivatives here based on what is in it?

   function f_pb(Δsol)
         Δp = # NEED TO FILL IN GIVEN THE APPROPRIATE KEY IN ORDER
        return nothing, Δp, nothing  # or whatever it is these days in the latest chainrulescore
   end

Then after that, we want to make sure that we can use Zygote on g_1(val) and g_2(val)` etc. all work.

jlperla commented 3 years ago

From lyndon:

f(nt) = nt.a + nt.b

function rrule(::typeof(f), nt::NamedTuple)
    y = nt.a + nt.b
    f_pullback(dy) = NoTangent(), Tangent{typeof(nt)}(; a=dy, b=dy)

    return y, f_pullback
end

That does the real trick here for the adjoints and the difference-type of named tuples. Otherwise I think this might be pretty simple and just a question of calling the appropriate analytic derivative to fill it in.

jlperla commented 3 years ago

Also, lyndon asked that if you get a simple example working with named tuples to submit it as a PR to the chainrules docs. I think he just meant something very simple given the above comment he gave.