Open VarLad opened 2 years ago
Thanks for reporting this! It turns out, Pluto adds some magic to the call so it's traced differently from a normal REPL. In REPL:
julia> Yota.Umlaut.trace(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
(9.0, Tape{BaseCtx}
inp %1::var"#101#102"
inp %2::Vector{Float64}
%3 = broadcasted(+, %2, 1)::Broadcasted{}
%4 = materialize(%3)::Vector{Float64}
%5 = sum(%4)::Float64
)
In Pluto:
(9.0,
Tape{Umlaut.BaseCtx}
inp %1::var"#3#4"{typeof(+), typeof(sum)}
inp %2::Vector{Float64}
%3 = getfield(%1, :sum)::typeof(sum) # <---
%4 = getfield(%1, :+)::typeof(+) # <---
%5 = broadcasted(%4, %2, 1)::Broadcasted{}
%6 = materialize(%5)::Vector{Float64}
%7 = %3(%6)::Float64)
One workaround is to define the missing rrule
:
using ChainRulesCore
ChainRulesCore.rrule(getfield, x, f::Symbol) = (getfield(x, f), dy -> (NoTangent(), NoTangent(), NoTangent()))
After this the example gives the correct result:
grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
# ==> (9.0, (NoTangent(), [1.0, 1.0, 1.0]))
Whereas it works normally in the REPL