Closed mcabbott closed 2 years ago
It's a bug. Yota thinks Optimisers.Restucture(...)
is a primitive and records it to the tape instead of tracing it down to (:new, T, args...)
(represented as __new__(T, args...)
on the tape). Yota thinks it's a primitive because typeof(Restructure)
returns UnionAll
, which belongs to module Core
, and we don't trace deeper than that. I have an idea how to fix it, will try to implement it tonight.
Just to note, I was also seeing this in the latest release. But when I cloned the repository in order to add some @show
statements for debugging, the problem went away (well, actually it was replaced by a different bug). @mcabbott are you seeing the same behavior in master
?
This seems to work on main
:
using Optimisers, ChainRulesCore
function gradient(f, xs...)
println("Yota gradient!")
_, g = Yota.grad(f, xs...)
g[2:end]
end;
# a quick hack, not really tested
function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
# a more robust implementation would be to do backword conversion:
# return x, Δ -> (NoTangent(), NoTangent(), convert(typeof(x), Δ))
# but it doesn't work for ZeroTangent(), so passing Δ as is
return x, Δ -> (NoTangent(), NoTangent(), Δ)
end
m1 = collect(1:3.0);
gradient(m -> destructure(m)[1][1], m1)[1]
I get the same error on latest Yota + Umlat.
The rule for convert
silences the error, but doesn't actually make the struct
requested. It isn't used above, but if I change the code to something which does use that part, it fails:
julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
Yota gradient!
ERROR: MethodError: no method matching _rebuild(::Vector{Float64}, ::Int64, ::ZeroTangent, ::Int64; walk::typeof(Optimisers._Tangent_biwalk), prune::NoTangent)
Closest candidates are:
_rebuild(::Any, ::Any, ::AbstractVector, ::Any; walk, kw...)
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
_rebuild(::Any, ::Any, ::AbstractVector) got unsupported keyword arguments "walk", "prune"
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
Stacktrace:
[1] (::Optimisers.var"#_flatten_back#18"{Vector{Float64}, Int64, Int64})(::Tangent{Tuple{Vector{Float64}, Int64, Int64}, Tuple{ZeroTangent, NoTangent, NoTangent}})
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:77
[2] mkcall(fn::Umlaut.Variable, args::Umlaut.Variable; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[3] mkcall
@ ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179 [inlined]
[4] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:164
cf Zygote:
julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
(nothing, [1.0, 0.0, 0.0])
It looks like a different error, tracing constructors of UnionAll
, which caused the previous error, works correctly this time:
julia> trace((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0]; ctx=GradCtx())
(1.0, Tape{GradCtx}
inp %1::var"#212#213"
inp %2::Vector{Float64}
inp %3::Vector{Float64}
%5, %6 = [%4] = rrule(YotaRuleConfig(), _flatten, %2)
%8, %9 = [%7] = rrule(YotaRuleConfig(), indexed_iterate, %5, 1)
%11, %12 = [%10] = rrule(YotaRuleConfig(), getfield, %8, 1)
%14, %15 = [%13] = rrule(YotaRuleConfig(), getfield, %8, 2)
%17, %18 = [%16] = rrule(YotaRuleConfig(), indexed_iterate, %5, 2, %14)
%20, %21 = [%19] = rrule(YotaRuleConfig(), getfield, %17, 1)
%23, %24 = [%22] = rrule(YotaRuleConfig(), getfield, %17, 2)
%26, %27 = [%25] = rrule(YotaRuleConfig(), indexed_iterate, %5, 3, %23)
%29, %30 = [%28] = rrule(YotaRuleConfig(), getfield, %26, 1)
%32, %33 = [%31] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
%35, %36 = [%34] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
%38, %39 = [%37] = rrule(YotaRuleConfig(), convert, Vector{Float64}, %2)
%41, %42 = [%40] = rrule(YotaRuleConfig(), convert, Int64, %20)
%44, %45 = [%43] = rrule(YotaRuleConfig(), fieldtype, %35, 3)
%47, %48 = [%46] = rrule(YotaRuleConfig(), convert, %44, %29)
%50, %51 = [%49] = rrule(YotaRuleConfig(), __new__, %35, %38, %41, %47) # <-- this is internal constuctor of Restructure
%53, %54 = [%52] = rrule(YotaRuleConfig(), tuple, %11, %50)
%56, %57 = [%55] = rrule(YotaRuleConfig(), getindex, %53, 2)
%59, %60 = [%58] = rrule(YotaRuleConfig(), getproperty, %56, model)
%62, %63 = [%61] = rrule(YotaRuleConfig(), getproperty, %56, offsets)
%65, %66 = [%64] = rrule(YotaRuleConfig(), getproperty, %56, length)
%68, %69 = [%67] = rrule(YotaRuleConfig(), _rebuild, %59, %62, %3, %65)
%71, %72 = [%70] = rrule(YotaRuleConfig(), getindex, %68, 1)
)
_rebuild()
doesn't accept ZeroTangent()
as a cotangent value. A real question is whether ZeroTangent()
is correct here, and if so, why Zygote doesn't hit the same problem. I will need to understand more about Optimisers
internals and the generated graph to answer these questions.
It's not impossible there are bugs in _rebuild
, sorry, it's pretty messy. Will take a look, at some point.
Trying to find a simpler example of what I thought was the original problem, with a struct from here:
julia> using Yota, ChainRulesCore
julia> struct Multiplier{T} # from test_helpers in ChainRules
x::T
end
julia> (m::Multiplier)(y) = m.x * y
julia> function ChainRulesCore.rrule(m::Multiplier, y)
Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
return m(y), Multiplier_pullback
end
julia> grad(x -> x(3.0), Multiplier(5.0)) # perfect
(15.0, (ZeroTangent(), Tangent{Multiplier{Float64}}(x = 3.0,)))
julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %3 = Multiplier(%2)::Multiplier{Float64}, try defining it using
ChainRulesCore.rrule(::UnionAll, ::Float64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
...
julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
inp %1::var"#8#9"
inp %2::Float64
%3 = Multiplier(%2)::Multiplier{Float64}
%5, %6 = [%4] = rrule(Yota.YotaRuleConfig(), %3, 3.0)
)
(jl_lUY4C1) pkg> st Yota
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lUY4C1/Project.toml`
[cd998857] Yota v0.7.3
That's the tagged version. On master:
(jl_lpE53K) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lpE53K/Project.toml`
[92992a2b] Umlaut v0.2.5 `https://github.com/dfdx/Umlaut.jl.git#main`
[cd998857] Yota v0.7.4 `https://github.com/dfdx/Yota.jl.git#main`
julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %9 = convert(%7, %2)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:170
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:211
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:222
...
julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
inp %1::var"#14#15"
inp %2::Float64
%4, %5 = [%3] = rrule(Yota.YotaRuleConfig(), apply_type, Multiplier, Float64)
%7, %8 = [%6] = rrule(Yota.YotaRuleConfig(), fieldtype, %4, 1)
%9 = convert(%7, %2)::Float64
%11, %12 = [%10] = rrule(Yota.YotaRuleConfig(), __new__, %4, %9)
%14, %15 = [%13] = rrule(Yota.YotaRuleConfig(), %11, 3.0)
)
julia> function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
# Version with re-conversion via ProjectTo? Maybe this is only right for number types...
# Also, why not convert on the forward pass?
return x, Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(Δ))
end
julia> grad(x -> Multiplier(x)(3.0), 5.0)
(15.0, (ZeroTangent(), 3.0))
That's why your rule targets convert
, since apply_type
is now being applied. But why is convert
being called at all?
Sorry for the silence - I've been working on some bug fixes and improvements that may affect this question too. In particular, I added lineinfo to call nodes, and here's what it shows:
struct Multiplier{T} # from test_helpers in ChainRules
x::T
end
(m::Multiplier)(y) = m.x * y
function ChainRulesCore.rrule(m::Multiplier, y)
Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
return m(y), Multiplier_pullback
end
mult1(x) = x(3.0)
mult2(x) = Multiplier(x)(3.0)
_, tape = trace(mult2, 5.0; ctx=GradCtx())
Result:
(15.0, Tape{GradCtx}
inp %1::typeof(mult2)
inp %2::Float64
%4, %5 = [%3] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%7, %8 = [%6] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%9 = convert(Float64, %2)::Float64 # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%11, %12 = [%10] = rrule(YotaRuleConfig(), __new__, %7, %9) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%14, %15 = [%13] = rrule(YotaRuleConfig(), %11, 3.0) # Main.mult2 at /home/azbs/work/Yota/src/_main3.jl:37
)
So convert(Float64, %2)
happens in the object constructor, even though %2
is already Float64
. I tried to trick the compiler not to add convert()
, but it seems to be just an essential detail of the lowered code. Nevertheless, we can now simplify the rrule
for convert
to a more strict version:
function ChainRulesCore.rrule(::typeof(convert), ::Type{T}, x::T) where T
return x, Δ -> (NoTangent(), NoTangent(), Δ)
end
Both the top example and the Multiplier one work on Yota 0.8 and Julia 1.8, which is great.
On Julia nightly, something seems to go wrong, perhaps of interest (and might be why I saw errors in https://github.com/FluxML/Optimisers.jl/pull/105):
julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: Unexpected expression: $(Expr(:static_parameter, 1))
Full IRCode:
2 1 ─ %4 = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│ %1 = Core.apply_type(Main.Multiplier, %4)::Core.Const(Multiplier{Float64})
│ %2 = (%1)(_2)::Multiplier{Float64}
└── return %2
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:333
[3] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Tuple{UnionAll, Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
[4] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Type, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:290
[5] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:315
[6] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
[7] trace(f::Function, args::Float64; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:556
[8] gradtape(f::Function, args::Float64; ctx::Yota.GradCtx, seed::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:291
Apparently, Julia 1.9 changes the way static parameters (i.e. type parameters, {T}
) are used in IRCode. You can update Umlaut to 0.4.5 to account for this.
(No changes to the version of Yota itself are needed at the moment)
Great, thanks. Then I mark this as closed.
I'm surprised by this error, which if I understand right comes from this line https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl#L31 constructing a struct which isn't in fact used. Is this the desired behaviour, or can all (default?) constructors be handled automatically somehow?
Xref https://github.com/FluxML/Optimisers.jl/issues/96