EnzymeAD / Reactant.jl

MIT License
26 stars 2 forks source link

Random crashes on printing `ConcreteRArray` content result of a call #9

Open mofeing opened 1 month ago

mofeing commented 1 month ago

I'm randomly running into a crash when printing the output of calling stablehlo.einsum just after calling it.

I believe it might be a problem with buffer synchronization because...

  1. Timings take the same time for different matrix sizes
  2. The error is more prone to appear for larger matrix sizes
  3. The error still happens even if I call XLA.await on the result buffer and XLA.is_ready returns true

Also, the problem might be linked with using struct types, as I've been unable to recreate the error if working directly on arrays.

julia> f(a′, b′)
2048×2048 Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}:
   0.0    233.947  239.041  242.471  244.843  233.684  234.742  242.731  241.478  235.834  246.115  228.468  162.586   153.2     …  241.457   238.066   238.748   231.795   250.899   Error showing value of type Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}:
ERROR: ArgumentError: can't repeat a string -1 times
Stacktrace:
  [1] repeat(s::String, r::Int64)
    @ Base ./strings/substring.jl:263
  [2] repeat
    @ Base ./strings/substring.jl:260 [inlined]
  [3] print_matrix_row(io::IOContext{Base.TTY}, X::AbstractVecOrMat, A::Vector{Tuple{Int64, Int64}}, i::Int64, cols::Vector{Int64}, sep::String, idxlast::Int64)
    @ Base ./arrayshow.jl:118
  [4] _print_matrix(io::IOContext{Base.TTY}, X::AbstractVecOrMat, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64, rowsA::UnitRange{Int64}, colsA::UnitRange{Int64})
    @ Base ./arrayshow.jl:254
  [5] print_matrix(io::IOContext{Base.TTY}, X::Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}}, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64)
    @ Base ./arrayshow.jl:171
  [6] print_matrix(io::IO, X::AbstractVecOrMat, pre::AbstractString, sep::AbstractString, post::AbstractString, hdots::AbstractString, vdots::AbstractString, ddots::AbstractString, hmod::Integer, vmod::Integer)
    @ Base ./arrayshow.jl:171 [inlined]
  [7] print_array
    @ ./arrayshow.jl:358 [inlined]
  [8] show(io::IOContext{Base.TTY}, ::MIME{Symbol("text/plain")}, X::Tensor{Float64, 2, Reactant.ConcreteRArray{Float64, (2048, 2048), 2}})
    @ Base ./arrayshow.jl:399
  [9] (::REPL.var"#55#56"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:273
 [10] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:569
 [11] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:259
 [12] display(d::REPL.REPLDisplay, x::Any)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:278
 [13] display(x::Any)
    @ Base.Multimedia ./multimedia.jl:340
 [14] #invokelatest#2
    @ ./essentials.jl:887 [inlined]
 [15] invokelatest
    @ ./essentials.jl:884 [inlined]
 [16] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:315
 [17] (::REPL.var"#57#58"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:284
 [18] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:569
 [19] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:282
 [20] (::REPL.var"#do_respond#80"{Bool, Bool, REPL.var"#93#103"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:911
 [21] (::VSCodeServer.var"#103#106"{REPL.var"#do_respond#80"{Bool, Bool, REPL.var"#93#103"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt}})(mi::REPL.LineEdit.MIState, buf::IOBuffer, ok::Bool)
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.79.2/scripts/packages/VSCodeServer/src/repl.jl:122
 [22] #invokelatest#2
    @ Base ./essentials.jl:887 [inlined]
 [23] invokelatest
    @ Base ./essentials.jl:884 [inlined]
 [24] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/LineEdit.jl:2656
 [25] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:1312
 [26] (::REPL.var"#62#68"{REPL.LineEditREPL, REPL.REPLBackendRef})()
    @ REPL /gpfs/apps/MN5/GPP/JULIA/1.10.0/INTEL/share/julia/stdlib/v1.10/REPL/src/REPL.jl:386

MWE

using Reactant
using Cassette

struct Tensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
           data::A
           inds::Vector{Symbol}
       end

Tensor(data::A, inds::AbstractVector{Symbol}) where {T,N,A<:AbstractArray{T,N}} = Tensor{T,N,A}(data, inds)

Base.parent(x::Tensor) = x.data
Base.size(t::Tensor) = size(parent(t))
Base.@propagate_inbounds Base.getindex(t::Tensor, i...) = getindex(parent(t), i...)

n = 2048
a = Tensor(rand(n,n), [:i, :j]);
b = Tensor(rand(n,n), [:j, :k])

a′ = Tensor(Reactant.ConcreteRArray(a.data), a.inds);
b′ = Tensor(Reactant.ConcreteRArray(b.data), b.inds);

contract(a,b) = a.data * b.data
function contract(a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray}
    ia = collect(a.inds)
    ib = collect(b.inds)
    i = ∩(ia, ib)

    ic::Vector{Symbol} = setdiff(ia ∪ ib, i)

    T = Base.promote_eltype(a, b)
    mlirty = Reactant.MLIR.IR.Type(T)

    op_a = parent(a).mlir_data
    op_b = parent(b).mlir_data
    rsize = (size(a.data, 1), size(b.data, 2))
    result_0 = Reactant.MLIR.IR.TensorType(rsize, mlirty)
    einsum_config = Reactant.MLIR.IR.Attribute("$(join(ia)),$(join(ib))->$(join(ic))")

    result = Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.einsum(op_a, op_b; result_0, einsum_config))

    data = Reactant.TracedRArray{T,rsize,length(ic)}((), result)
    _res = Tensor(data, ic)
    return _res
end

f = Reactant.compile(contract, (a′,b′))

f(a′,b′)

Without Tensor, it seems to work

using Reactant
using Cassette

n = 2048
a = rand(n, n);
b = rand(n, n);

a′ = Reactant.ConcreteRArray(a);
b′ = Reactant.ConcreteRArray(b);

matmul(a,b) = a * b

function matmul(a::Reactant.TracedRArray{Ta,Sa,Na}, b::Reactant.TracedRArray{Tb,Sb,Nb}) where {Ta,Tb,Sa,Sb,Na,Nb}
    T = Base.promote_type(Ta,Tb)
    mlirty = Reactant.MLIR.IR.Type(T)
    rsize = (Sa[1], Sb[2])
    result_0 = Reactant.MLIR.IR.TensorType(rsize, mlirty)
    einsum_config = Reactant.MLIR.IR.Attribute("ij,jk->ik")

    result = Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.einsum(a.mlir_data, b.mlir_data; result_0, einsum_config))

    return Reactant.TracedRArray{T,rsize,2}((), result)
end

Cassette.overdub(ctx::Reactant.TraceCtx, f::typeof(matmul), args...; kwargs...) = f(args...; kwargs...)

f = Reactant.compile(matmul, (a′,b′))

f(a′,b′)