Open avik-pal opened 1 month ago
Might be related. If we try to compile a broadcasted function that doesn't have the stablehlo mapping defined we get an error:
julia> using NNlib, Reactant
julia> act_fn(x) = swish.(x)
act_fn (generic function with 1 method)
julia> Reactant.@code_hlo act_fn(xr)
ERROR: MethodError: no method matching (::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0})
This error has been manually thrown, explicitly, so the method may exist but be intentionally marked as unimplemented.
Closest candidates are:
(::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0}, ::Tuple{}) (method too new to be called from this world context.)
@ Core :0
Stacktrace:
[1] (::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}})()
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
[2] block!(f::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
[3] make_mlir_fn(f::Function, args::Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
[4] make_mlir_fn
@ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
[5] elem_apply(f::Function, args::Reactant.TracedRArray{Float32, (2, 3), 2})
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:190
[6] _copyto!
@ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:555 [inlined]
[7] copyto!
@ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:470 [inlined]
[8] copyto!
@ ./broadcast.jl:920 [inlined]
[9] copy
@ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:461 [inlined]
[10] materialize
@ ./broadcast.jl:867 [inlined]
[11] act_fn
@ ./REPL[90]:1 [inlined]
[12] (::Tuple{})(none::Reactant.TracedRArray{Float32, (2, 3), 2})
@ Base.Experimental ./<missing>:0
[13] (::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}})()
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
[14] block!(f::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
[15] make_mlir_fn(f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
[16] make_mlir_fn
@ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
[17] #127
@ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1186 [inlined]
[18] block!(f::Reactant.var"#127#132"{typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
[19] #126
@ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1185 [inlined]
[20] mmodule!(f::Reactant.var"#126#131"{Reactant.MLIR.IR.Module, typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Module.jl:93
[21] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}; optimize::Bool)
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1182
[22] (::var"#77#78")()
@ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1314
[23] context!(f::var"#77#78", ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
[24] top-level scope
@ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1312
[25] top-level scope
@ none:1
Note that I am running a version of Reactant that has broadcast defined for sigmoid
, so swish = x * sigmoid(x)
should be compiled fine.