EnzymeAD / Reactant.jl

MIT License
65 stars 6 forks source link

Forward-mode differentiation is broken #183

Closed mofeing closed 3 weeks ago

mofeing commented 3 weeks ago

MWE

x = rand(2,2)
y = adapt(Reactant.ConcreteRArray, x)
f = Reactant.compile((y,)) do z
    Enzyme.gradient(Forward, sum, z)
end

gives the following error

ERROR: MethodError: no method matching dynamic_update_slice(::Nothing, ::Reactant.MLIR.IR.Value, ::Vector{Reactant.MLIR.IR.Value})

Closest candidates are:
  dynamic_update_slice(::Reactant.MLIR.IR.Value, ::Reactant.MLIR.IR.Value, ::Vector{Reactant.MLIR.IR.Value}; result, location)
   @ Reactant ~/.julia/artifacts/d665d0322ecfb603a7b4fae4c2b4f316fb46098a/StableHLO.inc.jl:1462

Stacktrace:
  [1] setindex!(::Reactant.TracedRArray{Float64, 2}, ::Float64, ::Int64, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/rRa4g/src/TracedRArray.jl:113
  [2] _setindex!
    @ ./abstractarray.jl:1431 [inlined]
  [3] setindex!
    @ ./abstractarray.jl:1396 [inlined]
  [4] #111
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1541 [inlined]
  [5] macro expansion
    @ ./ntuple.jl:72 [inlined]
  [6] ntuple(f::Enzyme.var"#111#112"{Reactant.TracedRArray{Float64, 2}, Int64}, ::Val{4})
    @ Base ./ntuple.jl:69
  [7] onehot
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1537 [inlined]
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1833 [inlined]
  [9] create_shadows
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1811 [inlined]
 [10] gradient(::ForwardMode{false, FFIABI, false, false}, ::typeof(sum), ::Reactant.TracedRArray{Float64, 2})
    @ Enzyme ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1970
 [11] #39
    @ ~/Developer/TenetDMRGMPI/gd.jl:48 [inlined]
 [12] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})
    @ Base.Experimental ./<missing>:0
 [13] (::Reactant.var"#26#35"{var"#39#40", Reactant.MLIR.IR.Block, Vector{Union{…}}, Tuple{Reactant.TracedRArray{…}}})()
    @ Reactant ~/.julia/packages/Reactant/rRa4g/src/utils.jl:113
 [14] block!(f::Reactant.var"#26#35"{var"#39#40", Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Block.jl:201
 [15] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant ~/.julia/packages/Reactant/rRa4g/src/utils.jl:81
 [16] make_mlir_fn
    @ ~/.julia/packages/Reactant/rRa4g/src/utils.jl:30 [inlined]
 [17] #6
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:261 [inlined]
 [18] block!(f::Reactant.Compiler.var"#6#11"{var"#39#40", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Block.jl:201
 [19] #5
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:260 [inlined]
 [20] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, var"#39#40", Tuple{…}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Module.jl:93
 [21] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:257
 [22] compile_mlir!
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:256 [inlined]
 [23] (::Reactant.Compiler.var"#30#32"{var"#39#40", Tuple{ConcreteRArray{Float64, 2}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:584
 [24] context!(f::Reactant.Compiler.var"#30#32"{var"#39#40", Tuple{ConcreteRArray{…}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Context.jl:71
 [25] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:581
 [26] compile_xla
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:575 [inlined]
 [27] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:608
 [28] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:607
 [29] top-level scope
    @ ~/Developer/demo.jl:47
Some type information was truncated. Use `show(err)` to see complete types.
mofeing commented 3 weeks ago

The issue seems to be a bad implementation of similar for TracedRArray in https://github.com/EnzymeAD/Reactant.jl/blob/af6b87eae701edb504e2eadd8cc58151d02bd321/src/TracedRArray.jl#L124-L127

julia> x = rand(2,2)
2×2 Matrix{Float64}:
 0.0845062  0.288305
 0.355494   0.0131065

julia> y = Reactant.to_rarray(x)
2×2 ConcreteRArray{Float64, 2}:
 0.0845062  0.288305
 0.355494   0.0131065

julia> f = @compile similar(y)
ERROR: MethodError: no method matching type(::Nothing)

Closest candidates are:
  type(::Reactant.MLIR.IR.Attribute)
   @ Reactant ~/Developer/Reactant.jl/src/mlir/IR/Attribute.jl:41
  type(::Reactant.MLIR.IR.Value)
   @ Reactant ~/Developer/Reactant.jl/src/mlir/IR/Value.jl:104

Stacktrace:
  [1] transpose_val(val::Nothing)
    @ Reactant ~/Developer/Reactant.jl/src/utils.jl:20
  [2] (::Reactant.var"#28#37"{Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}})()
    @ Reactant ~/Developer/Reactant.jl/src/utils.jl:143
  [3] block!(f::Reactant.var"#28#37"{Vector{Union{…}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Block.jl:201
  [4] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant ~/Developer/Reactant.jl/src/utils.jl:140
  [5] make_mlir_fn
    @ ~/Developer/Reactant.jl/src/utils.jl:30 [inlined]
  [6] #6
    @ ~/Developer/Reactant.jl/src/Compiler.jl:260 [inlined]
  [7] block!(f::Reactant.Compiler.var"#6#11"{typeof(similar), Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Block.jl:201
  [8] #5
    @ ~/Developer/Reactant.jl/src/Compiler.jl:259 [inlined]
  [9] mmodule!(f::Reactant.Compiler.var"#5#10"{…}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Module.jl:93
 [10] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
    @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:256
 [11] compile_mlir!
    @ ~/Developer/Reactant.jl/src/Compiler.jl:255 [inlined]
 [12] (::Reactant.Compiler.var"#30#32"{typeof(similar), Tuple{ConcreteRArray{Float64, 2}}})()
    @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:583
 [13] context!(f::Reactant.Compiler.var"#30#32"{typeof(similar), Tuple{ConcreteRArray{…}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Context.jl:71
 [14] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
    @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:580
 [15] compile_xla
    @ ~/Developer/Reactant.jl/src/Compiler.jl:574 [inlined]
 [16] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
    @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:607
 [17] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
    @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:606
 [18] top-level scope
    @ ~/Developer/Reactant.jl/src/Compiler.jl:367
Some type information was truncated. Use `show(err)` to see complete types.
mofeing commented 3 weeks ago

Seems like forward-mode differentiation is still broken. The following stacktrace suggests that it's skipping the overriden Enzyme.autodiff method:

julia> x = rand(2,2)
2×2 Matrix{Float64}:
 0.802631  0.182577
 0.688512  0.696927

julia> y = adapt(Reactant.ConcreteRArray, x)
2×2 ConcreteRArray{Float64, 2}:
 0.802631  0.182577
 0.688512  0.696927

julia> f = Reactant.compile((y,)) do z
           Enzyme.gradient(Forward, sum, z)
       end
ERROR: AssertionError: Base.isconcretetype(typ)
Stacktrace:
  [1] abs_typeof(arg::LLVM.LoadInst, partial::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/absint.jl:468
  [2] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:6931
  [3] codegen
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:5931 [inlined]
  [4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206
  [5] _thunk
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206 [inlined]
  [6] cached_compilation
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8247 [inlined]
  [7] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8379
  [8] #s2070#19068
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8516 [inlined]
  [9] 
    @ Enzyme.Compiler ./none:0
 [10] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [11] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::typeof(Reactant._copyto!), df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, primal_1::Reactant.TracedRArray{…}, shadow_1_1::Reactant.TracedRArray{…}, shadow_1_2::Reactant.TracedRArray{…}, shadow_1_3::Reactant.TracedRArray{…}, shadow_1_4::Reactant.TracedRArray{…}, primal_2::Base.Broadcast.Broadcasted{…}, shadow_2_1::Base.Broadcast.Broadcasted{…}, shadow_2_2::Base.Broadcast.Broadcasted{…}, shadow_2_3::Base.Broadcast.Broadcasted{…}, shadow_2_4::Base.Broadcast.Broadcasted{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/rules/jitrules.jl:290
 [12] copyto!
    @ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:535 [inlined]
 [13] copyto!
    @ ./broadcast.jl:956 [inlined]
 [14] copy
    @ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:526 [inlined]
 [15] materialize
    @ ./broadcast.jl:903 [inlined]
 [16] broadcast
    @ ./broadcast.jl:841 [inlined]
 [17] #mapreduce#83
    @ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:389 [inlined]
 [18] fwddiffe4julia__mapreduce_83_16760wrap
    @ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:0
 [19] macro expansion
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8136 [inlined]
 [20] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::BatchDuplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7702
 [21] (::Enzyme.Compiler.ForwardModeThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7491
 [22] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::Reactant.var"##mapreduce#83", df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, primal_1::Colon, shadow_1_1::Nothing, shadow_1_2::Nothing, shadow_1_3::Nothing, shadow_1_4::Nothing, primal_2::Nothing, shadow_2_1::Nothing, shadow_2_2::Nothing, shadow_2_3::Nothing, shadow_2_4::Nothing, primal_3::typeof(mapreduce), shadow_3_1::Nothing, shadow_3_2::Nothing, shadow_3_3::Nothing, shadow_3_4::Nothing, primal_4::typeof(identity), shadow_4_1::Nothing, shadow_4_2::Nothing, shadow_4_3::Nothing, shadow_4_4::Nothing, primal_5::typeof(Base.add_sum), shadow_5_1::Nothing, shadow_5_2::Nothing, shadow_5_3::Nothing, shadow_5_4::Nothing, primal_6::Reactant.TracedRArray{…}, shadow_6_1::Reactant.TracedRArray{…}, shadow_6_2::Reactant.TracedRArray{…}, shadow_6_3::Reactant.TracedRArray{…}, shadow_6_4::Reactant.TracedRArray{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/rules/jitrules.jl:305
 [23] mapreduce
    @ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:370 [inlined]
 [24] _sum
    @ ./reducedim.jl:1015 [inlined]
 [25] _sum
    @ ./reducedim.jl:1014 [inlined]
 [26] sum
    @ ./reducedim.jl:1010 [inlined]
 [27] fwddiffe4julia_sum_19579wrap
    @ ./reducedim.jl:0
 [28] macro expansion
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8136 [inlined]
 [29] enzyme_call
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7702 [inlined]
 [30] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7491 [inlined]
 [31] autodiff
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:647 [inlined]
 [32] autodiff
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:512 [inlined]
 [33] macro expansion
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:2068 [inlined]
 [34] gradient(::ForwardMode{…}, ::typeof(sum), ::Reactant.TracedRArray{…}; chunk::Nothing, shadows::Tuple{…})
    @ Enzyme ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1970
 [35] gradient(::ForwardMode{false, FFIABI, false, false}, ::typeof(sum), ::Reactant.TracedRArray{Float64, 2})
    @ Enzyme ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1970
 [36] #25
    @ ./REPL[14]:2 [inlined]
 [37] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})
    @ Base.Experimental ./<missing>:0
 [38] (::Reactant.var"#26#35"{var"#25#26", Reactant.MLIR.IR.Block, Vector{Union{…}}, Tuple{Reactant.TracedRArray{…}}})()
    @ Reactant ~/.julia/packages/Reactant/dXqEG/src/utils.jl:113
 [39] block!(f::Reactant.var"#26#35"{var"#25#26", Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Block.jl:201
 [40] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant ~/.julia/packages/Reactant/dXqEG/src/utils.jl:81
 [41] make_mlir_fn
    @ ~/.julia/packages/Reactant/dXqEG/src/utils.jl:30 [inlined]
 [42] #6
    @ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:260 [inlined]
 [43] block!(f::Reactant.Compiler.var"#6#11"{var"#25#26", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Block.jl:201
 [44] #5
    @ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:259 [inlined]
 [45] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, var"#25#26", Tuple{…}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Module.jl:93
 [46] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:256
 [47] compile_mlir!
    @ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:255 [inlined]
 [48] (::Reactant.Compiler.var"#30#32"{Bool, var"#25#26", Tuple{ConcreteRArray{Float64, 2}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:636
 [49] context!(f::Reactant.Compiler.var"#30#32"{Bool, var"#25#26", Tuple{ConcreteRArray{…}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Context.jl:71
 [50] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:633
 [51] compile_xla
    @ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:627 [inlined]
 [52] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:660
 [53] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:659
Some type information was truncated. Use `show(err)` to see complete types.
mofeing commented 3 weeks ago

Continuing on this topic, if I call Enzyme.autodiff directly then the error is another one:

julia> Reactant.compile((y,)) do z
           dz = similar(z)
           Enzyme.autodiff(Forward, sum, Active, Duplicated(z,dz))
       end
ERROR: MethodError: no method matching (::Reactant.var"#needs_primal#14")(::Type{ForwardMode{false, FFIABI, true, false}})

Closest candidates are:
  (::Reactant.var"#needs_primal#14")(::Type{<:ReverseMode{ReturnPrimal}}) where ReturnPrimal
   @ Reactant ~/.julia/packages/Reactant/e4gfb/src/Interpreter.jl:275

Stacktrace:
  [1] autodiff(::ForwardMode{false, FFIABI, true, false}, f::Const{typeof(sum)}, ::Type{Active}, args::Duplicated{Reactant.TracedRArray{Float64, 2}})
    @ Reactant ~/.julia/packages/Reactant/e4gfb/src/Interpreter.jl:279
  [2] autodiff
    @ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:512 [inlined]
  [3] #19
    @ ./REPL[16]:3 [inlined]
  [4] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})
    @ Base.Experimental ./<missing>:0
  [5] (::Reactant.var"#26#35"{var"#19#20", Reactant.MLIR.IR.Block, Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 2}}})()
    @ Reactant ~/.julia/packages/Reactant/e4gfb/src/utils.jl:113
  [6] block!(f::Reactant.var"#26#35"{var"#19#20", Reactant.MLIR.IR.Block, Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Block.jl:201
  [7] make_mlir_fn(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant ~/.julia/packages/Reactant/e4gfb/src/utils.jl:81
  [8] make_mlir_fn
    @ ~/.julia/packages/Reactant/e4gfb/src/utils.jl:30 [inlined]
  [9] #6
    @ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:260 [inlined]
 [10] block!(f::Reactant.Compiler.var"#6#11"{var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Block.jl:201
 [11] #5
    @ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:259 [inlined]
 [12] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Module.jl:93
 [13] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:256
 [14] compile_mlir!
    @ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:255 [inlined]
 [15] (::Reactant.Compiler.var"#30#32"{Bool, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:636
 [16] context!(f::Reactant.Compiler.var"#30#32"{Bool, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Context.jl:71
 [17] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:633
 [18] compile_xla
    @ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:627 [inlined]
 [19] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:660
 [20] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:659
 [21] top-level scope
    @ REPL[16]:1

@wsmoses we need to fix needs_primal

wsmoses commented 3 weeks ago

I mean if the overlay table doesn’t guarantee it’s always overloaded in the context of inlining we should either stop inlining, or alternatively do something like https://github.com/EnzymeAD/Enzyme.jl/blob/72763e9aa28978ac820c286c01d4bfd00aa451a3/src/compiler/interpreter.jl#L363

mofeing commented 3 weeks ago

the only thing yet to solve is #189