EnzymeAD / Reactant.jl

MIT License
26 stars 2 forks source link

Handling exponents #27

Open gdalle opened 2 weeks ago

gdalle commented 2 weeks ago

I'm trying out Reactant to compile DifferentiationInterface operators, and running into this error in my test suite: https://github.com/gdalle/DifferentiationInterface.jl/actions/runs/9642905092/job/26591732655?pr=325

I think it's because broadcasting ^ is not supported?

gdalle commented 2 weeks ago

Related MWE, although the failure is different:

julia> using Reactant, Enzyme

julia> f(x) = sum(x .^ 2)
f (generic function with 1 method)

julia> g(x) = Enzyme.gradient(Enzyme.Reverse, f, x)
g (generic function with 1 method)

julia> x1 = [1.0, 2.0];

julia> x2 = Reactant.ConcreteRArray(x1);

julia> g(x1)
2-element Vector{Float64}:
 2.0
 4.0

julia> g(x2)
...
ERROR: LLVM error: function failed verification (4)
Stacktrace:
  [1] handle_error(reason::Cstring)
    @ LLVM ~/.julia/packages/LLVM/6cDbl/src/core/context.jl:168
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/aioBJ/src/api.jl:156
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:3697
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:5832
  [5] codegen
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:5110 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6639
  [7] _thunk
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6639 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6677 [inlined]
  [9] (::Enzyme.Compiler.var"#28587#28588"{DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6746
 [10] JuliaContext(f::Enzyme.Compiler.var"#28587#28588"{DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:42
 [12] #s2003#28586
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6697 [inlined]
 [13] var"#s2003#28586"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, ::Type, ::Type, ::Type, tt::Any, ::Type, ::Type, ::Type, ::Type, ::Type, ::Any)
    @ Enzyme.Compiler ./none:0
 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [15] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:304 [inlined]
 [16] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
 [17] gradient
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:1005 [inlined]
 [18] g(x::Reactant.ConcreteRArray{Float64, (2,), 1})
    @ Main ./REPL[32]:1
 [19] top-level scope
    @ REPL[33]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> Reactant.compile(g, (x2,))
ERROR: MethodError: no method matching broadcast_to_size(::Base.RefValue{typeof(^)}, ::Tuple{Int64})

Closest candidates are:
  broadcast_to_size(::Base.Broadcast.Extruded, ::Any)
   @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:798
  broadcast_to_size(::T, ::Any) where T<:Number
   @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:790
  broadcast_to_size(::Reactant.TracedRArray, ::Any)
   @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:779
  ...

Stacktrace:
  [1] (::Reactant.var"#50#51"{Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Base.Broadcast.Extruded{Reactant.TracedRArray{…}, Tuple{…}, Tuple{…}}, Base.RefValue{Val{…}}}}})(a::Base.RefValue{typeof(^)})
    @ Reactant ./none:0
  [2] iterate(::Base.Generator{Tuple{Base.RefValue{typeof(^)}, Base.Broadcast.Extruded{Reactant.TracedRArray{…}, Tuple{…}, Tuple{…}}, Base.RefValue{Val{…}}}, Reactant.var"#50#51"{Base.Broadcast.Broadcasted{Nothing, Tuple{…}, typeof(Base.literal_pow), Tuple{…}}}})
    @ Base ./generator.jl:47
  [3] _copyto!(dest::Reactant.TracedRArray{Float64, (2,), 1}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Reactant.TracedRArray{Float64, (2,), 1}, Base.RefValue{Val{2}}}})
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:842
  [4] copyto!
    @ ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:760 [inlined]
  [5] copyto!
    @ ./broadcast.jl:956 [inlined]
  [6] copy
    @ ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:750 [inlined]
  [7] overdub(context::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, f::typeof(copy), args::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{…}, Tuple{…}, typeof(Base.literal_pow), Tuple{…}})
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:637
  [8] materialize(::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Reactant.TracedRArray{Float64, (2,), 1}, Base.RefValue{Val{2}}}})
    @ ./broadcast.jl:903 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] f(::Reactant.TracedRArray{Float64, (2,), 1})
    @ ./REPL[31]:1 [inlined]
 [11] f
    @ ./REPL[31]:1 [inlined]
 [12] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(f), ::Reactant.TracedRArray{Float64, (2,), 1})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [13] (::Reactant.var"#5#13"{typeof(f), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2,), 1}}})()
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:67
 [14] block!(f::Reactant.var"#5#13"{typeof(f), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2,), 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Block.jl:201
 [15] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Reactant.TracedRArray{Float64, (2,), 1}}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:60
 [16] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(autodiff), ::ReverseMode{false, FFIABI, false}, f::Const{typeof(f)}, ::Type{Active}, args::Duplicated{Reactant.TracedRArray{…}})
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:167
 [17] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(f), ::Type{Active}, ::Duplicated{Reactant.TracedRArray{Float64, (2,), 1}})
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
 [18] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
 [19] gradient
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:1005 [inlined]
 [20] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(gradient), ::ReverseMode{false, FFIABI, false}, ::typeof(f), ::Reactant.TracedRArray{Float64, (2,), 1})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [21] g(::Reactant.TracedRArray{Float64, (2,), 1})
    @ ./REPL[32]:1 [inlined]
 [22] g
    @ ./REPL[32]:1 [inlined]
 [23] (::Reactant.var"#5#13"{typeof(g), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2,), 1}}})()
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:67
 [24] block!(f::Reactant.var"#5#13"{typeof(g), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2,), 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Block.jl:201
 [25] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:60
 [26] (::Reactant.var"#104#109"{Reactant.MLIR.IR.Module, typeof(g), Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}}, Int64})()
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1073
 [27] mmodule!(f::Reactant.var"#104#109"{Reactant.MLIR.IR.Module, typeof(g), Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}}, Int64}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Module.jl:93
 [28] #103
    @ ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1072 [inlined]
 [29] context!(f::Reactant.var"#103#108"{typeof(g), Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Context.jl:71
 [30] compile(f::typeof(g), args::Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}}; pipeline_options::String, client::Nothing)
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1070
 [31] compile(f::typeof(g), args::Tuple{Reactant.ConcreteRArray{Float64, (2,), 1}})
    @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1063
 [32] top-level scope
    @ REPL[35]:1
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 2 weeks ago

You can’t pass rarrays to autodiff without a compile

gdalle commented 2 weeks ago

Okay, that explains the first failure (although such a warning probably belongs rather high in the README). How about the second one?

gdalle commented 2 weeks ago

In the DI test suite, the compilation error was the following, but both seem related to a lack of support for exponents:

MethodError: no method matching elem_apply(::typeof(^), ::Reactant.TracedRArray{Float64, (2, 3), 2}, ::Reactant.TracedRArray{Int64, (2, 3), 2})

  Closest candidates are:
    elem_apply(::typeof(*), ::Any, ::Reactant.TracedRArray{ElType, Shape, N}) where {ElType, Shape, N}
     @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:495
    elem_apply(::typeof(*), ::Reactant.TracedRArray{ElType, Shape, N}, ::Any) where {ElType, Shape, N}
     @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:483
    elem_apply(::typeof(max), ::Reactant.TracedRArray{ElType, Shape, N}, ::Any) where {ElType, Shape, N}
     @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:483
    ...

  Stacktrace:
    [1] _copyto!(dest::Reactant.TracedRArray{Float64, (2, 3), 2}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(^), Tuple{Reactant.TracedRArray{Float64, (2, 3), 2}, Int64}})
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:842
    [2] copyto!
      @ ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:760 [inlined]
    [3] copyto!
      @ ./broadcast.jl:956 [inlined]
    [4] copy
      @ ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:750 [inlined]
    [5] overdub(context::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, f::typeof(copy), args::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(^), Tuple{Reactant.TracedRArray{Float64, (2, 3), 2}, Int64}})
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/overloads.jl:637
    [6] materialize(::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{2}, Nothing, typeof(^), Tuple{Reactant.TracedRArray{Float64, (2, 3), 2}, Int64}})
      @ ./broadcast.jl:903 [inlined]
    [7] materialize
      @ ./broadcast.jl:903 [inlined]
    [8] var"
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:162 [inlined]
    [9] #arr_to_num_aux_linalg#18
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:162 [inlined]
   [10] recurse(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::DifferentiationInterfaceTest.var"##arr_to_num_aux_linalg#18", ::Int64, ::Int64, ::typeof(DifferentiationInterfaceTest.arr_to_num_aux_linalg), ::Reactant.TracedRArray{Float64, (2, 3), 2})
      @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
   [11] kwcall(::@NamedTuple{α::Int64, β::Int64}, ::typeof(DifferentiationInterfaceTest.arr_to_num_aux_linalg), ::Reactant.TracedRArray{Float64, (2, 3), 2})
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:162 [inlined]
   [12] arr_to_num_aux_linalg
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:162 [inlined]
   [13] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(Core.kwcall), ::@NamedTuple{α::Int64, β::Int64}, ::typeof(DifferentiationInterfaceTest.arr_to_num_aux_linalg), ::Reactant.TracedRArray{Float64, (2, 3), 2})
      @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
   [14] arr_to_num_linalg(::Reactant.TracedRArray{Float64, (2, 3), 2})
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:206 [inlined]
   [15] arr_to_num_linalg
      @ ~/work/DifferentiationInterface.jl/DifferentiationInterface.jl/DifferentiationInterfaceTest/src/scenarios/default.jl:206 [inlined]
   [16] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(DifferentiationInterfaceTest.arr_to_num_linalg), ::Reactant.TracedRArray{Float64, (2, 3), 2})
      @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
   [17] (::Reactant.var"#5#13"{typeof(DifferentiationInterfaceTest.arr_to_num_linalg), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2, 3), 2}}})()
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:67
   [18] block!(f::Reactant.var"#5#13"{typeof(DifferentiationInterfaceTest.arr_to_num_linalg), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float64, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Block.jl:201
   [19] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool)
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/utils.jl:60
   [20] (::Reactant.var"#104#109"{Reactant.MLIR.IR.Module, typeof(DifferentiationInterfaceTest.arr_to_num_linalg), Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}}, Int64})()
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1073
   [21] mmodule!(f::Reactant.var"#104#109"{Reactant.MLIR.IR.Module, typeof(DifferentiationInterfaceTest.arr_to_num_linalg), Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}}, Int64}, blk::Reactant.MLIR.IR.Module)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Module.jl:93
   [22] #103
      @ ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1072 [inlined]
   [23] context!(f::Reactant.var"#103#108"{typeof(DifferentiationInterfaceTest.arr_to_num_linalg), Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}}, Int64}, ctx::Reactant.MLIR.IR.Context)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/RKFgC/src/mlir/IR/Context.jl:71
   [24] compile(f::typeof(DifferentiationInterfaceTest.arr_to_num_linalg), args::Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}}; pipeline_options::String, client::Nothing)
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1070
   [25] compile(f::typeof(DifferentiationInterfaceTest.arr_to_num_linalg), args::Tuple{Reactant.ConcreteRArray{Float64, (2, 3), 2}})
      @ Reactant ~/.julia/packages/Reactant/RKFgC/src/Reactant.jl:1063
gdalle commented 4 days ago

@wsmoses I'm excited to help you try out Reactant with DI, do you think this is an easy one to solve? My test suite makes use of exponentiation

wsmoses commented 2 days ago

@gdalle yeah we are in the process of generalizing the broadcasting support which will consequently solve this [and other issues].

And yeah I'm really excited too! I think DI is almost certainly going to be the best way to use Reactant -- and also should come with none of the perf issues I kept mentioning when using Enzyme.jl directly :)