Open mofeing opened 1 month ago
I'm randomly running into a crash when printing the output of calling stablehlo.einsum just after calling it.
stablehlo.einsum
I believe it might be a problem with buffer synchronization because...
XLA.await
XLA.is_ready
Also, the problem might be linked with using struct types, as I've been unable to recreate the error if working directly on arrays.
struct
julia> f(a′, b′) 2048×2048 Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}: 0.0 233.947 239.041 242.471 244.843 233.684 234.742 242.731 241.478 235.834 246.115 228.468 162.586 153.2 … 241.457 238.066 238.748 231.795 250.899 Error showing value of type Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}: ERROR: ArgumentError: can't repeat a string -1 times Stacktrace: [1] repeat(s::String, r::Int64) @ Base ./strings/substring.jl:263 [2] repeat @ Base ./strings/substring.jl:260 [inlined] [3] print_matrix_row(io::IOContext{Base.TTY}, X::AbstractVecOrMat, A::Vector{Tuple{Int64, Int64}}, i::Int64, cols::Vector{Int64}, sep::String, idxlast::Int64) @ Base ./arrayshow.jl:118 [4] _print_matrix(io::IOContext{Base.TTY}, X::AbstractVecOrMat, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64, rowsA::UnitRange{Int64}, colsA::UnitRange{Int64}) @ Base ./arrayshow.jl:254 [5] print_matrix(io::IOContext{Base.TTY}, X::Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64) @ Base ./arrayshow.jl:171 [6] print_matrix(io::IO, X::AbstractVecOrMat, pre::AbstractString, sep::AbstractString, post::AbstractString, hdots::AbstractString, vdots::AbstractString, ddots::AbstractString, hmod::Integer, vmod::Integer) @ Base ./arrayshow.jl:171 [inlined] [7] print_array @ ./arrayshow.jl:358 [inlined] [8] show(io::IOContext{Base.TTY}, ::MIME{Symbol("text/plain")}, X::Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}) @ Base ./arrayshow.jl:399 [9] (::REPL.var"#55#56"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:273 [10] with_repl_linfo(f::Any, repl::REPL.LineEditREPL) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:569 [11] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:259 [12] display(d::REPL.REPLDisplay, x::Any) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:278 [13] display(x::Any) @ Base.Multimedia ./multimedia.jl:340 [14] #invokelatest#2 @ ./essentials.jl:887 [inlined] [15] invokelatest @ ./essentials.jl:884 [inlined] [16] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay}) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:315 [17] (::REPL.var"#57#58"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:284 [18] with_repl_linfo(f::Any, repl::REPL.LineEditREPL) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:569 [19] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:282 [20] (::REPL.var"#do_respond#80"{Bool, Bool, REPL.var"#93#103"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:911 [21] (::VSCodeServer.var"#103#106"{REPL.var"#do_respond#80"{Bool, Bool, REPL.var"#93#103"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt}})(mi::REPL.LineEdit.MIState, buf::IOBuffer, ok::Bool) @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:122 [22] #invokelatest#2 @ Base ./essentials.jl:887 [inlined] [23] invokelatest @ Base ./essentials.jl:884 [inlined] [24] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState) @ REPL.LineEdit /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/LineEdit.jl:2656 [25] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef) @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:1312 [26] (::REPL.var"#62#68"{REPL.LineEditREPL, REPL.REPLBackendRef})() @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:386
using Reactant using Cassette struct Tensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} data::A inds::Vector{Symbol} end Tensor(data::A, inds::AbstractVector{Symbol}) where {T,N,A<:AbstractArray{T,N}} = Tensor{T,N,A}(data, inds) Base.parent(x::Tensor) = x.data Base.size(t::Tensor) = size(parent(t)) Base.@propagate_inbounds Base.getindex(t::Tensor, i...) = getindex(parent(t), i...) n = 2048 a = Tensor(rand(n,n), [:i, :j]); b = Tensor(rand(n,n), [:j, :k]) a′ = Tensor(Reactant.ConcreteRArray(a.data), a.inds); b′ = Tensor(Reactant.ConcreteRArray(b.data), b.inds); contract(a,b) = a.data * b.data function contract(a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray} ia = collect(a.inds) ib = collect(b.inds) i = ∩(ia, ib) ic::Vector{Symbol} = setdiff(ia ∪ ib, i) T = Base.promote_eltype(a, b) mlirty = Reactant.MLIR.IR.Type(T) op_a = parent(a).mlir_data op_b = parent(b).mlir_data rsize = (size(a.data, 1), size(b.data, 2)) result_0 = Reactant.MLIR.IR.TensorType(rsize, mlirty) einsum_config = Reactant.MLIR.IR.Attribute("$(join(ia)),$(join(ib))->$(join(ic))") result = Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.einsum(op_a, op_b; result_0, einsum_config)) data = Reactant.TracedRArray{T,rsize,length(ic)}((), result) _res = Tensor(data, ic) return _res end f = Reactant.compile(contract, (a′,b′)) f(a′,b′)
Tensor
using Reactant using Cassette n = 2048 a = rand(n, n); b = rand(n, n); a′ = Reactant.ConcreteRArray(a); b′ = Reactant.ConcreteRArray(b); matmul(a,b) = a * b function matmul(a::Reactant.TracedRArray{Ta,Sa,Na}, b::Reactant.TracedRArray{Tb,Sb,Nb}) where {Ta,Tb,Sa,Sb,Na,Nb} T = Base.promote_type(Ta,Tb) mlirty = Reactant.MLIR.IR.Type(T) rsize = (Sa[1], Sb[2]) result_0 = Reactant.MLIR.IR.TensorType(rsize, mlirty) einsum_config = Reactant.MLIR.IR.Attribute("ij,jk->ik") result = Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.einsum(a.mlir_data, b.mlir_data; result_0, einsum_config)) return Reactant.TracedRArray{T,rsize,2}((), result) end Cassette.overdub(ctx::Reactant.TraceCtx, f::typeof(matmul), args...; kwargs...) = f(args...; kwargs...) f = Reactant.compile(matmul, (a′,b′)) f(a′,b′)
I'm randomly running into a crash when printing the output of calling
stablehlo.einsum
just after calling it.I believe it might be a problem with buffer synchronization because...
XLA.await
on the result buffer andXLA.is_ready
returns trueAlso, the problem might be linked with using
struct
types, as I've been unable to recreate the error if working directly on arrays.MWE
Without
Tensor
, it seems to work