EnzymeAD / Reactant.jl

MIT License
67 stars 7 forks source link

2nd order AD fails #298

Open avik-pal opened 2 days ago

avik-pal commented 2 days ago
using Reactant, Enzyme, Lux, Random, LinearAlgebra

const xdev = reactant_device()
const cdev = cpu_device()

model = Dense(5 => 5, gelu);
ps, st = Lux.setup(Random.default_rng(), model) |> xdev;
potential = StatefulLuxLayer{true}(model, ps, st)

# Currently EnzymeMLIR doesn't support batching so we force chunksize to 1
function ∇potential(potential, x)
    J = reshape(only(Enzyme.jacobian(Forward, potential, x; chunk=Val(1))), :, length(x))
    J_diag = @allowscalar diag(J)
    return reshape(J_diag, size(x))
end

function ∇²potential(potential, x)
    J = reshape(only(
        Enzyme.jacobian(Forward, Base.Fix1(∇potential, potential), x; chunk=Val(1))
    ), :, length(x))
end

x_ra = randn(Float32, 5, 3) |> xdev

@code_hlo ∇²potential(potential, x_ra)

A non-minimal example taken from https://github.com/LuxDL/Lux.jl/issues/614

avik-pal commented 2 days ago
Error Msg

```julia ERROR: AssertionError: Base.isconcretetype(typ) Stacktrace: [1] abs_typeof(arg::LLVM.LoadInst, partial::Bool, seenphis::Set{LLVM.PHIInst}) @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:557 [2] abs_typeof(arg::LLVM.LoadInst) @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:283 [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing) @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7066 [4] codegen @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined] [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [6] _thunk @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined] [7] cached_compilation @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined] [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…}) @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641 [9] #s2105#19135 @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined] [10] @ Enzyme.Compiler ./none:0 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any}) @ Core ./boot.jl:707 [12] autodiff @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:633 [inlined] [13] autodiff @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined] [14] macro expansion @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined] [15] gradient(::ForwardMode{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…}) @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970 [16] #jacobian#133 @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined] [17] jacobian @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined] [18] ∇potential(potential::StatefulLuxLayer{…}, x::Reactant.TracedRArray{…}) @ Main /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:17 [19] Fix1 @ ./operators.jl:1127 [inlined] [20] #apply#24 @ /mnt/software/lux/Reactant.jl/src/utils.jl:37 [inlined] [21] apply @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined] [22] (::Tuple{})(none::Base.Fix1{typeof(∇potential), StatefulLuxLayer{…}}, none::Tuple{Reactant.TracedRArray{…}}) @ Base.Experimental ./:0 [23] (::Reactant.var"#32#42"{Bool, Bool, typeof(Reactant.apply), Tuple{…}, Vector{…}, Tuple{…}})() @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:148 [24] block!(f::Reactant.var"#32#42"{…}, blk::Reactant.MLIR.IR.Block) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201 [25] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool) @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120 [26] make_mlir_fn @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined] [27] #make_mlir_fn#25 @ /mnt/software/lux/Reactant.jl/src/utils.jl:53 [inlined] [28] make_mlir_fn @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined] [29] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…}) @ Reactant /mnt/software/lux/Reactant.jl/src/Interpreter.jl:373 [30] autodiff @ /mnt/software/lux/Reactant.jl/src/Interpreter.jl:660 [inlined] [31] autodiff @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined] [32] macro expansion @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined] [33] gradient(::ForwardMode{…}, ::Base.Fix1{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…}) @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970 [34] #jacobian#133 @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined] [35] jacobian @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined] [36] ∇²potential @ /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:23 [inlined] [37] (::Tuple{})(none::StatefulLuxLayer{…}, none::Reactant.TracedRArray{…}) @ Base.Experimental ./:0 [38] (::Reactant.var"#32#42"{Bool, Bool, typeof(∇²potential), Tuple{…}, Vector{…}, Tuple{…}})() @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:157 [39] block!(f::Reactant.var"#32#42"{…}, blk::Reactant.MLIR.IR.Block) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201 [40] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool) @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120 [41] make_mlir_fn @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined] [42] #10 @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:286 [inlined] [43] block!(f::Reactant.Compiler.var"#10#15"{typeof(∇²potential), Tuple{…}}, blk::Reactant.MLIR.IR.Block) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201 [44] #9 @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined] [45] mmodule!(f::Reactant.Compiler.var"#9#14"{…}, blk::Reactant.MLIR.IR.Module) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92 [46] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool) @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:282 [47] compile_mlir! @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:281 [inlined] [48] #6 @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:276 [inlined] [49] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{…}, typeof(∇²potential), Tuple{…}}, ctx::Reactant.MLIR.IR.Context) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76 [50] compile_mlir(f::Function, args::Tuple{StatefulLuxLayer{…}, ConcreteRArray{…}}; kwargs::@Kwargs{optimize::Bool}) @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:274 Some type information was truncated. Use `show(err)` to see complete types. ```

wsmoses commented 1 day ago

just for fun what if you do set_abi(Forward, ReactantABI)

avik-pal commented 20 hours ago

That did work!

wsmoses commented 19 hours ago

yeah so this is again stemming from "any abstract interpreter based shenanigans fails to go through type unstable code".

Here the actual resolution we did earlier is to make Forward be replaced by set_abi(Forward, ReactantABI) in our absint. This makes things way nicer (including doing the replacement at the callsite of autodiff/jacobian/etc), so any intermediates that are type unstable don't have any issues. Similarly, it means we can natively call it like above. Unfortunately this only applies at the top level absint.

Probably the solution here is to have the absint replace type unstable calls with my_call(...) which itself runs things again in an absint.