LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
446 stars 50 forks source link

Auto compile Lux models to reactant #665

Open avik-pal opened 1 month ago

avik-pal commented 1 month ago

Example Usage

This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.

using Reactant, Lux, Random

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3); force_compile_backward=true)(model)
ps, st = Lux.setup(Random.default_rng(), reactant_model)

x = randn(Float32, 10, 3)

reactant_model(x, ps, st)

Upstream Needs

TODOs

avik-pal commented 1 month ago

Trying to use Enzyme the way I am currently is giving me

julia> reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model)
┌ Error: Enzyme failed to compile the backward pass. Differentiation will be disabled for this model.
│   exception = UndefKeywordError: keyword argument `kwargs` not assigned
└ @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:46

Full Stacktrace

ERROR: UndefKeywordError: keyword argument `kwargs` not assigned
Stacktrace:
  [1] call
    @ ~/.julia/packages/Cassette/4UsSX/src/context.jl:454 [inlined]
  [2] fallback
    @ ~/.julia/packages/Cassette/4UsSX/src/context.jl:452 [inlined]
  [3] overdub(ctx::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, f::typeof(throw), exception::UndefKeywordError)
    @ Reactant ~/.julia/packages/Cassette/4UsSX/src/context.jl:278
  [4] apply(::LuxReactantExt.var"
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
  [5] apply
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
  [6] overdub(::Cassette.Context{…}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#12#15"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
  [7] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
  [8] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
  [9] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#12#15"{…}, StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
 [10] make_mlir_fn
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
 [11] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:926
 [12] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
 [13] (::Reactant.var"#99#104"{LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
 [14] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
 [15] compile(f::LuxReactantExt.var"#12#15"{true, @NamedTuple{…}}, args::Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
 [16] compile(f::LuxReactantExt.var"#12#15"{true, @NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
 [17] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:43
 [18] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
 [19] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
 [20] top-level scope
    @ REPL[22]:1
 [21] top-level scope
    @ none:1
Some type information was truncated. Use `show(err)` to see complete types.
avik-pal commented 1 month ago

Ok fixed that. Now we are back to

ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
  [1] make_zero(::Type{Reactant.TracedRArray{Float32, Shape, 2} where Shape}, seen::IdDict{Any, Any}, prev::Reactant.TracedRArray{Float32, (5, 10), 2}, ::Val{false})
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:149
  [2] #42
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1258 [inlined]
  [3] ntuple
    @ ./ntuple.jl:19 [inlined]
  [4] make_zero
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1256 [inlined]
  [5] make_zero(::Type{@NamedTuple{…}}, seen::IdDict{Any, Any}, prev::@NamedTuple{weight::Reactant.TracedRArray{…} where Shape, bias::Reactant.TracedRArray{…} where Shape}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1263
  [6] (::Enzyme.Compiler.var"#42#43"{Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}, false, IdDict{Any, Any}, Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}})(i::Int64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1258
  [7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}, false, IdDict{Any, Any}, Tuple{@NamedTuple{…}, @NamedTuple{…}, @NamedTuple{}}}, n::Int64)
    @ Base ./ntuple.jl:19
  [8] make_zero
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1256 [inlined]
  [9] make_zero(::Type{@NamedTuple{…}}, seen::IdDict{Any, Any}, prev::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:1263
 [10] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
 [11] overdub
    @ /mnt/research/lux/Reactant.jl/src/overloads.jl:358 [inlined]
 [12] #42
    @ /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:35 [inlined]
 [13] overdub(::Cassette.Context{…}, ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [14] var"
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
 [15] #apply#142
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
 [16] overdub(::Cassette.Context{…}, ::Reactant.var"##apply#142", ::@Kwargs{}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [17] apply(::LuxReactantExt.var"
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
 [18] apply
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
 [19] overdub(::Cassette.Context{…}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#42#45"{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [20] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
 [21] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.TracedRArray{…}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
 [22] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#42#45"{…}, StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
 [23] make_mlir_fn
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
 [24] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:926
 [25] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
 [26] (::Reactant.var"#99#104"{LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
 [27] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
 [28] compile(f::LuxReactantExt.var"#42#45"{true, @NamedTuple{…}}, args::Tuple{StatefulLuxLayer{…}, Reactant.ConcreteRArray{…}}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
 [29] compile(f::LuxReactantExt.var"#42#45"{true, @NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
 [30] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:44
 [31] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
 [32] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}, layer_3::WrappedFunction{…}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
 [33] top-level scope
    @ REPL[29]:1
 [34] top-level scope
    @ none:1
Some type information was truncated. Use `show(err)` to see complete types.
codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 0.62500% with 159 lines in your changes missing coverage. Please review.

Project coverage is 80.26%. Comparing base (60c595e) to head (8ce9707). Report is 26 commits behind head on main.

Files Patch % Lines
ext/LuxReactantExt/layer.jl 0.00% 101 Missing :warning:
src/transform/reactant.jl 0.00% 15 Missing :warning:
ext/LuxReactantExt/utils.jl 0.00% 13 Missing :warning:
ext/LuxReactantExt/train.jl 0.00% 12 Missing :warning:
src/layers/extension.jl 0.00% 10 Missing :warning:
src/contrib/training.jl 0.00% 6 Missing :warning:
src/utils.jl 0.00% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #665 +/- ## ========================================== - Coverage 87.11% 80.26% -6.85% ========================================== Files 50 55 +5 Lines 2515 2671 +156 ========================================== - Hits 2191 2144 -47 - Misses 324 527 +203 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

avik-pal commented 1 month ago

Some more progress. Now we get

1-element ExceptionStack:
MethodError: no method matching EnzymeCore.Duplicated(::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}})

Closest candidates are:
  EnzymeCore.Duplicated(::T1, ::T1) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/iXoTK/src/EnzymeCore.jl:65
  EnzymeCore.Duplicated(::T1, ::T1, ::Bool) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/iXoTK/src/EnzymeCore.jl:65

Stacktrace:
  [1] call
    @ ~/.julia/packages/Cassette/4UsSX/src/context.jl:454 [inlined]
  [2] fallback
    @ ~/.julia/packages/Cassette/4UsSX/src/context.jl:452 [inlined]
  [3] _overdub_fallback(::Any, ::Vararg{Any})
    @ ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:596 [inlined]
  [4] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::Type{EnzymeCore.Duplicated}, ::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:596
  [5] #38
    @ /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:42 [inlined]
  [6] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
  [7] var"
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
  [8] #apply#1
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:11 [inlined]
  [9] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::Reactant.var"##apply#1", ::@Kwargs{}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [10] apply(::LuxReactantExt.var"
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
 [11] apply
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:10 [inlined]
 [12] overdub(::Cassette.Context{Reactant.var"##TraceCtx#Name", Nothing, Nothing, Cassette.var"##PassType#235", Nothing, Nothing}, ::typeof(Reactant.apply), ::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, ::StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, ::Reactant.TracedRArray{Float32, (10, 3), 2})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [13] (::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.TracedRArray{Float32, (10, 3), 2}}})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:53
 [14] block!(f::Reactant.var"#5#13"{typeof(Reactant.apply), Tuple{}, Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.TracedRArray{Float32, (10, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Block.jl:198
 [15] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant /mnt/research/lux/Reactant.jl/src/utils.jl:46
 [16] make_mlir_fn
    @ /mnt/research/lux/Reactant.jl/src/utils.jl:16 [inlined]
 [17] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:921
 [18] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
 [19] (::Reactant.var"#99#104"{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:919
 [20] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
 [21] compile(f::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
 [22] compile(f::LuxReactantExt.var"#38#41"{true, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
 [23] __to_reactant_adaptor(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
    @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:51
 [24] adapt(to::ToReactantAdaptor{true, Matrix{Float32}}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/reactant.jl:9
 [25] (::ToReactantAdaptor{true, Matrix{Float32}})(x::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
 [26] top-level scope
    @ REPL[9]:1
 [27] top-level scope
    @ none:1

@wsmoses can you review how I am compiling the gradient? I must be doing something wrong, the ps elements are Array instead of TracedRArray

wsmoses commented 1 month ago

@avik-pal can this be done in a way that reactant compiles the whole update, not just the gradient as separate from the inference pass. Specifically, I expect there to be a substantial perf improvement from doing so -- including the model update actually fully occuring in place.

E.g. the function reactant compiles being something like

function update(model, x, learning_rate)
   grads = gradient(model, x)
   update!(model, grads, learning_rate[1])
   nothing
end
avik-pal commented 1 month ago

Not with the layers API. Currently, if we can accelerate just the neural network part, I would consider it a good win. Also, having it like this makes it possible to use regular Julia ops for cases where we can't compile to Reactant, for example, the ODE solves happen in Julia and the neural network is via XLA.

We can add AutoReactant for the training API, where we can compile the entire pass in the first call, and reuse it in subsequent calls (similar to what we do for Enzyme).

avik-pal commented 1 month ago

There was a bug in my code. Fixed that but now I get:

AssertionError: !fnwrapped
Stacktrace:
  [1] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:927
  [2] mmodule!(f::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Module.jl:89
  [3] (::Reactant.var"#99#104"{LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64})()
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:924
  [4] context!(f::Reactant.var"#99#104"{LuxReactantExt.var"#36#39"{true}, Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/Reactant.jl/src/mlir/IR/Context.jl:68
  [5] compile(f::LuxReactantExt.var"#36#39"{true}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:922
  [6] compile(f::LuxReactantExt.var"#36#39"{true}, args::Tuple{StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, Reactant.ConcreteRArray{Float32, (10, 3), 2}})
    @ Reactant /mnt/research/lux/Reactant.jl/src/Reactant.jl:917
  [7] __to_reactant_adaptor(to::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, input_prototype::Matrix{Float32}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_3::@NamedTuple{}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, eltype_adaptor::Nothing)
    @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:91
  [8] __to_reactant_adaptor(to::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)}, model::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
    @ LuxReactantExt /mnt/research/lux/Lux.jl/ext/LuxReactantExt.jl:65
  [9] adapt
    @ /mnt/research/lux/Lux.jl/src/transform/reactant.jl:23 [inlined]
 [10] (::ToReactantAdaptor{true, Xoshiro, Matrix{Float32}, typeof(identity)})(x::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}})
    @ Lux /mnt/research/lux/Lux.jl/src/transform/types.jl:6
 [11] top-level scope
    @ REPL[16]:1
 [12] top-level scope
    @ none:1
 [13] eval
    @ ./boot.jl:385 [inlined]
 [14] eval
    @ ./Base.jl:88 [inlined]
 [15] repleval(m::Module, code::Expr, ::String)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:229
 [16] (::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:192
 [17] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging ./logging.jl:515
 [18] with_logger
    @ ./logging.jl:627 [inlined]
 [19] (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:193
 [20] #invokelatest#2
    @ ./essentials.jl:892 [inlined]
 [21] invokelatest(::Any)
    @ Base ./essentials.jl:889
 [22] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/eval.jl:34

Seems like there is some kind of Boxing going on somewhere

avik-pal commented 1 month ago

Okay things are working mostly now, we just need a copyto! for TracedRArray

avik-pal commented 1 month ago
using Reactant, Lux, Random, ComponentArrays

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model)

Gives me a Pipeline Failed with error: size of operand dimension 0 (3) is not equal to 1 or size of result dimension 0 (2). But the gradient seems to work fine for Enzyme with the XLA compilation.

wsmoses commented 1 month ago

@avik-pal you should oopen an issue with the pipeline error on Reactant, once the prereqs are merged

avik-pal commented 1 month ago

@wsmoses seems like an incorrect generation?

Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<3x10xf32>, %arg3: tensor<10x5xf32>, %arg4: tensor<10x5xf32>) -> (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<5x3xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<5xf32>
    %0 = stablehlo.reshape %arg1 : (tensor<1x5xf32>) -> tensor<5xf32>
    %1 = stablehlo.dot_general %arg4, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x5xf32>, tensor<3x10xf32>) -> tensor<5x3xf32>
    %2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5x3xf32>
    %3 = stablehlo.add %1, %2 : tensor<5x3xf32>
    %4 = stablehlo.tanh %3 : tensor<5x3xf32>
    %5 = stablehlo.reduce(%4 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<5x3xf32>, tensor<f32>) -> tensor<f32>
    %6 = stablehlo.multiply %4, %4 : tensor<5x3xf32>
    %7 = stablehlo.subtract %cst, %6 : tensor<5x3xf32>
    %8 = stablehlo.reduce(%7 init: %cst_1) across dimensions = [1] : (tensor<5x3xf32>, tensor<5xf32>) -> tensor<5xf32>
     reducer(%arg5: tensor<5xf32>, %arg6: tensor<5xf32>)  {
      %13 = stablehlo.add %arg5, %arg6 : tensor<5xf32>
      stablehlo.return %13 : tensor<5xf32>
    }
    %9 = stablehlo.dot_general %arg2, %7, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<5x3xf32>) -> tensor<10x5xf32>
    %10 = stablehlo.add %arg3, %9 : tensor<10x5xf32>
    %11 = stablehlo.reshape %8 : (tensor<5xf32>) -> tensor<1x5xf32>
    %12 = stablehlo.add %arg0, %11 : tensor<1x5xf32>
    return %12, %5, %10 : tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>
  }
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
  what():  UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[5]: 
<unknown>:0: note: see current operation: "func.return"(%15, %8, %13) : (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) -> ()
wsmoses commented 1 month ago

@avik-pal the lux fixes (and named tuple) just landed and were released.

I'll give the reduction error a go shortly, but at minimum we can see what works (and perhaps mark that as expected broken to start with)

avik-pal commented 1 month ago

Currently the julia session crashes because of the broken reverse pass, so can't mark it broken

avik-pal commented 1 month ago

I will try to see what kind of models compile for the forward pass atleast

avik-pal commented 1 month ago

Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like ComponentArrays.

Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such

wsmoses commented 1 month ago

Not easily as we need to own the data.

Similarly the model and ideally the inputs are always kept rarrays here.

Also for better performance the optimizers themselves are compiled by reactant

On Sat, Jun 1, 2024 at 8:33 AM Avik Pal @.***> wrote:

Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like ComponentArrays.

Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/665#issuecomment-2143324103, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFOCHIUSBR7OQWBYO3ZFFTNJAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGMZDIMJQGM . You are receiving this because you were mentioned.Message ID: <LuxDL/Lux. @.***>

avik-pal commented 1 month ago

Also for better performance the optimizers themselves are compiled by reactant

Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia.

I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge.

wsmoses commented 1 month ago

There’s no reason why we couldn’t in theory, but I don’t think we do right now.

Worth testing and opening an issue so we know what to work on though

On Sat, Jun 1, 2024 at 5:46 PM Avik Pal @.***> wrote:

Also for better performance the optimizers themselves are compiled by reactant

Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia.

I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge.

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/665#issuecomment-2143493123, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBAIIEC7N2TEXAKBB3ZFHUGXAVCNFSM6AAAAABIKL3C3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTGQ4TGMJSGM . You are receiving this because you were mentioned.Message ID: <LuxDL/Lux. @.***>

wsmoses commented 2 weeks ago

@avik-pal fix has landed, can we retry this?

avik-pal commented 2 weeks ago

This one is too broadly scoped, so I will hold it off.

First, I want to finish #673, which compiles the entire training loop and doesn't need to worry about users doing unwanted things to the parameters.