Closed avik-pal closed 2 weeks ago
Also you only need to tracer through if a variable contains data like an array
Let’s do it.
Also btw this will test both cpu and gpu
On Tue, May 14, 2024 at 12:09 PM Avik Pal @.***> wrote:
@.**** commented on this pull request.
In test/nn_lux.jl https://github.com/EnzymeAD/Reactant.jl/pull/5#discussion_r1600525134:
+comp = f(cmodel, cnoisy, cps, cst) @. comp[3] @. f(cmodel, cnoisy) + +# To train the model, we use batches of 64 samples, and one-hot encoding: +target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix +loader = Flux.DataLoader((noisy, target); batchsize=64, shuffle=true); +# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) + +optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. + +# Training loop, using the whole data set 1000 times: +losses = [] +for epoch in 1:1_000
- for (x, y) in loader
- loss, grads = Flux.withgradient(model) do m
You want to pull in Lux.Experimental.apply_gradients! and compile it together?
— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#discussion_r1600525134, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXC432WZZKTX7G27BV3ZCJOOHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJWGIYDEMBVGQ . You are receiving this because you commented.Message ID: @.***>
Couple of Problems
batchnorm
doesn't seem to work because it tries to trace through mean
and the arrays don't have getindex definedjulia> comp = f(cmodel, cnoisy, cps, cst)
ERROR: UndefVarError: `layer_2` not defined
Stacktrace:
[1] (::Reactant.var"#109#110")(arg_1::Chain{…}, arg_2::Reactant.ConcreteRArray{…}, arg_3::@NamedTuple{…}, arg_4::@NamedTuple{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:704
[2] top-level scope
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:29
Some type information was truncated. Use `show(err)` to see complete types.
create_result
should we add a case for NamedTuple similar to the other cases? The returned state is a namedtuple always.So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.
And yeah, we should just add an overload for mean, want to give it a shot?
This doesn't seem right:
julia> @code_lowered f(cmodel, cnoisy, cps, cst)
CodeInfo(
1 ─ %1 = Reactant.XLA.synced_buffer
│ %2 = Base.getfield
│ %3 = Base.getfield
│ %4 = (%3)(arg_3, Reactant.layer_2)
│ %5 = (%2)(%4, Reactant.weight)
│ %6 = Base.getproperty(%5, :data)
│ sbuf_1 = (%1)(%6)
│ %8 = Reactant.XLA.synced_buffer
│ %9 = Base.getfield
│ %10 = Base.getfield
│ %11 = (%10)(arg_3, Reactant.layer_1)
│ %12 = (%9)(%11, Reactant.weight)
│ %13 = Base.getproperty(%12, :data)
│ sbuf_2 = (%8)(%13)
│ %15 = Reactant.XLA.synced_buffer
│ %16 = Base.getfield
│ %17 = Base.getfield
│ %18 = (%17)(arg_3, Reactant.layer_1)
│ %19 = (%16)(%18, Reactant.bias)
│ %20 = Base.getproperty(%19, :data)
│ sbuf_3 = (%15)(%20)
│ %22 = Reactant.XLA.synced_buffer
│ %23 = Base.getfield
│ %24 = Base.getfield
│ %25 = (%24)(arg_3, Reactant.layer_2)
│ %26 = (%23)(%25, Reactant.bias)
│ %27 = Base.getproperty(%26, :data)
│ sbuf_4 = (%22)(%27)
│ %29 = Reactant.XLA.synced_buffer
│ %30 = Base.getproperty(arg_2, :data)
│ sbuf_5 = (%29)(%30)
│ %32 = $(Expr(:gc_preserve_begin, :(sbuf_1), :(sbuf_2), :(sbuf_3), :(sbuf_4), :(sbuf_5)))
│ %33 = Reactant.XLA.ExecutableCall
│ %34 = Base.getproperty(sbuf_1, :buffer)
│ %35 = Base.getproperty(sbuf_2, :buffer)
│ %36 = Base.getproperty(sbuf_3, :buffer)
│ %37 = Base.getproperty(sbuf_4, :buffer)
│ %38 = Base.getproperty(sbuf_5, :buffer)
│ %39 = Core.tuple(%34, %35, %36, %37, %38)
│ %40 = Reactant.Val(1)
│ %41 = (%33)(Reactant.XLA.LoadedExecutable(Ptr{Nothing} @0x000000001114fd40), %39, (0x01, 0x01, 0x01, 0x01, 0x01), %40)
│ linearized_results = %41
│ $(Expr(:gc_preserve_end, :(%32)))
│ concrete_res_1 = Base.getindex(linearized_results, 1)
│ result = (Reactant.ConcreteRArray{Float32, (2, 1000), 2})(concrete_res_1)
└── return result
)
it shouldn't be Reactant.layer_2
So it is possible to add getindex/etc here, but I'm intentionally preventing so now so we can ensure we trace fully vectorized code.
Makes sense, I added ArrayInterface to allow easy checking for that in downstream codes
Ah we should probably add an escape in the macro
wait that's odd though it should be a symbol there being looked up?
wait that's odd though it should be a symbol there being looked up?
It was directly getting interpolated, needed a Meta.quot
function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
function crossentropy(ŷ, y)
return .-sum(xlogy.(y, ŷ))
end
function loss_function(model, x, y, ps, st)
y_hat, _ = model(x, ps, st)
return crossentropy(y_hat, y)
end
compiled_loss_function = Reactant.compile(
loss_function, (cmodel, cnoisy, ctarget, cps, CST))
~elem_apply
is not defined for xlogy
, how do I trace into the body of xlogy, there isn't a direct mapping for that in https://openxla.org/stablehlo/spec#log~
The eltypes don't match
MethodError: no method matching elem_apply(::typeof(xlogy), ::Reactant.TracedRArray{Bool, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2})
My failed attempt at defining the comparisons
for (jlop, hloop, hlocomp, RT) in ((:(Base.:(==)), :compare, 0, :ElType),)
@eval begin
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data;
comparison_direction=$hlocomp), 1))
end
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
end
function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=$hlocomp), 1))
end
end
end
How do I pass the comparison direction enum? https://openxla.org/stablehlo/spec#compare
@avik-pal comparisons added here: 0f7a912ca9cfd2ce1a96491052a16eab899cc9a7
Seems like compiling the gradient is hitting
julia> compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst))
ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value)
Stacktrace:
[1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149
[2] #42
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined]
[3] ntuple
@ ./ntuple.jl:19 [inlined]
[4] make_zero
@ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined]
[5] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251
[7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64)
@ Base ./ntuple.jl:19
[8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249
[9] make_zero(::Type{…}, seen::IdDict{…}, prev::@NamedTuple{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256
[10] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined]
[11] overdub
@ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined]
[12] gradient_loss_function(::Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::@NamedTuple{layer_1::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_2::@NamedTuple{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[13] gradient_loss_function
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined]
[14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0
[15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53
[16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198
[17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46
[18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})()
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927
[19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89
[20] #99
@ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined]
[21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68
[22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing)
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923
[23] compile(f::typeof(gradient_loss_function), args::Tuple{…})
@ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918
[24] top-level scope
@ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70
Some type information was truncated. Use `show(err)` to see complete types.
Oh we should just provide a hook into make_zero for tracedrarray
On Fri, May 17, 2024 at 7:05 AM Avik Pal @.***> wrote:
Seems like compiling the gradient is hitting
julia> compiled_gradient = Reactant.compile( gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)) ERROR: MethodError: no method matching (Reactant.TracedRArray{Float32, Shape, 2} where Shape)(::Tuple{}, ::Reactant.MLIR.IR.Value) Stacktrace: [1] make_zero(::Type{…}, seen::IdDict{…}, prev::Reactant.TracedRArray{…}, ::Val{…}) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:149 [2] #42 @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [inlined] [3] ntuple @ ./ntuple.jl:19 [inlined] [4] make_zero @ ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [inlined] [5] make_zero(::Type{…}, seen::IdDict{…}, @.{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256 [6] (::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}})(i::Int64) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1251 [7] ntuple(f::Enzyme.Compiler.var"#42#43"{Tuple{…}, false, IdDict{…}, Tuple{…}}, n::Int64) @ Base ./ntuple.jl:19 [8] make_zero(::Type{…}, seen::IdDict{…}, prev::Tuple{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1249 [9] make_zero(::Type{…}, seen::IdDict{…}, @.{…}, ::Val{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/NVk8T/src/compiler.jl:1256 [10] make_zero (repeats 2 times) @ ~/.julia/packages/EnzymeCore/Z0CgU/src/EnzymeCore.jl:237 [inlined] [11] overdub @ /mnt/research/ongoing/lux/Reactant.jl/src/overloads.jl:358 [inlined] [12] @.{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, layer_3::WrappedFunction{typeof(softmax)}}}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, ::Reactant.TracedRArray{Float32, (2, 1000), 2}, @*.**@*.{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, @.{weight::Reactant.TracedRArray{Float32, Shape, 2} where Shape, bias::Reactant.TracedRArray{Float32, Shape, 2} where Shape}, @.{}}, @*.**@*.{}, @.{}, @.{}}) @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined] [13] gradient_loss_function @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:61 [inlined] [14] overdub(::Cassette.Context{…}, ::typeof(gradient_loss_function), ::Chain{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, @.{…}, @.***{…}) @ Cassette ~/.julia/packages/Cassette/4UsSX/src/overdub.jl:0 [15] (::Reactant.var"#5#13"{typeof(gradient_loss_function), Tuple{}, Reactant.MLIR.IR.Block, Vector{…}, Tuple{…}})() @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:53 [16] block!(f::Reactant.var"#5#13"{…}, blk::Reactant.MLIR.IR.Block) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Block.jl:198 [17] make_mlir_fn(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/utils.jl:46 [18] (::Reactant.var"#100#105"{Reactant.MLIR.IR.Module, typeof(gradient_loss_function), Tuple{…}, Int64})() @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:927 [19] mmodule!(f::Reactant.var"#100#105"{…}, blk::Reactant.MLIR.IR.Module) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Module.jl:89 [20] #99 @ /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:925 [inlined] [21] context!(f::Reactant.var"#99#104"{typeof(gradient_loss_function), Tuple{…}, Int64}, ctx::Reactant.MLIR.IR.Context) @ Reactant.MLIR.IR /mnt/research/ongoing/lux/Reactant.jl/src/mlir/IR/Context.jl:68 [22] compile(f::typeof(gradient_loss_function), args::Tuple{…}; pipeline_options::String, client::Nothing) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:923 [23] compile(f::typeof(gradient_loss_function), args::Tuple{…}) @ Reactant /mnt/research/ongoing/lux/Reactant.jl/src/Reactant.jl:918 [24] top-level scope @ /mnt/research/ongoing/lux/Reactant.jl/test/nn_lux.jl:70 Some type information was truncated. Use
show(err)
to see complete types.— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#issuecomment-2117686886, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDZRVUVFAKYLJUBZTDZCYFEPAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXGY4DMOBYGY . You are receiving this because you commented.Message ID: @.***>
julia> cps
(layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())
Shouldn't the Shape
here be fixed?
Depends on the requirements of the type. If something takes in a vector(float64) as a member variable we default replace with the union over sizes since that’s semantically equivalent (if say you have code that changes the size).
But if it’s possible to leave consistent it may be nice to fully type the size.
Check out our trace type function (i forget the exact name)
On Fri, May 17, 2024 at 7:18 AM Avik Pal @.***> wrote:
julia> cps (layer_1 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[-0.98918384 0.190184; 0.046477042 -1.0701349; -0.36382833 0.8563723], Float32[0.0; 0.0; 0.0;;])), layer_2 = @NamedTuple{weight::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape, bias::Reactant.ConcreteRArray{Float32, Shape, 2} where Shape}((Float32[0.8828562 -0.19665341 -0.70401317; -0.67718965 0.056223422 -0.2397092], Float32[0.0; 0.0;;])), layer_3 = NamedTuple())
Shouldn't the Shape here be fixed?
— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Reactant.jl/pull/5#issuecomment-2117715972, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXB5X75GHMBFLMGWJODZCYGUHAVCNFSM6AAAAABHWUKINSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJXG4YTKOJXGI . You are receiving this because you commented.Message ID: @.***>
This still needs the reduce pipeline error to be fixed before it is ready to be merged
That fix is here: https://github.com/EnzymeAD/Enzyme-JAX/pull/93 and will have a jll later today hopefully
Needs https://github.com/EnzymeAD/Reactant.jl/issues/4