Open jakobnissen opened 3 years ago
Currently Julia compiler doesn't run type inference on Threads.@spawn
ed code block, as shown below:
julia> code_typed(; optimize = false) do
fetch(Threads.@spawn 1 + 2)
end |> first
CodeInfo(
1 ─ (#34 = %new(Main.:(var"#34#36")))::Core.Const(var"#34#36"())
│ %2 = #34::Core.Const(var"#34#36"())
│ (task = Base.Threads.Task(%2))::Task
│ %4 = false::Core.Const(false)
│ Base.setproperty!(task, :sticky, %4)::Any
└── goto #3 if not false
2 ─ Core.Const(:(Base.Threads.put!(Main.:(var"##sync#41"), task)))::Union{}
3 ┄ Base.Threads.schedule(task)::Any
│ %9 = task::Task
│ %10 = Main.fetch(%9)::Any
└── return %10
) => Any
I think https://github.com/JuliaLang/julia/pull/39773 is an attempt to enable various optimizations including type inference for multithreading contexts like this (am I right, @tkf ?).
That said, for the time being, I'm willing to make JET special case some of multithreading code and enable JET analysis on e.g. @spawn
blocks, but I won't extend the native type inference routine in any other aspects, like annotating return type of @spawn
code block, it's a job that https://github.com/JuliaLang/julia/pull/39773 should do (technically, @spawn
ed code is internally represented as closure and I guess it's not too hard to implement a special analysis pass on it).
My question is, is it enough to handle @spawn
and @threads
macros for supporting analysis on multithreading code ? Or is there any other code snippet that is commonly used for multithreading ?
I've only ever used @spawn
and @threads
, and don't know about other ways.
Having JET work for these two macros would be nice, but there is also something to be said for simply waiting until Base Julia provides it.
It depends on how easy or hard it is to implement in JET, I suppose.
After #124:
julia> using JET
julia> foo() = fetch(Threads.@spawn 1 + "foo")
foo (generic function with 1 method)
julia> @report_call foo()
═════ 1 possible error found ═════
┌ @ threadingconstructs.jl:174 Base.Threads.Task(#3)
│┌ @ task.jl:5 #self#(f, 0)
││┌ @ threadingconstructs.jl:170 Main.+(1, "foo")
│││ no matching method found for call signature: Main.+(1, "foo")
││└──────────────────────────────
Any
From the technical reason, there is a limitation that JET currently runs the additional analysis pass on Task
construction, regardless of whether it's really schedule
d or not.
I'd like to run the analysis pass only when we encounter schedule(::Task)
call, but I'm not sure if there is any good solution for that, so I'd let this issue to be closed as is for now.
I think a challenging point here is that fetch(@async f())
is effectively Base.invokelatest(f)
:
julia> f() = 1
f (generic function with 1 method)
julia> t = Task(f)
Task (runnable) @0x00007f5c77464c40
julia> f() = 2
f (generic function with 1 method)
julia> schedule(t)
Task (done) @0x00007f5c77464c40
julia> fetch(t)
2
If we can change how Task
interacts with the world age, I think it's actually fixable without bringing the Big Gun (JuliaLang/julia#39773).
You can find related discussions here:
We've also noticed that report_opt
is handled differently in multithreaded context:
julia> function abmult(r::Int)
if r < 0
r = -r
end
# the closure assigned to `f` make the variable `r` captured
f = x -> x * r
return f
end;
julia> JET.@report_opt abmult(42)
═════ 3 possible errors found ═════
┌ @ REPL[97]:2 r = Core.Box(_7::Int64)
│ captured variable `r` detected
└──────────────
┌ @ REPL[97]:2 %7 < 0
│ runtime dispatch detected: (%7::Any < 0)::Any
└──────────────
┌ @ REPL[97]:3 -(%14)
│ runtime dispatch detected: -(%14::Any)::Any
└──────────────
julia> JET.@report_opt fetch(Threads.@spawn abmult(42))
═════ 3 possible errors found ═════
┌ @ task.jl:360 wait(t)
│┌ @ task.jl:343 Base._wait(t)
││┌ @ task.jl:301 lock(%13)
│││ runtime dispatch detected: lock(%13::Any)::Any
││└───────────────
││┌ @ task.jl:304 wait(%30)
│││ runtime dispatch detected: wait(%30::Any)::Any
││└───────────────
││┌ @ task.jl:307 unlock(%40)
│││ runtime dispatch detected: unlock(%40::Any)::Any
││└───────────────
Note that the boxed variable is not reported when @spawn
ed. We were hoping to use JET to help detect cases, where one has to interpolate into the @spawn
ed task in order to prevent boxing.
@Drvi fyi you can do
julia> g() = Threads.@spawn abmult(42)
g (generic function with 1 method)
julia> JET.@report_opt g()
═════ 3 possible errors found ═════
┌ g() @ Main ./threadingconstructs.jl:377
│┌ Task(f::var"#10#11") @ Base ./task.jl:5
││┌ (::var"#10#11")() @ Main ./threadingconstructs.jl:373
│││┌ abmult(r::Int64) @ Main ./REPL[11]:2
││││ captured variable `r` detected
│││└────────────────────
│││┌ abmult(r::Int64) @ Main ./REPL[11]:2
││││ runtime dispatch detected: (%6::Any < 0)::Any
│││└────────────────────
│││┌ abmult(r::Int64) @ Main ./REPL[11]:3
││││ runtime dispatch detected: -(%13::Any)::Any
│││└────────────────────
Small example:
Output: