Open avik-pal opened 1 month ago
Trying to use Enzyme the way I am currently is giving me
julia> reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model)
┌ Error: Enzyme failed to compile the backward pass. Differentiation will be disabled for this model.
│ exception = UndefKeywordError: keyword argument `kwargs` not assigned
└ @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:46
Full Stacktrace
ERROR: UndefKeywordError: keyword argument `kwargs` not assigned
Stacktrace:
[1] call
@ ~/.julia/packages/Cassette/4UsSX/src/context.jl:454 [inlined]
[2] fallback
@ ~/.julia/packages/Cassette/4UsSX/src/context.jl:452 [inlined]
[3] overdub(ctx::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, f::typeof(throw), exception::UndefKeywordError)
@ Reactant ~/.julia/packages/Cassette/4UsSX/src/context.jl:278
[4] apply(::LuxReactantExt.var"
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[5] apply
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[6] overdub(::Cassette.Context{…}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#12#15"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[7] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}})()
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
[8] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[9] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
[10] make_mlir_fn
@ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
[11] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:926
[12] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[13] (::Reactant.var"#99#104"{LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
[14] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[15] compile(f::LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, args::Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
[16] compile(f::LuxReactantExt.var"#12#15"{true, @NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
[17] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:43
[18] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
[19] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
[20] top-level scope
@ REPL[22]:1
[21] top-level scope
@ none:1
Some type information was truncated. Use `show(err)` to see complete types.
Ok fixed that. Now we are back to
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
[1] make_zero(::Type{Reactant.TracedRArray{Float32, Shape, 2} where Shape}, seen::IdDict{Any, Any}, prev::Reactant.TracedRArray{Float32, (5, 10), 2}, ::Val{false})
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:149
[2] #42
@ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1258 [inlined]
[3] ntuple
@ ./ntuple.jl:19 [inlined]
[4] make_zero
@ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1256 [inlined]
[5] make_zero(::Type{@NamedTuple{…}}, seen::IdDict{Any, Any}, prev::@NamedTuple{weight::Reactant.TracedRArray{…} where Shape, bias::Reactant.TracedRArray{…} where Shape}, ::Val{false})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1263
[6] (::Enzyme.Compiler.var"#42#43"{Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}, false, IdDict{Any, Any}, Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}})(i::Int64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1258
[7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}, false, IdDict{Any, Any}, Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}}, n::Int64)
@ Base ./ntuple.jl:19
[8] make_zero
@ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1256 [inlined]
[9] make_zero(::Type{@NamedTuple{…}}, seen::IdDict{Any, Any}, prev::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}, ::Val{false})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1263
[10] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
[11] overdub
@ /mnt/research/lux/Reactant.jl/src/overloads.jl:358 [inlined]
[12] #42
@ /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:35 [inlined]
[13] overdub(::Cassette.Context{…}, ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[14] var"
@ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
[15] #apply#142
@ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
[16] overdub(::Cassette.Context{…}, ::Reactant.var"##apply#142", ::@Kwargs{}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[17] apply(::LuxReactantExt.var"
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[18] apply
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[19] overdub(::Cassette.Context{…}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[20] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}})()
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
[21] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[22] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
[23] make_mlir_fn
@ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
[24] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:926
[25] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[26] (::Reactant.var"#99#104"{LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
[27] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[28] compile(f::LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, args::Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
[29] compile(f::LuxReactantExt.var"#42#45"{true, @NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
[30] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:44
[31] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
[32] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
[33] top-level scope
@ REPL[29]:1
[34] top-level scope
@ none:1
Some type information was truncated. Use `show(err)` to see complete types.
Attention: Patch coverage is 0.62500%
with 159 lines
in your changes missing coverage. Please review.
Project coverage is 80.26%. Comparing base (
60c595e
) to head (8ce9707
). Report is 26 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Some more progress. Now we get
1-element ExceptionStack:
MethodError: no method matching EnzymeCore.Duplicated(::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}})
Closest candidates are:
EnzymeCore.Duplicated(::T1, ::T1) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/iXoTK/src/EnzymeCore.jl:65
EnzymeCore.Duplicated(::T1, ::T1, ::Bool) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/iXoTK/src/EnzymeCore.jl:65
Stacktrace:
[1] call
@ ~/.julia/packages/Cassette/4UsSX/src/context.jl:454 [inlined]
[2] fallback
@ ~/.julia/packages/Cassette/4UsSX/src/context.jl:452 [inlined]
[3] _overdub_fallback(::Any, ::Vararg{Any})
@ ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:596 [inlined]
[4] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::Type{EnzymeCore.Duplicated}, ::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:596
[5] #38
@ /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:42 [inlined]
[6] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[7] var"
@ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
[8] #apply#1
@ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
[9] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::Reactant.var"##apply#1", ::@Kwargs{}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[10] apply(::LuxReactantExt.var"
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[11] apply
@ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
[12] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[13] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.TracedRArray{Float32, (10, 3), 2}}})()
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
[14] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.TracedRArray{Float32, (10, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[15] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
[16] make_mlir_fn
@ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
[17] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:921
[18] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[19] (::Reactant.var"#99#104"{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:919
[20] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[21] compile(f::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
[22] compile(f::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
[23] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
@ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:51
[24] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
[25] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
[26] top-level scope
@ REPL[9]:1
[27] top-level scope
@ none:1
@wsmoses can you review how I am compiling the gradient? I must be doing something wrong, the ps
elements are Array
instead of TracedRArray
@avik-pal can this be done in a way that reactant compiles the whole update, not just the gradient as separate from the inference pass. Specifically, I expect there to be a substantial perf improvement from doing so -- including the model update actually fully occuring in place.
E.g. the function reactant compiles being something like
function update(model, x, learning_rate)
grads = gradient(model, x)
update!(model, grads, learning_rate[1])
nothing
end
Not with the layers API. Currently, if we can accelerate just the neural network part, I would consider it a good win. Also, having it like this makes it possible to use regular Julia ops for cases where we can't compile to Reactant, for example, the ODE solves happen in Julia and the neural network is via XLA.
We can add AutoReactant
for the training API, where we can compile the entire pass in the first call, and reuse it in subsequent calls (similar to what we do for Enzyme).
There was a bug in my code. Fixed that but now I get:
AssertionError: !fnwrapped
Stacktrace:
[1] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:927
[2] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[3] (::Reactant.var"#99#104"{LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
[4] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[5] compile(f::LuxReactantExt.var"#36#39"{true}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
[6] compile(f::LuxReactantExt.var"#36#39"{true}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
@ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
[7] __to_reactant_adaptor(to::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, input_prototype::Matrix{Float32}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, eltype_adaptor::Nothing)
@ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:91
[8] __to_reactant_adaptor(to::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
@ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:65
[9] adapt
@ /mnt/research/lux/Lux.jl/src/transform/reactant.jl:23 [inlined]
[10] (::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)})(x::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
@ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
[11] top-level scope
@ REPL[16]:1
[12] top-level scope
@ none:1
[13] eval
@ ./boot.jl:385 [inlined]
[14] eval
@ ./Base.jl:88 [inlined]
[15] repleval(m::Module, code::Expr, ::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:229
[16] (::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:192
[17] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:515
[18] with_logger
@ ./logging.jl:627 [inlined]
[19] (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:193
[20] #invokelatest#2
@ ./essentials.jl:892 [inlined]
[21] invokelatest(::Any)
@ Base ./essentials.jl:889
[22] (::VSCodeServer.var"#64#65")()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/eval.jl:34
Seems like there is some kind of Boxing going on somewhere
Okay things are working mostly now, we just need a copyto! for TracedRArray
using Reactant, Lux, Random, ComponentArrays
model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model)
Gives me a Pipeline Failed
with error: size of operand dimension 0 (3) is not equal to 1 or size of result dimension 0 (2)
. But the gradient seems to work fine for Enzyme with the XLA compilation.
@avik-pal you should oopen an issue with the pipeline error on Reactant, once the prereqs are merged
@wsmoses seems like an incorrect generation?
Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<3x10xf32>, %arg3: tensor<10x5xf32>, %arg4: tensor<10x5xf32>) -> (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<5x3xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<5xf32>
%0 = stablehlo.reshape %arg1 : (tensor<1x5xf32>) -> tensor<5xf32>
%1 = stablehlo.dot_general %arg4, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x5xf32>, tensor<3x10xf32>) -> tensor<5x3xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5x3xf32>
%3 = stablehlo.add %1, %2 : tensor<5x3xf32>
%4 = stablehlo.tanh %3 : tensor<5x3xf32>
%5 = stablehlo.reduce(%4 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<5x3xf32>, tensor<f32>) -> tensor<f32>
%6 = stablehlo.multiply %4, %4 : tensor<5x3xf32>
%7 = stablehlo.subtract %cst, %6 : tensor<5x3xf32>
%8 = stablehlo.reduce(%7 init: %cst_1) across dimensions = [1] : (tensor<5x3xf32>, tensor<5xf32>) -> tensor<5xf32>
reducer(%arg5: tensor<5xf32>, %arg6: tensor<5xf32>) {
%13 = stablehlo.add %arg5, %arg6 : tensor<5xf32>
stablehlo.return %13 : tensor<5xf32>
}
%9 = stablehlo.dot_general %arg2, %7, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<5x3xf32>) -> tensor<10x5xf32>
%10 = stablehlo.add %arg3, %9 : tensor<10x5xf32>
%11 = stablehlo.reshape %8 : (tensor<5xf32>) -> tensor<1x5xf32>
%12 = stablehlo.add %arg0, %11 : tensor<1x5xf32>
return %12, %5, %10 : tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>
}
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
what(): UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[5]:
<unknown>:0: note: see current operation: "func.return"(%15, %8, %13) : (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) -> ()
@avik-pal the lux fixes (and named tuple) just landed and were released.
I'll give the reduction error a go shortly, but at minimum we can see what works (and perhaps mark that as expected broken to start with)
Currently the julia session crashes because of the broken reverse pass, so can't mark it broken
I will try to see what kind of models compile for the forward pass atleast
Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like ComponentArrays
.
Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such
Not easily as we need to own the data.
Similarly the model and ideally the inputs are always kept rarrays here.
Also for better performance the optimizers themselves are compiled by reactant
On Sat, Jun 1, 2024 at 8:33 AM Avik Pal @.***> wrote:
Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like ComponentArrays.
Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/665#issuecomment-2143324103, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFOCHIUSBR7OQWBYO3ZFFTNJAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGMZDIMJQGM . You are receiving this because you were mentioned.Message ID: <LuxDL/Lux. @.***>
Also for better performance the optimizers themselves are compiled by reactant
Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia.
I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge.
There’s no reason why we couldn’t in theory, but I don’t think we do right now.
Worth testing and opening an issue so we know what to work on though
On Sat, Jun 1, 2024 at 5:46 PM Avik Pal @.***> wrote:
Also for better performance the optimizers themselves are compiled by reactant
Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia.
I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge.
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/665#issuecomment-2143493123, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBAIIEC7N2TEXAKBB3ZFHUGXAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGQ4TGMJSGM . You are receiving this because you were mentioned.Message ID: <LuxDL/Lux. @.***>
@avik-pal fix has landed, can we retry this?
This one is too broadly scoped, so I will hold it off.
First, I want to finish #673, which compiles the entire training loop and doesn't need to worry about users doing unwanted things to the parameters.
Example Usage
This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.
Upstream Needs
LuxCore.apply
instead ofLuxCore.apply
TODOs
__make_reactant_array