Closed jlperla closed 3 years ago
From lyndon:
f(nt) = nt.a + nt.b
function rrule(::typeof(f), nt::NamedTuple)
y = nt.a + nt.b
f_pullback(dy) = NoTangent(), Tangent{typeof(nt)}(; a=dy, b=dy)
return y, f_pullback
end
That does the real trick here for the adjoints and the difference-type of named tuples. Otherwise I think this might be pretty simple and just a question of calling the appropriate analytic derivative to fill it in.
Also, lyndon asked that if you get a simple example working with named tuples to submit it as a PR to the chainrules docs. I think he just meant something very simple given the above comment he gave.
The chainrules crew tell me that it is now reasonable to pass in named tuples into AD, which might let us make a major syntax improvement in the DSL/perturbation solver proof of concept. Better to do now than later.
The thing to check is basically whether we can write a custom rrule that hooks into the named tuples approach.
Take the following code for the primal,
And then what we want to do is create a custom rrule for the
f
where we plug in forward rules for the the analytical derivatives. Of course, this would be fancier in the real setup.Then after that, we want to make sure that we can use Zygote on
g_1(val)
and g_2(val)` etc. all work.