Closed ToucheSir closed 6 months ago
cc @sethaxen re (2) since they were playing around with enzyme rules for BLAS. Might you be interested in helping add the cudaMemcpy/cublas/NNlib/etc stuff?
You reminded me that I forgot to link https://github.com/FluxML/NNlib.jl/issues/503 :)
cc @sethaxen re (2) since they were playing around with enzyme rules for BLAS. Might you be interested in helping add the cudaMemcpy/cublas/NNlib/etc stuff?
Interested, yes, but am still lobbying to take it on as a work project. The BLAS/Enzyme work is more aligned with the other things I work on and is probably all I can focus on right now.
Either way, figuring out where the rules should go and at what level they are needed is useful for whoever takes this on.
I think the right way to do it is in two steps.
@ruletransfer conv(x)
I honestly don't think this would take all that long. The reason to not carry over chain rules is that many of the rules aren't necessary for Enzyme, but doing it like this would allow one to just pick out the necessary rules and get NNLib converted rather quickly. Then of course there can always be improvements to use less memory and such, but I'd say we do this conversion, then Enzyme is at least strictly better than Zygote for DL, that helps the ecosystem move, and then worry about grabbing the last bit of performance out of each rule.
@ToucheSir so in the example you gave above, flux has some runtime activity return mismatches (now featuring the better backtraces which have landed to main).
julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
ERROR: Enzyme execution failed.
Mismatched activity for: ret {} addrspace(10)* %186, !dbg !418 const val: %186 = call fastcc noalias nonnull {} addrspace(10)* @julia_Array_2108({} addrspace(10)* noalias nocapture nofree nonnull readnone align 64 undef, [2 x i64] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %185) #403, !dbg !497
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] Dense
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:174
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:2790
[2] Dense
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:174
[3] macro expansion
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
[4] _applychain
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
[5] Chain
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:51 [inlined]
[6] loss
@ ./REPL[5]:1 [inlined]
[7] loss
@ ./REPL[5]:0 [inlined]
[8] diffejulia_loss_1881_inner_8wrap
@ ./REPL[5]:0
[9] macro expansion
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9369 [inlined]
[10] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Float32)
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9047
[11] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9010
[12] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
@ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:213
[13] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
@ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:222
[14] top-level scope
@ REPL[9]:1
[15] top-level scope
@ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
Some type information was truncated. Use `show(err)` to see complete types.
This was then fixed by marking the Dense function on basic.jl:170 as @ inline.
Of course this then hit a second one below:
julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
ERROR: Enzyme execution failed.
Mismatched activity for: ret {} addrspace(10)* %106, !dbg !398 const val: %106 = call fastcc noalias nonnull {} addrspace(10)* @julia_Array_2004({} addrspace(10)* noalias nocapture nofree nonnull readnone align 64 undef, [2 x i64] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %105) #399, !dbg !459
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] -
@ ./abstractarraymath.jl:218
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:2790
[2] -
@ ./abstractarraymath.jl:218
[3] #_norm_layer_forward#302
@ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:247
[4] _norm_layer_forward
@ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:225 [inlined]
[5] BatchNorm
@ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:351
[6] macro expansion
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
[7] _applychain
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53
[8] Chain
@ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:51 [inlined]
[9] loss
@ ./REPL[4]:1 [inlined]
[10] loss
@ ./REPL[4]:0 [inlined]
[11] diffejulia_loss_1755_inner_8wrap
@ ./REPL[4]:0
[12] macro expansion
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9369 [inlined]
[13] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Float32)
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9047
[14] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9010
[15] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
@ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:213
[16] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
@ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:222
[17] top-level scope
@ REPL[8]:1
[18] top-level scope
@ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
Some type information was truncated. Use `show(err)` to see complete types.
How does flux feel about adding some relevant @ inline's
Okay now after fixing KA.jl (https://github.com/JuliaGPU/KernelAbstractions.jl/pull/412), @ToucheSir your snippet runs successfully (didn't check values):
julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #249, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #249, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #249, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #249, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #250, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #250, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #251, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #251, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #255, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #255, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #255, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #255, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)
julia> println(dmodel)
Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))
On the fast blas enabled mode, the perf below for your microcode is as follows (though note the numbers didn't seem to match Zygote, so @ToucheSir if you have some cycles to identify what code causes a divergence).
julia> @btime Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
9.090 μs (77 allocations: 5.42 KiB)
((nothing, nothing),)
julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)
julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)
julia> @btime Zygote.gradient(model->loss(model, x), model)
144.153 μs (652 allocations: 39.97 KiB)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)
I don't have access to a machine to test this right now, but the difference is likely due to BatchNorm being in auto train mode in Zygote and not in Enzyme. Running Flux.trainmode!(model)
before the autodiff call should help confirm that.
Given the continued progres of all the Enzyme side of things @ToucheSir I think the next step here would be to isolate what part of the flux code causes Enzyme/Zygote answers to differ, so it can be fixed
I just realized the example above doesn't actually show the discrepancy, what code are you running which does? As I mentioned earlier, the likeliest culprit is manually running Flux.trainmode!(model)
before using Enzyme, since all ADs current need to opt-in to the auto trainmode when differentiating mechanism Flux uses for norm layers.
Oh maybe I did that incorrectly then.
Nevertheless, it would be interesting to start doing some flux + enzyme correctness tests, then we can start diving into the performance (which I see already a lot of optimizations we should be applying but aren't so I'm hopeful we can iterate on).
I tried some Flux models with Enzyme to see whether the gradients match Zygote.
The following, based off the above, is with Julia 1.10.0, Enzyme main (283a1c5) and Flux 0.14.10.
using Enzyme, Flux, Random, Test
Enzyme.API.runtimeActivity!(true)
loss(model, x) = sum(model(x))
function test_model(model, x, mi)
println(model)
l = loss(model, x)
Flux.reset!(model)
grads_flux = Flux.gradient(m -> loss(m, x), model)[1]
grads_enzyme = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
Flux.reset!(model)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Const(x))
@testset "Model $mi" begin
Flux.reset!(model)
@test loss(model, x) == l # Check loss doesn't change with multiple runs
for i in eachindex(grads_flux.layers)
layer_flux = grads_flux.layers[i]
layer_enzyme = grads_enzyme.layers[i]
for field in (:weight, :bias, :scale)
if hasfield(typeof(layer_flux), field)
@test isapprox(getfield(layer_flux, field), getfield(layer_enzyme, field))
end
end
if hasfield(typeof(layer_flux), :cell)
for field in (:Wh, :Wi, :b)
@test isapprox(getfield(layer_flux.cell, field), getfield(layer_enzyme.cell, field))
end
end
end
end
end
The good news is that that the following all work. I steered clear of normalisation or anything that changes with train/test for now.
models_xs = [
[
Chain(Dense(2 => 4), Dense(4 => 2)),
randn(Float32, 2, 1),
],
[
f64(Chain(Dense(2 => 4), Dense(4 => 2))),
randn(Float64, 2, 1),
],
[
Chain(Dense(2 => 4, relu), Dense(4 => 2)),
randn(Float32, 2, 1),
],
[
Chain(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2)),
randn(Float32, 2),
],
[
Chain(Conv((3, 3), 3 => 7, relu), Conv((3, 3), 7 => 7, relu)),
rand(Float32, 10, 10, 3, 50),
],
[
Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())),
rand(Float32, 100, 100, 3, 50),
],
[
Maxout(() -> Dense(5 => 7, tanh), 3),
randn(Float32, 5, 1),
],
[
Chain(RNN(3 => 5), RNN(5 => 3)),
randn(Float32, 3, 10),
],
[
Chain(LSTM(3 => 5), LSTM(5 => 3)),
randn(Float32, 3, 10),
],
]
for (mi, (model, x)) in enumerate(models_xs)
test_model(model, x, mi)
end
The following error in Enzyme:
models_xs = [
[
SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false)),
randn(Float32, 5, 1),
],
[
Chain(ConvTranspose((3, 3), 3 => 7, stride=2)),
rand(Float32, 10, 10, 3, 50),
],
[
Chain(GRU(3 => 5)),
randn(Float32, 3, 10),
],
[
fmap(cu, Chain(Dense(2 => 4), Dense(4 => 2))), # Requires using CUDA
cu(randn(Float32, 2, 1)),
],
]
And this one gives slightly different gradients:
models_xs = [
[
Chain(Conv((5, 5), 3 => 7), MeanPool((5,5), pad=SamePad())),
rand(Float32, 100, 100, 3, 50),
],
]
If it is helpful I can open individual issues and add the working cases to the Enzyme tests.
Yes individual issues (with corresponding error traces), would be highly helpful!
I'd also separately be interested in which ones fail if runtime activity is off
@ToucheSir the SkipConnection one seems to be a pure flux issue potentially?
julia> models_xs = [
[
SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false)),
randn(Float32, 5, 1),
],]
1-element Vector{Vector{Any}}:
[SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Bilinear((9, 5) => 3; bias=false)), Float32[0.56165487; 1.2769437; … ; 0.798284; 0.12582794;;]]
julia> for (mi, (model, x)) in enumerate(models_xs)
test_model(model, x, mi)
end
SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Bilinear((9, 5) => 3; bias=false))
Model 1: Error During Test at REPL[11]:14
Got exception outside of a @test
MethodError: no method matching getindex(::Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, ::Symbol)
Closest candidates are:
getindex(::Tuple, ::Colon)
@ Base tuple.jl:37
getindex(::Tuple, ::Int64)
@ Base tuple.jl:31
getindex(::Tuple, ::CartesianIndex{1})
@ Base multidimensional.jl:882
...
Stacktrace:
[1] #getindex#173
@ Flux ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:18 [inlined]
[2] getindex(x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, args::Symbol)
@ Flux ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:17
[3] macro expansion
@ ./REPL[11]:20 [inlined]
[4] macro expansion
@ ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[5] test_model(model::SkipConnection{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Flux.Bilinear{typeof(identity), Array{Float32, 3}, Bool}}, x::Matrix{Float32}, mi::Int64)
@ Main ./REPL[11]:15
[6] top-level scope
@ ./REPL[20]:2
[7] eval
@ Core ./boot.jl:385 [inlined]
[8] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
@ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
[9] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
@ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
[10] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
@ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
[11] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
@ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
[12] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
[13] (::Base.var"#1013#1015"{Bool, Bool, Bool})(REPL::Module)
@ Base ./client.jl:432
[14] #invokelatest#2
@ Base ./essentials.jl:887 [inlined]
[15] invokelatest
@ Base ./essentials.jl:884 [inlined]
[16] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
@ Base ./client.jl:416
[17] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:333
[18] _start()
@ Base ./client.jl:552
Test Summary: | Pass Error Total Time
Model 1 | 1 1 2 0.2s
ERROR: Some tests did not pass: 1 passed, 0 failed, 1 errored, 0 broken.
ConvTranspose issue has been posted to nnlib.jl https://github.com/FluxML/NNlib.jl/issues/565 It requires a rule implementation/extension to conv.
@ToucheSir the SkipConnection one seems to be a pure flux issue potentially?
Doesn't look like it:
julia> model = SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false))
SkipConnection(
Chain(
Dense(5 => 20, tanh), # 120 parameters
Dense(20 => 9, tanh), # 189 parameters
),
Bilinear((9, 5) => 3; bias=false), # 135 parameters
) # Total: 5 arrays, 444 parameters, 2.094 KiB.
julia> x = randn(Float32, 5, 1)
5×1 Matrix{Float32}:
-0.6271728
-0.5722281
-1.7240835
0.43075645
0.044463925
julia> model(x)
3×1 Matrix{Float32}:
0.39723554
0.15903589
-0.38918468
I believe the code here is making invalid assumptions about the structure of models and their gradients (i.e. model is always a single-layer Chain
of Dense
s, which is not true for SkipConnection
), so the error is caused by the test code instead of Flux.
for i in eachindex(grads_flux.layers) layer_flux = grads_flux.layers[i] layer_enzyme = grads_enzyme.layers[i] for field in (:weight, :bias, :scale) if hasfield(typeof(layer_flux), field) @test isapprox(getfield(layer_flux, field), getfield(layer_enzyme, field)) end end if hasfield(typeof(layer_flux), :cell) for field in (:Wh, :Wi, :b) @test isapprox(getfield(layer_flux.cell, field), getfield(layer_enzyme.cell, field)) end end end
I suppose then its that little test code erring then, which means maybe Enzyme works on it :) [or maybe not].
For me the test code runs the Flux version fine but errors on the Enzyme version. The Enzyme error is attached as error.txt. I notice your stacktrace contains julia-1.10.0-rc2
Billy, I don't know if that affects things.
The gradient checking code makes assumptions and might fail to check all gradients but ideally shouldn't error itself for anything with indexable .layers
since it uses hasfield
.
@jgreener64 can you try on latest main (aka not release). There have been fixes for that error precisely that have landed on main since.
Still errors for me on main (877e1d90).
Can you confirm your jll version via st? (it should be version 0.0.100)
Yes it is with jll version 0.0.100. The error message looked similar but may have changed a bit, I can't check right now.
Given that many of the core points here have been resolved, I'm going to propose closing this issue unless there are any objections.
1) Type-stable deep learning libraries. All the Enzyme tests on Flux now pass, due to both Enzyme gaining more type unstable support, as well Flux being mostly type stable now. I'm not sure what the exact status here is for Lux (e.g. https://github.com/LuxDL/Lux.jl/issues/605). But given that there is a DL library happy, I'll call this good for now (and presume others can follow suit shortly).
2) CUBLAS/cudaMemCopy/etc rules. NNlib now has EnzymeRules for relevant functions, and we have also added cuBLAS rules in our existing blas support. There is still a need for cudaMemcpy and some Julia-side JIT CUDA fixups for Enzyme, but we need to see if the existing support is sufficient for DL with tests, and the broader CUDA runtime function support can be separated to its own issue for what fails.
3) Scheduling. Like I said at the top this is key to good performance, and incidentally distinct from AD. I have started playing with a repo Reactant.jl (https://github.com/EnzymeAD/Reactant.jl) which aims to resolve this. It can take a julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for AD, and create relevant executables for CPU/GPU/TPU via XLA. It is very much in progress, but nevertheless a problem outside Enzyme now.
I'm not sure what the exact status here is for Lux
The tests are in, and 1 small patch was needed. The release should be available in an hour or so.
As mentioned, closing this now for the reasons above.
Copying over and summarizing some discussion from Slack:
@wsmoses:
A type stable example for 1):
As mentioned on Slack, I'd be happy to provide more if people have ideas of model types they'd like to see.
Riffing on 2): @touchesir:
@masonprotter: