EnzymeAD / Reactant.jl

MIT License
26 stars 2 forks source link

[WIP] Add a Lux example #5

Closed avik-pal closed 2 weeks ago

avik-pal commented 1 month ago

Needs https://github.com/EnzymeAD/Reactant.jl/issues/4

wsmoses commented 1 month ago

Also you only need to tracer through if a variable contains data like an array

wsmoses commented 1 month ago

Let’s do it.

Also btw this will test both cpu and gpu

On Tue, May 14, 2024 at 12:09 PM Avik Pal @.***> wrote:

@.**** commented on this pull request.

In test/nn_lux.jl https://github.com/EnzymeAD/Reactant.jl/pull/5#discussion_r1600525134:

+comp = f(cmodel, cnoisy, cps, cst) @. comp[3] @. f(cmodel, cnoisy) + +# To train the model, we use batches of 64 samples, and one-hot encoding: +target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix +loader = Flux.DataLoader((noisy, target); batchsize=64, shuffle=true); +# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) + +optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. + +# Training loop, using the whole data set 1000 times: +losses = [] +for epoch in 1:1_000

  • for (x, y) in loader
  • loss, grads = Flux.withgradient(model) do m

You want to pull in Lux.Experimental.apply_gradients! and compile it together?

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#discussion_r1600525134, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXC432WZZKTX7G27BV3ZCJOOHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJWGIYDEMBVGQ . You are receiving this because you commented.Message ID: @.***>

avik-pal commented 1 month ago

Couple of Problems

  1. batchnorm doesn't seem to work because it tries to trace through mean and the arrays don't have getindex defined
  2. Without batchnorm. After the function is compiled and I try to run it, I get:
julia> comp = f(cmodel, cnoisy, cps, cst)
ERROR: UndefVarError: `layer_2` not defined
 [1] (::Reactant.var"#109#110")(arg_1::Chain{…}, arg_2::Reactant.ConcreteRArray{…}, arg_3::@NamedTuple{…}, arg_4::@NamedTuple{…})
   @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:704
 [2] top-level scope
   @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:29
Some type information was truncated. Use `show(err)` to see complete types.
  1. In create_result should we add a case for NamedTuple similar to the other cases? The returned state is a namedtuple always.
wsmoses commented 1 month ago

So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.

And yeah, we should just add an overload for mean, want to give it a shot?

avik-pal commented 1 month ago

This doesn't seem right:

julia> @code_lowered  f(cmodel, cnoisy, cps, cst)
1 ─ %1  = Reactant.XLA.synced_buffer
│   %2  = Base.getfield
│   %3  = Base.getfield
│   %4  = (%3)(arg_3, Reactant.layer_2)
│   %5  = (%2)(%4, Reactant.weight)
│   %6  = Base.getproperty(%5, :data)
│         sbuf_1 = (%1)(%6)
│   %8  = Reactant.XLA.synced_buffer
│   %9  = Base.getfield
│   %10 = Base.getfield
│   %11 = (%10)(arg_3, Reactant.layer_1)
│   %12 = (%9)(%11, Reactant.weight)
│   %13 = Base.getproperty(%12, :data)
│         sbuf_2 = (%8)(%13)
│   %15 = Reactant.XLA.synced_buffer
│   %16 = Base.getfield
│   %17 = Base.getfield
│   %18 = (%17)(arg_3, Reactant.layer_1)
│   %19 = (%16)(%18, Reactant.bias)
│   %20 = Base.getproperty(%19, :data)
│         sbuf_3 = (%15)(%20)
│   %22 = Reactant.XLA.synced_buffer
│   %23 = Base.getfield
│   %24 = Base.getfield
│   %25 = (%24)(arg_3, Reactant.layer_2)
│   %26 = (%23)(%25, Reactant.bias)
│   %27 = Base.getproperty(%26, :data)
│         sbuf_4 = (%22)(%27)
│   %29 = Reactant.XLA.synced_buffer
│   %30 = Base.getproperty(arg_2, :data)
│         sbuf_5 = (%29)(%30)
│   %32 = $(Expr(:gc_preserve_begin, :(sbuf_1), :(sbuf_2), :(sbuf_3), :(sbuf_4), :(sbuf_5)))
│   %33 = Reactant.XLA.ExecutableCall
│   %34 = Base.getproperty(sbuf_1, :buffer)
│   %35 = Base.getproperty(sbuf_2, :buffer)
│   %36 = Base.getproperty(sbuf_3, :buffer)
│   %37 = Base.getproperty(sbuf_4, :buffer)
│   %38 = Base.getproperty(sbuf_5, :buffer)
│   %39 = Core.tuple(%34, %35, %36, %37, %38)
│   %40 = Reactant.Val(1)
│   %41 = (%33)(Reactant.XLA.LoadedExecutable(Ptr{Nothing} @0x000000001114fd40), %39, (0x01, 0x01, 0x01, 0x01, 0x01), %40)
│         linearized_results = %41
│         $(Expr(:gc_preserve_end, :(%32)))
│         concrete_res_1 = Base.getindex(linearized_results, 1)
│         result = (Reactant.ConcreteRArray{Float32, (2, 1000), 2})(concrete_res_1)
└──       return result

