Getting type stability with EinCode #97

sethaxen opened 4 years ago

sethaxen commented 4 years ago

I have a function that computes the product of a square matrix along one dimension of an n-dimensional array. Thus, the returned array is of the same size as the passed array. Because the dimension over which to multiply is only known at runtime, I use EinCode. However, the result is not type-stable. Is there a good way to give OMEinsum more information so the compiler can figure out the return type? Or maybe more generally, what's the best way to contract over a single index shared by two arrays, where the index is only known at runtime?

julia> using OMEinsum

julia> function f(M::AbstractMatrix, V::AbstractArray; dim=1)
           n = ndims(V)
           dimsV = Tuple(Base.OneTo(n))
           dimsY = Base.setindex(dimsV, 0, dim)
           dimsM = (0, dim)
           code = EinCode((dimsV, dimsM), dimsY)
           return einsum(code, (V, M))

julia> M, V = randn(4, 4), randn(10, 4, 2);

julia> f(M, V; dim=2)
julia> using Test

julia> @inferred f(M, V; dim=2)
ERROR: return type Array{Float64,3} does not match inferred return type Any
mcabbott commented 4 years ago

If you comment out the last line of f, then its return type is EinCode{_A,_B} where _B where _A -- so I don't know how much hope there is of the final type being stable.

I think the work here is ultimately done by TensorOperations, which keeps dimensions and strides as values not types. So this is stable:

julia> function f4(M, V; dim)
           IA = (-1,0)
           IB = ntuple(d -> d==dim ? 0 : d, ndims(V))
           # IC = (-1, filter(!=(dim), ntuple(+, ndims(V)))...)
           IC = ntuple(d -> d==dim ? -1 : d, ndims(V))
           TensorOperations.tensorcontract(M, IA, V, IB, IC)
f4 (generic function with 1 method)

julia> f4(M, V; dim=2) ≈ f(M, V, dim=2)

julia> @code_warntype  f4(M, V; dim=2)
FuZhiyu commented 3 months ago

I'm surprised to find out that even when the dimensions are known, it still returns unstable results:

a, b = randn(2, 2), randn(2, 2)
function einproduct(a, b)
    # c = ein"ij,jk -> ik"(a,b)
    @ein c[i,k] := a[i,j] * b[j,k]
    return c
Main.@code_warntype einproduct(a, b)

It returns Any. Why would this be?

GiggleLiu commented 3 months ago

Thank for the issue. Type stability is completely out of consideration in OMEinsum. OMEinsum often handles tensors of rank >20, there are exploding many possible types as the output, so reducing the compilation time has a higher priority.

High order tensors appears in many applications:

  1. quantum circuit simulation (https://github.com/nzy1997/TensorQEC.jl)
  2. probabilistic inference (https://github.com/TensorBFS/TensorInference.jl)
  3. combinatorial optimization (https://github.com/QuEraComputing/GenericTensorNetworks.jl)