JuliaGPU / GPUCompiler.jl

Reusable compiler infrastructure for Julia GPU backends.
Other
155 stars 50 forks source link

GPUCompiler code typed is type unstable whereas regular code typed is fine #587

Open wsmoses opened 2 months ago

wsmoses commented 2 months ago
function func_mixed_call(N)
    allargs = Expr[]
    typeargs = Union{Symbol,Expr}[]
    exprs2 = Union{Symbol,Expr}[]
    for i in 1:N
        arg = Symbol("arg_$i")
        targ = Symbol("T$i")
        e = :($arg::$targ)
        push!(allargs, e)
        push!(typeargs, targ)

        inarg = quote
            if RefTypes[1+$i]
                $arg[]
            else
                $arg
            end
        end
        push!(exprs2, inarg)
    end

    quote
        @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)}
            fexpr = :f
            if RefTypes[1]
                fexpr = :(($fexpr)[])
            end
            exprs2 = Union{Symbol,Expr}[]
            for i in 1:$N
                arg = Symbol("arg_$i")
                inarg = if RefTypes[1+i]
                    :($arg[])
                else
                    :($arg)
                end
                push!(exprs2, inarg)
            end
                return quote
                    Base.@_inline_meta
                    $fexpr($(exprs2...))
                end
        end
    end
end

for N in 0:10
    eval(func_mixed_call(N))
end

function make(x, y, z)
   function inner(); for i in z x[i] = y; end
   end
end

m = make(ones(10), 3.0, 1:3)

function threading_run(func)
    for i = 1:10
        func()
    end
end

using GPUCompiler

Base.@kwdef struct TestTarget <: AbstractCompilerTarget
end
GPUCompiler.llvm_triple(::TestTarget) = Sys.MACHINE

struct TestCompilerParams<: AbstractCompilerParams 
end

# TODO: We shouldn't blanket opt-out
# GPUCompiler.check_invocation(job::CompilerJob{TestTarget}, entry::LLVM.Function) = nothing

GPUCompiler.runtime_slug(job::CompilerJob{TestTarget}) = "enzyme"

@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Integer)
    # primal function. Inferred here to get return type
    _tt = (TT.parameters...,)

    primal_tt = Tuple{_tt...} # map(eltype, _tt)...}

    primal = GPUCompiler.methodinstance(F, primal_tt, world)

    return primal
end

function get_job(@nospecialize(func), @nospecialize(tt))
    world = Base.get_world_counter()
    primal = fspec(Core.Typeof(func), tt, world)
    target = TestTarget()
    params = TestCompilerParams()
    return GPUCompiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world)
end
function enzyme_code_typed(@nospecialize(func), @nospecialize(types); kwargs...)
    job = get_job(func, types; kwargs...)
    GPUCompiler.code_typed(job; kwargs...)
end

@show enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})
using InteractiveUtils
@show @code_typed runtime_mixed_call(Val((false,true)), threading_run, Ref(m))

On 1.10 output is

