Closed mofeing closed 3 weeks ago
The issue seems to be a bad implementation of similar
for TracedRArray
in https://github.com/EnzymeAD/Reactant.jl/blob/af6b87eae701edb504e2eadd8cc58151d02bd321/src/TracedRArray.jl#L124-L127
julia> x = rand(2,2)
2×2 Matrix{Float64}:
0.0845062 0.288305
0.355494 0.0131065
julia> y = Reactant.to_rarray(x)
2×2 ConcreteRArray{Float64, 2}:
0.0845062 0.288305
0.355494 0.0131065
julia> f = @compile similar(y)
ERROR: MethodError: no method matching type(::Nothing)
Closest candidates are:
type(::Reactant.MLIR.IR.Attribute)
@ Reactant ~/Developer/Reactant.jl/src/mlir/IR/Attribute.jl:41
type(::Reactant.MLIR.IR.Value)
@ Reactant ~/Developer/Reactant.jl/src/mlir/IR/Value.jl:104
Stacktrace:
[1] transpose_val(val::Nothing)
@ Reactant ~/Developer/Reactant.jl/src/utils.jl:20
[2] (::Reactant.var"#28#37"{Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}})()
@ Reactant ~/Developer/Reactant.jl/src/utils.jl:143
[3] block!(f::Reactant.var"#28#37"{Vector{Union{…}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Block.jl:201
[4] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
@ Reactant ~/Developer/Reactant.jl/src/utils.jl:140
[5] make_mlir_fn
@ ~/Developer/Reactant.jl/src/utils.jl:30 [inlined]
[6] #6
@ ~/Developer/Reactant.jl/src/Compiler.jl:260 [inlined]
[7] block!(f::Reactant.Compiler.var"#6#11"{typeof(similar), Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Block.jl:201
[8] #5
@ ~/Developer/Reactant.jl/src/Compiler.jl:259 [inlined]
[9] mmodule!(f::Reactant.Compiler.var"#5#10"{…}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Module.jl:93
[10] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:256
[11] compile_mlir!
@ ~/Developer/Reactant.jl/src/Compiler.jl:255 [inlined]
[12] (::Reactant.Compiler.var"#30#32"{typeof(similar), Tuple{ConcreteRArray{Float64, 2}}})()
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:583
[13] context!(f::Reactant.Compiler.var"#30#32"{typeof(similar), Tuple{ConcreteRArray{…}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/Developer/Reactant.jl/src/mlir/IR/Context.jl:71
[14] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:580
[15] compile_xla
@ ~/Developer/Reactant.jl/src/Compiler.jl:574 [inlined]
[16] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing)
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:607
[17] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:606
[18] top-level scope
@ ~/Developer/Reactant.jl/src/Compiler.jl:367
Some type information was truncated. Use `show(err)` to see complete types.
Seems like forward-mode differentiation is still broken. The following stacktrace suggests that it's skipping the overriden Enzyme.autodiff
method:
julia> x = rand(2,2)
2×2 Matrix{Float64}:
0.802631 0.182577
0.688512 0.696927
julia> y = adapt(Reactant.ConcreteRArray, x)
2×2 ConcreteRArray{Float64, 2}:
0.802631 0.182577
0.688512 0.696927
julia> f = Reactant.compile((y,)) do z
Enzyme.gradient(Forward, sum, z)
end
ERROR: AssertionError: Base.isconcretetype(typ)
Stacktrace:
[1] abs_typeof(arg::LLVM.LoadInst, partial::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/absint.jl:468
[2] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:6931
[3] codegen
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:5931 [inlined]
[4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206
[5] _thunk
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206 [inlined]
[6] cached_compilation
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8247 [inlined]
[7] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8379
[8] #s2070#19068
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8516 [inlined]
[9]
@ Enzyme.Compiler ./none:0
[10] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[11] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::typeof(Reactant._copyto!), df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, primal_1::Reactant.TracedRArray{…}, shadow_1_1::Reactant.TracedRArray{…}, shadow_1_2::Reactant.TracedRArray{…}, shadow_1_3::Reactant.TracedRArray{…}, shadow_1_4::Reactant.TracedRArray{…}, primal_2::Base.Broadcast.Broadcasted{…}, shadow_2_1::Base.Broadcast.Broadcasted{…}, shadow_2_2::Base.Broadcast.Broadcasted{…}, shadow_2_3::Base.Broadcast.Broadcasted{…}, shadow_2_4::Base.Broadcast.Broadcasted{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/rules/jitrules.jl:290
[12] copyto!
@ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:535 [inlined]
[13] copyto!
@ ./broadcast.jl:956 [inlined]
[14] copy
@ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:526 [inlined]
[15] materialize
@ ./broadcast.jl:903 [inlined]
[16] broadcast
@ ./broadcast.jl:841 [inlined]
[17] #mapreduce#83
@ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:389 [inlined]
[18] fwddiffe4julia__mapreduce_83_16760wrap
@ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:0
[19] macro expansion
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8136 [inlined]
[20] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::BatchDuplicated{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7702
[21] (::Enzyme.Compiler.ForwardModeThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7491
[22] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::Reactant.var"##mapreduce#83", df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, primal_1::Colon, shadow_1_1::Nothing, shadow_1_2::Nothing, shadow_1_3::Nothing, shadow_1_4::Nothing, primal_2::Nothing, shadow_2_1::Nothing, shadow_2_2::Nothing, shadow_2_3::Nothing, shadow_2_4::Nothing, primal_3::typeof(mapreduce), shadow_3_1::Nothing, shadow_3_2::Nothing, shadow_3_3::Nothing, shadow_3_4::Nothing, primal_4::typeof(identity), shadow_4_1::Nothing, shadow_4_2::Nothing, shadow_4_3::Nothing, shadow_4_4::Nothing, primal_5::typeof(Base.add_sum), shadow_5_1::Nothing, shadow_5_2::Nothing, shadow_5_3::Nothing, shadow_5_4::Nothing, primal_6::Reactant.TracedRArray{…}, shadow_6_1::Reactant.TracedRArray{…}, shadow_6_2::Reactant.TracedRArray{…}, shadow_6_3::Reactant.TracedRArray{…}, shadow_6_4::Reactant.TracedRArray{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/vgArw/src/rules/jitrules.jl:305
[23] mapreduce
@ ~/.julia/packages/Reactant/dXqEG/src/TracedRArray.jl:370 [inlined]
[24] _sum
@ ./reducedim.jl:1015 [inlined]
[25] _sum
@ ./reducedim.jl:1014 [inlined]
[26] sum
@ ./reducedim.jl:1010 [inlined]
[27] fwddiffe4julia_sum_19579wrap
@ ./reducedim.jl:0
[28] macro expansion
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:8136 [inlined]
[29] enzyme_call
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7702 [inlined]
[30] ForwardModeThunk
@ ~/.julia/packages/Enzyme/vgArw/src/compiler.jl:7491 [inlined]
[31] autodiff
@ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:647 [inlined]
[32] autodiff
@ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:512 [inlined]
[33] macro expansion
@ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:2068 [inlined]
[34] gradient(::ForwardMode{…}, ::typeof(sum), ::Reactant.TracedRArray{…}; chunk::Nothing, shadows::Tuple{…})
@ Enzyme ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1970
[35] gradient(::ForwardMode{false, FFIABI, false, false}, ::typeof(sum), ::Reactant.TracedRArray{Float64, 2})
@ Enzyme ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:1970
[36] #25
@ ./REPL[14]:2 [inlined]
[37] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})
@ Base.Experimental ./<missing>:0
[38] (::Reactant.var"#26#35"{var"#25#26", Reactant.MLIR.IR.Block, Vector{Union{…}}, Tuple{Reactant.TracedRArray{…}}})()
@ Reactant ~/.julia/packages/Reactant/dXqEG/src/utils.jl:113
[39] block!(f::Reactant.var"#26#35"{var"#25#26", Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Block.jl:201
[40] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
@ Reactant ~/.julia/packages/Reactant/dXqEG/src/utils.jl:81
[41] make_mlir_fn
@ ~/.julia/packages/Reactant/dXqEG/src/utils.jl:30 [inlined]
[42] #6
@ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:260 [inlined]
[43] block!(f::Reactant.Compiler.var"#6#11"{var"#25#26", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Block.jl:201
[44] #5
@ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:259 [inlined]
[45] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, var"#25#26", Tuple{…}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Module.jl:93
[46] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:256
[47] compile_mlir!
@ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:255 [inlined]
[48] (::Reactant.Compiler.var"#30#32"{Bool, var"#25#26", Tuple{ConcreteRArray{Float64, 2}}})()
@ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:636
[49] context!(f::Reactant.Compiler.var"#30#32"{Bool, var"#25#26", Tuple{ConcreteRArray{…}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/dXqEG/src/mlir/IR/Context.jl:71
[50] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:633
[51] compile_xla
@ ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:627 [inlined]
[52] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:660
[53] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
@ Reactant.Compiler ~/.julia/packages/Reactant/dXqEG/src/Compiler.jl:659
Some type information was truncated. Use `show(err)` to see complete types.
Continuing on this topic, if I call Enzyme.autodiff
directly then the error is another one:
julia> Reactant.compile((y,)) do z
dz = similar(z)
Enzyme.autodiff(Forward, sum, Active, Duplicated(z,dz))
end
ERROR: MethodError: no method matching (::Reactant.var"#needs_primal#14")(::Type{ForwardMode{false, FFIABI, true, false}})
Closest candidates are:
(::Reactant.var"#needs_primal#14")(::Type{<:ReverseMode{ReturnPrimal}}) where ReturnPrimal
@ Reactant ~/.julia/packages/Reactant/e4gfb/src/Interpreter.jl:275
Stacktrace:
[1] autodiff(::ForwardMode{false, FFIABI, true, false}, f::Const{typeof(sum)}, ::Type{Active}, args::Duplicated{Reactant.TracedRArray{Float64, 2}})
@ Reactant ~/.julia/packages/Reactant/e4gfb/src/Interpreter.jl:279
[2] autodiff
@ ~/.julia/packages/Enzyme/vgArw/src/Enzyme.jl:512 [inlined]
[3] #19
@ ./REPL[16]:3 [inlined]
[4] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})
@ Base.Experimental ./<missing>:0
[5] (::Reactant.var"#26#35"{var"#19#20", Reactant.MLIR.IR.Block, Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 2}}})()
@ Reactant ~/.julia/packages/Reactant/e4gfb/src/utils.jl:113
[6] block!(f::Reactant.var"#26#35"{var"#19#20", Reactant.MLIR.IR.Block, Vector{Union{Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Block.jl:201
[7] make_mlir_fn(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
@ Reactant ~/.julia/packages/Reactant/e4gfb/src/utils.jl:81
[8] make_mlir_fn
@ ~/.julia/packages/Reactant/e4gfb/src/utils.jl:30 [inlined]
[9] #6
@ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:260 [inlined]
[10] block!(f::Reactant.Compiler.var"#6#11"{var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Block.jl:201
[11] #5
@ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:259 [inlined]
[12] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Module.jl:93
[13] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:256
[14] compile_mlir!
@ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:255 [inlined]
[15] (::Reactant.Compiler.var"#30#32"{Bool, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}})()
@ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:636
[16] context!(f::Reactant.Compiler.var"#30#32"{Bool, var"#19#20", Tuple{ConcreteRArray{Float64, 2}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/e4gfb/src/mlir/IR/Context.jl:71
[17] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:633
[18] compile_xla
@ ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:627 [inlined]
[19] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:660
[20] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}})
@ Reactant.Compiler ~/.julia/packages/Reactant/e4gfb/src/Compiler.jl:659
[21] top-level scope
@ REPL[16]:1
@wsmoses we need to fix needs_primal
I mean if the overlay table doesn’t guarantee it’s always overloaded in the context of inlining we should either stop inlining, or alternatively do something like https://github.com/EnzymeAD/Enzyme.jl/blob/72763e9aa28978ac820c286c01d4bfd00aa451a3/src/compiler/interpreter.jl#L363
the only thing yet to solve is #189
MWE
gives the following error