it shouldn't be Reactant.layer_2

avik-pal commented 1 month ago

So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.

Makes sense, I added ArrayInterface to allow easy checking for that in downstream codes

wsmoses commented 1 month ago

Ah we should probably add an escape in the macro

wsmoses commented 1 month ago

wait that's odd though it should be a symbol there being looked up?

avik-pal commented 1 month ago

wait that's odd though it should be a symbol there being looked up?

It was directly getting interpolated, needed a Meta.quot

avik-pal commented 1 month ago

function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
function crossentropy(ŷ, y)
    return .-sum(xlogy.(y, ŷ))

function loss_function(model, x, y, ps, st)
    y_hat, _ = model(x, ps, st)
    return crossentropy(y_hat, y)

compiled_loss_function = Reactant.compile(
    loss_function, (cmodel, cnoisy, ctarget, cps, CST))

~elem_apply is not defined for xlogy, how do I trace into the body of xlogy, there isn't a direct mapping for that in https://openxla.org/stablehlo/spec#log~

The eltypes don't match

MethodError: no method matching elem_apply(::typeof(xlogy), ::Reactant.TracedRArray{Bool, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2})
avik-pal commented 1 month ago

My failed attempt at defining the comparisons

for (jlop, hloop, hlocomp, RT) in ((:(Base.:(==)), :compare, 0, :ElType),)
    @eval begin
        function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data;
                comparison_direction=$hlocomp), 1))

        function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
            rhs = promote_to(lhs, rhs)
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))

        function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
            lhs = promote_to(rhs, lhs)
            return TracedRArray{$RT,Shape,N}((),  MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))

How do I pass the comparison direction enum? https://openxla.org/stablehlo/spec#compare

wsmoses commented 1 month ago


wsmoses commented 1 month ago

@avik-pal comparisons added here: 0f7a912ca9cfd2ce1a96491052a16eab899cc9a7

avik-pal commented 1 month ago

Seems like compiling the gradient is hitting

julia> compiled_gradient = Reactant.compile(
           gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
  [1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…})
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149
  [2] #42
    @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined]
  [3] ntuple
    @ ./ntuple.jl:19 [inlined]
  [4] make_zero
    @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined]
  [5] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
  [6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251
  [7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64)
    @ Base ./ntuple.jl:19
  [8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249
  [9] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
 [10] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
 [11] overdub
    @ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined]
 [12] gradient_loss_function(::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
 [13] gradient_loss_function
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
 [14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
 [15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})()
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53
 [16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198
 [17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool)
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46
 [18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})()
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927
 [19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89
 [20] #99
    @ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined]
 [21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68
 [22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing)
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923
 [23] compile(f::typeof(gradient_loss_function), args::Tuple{…})
    @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918
 [24] top-level scope
    @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 1 month ago

Oh we should just provide a hook into make_zero for tracedrarray

On Fri, May 17, 2024 at 7:05 AM Avik Pal @.***> wrote:

Seems like compiling the gradient is hitting

julia> compiled_gradient = Reactant.compile( gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)) ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value) Stacktrace: [1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…}) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149 [2] #42 @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined] [3] ntuple @ ./ntuple.jl:19 [inlined] [4] make_zero @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined] [5] make_zero(::Type{…}, seen::IdDict{…}, @.{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256 [6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64) @ Base ./ntuple.jl:19 [8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [9] make_zero(::Type{…}, seen::IdDict{…}, @.{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256 [10] make_zero (repeats 2 times) @ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined] [11] overdub @ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined] [12] @.{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, @*.**@*.{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, @.{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, @.{}}, @*.**@*.{}, @.{}, @.{}}) @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined] [13] gradient_loss_function @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined] [14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, @.{…}, @.***{…}) @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0 [15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})() @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53 [16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198 [17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46 [18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})() @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927 [19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89 [20] #99 @ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined] [21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68 [22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923 [23] compile(f::typeof(gradient_loss_function), args::Tuple{…}) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918 [24] top-level scope @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70 Some type information was truncated. Use show(err) to see complete types.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#issuecomment-2117686886, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDZRVUVFAKYLJUBZTDZCYFEPAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXGY4DMOBYGY . You are receiving this because you commented.Message ID: @.***>

avik-pal commented 1 month ago
julia> cps
(layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())

Shouldn't the Shape here be fixed?

wsmoses commented 1 month ago

Depends on the requirements of the type. If something takes in a vector(float64) as a member variable we default replace with the union over sizes since that’s semantically equivalent (if say you have code that changes the size).

But if it’s possible to leave consistent it may be nice to fully type the size.

Check out our trace type function (i forget the exact name)

On Fri, May 17, 2024 at 7:18 AM Avik Pal @.***> wrote:

julia> cps (layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())

Shouldn't the Shape here be fixed?

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#issuecomment-2117715972, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXB5X75GHMBFLMGWJODZCYGUHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXG4YTKOJXGI . You are receiving this because you commented.Message ID: @.***>

avik-pal commented 2 weeks ago

This still needs the reduce pipeline error to be fixed before it is ready to be merged

wsmoses commented 2 weeks ago

That fix is here: https://github.com/EnzymeAD/Enzyme-JAX/pull/93 and will have a jll later today hopefully