EnzymeAD / Reactant.jl

MIT License
60 stars 4 forks source link

`Enzyme.ReverseSplitWithPrimal` is not supported #215

Open avik-pal opened 2 hours ago

avik-pal commented 2 hours ago
using Enzyme, Reactant

f(x) = sum(abs2, x .* x)

function enzyme_split_mode(x)
    dx = Enzyme.make_zero(x)
    forward, reverse = autodiff_thunk(
        ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(x)}
    )
    tape, result, shadow_result = forward(Const(f), Duplicated(x, dx))
    reverse(Const(f), Duplicated(x, dx), 1.0, tape)
    return result, dx
end

x = rand(10)

f(x)
enzyme_split_mode(x)

x_ra = Reactant.to_rarray(x)

@code_hlo optimize = true enzyme_split_mode(x_ra)
avik-pal commented 2 hours ago

error:

ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Union{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64}}, rettype = Active{Union{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64}}}
Stacktrace:
  [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{Nothing}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:4287
  [2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:4023
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:7156
  [4] codegen
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:5972 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8267
  [6] _thunk
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8267 [inlined]
  [7] cached_compilation
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8308 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{0x0000000000007b05}, ::Type{Const{typeof(f)}}, ::Type{Active}, tt::Type{Tuple{Duplicated{…}}}, ::Val{Enzyme.API.DEM_ReverseModeGradient}, ::Val{1}, ::Val{(false, false)}, ::Val{true}, ::Val{false}, ::Type{FFIABI}, ::Val{false}, ::Val{false})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8440
  [9] #s2080#19075
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8577 [inlined]
 [10] var"#s2080#19075"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ErrIfFuncWritten::Any, RuntimeActivity::Any, ::Any, ::Type, ::Type, ::Type, tt::Any, ::Type, ::Type, ::Type, ::Type, ::Type, ::Type, ::Type, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff_thunk
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:969 [inlined]
 [13] enzyme_split_mode
    @ ./REPL[14]:3 [inlined]
 [14] (::Tuple{})(none::Reactant.TracedRArray{Float64, 1})
    @ Base.Experimental ./<missing>:0
 [15] (::Reactant.var"#26#35"{Bool, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:148
 [16] block!(f::Reactant.var"#26#35"{Bool, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [17] make_mlir_fn(f::Function, args::Vector{ConcreteRArray{Float64, 1}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
 [18] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [19] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:271 [inlined]
 [20] block!(f::Reactant.Compiler.var"#6#11"{typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [21] #5
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
 [22] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
 [23] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{ConcreteRArray{Float64, 1}}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:267
 [24] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:266 [inlined]
 [25] #2
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:261 [inlined]
 [26] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:71
 [27] compile_mlir(f::Function, args::Vector{ConcreteRArray{Float64, 1}}; kwargs::@Kwargs{optimize::Bool})
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:259
 [28] macro expansion
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:409 [inlined]
 [29] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.

from the error it seems to hit enzyme proper (I am blind of course it hits enzyme proper, we dont have a autodiff_thunk)