under-Peter / OMEinsum.jl

One More Einsum for Julia! With runtime order-specification and high-level adjoints for AD
https://under-peter.github.io/OMEinsum.jl/dev/
MIT License
181 stars 23 forks source link

Getting type stability with EinCode #97

Open sethaxen opened 4 years ago

sethaxen commented 4 years ago

I have a function that computes the product of a square matrix along one dimension of an n-dimensional array. Thus, the returned array is of the same size as the passed array. Because the dimension over which to multiply is only known at runtime, I use EinCode. However, the result is not type-stable. Is there a good way to give OMEinsum more information so the compiler can figure out the return type? Or maybe more generally, what's the best way to contract over a single index shared by two arrays, where the index is only known at runtime?

julia> using OMEinsum

julia> function f(M::AbstractMatrix, V::AbstractArray; dim=1)
           n = ndims(V)
           dimsV = Tuple(Base.OneTo(n))
           dimsY = Base.setindex(dimsV, 0, dim)
           dimsM = (0, dim)
           code = EinCode((dimsV, dimsM), dimsY)
           return einsum(code, (V, M))
       end

julia> M, V = randn(4, 4), randn(10, 4, 2);

julia> f(M, V; dim=2)
10×4×2 Array{Float64,3}:
[:, :, 1] =
  1.19814   -2.83308   -0.82374    6.23831
 -3.856      1.35973    0.168978   1.15039
  3.60948   -2.782     -0.735527   1.44291
 -4.52866    0.361779   0.807384   3.24125
  2.74821    1.30956    1.20418   -5.25221
  4.45576   -0.632032  -1.40112   -5.93926
 -2.1384     0.81895    0.187812  -1.01684
  4.51044   -1.39046   -0.798984  -3.6388
 -0.987397  -0.393374  -1.85841   -0.326891
 -3.02511    2.97092    2.33957   -3.35689

[:, :, 2] =
  1.9988    -2.7311     -2.85731     3.38059
 -5.63312    2.61159     3.5489      7.22906
  1.58536    0.74342    -0.0612845  -5.44578
  0.957018  -0.0174554   0.838485    0.054773
  1.81001   -1.62433    -0.753998    0.165946
  2.69391   -0.0213057  -1.24054    -6.89847
  3.61053   -2.85339    -1.76307    -1.98227
  4.4069    -0.590834    0.724681    0.698118
 -5.60072    1.33233     1.42462     4.45287
 -2.31928   -0.103913    1.75607     7.84296

julia> using Test

julia> @inferred f(M, V; dim=2)
ERROR: return type Array{Float64,3} does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at REPL[57]:1
mcabbott commented 4 years ago

If you comment out the last line of f, then its return type is EinCode{_A,_B} where _B where _A -- so I don't know how much hope there is of the final type being stable.

I think the work here is ultimately done by TensorOperations, which keeps dimensions and strides as values not types. So this is stable:

julia> function f4(M, V; dim)
           IA = (-1,0)
           IB = ntuple(d -> d==dim ? 0 : d, ndims(V))
           # IC = (-1, filter(!=(dim), ntuple(+, ndims(V)))...)
           IC = ntuple(d -> d==dim ? -1 : d, ndims(V))
           TensorOperations.tensorcontract(M, IA, V, IB, IC)
       end
f4 (generic function with 1 method)

julia> f4(M, V; dim=2) ≈ f(M, V, dim=2)
true

julia> @code_warntype  f4(M, V; dim=2)
...
Body::Array{Float64,3}
FuZhiyu commented 3 months ago

I'm surprised to find out that even when the dimensions are known, it still returns unstable results:

a, b = randn(2, 2), randn(2, 2)
function einproduct(a, b)
    # c = ein"ij,jk -> ik"(a,b)
    @ein c[i,k] := a[i,j] * b[j,k]
    return c
end
Main.@code_warntype einproduct(a, b)

It returns Any. Why would this be?

GiggleLiu commented 3 months ago

Thank for the issue. Type stability is completely out of consideration in OMEinsum. OMEinsum often handles tensors of rank >20, there are exploding many possible types as the output, so reducing the compilation time has a higher priority.

High order tensors appears in many applications:

  1. quantum circuit simulation (https://github.com/nzy1997/TensorQEC.jl)
  2. probabilistic inference (https://github.com/TensorBFS/TensorInference.jl)
  3. combinatorial optimization (https://github.com/QuEraComputing/GenericTensorNetworks.jl)