wmoses@beast:~/git/Enzyme.jl (cai) $ ./julia-1.10.2/bin/julia --project
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.10.2 (2024-03-01)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> include("sad.jl")
enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}}) = Any[CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└──      goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})
│   %4 = Base.getfield(%3, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└──      goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└──      goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│        (f)(%8)::Nothing
└──      return nothing
) => Nothing]
#= /home/wmoses/git/Enzyme.jl/sad.jl:102 =# @code_typed(runtime_mixed_call(Val((false, true)), threading_run, Ref(m))) = CodeInfo(
1 ── %1  = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└───       goto #17 if not true
2 ┄─ %3  = φ (#1 => 1, #16 => %41)::Int64
│    %4  = Core.getfield(%1, :z)::UnitRange{Int64}
│    %5  = Base.getfield(%4, :start)::Int64
│    %6  = Base.getfield(%4, :stop)::Int64
│    %7  = Base.slt_int(%6, %5)::Bool
└───       goto #4 if not %7
3 ──       goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│    %11 = Base.getfield(%4, :start)::Int64
└───       goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│    %14 = φ (#4 => %10)::Int64
│    %15 = φ (#4 => %11)::Int64
│    %16 = Base.not_int(%13)::Bool
└───       goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│    %19 = φ (#5 => %15, #10 => %30)::Int64
│    %20 = Core.getfield(%1, :x)::Vector{Float64}
│    %21 = Core.getfield(%1, :y)::Float64
│          Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│    %23 = Base.getfield(%4, :stop)::Int64
│    %24 = (%19 === %23)::Bool
└───       goto #8 if not %24
7 ──       goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└───       goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│    %30 = φ (#8 => %27)::Int64
│    %31 = φ (#7 => true, #8 => false)::Bool
│    %32 = Base.not_int(%31)::Bool
└───       goto #11 if not %32
10 ─       goto #6
11 ┄       goto #12
12 ─ %36 = (%3 === 10)::Bool
└───       goto #14 if not %36
13 ─       goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└───       goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│    %42 = φ (#13 => true, #14 => false)::Bool
│    %43 = Base.not_int(%42)::Bool
└───       goto #17 if not %43
16 ─       goto #2
17 ┄       goto #18
18 ─       return nothing
) => Nothing
CodeInfo(
1 ── %1  = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└───       goto #17 if not true
2 ┄─ %3  = φ (#1 => 1, #16 => %41)::Int64
│    %4  = Core.getfield(%1, :z)::UnitRange{Int64}
│    %5  = Base.getfield(%4, :start)::Int64
│    %6  = Base.getfield(%4, :stop)::Int64
│    %7  = Base.slt_int(%6, %5)::Bool
└───       goto #4 if not %7
3 ──       goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│    %11 = Base.getfield(%4, :start)::Int64
└───       goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│    %14 = φ (#4 => %10)::Int64
│    %15 = φ (#4 => %11)::Int64
│    %16 = Base.not_int(%13)::Bool
└───       goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│    %19 = φ (#5 => %15, #10 => %30)::Int64
│    %20 = Core.getfield(%1, :x)::Vector{Float64}
│    %21 = Core.getfield(%1, :y)::Float64
│          Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│    %23 = Base.getfield(%4, :stop)::Int64
│    %24 = (%19 === %23)::Bool
└───       goto #8 if not %24
7 ──       goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└───       goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│    %30 = φ (#8 => %27)::Int64
│    %31 = φ (#7 => true, #8 => false)::Bool
│    %32 = Base.not_int(%31)::Bool
└───       goto #11 if not %32
10 ─       goto #6
11 ┄       goto #12
12 ─ %36 = (%3 === 10)::Bool
└───       goto #14 if not %36
13 ─       goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└───       goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│    %42 = φ (#13 => true, #14 => false)::Bool
│    %43 = Base.not_int(%42)::Bool
└───       goto #17 if not %43
16 ─       goto #2
17 ┄       goto #18
18 ─       return nothing
) => Nothing

cc @vchuravy

wsmoses commented 2 months ago
wmoses@beast:~/git/GPUCompiler.jl ((HEAD detached at origin/master)) $ git log
commit 8b513be9e2230fe0dd1905b805e25fa049b24d1d (HEAD, tag: v0.26.5, origin/master, origin/HEAD)
Author: Tim Besard <tim.besard@gmail.com>
Date:   Fri May 24 10:25:09 2024 +0200

    Bump version.
vchuravy commented 2 months ago
julia-repl> @show code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})

1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└──      goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})
│   %4 = Base.getfield(%3, :x)::var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}
└──      goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└──      goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│        (f)(%8)::Nothing
└──      return nothing
) => Nothing

Ref{typeof(m)} is not the same as typeof(Ref(m)).

julia> typeof(Ref(m))
Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}}