Closed ForceBru closed 2 years ago
FWIW the error is slightly different on master:
If I disable all rules related to broadcasting in Yota, then the error is more straightforward, and I think tells us that the rules from https://github.com/JuliaDiff/ChainRules.jl/pull/644 (which ought to cover this kind of broadcasting) are not being called:
julia> using Yota
julia> normal_pdf(x, mean, var) = exp(-(x - mean)^2 / (2var)) / sqrt(2π * var);
julia> Yota.grad((x, mu) -> sum(log, normal_pdf.(x, mu, 1.0)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using
ChainRulesCore.rrule(::typeof(Base.Broadcast.materialize), ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(normal_pdf), Tuple{Vector{Float64}, Float64, Float64}}) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/dev/Yota/src/grad.jl:185
...
# same error on a much easier broadcast, should use derivatives_given_output
julia> Yota.grad((x, mu) -> sum(log, atan.(x, mu)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using ...
# and an even easier one, has its own rrule(broadcasted, +, ...)
julia> Yota.grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
ERROR: No deriative rule found for op %7 = materialize(%5)::Vector{Float64} , try defining it using ...
(jl_fINTq0) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_fINTq0/Project.toml`
[082447d4] ChainRules v1.43.2
[92992a2b] Umlaut v0.4.2 `https://github.com/dfdx/Umlaut.jl.git#main`
[cd998857] Yota v0.7.5 `~/.julia/dev/Yota`
Oh, I think I messed up broadcasting on main
recently :( Right now I'm fixing a possibly related bug, so let's see how it works after that (ETA ~1-2 days).
@mcabbott If I understand it correctly, JuliaDiff/ChainRules.jl#644 doesn't provide an rrule(materialize, ...)
, so if you commented it out from Yota too (as I just did), it makes sense that Yota started to complain about the missing rule. At least, when I uncommented the rule, this example started to work (on fix-kw-rrule
branch):
julia> grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
(2.515783933810798, (ZeroTangent(), [0.5835340506097648, 0.6717831140646412, 0.9651554343975317, 0.5720477538411006, 0.9756415208897008, 0.7392737693342455, 0.9473341524270764, 0.9914851197981195, 0.9458598901263214, 0.5826024579088742], 7.974717263397377))
rrule(broadcasted, normal_pdf, ...)
still doesn't hit the new generic broadcasting though. Is it correct that I need to invoke the signature with ::BroadcastStyle
to make this work?
Oh right, sorry, the rule for materialize
is indeed still needed. And yes, I think Yota somehow needs to fall through to the signature with ::BroadcastStyle
(and more Ref
s).
The fix is now on main
, featuring the new generic broadcasting from https://github.com/JuliaDiff/ChainRules.jl/pull/644
Thanks! Are you thinking of making a release sometime soon?
@mcabbott Yep, tagged v0.7.4.
Code
Error message
Versions