Closed cscherrer closed 1 year ago
Looks like tuple
is somewhat special function:
julia> foo(args...) = tuple(args...)
foo (generic function with 1 method)
julia> tuple(1.2)
(1.2,)
julia> foo(1.2)
(1.2,)
julia> code_lowered(tuple, (Float64,))
Core.CodeInfo[]
julia> code_lowered(foo, (Float64,))
1-element Vector{Core.CodeInfo}:
CodeInfo(
1 ─ %1 = Core._apply_iterate(Base.iterate, Main.tuple, args)
└── return %1
)
So in general the error makes sense to me. However, it's unclear why the tracer tried to recurse into tuple
since it is marked as a primitive in most contexts:
julia> Umlaut.isprimitive(Umlaut.BaseCtx(), tuple, 1.2)
true
julia> Umlaut.isprimitive(Umlaut.BaseCtx(), tuple, 1.2)
true
I'll try to debug it in the next couple of days.
Could you please post the package versions? With the following specs:
(Yota) pkg> st
Project Yota v0.7.3
Status `~/work/Yota/Project.toml`
[082447d4] ChainRules v1.28.1
[d360d2e6] ChainRulesCore v1.14.0
[26cc04aa] FiniteDifferences v0.12.24
[872c559c] NNlib v0.8.4
[92992a2b] Umlaut v0.2.4
[37e2e46d] LinearAlgebra
[9a3f8284] Random
[10745b16] Statistics
[8dfed614] Test
[cf7118a7] UUIDs
I see a different error:
ERROR: ArgumentError: argument is not a generic function
Stacktrace:
[1] which(f::Any, t::Any)
@ Base ./reflection.jl:1320
[2] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/ESm8s/src/trace.jl:341
[3] #gradtape#90
@ ~/work/Yota/src/grad.jl:243 [inlined]
[4] make_rrule(f::Function, args::Float64)
@ Yota ~/work/Yota/src/cr_api.jl:128
[5] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Float64)
@ Yota ~/work/Yota/src/cr_api.jl:170
[6] rrule(::Yota.YotaRuleConfig, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(tuple), ::Tuple{Float64}, ::Tuple{})
@ Yota ~/work/Yota/src/rulesets.jl:28
[84d833dd] TransformVariables v0.6.2
[92992a2b] Umlaut v0.2.4
[cd998857] Yota v0.7.3
julia> versioninfo()
Julia Version 1.8.0-rc1
Commit 6368fdc656 (2022-05-27 18:33 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: 32 × AMD Ryzen Threadripper 2950X 16-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
Threads: 1 on 32 virtual cores
Technical note, MRE (julia-1.8.0-rc1):
using TransformVariables, Umlaut, Yota
trace(TransformVariables._transform_tuple, TransformVariables.NoLogJac(), zeros(2), 2, (asℝ,); ctx=Yota.GradCtx())
I understand the problem, but my naive fix in #114 didn't work as expected. As a workaround right now and here, you can define this rule:
function ChainRulesCore.rrule(::typeof(tuple), args...)
y = tuple(args...)
return y, dy -> (NoTangent(), collect(dy)...)
end
Note that you still will need a few more rrules
(at least first(::NamedTuple)
).
In essence, this code touches a pretty sensitive piece of codebase that dynamically generates new rrules. This code works fine with rules defined in ChainRules.jl (or just without a specific config), but the rule for tuple(args...)
only exists in Yota and is not processed correctly by rrule_via_ad()
. I tried to fix the invocation, but it led to an endless recursion between rrule()
and rrule_via_ad()
, and it isn't obvious how to make them flexible and friendly to each another. So it's kind of design question that I need to sleep over.
As even more minimal reproducible examples:
trace(x -> Core._apply_iterate(Base.iterate, Core.tuple, x), 2.0; ctx=Yota.GradCtx())
Yota.rrule_via_ad(Yota.YotaRuleConfig(), Core.tuple, 2.0)
Thanks @dfdx , this is very helpful!
Note that you still will need a few more rrules
I'm not sure how specific this is to Yota (maybe it's a more general ChainRules question?) but do you know of a good strategy for knowing which rrule
s should be defined, first to get things working and then to improve performance?
Usually, when an rrule
is missing, you get an error similar to this:
ERROR: No deriative rule found for op %228 = first(%226)::NamedTuple{(:a, :b), Tuple{Float64, Float64}}, try defining it using
ChainRulesCore.rrule(::typeof(first), ::Tuple{NamedTuple{(:a, :b), Tuple{Float64, Float64}}, TransformVariables.NoLogJac, Int64}) = ...
So the general rule is just to try it out and see what's missing.
In your case, though, Yota hit double issue - not only the rrule
was missing, but also AD went through the valley of the shadow of recursive rrule_via_ad
, which led to a complex combination of otherwise harmless corner cases.
(I'm a bit overwhelmed these days, but I remember this issue isn't fully resolved yet)
Thanks @dfdx , I think I see. I'll leave this open so we remember the tuple
case isn't fully addressed yet. But it looks like your suggestion should work to avoid the issue.
rrule()
for tuple
is now in the library and tested with several use cases, so I close this issue. Note that the following still doesn't work:
rrule
s unrelated to the topictuple
itself (e.g. trace(tuple, 2.0)
), which is not supposed to work anywayUsing tuple
in any other context should work fine now, e.g.:
julia> grad(x -> sum(tuple(x, x)), 2.0)
(4.0, (ChainRulesCore.ZeroTangent(), 2.0))
Feel free to reopen if you encounter any more issues.
Hi, I think I'm seeing a bug in Umlaut when called from Yota. Here's a MWE:
To get some idea where it's breaking, ...
To get some more detail, I looked at the original stack trace, which includes a call to
Umlaut.getcode
. Adding a line@show f, types
to that function, I see that it's trying to callcode_lowered(tuple, (Float64,))
, which leads to the crash.