EnzymeAD / Reactant.jl

MIT License
46 stars 2 forks source link

TypeError on compiling broadcast #54

Open avik-pal opened 1 month ago

avik-pal commented 1 month ago
julia> xr = Reactant.ConcreteRArray(rand(Float32, 2, 3))
2×3 Reactant.ConcreteRArray{Float32, (2, 3), 2}:
 0.184252  0.863562  0.0996157
 0.14061   0.574859  0.236953

julia> Reactant.@code_hlo broadcast(tanh, xr)
ERROR: TypeError: in typeassert, expected Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}, got a value of type Reactant.TracedRArray{Float32, (2, 3), 2}
Stacktrace:
  [1] (::Reactant.var"#398#408"{typeof(broadcast), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{typeof(tanh), Reactant.TracedRArray{Float32, (2, 3), 2}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
  [2] block!(f::Reactant.var"#398#408"{typeof(broadcast), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{typeof(tanh), Reactant.TracedRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [3] make_mlir_fn(f::Function, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
  [4] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
  [5] #127
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1186 [inlined]
  [6] block!(f::Reactant.var"#127#132"{typeof(broadcast), Vector{Any}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [7] #126
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1185 [inlined]
  [8] mmodule!(f::Reactant.var"#126#131"{Reactant.MLIR.IR.Module, typeof(broadcast), Vector{Any}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Module.jl:93
  [9] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Any}; optimize::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1182
 [10] (::var"#73#74")()
    @ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1314
 [11] context!(f::var"#73#74", ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
 [12] top-level scope
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1312
 [13] top-level scope
    @ none:1
avik-pal commented 1 month ago

Might be related. If we try to compile a broadcasted function that doesn't have the stablehlo mapping defined we get an error:

julia> using NNlib, Reactant

julia> act_fn(x) = swish.(x)
act_fn (generic function with 1 method)

julia> Reactant.@code_hlo act_fn(xr)
ERROR: MethodError: no method matching (::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0})
This error has been manually thrown, explicitly, so the method may exist but be intentionally marked as unimplemented.

Closest candidates are:
  (::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0}, ::Tuple{}) (method too new to be called from this world context.)
   @ Core :0

Stacktrace:
  [1] (::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
  [2] block!(f::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [3] make_mlir_fn(f::Function, args::Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
  [4] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
  [5] elem_apply(f::Function, args::Reactant.TracedRArray{Float32, (2, 3), 2})
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:190
  [6] _copyto!
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:555 [inlined]
  [7] copyto!
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:470 [inlined]
  [8] copyto!
    @ ./broadcast.jl:920 [inlined]
  [9] copy
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:461 [inlined]
 [10] materialize
    @ ./broadcast.jl:867 [inlined]
 [11] act_fn
    @ ./REPL[90]:1 [inlined]
 [12] (::Tuple{})(none::Reactant.TracedRArray{Float32, (2, 3), 2})
    @ Base.Experimental ./<missing>:0
 [13] (::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
 [14] block!(f::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
 [15] make_mlir_fn(f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
 [16] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
 [17] #127
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1186 [inlined]
 [18] block!(f::Reactant.var"#127#132"{typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
 [19] #126
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1185 [inlined]
 [20] mmodule!(f::Reactant.var"#126#131"{Reactant.MLIR.IR.Module, typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Module.jl:93
 [21] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}; optimize::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1182
 [22] (::var"#77#78")()
    @ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1314
 [23] context!(f::var"#77#78", ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
 [24] top-level scope
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1312
 [25] top-level scope
    @ none:1

Note that I am running a version of Reactant that has broadcast defined for sigmoid, so swish = x * sigmoid(x) should be compiled fine.