Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
443 stars 56 forks source link

tensorcontract type stability #96

Open khanley6 opened 3 years ago

khanley6 commented 3 years ago

Hi there,

It seems that tensorcontract, with its current definition, is type unstable:

julia> @code_warntype tensorcontract(A, a_idx, B, b_idx)
Variables
  #self#::Core.Compiler.Const(TensorOperations.tensorcontract, false)
  A::Array{Int64,2}
  IA::Tuple{Int64,Int64}
  B::Array{Int64,2}
  IB::Tuple{Int64,Int64}

Body::Array
1 ─ %1 = TensorOperations.symdiff(IA, IB)::Array{Int64,1}
│   %2 = (#self#)(A, IA, B, IB, %1)::Array
└──      return %2

However, changing the definition (tested below) results in stability:

julia> mytensorcontract(A, IA, B, IB, IC = Tuple(symdiff(IA, IB))) =
          tensorcontract(A, tuple(IA...), B, tuple(IB...), IC)
julia> @code_warntype mytensorcontract(A, a_idx, B, b_idx)
Variables
  #self#::Core.Compiler.Const(mytensorcontract, false)
  A::Array{Int64,2}
  IA::Tuple{Int64,Int64}
  B::Array{Int64,2}
  IB::Tuple{Int64,Int64}

Body::Union{}
1 ─ %1 = Main.symdiff(IA, IB)::Array{Int64,1}
│   %2 = Main.tuple(%1)::Tuple{Array{Int64,1}}
│        (#self#)(A, IA, B, IB, %2)
└──      Core.Compiler.Const(:(return %3), false)

I have made the change locally and all tests pass.

Is there any reason to use the current definition over the proposed one? I guess there is still an issue with passing an array explicitly:

c_idx::Array{Int} = ...
@code_warntype mytensorcontract(A, a_idx, B, b_idx, c_idx) #-> not stable
@code_warntype mytensorcontract(A, a_idx, B, b_idx, Tuple(c_idx)) #-> stable
Jutho commented 2 weeks ago

I know this is an old one, but the current interface of tensorcontract supports entering labels in tuples, and explicitly specifying IC, in order to provide type stability. Can this be closed?